import logging
from typing import Optional, Sequence, Union
import numpy as np
import torch
from anndata import AnnData
from scvi.dataloaders import DataSplitter
from scvi.model.base import BaseModelClass
from scvi.train import PyroTrainingPlan, TrainRunner
from ._mypyromodule import MyPyroModule
logger = logging.getLogger(__name__)
[docs]class MyPyroModel(BaseModelClass):
"""
Skeleton for a pyro version of a scvi-tools model.
Please use this skeleton to create new models.
Parameters
----------
adata
AnnData object that has been registered via :func:`~scvi.data.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
**model_kwargs
Keyword args for :class:`~mypackage.MyModule`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.data.setup_anndata(adata, batch_key="batch")
>>> vae = mypackage.MyModel(adata)
>>> vae.train()
>>> adata.obsm["X_mymodel"] = vae.get_latent_representation()
"""
def __init__(
self,
adata: AnnData,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
**model_kwargs,
):
super(MyPyroModel, self).__init__(adata)
# self.summary_stats provides information about anndata dimensions and other tensor info
self.module = MyPyroModule(
n_input=self.summary_stats["n_vars"],
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
**model_kwargs,
)
self._model_summary_string = "Overwrite this attribute to get an informative representation for your model"
# necessary line to get params that will be used for saving/loading
self.init_params_ = self._get_init_params(locals())
logger.info("The model has been initialized")
[docs] def get_latent(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
):
r"""
Return the latent representation for each cell.
This is denoted as :math:`z_n` in our manuscripts.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
Returns
-------
latent_representation : np.ndarray
Low-dimensional representation for each cell
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
latent = []
for tensors in scdl:
qz_m = self.module.get_latent(tensors)
latent += [qz_m.cpu()]
return np.array(torch.cat(latent))
[docs] def train(
self,
max_epochs: Optional[int] = None,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 0.9,
validation_size: Optional[float] = None,
batch_size: int = 128,
plan_kwargs: Optional[dict] = None,
**trainer_kwargs,
):
"""
Train the model.
Parameters
----------
max_epochs
Number of passes through the dataset. If `None`, defaults to
`np.min([round((20000 / n_cells) * 400), 400])`
use_gpu
Use default GPU if available (if None or True), or index of GPU to use (if int),
or name of GPU (if str), or use CPU (if False).
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
batch_size
Minibatch size to use during training.
plan_kwargs
Keyword args for :class:`~scvi.lightning.TrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**trainer_kwargs
Other keyword args for :class:`~scvi.lightning.Trainer`.
"""
if max_epochs is None:
n_cells = self.adata.n_obs
max_epochs = np.min([round((20000 / n_cells) * 400), 400])
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
data_splitter = DataSplitter(
self.adata,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
use_gpu=use_gpu,
)
training_plan = PyroTrainingPlan(self.module, **plan_kwargs)
runner = TrainRunner(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
**trainer_kwargs,
)
return runner()