Source code for mypackage._mymodel

import logging
from typing import List, Optional

from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
)
from scvi.model._utils import _init_library_size
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.utils import setup_anndata_dsp

from ._mymodule import MyModule

logger = logging.getLogger(__name__)


[docs]class MyModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): """ Skeleton for an scvi-tools model. Please use this skeleton to create new models. Parameters ---------- adata AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`. n_hidden Number of nodes per hidden layer. n_latent Dimensionality of the latent space. n_layers Number of hidden layers used for encoder and decoder NNs. **model_kwargs Keyword args for :class:`~mypackage.MyModule` Examples -------- >>> adata = anndata.read_h5ad(path_to_anndata) >>> mypackage.MyModel.setup_anndata(adata, batch_key="batch") >>> vae = mypackage.MyModel(adata) >>> vae.train() >>> adata.obsm["X_mymodel"] = vae.get_latent_representation() """ def __init__( self, adata: AnnData, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, **model_kwargs, ): super(MyModel, self).__init__(adata) library_log_means, library_log_vars = _init_library_size( self.adata_manager, self.summary_stats["n_batch"] ) # self.summary_stats provides information about anndata dimensions and other tensor info self.module = MyModule( n_input=self.summary_stats["n_vars"], n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, library_log_means=library_log_means, library_log_vars=library_log_vars, **model_kwargs, ) self._model_summary_string = "Overwrite this attribute to get an informative representation for your model" # necessary line to get params that will be used for saving/loading self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized")
[docs] @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, **kwargs, ) -> Optional[AnnData]: """ %(summary)s. Parameters ---------- %(param_adata)s %(param_batch_key)s %(param_labels_key)s %(param_layer)s %(param_cat_cov_keys)s %(param_cont_cov_keys)s Returns ------- %(returns)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), CategoricalJointObsField( REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys ), NumericalJointObsField( REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys ), ] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)