Source code for mypackage._mypyromodel

import logging
from typing import Optional, Sequence, Union

import numpy as np
import torch
from anndata import AnnData
from scvi.dataloaders import DataSplitter
from scvi.model.base import BaseModelClass
from scvi.train import PyroTrainingPlan, TrainRunner

from ._mypyromodule import MyPyroModule

logger = logging.getLogger(__name__)


[docs]class MyPyroModel(BaseModelClass): """ Skeleton for a pyro version of a scvi-tools model. Please use this skeleton to create new models. Parameters ---------- adata AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. n_hidden Number of nodes per hidden layer. n_latent Dimensionality of the latent space. n_layers Number of hidden layers used for encoder and decoder NNs. **model_kwargs Keyword args for :class:`~mypackage.MyModule` Examples -------- >>> adata = anndata.read_h5ad(path_to_anndata) >>> scvi.data.setup_anndata(adata, batch_key="batch") >>> vae = mypackage.MyModel(adata) >>> vae.train() >>> adata.obsm["X_mymodel"] = vae.get_latent_representation() """ def __init__( self, adata: AnnData, n_hidden: int = 128, n_latent: int = 10, n_layers: int = 1, **model_kwargs, ): super(MyPyroModel, self).__init__(adata) # self.summary_stats provides information about anndata dimensions and other tensor info self.module = MyPyroModule( n_input=self.summary_stats["n_vars"], n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, **model_kwargs, ) self._model_summary_string = "Overwrite this attribute to get an informative representation for your model" # necessary line to get params that will be used for saving/loading self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized")
[docs] def get_latent( self, adata: Optional[AnnData] = None, indices: Optional[Sequence[int]] = None, batch_size: Optional[int] = None, ): r""" Return the latent representation for each cell. This is denoted as :math:`z_n` in our manuscripts. Parameters ---------- adata AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. Returns ------- latent_representation : np.ndarray Low-dimensional representation for each cell """ adata = self._validate_anndata(adata) scdl = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) latent = [] for tensors in scdl: qz_m = self.module.get_latent(tensors) latent += [qz_m.cpu()] return np.array(torch.cat(latent))
[docs] def train( self, max_epochs: Optional[int] = None, use_gpu: Optional[Union[str, int, bool]] = None, train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 128, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): """ Train the model. Parameters ---------- max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` use_gpu Use default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size Minibatch size to use during training. plan_kwargs Keyword args for :class:`~scvi.lightning.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. **trainer_kwargs Other keyword args for :class:`~scvi.lightning.Trainer`. """ if max_epochs is None: n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() data_splitter = DataSplitter( self.adata, train_size=train_size, validation_size=validation_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = PyroTrainingPlan(self.module, **plan_kwargs) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, use_gpu=use_gpu, **trainer_kwargs, ) return runner()