import logging
from typing import List, Optional, Sequence, Union
import numpy as np
import torch
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
)
from scvi.dataloaders import DataSplitter
from scvi.model.base import BaseModelClass
from scvi.train import PyroTrainingPlan, TrainRunner
from scvi.utils import setup_anndata_dsp
from ._mypyromodule import MyPyroModule
logger = logging.getLogger(__name__)
[docs]class MyPyroModel(BaseModelClass):
"""
Skeleton for a pyro version of a scvi-tools model.
Please use this skeleton to create new models.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~mypackage.MyPyroModel.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
**model_kwargs
Keyword args for :class:`~mypackage.MyModule`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> mypackage.MyPyroModel.setup_anndata(adata, batch_key="batch")
>>> vae = mypackage.MyModel(adata)
>>> vae.train()
>>> adata.obsm["X_mymodel"] = vae.get_latent_representation()
"""
def __init__(
self,
adata: AnnData,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
**model_kwargs,
):
super(MyPyroModel, self).__init__(adata)
# self.summary_stats provides information about anndata dimensions and other tensor info
self.module = MyPyroModule(
n_input=self.summary_stats["n_vars"],
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
**model_kwargs,
)
self._model_summary_string = "Overwrite this attribute to get an informative representation for your model"
# necessary line to get params that will be used for saving/loading
self.init_params_ = self._get_init_params(locals())
logger.info("The model has been initialized")
[docs] def get_latent(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
):
r"""
Return the latent representation for each cell.
This is denoted as :math:`z_n` in our manuscripts.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
Returns
-------
latent_representation : np.ndarray
Low-dimensional representation for each cell
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
latent = []
for tensors in scdl:
qz_m = self.module.get_latent(tensors)
latent += [qz_m.cpu()]
return np.array(torch.cat(latent))
[docs] def train(
self,
max_epochs: Optional[int] = None,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 0.9,
validation_size: Optional[float] = None,
batch_size: int = 128,
plan_kwargs: Optional[dict] = None,
**trainer_kwargs,
):
"""
Train the model.
Parameters
----------
max_epochs
Number of passes through the dataset. If `None`, defaults to
`np.min([round((20000 / n_cells) * 400), 400])`
use_gpu
Use default GPU if available (if None or True), or index of GPU to use (if int),
or name of GPU (if str), or use CPU (if False).
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
batch_size
Minibatch size to use during training.
plan_kwargs
Keyword args for :class:`~scvi.lightning.TrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**trainer_kwargs
Other keyword args for :class:`~scvi.lightning.Trainer`.
"""
if max_epochs is None:
n_cells = self.adata.n_obs
max_epochs = np.min([round((20000 / n_cells) * 400), 400])
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict()
data_splitter = DataSplitter(
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
use_gpu=use_gpu,
)
training_plan = PyroTrainingPlan(self.module, **plan_kwargs)
runner = TrainRunner(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
**trainer_kwargs,
)
return runner()
[docs] @classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
layer: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
**kwargs,
) -> Optional[AnnData]:
"""
%(summary)s.
Parameters
----------
%(param_adata)s
%(param_batch_key)s
%(param_labels_key)s
%(param_layer)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
Returns
-------
%(returns)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
CategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
),
NumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)