Source code for mypackage._mypyromodel

import logging
from typing import List, Optional, Sequence, Union

import numpy as np
import torch
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
)
from scvi.dataloaders import DataSplitter
from scvi.model.base import BaseModelClass
from scvi.train import PyroTrainingPlan, TrainRunner
from scvi.utils import setup_anndata_dsp

from ._mypyromodule import MyPyroModule

logger = logging.getLogger(__name__)


[docs]class MyPyroModel(BaseModelClass): """ Skeleton for a pyro version of a scvi-tools model. Please use this skeleton to create new models. Parameters ---------- adata AnnData object that has been registered via :meth:`~mypackage.MyPyroModel.setup_anndata`. n_hidden Number of nodes per hidden layer. n_latent Dimensionality of the latent space. n_layers Number of hidden layers used for encoder and decoder NNs. **model_kwargs Keyword args for :class:`~mypackage.MyModule` Examples -------- >>> adata = anndata.read_h5ad(path_to_anndata) >>> mypackage.MyPyroModel.setup_anndata(adata, batch_key="batch") >>> vae = mypackage.MyModel(adata) >>> vae.train() >>> adata.obsm["X_mymodel"] = vae.get_latent_representation() """ def __init__( self, adata: AnnData, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, **model_kwargs, ): super(MyPyroModel, self).__init__(adata) # self.summary_stats provides information about anndata dimensions and other tensor info self.module = MyPyroModule( n_input=self.summary_stats["n_vars"], n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, **model_kwargs, ) self._model_summary_string = "Overwrite this attribute to get an informative representation for your model" # necessary line to get params that will be used for saving/loading self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized")
[docs] def get_latent( self, adata: Optional[AnnData] = None, indices: Optional[Sequence[int]] = None, batch_size: Optional[int] = None, ): r""" Return the latent representation for each cell. This is denoted as :math:`z_n` in our manuscripts. Parameters ---------- adata AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. Returns ------- latent_representation : np.ndarray Low-dimensional representation for each cell """ adata = self._validate_anndata(adata) scdl = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) latent = [] for tensors in scdl: qz_m = self.module.get_latent(tensors) latent += [qz_m.cpu()] return np.array(torch.cat(latent))
[docs] def train( self, max_epochs: Optional[int] = None, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for :class:`~scvi.lightning.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **trainer_kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() data_splitter = DataSplitter( self.adata_manager, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = PyroTrainingPlan(self.module, **plan_kwargs) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, **trainer_kwargs, ) return runner()
[docs] @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, **kwargs, ) -> Optional[AnnData]: """ %(summary)s. Parameters ---------- %(param_adata)s %(param_batch_key)s %(param_labels_key)s %(param_layer)s %(param_cat_cov_keys)s %(param_cont_cov_keys)s Returns ------- %(returns)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), CategoricalJointObsField( REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys ), NumericalJointObsField( REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys ), ] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)