import logging
from typing import List, Optional
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
)
from scvi.model._utils import _init_library_size
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.utils import setup_anndata_dsp
from ._mymodule import MyModule
logger = logging.getLogger(__name__)
[docs]class MyModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
"""
Skeleton for an scvi-tools model.
Please use this skeleton to create new models.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
**model_kwargs
Keyword args for :class:`~mypackage.MyModule`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> mypackage.MyModel.setup_anndata(adata, batch_key="batch")
>>> vae = mypackage.MyModel(adata)
>>> vae.train()
>>> adata.obsm["X_mymodel"] = vae.get_latent_representation()
"""
def __init__(
self,
adata: AnnData,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
**model_kwargs,
):
super(MyModel, self).__init__(adata)
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, self.summary_stats["n_batch"]
)
# self.summary_stats provides information about anndata dimensions and other tensor info
self.module = MyModule(
n_input=self.summary_stats["n_vars"],
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
library_log_means=library_log_means,
library_log_vars=library_log_vars,
**model_kwargs,
)
self._model_summary_string = "Overwrite this attribute to get an informative representation for your model"
# necessary line to get params that will be used for saving/loading
self.init_params_ = self._get_init_params(locals())
logger.info("The model has been initialized")
[docs] @classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
layer: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
**kwargs,
) -> Optional[AnnData]:
"""
%(summary)s.
Parameters
----------
%(param_adata)s
%(param_batch_key)s
%(param_labels_key)s
%(param_layer)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
Returns
-------
%(returns)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
CategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
),
NumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)