import pyro
import pyro.distributions as dist
import torch
from scvi import _CONSTANTS
from scvi.module.base import PyroBaseModuleClass, auto_move_data
from scvi.nn import DecoderSCVI, Encoder
[docs]class MyPyroModule(PyroBaseModuleClass):
"""
Skeleton Variational auto-encoder Pyro model.
Here we implement a basic version of scVI's underlying VAE [Lopez18]_.
This implementation is for instructional purposes only.
Parameters
----------
n_input
Number of input genes
n_latent
Dimensionality of the latent space
n_hidden
Number of nodes per hidden layer
n_layers
Number of hidden layers used for encoder and decoder NNs
"""
def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int):
super().__init__()
self.n_input = n_input
self.n_latent = n_latent
self.epsilon = 5.0e-3
# z encoder goes from the n_input-dimensional data to an n_latent-d
# latent space representation
self.encoder = Encoder(
n_input,
n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0.1,
)
# decoder goes from n_latent-dimensional space to n_input-d data
self.decoder = DecoderSCVI(
n_latent,
n_input,
n_layers=n_layers,
n_hidden=n_hidden,
)
# This gene-level parameter modulates the variance of the observation distribution
self.px_r = torch.nn.Parameter(torch.ones(self.n_input))
@staticmethod
def _get_fn_args_from_batch(tensor_dict):
x = tensor_dict[_CONSTANTS.X_KEY]
log_library = torch.log(torch.sum(x, dim=1, keepdim=True) + 1e-6)
return (x, log_library), {}
[docs] def model(self, x, log_library):
# register PyTorch module `decoder` with Pyro
pyro.module("scvi", self)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.n_latent)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.n_latent)))
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
px_scale, _, px_rate, px_dropout = self.decoder("gene", z, log_library)
# build count distribution
nb_logits = (px_rate + self.epsilon).log() - (
self.px_r.exp() + self.epsilon
).log()
x_dist = dist.ZeroInflatedNegativeBinomial(
gate_logits=px_dropout, total_count=self.px_r.exp(), logits=nb_logits
)
# score against actual counts
pyro.sample("obs", x_dist.to_event(1), obs=x)
[docs] def guide(self, x, log_library):
# define the guide (i.e. variational distribution) q(z|x)
pyro.module("scvi", self)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
x_ = torch.log(1 + x)
z_loc, z_scale, _ = self.encoder(x_)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
[docs] @torch.no_grad()
@auto_move_data
def get_latent(self, tensors):
x = tensors[_CONSTANTS.X_KEY]
x_ = torch.log(1 + x)
z_loc, _, _ = self.encoder(x_)
return z_loc