Source code for mypackage._mymodel

import logging

from anndata import AnnData
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin

from ._mymodule import MyModule

logger = logging.getLogger(__name__)


[docs]class MyModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): """ Skeleton for an scvi-tools model. Please use this skeleton to create new models. Parameters ---------- adata AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. n_hidden Number of nodes per hidden layer. n_latent Dimensionality of the latent space. n_layers Number of hidden layers used for encoder and decoder NNs. **model_kwargs Keyword args for :class:`~mypackage.MyModule` Examples -------- >>> adata = anndata.read_h5ad(path_to_anndata) >>> scvi.data.setup_anndata(adata, batch_key="batch") >>> vae = mypackage.MyModel(adata) >>> vae.train() >>> adata.obsm["X_mymodel"] = vae.get_latent_representation() """ def __init__( self, adata: AnnData, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, **model_kwargs, ): super(MyModel, self).__init__(adata) # self.summary_stats provides information about anndata dimensions and other tensor info self.module = MyModule( n_input=self.summary_stats["n_vars"], n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, **model_kwargs, ) self._model_summary_string = "Overwrite this attribute to get an informative representation for your model" # necessary line to get params that will be used for saving/loading self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized")