import numpy as np
import torch
from scvi import _CONSTANTS
from scvi.distributions import ZeroInflatedNegativeBinomial
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.nn import DecoderSCVI, Encoder
from torch.distributions import Normal
from torch.distributions import kl_divergence as kl
torch.backends.cudnn.benchmark = True
[docs]class MyModule(BaseModuleClass):
"""
Skeleton Variational auto-encoder model.
Here we implement a basic version of scVI's underlying VAE [Lopez18]_.
This implementation is for instructional purposes only.
Parameters
----------
n_input
Number of input genes
n_batch
Number of batches, if 0, no batch correction is performed.
n_hidden
Number of nodes per hidden layer
n_latent
Dimensionality of the latent space
n_layers
Number of hidden layers used for encoder and decoder NNs
dropout_rate
Dropout rate for neural networks
"""
def __init__(
self,
n_input: int,
n_batch: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
):
super().__init__()
self.n_latent = n_latent
self.n_batch = n_batch
# this is needed to comply with some requirement of the VAEMixin class
self.latent_distribution = "normal"
# setup the parameters of your generative model, as well as your inference model
self.px_r = torch.nn.Parameter(torch.randn(n_input))
# z encoder goes from the n_input-dimensional data to an n_latent-d
# latent space representation
self.z_encoder = Encoder(
n_input,
n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
# l encoder goes from n_input-dimensional data to 1-d library size
self.l_encoder = Encoder(
n_input,
1,
n_layers=1,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
)
# decoder goes from n_latent-dimensional space to n_input-d data
self.decoder = DecoderSCVI(
n_latent,
n_input,
n_layers=n_layers,
n_hidden=n_hidden,
)
def _get_inference_input(self, tensors):
"""Parse the dictionary to get appropriate args"""
x = tensors[_CONSTANTS.X_KEY]
input_dict = dict(x=x)
return input_dict
def _get_generative_input(self, tensors, inference_outputs):
z = inference_outputs["z"]
library = inference_outputs["library"]
input_dict = {
"z": z,
"library": library,
}
return input_dict
[docs] @auto_move_data
def inference(self, x):
"""
High level inference method.
Runs the inference (encoder) model.
"""
# log the input to the variational distribution for numerical stability
x_ = torch.log(1 + x)
# get variational parameters via the encoder networks
qz_m, qz_v, z = self.z_encoder(x_)
ql_m, ql_v, library = self.l_encoder(x_)
outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library)
return outputs
[docs] @auto_move_data
def generative(self, z, library):
"""Runs the generative model."""
# form the parameters of the ZINB likelihood
px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library)
px_r = torch.exp(self.px_r)
return dict(
px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout
)
[docs] def loss(
self,
tensors,
inference_outputs,
generative_outputs,
kl_weight: float = 1.0,
):
x = tensors[_CONSTANTS.X_KEY]
local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY]
local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY]
qz_m = inference_outputs["qz_m"]
qz_v = inference_outputs["qz_v"]
ql_m = inference_outputs["ql_m"]
ql_v = inference_outputs["ql_v"]
px_rate = generative_outputs["px_rate"]
px_r = generative_outputs["px_r"]
px_dropout = generative_outputs["px_dropout"]
mean = torch.zeros_like(qz_m)
scale = torch.ones_like(qz_v)
kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
dim=1
)
kl_divergence_l = kl(
Normal(ql_m, torch.sqrt(ql_v)),
Normal(local_l_mean, torch.sqrt(local_l_var)),
).sum(dim=1)
reconst_loss = (
-ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout)
.log_prob(x)
.sum(dim=-1)
)
kl_local_for_warmup = kl_divergence_z
kl_local_no_warmup = kl_divergence_l
weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup
loss = torch.mean(reconst_loss + weighted_kl_local)
kl_local = dict(
kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z
)
kl_global = 0.0
return LossRecorder(loss, reconst_loss, kl_local, kl_global)
[docs] @torch.no_grad()
def sample(
self,
tensors,
n_samples=1,
library_size=1,
) -> np.ndarray:
r"""
Generate observation samples from the posterior predictive distribution.
The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.
Parameters
----------
tensors
Tensors dict
n_samples
Number of required samples for each cell
library_size
Library size to scale scamples to
Returns
-------
x_new : :py:class:`torch.Tensor`
tensor with shape (n_cells, n_genes, n_samples)
"""
inference_kwargs = dict(n_samples=n_samples)
inference_outputs, generative_outputs, = self.forward(
tensors,
inference_kwargs=inference_kwargs,
compute_loss=False,
)
px_r = generative_outputs["px_r"]
px_rate = generative_outputs["px_rate"]
px_dropout = generative_outputs["px_dropout"]
dist = ZeroInflatedNegativeBinomial(
mu=px_rate, theta=px_r, zi_logits=px_dropout
)
if n_samples > 1:
exprs = dist.sample().permute(
[1, 2, 0]
) # Shape : (n_cells_batch, n_genes, n_samples)
else:
exprs = dist.sample()
return exprs.cpu()
[docs] @torch.no_grad()
@auto_move_data
def marginal_ll(self, tensors, n_mc_samples):
sample_batch = tensors[_CONSTANTS.X_KEY]
local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY]
local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY]
to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples)
for i in range(n_mc_samples):
# Distribution parameters and sampled variables
inference_outputs, generative_outputs, losses = self.forward(tensors)
qz_m = inference_outputs["qz_m"]
qz_v = inference_outputs["qz_v"]
z = inference_outputs["z"]
ql_m = inference_outputs["ql_m"]
ql_v = inference_outputs["ql_v"]
library = inference_outputs["library"]
# Reconstruction Loss
reconst_loss = losses.reconstruction_loss
# Log-probabilities
p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1)
p_z = (
Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
.log_prob(z)
.sum(dim=-1)
)
p_x_zl = -reconst_loss
q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1)
q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1)
to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x
batch_log_lkl = torch.logsumexp(to_sum, dim=-1) - np.log(n_mc_samples)
log_lkl = torch.sum(batch_log_lkl).item()
return log_lkl