Source code for scvi.data._anndata

import logging
import os
import pickle
import sys
import warnings
from typing import Dict, List, Optional, Tuple, Union

import anndata
import numpy as np
import pandas as pd
import rich
from anndata._core.anndata import AnnData
from pandas.api.types import CategoricalDtype
from rich.console import Console
from scipy.sparse import isspmatrix

import scvi
from scvi import _CONSTANTS
from scvi._compat import Literal

from ._utils import (
    _check_nonnegative_integers,
    _compute_library_size_batch,
    _get_batch_mask_protein_data,
)

logger = logging.getLogger(__name__)


def get_from_registry(adata: anndata.AnnData, key: str) -> np.ndarray:
    """
    Returns the object in AnnData associated with the key in ``.uns['_scvi']['data_registry']``.

    Parameters
    ----------
    adata
        anndata object already setup with `scvi.data.setup_anndata()`
    key
        key of object to get from ``adata.uns['_scvi]['data_registry']``

    Returns
    -------
    The requested data

    Examples
    --------
    >>> import scvi
    >>> adata = scvi.data.cortex()
    >>> adata.uns['_scvi']['data_registry']
    {'X': ['_X', None],
    'batch_indices': ['obs', 'batch'],
    'local_l_mean': ['obs', '_scvi_local_l_mean'],
    'local_l_var': ['obs', '_scvi_local_l_var'],
    'labels': ['obs', 'labels']}
    >>> batch = get_from_registry(adata, "batch_indices")
    >>> batch
    array([[0],
           [0],
           [0],
           ...,
           [0],
           [0],
           [0]])
    """
    data_loc = adata.uns["_scvi"]["data_registry"][key]
    attr_name, attr_key = data_loc["attr_name"], data_loc["attr_key"]

    data = getattr(adata, attr_name)
    if attr_key != "None":
        if isinstance(data, pd.DataFrame):
            data = data.loc[:, attr_key]
        else:
            data = data[attr_key]
    if isinstance(data, pd.Series):
        data = data.to_numpy().reshape(-1, 1)
    return data


[docs]def setup_anndata( adata: anndata.AnnData, batch_key: Optional[str] = None, labels_key: Optional[str] = None, layer: Optional[str] = None, protein_expression_obsm_key: Optional[str] = None, protein_names_uns_key: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, copy: bool = False, ) -> Optional[anndata.AnnData]: """ Sets up :class:`~anndata.AnnData` object for models. A mapping will be created between data fields used by models to their respective locations in adata. This method will also compute the log mean and log variance per batch for the library size prior. None of the data in adata are modified. Only adds fields to adata. Parameters ---------- adata AnnData object containing raw counts. Rows represent cells, columns represent features. batch_key key in `adata.obs` for batch information. Categories will automatically be converted into integer categories and saved to `adata.obs['_scvi_batch']`. If `None`, assigns the same batch to all the data. labels_key key in `adata.obs` for label information. Categories will automatically be converted into integer categories and saved to `adata.obs['_scvi_labels']`. If `None`, assigns the same label to all the data. layer if not `None`, uses this as the key in `adata.layers` for raw count data. protein_expression_obsm_key key in `adata.obsm` for protein expression data, Required for :class:`~scvi.model.TOTALVI`. protein_names_uns_key key in `adata.uns` for protein names. If None, will use the column names of `adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will assign sequential names to proteins. Only relevant but not required for :class:`~scvi.model.TOTALVI`. categorical_covariate_keys keys in `adata.obs` that correspond to categorical data. Used in some models. continuous_covariate_keys keys in `adata.obs` that correspond to continuous data. Used in some models. copy if `True`, a copy of adata is returned. Returns ------- If ``copy``, will return :class:`~anndata.AnnData`. Adds the following fields to adata: .uns['_scvi'] `scvi` setup dictionary .obs['_local_l_mean'] per batch library size mean .obs['_local_l_var'] per batch library size variance .obs['_scvi_labels'] labels encoded as integers .obs['_scvi_batch'] batch encoded as integers Examples -------- Example setting up a scanpy dataset with random gene data and no batch nor label information >>> import scanpy as sc >>> import scvi >>> import numpy as np >>> adata = scvi.data.synthetic_iid(run_setup_anndata=False) >>> adata AnnData object with n_obs × n_vars = 400 × 100 obs: 'batch', 'labels' uns: 'protein_names' obsm: 'protein_expression' Filter cells and run preprocessing before `setup_anndata` >>> sc.pp.filter_cells(adata, min_counts = 0) Since no batch_key nor labels_key was passed, setup_anndata() will assume all cells have the same batch and label >>> scvi.data.setup_anndata(adata) INFO No batch_key inputted, assuming all cells are same batch INFO No label_key inputted, assuming all cells have same label INFO Using data from adata.X INFO Computing library size prior per batch INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels'] INFO Successfully registered anndata object containing 400 cells, 100 vars, 1 batches, 1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. Example setting up scanpy dataset with random gene data, batch, and protein expression >>> adata = scvi.data.synthetic_iid(run_setup_anndata=False) >>> scvi.data.setup_anndata(adata, batch_key='batch', protein_expression_obsm_key='protein_expression') INFO Using batches from adata.obs["batch"] INFO No label_key inputted, assuming all cells have same label INFO Using data from adata.X INFO Computing library size prior per batch INFO Using protein expression from adata.obsm['protein_expression'] INFO Generating sequential protein names INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels', 'protein_expression'] INFO Successfully registered anndata object containing 400 cells, 100 vars, 2 batches, 1 labels, and 100 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. """ if copy: adata = adata.copy() if adata.is_view: raise ValueError( "Please run `adata = adata.copy()` or use the copy option in this function." ) adata.uns["_scvi"] = {} adata.uns["_scvi"]["scvi_version"] = scvi.__version__ batch_key = _setup_batch(adata, batch_key) labels_key = _setup_labels(adata, labels_key) x_loc, x_key = _setup_x(adata, layer) local_l_mean_key, local_l_var_key = _setup_library_size(adata, batch_key, layer) data_registry = { _CONSTANTS.X_KEY: {"attr_name": x_loc, "attr_key": x_key}, _CONSTANTS.BATCH_KEY: {"attr_name": "obs", "attr_key": batch_key}, _CONSTANTS.LOCAL_L_MEAN_KEY: {"attr_name": "obs", "attr_key": local_l_mean_key}, _CONSTANTS.LOCAL_L_VAR_KEY: {"attr_name": "obs", "attr_key": local_l_var_key}, _CONSTANTS.LABELS_KEY: {"attr_name": "obs", "attr_key": labels_key}, } if protein_expression_obsm_key is not None: protein_expression_obsm_key = _setup_protein_expression( adata, protein_expression_obsm_key, protein_names_uns_key, batch_key ) data_registry[_CONSTANTS.PROTEIN_EXP_KEY] = { "attr_name": "obsm", "attr_key": protein_expression_obsm_key, } if categorical_covariate_keys is not None: cat_loc, cat_key = _setup_extra_categorical_covs( adata, categorical_covariate_keys ) data_registry[_CONSTANTS.CAT_COVS_KEY] = { "attr_name": cat_loc, "attr_key": cat_key, } if continuous_covariate_keys is not None: cont_loc, cont_key = _setup_extra_continuous_covs( adata, continuous_covariate_keys ) data_registry[_CONSTANTS.CONT_COVS_KEY] = { "attr_name": cont_loc, "attr_key": cont_key, } # add the data_registry to anndata _register_anndata(adata, data_registry_dict=data_registry) logger.debug("Registered keys:{}".format(list(data_registry.keys()))) _setup_summary_stats( adata, batch_key, labels_key, protein_expression_obsm_key, categorical_covariate_keys, continuous_covariate_keys, ) logger.info("Please do not further modify adata until model is trained.") _verify_and_correct_data_format(adata, data_registry) if copy: return adata
def _set_data_in_registry(adata, data, key): """ Sets the data associated with key in adata.uns['_scvi']['data_registry'].keys() to data. Note: This is a dangerous method and will change the underlying data of the user's anndata Currently used to make the user's anndata C_CONTIGUOUS and csr if it is dense numpy or sparse respectively. Parameters ---------- adata anndata object to change data of data data to change to key key in adata.uns['_scvi]['data_registry'].keys() associated with the data """ data_loc = adata.uns["_scvi"]["data_registry"][key] attr_name, attr_key = data_loc["attr_name"], data_loc["attr_key"] if attr_key == "None": setattr(adata, attr_name, data) elif attr_key != "None": attribute = getattr(adata, attr_name) if isinstance(attribute, pd.DataFrame): attribute.loc[:, attr_key] = data else: attribute[attr_key] = data setattr(adata, attr_name, attribute) def _verify_and_correct_data_format(adata, data_registry): """ Will make sure that the user's anndata is C_CONTIGUOUS and csr if it is dense numpy or sparse respectively. Will iterate through all the keys of data_registry. Parameters ---------- adata anndata to check data_registry data registry of anndata """ keys_to_check = [_CONSTANTS.X_KEY, _CONSTANTS.PROTEIN_EXP_KEY] keys = [key for key in keys_to_check if key in data_registry.keys()] for k in keys: data = get_from_registry(adata, k) if isspmatrix(data) and (data.getformat() != "csr"): warnings.warn( "Training will be faster when sparse matrix is formatted as CSR. It is safe to cast before model initialization." ) elif isinstance(data, np.ndarray) and (data.flags["C_CONTIGUOUS"] is False): logger.debug( "{} is not C_CONTIGUOUS. Overwriting to C_CONTIGUOUS.".format(k) ) data = np.asarray(data, order="C") _set_data_in_registry(adata, data, k) elif isinstance(data, pd.DataFrame) and ( data.to_numpy().flags["C_CONTIGUOUS"] is False ): logger.debug( "{} is not C_CONTIGUOUS. Overwriting to C_CONTIGUOUS.".format(k) ) index = data.index vals = data.to_numpy() columns = data.columns data = pd.DataFrame( np.ascontiguousarray(vals), index=index, columns=columns ) _set_data_in_registry(adata, data, k) def register_tensor_from_anndata( adata: anndata.AnnData, registry_key: str, adata_attr_name: Literal["obs", "var", "obsm", "varm", "uns"], adata_key_name: str, is_categorical: Optional[bool] = False, adata_alternate_key_name: Optional[str] = None, ): """ Add another tensor to scvi data registry. This function is intended for contributors testing out new models. Parameters ---------- adata AnnData with "_scvi" key in `.uns` registry_key Key for tensor in registry, which will be the key in the dataloader output adata_attr_name AnnData attribute with tensor adata_key_name key in adata_attr_name with data is_categorical Whether or not data is categorical adata_alternate_key_name Added key in adata_attr_name for categorical codes if `is_categorical` is True """ if is_categorical is True: if adata_attr_name != "obs": raise ValueError("categorical handling only implemented for data in `.obs`") if adata_alternate_key_name is None: adata_alternate_key_name = adata_key_name + "_scvi" if is_categorical is True and adata_attr_name == "obs": adata_key_name = _make_obs_column_categorical( adata, column_key=adata_key_name, alternate_column_key=adata_alternate_key_name, ) new_dict = { registry_key: {"attr_name": adata_attr_name, "attr_key": adata_key_name} } data_registry = adata.uns["_scvi"]["data_registry"] data_registry.update(new_dict) _verify_and_correct_data_format(adata, data_registry) def transfer_anndata_setup( adata_source: Union[anndata.AnnData, dict], adata_target: anndata.AnnData, extend_categories: bool = False, ): """ Transfer anndata setup from a source object to a target object. This handles encoding for categorical data and is useful in the case where an anndata object has been subsetted and a category is lost. Parameters ---------- adata_source AnnData that has been setup with scvi. If `dict`, must be dictionary from source anndata containing scvi setup parameters. adata_target AnnData with equivalent organization as source, but possibly subsetted. extend_categories New categories in `adata_target` are added to the registry. """ adata_target.uns["_scvi"] = {} if isinstance(adata_source, anndata.AnnData): _scvi_dict = adata_source.uns["_scvi"] else: _scvi_dict = adata_source data_registry = _scvi_dict["data_registry"] summary_stats = _scvi_dict["summary_stats"] # transfer version adata_target.uns["_scvi"]["scvi_version"] = _scvi_dict["scvi_version"] x_loc = data_registry[_CONSTANTS.X_KEY]["attr_name"] if x_loc == "layers": layer = data_registry[_CONSTANTS.X_KEY]["attr_key"] else: layer = None target_n_vars = adata_target.shape[1] if target_n_vars != summary_stats["n_vars"]: raise ValueError( "Number of vars in adata_target not the same as source. " + "Expected: {} Received: {}".format(target_n_vars, summary_stats["n_vars"]) ) # transfer batch and labels categorical_mappings = _scvi_dict["categorical_mappings"] _transfer_batch_and_labels(adata_target, categorical_mappings, extend_categories) batch_key = "_scvi_batch" labels_key = "_scvi_labels" # transfer protein_expression protein_expression_obsm_key = _transfer_protein_expression( _scvi_dict, adata_target, batch_key ) # transfer X x_loc, x_key = _setup_x(adata_target, layer) local_l_mean_key, local_l_var_key = _setup_library_size( adata_target, batch_key, layer ) target_data_registry = data_registry.copy() target_data_registry.update( {_CONSTANTS.X_KEY: {"attr_name": x_loc, "attr_key": x_key}} ) # transfer extra categorical covs has_cat_cov = True if _CONSTANTS.CAT_COVS_KEY in data_registry.keys() else False if has_cat_cov: source_cat_dict = _scvi_dict["extra_categoricals"]["mappings"].copy() # extend categories if extend_categories: for key, mapping in source_cat_dict.items(): for c in np.unique(adata_target.obs[key]): if c not in mapping: mapping = np.concatenate([mapping, [c]]) source_cat_dict[key] = mapping # use ["keys"] to maintain correct order cat_loc, cat_key = _setup_extra_categorical_covs( adata_target, _scvi_dict["extra_categoricals"]["keys"], category_dict=source_cat_dict, ) target_data_registry.update( {_CONSTANTS.CAT_COVS_KEY: {"attr_name": cat_loc, "attr_key": cat_key}} ) else: source_cat_dict = None # transfer extra continuous covs has_cont_cov = True if _CONSTANTS.CONT_COVS_KEY in data_registry.keys() else False if has_cont_cov: obs_keys_names = _scvi_dict["extra_continuous_keys"] cont_loc, cont_key = _setup_extra_continuous_covs( adata_target, list(obs_keys_names) ) target_data_registry.update( {_CONSTANTS.CONT_COVS_KEY: {"attr_name": cont_loc, "attr_key": cont_key}} ) else: obs_keys_names = None # add the data_registry to anndata _register_anndata(adata_target, data_registry_dict=target_data_registry) logger.info("Registered keys:{}".format(list(target_data_registry.keys()))) _setup_summary_stats( adata_target, batch_key, labels_key, protein_expression_obsm_key, source_cat_dict, obs_keys_names, ) _verify_and_correct_data_format(adata_target, data_registry) def _transfer_batch_and_labels(adata_target, categorical_mappings, extend_categories): for key, val in categorical_mappings.items(): original_key = val["original_key"] if (key == original_key) and (original_key not in adata_target.obs.keys()): # case where original key and key are equal # caused when no batch or label key were given # when anndata_source was setup logger.info( ".obs[{}] not found in target, assuming every cell is same category".format( original_key ) ) adata_target.obs[original_key] = np.zeros( adata_target.shape[0], dtype=np.int64 ) elif (key != original_key) and (original_key not in adata_target.obs.keys()): raise KeyError( '.obs["{}"] was used to setup source, but not found in target.'.format( original_key ) ) mapping = val["mapping"].copy() # extend mapping for new categories if extend_categories: for c in np.unique(adata_target.obs[original_key]): if c not in mapping: mapping = np.concatenate([mapping, [c]]) cat_dtype = CategoricalDtype(categories=mapping, ordered=True) _make_obs_column_categorical( adata_target, original_key, key, categorical_dtype=cat_dtype ) def _transfer_protein_expression(_scvi_dict, adata_target, batch_key): data_registry = _scvi_dict["data_registry"] summary_stats = _scvi_dict["summary_stats"] has_protein = True if _CONSTANTS.PROTEIN_EXP_KEY in data_registry.keys() else False if has_protein is True: prev_protein_obsm_key = data_registry[_CONSTANTS.PROTEIN_EXP_KEY]["attr_key"] if prev_protein_obsm_key not in adata_target.obsm.keys(): raise KeyError( "Can't find {} in adata_target.obsm for protein expressions.".format( prev_protein_obsm_key ) ) else: assert ( summary_stats["n_proteins"] == adata_target.obsm[prev_protein_obsm_key].shape[1] ) protein_expression_obsm_key = prev_protein_obsm_key adata_target.uns["_scvi"]["protein_names"] = _scvi_dict["protein_names"] # batch mask totalVI batch_mask = _get_batch_mask_protein_data( adata_target, protein_expression_obsm_key, batch_key ) # check if it's actually needed if np.sum([~b[1] for b in batch_mask.items()]) > 0: logger.info("Found batches with missing protein expression") adata_target.uns["_scvi"]["totalvi_batch_mask"] = batch_mask else: protein_expression_obsm_key = None return protein_expression_obsm_key def _assert_key_in_obs(adata, key): assert key in adata.obs.keys(), "{} is not a valid key for in adata.obs".format(key) def _setup_labels(adata, labels_key): # checking labels if labels_key is None: logger.info("No label_key inputted, assuming all cells have same label") labels_key = "_scvi_labels" adata.obs[labels_key] = np.zeros(adata.shape[0], dtype=np.int64) alt_key = labels_key else: _assert_key_in_obs(adata, labels_key) logger.info('Using labels from adata.obs["{}"]'.format(labels_key)) alt_key = "_scvi_labels" labels_key = _make_obs_column_categorical( adata, column_key=labels_key, alternate_column_key=alt_key ) return labels_key def _setup_batch(adata, batch_key): # checking batch if batch_key is None: logger.info("No batch_key inputted, assuming all cells are same batch") batch_key = "_scvi_batch" adata.obs[batch_key] = np.zeros(adata.shape[0], dtype=np.int64) alt_key = batch_key else: _assert_key_in_obs(adata, batch_key) logger.info('Using batches from adata.obs["{}"]'.format(batch_key)) alt_key = "_scvi_batch" batch_key = _make_obs_column_categorical( adata, column_key=batch_key, alternate_column_key=alt_key ) return batch_key def _setup_extra_categorical_covs( adata: anndata.AnnData, categorical_covariate_keys: List[str], category_dict: Dict[str, List[str]] = None, ): """ Setup obsm df for extra categorical covariates. Parameters ---------- adata AnnData to setup categorical_covariate_keys List of keys in adata.obs with categorical data category_dict Optional dictionary with keys being keys of categorical data in obs and values being precomputed categories for each obs vector """ for key in categorical_covariate_keys: _assert_key_in_obs(adata, key) cat_loc = "obsm" cat_key = "_scvi_extra_categoricals" adata.uns["_scvi"]["extra_categoricals"] = {} categories = {} df = pd.DataFrame(index=adata.obs_names) for key in categorical_covariate_keys: if category_dict is None: categorical_obs = adata.obs[key].astype("category") mapping = categorical_obs.cat.categories.to_numpy(copy=True) categories[key] = mapping else: possible_cats = category_dict[key] categorical_obs = adata.obs[key].astype( CategoricalDtype(categories=possible_cats) ) codes = categorical_obs.cat.codes df[key] = codes adata.obsm[cat_key] = df store_cats = categories if category_dict is None else category_dict adata.uns["_scvi"]["extra_categoricals"]["mappings"] = store_cats # this preserves the order of the keys added to the df adata.uns["_scvi"]["extra_categoricals"]["keys"] = categorical_covariate_keys # how many cats per key, in the preserved order n_cats_per_key = [] for k in categorical_covariate_keys: n_cats_per_key.append(len(store_cats[k])) adata.uns["_scvi"]["extra_categoricals"]["n_cats_per_key"] = n_cats_per_key return cat_loc, cat_key def _setup_extra_continuous_covs( adata: anndata.AnnData, continuous_covariate_keys: List[str] ): """ Setup obsm df for extra continuous covariates. Parameters ---------- adata AnnData to setup continuous_covariate_keys List of keys in adata.obs with continuous data """ for key in continuous_covariate_keys: _assert_key_in_obs(adata, key) cont_loc = "obsm" cont_key = "_scvi_extra_continuous" series = [] for key in continuous_covariate_keys: s = adata.obs[key] series.append(s) adata.obsm[cont_key] = pd.concat(series, axis=1) adata.uns["_scvi"]["extra_continuous_keys"] = adata.obsm[ cont_key ].columns.to_numpy() return cont_loc, cont_key def _make_obs_column_categorical( adata, column_key, alternate_column_key, categorical_dtype=None ): """ Makes the data in column_key in obs all categorical. If adata.obs[column_key] is not categorical, will categorize and save to .obs[alternate_column_key] """ if categorical_dtype is None: categorical_obs = adata.obs[column_key].astype("category") else: categorical_obs = adata.obs[column_key].astype(categorical_dtype) # put codes in .obs[alternate_column_key] codes = categorical_obs.cat.codes mapping = categorical_obs.cat.categories.to_numpy(copy=True) if -1 in np.unique(codes): received_categories = adata.obs[column_key].astype("category").cat.categories raise ValueError( 'Making .obs["{}"] categorical failed. Expected categories: {}. ' "Received categories: {}. ".format(column_key, mapping, received_categories) ) adata.obs[alternate_column_key] = codes # store categorical mappings store_dict = { alternate_column_key: {"original_key": column_key, "mapping": mapping} } if "categorical_mappings" not in adata.uns["_scvi"].keys(): adata.uns["_scvi"].update({"categorical_mappings": store_dict}) else: adata.uns["_scvi"]["categorical_mappings"].update(store_dict) # make sure each category contains enough cells unique, counts = np.unique(adata.obs[alternate_column_key], return_counts=True) if np.min(counts) < 3: category = unique[np.argmin(counts)] warnings.warn( "Category {} in adata.obs['{}'] has fewer than 3 cells. SCVI may not train properly.".format( category, alternate_column_key ) ) # possible check for continuous? if len(unique) > (adata.shape[0] / 3): warnings.warn( "Is adata.obs['{}'] continuous? SCVI doesn't support continuous obs yet." ) return alternate_column_key def _setup_protein_expression( adata, protein_expression_obsm_key, protein_names_uns_key, batch_key ): assert ( protein_expression_obsm_key in adata.obsm.keys() ), "{} is not a valid key in adata.obsm".format(protein_expression_obsm_key) logger.info( "Using protein expression from adata.obsm['{}']".format( protein_expression_obsm_key ) ) pro_exp = adata.obsm[protein_expression_obsm_key] if _check_nonnegative_integers(pro_exp) is False: warnings.warn( "adata.obsm[{}] does not contain unnormalized count data. Are you sure this is what you want?".format( protein_expression_obsm_key ) ) # setup protein names if protein_names_uns_key is None and isinstance( adata.obsm[protein_expression_obsm_key], pd.DataFrame ): logger.info( "Using protein names from columns of adata.obsm['{}']".format( protein_expression_obsm_key ) ) protein_names = list(adata.obsm[protein_expression_obsm_key].columns) elif protein_names_uns_key is not None: logger.info( "Using protein names from adata.uns['{}']".format(protein_names_uns_key) ) protein_names = adata.uns[protein_names_uns_key] else: logger.info("Generating sequential protein names") protein_names = np.arange(adata.obsm[protein_expression_obsm_key].shape[1]) adata.uns["_scvi"]["protein_names"] = protein_names # batch mask totalVI batch_mask = _get_batch_mask_protein_data( adata, protein_expression_obsm_key, batch_key ) # check if it's actually needed if np.sum([~b[1] for b in batch_mask.items()]) > 0: logger.info("Found batches with missing protein expression") adata.uns["_scvi"]["totalvi_batch_mask"] = batch_mask return protein_expression_obsm_key def _setup_x(adata, layer): if layer is not None: assert ( layer in adata.layers.keys() ), "{} is not a valid key in adata.layers".format(layer) logger.info('Using data from adata.layers["{}"]'.format(layer)) x_loc = "layers" x_key = layer x = adata.layers[x_key] else: logger.info("Using data from adata.X") x_loc = "X" x_key = "None" x = adata.X if _check_nonnegative_integers(x) is False: logger_data_loc = ( "adata.X" if layer is None else "adata.layers[{}]".format(layer) ) warnings.warn( "{} does not contain unnormalized count data. Are you sure this is what you want?".format( logger_data_loc ) ) return x_loc, x_key def _setup_library_size(adata, batch_key, layer): # computes the library size per batch logger.info("Computing library size prior per batch") local_l_mean_key = "_scvi_local_l_mean" local_l_var_key = "_scvi_local_l_var" _compute_library_size_batch( adata, batch_key=batch_key, local_l_mean_key=local_l_mean_key, local_l_var_key=local_l_var_key, layer=layer, ) return local_l_mean_key, local_l_var_key def _setup_summary_stats( adata, batch_key, labels_key, protein_expression_obsm_key, categorical_covariate_keys, continuous_covariate_keys, ): categorical_mappings = adata.uns["_scvi"]["categorical_mappings"] n_batch = len(np.unique(categorical_mappings[batch_key]["mapping"])) n_cells = adata.shape[0] n_vars = adata.shape[1] n_labels = len(np.unique(categorical_mappings[labels_key]["mapping"])) if protein_expression_obsm_key is not None: n_proteins = adata.obsm[protein_expression_obsm_key].shape[1] else: n_proteins = 0 if categorical_covariate_keys is not None: n_cat_covs = len(categorical_covariate_keys) else: n_cat_covs = 0 if continuous_covariate_keys is not None: n_cont_covs = len(continuous_covariate_keys) else: n_cont_covs = 0 summary_stats = { "n_batch": n_batch, "n_cells": n_cells, "n_vars": n_vars, "n_labels": n_labels, "n_proteins": n_proteins, "n_continuous_covs": n_cont_covs, } adata.uns["_scvi"]["summary_stats"] = summary_stats logger.info( "Successfully registered anndata object containing {} cells, {} vars, " "{} batches, {} labels, and {} proteins. Also registered {} extra categorical " "covariates and {} extra continuous covariates.".format( n_cells, n_vars, n_batch, n_labels, n_proteins, n_cat_covs, n_cont_covs ) ) return summary_stats def _register_anndata(adata, data_registry_dict: Dict[str, Tuple[str, str]]): """ Registers the AnnData object by adding data_registry_dict to adata.uns['_scvi']['data_registry']. Format of data_registry_dict is: {<scvi_key>: (<anndata dataframe>, <dataframe key> )} Parameters ---------- adata anndata object data_registry_dict dictionary mapping keys used by scvi.model to their respective location in adata. Examples -------- >>> data_dict = {"batch" :("obs", "batch_idx"), "X": ("_X", None)} >>> _register_anndata(adata, data_dict) """ adata.uns["_scvi"]["data_registry"] = data_registry_dict.copy() def view_anndata_setup(source: Union[anndata.AnnData, dict, str]): """ Prints setup anndata. Parameters ---------- source Either AnnData, path to saved AnnData, path to folder with adata.h5ad, or scvi-setup-dict (adata.uns['_scvi']) Examples -------- >>> scvi.data.view_anndata_setup(adata) >>> scvi.data.view_anndata_setup('saved_model_folder/adata.h5ad') >>> scvi.data.view_anndata_setup('saved_model_folder/') >>> scvi.data.view_anndata_setup(adata.uns['_scvi']) """ if isinstance(source, anndata.AnnData): adata = source elif isinstance(source, str): # check if user passed in folder or anndata if source.endswith("h5ad"): path = source adata = anndata.read(path) else: path = os.path.join(source, "adata.h5ad") if os.path.exists(path): adata = anndata.read(path) else: path = os.path.join(source, "attr.pkl") with open(path, "rb") as handle: adata = None setup_dict = pickle.load(handle)["scvi_setup_dict_"] elif isinstance(source, dict): adata = None setup_dict = source else: raise ValueError( "Invalid source passed in. Must be either AnnData, path to saved AnnData, " + "path to folder with adata.h5ad or scvi-setup-dict (adata.uns['_scvi'])" ) if adata is not None: if "_scvi" not in adata.uns.keys(): raise ValueError("Please run setup_anndata() on your adata first.") setup_dict = adata.uns["_scvi"] summary_stats = setup_dict["summary_stats"] data_registry = setup_dict["data_registry"] mappings = setup_dict["categorical_mappings"] version = setup_dict["scvi_version"] rich.print("Anndata setup with scvi-tools version {}.".format(version)) n_cat = 0 n_covs = 0 if "extra_categoricals" in setup_dict.keys(): n_cat = len(setup_dict["extra_categoricals"]["mappings"]) if "extra_continuous_keys" in setup_dict.keys(): n_covs = len(setup_dict["extra_continuous_keys"]) in_colab = "google.colab" in sys.modules force_jupyter = None if not in_colab else True console = Console(force_jupyter=force_jupyter) t = rich.table.Table(title="Data Summary") t.add_column( "Data", justify="center", style="dodger_blue1", no_wrap=True, overflow="fold" ) t.add_column( "Count", justify="center", style="dark_violet", no_wrap=True, overflow="fold" ) data_summary = { "Cells": summary_stats["n_cells"], "Vars": summary_stats["n_vars"], "Labels": summary_stats["n_labels"], "Batches": summary_stats["n_batch"], "Proteins": summary_stats["n_proteins"], "Extra Categorical Covariates": n_cat, "Extra Continuous Covariates": n_covs, } for data, count in data_summary.items(): t.add_row(data, str(count)) console.print(t) t = rich.table.Table(title="SCVI Data Registry") t.add_column( "Data", justify="center", style="dodger_blue1", no_wrap=True, overflow="fold" ) t.add_column( "scvi-tools Location", justify="center", style="dark_violet", no_wrap=True, overflow="fold", ) for scvi_data_key, data_loc in data_registry.items(): attr_name = data_loc["attr_name"] attr_key = data_loc["attr_key"] if attr_key == "None": scvi_data_str = "adata.{}".format(attr_name) else: scvi_data_str = "adata.{}['{}']".format(attr_name, attr_key) t.add_row(scvi_data_key, scvi_data_str) console.print(t) t = _categorical_mappings_table("Label Categories", "_scvi_labels", mappings) console.print(t) t = _categorical_mappings_table("Batch Categories", "_scvi_batch", mappings) console.print(t) if "extra_categoricals" in setup_dict.keys(): t = _extra_categoricals_table(setup_dict) console.print(t) if "extra_continuous_keys" in setup_dict.keys(): t = _extra_continuous_table(adata, setup_dict) console.print(t) def _extra_categoricals_table(setup_dict: dict): """Returns rich.table.Table with info on extra categorical variables.""" t = rich.table.Table(title="Extra Categorical Variables") t.add_column( "Source Location", justify="center", style="dodger_blue1", no_wrap=True, overflow="fold", ) t.add_column( "Categories", justify="center", style="green", no_wrap=True, overflow="fold" ) t.add_column( "scvi-tools Encoding", justify="center", style="dark_violet", no_wrap=True, overflow="fold", ) for key, mappings in setup_dict["extra_categoricals"]["mappings"].items(): for i, mapping in enumerate(mappings): if i == 0: t.add_row("adata.obs['{}']".format(key), str(mapping), str(i)) else: t.add_row("", str(mapping), str(i)) t.add_row("", "") return t def _extra_continuous_table(adata: Optional[anndata.AnnData], setup_dict: dict): """Returns rich.table.Table with info on extra continuous variables.""" t = rich.table.Table(title="Extra Continuous Variables") t.add_column( "Source Location", justify="center", style="dodger_blue1", no_wrap=True, overflow="fold", ) if adata is not None: t.add_column( "Range", justify="center", style="dark_violet", no_wrap=True, overflow="fold", ) cont_covs = scvi.data.get_from_registry(adata, "cont_covs") for cov in cont_covs.iteritems(): col_name, values = cov[0], cov[1] min_val = np.min(values) max_val = np.max(values) t.add_row( "adata.obs['{}']".format(col_name), "{:.20g} -> {:.20g}".format(min_val, max_val), ) else: for key in setup_dict["extra_continuous_keys"]: t.add_row("adata.obs['{}']".format(key)) return t def _categorical_mappings_table(title: str, scvi_column: str, mappings: dict): """ Returns rich.table.Table with info on a categorical variable. Parameters ---------- title title of table scvi_column column used by scvi for categorical representation mappings output of adata.uns['_scvi']['categorical_mappings'], containing mapping between scvi_column and original column and categories """ source_key = mappings[scvi_column]["original_key"] mapping = mappings[scvi_column]["mapping"] t = rich.table.Table(title=title) t.add_column( "Source Location", justify="center", style="dodger_blue1", no_wrap=True, overflow="fold", ) t.add_column( "Categories", justify="center", style="green", no_wrap=True, overflow="fold" ) t.add_column( "scvi-tools Encoding", justify="center", style="dark_violet", no_wrap=True, overflow="fold", ) for i, cat in enumerate(mapping): if i == 0: t.add_row("adata.obs['{}']".format(source_key), str(cat), str(i)) else: t.add_row("", str(cat), str(i)) return t def _check_anndata_setup_equivalence( adata_source: Union[AnnData, dict], adata_target: AnnData ) -> bool: """ Checks if target setup is equivalent to source. Parameters ---------- adata_source Either AnnData already setup or scvi_setup_dict as the source adata_target Target AnnData to check setup equivalence Returns ------- Whether the adata_target should be run through `transfer_anndata_setup` """ if isinstance(adata_source, anndata.AnnData): _scvi_dict = adata_source.uns["_scvi"] else: _scvi_dict = adata_source adata = adata_target stats = _scvi_dict["summary_stats"] target_n_vars = adata.shape[1] error_msg = ( "Number of {} in anndata different from initial anndata used for training." ) if target_n_vars != stats["n_vars"]: raise ValueError(error_msg.format("vars")) error_msg = ( "There are more {} categories in the data than were originally registered. " + "Please check your {} categories as well as adata.uns['_scvi']['categorical_mappings']." ) self_categoricals = _scvi_dict["categorical_mappings"] self_batch_mapping = self_categoricals["_scvi_batch"]["mapping"] adata_categoricals = adata.uns["_scvi"]["categorical_mappings"] adata_batch_mapping = adata_categoricals["_scvi_batch"]["mapping"] # check if mappings are equal or needs transfer transfer_setup = _needs_transfer(self_batch_mapping, adata_batch_mapping, "batch") self_labels_mapping = self_categoricals["_scvi_labels"]["mapping"] adata_labels_mapping = adata_categoricals["_scvi_labels"]["mapping"] transfer_setup = transfer_setup or _needs_transfer( self_labels_mapping, adata_labels_mapping, "label" ) # validate any extra categoricals error_msg = ( "Registered categorical key order mismatch between " + "the anndata used to train and the anndata passed in." + "Expected categories & order {}. Received {}.\n" ) if "extra_categoricals" in _scvi_dict.keys(): target_dict = adata.uns["_scvi"]["extra_categoricals"] source_dict = _scvi_dict["extra_categoricals"] # check that order of keys setup is same if not np.array_equal(target_dict["keys"], source_dict["keys"]): raise ValueError(error_msg.format(source_dict["keys"], target_dict["keys"])) # check mappings are equivalent target_extra_cat_maps = adata.uns["_scvi"]["extra_categoricals"]["mappings"] for key, val in source_dict["mappings"].items(): target_map = target_extra_cat_maps[key] transfer_setup = transfer_setup or _needs_transfer(val, target_map, key) # validate any extra continuous covs if "extra_continuous_keys" in _scvi_dict.keys(): if "extra_continuous_keys" not in adata.uns["_scvi"].keys(): raise ValueError('extra_continuous_keys not in adata.uns["_scvi"]') target_cont_keys = adata.uns["_scvi"]["extra_continuous_keys"] source_cont_keys = _scvi_dict["extra_continuous_keys"] # check that order of keys setup is same if not np.array_equal(target_cont_keys, source_cont_keys): raise ValueError(error_msg.format(source_cont_keys, target_cont_keys)) return transfer_setup def _needs_transfer(mapping1, mapping2, category): needs_transfer = False error_msg = ( "Categorial encoding for {} is not the same between " + "the anndata used to train and the anndata passed in. " + "Categorical encoding needs to be same elements, same order, and same datatype.\n" + "Expected categories: {}. Received categories: {}.\n" ) warning_msg = ( "Categorical encoding for {} is similar but not equal between " + "the anndata used to train and the anndata passed in. " + "Will attempt transfer. Expected categories: {}. Received categories: {}.\n " ) if _is_equal_mapping(mapping1, mapping2): needs_transfer = False elif _is_similar_mapping(mapping1, mapping2): needs_transfer = True warnings.warn(warning_msg.format(category, mapping1, mapping2)) else: raise ValueError(error_msg.format(category, mapping1, mapping2)) return needs_transfer def _is_similar_mapping(mapping1, mapping2): """Returns True if mapping2 is a subset of mapping1.""" if len(set(mapping2) - set(mapping1)) == 0: return True else: return False def _is_equal_mapping(mapping1, mapping2): return pd.Index(mapping1).equals(pd.Index(mapping2))