mypackage.MyPyroModel

class mypackage.MyPyroModel(adata, n_hidden=128, n_latent=10, n_layers=1, **model_kwargs)[source]

Skeleton for a pyro version of a scvi-tools model.

Please use this skeleton to create new models.

Parameters
adata : AnnData

AnnData object that has been registered via setup_anndata().

n_hidden : int (default: 128)

Number of nodes per hidden layer.

n_latent : int (default: 10)

Dimensionality of the latent space.

n_layers : int (default: 1)

Number of hidden layers used for encoder and decoder NNs.

**model_kwargs

Keyword args for MyModule

Examples

>>> adata = anndata.read_h5ad(path_to_anndata)
>>> mypackage.MyPyroModel.setup_anndata(adata, batch_key="batch")
>>> vae = mypackage.MyModel(adata)
>>> vae.train()
>>> adata.obsm["X_mymodel"] = vae.get_latent_representation()

Attributes

adata

Data attached to model instance.

adata_manager

Manager instance associated with self.adata.

device

The current device that the module's params are on.

history

Returns computed metrics during training.

is_trained

Whether the model has been trained.

test_indices

Observations that are in test set.

train_indices

Observations that are in train set.

validation_indices

Observations that are in validation set.

Methods

get_anndata_manager(adata[, required])

Retrieves the AnnDataManager for a given AnnData object specific to this model instance.

get_from_registry(adata, registry_key)

Returns the object in AnnData associated with the key in the data registry.

get_latent([adata, indices, batch_size])

Return the latent representation for each cell.

load(dir_path[, adata, use_gpu, prefix])

Instantiate a model from the saved output.

register_manager(adata_manager)

Registers an AnnDataManager instance with this model class.

save(dir_path[, prefix, overwrite, save_anndata])

Save the state of the model.

setup_anndata(adata[, batch_key, ...])

Sets up the AnnData object for this model.

to_device(device)

Move model to device.

train([max_epochs, use_gpu, train_size, ...])

Train the model.

view_anndata_setup([adata, ...])

Print summary of the setup for the initial AnnData or a given AnnData object.

view_setup_args(dir_path[, prefix])

Print args used to setup a saved model.