from __future__ import annotations
import ctypes
import h5py
import mmap
import numpy as np
import os
import polars as pl
import pyarrow as pa
import re
import signal
import sys
import warnings
from collections.abc import Iterable
from itertools import chain, islice
from pathlib import Path
from scipy import sparse
from textwrap import fill
from threadpoolctl import threadpool_info, threadpool_limits
from typing import Any, Callable, Literal, Mapping, NoReturn, Sequence
from .pseudobulk import Pseudobulk
from .sparse import csc_array, csr_array
from .type_aliases import Color, Indexer, Scalar, UnsDict, UnsItem, \
SingleCellColumn
from .utils import array_equal, cast_to_Enum, check_bounds, check_dtype, \
check_R_variable_name, check_type, check_types, concatenate, \
filter_columns, generate_palette, getnnz, getnnz_at_least_threshold, \
import_cython, ix_symmetric, numa_zeros, parallel_subset_1d, \
parallel_subset_2d, plural, read_parallel_multiprocessing, sparse_equal, \
sparse_major_stack, sparse_minor_stack, to_tuple, to_tuple_checked
from .validated_dict import Obsm, Obsp, Uns, Varm, Varp
import_cython({
'doublets': ('call_doublets', 'compute_cxds', 'compute_obs', 'compute_S',
'get_hvgs', 'simulate_doublets'),
'embed': ('localmap', 'pacmap', 'umap_fuzzy_weights', 'umap_optimize'),
'harmonize': ('label_transfer', 'harmony', 'harmony_original',
'normalize_rows'),
'hdf5': 'read_all_datasets',
'hvg': ('clipped_sum_csc', 'clipped_sum_csr', 'gene_mean_and_variance_csc',
'gene_mean_and_variance_csr'),
'kmeans': 'kmeans',
'knn': ('knn_self', 'knn_cross'),
'leiden': ('leiden', 'leiden_multiresolution'),
'normalize': ('normalize_csc', 'normalize_csr'),
'pca': 'irlba',
'pseudobulk_and_markers': (
'get_detection_rate', 'get_detection_rate_and_fold_change',
'get_detection_rate_and_fold_change_and_pareto_candidates',
'groupby_getnnz_csc', 'groupby_getnnz_csc_for_gene_subset',
'groupby_getnnz_and_total_csc_for_gene_subset',
'groupby_getnnz_csr', 'groupby_getnnz_csr_for_gene_subset',
'groupby_getnnz_and_total_csr_for_gene_subset', 'groupby_sum_csr',
'groupby_sum_csc', 'pareto_front'),
'qc': ('malat1_mask_csr', 'malat1_mask_csr_check', 'malat1_mask_csr_scan',
'mito_mask_csr', 'mito_mask_csc', 'qc_metrics_csr',
'qc_metrics_csc'),
'snn': 'snn'})
[docs]
class SingleCell:
"""
A single-cell dataset.
Has slots for:
- `X`: a scipy sparse array of counts per cell and gene
- `obs`: a polars DataFrame of cell metadata
- `var`: a polars DataFrame of gene metadata
- `obsm`: a dictionary of NumPy arrays and polars DataFrames of cell
metadata
- `varm`: a dictionary of NumPy arrays and polars DataFrames of gene
metadata
- `uns`: a dictionary of scalars (strings, numbers or Booleans) or NumPy
arrays, or nested dictionaries thereof
- `num_threads`: the default number of threads to use for operations on the
dataset that support multithreading (which can be overridden by
individual functions)
as well as `obs_names` and `var_names`, aliases for `obs[:, 0]` and
`var[:, 0]`.
"""
[docs]
def __init__(self,
source: str | Path | 'AnnData' | None = None,
/,
*,
X: sparse.csr_array | sparse.csc_array | sparse.csr_matrix |
sparse.csc_matrix | Literal[False] | None = None,
obs: pl.DataFrame | None = None,
var: pl.DataFrame | None = None,
obsm: dict[str, np.ndarray | pl.DataFrame] |
Literal[False] | None = None,
varm: dict[str, np.ndarray | pl.DataFrame] |
Literal[False] | None = None,
obsp: dict[str, sparse.csr_array | sparse.csc_array |
sparse.csr_matrix | sparse.csc_matrix] |
Literal[False] | None = None,
varp: dict[str, sparse.csr_array | sparse.csc_array |
sparse.csr_matrix | sparse.csc_matrix] |
Literal[False] | None = None,
uns: UnsDict | Literal[False] | None = None,
X_key: str | None = None,
assay: str | None = None,
obs_columns: str | Iterable[str] = None,
var_columns: str | Iterable[str] = None,
num_threads: int | np.integer = -1) -> None:
"""
Load a SingleCell dataset from a file, or create one from an in-memory
AnnData object or count matrix + metadata.
SingleCell supports reading and writing files from each of the three
major single-cell ecosystems:
- scverse/Scanpy AnnData (`.h5ad`)
- Seurat (`.rds` and `.h5Seurat`)
- Bioconductor SingleCellExperiment (`.rds`)
as well as raw 10x data files (`.h5` or `.mtx`/`.mtx.gz`).
By default, when an AnnData object, `.h5ad` file, `.h5Seurat` file, or
`.rds` file contains both raw and normalized counts, only the raw
counts will be loaded. To load normalized counts instead, use the `X`
argument (for AnnData objects) or `X_key` argument (for files).
Reading and writing `.rds` files requires the ryp Python-R bridge.
To create a SingleCell dataset from an in-memory Seurat or
SingleCellExperiment object in the ryp R workspace, use
`SingleCell.from_seurat()` or `SingleCell.from_sce()`.
Reading and writing loom files is not supported because SingleCell only
supports sparse count matrices, and loom only supports dense matrices.
Using loom files for SingleCell data is not recommended due to this
wastefulness. If you must, load them with
`SingleCell(scanpy.read_loom(loom_filename))`; `read_loom()` implicitly
converts the counts to a sparse matrix by default.
Args:
source: a filename or AnnData object, or `None` if specifying `X`,
`obs`, and `var` instead. Supported file formats are
scverse/Scanpy AnnData (`.h5ad`), Seurat (`.rds` and
`.h5Seurat`), Bioconductor SingleCellExperiment (`.rds`),
and raw 10x data files (`.h5` or `.mtx`/`.mtx.gz`). If
`source` is a 10x `.mtx`/`.mtx.gz` filename,
`barcodes.tsv`/`barcodes.tsv.gz` and
`features.tsv`/`features.tsv.gz` are assumed to be in the
same directory (with the ungzipped versions used
preferentially), unless custom paths to these files are
specified via the `obs` and/or `var` arguments.
X: If `source` is `None`, the data as a sparse array or matrix
(with rows = cells, columns = genes).
If `source` is an AnnData object, an optional sparse array or
matrix to use as `X`. By default, `X` will be loaded from
`source.layers['UMIs']` or `source.raw.X` if present and
`source.X` otherwise.
If `X` is `None` when `source` is `None`, or `False` when
`source` is a filename, do not store any data in `X` and set it
to `None`. This helps save memory, but the resulting dataset
cannot be saved, converted to another format, or used to run
analyses that require `X`.
obs: a polars DataFrame of metadata for each cell (row of `X`), or
`None` if specifying `source` instead. Or, if `source` is a
10x `.mtx`/`.mtx.gz` filename, an optional filename for
cell-level metadata, which is otherwise assumed to be at
`barcodes.tsv` (or `barcodes.tsv.gz`) in the same directory as
the `.mtx`/`.mtx.gz` file.
var: a polars DataFrame of metadata for each gene (column of `X`),
or `None` if specifying `source` instead. Or, if `source` is a
10x `.mtx`/`.mtx.gz` filename, an optional filename for
gene-level metadata, which is otherwise assumed to be at
`features.tsv` (or `features.tsv.gz`) in the same directory as
the `.mtx`/`.mtx.gz` file.
obsm: an optional dictionary mapping string names to NumPy arrays
and polars DataFrames of metadata for each cell, or `False`
to skip loading `obsm` when reading `.h5ad` and `.h5Seurat`
files
varm: an optional dictionary mapping string names to NumPy arrays
and polars DataFrames of metadata for each gene, or `False`
to skip loading `varm` when reading `.h5ad` files
obsp: an optional dictionary mapping string names to sparse arrays
or matrices containing pairwise cell-cell information like
nearest-neighbors graphs, or `False` to skip loading `obsp`
when reading `.h5ad` and `.h5Seurat` files
varp: an optional dictionary mapping string names to sparse arrays
or matrices containing pairwise gene-gene information, or
`False` to skip loading `varp` when reading `.h5ad` files
uns: an optional dictionary mapping string names to unstructured
metadata - scalars (strings, numbers or Booleans), NumPy
arrays, or nested dictionaries thereof - or `False` to skip
loading `uns` when reading `.h5ad` or `.h5Seurat` files
X_key: if `source` is an AnnData `.h5ad`, Seurat `.rds` or
`.h5Seurat` filename, or SingleCellExperiment `.rds`
filename, the location within `source` to use as `X`:
- If `source` is an `.h5ad` filename, the name of the key in
the `.h5ad` file to use as `X`. If `None`, defaults to
`'layers/UMIs'` (i.e. `self.layers['UMIs']` in Scanpy) or
`'raw/X'` (i.e. `self.raw.X` in Scanpy) if present,
otherwise `'X'`.
Tip: `SingleCell.ls(h5ad_file)` shows the structure of an
`.h5ad` file without loading it, allowing you to figure
out which key to use as `X`.
- If `source` is a Seurat `.rds` or `.h5Seurat` filename,
the layer within the active assay (or the assay specified
by the `assay` argument, if not `None`) to use as `X`. Set
to `'data'` to load the normalized counts, or
`'scale.data'` to load the normalized and scaled counts,
if available. If `None`, defaults to `'counts'`.
- If `source` is a SingleCellExperiment `.rds` filename, the
element within `@assays@data` to use as `X`. Set to
`'logcounts'` to load the normalized counts, if available.
If `None`, defaults to `'counts'`.
assay: if `source` is a Seurat `.rds` or `.h5Seurat`/`.h5seurat`
filename, the name of the assay within the Seurat object to
load data from. Defaults to the Seurat object's
`active.assay` attribute (usually `'RNA'`).
obs_columns: if `source` is an `.h5ad` or `.h5Seurat` filename, the
columns of `obs` to load. If not specified, load all
columns. Specifying only a subset of columns can speed
up reading. Not supported for `.h5` files, since they
only have a single `obs` column (`'barcodes'`), nor
for Seurat and SingleCellExperiment `.rds` files,
since `.rds` files do not support partial loading.
var_columns: if `source` is an `.h5ad`, `.h5`, or `.h5Seurat`
filename, the columns of `var` to load. If not
specified, load all columns. Specifying only a subset
of columns can speed up reading. Not supported for
Seurat and SingleCellExperiment `.rds` files, since
the `.rds` file format does not support partial
loading.
num_threads: the number of threads to use when reading `.h5ad` and
`.h5` files, and the default number of threads to use
for all subsequent operations on this SingleCell
dataset. Also sets the number of threads for this
SingleCell dataset's count matrix, if present. By
default (`num_threads=-1`), use all available cores,
as determined by `os.cpu_count()`.
Examples:
Load an `.h5ad` file:
>>> sc = SingleCell('data.h5ad')
Load an `.h5ad` file where raw counts are stored in a non-default
location, `adata.raw.counts` (use `SingleCell.ls('data.h5ad')` to
inspect its structure before loading):
>>> sc = SingleCell('data.h5ad', X_key='raw/counts')
Load only selected metadata columns from an `.h5ad` file
to reduce loading time and memory usage:
>>> sc = SingleCell('data.h5ad',
... obs_columns=['cell_type', 'batch'],
... var_columns=['gene_symbol'])
Skip loading the count matrix to minimize memory usage (note:
dataset cannot be saved, converted, or used for analyses that
require `X`):
>>> sc = SingleCell('large_data.h5ad', X=False)
Load a Seurat `.h5Seurat` file:
>>> sc = SingleCell('seurat_obj.h5Seurat')
Load a Seurat `.rds` file:
>>> sc = SingleCell('seurat_obj.rds')
Load a Bioconductor SingleCellExperiment `.rds` file and use
log-normalized counts:
>>> sc = SingleCell('sce_obj.rds', X_key='logcounts')
Load raw 10x Genomics data from an `.h5` file:
>>> sc = SingleCell('matrix.h5')
Load raw 10x Genomics data from an `.mtx.gz` file, with barcodes
and features stored in the same directory as `barcodes.tsv.gz` and
`features.tsv.gz` in the usual way:
>>> sc = SingleCell('matrix.mtx.gz')
Manually create a SingleCell dataset from an in-memory sparse
matrix and metadata:
>>> import polars as pl
>>> from scipy.sparse import csr_array
>>> X = csr_array([[1, 0, 3], [0, 2, 0]])
>>> obs = pl.DataFrame({'cell_id': ['cell1', 'cell2']})
>>> var = pl.DataFrame({'gene_id': ['g1', 'g2', 'g3']})
>>> sc = SingleCell(X=X, obs=obs, var=var)
Note:
Both ordered and unordered categorical columns of `obs` and `var`
will be loaded as polars Enums rather than polars Categoricals.
This is because polars Categoricals use a shared numerical encoding
across columns, so their codes are not `[0, 1, 2, ...]` like pandas
categoricals and polars Enums are. Using Categoricals leads to a
large overhead (~25%) when loading `obs` from an `.h5ad` file, for
example.
Note:
SingleCell does not support dense matrices, which are highly
memory-inefficient for single-cell data. Passing a NumPy array as
the `X` argument will give an error; if for some reason your data
has been improperly stored as a dense matrix, convert it to a
sparse array first with `csr_array(numpy_array)`). However, when
loading from disk or converting from other formats, dense matrices
will be automatically converted to sparse matrices, to avoid giving
an error when loading or converting.
"""
# Initialize this SingleCell dataset's `num_threads`
num_threads = SingleCell._process_num_threads_static(num_threads)
self._num_threads = num_threads
# Initialize the SingleCell dataset depending on which arguments were
# specified
if source is None:
# Sparse array or matrix
if X is not None:
for prop, prop_name in (
(X_key, 'X_key'), (assay, 'assay'),
(obs_columns, 'obs_columns'),
(var_columns, 'var_columns')):
if prop is not None:
error_message = (
f'when X is a sparse array or matrix, {prop_name} '
f'must be None')
raise ValueError(error_message)
self.X = X # not self._X; triggers validation
else:
for prop, prop_name in (
(X_key, 'X_key'), (assay, 'assay'),
(obs_columns, 'obs_columns'),
(var_columns, 'var_columns')):
if prop is not None:
error_message = (
f'when X and source are both None, {prop_name} '
f'must be None')
raise ValueError(error_message)
self._X = None
self.obs = obs
self.var = var
self.obsm = obsm if obsm is not None else {}
self.varm = varm if varm is not None else {}
self.obsp = obsp if obsp is not None else {}
self.varp = varp if varp is not None else {}
self.uns = uns if uns is not None else {}
elif str(type(source)).startswith("<class 'anndata"):
# AnnData object
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
from anndata import AnnData
finally:
signal.signal(signal.SIGINT, signal.default_int_handler)
check_type(source, 'source', AnnData,
'a filename, Path, or Anndata object')
for prop, prop_name in (
(obs, 'obs'), (var, 'var'), (obsm, 'obsm'),
(varm, 'varm'), (obsp, 'obsp'), (varp, 'varp'),
(uns, 'uns'), (X_key, 'X_key'), (assay, 'assay'),
(obs_columns, 'obs_columns'),
(var_columns, 'var_columns')):
if prop is not None:
error_message = (
f'when initializing a SingleCell dataset from an '
f'AnnData object, {prop_name} must be None')
raise ValueError(error_message)
# Get `X`
if X is False:
self._X = None
else:
if X is None:
has_layers_UMIs = 'UMIs' in source._layers
has_raw_X = hasattr(source._raw, '_X')
if has_layers_UMIs and has_raw_X:
error_message = (
"both layers['UMIs'] and raw.X are present in "
"this AnnData object; this should never "
"happen in well-formed AnnData objects")
raise ValueError(error_message)
X = source._layers['UMIs'] if has_layers_UMIs else \
source._raw._X if has_raw_X else source._X
if isinstance(X, np.ndarray):
X_string = f"layers['UMIs']" if has_layers_UMIs else \
f'raw.X' if has_raw_X else 'X'
warning_message = (
f"this AnnData object's {X_string} is stored as a "
f"dense matrix; auto-converting to a sparse csr_array")
warnings.warn(warning_message)
X = csr_array(X)
self.X = X
# Get `obs` and `var`
for attr in '_obs', '_var':
df = getattr(source, attr)
if df.index.name is None:
# Make the index name consistent with what you'd get if
# you had loaded the same AnnData object directly from
# an `.h5ad` file
df = df.rename_axis('_index')
# Convert Categoricals with string categories to polars Enums,
# and Categoricals with other types of categories (e.g.
# integers) to non-categorical columns of the corresponding
# polars dtype (e.g. pl.Int64) since polars only supports
# string categories
schema_overrides = {}
cast_dict = {}
for column, dtype in \
df.dtypes[df.dtypes == 'category'].items():
categories_dtype = dtype.categories.dtype
if categories_dtype == object:
schema_overrides[column] = pl.Enum(dtype.categories)
else:
cast_dict[column] = categories_dtype
setattr(self, attr, pl.from_pandas(
df.astype(cast_dict),
schema_overrides=schema_overrides, include_index=True))
# Get the other fields
self.obsm = source._obsm
self.varm = source._varm
self.obsp = source._obsp
self.varp = source._varp
self.uns = source._uns
else:
# Filename
check_type(source, 'source', (str, Path),
'a filename, Path, or AnnData object')
source = os.path.expanduser(source)
if source.endswith('.h5ad'):
# For the AnnData on-disk format specification, see
# anndata.readthedocs.io/en/latest/fileformat-prose.html
if not os.path.exists(source):
error_message = f'.h5ad file {source} does not exist'
raise FileNotFoundError(error_message)
for prop, prop_name in \
(obs, 'obs'), (var, 'var'), (assay, 'assay'):
if prop is not None:
error_message = (
f'when loading an .h5ad file, {prop_name} '
f'must be None')
raise ValueError(error_message)
for prop, prop_name in (
(obsm, 'obsm'), (varm, 'varm'), (obsp, 'obsp'),
(varp, 'varp'), (uns, 'uns')):
if prop is not None and prop is not False:
error_message = (
f'when loading an .h5ad file, {prop_name} '
f'must be None or False')
raise ValueError(error_message)
if obs_columns is not None:
obs_columns = to_tuple_checked(
obs_columns, 'obs_columns', str, 'strings')
if var_columns is not None:
var_columns = to_tuple_checked(
var_columns, 'var_columns', str, 'strings')
# Loading happens in three stages:
# 1. Tabulate which HDF5 datasets we need to load (except uns)
# 2. "Preload" datasets in parallel, using threads + a custom
# Cython reader for most h5ad files and falling back to
# multiprocessing + shared memory using h5py for h5ad files
# with chunking and/or compression. Variable-length
# (dtype=object) string datasets will only be preloaded when
# using threads, since they are not compatible with shared
# memory. Store the preloaded datasets in a cache.
# 3. Assemble the SingleCell dataset, loading from the
# cache if the dataset is there, and otherwise loading
# it single-threaded. uns is always loaded single-threaded.
# When loading single-threaded, skip steps 1 and 2.
h5ad_file = h5py.File(source)
try:
if num_threads == 1:
preloaded_datasets = {}
else:
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
# `X`
if X is False:
if X_key is not None:
error_message = (
'when loading an .h5ad file with X=False, '
'X_key must be None')
raise ValueError(error_message)
self._X = None
else:
if X_key is None:
has_layers_UMIs = 'layers/UMIs' in h5ad_file
has_raw_X = 'raw/X' in h5ad_file
if has_layers_UMIs and has_raw_X:
error_message = (
"both layers['UMIs'] and raw.X are "
"present; this should never happen in "
"well-formed .h5ad files")
raise ValueError(error_message)
X_key = 'layers/UMIs' if has_layers_UMIs \
else 'raw/X' if has_raw_X else 'X'
else:
check_type(X_key, 'X_key', str, 'a string')
if X_key not in h5ad_file:
error_message = (
f'X_key {X_key!r} is not present in '
f'the .h5ad file')
raise ValueError(error_message)
X = h5ad_file[X_key]
if isinstance(X, h5py.Dataset):
datasets_to_load.add(X.name)
else:
datasets_to_load.add(X['data'].name)
datasets_to_load.add(X['indices'].name)
datasets_to_load.add(X['indptr'].name)
# `obs`
SingleCell._tabulate_h5ad_dataframe(
h5ad_file, 'obs', datasets_to_load,
columns=obs_columns)
# `var`
SingleCell._tabulate_h5ad_dataframe(
h5ad_file, 'var', datasets_to_load,
columns=var_columns)
# `obsm`
if obsm is None and 'obsm' in h5ad_file:
for key, value in h5ad_file['obsm'].items():
if isinstance(value, h5py.Dataset):
datasets_to_load.add(f'obsm/{key}')
else:
SingleCell._tabulate_h5ad_dataframe(
h5ad_file, f'obsm/{key}',
datasets_to_load)
# `varm`
if varm is None and 'varm' in h5ad_file:
for key, value in h5ad_file['varm'].items():
if isinstance(value, h5py.Dataset):
datasets_to_load.add(f'varm/{key}')
else:
SingleCell._tabulate_h5ad_dataframe(
h5ad_file, f'varm/{key}',
datasets_to_load)
# `obsp`
if obsp is None and 'obsp' in h5ad_file:
for value in h5ad_file['obsp'].values():
datasets_to_load.add(value['data'].name)
datasets_to_load.add(value['indices'].name)
datasets_to_load.add(value['indptr'].name)
# `varp`
if varp is None and 'varp' in h5ad_file:
for value in h5ad_file['varp'].values():
datasets_to_load.add(value['data'].name)
datasets_to_load.add(value['indices'].name)
datasets_to_load.add(value['indptr'].name)
# Infer the number of cells
obs_group = h5ad_file['obs']
if isinstance(obs_group, h5py.Group):
num_cells = obs_group[
obs_group.attrs['_index']].shape[0]
else:
num_cells = obs_group.shape[0]
# 2. Preload datasets in parallel
# If any fixed-width dataset is chunked or compressed
# (indicated by a `None` `file_offset`), fall back to
# the multiprocessing-based reader.
# Crucial: the HDF5 file must be closed before
# `read_parallel_multiprocessing()` to avoid the child
# processes inheriting the HDF5 library's state
# associated with the open file handle when forking.
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
h5ad_file, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
h5ad_file.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, source,
num_threads=num_threads)
h5ad_file = h5py.File(source)
else:
preloaded_datasets = SingleCell._read_parallel(
datasets_to_preload, source,
num_cells=num_cells, num_threads=num_threads)
# 3. Assemble the SingleCell dataset, loading from the
# cache if the dataset is there, and otherwise loading
# it single-threaded. Always load `uns` single-threaded
# for simplicity, since it's rarely that large.
# Load `X`
if X is False:
if X_key is not None:
error_message = (
'when loading an .h5ad file with X=False, '
'X_key must be None')
raise ValueError(error_message)
self._X = None
else:
if X_key is None:
has_layers_UMIs = \
'layers/UMIs' in h5ad_file
has_raw_X = 'raw/X' in h5ad_file
if has_layers_UMIs and has_raw_X:
error_message = (
"both layers['UMIs'] and raw.X are "
"present; this should never happen in "
"well-formed .h5ad files")
raise ValueError(error_message)
X_key = 'layers/UMIs' if has_layers_UMIs \
else 'raw/X' if has_raw_X else 'X'
else:
check_type(X_key, 'X_key', str, 'a string')
if X_key not in h5ad_file:
error_message = (
f'X_key {X_key!r} is not present in '
f'the .h5ad file')
raise ValueError(error_message)
X = h5ad_file[X_key]
if isinstance(X, h5py.Dataset):
warning_message = (
f'{X_key!r} is stored as a dense matrix; '
f'auto-converting to a sparse csr_array')
warnings.warn(warning_message)
X = SingleCell._read_dataset(X, preloaded_datasets)
self.X = csr_array(X)
else:
matrix_class = X.attrs['encoding-type'] \
if 'encoding-type' in X.attrs else \
X.attrs['h5sparse_format'] + '_matrix'
if matrix_class == 'csr_matrix':
array_class = csr_array
elif matrix_class == 'csc_matrix':
array_class = csc_array
else:
error_message = (
f"X has unsupported encoding-type "
f"{matrix_class!r}, but should be "
f"'csr_matrix' or 'csc_matrix'")
raise ValueError(error_message)
data = SingleCell._read_dataset(
X['data'], preloaded_datasets)
self.X = array_class((
data,
SingleCell._read_dataset(
X['indices'], preloaded_datasets),
SingleCell._read_dataset(
X['indptr'], preloaded_datasets)),
shape=X.attrs['shape']
if 'shape' in X.attrs
else X.attrs['h5sparse_shape'])
# Load `obs` and `var`
self.obs = SingleCell._read_h5ad_dataframe(
h5ad_file, 'obs', preloaded_datasets,
columns=obs_columns)
self.var = SingleCell._read_h5ad_dataframe(
h5ad_file, 'var', preloaded_datasets,
columns=var_columns)
# Load `obsm`
if obsm is None and 'obsm' in h5ad_file:
self.obsm = {
key: SingleCell._read_dataset(
value, preloaded_datasets)
if isinstance(value, h5py.Dataset)
else SingleCell._read_h5ad_dataframe(
h5ad_file, f'obsm/{key}',
preloaded_datasets)
for key, value in h5ad_file['obsm'].items()}
else:
self.obsm = {}
# Load `varm`
if varm is None and 'varm' in h5ad_file:
self.varm = {
key: SingleCell._read_dataset(
value, preloaded_datasets)
if isinstance(value, h5py.Dataset)
else SingleCell._read_h5ad_dataframe(
h5ad_file, f'varm/{key}',
preloaded_datasets)
for key, value in h5ad_file['varm'].items()}
else:
self.varm = {}
# Load `obsp`
if obsp is None and 'obsp' in h5ad_file:
self.obsp = {
key: (csr_array
if value.attrs['encoding-type'] ==
'csr_matrix' else csc_array)(
(SingleCell._read_dataset(
value['data'], preloaded_datasets),
SingleCell._read_dataset(
value['indices'], preloaded_datasets),
SingleCell._read_dataset(
value['indptr'], preloaded_datasets)),
shape=value.attrs['shape'])
for key, value in h5ad_file['obsp'].items()}
else:
self.obsp = {}
# Load `varp`
if varp is None and 'varp' in h5ad_file:
self.varp = {
key: (csr_array
if value.attrs['encoding-type'] ==
'csr_matrix' else csc_array)(
(SingleCell._read_dataset(
value['data'], preloaded_datasets),
SingleCell._read_dataset(
value['indices'], preloaded_datasets),
SingleCell._read_dataset(
value['indptr'], preloaded_datasets)),
shape=value.attrs['shape'])
for key, value in h5ad_file['varp'].items()}
else:
self.varp = {}
# Load `uns`
if uns is None and 'uns' in h5ad_file:
self.uns = SingleCell._read_uns(h5ad_file['uns'])
else:
self.uns = {}
finally:
h5ad_file.close()
elif source.endswith('.h5Seurat') or \
source.endswith('.h5seurat'):
if not os.path.exists(source):
error_message = \
f'.h5Seurat file {source} does not exist'
raise FileNotFoundError(error_message)
for prop, prop_name in (
(obs, 'obs'), (var, 'var'), (varm, 'varm'),
(varp, 'varp')):
if prop is not None:
error_message = (
f'when loading an .h5Seurat file, {prop_name} '
f'must be None')
raise ValueError(error_message)
for prop, prop_name in \
(obsm, 'obsm'), (obsp, 'obsp'), (uns, 'uns'):
if prop is not None and prop is not False:
error_message = (
f'when loading an .h5Seurat file, {prop_name} '
f'must be None or False')
raise ValueError(error_message)
if obs_columns is not None:
obs_columns = to_tuple_checked(
obs_columns, 'obs_columns', str, 'strings')
if var_columns is not None:
var_columns = to_tuple_checked(
var_columns, 'var_columns', str, 'strings')
# The logic here is similar to `.h5ad` files: preload in
# parallel where we can
h5Seurat_file = h5py.File(source)
try:
if num_threads == 1:
preloaded_datasets = {}
else:
# Check that `assay` is an assay in the `.h5Seurat`
# file, or set it to `active.assay` if `None`
if assay is None:
assay = \
h5Seurat_file.attrs['active.assay'].item()
elif assay not in h5Seurat_file['assays']:
error_message = (
f'assay {assay!r} does not exist in the '
f'.h5Seurat file; specify a different '
f'assay than {assay!r}')
raise ValueError(error_message)
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
# `X`
assay_group = h5Seurat_file['assays'][assay]
if X is False:
if X_key is not None:
error_message = (
'when loading an .h5Seurat file with '
'X=False, X_key must be None')
raise ValueError(error_message)
self._X = None
else:
if X_key is None:
X_key = 'counts'
if X_key not in assay_group:
error_message = (
f"the 'counts' key is not present "
f"in the .h5Seurat file as part "
f"of assay {assay!r}; specify a "
f"different assay than {assay!r} "
f"or specify X_key as something "
f"other than 'counts'")
raise ValueError(error_message)
else:
check_type(X_key, 'X_key', str,
'a string or False')
if X_key not in assay_group:
error_message = (
f'X_key {X_key!r} is not present '
f'in the .h5Seurat file as part '
f'of assay {assay!r}; specify a '
f'different assay than {assay!r} '
f'or a different X_key than '
f'{X_key!r}')
raise ValueError(error_message)
X = assay_group[X_key]
if isinstance(X, h5py.Dataset):
datasets_to_load.add(X.name)
else:
datasets_to_load.add(X['data'].name)
datasets_to_load.add(X['indices'].name)
datasets_to_load.add(X['indptr'].name)
# `obs`
SingleCell._tabulate_h5Seurat_dataframe(
h5Seurat_file['meta.data'], datasets_to_load,
columns=obs_columns)
# `var`
if 'meta.features' in assay_group:
SingleCell._tabulate_h5Seurat_dataframe(
assay_group['meta.features'],
datasets_to_load, columns=var_columns)
# `obsm`
if obsm is None:
for key, value in \
h5Seurat_file['reductions'].items():
if value.attrs['active.assay'] != assay:
continue
datasets_to_load.add(
f'reductions/{key}/cell.embeddings')
# `obsp`
if obsp is None:
for value in h5Seurat_file['graphs'].values():
if value.attrs['assay.used'] != assay:
continue
datasets_to_load.add(value['data'].name)
datasets_to_load.add(value['indices'].name)
datasets_to_load.add(value['indptr'].name)
# Infer the number of cells
meta_data = h5Seurat_file['meta.data']
num_cells = meta_data['_index'].shape[0]
# 2. Preload datasets in parallel
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
h5Seurat_file, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
h5Seurat_file.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, source,
num_threads=num_threads)
h5Seurat_file = h5py.File(source)
else:
preloaded_datasets = \
SingleCell._read_parallel(
datasets_to_preload, source,
num_cells=num_cells,
num_threads=num_threads)
# 3. Assemble the SingleCell dataset
# Check that `assay` is an assay in the `.h5Seurat`
# file, or set it to `active.assay` if `None`
if assay is None:
assay = h5Seurat_file.attrs['active.assay'].item()
elif assay not in h5Seurat_file['assays']:
error_message = (
f'assay {assay!r} does not exist in the '
f'.h5Seurat file; specify a different assay '
f'than {assay!r}')
raise ValueError(error_message)
# Load `X`
assay_group = h5Seurat_file['assays'][assay]
if X is False:
if X_key is not None:
error_message = (
'when loading an .h5Seurat file with '
'X=False, X_key must be None')
raise ValueError(error_message)
self._X = None
else:
if X_key is None:
X_key = 'counts'
if X_key not in assay_group:
error_message = (
f"the 'counts' key is not present in "
f"the .h5Seurat file as part of assay "
f"{assay!r}; specify a different "
f"assay than {assay!r} or specify "
f"X_key as something other than "
f"'counts'")
raise ValueError(error_message)
else:
check_type(X_key, 'X_key', str,
'a string or False')
if X_key not in assay_group:
error_message = (
f'X_key {X_key!r} is not present in '
f'the .h5Seurat file as part of assay '
f'{assay!r}; specify a different '
f'assay than {assay!r} or a different '
f'X_key than {X_key!r}')
raise ValueError(error_message)
X = assay_group[X_key]
if isinstance(X, h5py.Dataset):
warning_message = (
f'{X_key!r} is stored as a dense matrix; '
f'auto-converting to a sparse csr_array')
warnings.warn(warning_message)
X = SingleCell._read_dataset(X, preloaded_datasets)
self.X = csr_array(X.T)
else:
data = SingleCell._read_dataset(
X['data'], preloaded_datasets)
self.X = csr_array((
data,
SingleCell._read_dataset(
X['indices'], preloaded_datasets),
SingleCell._read_dataset(
X['indptr'], preloaded_datasets)),
shape=X.attrs['dims'][::-1])
# Load `obs`
self.obs = SingleCell._read_h5Seurat_dataframe(
h5Seurat_file['meta.data'], preloaded_datasets,
columns=obs_columns)
# Load `var`
self.var = SingleCell._read_h5Seurat_dataframe(
assay_group['meta.features'], preloaded_datasets,
columns=var_columns) \
if 'meta.features' in assay_group else \
pl.Series('feature.names',
assay_group['features'][:])\
.cast(pl.String)\
.to_frame()
# Load `obsm`
self.obsm = {
key: value['cell.embeddings'][:].T
for key, value in
h5Seurat_file['reductions'].items()
if value.attrs['active.assay'] == assay}
# Load `obsp`
self.obsp = {
key: csr_array((
SingleCell._read_dataset(
value['data'], preloaded_datasets),
SingleCell._read_dataset(
value['indices'], preloaded_datasets),
SingleCell._read_dataset(
value['indptr'], preloaded_datasets)),
shape=value.attrs['dims'][::-1])
for key, value in h5Seurat_file['graphs'].items()
if value.attrs['assay.used'] == assay}
# Load `uns`
self.uns = SingleCell._read_h5Seurat_uns(
h5Seurat_file['misc'])
self.varm = {}
self.varp = {}
finally:
h5Seurat_file.close()
elif source.endswith('.h5'):
if not os.path.exists(source):
error_message = f'10x .h5 file {source} does not exist'
raise FileNotFoundError(error_message)
for prop, prop_name in (
(obs, 'obs'), (var, 'var'), (obsm, 'obsm'),
(varm, 'varm'), (obsp, 'obsp'), (varp, 'varp'),
(uns, 'uns'), (X_key, 'X_key'), (assay, 'assay'),
(obs_columns, 'obs_columns')):
if prop is not None:
error_message = (
f'when loading a 10x .h5 file, {prop_name} '
f'must be None')
raise ValueError(error_message)
if var_columns is not None:
var_columns = to_tuple_checked(
var_columns, 'var_columns', str, 'strings')
# The logic here is similar to `.h5ad` files: preload in
# parallel where we can
h5_file = h5py.File(source)
try:
if num_threads == 1:
preloaded_datasets = {}
else:
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
# `X`
matrix = h5_file['matrix']
if X is None:
datasets_to_load.add(matrix['data'].name)
datasets_to_load.add(matrix['indices'].name)
datasets_to_load.add(matrix['indptr'].name)
elif X is not False:
error_message = (
'when loading a 10x .h5 file, X must be '
'None or False')
raise ValueError(error_message)
# `obs`
datasets_to_load.add(matrix['barcodes'].name)
# `var`
features = matrix['features']
if var_columns is not None:
for column in var_columns:
if column not in features:
error_message = (
f'var_columns contains the column '
f'{column!r}, which is not '
f'present in the .h5 file')
raise ValueError(error_message)
datasets_to_load.add(features[column].name)
else:
for column in 'name', 'id', 'feature_type', \
'genome':
datasets_to_load.add(features[column].name)
for column in 'pattern', 'read', 'sequence':
if column in features:
datasets_to_load.add(
features[column].name)
# Infer the number of cells
num_cells = matrix['barcodes'].shape[0]
# 2. Preload datasets in parallel
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
h5_file, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
h5_file.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, source,
num_threads=num_threads)
h5_file = h5py.File(source)
else:
preloaded_datasets = \
SingleCell._read_parallel(
datasets_to_preload, source,
num_cells=num_cells,
num_threads=num_threads)
# 3. Assemble the SingleCell dataset
matrix = h5_file['matrix']
features = matrix['features']
# Load `X`
if X is None:
self.X = csr_array((
SingleCell._read_dataset(
matrix['data'], preloaded_datasets),
SingleCell._read_dataset(
matrix['indices'], preloaded_datasets),
SingleCell._read_dataset(
matrix['indptr'], preloaded_datasets)),
shape=matrix['shape'][:][::-1])
elif X is False:
self._X = None
else:
error_message = (
'when loading a 10x .h5 file, X must be None '
'or False')
raise ValueError(error_message)
# Load `obs` and `var`
self.obs = pl.Series('barcodes',
matrix['barcodes'][:])\
.cast(pl.String)\
.to_frame()
if var_columns is not None:
for column in var_columns:
if column not in features:
error_message = (
f'var_columns contains the column '
f'{column!r}, which is not present in '
f'the .h5 file')
raise ValueError(error_message)
else:
var_columns = \
['name', 'id', 'feature_type', 'genome'] + \
[column for column in
('pattern', 'read', 'sequence')
if column in features]
self.var = pl.DataFrame([
pl.Series(column, SingleCell._read_dataset(
features[column], preloaded_datasets))
.cast(pl.String) for column in var_columns])
self.obsm = {}
self.varm = {}
self.obsp = {}
self.varp = {}
self.uns = {}
finally:
h5_file.close()
elif source.endswith('.mtx') or source.endswith('.mtx.gz'):
if not os.path.exists(source):
error_message = f'10x file {source} does not exist'
raise FileNotFoundError(error_message)
if obs is None:
ungzipped_barcode_file = \
f'{os.path.dirname(source)}/barcodes.tsv'
if os.path.exists(ungzipped_barcode_file):
barcode_file = ungzipped_barcode_file
else:
gzipped_barcode_file = \
f'{os.path.dirname(source)}/barcodes.tsv.gz'
if os.path.exists(gzipped_barcode_file):
barcode_file = gzipped_barcode_file
else:
error_message = (
f'the cell-level metadata file '
f'corresponding to {source} was not found '
f'at either {ungzipped_barcode_file} or '
f'{gzipped_barcode_file}; you can specify '
f'a custom location via the obs argument')
raise FileNotFoundError(error_message)
else:
if not isinstance(obs, (str, Path)):
error_message = (
f'when loading a 10x .mtx or .mtx.gz file, '
f'obs must be None or the path to a '
f'barcodes.tsv or barcodes.tsv.gz file of '
f'cell-level metadata, but it has type '
f'{type(obs).__name__!r}')
raise TypeError(error_message)
barcode_file = str(obs)
if not os.path.exists(barcode_file):
error_message = (
f'the cell-level metadata file {barcode_file} '
f'does not exist')
raise FileNotFoundError(error_message)
if var is None:
ungzipped_feature_file = \
f'{os.path.dirname(source)}/features.tsv'
if os.path.exists(ungzipped_feature_file):
feature_file = ungzipped_feature_file
else:
gzipped_feature_file = \
f'{os.path.dirname(source)}/features.tsv.gz'
if os.path.exists(gzipped_feature_file):
feature_file = gzipped_feature_file
else:
error_message = (
f'the gene-level metadata file '
f'corresponding to {source} was not found '
f'at either {ungzipped_feature_file} or '
f'{gzipped_feature_file}; you can specify '
f'a custom location via the var argument')
raise FileNotFoundError(error_message)
else:
if not isinstance(var, (str, Path)):
error_message = (
f'when loading a 10x .mtx or .mtx.gz file, '
f'var must be None or the path to a '
f'features.tsv or features.tsv.gz file of '
f'gene-level metadata, but it has type '
f'{type(var).__name__!r}')
raise TypeError(error_message)
feature_file = str(var)
if not os.path.exists(feature_file):
error_message = (
f'the gene-level metadata file {feature_file} '
f'does not exist')
raise FileNotFoundError(error_message)
for prop, prop_name in (
(obsm, 'obsm'), (varm, 'varm'), (obsp, 'obsp'),
(varp, 'varp'), (uns, 'uns'), (X_key, 'X_key'),
(assay, 'assay'), (obs_columns, 'obs_columns'),
(var_columns, 'var_columns')):
if prop is not None:
error_message = (
f'when loading a 10 .mtx or .mtx.gz file, '
f'{prop_name} must be None')
raise ValueError(error_message)
from scipy.io import mmread
# Load `obs` and `var`
self.obs = pl.read_csv(barcode_file, has_header=False,
new_columns=['cell'])
self.var = pl.read_csv(feature_file, has_header=False,
new_columns=['gene'])
# Load `X`
if X is None:
self.X = csr_array(mmread(source).T.tocsr())
elif X is False:
self._X = None
else:
error_message = (
'when loading a 10x .mtx or .mtx.gz file, X must '
'be None or False')
raise ValueError(error_message)
self.obsm = {}
self.varm = {}
self.obsp = {}
self.varp = {}
self.uns = {}
elif source.endswith('.rds'):
if not os.path.exists(source):
error_message = f'.rds file {source} does not exist'
raise FileNotFoundError(error_message)
for prop, prop_name in (
(X, 'X'), (obs, 'obs'), (var, 'var'),
(varm, 'varm'), (obsp, 'obsp'), (varp, 'varp'),
(uns, 'uns'), (obs_columns, 'obs_columns'),
(var_columns, 'var_columns')):
if prop is not None:
error_message = (
f'when loading a .rds file, {prop_name} must '
f'be None')
raise ValueError(error_message)
from ryp import r, to_py, to_r
r(f'.SingleCell.object = readRDS({source!r})')
try:
if X_key is None:
X_key = 'counts'
else:
check_type(X_key, 'X_key', str, 'a string')
classes = to_py('class(.SingleCell.object)',
squeeze=False)
if len(classes) == 1:
if classes[0] == 'Seurat':
r('suppressPackageStartupMessages('
'library(SeuratObject))')
X, self.obs, self.var, self.obsm, self.obsp, \
self.uns = SingleCell._from_seurat(
'.SingleCell.object', assay=assay,
layer=X_key, layer_name='X_key')
self.varm = {}
self.varp = {}
elif classes[0] == 'SingleCellExperiment':
if assay is not None:
error_message = (
f'when loading a SingleCellExperiment '
f'.rds file, assay must be None')
raise ValueError(error_message)
r('suppressPackageStartupMessages('
'library(SingleCellExperiment))')
X, self.obs, self.var, self.obsm, self.uns = \
SingleCell._from_sce(
'.SingleCell.object', assay=X_key,
assay_name='X_key')
self.obsp = {}
self.varm = {}
self.varp = {}
else:
error_message = (
f'the R object loaded from {source} must '
f'be a Seurat or SingleCellExperiment '
f'object, but has class {classes[0]!r}')
raise TypeError(error_message)
self.X = X
elif len(classes) == 0:
error_message = (
f'the R object loaded from {source} must be a '
f'Seurat or SingleCellExperiment object, but '
f'has no classes')
raise TypeError(error_message)
else:
classes_string = \
', '.join(f'{c!r}' for c in classes[:-1])
error_message = (
f'the R object loaded from {source} must be a '
f'Seurat object, but has classes '
f'{classes_string} and {classes[-1]!r}')
raise TypeError(error_message)
finally:
r('rm(.SingleCell.object)')
else:
error_message = (
f'source is a filename with unsupported extension '
f'.{".".join(source.split(".")[1:])}; it must be '
f'.h5ad (AnnData), .h5 or .mtx/.mtx.gz (10x), .rds '
f'(Seurat or SingleCellExperiment), or '
f'.h5Seurat/.h5seurat (Seurat)')
raise ValueError(error_message)
# Propagate this SingleCell dataset's `num_threads` to its count matrix
if self._X is not None:
self._X._num_threads = num_threads
# Check that each dimension is non-zero and ≤INT32_MAX
num_cells = len(self._obs)
num_genes = len(self._var)
if num_cells == 0:
error_message = 'len(obs) is 0: no cells remain'
raise ValueError(error_message)
if num_genes == 0:
error_message = 'len(var) is 0: no genes remain'
raise ValueError(error_message)
if num_cells > 2_147_483_647:
error_message = (
'X has more than 2,147,483,647 (INT32_MAX) cells, which is '
'not currently supported')
raise ValueError(error_message)
if num_genes > 2_147_483_647:
error_message = (
'X has more than 2,147,483,647 (INT32_MAX) genes, which is '
'not currently supported')
raise ValueError(error_message)
# Set `uns['QCed']` and `uns['normalized']` to `False` if not set yet;
# if already set but not a Boolean, back it up to
# `uns['_QCed']`/`uns['_normalized']`
for key in 'QCed', 'normalized':
if key in self._uns:
if not isinstance(self._uns[key], bool):
new_key = f'_{key}'
while new_key in self._uns:
new_key = f'_{new_key}'
warning_message = (
f'uns[{key!r}] already exists and is not Boolean; '
f'moving it to uns[{new_key!r}]')
warnings.warn(warning_message)
self._uns[new_key] = self._uns[key]
self._uns[key] = False
else:
self._uns[key] = False
@property
def X(self) -> csr_array | csc_array | None:
"""
The count matrix, as a sparse array. It behaves like a Scipy sparse
array, but also has a `num_threads` attribute which determines the
number of threads used for operations like subsetting.
"""
return self._X
@X.setter
def X(self, X: sparse.csr_array | sparse.csc_array | sparse.csr_matrix |
sparse.csc_matrix) -> None:
if isinstance(X, (csr_array, csc_array)):
pass
elif isinstance(X, (sparse.csr_array, sparse.csr_matrix)):
X = csr_array(X)
elif isinstance(X, (sparse.csc_array, sparse.csc_matrix)):
X = csc_array(X)
elif X is None:
error_message = (
'attempting to set X to None; if you want to remove X to save '
'memory, use drop_X() instead')
raise ValueError(error_message)
else:
error_message = (
f'X must be a csr_array, csc_array, csr_matrix, or '
f'csc_matrix, but has type {type(X).__name__!r}')
raise TypeError(error_message)
try:
if X.shape != self._X.shape:
error_message = (
f'new X is {X.shape[0]:,} × {X.shape[1]:,}, but old X is '
f'{self._X.shape[0]:,} × {self._X.shape[1]:,}')
raise ValueError(error_message)
except AttributeError: # `self._X` may not have been assigned yet
pass
dtype = X.dtype
if dtype != np.int32 and dtype != np.int64 and \
dtype != np.float32 and dtype != np.float64 and \
dtype != np.uint32 and dtype != np.uint64:
error_message = (
f'X must be (u)int32/64 or float32/64, but has data type '
f'{str(dtype)}')
raise TypeError(error_message)
try:
if X.shape[0] != len(self._obs):
error_message = (
f'len(obs) is {len(self._obs):,}, but X.shape[0] is '
f'{X.shape[0]:,}')
raise ValueError(error_message)
except AttributeError: # `self._obs` may not have been assigned yet
pass
try:
if X.shape[1] != len(self._var):
error_message = (
f'len(var) is {len(self._var):,}, but X.shape[1] is '
f'{X.shape[1]:,}')
raise ValueError(error_message)
except AttributeError: # `self._var` may not have been assigned yet
pass
is_csc = int(isinstance(X, csc_array))
if len(X.indptr) - 1 != X.shape[is_csc]:
error_message = (
f'X is corrupted: len(X.indptr) - 1 is '
f'{len(X.indptr) - 1:,}, but X.shape[{is_csc}] is '
f'{X.shape[is_csc]:,}')
raise ValueError(error_message)
self._X = X if X.nnz != 0 else None
@property
def obs(self) -> pl.DataFrame:
"""
A Polars DataFrame of metadata for each cell.
"""
return self._obs
@obs.setter
def obs(self, obs: pl.DataFrame) -> None:
check_type(obs, 'obs', pl.DataFrame, 'a polars DataFrame')
obs_names_dtype = obs[:, 0].dtype
if obs_names_dtype not in (pl.String, pl.Enum, pl.Categorical) and \
obs_names_dtype not in pl.INTEGER_DTYPES:
error_message = (
f'the first column of obs ({obs.columns[0]!r}) must be '
f'String, Enum, Categorical, or integer, but has data type '
f'{obs_names_dtype.base_type()!r}')
raise ValueError(error_message)
try:
if len(obs) != len(self._obs):
error_message = (
f'new obs has length {len(obs):,}, but old obs has length '
f'{len(self._obs):,}')
raise ValueError(error_message)
except AttributeError: # `self._obs` may not have been assigned yet
pass
self._obs = obs
@property
def var(self) -> pl.DataFrame:
"""
A Polars DataFrame of metadata for each gene.
"""
return self._var
@var.setter
def var(self, var: pl.DataFrame) -> None:
check_type(var, 'var', pl.DataFrame, 'a polars DataFrame')
var_names_dtype = var[:, 0].dtype
if var_names_dtype not in (pl.String, pl.Enum, pl.Categorical) and \
var_names_dtype not in pl.INTEGER_DTYPES:
error_message = (
f'the first column of var ({var.columns[0]!r}) must be '
f'String, Enum, Categorical, or integer, but has data type '
f'{var_names_dtype.base_type()!r}')
raise ValueError(error_message)
try:
if len(var) != len(self._var):
error_message = (
f'new var has length {len(var):,}, but old var has length '
f'{len(self._var):,}')
raise ValueError(error_message)
except AttributeError: # `self._var` may not have been assigned yet
pass
self._var = var
@property
def obsm(self) -> dict[str, np.ndarray | pl.DataFrame]:
"""
A dictionary of 2D NumPy arrays, where the length of each array's first
dimension is the number of cells.
"""
return self._obsm
@obsm.setter
def obsm(self, obsm: dict[str, np.ndarray | pl.DataFrame]) -> None:
check_type(obsm, 'obsm', dict, 'a dictionary')
self._obsm = Obsm(obsm, length=len(self._obs))
@property
def varm(self) -> dict[str, np.ndarray | pl.DataFrame]:
"""
A dictionary of 2D NumPy arrays, where the length of each array's first
dimension is the number of genes.
"""
return self._varm
@varm.setter
def varm(self, varm: dict[str, np.ndarray | pl.DataFrame]) -> None:
check_type(varm, 'varm', dict, 'a dictionary')
self._varm = Varm(varm, length=len(self._var))
@property
def obsp(self) -> dict[str, csr_array | csc_array]:
"""
A dictionary of 2D sparse arrays, where the length and width of each
array is the number of cells. The arrays behave like Scipy sparse
arrays, but also have a `num_threads` attribute which determines the
number of threads used for operations like subsetting.
"""
return self._obsp
@obsp.setter
def obsp(self,
obsp: dict[str, sparse.csr_array | sparse.csc_array |
sparse.csr_matrix | sparse.csc_matrix]) -> None:
check_type(obsp, 'obsp', dict, 'a dictionary')
self._obsp = Obsp(obsp, length=len(self._obs))
@property
def varp(self) -> dict[str, csr_array | csc_array]:
"""
A dictionary of 2D sparse arrays, where the length and width of each
array is the number of genes. The arrays behave like Scipy sparse
arrays, but also have a `num_threads` attribute which determines the
number of threads used for operations like subsetting.
"""
return self._varp
@varp.setter
def varp(self,
varp: dict[str, sparse.csr_array | sparse.csc_array |
sparse.csr_matrix | sparse.csc_matrix]) -> None:
check_type(varp, 'varp', dict, 'a dictionary')
self._varp = Varp(varp, length=len(self._var))
@property
def uns(self) -> UnsDict:
"""
A dictionary of miscellaneous metadata. Keys must be strings and values
may be scalars, NumPy arrays, or other dictionaries.
"""
return self._uns
@uns.setter
def uns(self, uns: UnsDict) -> None:
check_type(uns, 'uns', dict, 'a dictionary')
self._uns = Uns(uns)
@staticmethod
def _read_uns(uns_group: h5py.Group) -> UnsDict:
"""
Recursively load `uns` from an `.h5ad` file.
Args:
uns_group: `uns` as an `h5py.Group`
Returns:
The loaded `uns`.
"""
return {key: SingleCell._read_uns(value)
if isinstance(value, h5py.Group) else
None if value.shape is None else
(pl.Series(value[:]).cast(pl.String).to_numpy()
if value.shape else value[()].decode('utf-8'))
if value.dtype == object else
(value[:] if value.shape else value[()].item())
for key, value in uns_group.items()}
@staticmethod
def _read_h5Seurat_uns(uns_group: h5py.Group) -> UnsDict:
"""
Recursively load `uns` (i.e. `misc`) from an `.h5Seurat` file.
Args:
uns_group: `uns` as an `h5py.Group`
Returns:
The loaded `uns`.
"""
return {key: SingleCell._read_h5Seurat_uns(value)
if isinstance(value, h5py.Group) else
None if value.shape is None else
(pl.Series(value[:]).cast(pl.String).to_numpy()
if len(value) > 1 else value[:].item().decode('utf-8'))
if value.dtype == object else
(value[:] if len(value) > 1 else value[:].item())
for key, value in uns_group.items()}
@staticmethod
def _save_uns(uns: UnsDict,
uns_group: h5py.Group,
h5ad_file: h5py.File) -> None:
"""
Recursively save `uns` to an `.h5ad` file.
Args:
uns: an `uns` dictionary
uns_group: `uns` as an `h5py.Group`
h5ad_file: an `h5py.File` open in write mode
"""
uns_group.attrs['encoding-type'] = 'dict'
uns_group.attrs['encoding-version'] = '0.1.0'
for key, value in uns.items():
if isinstance(value, dict):
SingleCell._save_uns(value, uns_group.create_group(key),
h5ad_file)
else:
dataset = uns_group.create_dataset(key, data=value)
dataset.attrs['encoding-type'] = \
('string-array' if value.dtype == object else 'array') \
if isinstance(value, np.ndarray) else \
'string' if isinstance(value, str) else 'numeric-scalar'
dataset.attrs['encoding-version'] = '0.2.0'
@staticmethod
def _save_h5Seurat_uns(uns: UnsDict,
misc_group: h5py.Group,
h5Seurat_file: h5py.File) -> None:
"""
Recursively save `uns` (i.e. `misc`) to an `.h5Seurat` file. Only
string values will be saved.
Args:
uns: an `uns` dictionary
misc_group: `uns` as an `h5py.Group`
h5Seurat_file: an `h5py.File` open in write mode
"""
for key, value in uns.items():
if isinstance(value, dict):
SingleCell._save_h5Seurat_uns(
value, misc_group.create_group(key), h5Seurat_file)
elif isinstance(value, str):
misc_group.create_dataset(key, data=value)
@property
def obs_names(self) -> pl.Series:
"""
A shortcut to access the first column of `obs`. Generally holds cell
barcodes.
"""
return self._obs[:, 0]
@property
def var_names(self) -> pl.Series:
"""
A shortcut to access the first column of `var`. Generally holds gene
names.
"""
return self._var[:, 0]
[docs]
def set_obs_names(self, column: str, /) -> SingleCell:
"""
Sets a column as the new first column of `obs`, i.e. the `obs_names`.
Args:
column: the column name in `obs`; must be String, Enum,
Categorical, or integer
Returns:
A new SingleCell dataset with `column` as the first column of
`obs`. If `column` is already the first column, return this dataset
unchanged.
"""
obs = self._obs
check_type(column, 'column', str, 'a string')
if column == obs.columns[0]:
return self
if column not in obs:
error_message = f'{column!r} is not a column of obs'
raise ValueError(error_message)
check_dtype(obs[column], f'obs[{column!r}]',
(pl.String, pl.Enum, pl.Categorical, 'integer'))
return SingleCell(X=self._X,
obs=obs.select(column, pl.exclude(column)),
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def set_var_names(self, column: str, /) -> SingleCell:
"""
Sets a column as the new first column of `var`, i.e. the `var_names`.
Args:
column: the column name in `var`; must be String, Enum,
Categorical, or integer
Returns:
A new SingleCell dataset with `column` as the first column of
`var`. If `column` is already the first column, return this dataset
unchanged.
"""
var = self._var
check_type(column, 'column', str, 'a string')
if column == var.columns[0]:
return self
if column not in var:
error_message = f'{column!r} is not a column of var'
raise ValueError(error_message)
check_dtype(self._var[column], f'var[{column!r}]',
(pl.String, pl.Enum, pl.Categorical, 'integer'))
return SingleCell(X=self._X, obs=self._obs,
var=var.select(column, pl.exclude(column)),
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
@property
def num_threads(self) -> int:
"""
The default number of threads used for this SingleCell dataset's
operations.
"""
return self._num_threads
@num_threads.setter
def num_threads(self, num_threads: int | np.integer) -> None:
"""
Set the default number of threads used for this SingleCell dataset's
operations. Also sets the number of threads for this SingleCell
object's count matrix, if present.
Args:
num_threads: the new default number of threads. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`.
"""
check_type(num_threads, 'num_threads', int, 'a positive integer or -1')
if num_threads == -1:
num_threads = os.cpu_count()
else:
num_threads = int(num_threads)
if num_threads <= 0:
error_message = (
f'num_threads is {num_threads:,}, but must be a positive '
f'integer or -1')
raise ValueError(error_message)
self._num_threads = num_threads
if self._X is not None:
self._X.num_threads = num_threads
[docs]
def set_num_threads(self, num_threads: int | np.integer, /) -> SingleCell:
"""
Return a new SingleCell dataset with a different default number of
threads. Also sets the number of threads for the SingleCell dataset's
count matrix, if present.
Args:
num_threads: the new default number of threads. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`.
"""
if self._X is not None:
self._X.num_threads = num_threads
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp, uns=self._uns,
num_threads=num_threads)
def _process_num_threads(self,
num_threads: int | np.integer | None) -> int:
"""
Process a `num_threads` value specified by the user as an argument to a
SingleCell function.
Check that `num_threads` is a positive integer, -1 or `None`; if
`None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`.
Args:
num_threads: the number of threads specified by the user
Returns:
The actual number of threads to use.
"""
if num_threads is None:
return self._num_threads
check_type(num_threads, 'num_threads', int,
'a positive integer, -1, or None')
if num_threads == -1:
return os.cpu_count()
else:
num_threads = int(num_threads)
if num_threads <= 0:
error_message = (
f'num_threads is {num_threads:,}, but must be a positive '
f'integer, -1, or None')
raise ValueError(error_message)
return num_threads
@staticmethod
def _process_num_threads_static(num_threads: int | np.integer) -> int:
"""
Process a `num_threads` value specified by the user as an argument to a
SingleCell function.
Check that `num_threads` is a positive integer or -1; if -1, set to
`os.cpu_count()`. Unlike `_process_num_threads()`, there is no
SingleCell object, so `num_threads` cannot be `None`.
Args:
num_threads: the number of threads specified by the user
Returns:
The actual number of threads to use.
"""
check_type(num_threads, 'num_threads', int, 'a positive integer or -1')
if num_threads == -1:
return os.cpu_count()
else:
num_threads = int(num_threads)
if num_threads <= 0:
error_message = (
f'num_threads is {num_threads:,}, but must be a positive '
f'integer or -1')
raise ValueError(error_message)
return num_threads
@staticmethod
def _get_datasets_to_preload(hdf5_file: h5py.File,
datasets_to_load: set[str]) -> \
tuple[list[tuple[str, tuple[int, ...], np.dtype, int | None]],
list[tuple[str, tuple[int, ...], np.dtype, int | None]]]:
"""
Given the names of datasets that need to be loaded, stratify into
fixed-length datasets (numeric or fixed-length string data) and
variable-length string datasets (dtype == object). Also get each
dataset's offset within the hdf5 file.
Args:
hdf5_file: an `h5py.File` open in read mode
datasets_to_load: the set of datasets to be loaded
Returns:
Two lists of `(name, shape, dtype, file_offset)` tuples of datasets
to preload: the first is for fixed-width datasets (numeric and
fixed-length strings), the second for variable-length string
datasets. `file_offset` is the byte offset of the dataset's
contiguous storage in the file, or `None` if the dataset is chunked
and/or compressed. If `None` for any dataset, we will use the
multiprocessing-based parallel reader rather than the thread-based
one.
"""
fixed_datasets = []
vlen_datasets = []
for key in datasets_to_load:
dataset = hdf5_file[key]
file_offset = dataset.id.get_offset()
dataset_info = \
dataset.name, dataset.shape, dataset.dtype, file_offset
if dataset.dtype.kind == 'O':
vlen_datasets.append(dataset_info)
else:
fixed_datasets.append(dataset_info)
return fixed_datasets, vlen_datasets
@staticmethod
def _read_dataset(dataset: h5py.Dataset,
preloaded_datasets: dict[str, np.ndarray]) -> \
np.ndarray:
"""
Read an HDF5 dataset into a NumPy array, if not already preloaded.
Read in ~32-MB chunks, to allow KeyboardInterrupts.
Args:
dataset: the dataset to read
preloaded_datasets: a dictionary of preloaded datasets. If `key` is
present, load it from there instead of the
HDF5 file.
Returns:
A 1D NumPy array with the contents of the dataset.
"""
dataset_name = dataset.name
if dataset_name in preloaded_datasets:
return preloaded_datasets[dataset_name]
# Fast path for datasets under 32 MB
if not dataset.shape:
return dataset[()]
elif dataset.size * dataset.dtype.itemsize <= 33_554_432:
return dataset[:]
# Calculate how many rows fit into 32 MB, minimum 1
step = max(1, 33_554_432 // (
int(np.prod(dataset.shape[1:])) * dataset.dtype.itemsize))
# Align reading to HDF5 internal chunks
if dataset.chunks is not None:
step = max(dataset.chunks[0], (
step // dataset.chunks[0]) * dataset.chunks[0])
result = np.empty(dataset.shape, dtype=dataset.dtype)
for start in range(0, dataset.shape[0], step):
end = min(start + step, dataset.shape[0])
# `read_direct()` skips temporary array allocation overhead
dataset.read_direct(result, np.s_[start:end], np.s_[start:end])
return result
@staticmethod
def _read_parallel(datasets_to_preload: tuple[
list[tuple[str, tuple[int, ...], np.dtype,
int | None]],
list[tuple[str, tuple[int, ...], np.dtype,
int | None]]],
filename: str,
*,
num_cells: int,
num_threads: int | np.integer) -> \
dict[str, np.ndarray | pl.Series]:
"""
Read a sequence of HDF5 datasets into a dictionary of 1D NumPy arrays
(or polars String Series, for variable-length strings) in parallel.
Optimized for the common case of unchunked, uncompressed datasets: uses
`pread()` with OpenMP threads to bypass HDF5's single-threaded I/O.
Variable-length string datasets are read by parsing HDF5's global heap
format directly, avoiding both HDF5's global lock and the overhead of
creating Python string objects. The result is a polars String Series
backed by Arrow.
All datasets >= `num_cells` elements are loaded with per-thread byte
slices, so that Linux's first-touch NUMA policy distributes pages
across NUMA nodes in the same pattern as downstream `prange` over
cells. Smaller datasets are split into ~1 MiB chunks and distributed
among threads via best-fit decreasing bin-packing.
Args:
datasets_to_preload: two lists of (name, shape, dtype, file_offset)
tuples of datasets to preload: the first is
for fixed-width datasets, the second for
variable-length string datasets. file_offset
is None for chunked/compressed datasets.
filename: the HDF5 filename
num_cells: the number of cells (observations) in the dataset;
datasets of at least `num_cells` in size will be loaded
in a NUMA-aware way across all threads, while smaller
datasets will undergo chunk-based loading
num_threads: the number of threads (or processes, in the fallback)
to spawn when reading
Returns:
A dictionary mapping dataset names to their contents: 1D NumPy
arrays for fixed-width datasets, and polars String Series for
variable-length string datasets.
"""
fixed_datasets, vlen_datasets = datasets_to_preload
if len(fixed_datasets) == 0 and len(vlen_datasets) == 0:
return {}
result = {}
# Allocate arrays for fixed-width datasets and separate into large
# (>= num_cells elements, loaded in a NUMA-aware way across all
# threads) and small (loaded in ~1 MiB chunks, bin-packed among
# threads)
large_fixed_arrays = []
large_fixed_file_offsets = []
small_fixed_info = []
total_small_bytes = 0
for name, shape, dtype, file_offset in fixed_datasets:
if np.prod(shape) >= num_cells:
array = numa_zeros(shape, dtype=dtype)
large_fixed_arrays.append(array)
large_fixed_file_offsets.append(file_offset)
else:
array = np.empty(shape, dtype=dtype)
rows = shape[0]
bytes_per_row = np.prod(shape[1:]) * dtype.itemsize
bytes_per_dataset = rows * bytes_per_row
small_fixed_info.append((array, file_offset, rows,
bytes_per_row, bytes_per_dataset))
total_small_bytes += bytes_per_dataset
result[name] = array
large_fixed_file_offsets = np.array(large_fixed_file_offsets,
dtype=np.uint64)
# Bin-pack small fixed-width datasets among threads using a best-fit
# decreasing (BFD) bin-packing algorithm
page_size = 16_384 if sys.platform == 'darwin' else 4096
page_size_minus_1 = page_size - 1
target = (total_small_bytes + num_threads - 1) // num_threads
target = min(max(target, 4_194_304), 67_108_864) # 4 MB to 64 MB
target = (target // page_size) * page_size or \
page_size # multiple of `page_size` bytes
task_records = []
for array, file_offset, rows, bytes_per_row, bytes_per_dataset \
in small_fixed_info:
destination_base = array.ctypes.data
if bytes_per_dataset <= target:
task_records.append((bytes_per_dataset, file_offset,
destination_base))
else:
rows_per_chunk = min(max(1, target // bytes_per_row), rows)
for start in range(0, rows, rows_per_chunk):
end = min(start + rows_per_chunk, rows)
chunk_bytes = (end - start) * bytes_per_row
# Since we are doing direct reading, as an optimization,
# align the chunk size when the file offset is
# block-aligned
offset = file_offset + start * bytes_per_row
if offset & page_size_minus_1 == 0:
chunk_bytes &= ~page_size_minus_1
end = start + chunk_bytes // bytes_per_row
task_records.append((
chunk_bytes, offset,
destination_base + start * bytes_per_row))
task_records.sort(reverse=True)
tasks_by_thread = [[] for _ in range(num_threads)]
loads = np.zeros(num_threads, dtype=np.int64)
for chunk_bytes, chunk_file_offset, chunk_destination in task_records:
best_slack = None
for index, load in enumerate(loads):
slack = target - (load + chunk_bytes)
if slack >= 0 and (best_slack is None or slack < best_slack):
best_index = index
best_slack = slack
if best_slack is None:
best_index = loads.argmin()
tasks_by_thread[best_index].append(
(chunk_file_offset, chunk_destination, chunk_bytes))
loads[best_index] += chunk_bytes
# Flatten per-thread tasks into arrays for Cython
chunk_file_offsets = []
chunk_byte_sizes = []
chunk_destinations = []
chunk_thread_boundaries = [0]
for thread in range(num_threads):
for chunk_file_offset, chunk_destination, chunk_bytes \
in tasks_by_thread[thread]:
chunk_file_offsets.append(chunk_file_offset)
chunk_byte_sizes.append(chunk_bytes)
chunk_destinations.append(chunk_destination)
chunk_thread_boundaries.append(len(chunk_file_offsets))
chunk_file_offsets = np.array(chunk_file_offsets, dtype=np.uint64)
chunk_byte_sizes = np.array(chunk_byte_sizes, dtype=np.uint64)
chunk_destinations = np.array(chunk_destinations, dtype=np.uint64)
chunk_thread_boundaries = \
np.array(chunk_thread_boundaries, dtype=np.uint64)
# Prepare for variable-length string loading
num_vlen_datasets = len(vlen_datasets)
vlen_num_strings = np.empty(num_vlen_datasets, dtype=np.uint64)
vlen_file_offsets = np.empty(num_vlen_datasets, dtype=np.uint64)
if num_vlen_datasets > 0:
with h5py.File(filename) as hdf5_file:
offset_size, length_size = \
hdf5_file.id.get_create_plist().get_sizes()
for index, (name, shape, dtype, file_offset) in \
enumerate(vlen_datasets):
vlen_num_strings[index] = np.prod(shape)
vlen_file_offsets[index] = file_offset
else:
offset_size = length_size = 0
# Open the file and read the datasets. Direct reads (`O_DIRECT`) bypass
# the page cache, giving a several-fold speedup on NFS filesystems.
# However, certain file systems (Lustre < 2.16, ext4, xfs, etc.) only
# support direct reads that are aligned to a multiple of e.g. 4096
# bytes. Thus, we first probe the filesystem with a deliberately
# unaligned direct read, and use `O_DIRECT` only if the probe succeeds,
# falling back to buffered reading if it fails.
fd = -1
direct = False
try:
if hasattr(os, 'O_DIRECT'):
try:
candidate = os.open(filename, os.O_RDONLY | os.O_DIRECT)
except OSError:
# Some filesystems (tmpfs, some overlay/FUSE mounts) reject
# O_DIRECT at open() time
candidate = -1
if candidate != -1:
# Probe: force all three O_DIRECT alignment axes to be
# unaligned at once - file offset (1), length (1), and
# destination buffer. An mmap region is page-aligned, so
# reading into buf[1:2] guarantees an unaligned memory
# address.
probe_buffer = mmap.mmap(-1, 4096)
try:
os.preadv(candidate,
[memoryview(probe_buffer)[1:2]], 1)
fd = candidate
direct = True
except OSError:
os.close(candidate)
finally:
probe_buffer.close()
if fd == -1:
fd = os.open(filename, os.O_RDONLY)
if sys.platform == 'darwin':
# macOS equivalent of `O_DIRECT`; no alignment constraints
import fcntl
fcntl.fcntl(fd, fcntl.F_NOCACHE, 1)
if not direct and hasattr(os, 'posix_fadvise'):
# Buffered fallback path only: `posix_fadvise()` is a
# page-cache hint and is a no-op under O_DIRECT. Use
# `POSIX_FADV_SEQUENTIAL` because the gain on the large
# fixed-width datasets likely outweighs the cost on the small
# scattered vlen header reads.
os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_SEQUENTIAL)
vlen_offsets_out, vlen_data_out = read_all_datasets(
fd, large_fixed_arrays, large_fixed_file_offsets,
chunk_file_offsets, chunk_byte_sizes, chunk_destinations,
chunk_thread_boundaries, vlen_file_offsets, vlen_num_strings,
length_size, offset_size, num_threads)
finally:
if fd != -1:
os.close(fd)
for dataset_index, dataset in enumerate(vlen_datasets):
output_offsets = vlen_offsets_out[dataset_index]
output_data = vlen_data_out[dataset_index]
num_strings = vlen_num_strings[dataset_index]
total_string_bytes = output_offsets[num_strings]
name = dataset[0]
result[name] = pl.from_arrow(
pa.LargeStringArray.from_buffers(
num_strings, pa.py_buffer(output_offsets),
pa.py_buffer(output_data[:total_string_bytes])))
return result
@staticmethod
def _tabulate_h5ad_dataframe(
h5ad_file: h5py.File,
key: str,
datasets_to_load: set[str],
*,
columns: str | Sequence[str] | None = None) -> None:
"""
Add to `datasets_to_load` the datasets in an `.h5ad` file required to
load `obs` or `var` (or dataframe keys of `obsm` or `varm`), or the
requested columns thereof if `obs_columns` or `var_columns` was
specified.
Args:
h5ad_file: an `h5py.File` open in read mode
key: the key to be loaded as a DataFrame, e.g. `'obs'` or `'var'`
datasets_to_load: the set of required datasets, which will be
modified by this function
columns: the column(s) of the DataFrame requested; the index column
is always loaded as the first column, regardless of
whether it is specified here
"""
# Get the group corresponding to `obs` or `var`
group = h5ad_file[key]
# Special case: the entire `obs` or `var` may rarely be a single NumPy
# structured array (`dtype=void`)
if isinstance(group, h5py.Dataset) and \
np.issubdtype(group.dtype, np.void):
datasets_to_load.add(key)
return
# Special case: the entire `obs` or `var` may rarely be a single,
# unnested column.
unnested = group.attrs['encoding-type'] != 'dataframe'
# Get the list of which columns to load, in which order
if not unnested:
if columns is None:
columns = group.attrs['column-order']
else:
columns = [column for column in to_tuple(columns)
if column != group.attrs['_index']]
for column in columns:
if column not in group.attrs['column-order']:
error_message = f'{column!r} is not a column of {key}'
raise ValueError(error_message)
# For each column, enumerate the dataset(s) required to load it
for column, value in ((column, group[column]) for column in
chain((group.attrs['_index'],), columns)) \
if not unnested else (('_index', group),):
encoding_type = value.attrs.get('encoding-type')
if encoding_type == 'categorical' or (
isinstance(value, h5py.Group) and all(
key == 'categories' or key == 'codes'
for key in value.keys())) or 'categories' in value.attrs:
# Sometimes, the categories are stored in a different place
# which is pointed to by `value.attrs['categories']`
if 'categories' in value.attrs:
category_object = h5ad_file[value.attrs['categories']]
category_encoding_type = None
datasets_to_load.add(value.name)
else:
category_object = value['categories']
category_encoding_type = \
category_object.attrs.get('encoding-type')
datasets_to_load.add(value['codes'].name)
# Sometimes, the categories are themselves nullable
# integer or Boolean arrays
if category_encoding_type == 'nullable-integer' or \
category_encoding_type == 'nullable-boolean' or (
isinstance(category_object, h5py.Group) and all(
key == 'values' or key == 'mask'
for key in category_object.keys())):
datasets_to_load.add(category_object['values'].name)
datasets_to_load.add(category_object['mask'].name)
else:
datasets_to_load.add(category_object.name)
elif encoding_type == 'nullable-integer' or \
encoding_type == 'nullable-boolean' or (
isinstance(value, h5py.Group) and all(
key == 'values' or key == 'mask' for key in value.keys())):
datasets_to_load.add(value['values'].name)
datasets_to_load.add(value['mask'].name)
elif encoding_type == 'array' or \
encoding_type == 'string-array' or \
isinstance(value, h5py.Dataset):
datasets_to_load.add(value.name)
else:
encoding = f'encoding-type {encoding_type!r}' \
if encoding_type is not None else 'encoding'
if unnested:
error_message = f'{key!r} has unsupported {encoding}'
raise ValueError(error_message)
else:
error_message = (
f'{column!r} column of {key!r} has unsupported '
f'{encoding}')
raise ValueError(error_message)
@staticmethod
def _read_h5ad_dataframe(h5ad_file: h5py.File,
key: str,
preloaded_datasets: dict[str,
np.ndarray],
*,
columns: str | Sequence[str] | None = None) \
-> pl.DataFrame:
"""
Load `obs` or `var` from an `.h5ad` file as a polars DataFrame.
Enum casts and Binary-to-String casts are deferred and applied in a
single lazy pass, allowing polars to parallelize them across columns.
Args:
h5ad_file: an `h5py.File` open in read mode
key: the key to load as a DataFrame, e.g. `'obs'` or `'var'`
preloaded_datasets: a dictionary of preloaded datasets
columns: the column(s) of the DataFrame to load; the index column
is always loaded as the first column, regardless of
whether it is specified here, and then the remaining
columns are loaded in the order specified
Returns:
A polars DataFrame of the data in `h5ad_file[key]`.
"""
# Get the group corresponding to `obs` or `var`
group = h5ad_file[key]
# Special case: the entire `obs` or `var` may rarely be a single NumPy
# structured array (`dtype=void`)
if isinstance(group, h5py.Dataset) and \
np.issubdtype(group.dtype, np.void):
data = pl.from_numpy(group[:])
data = data.with_columns(pl.col(pl.Binary).cast(pl.String))
return data
# Special case: the entire `obs` or `var` may rarely be a single,
# unnested column.
unnested = group.attrs['encoding-type'] != 'dataframe'
# Get the list of which columns to load, in which order
if not unnested:
if columns is None:
columns = group.attrs['column-order']
else:
columns = [column for column in to_tuple(columns)
if column != group.attrs['_index']]
for column in columns:
if column not in group.attrs['column-order']:
error_message = \
f'{column!r} is not a column of {key}'
raise ValueError(error_message)
# Create the DataFrame. Enum casts are deferred and applied in a
# single lazy pass at the end, so that polars can parallelize them
# across columns.
data = {}
deferred_casts = []
for column, value in ((column, group[column]) for column in
chain((group.attrs['_index'],), columns)) \
if not unnested else (('_index', group),):
encoding_type = value.attrs.get('encoding-type')
if encoding_type == 'categorical' or (
isinstance(value, h5py.Group) and all(
key == 'categories' or key == 'codes'
for key in value.keys())) or 'categories' in value.attrs:
# Sometimes, the categories are stored in a different place
# which is pointed to by `value.attrs['categories']`
if 'categories' in value.attrs:
category_object = h5ad_file[value.attrs['categories']]
category_encoding_type = None
codes = SingleCell._read_dataset(
value, preloaded_datasets)
else:
category_object = value['categories']
category_encoding_type = \
category_object.attrs.get('encoding-type')
codes = SingleCell._read_dataset(
value['codes'], preloaded_datasets)
# Sometimes, the categories are themselves nullable
# integer or Boolean arrays
if category_encoding_type == 'nullable-integer' or \
category_encoding_type == 'nullable-boolean' or (
isinstance(category_object, h5py.Group) and all(
key == 'values' or key == 'mask'
for key in category_object.keys())):
data[column] = pl.Series(SingleCell._read_dataset(
category_object['values'],
preloaded_datasets)[codes])
mask = pl.Series(SingleCell._read_dataset(
category_object['mask'],
preloaded_datasets)[codes] | (codes == -1))
has_missing = mask.any()
if has_missing:
data[column] = data[column].set(mask, None)
continue
categories = SingleCell._read_dataset(
category_object, preloaded_datasets)
mask = pl.Series(codes == -1)
has_missing = mask.any()
# polars does not (as of version 1.0) support Categoricals
# or Enums with non-string categories, so if the categories
# are not strings, just map the codes to the categories.
if category_encoding_type == 'array' or (
isinstance(category_object, h5py.Dataset) and
category_object.dtype != object):
data[column] = pl.Series(categories[codes],
nan_to_null=True)
if has_missing:
data[column] = data[column].set(mask, None)
elif category_encoding_type == 'string-array' or (
isinstance(category_object, h5py.Dataset) and
category_object.dtype == object):
if has_missing:
codes[mask] = 0
data[column] = pl.Series(codes, dtype=pl.UInt32)
if has_missing:
data[column] = data[column].set(mask, None)
deferred_casts.append(pl.col(column).cast(
pl.Enum(pl.Series(categories).cast(pl.String))))
else:
encoding = \
f'encoding-type {category_encoding_type!r}' \
if category_encoding_type is not None \
else 'encoding'
error_message = (
f'{column!r} column of {key!r} is a categorical '
f'with unsupported {encoding}')
raise ValueError(error_message)
elif encoding_type == 'nullable-integer' or \
encoding_type == 'nullable-boolean' or (
isinstance(value, h5py.Group) and all(
key == 'values' or key == 'mask'
for key in value.keys())):
values = SingleCell._read_dataset(
value['values'], preloaded_datasets)
mask = SingleCell._read_dataset(
value['mask'], preloaded_datasets)
data[column] = pl.Series(values).set(
pl.Series(mask), None)
elif encoding_type == 'array' or (
isinstance(value, h5py.Dataset) and
value.dtype != object):
data[column] = pl.Series(SingleCell._read_dataset(
value, preloaded_datasets), nan_to_null=True)
elif encoding_type == 'string-array' or (
isinstance(value, h5py.Dataset) and
value.dtype == object):
data[column] = \
SingleCell._read_dataset(value, preloaded_datasets)
else:
encoding = f'encoding-type {encoding_type!r}' \
if encoding_type is not None else 'encoding'
error_message = (
f'{column!r} column of {key!r} has unsupported '
f'{encoding}')
raise ValueError(error_message)
data = pl.DataFrame(data)
# Apply deferred Enum casts and Binary-to-String casts in one pass,
# allowing polars to parallelize across columns
deferred_casts.append(pl.col(pl.Binary).cast(pl.String))
data = data.with_columns(deferred_casts)
return data
@staticmethod
def _tabulate_h5Seurat_dataframe(group: h5py.Group,
datasets_to_load: set[str],
*,
columns: str | Sequence[str] |
None = None) -> None:
"""
Add to `datasets_to_load` the datasets in an `.h5ad` file required to
load `obs` or `var` (or dataframe keys of `obsm` or `varm`), or the
requested columns thereof if `obs_columns` or `var_columns` was
specified.
Args:
group: the group to load as a DataFrame, e.g. `'meta.data'` or
`'meta.features'`
datasets_to_load: the set of datasets that will eventually need to
be loaded; will be modified by this function
columns: the column(s) of the DataFrame requested; the index column
is always loaded as the first column, regardless of
whether it is specified here
"""
# Get the list of which columns to load, in which order
if columns is None:
columns = group.attrs['colnames']
else:
columns = [column for column in to_tuple(columns)
if column != group.attrs['_index']]
for column in columns:
if column not in group.attrs['colnames']:
error_message = f'{column!r} is not a column of meta.data'
raise ValueError(error_message)
# For each column, enumerate the dataset(s) required to load it
for column in chain(group.attrs['_index'], columns):
value = group[column]
if isinstance(value, h5py.Group):
# Factor
if len(value) != 2 or 'levels' not in value or \
'values' not in value:
error_message = (
f"the h5Seurat file's meta.data contains a group "
f"of unknown format, ")
if len(value) <= 5:
error_message += (
f'with the following keys: '
f'{", ".join(map(repr, value))}')
else:
error_message += (
f'with {len(value):,} keys, including the '
f'following: '
f'{", ".join(map(repr, islice(value, 5)))}')
raise ValueError(error_message)
datasets_to_load.add(value['values'].name)
datasets_to_load.add(value['levels'].name)
else:
datasets_to_load.add(value.name)
@staticmethod
def _read_h5Seurat_dataframe(group: h5py.Group,
preloaded_datasets: dict[
str, np.ndarray],
*,
columns: str | Sequence[str] |
None = None) -> pl.DataFrame:
"""
Load `meta.data` (i.e. `obs`) or `meta.features` (i.e. `var`) from an
`.h5Seurat` file as a polars DataFrame.
Args:
group: the group to load as a DataFrame, e.g. `'meta.data'` or
`'meta.features'`
preloaded_datasets: a dictionary of preloaded datasets
columns: the column(s) of the DataFrame to load; the index column
is always loaded as the first column, regardless of
whether it is specified here, and then the remaining
columns are loaded in the order specified
Returns:
A polars DataFrame of the data in `h5Seurat_file[key]`.
"""
# Get the list of which columns to load, in which order
if columns is None:
columns = group.attrs['colnames']
else:
columns = [column for column in to_tuple(columns)
if column != group.attrs['_index']]
for column in columns:
if column not in group.attrs['colnames']:
error_message = f'{column!r} is not a column of meta.data'
raise ValueError(error_message)
# Get the list of which columns are Boolean
Boolean_columns = group.attrs.get('logicals', ())
# Create the DataFrame
data = {}
for column in chain(group.attrs['_index'], columns):
value = group[column]
if isinstance(value, h5py.Group):
# Factor
if len(value) != 2 or 'levels' not in value or \
'values' not in value:
error_message = (
f"the h5Seurat file's meta.data contains a group "
f"of unknown format, ")
if len(value) <= 5:
error_message += (
f'with the following keys: '
f'{", ".join(map(repr, value))}')
else:
error_message += (
f'with {len(value):,} keys, including the '
f'following: '
f'{", ".join(map(repr, islice(value, 5)))}')
raise ValueError(error_message)
values = SingleCell._read_dataset(
value['values'], preloaded_datasets)
levels = value['levels'][:]
data[column] = (pl.Series(values) - 1)\
.cast(pl.Enum(pl.Series(levels).cast(pl.String)))
else:
data[column] = pl.Series(SingleCell._read_dataset(
value, preloaded_datasets), nan_to_null=True)
if column in Boolean_columns:
data[column] = data[column]\
.replace({2: None})\
.cast(pl.Boolean)
elif data[column].dtype == pl.Int32:
data[column] = data[column].replace({-2147483648: None})
data = pl.DataFrame(data)
# NumPy doesn't support encoding object-dtyped string arrays as UTF-8,
# so do the conversion in polars instead. Also convert `b'NA'`
# (h5Seurat's missing value indicator) to `null`.
data = data.with_columns(pl.col(pl.Binary).replace({b'NA': None})
.cast(pl.String))
return data
[docs]
@staticmethod
def read_obs(h5ad_file: h5py.File | str | Path,
*,
columns: str | Iterable[str] | None = None,
num_threads: int | np.integer = -1) -> pl.DataFrame:
"""
Load just `obs` from an `.h5ad` file as a polars DataFrame.
Args:
h5ad_file: an `.h5ad` filename
columns: the column(s) of `obs` to load; if `None`, load all
columns
num_threads: the number of threads to use when reading. By default
(`num_threads=-1`), use all available cores, as
determined by `os.cpu_count()`.
Returns:
A polars DataFrame of the data in `obs`.
"""
check_type(h5ad_file, 'h5ad_file', (str, Path),
'a string or pathlib.Path')
h5ad_file = str(h5ad_file)
if not h5ad_file.endswith('.h5ad'):
error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'"
raise ValueError(error_message)
filename = os.path.expanduser(h5ad_file)
if not os.path.exists(filename):
error_message = f'.h5ad file {h5ad_file} does not exist'
raise FileNotFoundError(error_message)
if columns is not None:
columns = to_tuple_checked(columns, 'columns', str, 'strings')
# Check that `num_threads` is a positive integer or -1; if -1, set to
# `os.cpu_count()`
num_threads = SingleCell._process_num_threads_static(num_threads)
# Load `obs`, using a similar approach to loading the full `.h5ad` file
if num_threads == 1:
preloaded_datasets = {}
else:
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
with h5py.File(h5ad_file) as f:
SingleCell._tabulate_h5ad_dataframe(
f, 'obs', datasets_to_load, columns=columns)
# 2. Preload datasets in parallel
# If any fixed-width dataset is chunked or compressed
# (indicated by a `None` `file_offset`), fall back to
# the multiprocessing-based reader.
# Crucial: the HDF5 file must be closed before
# `read_parallel_multiprocessing()` to avoid the child
# processes inheriting the HDF5 library's state
# associated with the open file handle when forking.
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
f, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
f.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
f = h5py.File(h5ad_file)
else:
preloaded_datasets = SingleCell._read_parallel(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
# 3. Assemble `obs`
with h5py.File(h5ad_file) as f:
obs = SingleCell._read_h5ad_dataframe(
f, 'obs', preloaded_datasets, columns=columns)
return obs
[docs]
@staticmethod
def read_var(h5ad_file: str | Path,
*,
columns: str | Iterable[str] | None = None,
num_threads: int | np.integer = -1) -> pl.DataFrame:
"""
Load just `var` from an `.h5ad` file as a polars DataFrame.
Args:
h5ad_file: an `.h5ad` filename
columns: the column(s) of `var` to load; if `None`, load all
columns
num_threads: the number of threads to use when reading. By default
(`num_threads=-1`), use all available cores, as
determined by `os.cpu_count()`.
Returns:
A polars DataFrame of the data in `var`.
"""
check_type(h5ad_file, 'h5ad_file', (str, Path),
'a string or pathlib.Path')
h5ad_file = str(h5ad_file)
if not h5ad_file.endswith('.h5ad'):
error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'"
raise ValueError(error_message)
filename = os.path.expanduser(h5ad_file)
if not os.path.exists(filename):
error_message = f'.h5ad file {h5ad_file} does not exist'
raise FileNotFoundError(error_message)
if columns is not None:
columns = to_tuple_checked(columns, 'columns', str, 'strings')
# Check that `num_threads` is a positive integer or -1; if -1, set to
# `os.cpu_count()`
num_threads = SingleCell._process_num_threads_static(num_threads)
# Load `var`, using a similar approach to loading the full `.h5ad` file
if num_threads == 1:
preloaded_datasets = {}
else:
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
with h5py.File(h5ad_file) as f:
SingleCell._tabulate_h5ad_dataframe(
f, 'var', datasets_to_load, columns=columns)
# 2. Preload datasets in parallel
# If any fixed-width dataset is chunked or compressed
# (indicated by a `None` `file_offset`), fall back to
# the multiprocessing-based reader.
# Crucial: the HDF5 file must be closed before
# `read_parallel_multiprocessing()` to avoid the child
# processes inheriting the HDF5 library's state
# associated with the open file handle when forking.
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
f, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
f.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
f = h5py.File(h5ad_file)
else:
preloaded_datasets = SingleCell._read_parallel(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
# 3. Assemble `var`
with h5py.File(h5ad_file) as f:
var = SingleCell._read_h5ad_dataframe(
f, 'var', preloaded_datasets, columns=columns)
return var
[docs]
@staticmethod
def read_obsm(h5ad_file: str | Path,
*,
keys: str | Iterable[str] | None = None,
num_threads: int | np.integer = -1) -> \
dict[str, np.ndarray | pl.DataFrame]:
"""
Load just `obsm` from an `.h5ad` file as a dictionary of Numpy arrays
or DataFrames.
Args:
h5ad_file: an `.h5ad` filename
keys: the keys(s) of `obsm` to load; if `None`, load all keys
num_threads: the number of threads to use when reading. By default
(`num_threads=-1`), use all available cores, as
determined by `os.cpu_count()`.
Returns:
A dictionary of NumPy arrays and polars DataFrames of the data in
obsm.
"""
check_type(h5ad_file, 'h5ad_file', (str, Path),
'a string or pathlib.Path')
h5ad_file = str(h5ad_file)
if not h5ad_file.endswith('.h5ad'):
error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'"
raise ValueError(error_message)
filename = os.path.expanduser(h5ad_file)
if not os.path.exists(filename):
error_message = f'.h5ad file {h5ad_file} does not exist'
raise FileNotFoundError(error_message)
if keys is not None:
keys = to_tuple_checked(keys, 'keys', str, 'strings')
# Check that `num_threads` is a positive integer or -1; if -1, set to
# `os.cpu_count()`
num_threads = SingleCell._process_num_threads_static(num_threads)
# Load `obsm`, using a similar approach to loading the full `.h5ad`
# file
if num_threads == 1:
preloaded_datasets = {}
else:
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
with h5py.File(h5ad_file) as f:
if 'obsm' in f:
obsm = f['obsm']
if keys is None:
for key, value in obsm.items():
if isinstance(value, h5py.Dataset):
datasets_to_load.add(f'obsm/{key}')
else:
SingleCell._tabulate_h5ad_dataframe(
f, f'obsm/{key}', datasets_to_load)
else:
for key_index, key in enumerate(keys):
if key not in obsm:
error_message = (
f'keys[{key_index}] is {key!r}, which is '
f'not a key of obsm')
raise ValueError(error_message)
for key in keys:
value = obsm[key]
if isinstance(value, h5py.Dataset):
datasets_to_load.add(f'obsm/{key}')
else:
SingleCell._tabulate_h5ad_dataframe(
f, f'obsm/{key}', datasets_to_load)
else:
if keys is not None:
error_message = 'keys was specified, but obsm is empty'
raise ValueError(error_message)
return {}
# 2. Preload datasets in parallel
# If any fixed-width dataset is chunked or compressed
# (indicated by a `None` `file_offset`), fall back to
# the multiprocessing-based reader.
# Crucial: the HDF5 file must be closed before
# `read_parallel_multiprocessing()` to avoid the child
# processes inheriting the HDF5 library's state
# associated with the open file handle when forking.
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
f, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
f.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
f = h5py.File(h5ad_file)
else:
preloaded_datasets = SingleCell._read_parallel(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
# 3. Assemble `obsm`
with h5py.File(h5ad_file) as f:
if 'obsm' in f:
obsm = {
key: SingleCell._read_dataset(value, preloaded_datasets)
if isinstance(value, h5py.Dataset) else
SingleCell._read_h5ad_dataframe(
f, f'obsm/{key}', preloaded_datasets)
for key, value in f['obsm'].items()}
else:
obsm = {}
return obsm
[docs]
@staticmethod
def read_varm(h5ad_file: str | Path,
*,
keys: str | Iterable[str] | None = None,
num_threads: int | np.integer = -1) -> \
dict[str, np.ndarray | pl.DataFrame]:
"""
Load just `varm` from an `.h5ad` file as a dictionary of Numpy arrays
or DataFrames.
Args:
h5ad_file: an `.h5ad` filename
keys: the keys(s) of `varm` to load; if `None`, load all keys
num_threads: the number of threads to use when reading DataFrame
keys of `varm`. By default (`num_threads=-1`), use all
available cores, as determined by `os.cpu_count()`.
Returns:
A dictionary of NumPy arrays and polars DataFrames of the data in
varm.
"""
check_type(h5ad_file, 'h5ad_file', (str, Path),
'a string or pathlib.Path')
h5ad_file = str(h5ad_file)
if not h5ad_file.endswith('.h5ad'):
error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'"
raise ValueError(error_message)
filename = os.path.expanduser(h5ad_file)
if not os.path.exists(filename):
error_message = f'.h5ad file {h5ad_file} does not exist'
raise FileNotFoundError(error_message)
if keys is not None:
keys = to_tuple_checked(keys, 'keys', str, 'strings')
# Check that `num_threads` is a positive integer or -1; if -1, set to
# `os.cpu_count()`
num_threads = SingleCell._process_num_threads_static(num_threads)
# Load `varm`, using a similar approach to loading the full `.h5ad`
# file
if num_threads == 1:
preloaded_datasets = {}
else:
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
with h5py.File(h5ad_file) as f:
if 'varm' in f:
varm = f['varm']
if keys is None:
for key, value in varm.items():
if isinstance(value, h5py.Dataset):
datasets_to_load.add(f'varm/{key}')
else:
SingleCell._tabulate_h5ad_dataframe(
f, f'varm/{key}', datasets_to_load)
else:
for key_index, key in enumerate(keys):
if key not in varm:
error_message = (
f'keys[{key_index}] is {key!r}, which is '
f'not a key of varm')
raise ValueError(error_message)
for key in keys:
value = varm[key]
if isinstance(value, h5py.Dataset):
datasets_to_load.add(f'varm/{key}')
else:
SingleCell._tabulate_h5ad_dataframe(
f, f'varm/{key}', datasets_to_load)
else:
if keys is not None:
error_message = 'keys was specified, but varm is empty'
raise ValueError(error_message)
return {}
# 2. Preload datasets in parallel
# If any fixed-width dataset is chunked or compressed
# (indicated by a `None` `file_offset`), fall back to
# the multiprocessing-based reader.
# Crucial: the HDF5 file must be closed before
# `read_parallel_multiprocessing()` to avoid the child
# processes inheriting the HDF5 library's state
# associated with the open file handle when forking.
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
f, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
f.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
f = h5py.File(h5ad_file)
else:
preloaded_datasets = SingleCell._read_parallel(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
# 3. Assemble `varm`
with h5py.File(h5ad_file) as f:
if 'varm' in f:
varm = {
key: SingleCell._read_dataset(value, preloaded_datasets)
if isinstance(value, h5py.Dataset) else
SingleCell._read_h5ad_dataframe(
f, f'varm/{key}', preloaded_datasets)
for key, value in f['varm'].items()}
else:
varm = {}
return varm
[docs]
@staticmethod
def read_obsp(h5ad_file: str | Path,
*,
keys: str | Iterable[str] | None = None,
num_threads: int | np.integer = -1) -> \
dict[str, csr_array | csc_array]:
"""
Load just `obsp` from an `.h5ad` file as a dictionary of sparse arrays.
Args:
h5ad_file: an `.h5ad` filename
keys: the keys(s) of `obsp` to load; if `None`, load all keys
num_threads: the number of threads to use when reading DataFrame
keys of `varm`. By default (`num_threads=-1`), use all
available cores, as determined by `os.cpu_count()`.
Returns:
A dictionary of sparse arrays of the data in `obsp`.
"""
check_type(h5ad_file, 'h5ad_file', (str, Path),
'a string or pathlib.Path')
h5ad_file = str(h5ad_file)
if not h5ad_file.endswith('.h5ad'):
error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'"
raise ValueError(error_message)
filename = os.path.expanduser(h5ad_file)
if not os.path.exists(filename):
error_message = f'.h5ad file {h5ad_file} does not exist'
raise FileNotFoundError(error_message)
if keys is not None:
keys = to_tuple_checked(keys, 'keys', str, 'strings')
# Load `obsp`, using a similar approach to loading the full `.h5ad`
# file
if num_threads == 1:
preloaded_datasets = {}
else:
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
with h5py.File(h5ad_file) as f:
if 'obsp' in f:
obsp = f['obsp']
if keys is None:
for value in f['obsp'].values():
datasets_to_load.add(value['data'].name)
datasets_to_load.add(value['indices'].name)
datasets_to_load.add(value['indptr'].name)
else:
for key_index, key in enumerate(keys):
if key not in obsp:
error_message = (
f'keys[{key_index}] is {key!r}, which is '
f'not a key of obsp')
raise ValueError(error_message)
for key in keys:
value = obsp[key]
datasets_to_load.add(value['data'].name)
datasets_to_load.add(value['indices'].name)
datasets_to_load.add(value['indptr'].name)
else:
if keys is not None:
error_message = 'keys was specified, but obsp is empty'
raise ValueError(error_message)
return {}
# 2. Preload datasets in parallel
# If any fixed-width dataset is chunked or compressed
# (indicated by a `None` `file_offset`), fall back to
# the multiprocessing-based reader.
# Crucial: the HDF5 file must be closed before
# `read_parallel_multiprocessing()` to avoid the child
# processes inheriting the HDF5 library's state
# associated with the open file handle when forking.
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
f, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
f.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
f = h5py.File(h5ad_file)
else:
preloaded_datasets = SingleCell._read_parallel(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
# 3. Assemble `obsp`
with h5py.File(h5ad_file) as f:
if 'obsp' in f:
obsp = {
key: (csr_array if value.attrs['encoding-type'] ==
'csr_matrix' else csc_array)(
(SingleCell._read_dataset(value['data'],
preloaded_datasets),
SingleCell._read_dataset(value['indices'],
preloaded_datasets),
SingleCell._read_dataset(value['indptr'],
preloaded_datasets)),
shape=value.attrs['shape'])
for key, value in f['obsp'].items()}
else:
obsp = {}
return obsp
[docs]
@staticmethod
def read_varp(h5ad_file: str | Path,
*,
keys: str | Iterable[str] | None = None,
num_threads: int | np.integer = -1) -> \
dict[str, csr_array | csc_array]:
"""
Load just `varp` from an `.h5ad` file as a dictionary of sparse arrays.
Args:
h5ad_file: an `.h5ad` filename
keys: the keys(s) of `varp` to load; if `None`, load all keys
num_threads: the number of threads to use when reading DataFrame
keys of `varm`. By default (`num_threads=-1`), use all
available cores, as determined by `os.cpu_count()`.
Returns:
A dictionary of sparse arrays of the data in `varp`.
"""
check_type(h5ad_file, 'h5ad_file', (str, Path),
'a string or pathlib.Path')
h5ad_file = str(h5ad_file)
if not h5ad_file.endswith('.h5ad'):
error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'"
raise ValueError(error_message)
filename = os.path.expanduser(h5ad_file)
if not os.path.exists(filename):
error_message = f'.h5ad file {h5ad_file} does not exist'
raise FileNotFoundError(error_message)
if keys is not None:
keys = to_tuple_checked(keys, 'keys', str, 'strings')
# Load `varp`, using a similar approach to loading the full `.h5ad`
# file
if num_threads == 1:
preloaded_datasets = {}
else:
# 1. Tabulate which HDF5 datasets need to be loaded
datasets_to_load = set()
with h5py.File(h5ad_file) as f:
if 'varp' in f:
varp = f['varp']
if keys is None:
for value in f['varp'].values():
datasets_to_load.add(value['data'].name)
datasets_to_load.add(value['indices'].name)
datasets_to_load.add(value['indptr'].name)
else:
for key_index, key in enumerate(keys):
if key not in varp:
error_message = (
f'keys[{key_index}] is {key!r}, which is '
f'not a key of varp')
raise ValueError(error_message)
for key in keys:
value = varp[key]
datasets_to_load.add(value['data'].name)
datasets_to_load.add(value['indices'].name)
datasets_to_load.add(value['indptr'].name)
else:
if keys is not None:
error_message = 'keys was specified, but varp is empty'
raise ValueError(error_message)
return {}
# 2. Preload datasets in parallel
# If any fixed-width dataset is chunked or compressed
# (indicated by a `None` `file_offset`), fall back to
# the multiprocessing-based reader.
# Crucial: the HDF5 file must be closed before
# `read_parallel_multiprocessing()` to avoid the child
# processes inheriting the HDF5 library's state
# associated with the open file handle when forking.
datasets_to_preload = \
SingleCell._get_datasets_to_preload(
f, datasets_to_load)
if any(dataset[-1] is None
for dataset in datasets_to_preload[0]):
f.close()
preloaded_datasets = \
read_parallel_multiprocessing(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
f = h5py.File(h5ad_file)
else:
preloaded_datasets = SingleCell._read_parallel(
datasets_to_preload, h5ad_file,
num_threads=num_threads)
# 3. Assemble `varp`
with h5py.File(h5ad_file) as f:
if 'varp' in f:
varp = {
key: (csr_array if value.attrs['encoding-type'] ==
'csr_matrix' else csc_array)(
(SingleCell._read_dataset(value['data'],
preloaded_datasets),
SingleCell._read_dataset(value['indices'],
preloaded_datasets),
SingleCell._read_dataset(value['indptr'],
preloaded_datasets)),
shape=value.attrs['shape'])
for key, value in f['varp'].items()}
else:
varp = {}
return varp
[docs]
@staticmethod
def read_uns(h5ad_file: str | Path) -> UnsDict:
"""
Load just `uns` from an `.h5ad` file as a dictionary.
Args:
h5ad_file: an .`h5ad` filename
Returns:
A dictionary of the data in `uns`.
"""
check_type(h5ad_file, 'h5ad_file', (str, Path),
'a string or pathlib.Path')
h5ad_file = str(h5ad_file)
if not h5ad_file.endswith('.h5ad'):
error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'"
raise ValueError(error_message)
filename = os.path.expanduser(h5ad_file)
if not os.path.exists(filename):
error_message = f'.h5ad file {h5ad_file} does not exist'
raise FileNotFoundError(error_message)
with h5py.File(filename) as f:
if 'uns' in f:
return SingleCell._read_uns(f['uns'])
else:
return {}
@staticmethod
def _print_matrix_info(X: h5py.Group | h5py.Dataset, X_name: str) -> None:
"""
Given a key of an `.h5ad` file representing a sparse or dense matrix,
print its shape, data type and (if sparse) number of non-zero elements.
Args:
X: the key in the `.h5ad` file representing the matrix, as a
`Group` or `Dataset` object
X_name: the name of the key
"""
is_sparse = isinstance(X, h5py.Group)
if is_sparse:
data = X['data']
shape = X.attrs['shape'] if 'shape' in X.attrs else \
X.attrs['h5sparse_shape']
dtype = str(data.dtype)
nnz = data.shape[0]
print(f'{X_name}: {shape[0]:,} × {shape[1]:,} {dtype} sparse '
f'array with {nnz:,} non-zero elements and first non-zero '
f'element = {data[0]:.6g}')
else:
shape = X.shape
dtype = str(X.dtype)
print(f'{X_name}: {shape[0]:,} × {shape[1]:,} {dtype} dense '
f'matrix with first element = {X[0, 0]:.6g}')
[docs]
@staticmethod
def ls(h5ad_file: str | Path) -> None:
"""
Print the fields in an `.h5ad` file. This can be useful e.g. when
deciding which count matrix to load via the `X_key` argument to
`SingleCell()`.
Args:
h5ad_file: an `.h5ad` filename
"""
check_type(h5ad_file, 'h5ad_file', (str, Path),
'a string or pathlib.Path')
h5ad_file = str(h5ad_file)
if not h5ad_file.endswith('.h5ad'):
error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'"
raise ValueError(error_message)
filename = os.path.expanduser(h5ad_file)
if not os.path.exists(filename):
error_message = f'.h5ad file {h5ad_file} does not exist'
raise FileNotFoundError(error_message)
try:
terminal_width = os.get_terminal_size().columns
except AttributeError:
terminal_width = 80 # for Jupyter notebooks
attrs = 'obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns'
with h5py.File(filename) as f:
# `X`
SingleCell._print_matrix_info(f['X'], 'X')
# layers
if 'layers' in f:
layers = f['layers']
if len(layers) > 0:
for layer_name, layer in layers.items():
SingleCell._print_matrix_info(
layer, f'layers[{layer_name!r}]')
# `obs`, `var`, `obsm`, `varm`, `obsp`, `varp`, `uns`
for attr in attrs:
if attr in f:
entries = f[attr]
if (attr == 'obs' or attr == 'var') and \
isinstance(entries, h5py.Dataset) and \
np.issubdtype(entries.dtype, np.void):
entries = entries.dtype.fields
if len(entries) > 0:
print(fill(f'{attr}: {", ".join(entries)}',
width=terminal_width,
subsequent_indent=' ' * (len(attr) + 2)))
# raw
if 'raw' in f:
raw = f['raw']
if len(raw) > 0:
print('raw:')
if 'X' in raw:
SingleCell._print_matrix_info(raw['X'], ' X')
if 'layers' in raw:
layers = raw['layers']
if len(layers) > 0:
for layer_name, layer in layers.items():
SingleCell._print_matrix_info(
layer, f' layers[{layer_name!r}]')
for attr in attrs:
if attr in raw:
entries = raw[attr]
if (attr == 'obs' or attr == 'var') and \
isinstance(entries, h5py.Dataset) and \
np.issubdtype(entries.dtype, np.void):
entries = entries.dtype.fields
if len(entries) > 0:
print(fill(f' {attr}: {", ".join(entries)}',
width=terminal_width,
subsequent_indent=' ' * (
len(attr) + 6)))
def __eq__(self, other: SingleCell) -> bool:
"""
Test for equality with another SingleCell dataset.
Args:
other: the other SingleCell dataset to test for equality with
Returns:
Whether the two SingleCell datasets are identical.
"""
if not isinstance(other, SingleCell):
error_message = (
f'the left-hand operand of `==` is a SingleCell dataset, but '
f'the right-hand operand has type {type(other).__name__!r}')
raise TypeError(error_message)
return (other._X is None if self._X is None else
other._X is not None) and \
self._num_threads == other._num_threads and \
self._obs.equals(other._obs) and \
self._var.equals(other._var) and \
self._obsm.keys() == other._obsm.keys() and \
self._varm.keys() == other._varm.keys() and \
all(type(other._obsm[key]) is type(value) and
(array_equal(other._obsm[key], value)
if isinstance(value, np.ndarray) else
other._obsm[key].equals(value))
for key, value in self._obsm.items()) and \
all(type(other._varm[key]) is type(value) and
(array_equal(other._varm[key], value)
if isinstance(value, np.ndarray) else
other._varm[key].equals(value))
for key, value in self._varm.items()) and \
SingleCell._eq_uns(self._uns, other._uns) and \
self._X.equals(other._X)
@staticmethod
def _eq_uns(uns: UnsDict,
other_uns: UnsDict,
different_order_ok: bool = False) -> bool:
"""
Test whether two `uns` are equal.
Args:
uns: an `uns`
other_uns: another `uns`
different_order_ok: whether to consider `uns` and `other_uns` equal
when they have the same keys and values, but in
a different order
Returns:
Whether `uns` and `other_uns` are equal.
"""
return set(uns.keys()) == set(other_uns.keys()) \
if different_order_ok else uns.keys() == other_uns.keys() and all(
isinstance(value, dict) and isinstance(other_value, dict) and
SingleCell._eq_uns(value, other_value, different_order_ok) or
isinstance(value, np.ndarray) and
isinstance(other_value, np.ndarray) and
array_equal(value, other_value) or
not isinstance(other_value, (dict, np.ndarray)) and
value == other_value
for key, value, other_value in
((key, value, other_uns[key]) for key, value in uns.items()))
@staticmethod
def _getitem_error(item: Indexer | tuple[Indexer, Indexer]) -> NoReturn:
"""
Raise an error if the indexer is invalid.
Args:
item: the indexer
"""
types = tuple(type(elem).__name__ for elem in to_tuple(item))
if len(types) == 1:
types = types[0]
error_message = (
f'SingleCell indices must be cells, a length-1 tuple of (cells,), '
f'or a length-2 tuple of (cells, genes). Cells and genes must '
f'each be a string or integer; a slice of strings or integers; or '
f'a list, NumPy array, or polars Series of strings, integers, or '
f'Booleans. You indexed with: {types}.')
raise ValueError(error_message)
@staticmethod
def _getitem_by_string(df: pl.DataFrame, string: str) -> int:
"""
Get the index where df[:, 0] == string, raising an error if no rows or
multiple rows match.
Args:
df: a DataFrame (`obs` or `var`)
string: the string to find the index of in the first column of df
Returns:
The integer index of the string within the first column of df.
"""
first_column = df.columns[0]
try:
return df\
.select(pl.int_range(pl.len(), dtype=pl.Int32)
.alias('_SingleCell_getitem'), first_column)\
.row(by_predicate=pl.col(first_column) == string)\
[0]
except pl.exceptions.NoRowsReturnedError:
raise KeyError(string)
@staticmethod
def _getitem_process(item: Indexer | tuple[Indexer, Indexer], index: int,
df: pl.DataFrame) -> list[int] | slice | pl.Series:
"""
Process an element of an item passed to `__getitem__()`.
Args:
item: the item
index: the index of the element to process
df: the DataFrame (`obs` or `var`) to process the element with
respect to
Returns:
A new indexer indicating the rows/columns to index.
"""
subitem = item[index]
if isinstance(subitem, (int, np.integer)):
return [subitem]
elif isinstance(subitem, str):
return [SingleCell._getitem_by_string(df, subitem)]
elif isinstance(subitem, slice):
start = subitem.start
stop = subitem.stop
step = subitem.step
if isinstance(start, str):
start = SingleCell._getitem_by_string(df, start)
elif start is not None and \
not isinstance(start, (int, np.integer)):
SingleCell._getitem_error(item)
if isinstance(stop, str):
stop = SingleCell._getitem_by_string(df, stop)
elif stop is not None and not isinstance(stop, (int, np.integer)):
SingleCell._getitem_error(item)
if step is not None and not isinstance(step, (int, np.integer)):
SingleCell._getitem_error(item)
return slice(start, stop, step)
elif isinstance(subitem, (tuple, list, np.ndarray, pl.Series)):
subitem = pl.Series(subitem)
if subitem.is_null().any():
error_message = 'your indexer contains missing values'
raise ValueError(error_message)
dtype = subitem.dtype
if dtype in (pl.String, pl.Enum, pl.Categorical):
names_dtype = df[:, 0].dtype
if dtype != names_dtype:
subitem = subitem.cast(names_dtype)
indices = subitem\
.to_frame(df.columns[0])\
.join(df.with_columns(_SingleCell_index=pl.int_range(
pl.len(), dtype=pl.UInt32)),
on=df.columns[0], how='left')\
['_SingleCell_index']
if indices.null_count():
error_message = subitem.filter(indices.is_null())[0]
raise KeyError(error_message)
return indices
elif dtype.is_integer() or dtype == pl.Boolean:
return subitem
else:
SingleCell._getitem_error(item)
else:
SingleCell._getitem_error(item)
def __getitem__(self, item: Indexer | tuple[Indexer, Indexer]) -> \
SingleCell:
"""
Subset to specific cell(s) and/or gene(s).
Index with a tuple of `(cells, genes)`. If `cells` and `genes` are
integers, arrays/lists/slices of integers, or arrays/lists of Booleans,
the result will be a SingleCell dataset subset to `X[cells, genes]`,
`obs[cells]`, `var[genes]`, `obsm[cells]`, `varm[genes]`,
`obsp[cells][:, cells]`, and `varp[genes][:, genes]`. However, `cells`
and/or `genes` can instead be strings (or arrays or slices of strings),
in which case they refer to the first column of `obs` (`obs_names`)
and/or `var` (`var_names`), respectively.
Args:
item: the item to index with
Returns:
A new SingleCell dataset subset to the specified cells and/or
genes.
Examples:
Subset to one cell (all genes):
>>> sc[2]
>>> sc['CGAATTGGTGACAGGT-L8TX_210916_01_B05-1131590416']
Subset to one gene (all cells):
>>> sc[:, 13196]
>>> sc[:, 'APOE']
Subset to one cell and one gene:
>>> sc[2, 13196]
>>> sc['CGAATTGGTGACAGGT-L8TX_210916_01_B05-1131590416', 'APOE']
Subset to a range of cells and genes:
>>> sc[2:6, 13196:34268]
>>> sc['CGAATTGGTGACAGGT-L8TX_210916_01_B05-1131590416':
... 'CCCTCTCAGCAGCCTC-L8TX_211007_01_A09-1135034522',
... 'APOE':'TREM2']
Subset to specific cells or genes:
>>> sc[[2, 5, 9]]
>>> sc[:, ['APOE', 'TREM2']]
Boolean indexing:
>>> sc[sc.obs['cell_type'] == 'microglia']
>>> sc[:, sc.var['gene_type'] == 'protein_coding']
Use different indexing types for cells and genes:
>>> sc[sc.obs['batch'] == 'A', ['APOE', 'TREM2']]
```
"""
if not isinstance(item, (int, str, slice, tuple, list,
np.ndarray, pl.Series)):
error_message = (
f'SingleCell datasets must be indexed with an integer, '
f'string, slice, tuple, list, NumPy array, or polars Series, '
f'but you tried to index with an object of type '
f'{type(item).__name__!r}')
raise TypeError(error_message)
if isinstance(item, tuple):
if not 1 <= len(item) <= 2:
self._getitem_error(item)
else:
item = item,
rows = self._getitem_process(item, 0, self._obs)
rows_is_Series = isinstance(rows, pl.Series)
if rows_is_Series:
boolean_Series = rows.dtype == pl.Boolean
obs = self._obs.filter(rows) if boolean_Series else self._obs[rows]
rows_NumPy = rows.to_numpy()
else:
boolean_Series = False
obs = self._obs[rows]
rows_NumPy = \
np.asarray(rows) if not isinstance(rows, slice) else rows
obsm = {key: (value.filter(rows) if boolean_Series else
value[rows]) if isinstance(value, pl.DataFrame) else
value[rows_NumPy]
for key, value in self._obsm.items()} if self._obsm else {}
rows_is_slice = isinstance(rows, slice)
obsp = ({key: value[rows_NumPy, rows_NumPy]
for key, value in self._obsp.items()} if rows_is_slice else
{key: value[ix_symmetric(rows_NumPy)]
for key, value in self._obsp.items()}) if self._obsp else {}
if len(item) == 1:
return SingleCell(
X=self._X[rows_NumPy] if self._X is not None else None,
obs=obs, var=self._var, obsm=obsm, varm=self._varm, obsp=obsp,
varp=self._varp,uns=self._uns, num_threads=self._num_threads)
cols = self._getitem_process(item, 1, self._var)
cols_is_Series = isinstance(cols, pl.Series)
if cols_is_Series:
boolean_Series = cols.dtype == pl.Boolean
var = self._var.filter(cols) if boolean_Series else self._var[cols]
cols_NumPy = cols.to_numpy()
else:
boolean_Series = False
var = self._var[cols]
cols_NumPy = \
np.asarray(cols) if not isinstance(cols, slice) else cols
varm = {key: (value.filter(cols) if boolean_Series else
value[cols]) if isinstance(value, pl.DataFrame) else
value[cols_NumPy]
for key, value in self._varm.items()} if self._varm else {}
cols_is_slice = isinstance(cols, slice)
varp = ({key: value[cols_NumPy, cols_NumPy]
for key, value in self._varp.items()} if cols_is_slice else
{key: value[ix_symmetric(cols_NumPy)]
for key, value in self._varp.items()}) if self._varp else {}
X = None if self._X is None else self._X[rows_NumPy, cols_NumPy] \
if rows_is_slice or cols_is_slice else \
self._X[np.ix_(rows_NumPy, cols_NumPy)]
return SingleCell(X=X, obs=obs, var=var, obsm=obsm, varm=varm,
obsp=obsp, varp=varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def cell(self,
cell: str,
/,
*,
num_threads: int | np.integer | None = None) -> \
np.ndarray:
"""
Get the row of `X` corresponding to a single cell, based on the cell's
name in `obs_names`.
Args:
cell: the name of the cell in `obs_names`
num_threads: the number of threads to use when retrieving the
row of `X`. Set `num_threads=-1` to use all available
cores, as determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores
when `X` is a CSC array and 1 thread when `X` is a CSR
array. Cannot be specified when `X` is a CSR array,
since there is no benefit to parallelism in that case.
Returns:
The corresponding row of `X`, as a dense 1D NumPy array with zeros
included.
"""
# Check that `X` is present
if self._X is None:
error_message = 'X is None, so getting a row of X is not possible'
raise ValueError(error_message)
# Check that `num_threads` is not specified when `X` is a CSR array
if num_threads is not None and isinstance(self._X, csr_array):
error_message = (
'num_threads cannot be specified when X is a CSR array, since '
'cell() will always run single-threaded')
raise ValueError(error_message)
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
row_index = SingleCell._getitem_by_string(self._obs, cell)
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
return self._X[row_index].toarray().squeeze()
finally:
self._X._num_threads = original_num_threads
[docs]
def gene(self,
gene: str,
/,
*,
num_threads: int | np.integer | None = None) -> \
np.ndarray:
"""
Get the column of `X` corresponding to a single gene, based on the
gene's name in `var_names`.
Args:
gene: the name of the gene in `var_names`
num_threads: the number of threads to use when retrieving the
row of `X`. Set `num_threads=-1` to use all available
cores, as determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores
when `X` is a CSR array and 1 thread when `X` is a CSC
array. Cannot be specified when `X` is a CSC array,
since there is no benefit to parallelism in that case.
Returns:
The corresponding column of `X`, as a dense 1D NumPy array with
zeros included.
"""
# Check that `X` is present
if self._X is None:
error_message = \
'X is None, so getting a column of X is not possible'
raise ValueError(error_message)
# Check that `num_threads` is not specified when `X` is a CSC array
if num_threads is not None and isinstance(self._X, csc_array):
error_message = (
'num_threads cannot be specified when X is a CSC array, since '
'cell() will always run single-threaded')
raise ValueError(error_message)
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
column_index = SingleCell._getitem_by_string(self._var, gene)
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
return self._X[:, column_index].toarray().squeeze()
finally:
self._X._num_threads = original_num_threads
def __len__(self) -> int:
"""
Get the number of cells in this SingleCell dataset.
Returns:
The number of cells.
"""
return len(self._obs)
def __repr__(self) -> str:
"""
Get a string representation of this SingleCell dataset.
Returns:
A string summarizing the dataset.
"""
descr = (
f'SingleCell dataset in '
f'{"CSR" if isinstance(self._X, csr_array) else "CSC"} format '
f'with {len(self._obs):,} {plural("cell", len(self._obs))} (obs), '
f'{len(self._var):,} {plural("gene", len(self._var))} (var), and ')
if self._X is not None:
descr += (f'{self._X.nnz:,} non-zero {self._X.dtype} '
f'{"entries" if self._X.nnz != 1 else "entry"} (X)')
else:
descr += 'no X'
try:
terminal_width = os.get_terminal_size().columns
except AttributeError:
terminal_width = 80 # for Jupyter notebooks
for attr in 'obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns':
entries = getattr(self, attr).columns \
if attr == 'obs' or attr == 'var' else getattr(self, attr)
if len(entries) > 0:
descr += '\n' + fill(
f' {attr}: {", ".join(entries)}',
width=terminal_width,
subsequent_indent=' ' * (len(attr) + 6))
return descr
def __iter__(self):
error_message = 'SingleCell datasets do not support iteration'
raise TypeError(error_message)
@property
def shape(self) -> tuple[int, int]:
"""
The shape of this SingleCell dataset: a length-2 tuple where the first
element is the number of cells, and the second is the number of genes.
"""
return len(self._obs), len(self._var)
@staticmethod
def _save_h5ad_dataframe(h5ad_file: h5py.File,
df: pl.DataFrame,
key: str,
preserve_strings: bool) -> None:
"""
Save `obs` or `var` to an `.h5ad` file.
Args:
h5ad_file: an `h5py.File` open in write mode
df: the DataFrame to write, e.g. `obs` or `var`
key: the key to create in `h5ad_file`, e.g. `'obs'` or `'var'`
preserve_strings: if `False`, encode string columns with duplicate
values as Enums to save space; if `True`,
preserve these columns as string columns
"""
# Create a group for the data frame and add top-level metadata
group = h5ad_file.create_group(key)
group.attrs['_index'] = df.columns[0]
group.attrs['column-order'] = df.columns[1:]
group.attrs['encoding-type'] = 'dataframe'
group.attrs['encoding-version'] = '0.2.0'
for column in df:
dtype = column.dtype
if dtype == pl.String:
if column.null_count() or not preserve_strings and \
column.is_duplicated().any():
column = column\
.cast(pl.Enum(column.unique(maintain_order=True)
.drop_nulls()))
dtype = column.dtype
else:
dataset = group.create_dataset(column.name,
data=column.to_numpy())
dataset.attrs['encoding-type'] = 'string-array'
dataset.attrs['encoding-version'] = '0.2.0'
continue
if dtype == pl.Enum or dtype == pl.Categorical:
is_Enum = dtype == pl.Enum
subgroup = group.create_group(column.name)
subgroup.attrs['encoding-type'] = 'categorical'
subgroup.attrs['encoding-version'] = '0.2.0'
subgroup.attrs['ordered'] = is_Enum
categories = column.cat.get_categories()
if not is_Enum:
column = column.cast(pl.Enum(categories))
codes = column.to_physical().fill_null(-1)
subgroup.create_dataset('codes', data=codes.to_numpy())
if len(categories) == 0:
subgroup.create_dataset('categories', shape=(0,),
dtype=h5py.special_dtype(vlen=str))
else:
subgroup.create_dataset('categories',
data=categories.to_numpy())
elif dtype.is_float():
# Nullable floats are not supported, so convert `null` to `NaN`
dataset = group.create_dataset(
column.name, data=column.fill_null(np.nan).to_numpy())
dataset.attrs['encoding-type'] = 'array'
dataset.attrs['encoding-version'] = '0.2.0'
else: # Boolean or integer
is_Boolean = dtype == pl.Boolean
if column.null_count():
# Store as nullable integer/Boolean
subgroup = group.create_group(column.name)
subgroup.attrs['encoding-type'] = \
f'nullable-{"boolean" if is_Boolean else "integer"}'
subgroup.attrs['encoding-version'] = '0.1.0'
subgroup.create_dataset(
'values',
data=column.fill_null(False if is_Boolean else 1)
.to_numpy())
subgroup.create_dataset(
'mask', data=column.is_null().to_numpy())
else:
# Store as regular integer/Boolean
dataset = group.create_dataset(column.name,
data=column.to_numpy())
dataset.attrs['encoding-type'] = 'array'
dataset.attrs['encoding-version'] = '0.2.0'
@staticmethod
def _save_h5Seurat_dataframe(h5ad_file: h5py.File,
df: pl.DataFrame,
key: str,
preserve_strings: bool) -> None:
"""
Save `obs` or `var` to an `.h5Seurat` file.
Args:
h5ad_file: an `h5py.File` open in write mode
df: the DataFrame to write, e.g. `obs` or `var`
key: the key to create in `h5ad_file`, e.g. `'meta.data'` or
`'RNA/meta.features'`
preserve_strings: if `False`, encode string columns with duplicate
values as Enums to save space; if `True`,
preserve these columns as string columns
"""
# Create a group for the data frame and add top-level metadata
group = h5ad_file.create_group(key)
group.attrs['_index'] = np.array([df.columns[0]], dtype=object)
group.attrs['colnames'] = df.columns[1:]
logicals = \
pl.selectors.expand_selector(df, pl.selectors.by_dtype(pl.Boolean))
if len(logicals) > 0:
group.attrs['logicals'] = logicals
for column in df:
dtype = column.dtype
if dtype == pl.String:
if column.null_count() or not preserve_strings and \
column.is_duplicated().any():
column = column\
.cast(pl.Enum(column.unique(maintain_order=True)
.drop_nulls()))
dtype = column.dtype
else:
group.create_dataset(column.name, data=column.to_numpy())
continue
if dtype == pl.Enum or dtype == pl.Categorical:
subgroup = group.create_group(column.name)
levels = column.cat.get_categories()
if dtype != pl.Enum:
column = column.cast(pl.Enum(levels))
values = (column.to_physical() + 1).fill_null(-2147483648)
subgroup.create_dataset('values', data=values.to_numpy())
if len(levels) == 0:
subgroup.create_dataset('levels', shape=(0,),
dtype=h5py.special_dtype(vlen=str))
else:
subgroup.create_dataset('levels', data=levels.to_numpy())
else:
if dtype.is_float():
column = column.fill_null(np.nan)
elif dtype == pl.Boolean:
column = column.cast(pl.Int32).fill_null(2)
else: # integer
column = column.fill_null(-2147483648)
group.create_dataset(column.name, data=column.to_numpy())
[docs]
def save(self,
filename: str | Path,
/,
*,
assay: str = 'RNA',
X_key: str = 'counts',
overwrite: bool = False,
preserve_strings: bool = False,
empty_X: bool = False,
v3: bool = False,
sce: bool = False) -> None:
"""
Save this SingleCell dataset to a file. File format will be inferred
from the file extension (e.g. `.h5ad`).
Args:
filename: an AnnData `.h5ad` file, Seurat `.rds` or `.h5Seurat`
file, SingleCellExperiment `.rds` file, or 10x `.h5` or
`.mtx`/`.mtx.gz` file to save to. If the extension is
`.rds`, the `sce` argument will determine whether to save
to a Seurat or a SingleCellExperiment object.
- When saving to a Seurat `.rds` file, to match the
requirements of Seurat objects, the `'X_'` prefix
(often used by Scanpy) will be removed from each key of
obsm where it is present (e.g. `'X_umap'` will become
`'umap'`).
- When saving to a Seurat `.rds` file, Seurat will add
`'orig.ident'`, `'nCount_RNA'` and `'nFeature_RNA'` as
gene-level metadata by default; you can disable the
calculation of the latter two columns with:
```python
from ryp import r
r('options(Seurat.object.assay.calcn = FALSE)')
```
- When saving to a Seurat `.rds` or `.h5Seurat` file or a
SingleCellExperiment `.rds` file, `varm` will not be
saved.
- When saving to a 10x `.h5` file, `obs['barcodes']`,
`var['feature_type']`, `var['genome']`, `var['id']`,
and `var['name']` must all exist. Only `X` and these
columns will be saved, along with whichever of
`var['pattern']`, `var['read']`, and `var['sequence']`
exist. All of these columns (if they exist) must be
String, Enum, Categorical, or integer.
- When saving to a 10x `.mtx` file, `barcodes.tsv` and
`features.tsv` will be created in the same directory.
Only `X`, `obs` and `var` will be saved.
- When saving to a 10x `.mtx.gz` file, `barcodes.tsv.gz`
and `features.tsv.gz` will be created in the same
directory. Only `X`, `obs` and `var` will be saved.
assay: when saving to a Seurat `.rds` or `.h5Seurat` file, the name
to use for the active assay
X_key: when saving to a Seurat `.rds` or `.h5Seurat` file, the name
of the layer within the active assay to save `X` to; must be
`'counts'` or `'data'`. When saving to a
SingleCellExperiment `.rds` file, the name of the layer
within `@assays@data` to save `X` to.
overwrite: if `False`, raises an error if (any of) the file(s)
exist; if `True`, overwrites them
preserve_strings: if `False`, encode string columns with duplicate
values as Enums to save space, when saving to
AnnData `.h5ad` or Seurat or SingleCellExperiment
`.rds`; if `True`, preserve these columns as
string columns. (Regardless of the value of
`preserve_strings`, String columns with `null`
values will be encoded as Enums when saving to
`.h5ad`, since the `.h5ad` format cannot
represent them otherwise.)
empty_X: if `True`, allow saving of SingleCell datasets with
missing counts (i.e. where `self.X is None`), by saving an
empty, dummy count matrix with no non-zero entries
v3: if `True`, use Seurat's version 3 format instead of the
more current version 5 format when saving to a Seurat `.rds`
file. When `v3=True`, the Seurat object's assay is created with
`CreateAssayObject()` instead of `CreateAssay5Object()`.
`v3=True` cannot be specified when saving to `.h5Seurat`, since
`.h5Seurat` files
[do not support](https://github.com/mojaveazure/seurat-disk/issues/147)
Seurat version 5 and so saving to `.h5Seurat` will always save
to version 3.
sce: if `True` and the extension of filename is `.rds`, save to a
SingleCellExperiment object instead of a Seurat object
Examples:
Save to an AnnData `.h5ad` file:
>>> sc.save('data.h5ad')
Overwrite an existing file:
>>> sc.save('data.h5ad', overwrite=True)
Save to a Seurat `.h5Seurat` file:
>>> sc.save('seurat_obj.h5Seurat')
Save to a Seurat `.rds` file:
>>> sc.save('seurat_obj.rds')
Save to a Seurat `.rds` file using the `'data'` layer instead of
`'counts'`:
>>> sc.save('seurat_obj.rds', X_key='data')
Save to a Seurat version 3 `.rds` file:
>>> sc.save('seurat_v3.rds', v3=True)
Save to a SingleCellExperiment `.rds` file:
>>> sc.save('sce_obj.rds', sce=True)
Save to a 10x Genomics `.h5` file (requires `obs` to contain a
`'barcodes'` column and `var` to contain `'feature_type'`,
`'genome'`, `'id'`, and `'name'` columns):
>>> sc.save('matrix.h5')
Save to a 10x Genomics `.mtx.gz` file (creates `barcodes.tsv.gz`
and `features.tsv.gz` in the same directory; unlike when saving to
`.h5`, `obs` and `var` are not required to contain any specific
columns):
>>> sc.save('matrix.mtx.gz')
Allow saving when `X` is missing by writing an empty count matrix:
>>> sc.save('empty.h5ad', empty_X=True)
Preserve string columns instead of encoding them as Enums:
>>> sc.save('data_preserve_strings.h5ad',
... preserve_strings=True)
"""
# Check that `filename` is a string or `Path`; convert it to a string
# and expand `~` into home directories
check_type(filename, 'filename', (str, Path),
'a string or pathlib.Path')
filename = str(filename)
filename_expanduser = os.path.expanduser(filename)
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Raise an error if the filename already exists, unless
# `overwrite=True`
if not overwrite and os.path.exists(filename_expanduser):
error_message = (
f'filename {filename!r} already exists; set overwrite=True '
f'to overwrite')
raise FileExistsError(error_message)
# Check that `empty_X` is `True` when `X` is absent, and vice versa.
# If `empty_X` is `True`, make an empty R-compatible count matrix.
check_type(empty_X, 'empty_X', bool, 'Boolean')
if empty_X:
if self._X is not None:
error_message = \
'X is not None, so empty_X=True cannot be specified'
raise ValueError(error_message)
counts = csr_array((np.array([], dtype=np.float64),
np.array([], dtype=np.int32),
np.array([0], dtype=np.int32)), shape=(0, 0))
else:
if self._X is None:
error_message = (
'X is None, so saving requires creating an empty count '
'matrix; specify empty_X=True to do this automatically '
'when saving')
raise ValueError(error_message)
counts = self._X
# Get the file type from the extension, raising an error if invalid
is_h5ad = filename.endswith('.h5ad')
is_rds = filename.endswith('.rds')
is_h5Seurat = filename.endswith('.h5Seurat') or \
filename.endswith('.h5seurat')
is_h5 = filename.endswith('.h5')
is_hdf5 = is_h5ad or is_h5 or is_h5Seurat
is_mtx = filename.endswith('.mtx')
is_mtx_gz = filename.endswith('.mtx.gz')
is_Seurat = is_h5Seurat or is_rds and not sce
if not (is_hdf5 or is_mtx or is_mtx_gz or is_rds):
error_message = (
f"filename {filename!r} does not end with '.h5ad', '.rds', "
f"'.h5Seurat'/'.h5seurat', '.h5', or '.mtx.gz'")
raise ValueError(error_message)
# Check that `assay` is a string, and that `assay` is not specified
# (i.e. retains its default value) unless saving to a Seurat object
check_type(assay, 'assay', str, 'a string')
if not is_Seurat and assay != 'RNA':
error_message = \
'assay cannot be specified unless saving to a Seurat object'
raise ValueError(error_message)
# Check that `X_key` is `'counts'` or `'data'` when saving to a Seurat
# object, any string when saving to a SingleCellExperiment object, and
# not specified (i.e. retains its default value) otherwise
check_type(X_key, 'X_key', str, 'a string')
if is_Seurat:
if X_key != 'counts' and X_key != 'data':
error_message = (
f"when saving to a Seurat object, X_key must be 'counts' "
f"or 'data', not {X_key!r}")
raise ValueError(error_message)
elif not sce and X_key != 'counts':
error_message = (
'X_key cannot be specified unless saving to a Seurat or '
'SingleCellExperiment object')
raise ValueError(error_message)
# Check that `preserve_strings` and `sce` are Boolean, and that `sce`
# is only `True` when saving to an `.rds` file
check_type(preserve_strings, 'preserve_strings', bool, 'Boolean')
check_type(sce, 'sce', bool, 'Boolean')
if sce and not is_rds:
error_message = 'sce can only be True when saving to an .rds file'
raise ValueError(error_message)
# Check that `v3` is Boolean, and only `True` when saving to an `.rds`
# file with `sce=False`
check_type(v3, 'v3', bool, 'Boolean')
if v3 and (not is_rds or sce):
error_message = (
'v3=True can only be specified when saving to a Seurat .rds '
'file')
raise ValueError(error_message)
# Raise an error if `obs` or `var` (or, if saving to `.h5ad`, DataFrame
# keys of `obsm` or `varm`) contain columns with unsupported data types
# (anything but float, int, String, Enum, Categorical, Boolean)
valid_dtypes = pl.FLOAT_DTYPES | pl.INTEGER_DTYPES | \
{pl.String, pl.Enum, pl.Categorical, pl.Boolean}
for df, df_name in (self._obs, 'obs'), (self._var, 'var'):
for column, dtype in df.schema.items():
if dtype.base_type() not in valid_dtypes:
error_message = (
f'{df_name}[{column!r}] has the data type '
f'{dtype.base_type()!r}, which is not supported when '
f'saving')
raise TypeError(error_message)
if is_h5ad:
for field, field_name in (self._obsm, 'obsm'), \
(self._varm, 'varm'):
for key, value in field.items():
if not isinstance(value, pl.DataFrame):
continue
for column, dtype in value.schema.items():
if dtype.base_type() not in valid_dtypes:
error_message = (
f'{field}[{key!r}][{column!r}] has the data '
f'type {dtype.base_type()!r}, which is not '
f'supported when saving')
raise TypeError(error_message)
# Raise an error if `obsm`, `varm` or `uns` contain NumPy arrays with
# unsupported data types (`datetime64`, `timedelta64`, unstructured
# `void`). Do not specifically check `dtype=object` to avoid extra
# overhead.
for field, field_name in \
(self._obsm, 'obsm'), (self._varm, 'varm'), (self._uns, 'uns'):
for key, value in field.items():
if not isinstance(value, np.ndarray):
continue
if value.dtype.type == np.void and value.dtype.names is None:
error_message = (
f'{field_name}[{key!r}] is an unstructured void '
f'array, which is not supported when saving')
raise TypeError(error_message)
elif value.dtype == np.datetime64:
error_message = (
f'{field_name}[{key!r}] is a datetime64 array, which '
f'is not supported when saving')
raise TypeError(error_message)
elif value.dtype == np.timedelta64:
error_message = (
f'{field_name}[{key!r}] is a timedelta64 array, which '
f'is not supported when saving')
raise TypeError(error_message)
# Save, depending on the file extension
if is_hdf5:
try:
with h5py.File(filename_expanduser, 'w') as hdf5_file:
if is_h5ad:
# Add top-level metadata
hdf5_file.attrs['encoding-type'] = 'anndata'
hdf5_file.attrs['encoding-version'] = '0.1.0'
# Save `obs` and `var`
SingleCell._save_h5ad_dataframe(
hdf5_file, self._obs, 'obs', preserve_strings)
SingleCell._save_h5ad_dataframe(
hdf5_file, self._var, 'var', preserve_strings)
# Save `obsm`
if self._obsm:
obsm = hdf5_file.create_group('obsm')
obsm.attrs['encoding-type'] = 'dict'
obsm.attrs['encoding-version'] = '0.1.0'
for key, value in self._obsm.items():
if isinstance(value, pl.DataFrame):
SingleCell._save_h5ad_dataframe(
hdf5_file, value, f'obsm/{key}',
preserve_strings)
else:
obsm.create_dataset(key, data=value)
# Save `varm`
if self._varm:
varm = hdf5_file.create_group('varm')
varm.attrs['encoding-type'] = 'dict'
varm.attrs['encoding-version'] = '0.1.0'
for key, value in self._varm.items():
if isinstance(value, pl.DataFrame):
SingleCell._save_h5ad_dataframe(
hdf5_file, value, f'varm/{key}',
preserve_strings)
else:
varm.create_dataset(key, data=value)
# Save `obsp`
obsp = hdf5_file.create_group('obsp')
obsp.attrs['encoding-type'] = 'dict'
obsp.attrs['encoding-version'] = '0.1.0'
for key, value in self._obsp.items():
group = obsp.create_group(key)
group.attrs['encoding-type'] = 'csr_matrix' \
if isinstance(value, csr_array) else \
'csc_matrix'
group.attrs['encoding-version'] = '0.1.0'
group.attrs['shape'] = value.shape
group.create_dataset('data', data=value.data)
group.create_dataset('indices', data=value.indices)
group.create_dataset('indptr', data=value.indptr)
# Save `varp`
varp = hdf5_file.create_group('varp')
varp.attrs['encoding-type'] = 'dict'
varp.attrs['encoding-version'] = '0.1.0'
for key, value in self._varp.items():
group = varp.create_group(key)
group.attrs['encoding-type'] = 'csr_matrix' \
if isinstance(value, csr_array) else \
'csc_matrix'
group.attrs['encoding-version'] = '0.1.0'
group.attrs['shape'] = value.shape
group.create_dataset('data', data=value.data)
group.create_dataset('indices', data=value.indices)
group.create_dataset('indptr', data=value.indptr)
# Save `uns`
if self._uns:
SingleCell._save_uns(self._uns,
hdf5_file.create_group('uns'),
hdf5_file)
# Save `X`
X = hdf5_file.create_group('X')
X.attrs['encoding-type'] = 'csr_matrix' \
if isinstance(counts, csr_array) else 'csc_matrix'
X.attrs['encoding-version'] = '0.1.0'
X.attrs['shape'] = counts.shape
X.create_dataset('data', data=counts.data)
X.create_dataset('indices', data=counts.indices)
X.create_dataset('indptr', data=counts.indptr)
elif is_h5Seurat:
obs_names = self.obs_names
var_names = self.var_names
for names, names_name in (obs_names, 'obs_names'), \
(var_names, 'var_names'):
null_count = names.null_count()
if null_count:
error_message = (
f'{names_name} contains {null_count:,} '
f'null values, but must not contain any '
f'when saving to an .h5Seurat file')
raise ValueError(error_message)
# Add top-level metadata and required groups/datasets
hdf5_file.attrs['active.assay'] = \
np.array([assay], dtype=object)
hdf5_file.attrs['project'] = \
np.array(['SeuratProject'], dtype=object)
hdf5_file.attrs['version'] = \
np.array(['3.0.0'], dtype=object)
for required_group in 'commands', 'images', 'tools':
hdf5_file.create_group(required_group)
active_ident = hdf5_file.create_group('active.ident')
active_ident.create_dataset('levels', data=[b'local'])
active_ident.create_dataset(
'values', data=np.ones(len(self._obs),
dtype=np.int32))
hdf5_file.create_dataset('cell.names',
data=obs_names.to_numpy())
active_assay = \
hdf5_file.create_group(f'assays/{assay}')
active_assay.attrs['key'] = \
np.array([f'{assay.lower()}_'], dtype=object)
active_assay.create_dataset(
'features', data=var_names.to_numpy())
active_assay.create_group('misc')
# Save `obs` and `var`
SingleCell._save_h5Seurat_dataframe(
hdf5_file, self._obs, 'meta.data',
preserve_strings)
SingleCell._save_h5Seurat_dataframe(
hdf5_file, self._var,
f'assays/{assay}/meta.features', preserve_strings)
# Save `obsm`
reductions = hdf5_file.create_group('reductions')
if self._obsm:
for key, value in self._obsm.items():
if isinstance(value, pl.DataFrame):
continue
group = reductions.create_group(key)
group.attrs['active.assay'] = \
np.array([assay], dtype=object)
group.attrs['global'] = \
np.array([0], dtype=np.int32)
group.attrs['key'] = \
np.array([f'{key.upper()}_'], dtype=object)
group.create_dataset('cell.embeddings',
data=value.T)
# Save `obsp`
graphs = hdf5_file.create_group('graphs')
if self._obsp:
for key, value in self._obsp.items():
if isinstance(value, csc_array):
value = value.tocsr()
group = graphs.create_group(key)
group.attrs['assay.used'] = \
np.array([assay], dtype=object)
group.attrs['dims'] = value.shape[::-1]
group.create_dataset('data',
data=value.data)
group.create_dataset('indices',
data=value.indices)
group.create_dataset('indptr',
data=value.indptr)
# Save `uns`
misc = hdf5_file.create_group('misc')
if self._uns:
SingleCell._save_h5Seurat_uns(
self._uns, misc, hdf5_file)
# Save `X`
if isinstance(counts, csc_array):
counts = counts.tocsr()
X = active_assay.create_group(X_key)
X.attrs['dims'] = counts.shape[::-1]
X.create_dataset('data', data=counts.data)
X.create_dataset('indices', data=counts.indices)
X.create_dataset('indptr', data=counts.indptr)
else: # `.h5`
obs_columns = ['barcodes']
var_columns = ['feature_type', 'genome', 'id', 'name']
for columns, df, df_name in \
(obs_columns, self._obs, 'obs'), \
(var_columns, self._var, 'var'):
for column in columns:
if column not in df:
error_message = (
f'{column!r} was not found in '
f'{df_name}, but is a required column '
f'when saving to a 10x .h5 file')
raise ValueError(error_message)
check_dtype(df[column],
f'{df_name}[{column!r}]',
(pl.String, pl.Categorical,
pl.Enum))
all_tag_keys = ['genome']
for column in 'pattern', 'read', 'sequence':
if column in self._var:
check_dtype(self._var[column],
f'var[{column!r}]',
(pl.String, pl.Categorical,
pl.Enum))
var_columns.append(column)
all_tag_keys.append(column)
matrix = hdf5_file.create_group('matrix')
matrix.create_dataset('barcodes',
data=self._obs[:, 0].to_numpy())
matrix.create_dataset('data', data=counts.data)
features = matrix.create_group('features')
matrix.create_dataset('indices', data=counts.indices)
matrix.create_dataset('indptr', data=counts.indptr)
matrix.create_dataset('shape',
data=counts.shape[::-1])
features.create_dataset('_all_tag_keys',
data=all_tag_keys)
for column in var_columns:
features.create_dataset(
column, data=self._var[column].to_numpy())
except:
if os.path.exists(filename_expanduser):
os.unlink(filename_expanduser)
raise
elif is_mtx or is_mtx_gz:
barcode_filename = os.path.join(
os.path.dirname(filename_expanduser),
'barcodes.tsv' if is_mtx else 'barcodes.tsv.gz')
feature_filename = os.path.join(
os.path.dirname(filename_expanduser),
'features.tsv' if is_mtx else 'features.tsv.gz')
if not overwrite:
for ancillary_filename in barcode_filename, feature_filename:
if os.path.exists(ancillary_filename):
error_message = (
f'{ancillary_filename!r} already exists; set '
f'overwrite=True to overwrite')
raise FileExistsError(error_message)
from scipy.io import mmwrite
try:
mmwrite(filename_expanduser, counts.T)
if is_mtx:
self._obs.write_csv(barcode_filename, include_header=False)
self._var.write_csv(feature_filename, include_header=False)
else:
import gzip
with gzip.open(barcode_filename, 'wb') as f:
self._obs.write_csv(f, include_header=False)
with gzip.open(feature_filename, 'wb') as f:
self._var.write_csv(f, include_header=False)
except:
if os.path.exists(filename_expanduser):
os.unlink(filename_expanduser)
if os.path.exists(barcode_filename):
os.unlink(barcode_filename)
if os.path.exists(feature_filename):
os.unlink(feature_filename)
raise
else:
from ryp import r
if preserve_strings:
sc = self
else:
# Convert string columns with duplicate values to Enum
enumify = lambda df: df.cast({
row[0]: pl.Enum(row[1]) for row in df
.select(pl.selectors.string()
.unique(maintain_order=True)
.implode()
.list.drop_nulls())
.unpivot()
.filter(pl.col.value.list.len() < len(df))
.rows()})
sc = SingleCell(X=counts, obs=enumify(self._obs),
var=enumify(self._var), obsm=self._obsm,
uns=self._uns, num_threads=self._num_threads)
# Do not include `empty_X=empty_X` in `to_sce()` and `to_seurat()`,
# since we've already handled it above
if sce:
sc.to_sce('.SingleCell.object', assay=X_key)
else:
sc.to_seurat('.SingleCell.object', assay=assay, layer=X_key,
v3=v3)
try:
r(f'saveRDS(.SingleCell.object, {filename_expanduser!r})')
except:
if os.path.exists(filename_expanduser):
os.unlink(filename_expanduser)
raise
finally:
r('rm(.SingleCell.object)')
def _get_column(self,
obs_or_var_name: Literal['obs', 'var'],
column: SingleCellColumn,
variable_name: str,
dtypes: pl.datatypes.classes.DataTypeClass | str |
tuple[pl.datatypes.classes.DataTypeClass | str,
...],
*,
QC_column: pl.Series | None = None,
allow_missing: bool = False,
allow_null: bool = False,
custom_error: str | None = None) -> pl.Series | None:
"""
Get a column of the same length as `obs`/`var`, or `None` if the column
is missing from `obs`/`var` and `allow_missing=True`.
Args:
obs_or_var_name: the name of the DataFrame the column is with
respect to, i.e. `'obs'` or `'var'`
column: a string naming a column of `obs`/`var`, a polars
expression that evaluates to a single column when applied
to `obs`/`var`, a polars Series or 1D NumPy array of the
same length as `obs`/`var`, or a function that takes in
`self` and returns a polars Series or 1D NumPy array of the
same length as `obs`/`var`
variable_name: the name of the variable corresponding to `column`
dtypes: the required dtype(s) of the column
QC_column: an optional column of cells passing QC. If specified,
the presence of `null` values will only raise an error
for cells passing QC. Has no effect when
`allow_null=True`.
allow_missing: whether to allow `column` to be a string missing
from `obs`/`var`, returning `None` in this case
allow_null: whether to allow `column` to contain `null` values
custom_error: a custom error message for when `column` is a string
and is not found in `obs`/`var`, and
`allow_missing=False`; use `{}` as a placeholder for
the name of the column
Returns:
A polars Series of the same length as `obs`/`var`, or `None` if the
column is missing from `obs`/`var` and `allow_missing=True`.
"""
obs_or_var = self._obs if obs_or_var_name == 'obs' else self._var
if isinstance(column, str):
variable_name = f'{variable_name} {column!r}'
if column in obs_or_var:
column = obs_or_var[column]
elif allow_missing:
return None
else:
error_message = \
f'{variable_name} is not a column of {obs_or_var_name}' \
if custom_error is None else \
custom_error.format(f'{column!r}')
raise ValueError(error_message)
elif isinstance(column, pl.Expr):
column = obs_or_var.select(column)
if column.width > 1:
error_message = (
f'{variable_name} is a polars expression that expands to '
f'{column.width:,} columns rather than 1')
raise ValueError(error_message)
column = column.to_series()
elif isinstance(column, pl.Series):
if len(column) != len(obs_or_var):
error_message = (
f'{variable_name} is a polars Series of length '
f'{len(column):,}, which differs from the length of '
f'{obs_or_var_name} ({len(obs_or_var):,})')
raise ValueError(error_message)
elif isinstance(column, np.ndarray):
if len(column) != len(obs_or_var):
error_message = (
f'{variable_name} is a NumPy array of length '
f'{len(column):,}, which differs from the length of '
f'{obs_or_var_name} ({len(obs_or_var):,})')
raise ValueError(error_message)
column = pl.Series(variable_name, column)
elif callable(column):
column = column(self)
if isinstance(column, np.ndarray):
if column.ndim != 1:
error_message = (
f'{variable_name} is a function that returns a '
f'{column.ndim:,}D NumPy array, but must return a '
f'polars Series or 1D NumPy array')
raise ValueError(error_message)
column = pl.Series(variable_name, column)
elif not isinstance(column, pl.Series):
error_message = (
f'{variable_name} is a function that returns a variable '
f'of type {type(column).__name__}, but must return a '
f'polars Series or 1D NumPy array')
raise TypeError(error_message)
if len(column) != len(obs_or_var):
error_message = (
f'{variable_name} is a function that returns a column of '
f'length {len(column):,}, which differs from the length '
f'of {obs_or_var_name} ({len(obs_or_var):,})')
raise ValueError(error_message)
else:
error_message = (
f'{variable_name} must be a string column name, a polars '
f'expression, a polars Series, a 1D NumPy array, or a '
f'function that returns a polars Series or 1D NumPy array '
f'when applied to this SingleCell dataset, but has type '
f'{type(column).__name__!r}')
raise TypeError(error_message)
check_dtype(column, variable_name, dtypes)
if not allow_null:
if QC_column is None:
null_count = column.null_count()
if null_count > 0:
error_message = (
f'{variable_name} contains {null_count:,} '
f'{plural("null value", null_count)}, but must not '
f'contain any')
raise ValueError(error_message)
else:
null_count = (column.is_null() & QC_column).sum()
if null_count > 0:
error_message = (
f'{variable_name} contains {null_count:,} '
f'{plural("null value", null_count)} for cells '
f'passing QC, but must not contain any')
raise ValueError(error_message)
return column
@staticmethod
def _get_columns(obs_or_var_name: Literal['obs', 'var'],
datasets: Sequence[SingleCell],
columns: SingleCellColumn | None |
Sequence[SingleCellColumn | None],
variable_name: str,
dtypes: pl.datatypes.classes.DataTypeClass | str |
tuple[pl.datatypes.classes.DataTypeClass | str,
...],
*,
QC_columns: list[pl.Series | None] = None,
allow_None: bool = True,
allow_missing: bool = False,
allow_null: bool = False,
custom_error: str | None = None) -> \
list[pl.Series | None]:
"""
Get a column of the same length as `obs`/`var` from each dataset.
Args:
obs_or_var_name: the name of the DataFrame the column is with
respect to, i.e. `'obs'` or `'var'`
datasets: a sequence of SingleCell datasets
columns: a string naming a column of `obs`/`var`, a polars
expression that evaluates to a single column when applied
to `obs`/`var`, a polars Series or 1D NumPy array of the
same length as `obs`/`var`, or a function that takes in
`self` and returns a polars Series or 1D NumPy array of
the same length as `obs`/`var`. Or, a Sequence of these,
one per dataset in `datasets`. May also be `None` (or a
Sequence containing `None`) if `allow_None=True`.
variable_name: the name of the variable corresponding to `columns`
dtypes: the required dtype(s) of the columns
QC_columns: an optional column of cells passing QC for each
dataset. If not `None` for a given dataset, the
presence of `null` values for that dataset will only
raise an error for cells passing QC. Has no effect when
`allow_null=True`.
allow_None: whether to allow `columns` or its elements to be `None`
allow_missing: whether to allow `columns` to be a string (or
contain strings) missing from certain datasets'
`obs`/`var`, returning `None` for these datasets
allow_null: whether to allow `columns` to contain `null` values
custom_error: a custom error message for when `column` is a string
and is not found in `obs`/`var`, and
`allow_missing=False`; use `{}` as a placeholder for
the name of the column
Returns:
A list of polars Series of the same length as `datasets`, where
each Series has the same length as the corresponding dataset's
`obs`/`var`. Or, if `columns` is `None` (or if some elements are
`None`) or missing from `obs`/`var` (when `allow_missing=True`), a
list of `None` (or where the corresponding elements are `None`).
"""
if columns is None:
if not allow_None:
error_message = f'{variable_name} is None'
raise TypeError(error_message)
return [None] * len(datasets)
if isinstance(columns, Sequence) and not isinstance(columns, str):
if len(columns) != len(datasets):
error_message = (
f'{variable_name} has length {len(columns):,}, but you '
f'specified {len(datasets):,} datasets')
raise ValueError(error_message)
if not allow_None and any(column is None for column in columns):
error_message = \
f'{variable_name} contains an element that is None'
raise TypeError(error_message)
if QC_columns is None:
return [dataset._get_column(
obs_or_var_name=obs_or_var_name, column=column,
variable_name=variable_name, dtypes=dtypes,
allow_null=allow_null, allow_missing=allow_missing,
custom_error=custom_error)
if column is not None else None
for dataset, column in zip(datasets, columns)]
else:
return [dataset._get_column(
obs_or_var_name=obs_or_var_name, column=column,
variable_name=variable_name, dtypes=dtypes,
QC_column=QC_column, allow_null=allow_null,
allow_missing=allow_missing, custom_error=custom_error)
if column is not None else None
for dataset, column, QC_column in
zip(datasets, columns, QC_columns)]
else:
if QC_columns is None:
return [dataset._get_column(
obs_or_var_name=obs_or_var_name, column=columns,
variable_name=variable_name, dtypes=dtypes,
allow_null=allow_null, allow_missing=allow_missing,
custom_error=custom_error) for dataset in datasets]
else:
return [dataset._get_column(
obs_or_var_name=obs_or_var_name, column=columns,
variable_name=variable_name, dtypes=dtypes,
QC_column=QC_column, allow_null=allow_null,
allow_missing=allow_missing, custom_error=custom_error)
for dataset, QC_column in zip(datasets, QC_columns)]
@staticmethod
def _describe_column(column_name: str, column: SingleCellColumn):
"""
Describe a column-name argument in an error message.
Args:
column_name: the name of the column-name argument
column: the value of the column-name argument
Returns:
The column's description: just the argument's name unless the value
is a string (i.e. the column's name in `obs` or `var`), in which
case also include the value.
"""
return f'{column_name} {column!r}' \
if isinstance(column, str) else column_name
[docs]
def to_scanpy(self, *, QC_column: str | None = 'passed_QC') -> 'AnnData':
"""
Converts this SingleCell dataset to an AnnData object, the
representation used by Scanpy.
Make sure to remove cells failing QC with `sc.filter_obs(QC_column)`
first, or specify `subset=True` in `qc()`. Alternatively, to include
cells failing QC in the AnnData object, set `QC_column` to `None`.
Note that there is no `from_scanpy()`; simply do
`SingleCell(anndata_object)` to initialize a SingleCell dataset from an
in-memory AnnData object.
Args:
QC_column: if not `None`, raise an error if this column is present
in `obs` and not all cells pass QC
Returns:
An AnnData object. For AnnData versions older than 0.11.0, which
do not support `csr_array`/`csc_array`, counts will be converted to
`csr_matrix`/`csc_matrix`.
Note:
The count matrix is not copied during the conversion to save
memory. This means modifying the AnnData object's count matrix
will also modify the original SingleCell dataset. If this behavior
is undesirable, explicitly copy the dataset before converting with
`sc.copy().to_scanpy()`.
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
from anndata import AnnData
import pandas as pd
finally:
signal.signal(signal.SIGINT, signal.default_int_handler)
valid_dtypes = pl.FLOAT_DTYPES | pl.INTEGER_DTYPES | \
{pl.String, pl.Enum, pl.Categorical, pl.Boolean}
for df, df_name in (self._obs, 'obs'), (self._var, 'var'):
for column, dtype in df.schema.items():
if dtype.base_type() not in valid_dtypes:
error_message = (
f'{df_name}[{column!r}] has the data type '
f'{dtype.base_type()!r}, which is not supported by '
f'AnnData')
raise TypeError(error_message)
if QC_column is not None:
check_type(QC_column, 'QC_column', str, 'a string')
if QC_column in self._obs:
QCed_cells = self._obs[QC_column]
check_dtype(QCed_cells, f'obs[{QC_column!r}]',
pl.Boolean)
if QCed_cells.null_count() or not QCed_cells.all():
error_message = (
f'not all cells pass QC; remove cells failing QC with '
f'filter_obs({QC_column!r}) or by specifying '
f'subset=True in qc(), or set QC_column=None to '
f'include them in the AnnData object')
raise ValueError(error_message)
type_mapping = {
pa.int8(): pd.Int8Dtype(), pa.int16(): pd.Int16Dtype(),
pa.int32(): pd.Int32Dtype(), pa.int64(): pd.Int64Dtype(),
pa.uint8(): pd.UInt8Dtype(), pa.uint16(): pd.UInt16Dtype(),
pa.uint32(): pd.UInt32Dtype(), pa.uint64(): pd.UInt64Dtype(),
pa.string(): pd.StringDtype(storage='pyarrow')
if int(pd.__version__.split('.')[0]) >= 2 else object,
pa.bool_(): pd.BooleanDtype()}
to_pandas = lambda df: df\
.to_pandas(split_blocks=True, types_mapper=type_mapping.get)\
.set_index(df.columns[0])
return anndata.AnnData(
X=None if self._X is None
else sparse.csr_matrix(self._X) if isinstance(self._X, csr_array)
else sparse.csc_matrix(self._X),
obs=to_pandas(self._obs), var=to_pandas(self._var),
obsm=dict(self._obsm), varm=dict(self._varm),
obsp=dict(self._obsp), varp=dict(self._varp),
uns=Uns._copy_uns(self._uns))
@staticmethod
def _from_seurat(seurat_object_name: str,
*,
assay: str | None,
layer: str,
layer_name: str) -> \
tuple[csr_array | csc_array, pl.DataFrame, pl.DataFrame,
dict[str, np.ndarray], dict[str, csc_array],
UnsDict]:
"""
Create a SingleCell dataset from an in-memory Seurat object loaded with
the ryp Python-R bridge. Used by `__init__()` and `from_seurat()`.
Args:
seurat_object_name: the name of the Seurat object in the ryp R
workspace
assay: the name of the assay within the Seurat object to load data
from; if `None`, defaults to the Seurat object's
`active.assay` attribute (usually `'RNA'`)
layer: the layer within the active assay (or the assay specified by
the `assay` argument, if not `None`) to use as `X`. Set to
`'data'` to load the normalized counts, or `'scale.data'` to
load the normalized and scaled counts, if available. If
dense, will be automatically converted to a sparse array.
layer_name: the name of the variable passed via the `layer`
argument
Returns:
A length-5 tuple of (`X`, `obs`, `var`, `obsm`, `obsp`, `uns`).
"""
from ryp import r, to_py
if assay is None:
assay = to_py(f'{seurat_object_name}@active.assay')
elif assay not in to_py(f'Assays({seurat_object_name})',
squeeze=False):
error_message = (
f'assay {assay!r} does not exist in '
f'{seurat_object_name}@assays; specify a different assay than '
f'{assay!r}')
raise ValueError(error_message)
assay_slot = f'{seurat_object_name}@assays${assay}'
# If Seurat v5, merge layers if necessary, and use `$slot` instead of
# `@slot` for `X` and `meta.data` instead of `meta.features` for `var`
v5 = to_py(f'inherits({assay_slot}, "Assay5")')
if v5:
r(f'.SingleCell.layers = names({assay_slot}@layers)')
try:
layers = to_py('.SingleCell.layers')
if layer not in layers:
# The exact layer doesn't exist. Check for sharded versions
# like `'counts.1'`, `'counts.2'`, etc.
r(rf'.SingleCell.layers = grep('
rf'"^{layer}\\.[0-9]+$", .SingleCell.layers, '
rf'value=TRUE)')
if to_py('length(.SingleCell.layers)') > 0:
r(f'{assay_slot} = JoinLayers('
f'{assay_slot}, layers=.SingleCell.layers)')
else:
# Collapse shard names like `'counts.1'`, `'counts.2'`
# etc. into just `'counts'`
current_assay_layers = to_py(
rf'unique(sub("\\.[0-9]+$", "", '
rf'names({assay_slot}@layers)))')
error_message = (
f'{layer_name} {layer!r} does not exist in the '
f'current assay ({assay!r}); specify a different '
f'assay than {assay!r} or a different '
f'{layer_name} than {layer!r}. The layers in the '
f'current assay are: '
f'{", ".join(map(repr, current_assay_layers))}')
raise ValueError(error_message)
finally:
r('rm(".SingleCell.layers")')
# In Seurat v5, feature metadata (`meta.data`) is per assay,
# while the features themselves (`features`) are per layer. To
# reconcile them, we must get the layer-specific feature names and
# use them to filter the assay-level metadata.
var = to_py(f'''
(function() {{
target_features = {assay_slot}@features[[{layer!r}]]
meta.data = {assay_slot}@meta.data
if (nrow(meta.data) == nrow({assay_slot})) {{
rownames(meta.data) = rownames({assay_slot})
}}
res <- meta.data[target_features, , drop=FALSE]
rownames(res) = target_features
res
}})()''', index='gene')
X_slot = f'{assay_slot}@layers${layer}'
else:
# unlike v5 objects, v3 objects indicate the absence of a layer
# with a 0 x 0 matrix
if not (to_py(f'"{layer}" %in% slotNames({assay_slot})') and
to_py(f'prod(dim({assay_slot}@{layer}))') > 0):
current_assay_slots = to_py(f"slotNames({assay_slot})")
error_message = (
f'{layer_name} {layer!r} does not exist in the current '
f'assay ({assay!r}); specify a different assay than '
f'{assay!r} or a different {layer_name} than {layer!r}. '
f'The layers in the current assay are: '
f'{", ".join(map(repr, current_assay_slots))}')
raise ValueError(error_message)
X_slot = f'{assay_slot}@{layer}'
var = to_py(f'{assay_slot}@meta.features')
X_classes = tuple(to_py(f'class({X_slot})', squeeze=False))
if X_classes == ('dgCMatrix',):
X = to_py(X_slot).T
elif X_classes == ('matrix', 'array'):
warning_message = (
f"this Seurat object's {layer_name} {layer!r} is stored as a "
f"dense matrix; auto-converting to a sparse csr_array")
warnings.warn(warning_message)
X = csr_array(to_py(X_slot, format='numpy').T)
else:
error_message = (
f'{layer_name} {layer!r} exists in {assay_slot} but is not a '
f'dgCMatrix (column-oriented sparse matrix) or matrix, '
f'instead having ')
if len(X_classes) == 0:
error_message += 'no classes'
elif len(X_classes) == 1:
error_message += f'the class {X_classes[0]!r}'
else:
error_message += (
f'the classes '
f'{", ".join(f"{c!r}" for c in X_classes[:-1])} and '
f'{X_classes[-1]}')
error_message += (
f'; specify a different assay than {assay!r} or a different '
f'{layer_name} than {layer!r}')
raise TypeError(error_message)
obs_key = f'{seurat_object_name}@meta.data'
obs = to_py(obs_key, index='_index' if to_py(f'"cell" %in% {obs_key}')
else 'cell')
if var is None:
var = to_py(f'rownames({assay_slot}@{layer})').to_frame('gene')
obs = obs.cast({column.name: pl.Enum(column.cat.get_categories())
for column in obs.select(pl.col(pl.Categorical))})
var = var.cast({column.name: pl.Enum(column.cat.get_categories())
for column in var.select(pl.col(pl.Categorical))})
reduction_names = to_py(f'names({seurat_object_name}@reductions)')
obsm = {reduction_name: to_py(f'{seurat_object_name}@reductions$'
f'{reduction_name}@cell.embeddings',
format='numpy')
for reduction_name in reduction_names
if not to_py(f'is.null({seurat_object_name}@reductions$'
f'{reduction_name})') and
to_py(f'{seurat_object_name}@reductions${reduction_name}'
f'@assay.used') == assay} \
if reduction_names is not None else {}
graph_names = to_py(f'names({seurat_object_name}@graphs)')
obsp = {graph_name: csc_array((
to_py(f'{seurat_object_name}${graph_name}@x', format='numpy'),
to_py(f'{seurat_object_name}${graph_name}@i', format='numpy'),
to_py(f'{seurat_object_name}${graph_name}@p', format='numpy')),
shape=to_py(f'{seurat_object_name}${graph_name}@Dim',
format='numpy'))
for graph_name in graph_names
if to_py(f'length({seurat_object_name}${graph_name}@'
f'assay.used)') == 0 or
to_py(f'{seurat_object_name}${graph_name}@assay.used') == assay
} if graph_names is not None else {}
uns = to_py(f'{seurat_object_name}@misc')
if not uns:
# uns may be an empty unnamed list in R, which ryp converts to a
# Python list rather than a dictionary
uns = {}
return X, obs, var, obsm, obsp, uns
[docs]
@staticmethod
def from_seurat(seurat_object_name: str,
*,
assay: str | None = None,
layer: str = 'counts',
num_threads: int | np.integer = -1) -> SingleCell:
"""
Create a SingleCell dataset from a Seurat object that has already been
loaded into memory via the ryp Python-R bridge. To load a Seurat object
from disk, use e.g. `SingleCell('filename.rds')`. Both version 3 and
version 5 Seurat objects are supported.
Args:
seurat_object_name: the name of the Seurat object in the ryp R
workspace
assay: the name of the assay within the Seurat object to load data
from; if `None`, defaults to the Seurat object's
`active.assay` attribute (usually `'RNA'`)
layer: the layer within the active assay (or the assay specified by
the `assay` argument, if not `None`) to use as `X`. Defaults
to `'counts'`. Set to `'data'` to load the normalized
counts, or `'scale.data'` to load the normalized and scaled
counts, if available. If dense, will be automatically
converted to a sparse array.
num_threads: the default number of threads to use for all
subsequent operations on this SingleCell dataset. Also
sets the number of threads for this SingleCell
dataset's count matrix, if present. Does not affect
the number of threads used for data loading; this will
always be single-threaded for Seurat objects.
Returns:
The corresponding SingleCell dataset.
"""
from ryp import to_py
check_type(seurat_object_name, 'seurat_object_name', str, 'a string')
check_R_variable_name(seurat_object_name, 'seurat_object_name')
if assay is not None:
check_type(assay, 'assay', str, 'a string')
check_type(layer, 'layer', str, 'a string')
num_threads = SingleCell._process_num_threads_static(num_threads)
if not to_py(f'inherits({seurat_object_name}, "Seurat")'):
classes = to_py(f'class({seurat_object_name})', squeeze=False)
error_message = (
f'the R object named by seurat_object_name, '
f'{seurat_object_name}, must be a Seurat object, but has ')
if len(classes) == 0:
error_message += 'no classes'
elif len(classes) == 1:
error_message += f'the class {classes[0]!r}'
else:
error_message += (
f'the classes '
f'{", ".join(f"{c!r}" for c in classes[:-1])} and '
f'{classes[-1]!r}')
raise TypeError(error_message)
X, obs, var, obsm, obsp, uns = \
SingleCell._from_seurat(seurat_object_name, assay=assay,
layer=layer, layer_name='layer')
return SingleCell(X=X, obs=obs, var=var, obsm=obsm, obsp=obsp, uns=uns,
num_threads=num_threads)
[docs]
def to_seurat(self,
seurat_object_name: str,
/,
*,
QC_column: str | None = 'passed_QC',
assay: str = 'RNA',
layer: str = 'counts',
empty_X: bool = False,
v3: bool = False) -> None:
"""
Convert this SingleCell dataset to a Seurat object in the R workspace
of the ryp Python-R bridge.
Make sure to remove cells failing QC with `sc.filter_obs(QC_column)`
first, or specify `subset=True` in `qc()`. Alternatively, to include
cells failing QC in the Seurat object, set `QC_column` to `None`.
When converting to Seurat, to match the requirements of Seurat objects,
the `'X_'` prefix (often used by Scanpy) will be removed from each key
of `obsm` where it is present (e.g. `'X_umap'` will become `'umap'`).
Seurat will also add `'orig.ident'`, `'nCount_RNA'` and
`'nFeature_RNA'` as gene-level metadata by default; you can disable
the calculation of the latter two columns with:
```python
from ryp import r
r('options(Seurat.object.assay.calcn = FALSE)')
```
`varm` and DataFrame keys of `obsm` will not be converted.
Args:
seurat_object_name: the name of the R variable to assign the Seurat
object to
QC_column: if not `None`, raise an error if this column is present
in `obs` and not all cells pass QC
assay: the name to use for the active assay
layer: the name of the layer within the active assay to save `X`
to; must be `'counts'` or `'data'`
empty_X: if `True`, allow converting SingleCell datasets with
missing counts (i.e. where `self.X is None`), by using an
empty, dummy matrix with no non-zero entries as the Seurat
object's count matrix
v3: if `True`, save to Seurat's version 3 format instead of the
more current version 5 format. When `v3=True`, the Seurat
object's assay is created with `CreateAssayObject()` instead of
`CreateAssay5Object()`.
Note:
Seurat objects do not support count matrices with more than
2,147,483,647 (INT32_MAX) non-zero elements. If your SingleCell
dataset is larger than this, it cannot be converted!
Note:
Seurat requires the counts to be float64 and the indices of the
count matrix to be int32. To avoid copying data when converting,
consider converting the counts to float64 with
`sc.X = sc.X.astype(np.float64, copy=False)` and the indices with
`sc.X.shrink_indices()` (an in-place operation).
Note:
In the special case where the counts are float64 or the indices of
the count matrix are int32, the counts or indices will not be
copied during the conversion to save memory. This means modifying
the Seurat object's count matrix will also modify the original
SingleCell dataset. If this behavior is undesirable, explicitly
copy the dataset before converting with `sc.copy().to_seurat()`.
"""
# Check that `empty_X` is `True` when `X` is absent, and vice versa.
# If `empty_X` is `True`, make an empty R-compatible count matrix.
check_type(empty_X, 'empty_X', bool, 'Boolean')
if empty_X:
if self._X is not None:
error_message = \
'X is not None, so empty_X=True cannot be specified'
raise ValueError(error_message)
X = csr_array((np.array([], dtype=np.float64),
np.array([], dtype=np.int32),
np.array([0], dtype=np.int32)), shape=(0, 0))
else:
if self._X is None:
error_message = (
'X is None, so converting to Seurat requires creating an '
'empty count matrix; specify empty_X=True to do this '
'automatically when saving')
raise ValueError(error_message)
X = self._X
# Check that `X` is CSR
if isinstance(self._X, csc_array):
error_message = (
'X is a csc_array; run `sc.tocsr() before converting to '
'Seurat')
raise TypeError(error_message)
# Check that `X` has sorted indices
if not X.has_sorted_indices:
error_message = (
'X does not have sorted indices; run '
'`sc.X.sort_indices()` (an in-place operation) before '
'converting to Seurat')
raise ValueError(error_message)
# Check that `X` does not have too many non-zero elements
if X.nnz > 2_147_483_647:
error_message = (
f'X has {X.nnz:,} non-zero elements, more than '
f'2,147,483,647 (INT32_MAX), the maximum supported in R')
raise ValueError(error_message)
from ryp import r, to_r
r('suppressPackageStartupMessages(library(Seurat))')
check_type(seurat_object_name, 'seurat_object_name', str, 'a string')
check_R_variable_name(seurat_object_name, 'seurat_object_name')
if QC_column is not None:
check_type(QC_column, 'QC_column', str, 'a string')
if QC_column in self._obs:
QCed_cells = self._obs[QC_column]
check_dtype(QCed_cells, f'obs[{QC_column!r}]',
pl.Boolean)
if QCed_cells.null_count() or not QCed_cells.all():
error_message = (
f'not all cells pass QC; remove cells failing QC with '
f'filter_obs({QC_column!r}) or by specifying '
f'subset=True in qc(), or set QC_column=None to '
f'include them in the Seurat object')
raise ValueError(error_message)
check_type(assay, 'assay', str, 'a string')
check_type(layer, 'X_key', str, 'a string')
check_type(v3, 'v3', bool, 'Boolean')
if layer != 'counts' and layer != 'data':
error_message = (
f"when converting to a Seurat object, X_key must be 'counts' "
f"or 'data', not {layer!r}")
raise ValueError(error_message)
valid_dtypes = pl.FLOAT_DTYPES | pl.INTEGER_DTYPES | \
{pl.String, pl.Enum, pl.Categorical, pl.Boolean}
for df, df_name in (self._obs, 'obs'), (self._var, 'var'):
for column, dtype in df.schema.items():
if dtype.base_type() not in valid_dtypes:
error_message = (
f'{df_name}[{column!r}] has the data type '
f'{dtype.base_type()!r}, which is not supported when '
f'converting to a Seurat object')
raise TypeError(error_message)
obs_names = self.obs_names
var_names = self.var_names
for names, names_name in (obs_names, 'obs_names'), \
(var_names, 'var_names'):
null_count = names.null_count()
if null_count:
error_message = (
f'{names_name} contains {null_count:,} null values, but '
f'must not contain any when converting to a Seurat object')
raise ValueError(error_message)
is_string = var_names.dtype == pl.String
num_with_underscores = var_names.str.contains('_').sum() \
if is_string else \
var_names.cat.get_categories().str.contains('_').sum()
if num_with_underscores:
var_names_expression = f'pl.col.{var_names.name}' \
if var_names.name.isidentifier() else \
f'pl.col({var_names.name!r})'
error_message = (
f"var_names contains {num_with_underscores:,}"
f"{'' if is_string else ' unique'} gene "
f"{plural('name', num_with_underscores)} with "
f"underscores, which are not supported by Seurat; Seurat "
f"recommends changing the underscores to dashes, which you "
f"can do with .with_columns_var({var_names_expression}"
f"{'' if is_string else '.cast(pl.String)'}"
f".str.replace_all('_', '-'))")
raise ValueError(error_message)
# The R sparse matrix class Seurat uses (dgCMatrix) requires float64
# data and int32 indices and indptr. Convert each of these three to the
# correct type, if not already correct.
X = X.astype(np.float64, copy=False)
if X.indices.dtype == np.int64:
X = X.shrink_indices(copy=True)
if obs_names.dtype in pl.INTEGER_DTYPES:
idents = obs_names.cast(pl.String) # dummy idents
elif obs_names.dtype == pl.String:
idents = obs_names\
.str.splitn('_', 2)\
.struct.field('field_0')\
.cast(pl.Categorical)
else: # Categorical or Enum
idents = obs_names.to_frame().join(
obs_names.to_frame().unique().with_columns(
idents=pl.first()
.cast(pl.String)
.str.splitn('_', 2)
.struct.field('field_0').cast(pl.Categorical)),
on=obs_names.name)['idents']
# Create the Seurat object, with just counts (X) and meta.data (obs)
try:
to_r(obs_names, '.SingleCell.obs_names')
to_r(var_names, '.SingleCell.var_names')
to_r(X.T, '.SingleCell.X.T')
r('rownames(.SingleCell.X.T) = .SingleCell.var_names')
r('colnames(.SingleCell.X.T) = .SingleCell.obs_names')
to_r(self._obs.drop(obs_names.name)
.with_columns(idents.alias('orig.idents'), pl.all()),
'.SingleCell.obs')
r('rownames(.SingleCell.obs) = .SingleCell.obs_names')
r('.SingleCell.active_ident = .SingleCell.obs$orig.idents')
r('names(.SingleCell.active_ident) = .SingleCell.obs_names')
r(f'.SingleCell.assays = list(CreateAssay{"" if v3 else "5"}'
f'Object({layer}=.SingleCell.X.T))')
r(f'names(.SingleCell.assays) = {assay!r}')
r(f'Key(.SingleCell.assays[[{assay!r}]]) = "{assay.lower()}_"')
r(f'''.SingleCell.seurat = new(
Class = 'Seurat',
assays = .SingleCell.assays,
meta.data = .SingleCell.obs,
active.assay = {assay!r},
active.ident = .SingleCell.active_ident,
project.name = 'SeuratProject',
version = packageVersion(pkg = 'SeuratObject'),
graphs = list(),
neighbors = list(),
reductions = list(),
images = list(),
misc = list(),
commands = list(),
tools = list())''')
# Add var
to_r(self._var.drop(var_names.name), '.SingleCell.var',
rownames=var_names)
r(f'.SingleCell.seurat@assays${assay}@'
f'meta.{"features" if v3 else "data"} = .SingleCell.var')
# Add obsm
if self._obsm:
for original_key, value in self._obsm.items():
if isinstance(value, pl.DataFrame):
continue
dtype = value.dtype
if np.issubdtype(dtype, np.floating):
value = value.astype(np.float64, copy=False)
# ryp converts uint32 to int32 zero-copy, so skip the cast
elif np.issubdtype(dtype, np.integer) and \
dtype != np.uint32:
value = value.astype(np.int32, copy=False)
else:
error_message = (
f'obsm[{original_key!r}] has data type {dtype}, '
f'but it must be integer or floating-point to be '
f'convertible to Seurat')
raise TypeError(error_message)
# Remove the initial X_ from the reduction name and suffix
# with `'_'` when creating the key, like
# mojaveazure.github.io/seurat-disk/reference/Convert.html
key = original_key.removeprefix('X_')
to_r(value, '.SingleCell.value', colnames=[
f'{key}_{i}' for i in range(1, value.shape[1] + 1)])
r('rownames(.SingleCell.value) = .SingleCell.obs_names')
r(f'.SingleCell.seurat@reductions${key} = '
f'CreateDimReducObject(.SingleCell.value, '
f'key="{key}_", assay="{assay}")')
# Add obsp
if self._obsp:
obsp = {}
for key, value in self._obsp.items():
dtype = value.dtype
if np.issubdtype(dtype, np.floating):
value = value.astype(np.float64, copy=False)
# ryp converts uint32 to int32 zero-copy, so skip the cast
elif np.issubdtype(dtype, np.integer) and \
dtype != np.uint32:
value = value.astype(np.int32, copy=False)
else:
error_message = (
f'obsp[{key!r}] has data type {dtype}, but it '
f'must be integer or floating-point to be '
f'convertible to Seurat')
raise TypeError(error_message)
if value.indices.dtype == np.int64:
value = value.shrink_indices(copy=True)
obsp[key] = value
to_r(obsp, '.SingleCell.obsp', rownames=obs_names,
colnames=obs_names)
r('''
for (i in seq_along(.SingleCell.obsp)) {
rownames(.SingleCell.obsp[[i]]) <- .SingleCell.obs_names
colnames(.SingleCell.obsp[[i]]) <- .SingleCell.obs_names
}''')
r(f'.SingleCell.seurat@graphs = '
f'lapply(.SingleCell.obsp, as.Graph)')
r('rm(.SingleCell.obsp)')
# Add uns
if self._uns:
to_r(self._uns, '.SingleCell.uns')
r(f'.SingleCell.seurat@misc = .SingleCell.uns')
# Atomically assign the Seurat object to its final name at the end
r(f'{seurat_object_name} = .SingleCell.seurat')
finally:
r(f'rm(list = Filter(exists, c(".SingleCell.obs_names", '
f'".SingleCell.var_names", ".SingleCell.X.T", '
f'".SingleCell.obs", ".SingleCell.active_ident", '
f'".SingleCell.assays", ".SingleCell.seurat", '
f'".SingleCell.var", ".SingleCell.value", ".SingleCell.obsp", '
f'".SingleCell.uns")))')
@staticmethod
def _get_DFrame(dframe: str, *, index: str) -> pl.DataFrame:
"""
Convert a DFrame in the ryp R workspace to a Python DataFrame, raising
an error if the DFrame contains any nested data stuctures.
Args:
dframe: the name of the DFrame object in the ryp R workspace
index: the name of the column containing the rownames in the output
DataFrame; if this column name is already present in the
DFrame, fall back to calling the rownames column `'_index'`
Returns:
A polars DataFrame containing the data in `dframe`.
"""
from ryp import to_py
df = to_py(f'{dframe}@listData', index=False)
if not all(isinstance(value, pl.Series) for value in df.values()):
error_message = (
f'{dframe} contains nested data; unnest before converting to '
f'a SingleCell dataset')
raise ValueError(error_message)
if index in df.keys():
index = '_index'
df = pl.DataFrame({index: to_py(f'{dframe}@rownames')} | df)
return df
@staticmethod
def _from_sce(sce_object_name: str,
*,
assay: str,
assay_name: str) -> \
tuple[csr_array | csc_array, pl.DataFrame, pl.DataFrame,
dict[str, np.ndarray], UnsDict]:
"""
Create a SingleCell dataset from an in-memory SingleCellExperiment
object loaded with the ryp Python-R bridge. Used by `__init__()` and
`from_sce()`.
Args:
sce_object_name: the name of the SingleCellExperiment object in the
ryp R workspace
assay: the element within `sce_object@assays@data` to use as `X`.
Set to `'counts'` to load raw counts, or `'logcounts'` to
load the normalized counts if available. If dense, will be
automatically converted to a sparse array.
assay_name: the name of the variable passed via the `assay`
argument
Returns:
A length-5 tuple of (`X`, `obs`, `var`, `obsm`, `uns`).
"""
from ryp import to_py
X_slot = f'{sce_object_name}@assays@data${assay}'
if not to_py(f'"{assay}" %in% names({sce_object_name}@assays@data)'):
error_message = (
f'{assay_name} {assay!r} does not exist in '
f'{sce_object_name}@assays@data${assay}; specify a different '
f'{assay_name} than {assay!r}')
raise ValueError(error_message)
X_classes = tuple(to_py(f'class({X_slot})', squeeze=False))
if X_classes == ('dgCMatrix',):
X = to_py(X_slot).T
elif X_classes == ('matrix', 'array'):
warning_message = (
f"this SingleCellExperiment object's {layer_name} {layer!r} "
f"is stored as a dense matrix; auto-converting to a sparse "
f"csr_array")
warnings.warn(warning_message)
X = csr_array(to_py(X_slot, format='numpy').T)
else:
error_message = (
f'{assay_name} {assay!r} exists in '
f'{sce_object_name}@assays@data${assay} but is not a '
f'dgCMatrix (column-oriented sparse matrix) or matrix, '
f'instead having ')
if len(X_classes) == 0:
error_message += 'no classes'
elif len(X_classes) == 1:
error_message += f'the class {X_classes[0]!r}'
else:
error_message += (
f'the classes '
f'{", ".join(f"{c!r}" for c in X_classes[:-1])} and '
f'{X_classes[-1]}')
error_message += (
f'; specify a different {assay_name} than {assay!r}')
raise TypeError(error_message)
obs = SingleCell._get_DFrame(f'colData({sce_object_name})',
index='cell')
var = SingleCell._get_DFrame(f'rowData({sce_object_name})',
index='gene')
obs = obs.cast({column.name: pl.Enum(column.cat.get_categories())
for column in obs.select(pl.col(pl.Categorical))})
var = var.cast({column.name: pl.Enum(column.cat.get_categories())
for column in var.select(pl.col(pl.Categorical))})
obsm = to_py(f'reducedDims({sce_object_name})@listData',
format='numpy')
uns = to_py(f'{sce_object_name}@metadata', format='numpy')
# `@metadata` can be a non-empty unnamed list (though this is not
# recommended), but `uns` requires names, so add dummy ones
if isinstance(uns, list):
uns = {f'element_{i}': uns_element
for i, uns_element in enumerate(uns)}
return X, obs, var, obsm, uns
[docs]
@staticmethod
def from_sce(sce_object_name: str,
*,
assay: str = 'counts',
num_threads: int | np.integer = -1) -> SingleCell:
"""
Create a SingleCell dataset from a SingleCellExperiment object that has
already been loaded into memory via the ryp Python-R bridge. To load a
SingleCellExperiment object from disk, use e.g.
`SingleCell('filename.rds')`.
Args:
sce_object_name: the name of the SingleCellExperiment object in the
ryp R workspace
assay: the element within `{sce_object_name}@assays@data` to use as
`X`. Defaults to `'counts'`. If available, set to
`'logcounts'` to load the normalized counts. If dense, will
be automatically converted to a sparse array.
num_threads: the default number of threads to use for all
subsequent operations on this SingleCell dataset. Also
sets the number of threads for this SingleCell
dataset's count matrix, if present. Does not affect
the number of threads used for data loading; this will
always be single-threaded for SingleCellExperiment
objects.
Returns:
The corresponding SingleCell dataset.
"""
from ryp import r, to_py
r('suppressPackageStartupMessages(library(SingleCellExperiment))')
check_type(sce_object_name, 'sce_object_name', str, 'a string')
check_R_variable_name(sce_object_name, 'sce_object_name')
check_type(assay, 'assay', str, 'a string')
num_threads = SingleCell._process_num_threads_static(num_threads)
if not to_py(f'inherits({sce_object_name}, "SingleCellExperiment")'):
classes = to_py(f'class({sce_object_name})', squeeze=False)
error_message = (
f'the R object named by sce_object_name, {sce_object_name}, '
f'must be a SingleCellExperiment object, but has ')
if len(classes) == 0:
error_message += 'no classes'
elif len(classes) == 1:
error_message += f'the class {classes[0]!r}'
else:
error_message += (
f'the classes '
f'{", ".join(f"{c!r}" for c in classes[:-1])} and '
f'{classes[-1]!r}')
raise TypeError(error_message)
X, obs, var, obsm, uns = \
SingleCell._from_sce(sce_object_name, assay=assay,
assay_name='assay')
return SingleCell(X=X, obs=obs, var=var, obsm=obsm, uns=uns,
num_threads=num_threads)
[docs]
def to_sce(self,
sce_object_name: str,
/,
*,
QC_column: str | None = 'passed_QC',
assay: str = 'counts',
empty_X: bool = False) -> None:
"""
Convert this SingleCell dataset to a SingleCellExperiment object in the
R workspace of the ryp Python-R bridge.
Make sure to remove cells failing QC with `sc.filter_obs(QC_column)`
first, or specify `subset=True` in `qc()`. Alternatively, to include
cells failing QC in the SingleCellExperiment object, set `QC_column` to
`None`.
`varm`, `obsp`, `varp`, and DataFrame keys of `obsm` will not be
converted.
Args:
sce_object_name: the name of the R variable to assign the
SingleCellExperiment object to
QC_column: if not `None`, raise an error if this column is present
in `obs` and not all cells pass QC
assay: the name of the slot within `{sce_object_name}@assays@data`
to save `X` to
empty_X: if `True`, allow converting SingleCell datasets with
missing counts (i.e. where `self.X is None`), by using an
empty, dummy matrix with no non-zero entries as the
SingleCellExperiment object's count matrix
Note:
SingleCellExperiment objects do not support count matrices with
more than 2,147,483,647 (INT32_MAX) non-zero elements. If your
SingleCell dataset is larger than this, it cannot be converted!
Note:
SingleCellExperiment requires the counts to be float64 and the
indices of the count matrix to be int32. To avoid copying data when
converting, consider converting the counts to float64 with
`sc.X = sc.X.astype(np.float64, copy=False)` and the indices with
`sc.X.shrink_indices()` (an in-place operation).
Note:
In the special case where the counts are float64 or the indices of
the count matrix are int32, the counts or indices will not be
copied during the conversion to save memory. This means modifying
the SinglecellExperiment object's count matrix will also modify the
original SingleCell dataset. If this behavior is undesirable,
explicitly copy the dataset before converting with
`sc.copy().to_sce()`.
"""
# Check that `empty_X` is `True` when `X` is absent, and vice versa.
# If `empty_X` is `True`, make an empty R-compatible count matrix.
check_type(empty_X, 'empty_X', bool, 'Boolean')
if empty_X:
if self._X is not None:
error_message = \
'X is not None, so empty_X=True cannot be specified'
raise ValueError(error_message)
X = csr_array((np.array([], dtype=np.float64),
np.array([], dtype=np.int32),
np.array([0], dtype=np.int32)), shape=(0, 0))
else:
if self._X is None:
error_message = (
'X is None, so converting to SingleCellExperiment '
'requires creating an empty count matrix; specify '
'empty_X=True to do this automatically when saving')
raise ValueError(error_message)
X = self._X
# Check that `X` is CSR
if isinstance(self._X, csc_array):
error_message = (
'X is a csc_array; run `sc.tocsr() before converting to '
'SingleCellExperiment')
raise TypeError(error_message)
# Check that `X` has sorted indices
if not X.has_sorted_indices:
error_message = (
'X does not have sorted indices; run '
'`sc.X.sort_indices()` (an in-place operation) before '
'converting to SingleCellExperiment')
raise ValueError(error_message)
# Check that `X` does not have too many non-zero elements
if X.nnz > 2_147_483_647:
error_message = (
f'X has {X.nnz:,} non-zero elements, more than '
f'2,147,483,647 (INT32_MAX), the maximum supported in R')
raise ValueError(error_message)
from ryp import r, to_r
r('suppressPackageStartupMessages(library(SingleCellExperiment))')
check_type(sce_object_name, 'sce_object_name', str, 'a string')
check_R_variable_name(sce_object_name, 'sce_object_name')
if QC_column is not None:
check_type(QC_column, 'QC_column', str, 'a string')
if QC_column in self._obs:
QCed_cells = self._obs[QC_column]
check_dtype(QCed_cells, f'obs[{QC_column!r}]',
pl.Boolean)
if QCed_cells.null_count() or not QCed_cells.all():
error_message = (
f'not all cells pass QC; remove cells failing QC with '
f'filter_obs({QC_column!r}) or by specifying '
f'subset=True in qc(), or set QC_column=None to '
f'include them in the SingleCellExperiment object')
raise ValueError(error_message)
check_type(assay, 'assay', str, 'a string')
valid_dtypes = pl.FLOAT_DTYPES | pl.INTEGER_DTYPES | \
{pl.String, pl.Enum, pl.Categorical, pl.Boolean}
for df, df_name in (self._obs, 'obs'), (self._var, 'var'):
for column, dtype in df.schema.items():
if dtype.base_type() not in valid_dtypes:
error_message = (
f'{df_name}[{column!r}] has the data type '
f'{dtype.base_type()!r}, which is not supported when '
f'converting to a SingleCellExperiment object')
raise TypeError(error_message)
# The R sparse matrix class SingleCellExperiment uses (dgCMatrix)
# requires float64 data and int32 indices and indptr. Convert each of
# these three to the correct type, if not already correct.
X = X.astype(np.float64, copy=False)
if X.indices.dtype == np.int64:
X = X.shrink_indices(copy=True)
try:
obs_names = self.obs_names
var_names = self.var_names
for names, names_name in (obs_names, 'obs_names'), \
(var_names, 'var_names'):
null_count = names.null_count()
if null_count:
error_message = (
f'{names_name} contains {null_count:,} null values, '
f'but must not contain any when converting to a '
f'SingleCellExperiment object')
raise ValueError(error_message)
to_r(X.T, '.SingleCell.X.T')
to_r(self._obs.drop(obs_names.name), '.SingleCell.obs',
rownames=obs_names)
to_r(self._var.drop(var_names.name), '.SingleCell.var',
rownames=var_names)
r('.SingleCell.obs_names = rownames(.SingleCell.obs)')
r('.SingleCell.var_names = rownames(.SingleCell.var)')
r('rownames(.SingleCell.X.T) = .SingleCell.var_names')
r('colnames(.SingleCell.X.T) = .SingleCell.obs_names')
r(f'.SingleCell.sce = SingleCellExperiment('
f'assays = list({assay} = .SingleCell.X.T), '
f'colData = S4Vectors::DataFrame('
f' .SingleCell.obs, check.names=FALSE), '
f'rowData = S4Vectors::DataFrame('
f' .SingleCell.var, check.names=FALSE))')
if self._obsm:
for key, value in self._obsm.items():
if isinstance(value, pl.DataFrame):
continue
to_r(value, '.SingleCell.value', colnames=[
f'{key}_{i}' for i in range(1, value.shape[1] + 1)])
r('rownames(.SingleCell.value) = .SingleCell.obs_names')
r(f'reducedDim(.SingleCell.sce, {key!r}) = '
f'.SingleCell.value')
if self._uns:
to_r(self._uns, '.SingleCell.uns')
r(f'.SingleCell.sce@metadata = .SingleCell.uns')
r(f'{sce_object_name} = .SingleCell.sce')
finally:
r(f'rm(list = Filter(exists, c(".SingleCell.obs_names", '
f'".SingleCell.var_names", ".SingleCell.X.T", '
f'".SingleCell.obs", ".SingleCell.var", ".SingleCell.sce", '
f'".SingleCell.value", ".SingleCell.uns")))')
[docs]
def copy(self,
*,
deep: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Make a copy of this SingleCell dataset.
Args:
deep: whether to perform a deep or shallow copy. Since polars
DataFrames are immutable, `obs` and `var` will always point
to the same underlying data as the original. The difference
when `deep=True` is that `X` and any NumPy arrays in `obsm`,
`varm`, `obsp`, `varp`, and `uns` will point to fresh copies
of the underlying data, instead of the same data as the
original SingleCell dataset. When `deep=False`, any
modifications to these NumPy arrays will modify both copies!
num_threads: the number of threads to use when making a deep copy
of `X`. Set `num_threads=-1` to use all available
cores, as determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores.
Does not affect the copied SingleCell dataset's
`num_threads`; this will always be the same as the
original dataset's `num_threads`. Can only be
specified when `deep=True` and `X` is not `None`.
Returns:
A copy of the SingleCell dataset.
"""
check_type(deep, 'deep', bool, 'Boolean')
if deep:
num_threads = self._process_num_threads(num_threads)
if self._X is not None:
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
X = self._X.copy()
finally:
self._X._num_threads = original_num_threads
else:
if num_threads is not None:
error_message = \
'num_threads can only be specified when X is not None'
raise ValueError(error_message)
X = None
obsm = {key: value if isinstance(value, pl.DataFrame) else
value.copy() for key, value in self._obsm.items()}
varm = {key: value if isinstance(value, pl.DataFrame) else
value.copy() for key, value in self._varm.items()}
obsp = {key: value.copy() for key, value in self._obsp.items()}
varp = {key: value.copy() for key, value in self._varp.items()}
uns = self._uns.copy(deep=True)
else:
if num_threads is not None:
error_message = \
'num_threads can only be specified when X is not None'
raise ValueError(error_message)
X = self._X
obsm = self._obsm.copy()
varm = self._varm.copy()
obsp = self._obsp.copy()
varp = self._varp.copy()
uns = self._uns.copy()
return SingleCell(X=X, obs=self._obs, var=self._var, obsm=obsm,
varm=varm, obsp=obsp, varp=varp, uns=uns,
num_threads=self._num_threads)
[docs]
def concat_obs(self,
datasets: SingleCell | Iterable[SingleCell],
/,
*more_datasets: SingleCell,
dataset_column: str | None = None,
dataset_labels: Iterable[str] | None = None,
flexible: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Concatenate one or more other SingleCell datasets with this one,
cell-wise. All datasets must have distinct `obs_names`.
By default, all datasets must have the same `var`, `varm`, `varp`, and
`uns`. They must also have the same columns in `obs` and the same keys
in `obsm`, with the same data types. `obsp` will be discarded during
the concatenation.
Conversely, if `flexible=True`, subset to genes present in all datasets
(according to the first column of `var`, i.e. the `var_names`) before
concatenating. Subset to columns of `var` and keys of `varm`, `varp`,
and `uns` that are identical in all datasets after this subsetting.
Also, subset to columns of `obs` and keys of `obsm` that are present in
all datasets, and have the same data types. All datasets' `obs_names`
must have the same name and data type, and similarly for their
`var_names`.
The one exception to the `obs` "same data type" rule: if a column is
Enum in some datasets and Categorical in others, or Enum in all
datasets but with different categories in each dataset, that column
will be retained as an Enum column (with the union of the categories)
in the concatenated `obs`.
If the datasets' `X` are a mix of CSR and CSC sparse arrays, they will
all be coerced to CSR.
Args:
datasets: one or more SingleCell datasets to concatenate with this
one
*more_datasets: additional SingleCell datasets to concatenate with
this one, specified as positional arguments
dataset_column: the name of an Enum column to be added to the
concatenated dataset's `obs` labeling which dataset
each cell came from. The labels themselves are
determined by the `dataset_labels` argument.
dataset_labels: a sequence of labels for each dataset, used to
populate `dataset_column`. There must be one label
per dataset being concatenated. If `dataset_labels`
is not specified, the labels default to
`{dataset_column}_0`, `{dataset_column}_1`, ...,
`{dataset_column}_{N - 1}`. Can only be specified
when `dataset_column` is not `None`.
flexible: whether to subset to genes, columns of `obs` and `var`,
and keys of `obsm`, `varm` and `uns` common to all
datasets before concatenating, rather than raising an
error on any mismatches
num_threads: the number of threads to use when concatenating. Does
not affect the concatenated SingleCell dataset's
`num_threads`; this will always be the same as the
first dataset's `num_threads`.
Returns:
The concatenated SingleCell dataset.
"""
# Check inputs
datasets = (self,) + to_tuple(datasets) + more_datasets
if len(datasets) == 1:
error_message = \
'need at least one other SingleCell dataset to concatenate'
raise ValueError(error_message)
check_types(datasets[1:], 'datasets', SingleCell,
'SingleCell datasets')
if self._X is not None:
if all(dataset._X is not None for dataset in datasets):
X_present = True
else:
error_message = (
'some datasets being concatenated have X missing, while '
'others do not')
raise ValueError(error_message)
else:
if all(dataset._X is None for dataset in datasets):
X_present = False
else:
error_message = (
'some datasets being concatenated have X missing, while '
'others do not')
raise ValueError(error_message)
if dataset_column is not None:
check_type(dataset_column, 'dataset_column', str, 'a string')
if any(dataset_column in dataset._obs for dataset in datasets):
error_message = (
f"dataset_column {dataset_column!r} is already a column "
f"of at least one dataset's obs; specify a different name "
f"for dataset_column")
raise ValueError(error_message)
if dataset_labels is not None:
dataset_labels = to_tuple_checked(
dataset_labels, 'dataset_labels', str, 'strings')
if len(dataset_labels) != len(datasets):
error_message = (
f'dataset_labels has length {len(dataset_labels):,}, '
f'but there are {len(datasets):,} datasets being '
f'concatenated')
raise ValueError(error_message)
else:
dataset_labels = (f'dataset_{i}' for i in range(len(datasets)))
elif dataset_labels is not None:
error_message = (
'when dataset_labels is specified, dataset_column must also '
'be specified')
raise ValueError(error_message)
check_type(flexible, 'flexible', bool, 'Boolean')
num_threads = self._process_num_threads(num_threads)
# Perform either flexible or non-flexible concatenation
if flexible:
# Check that `obs_names` and `var_names` have the same name and
# data type across all datasets
obs_names_name = self.obs_names.name
if not all(dataset.obs_names.name == obs_names_name
for dataset in datasets[1:]):
error_message = (
'not all SingleCell datasets have the same name for the '
'first column of obs (the obs_names column)')
raise ValueError(error_message)
var_names_name = self.var_names.name
if not all(dataset.var_names.name == var_names_name
for dataset in datasets[1:]):
error_message = (
'not all SingleCell datasets have the same name for the '
'first column of var (the var_names column)')
raise ValueError(error_message)
obs_names_dtype = self.obs_names.dtype
if not all(dataset.obs_names.dtype == obs_names_dtype
for dataset in datasets[1:]):
error_message = (
'not all SingleCell datasets have the same data type for '
'the first column of obs (the obs_names column)')
raise TypeError(error_message)
var_names_dtype = self.var_names.dtype
if not all(dataset.var_names.dtype == var_names_dtype
for dataset in datasets[1:]):
error_message = (
'not all SingleCell datasets have the same data type for '
'the first column of var (the var_names column)')
raise TypeError(error_message)
# Subset to genes in common across all datasets
genes_in_common = self.var_names\
.to_frame()\
.filter(pl.all_horizontal(
self.var_names.is_in(dataset.var_names)
for dataset in datasets))\
.to_series()
if len(genes_in_common) == 0:
error_message = \
'no genes are shared across all SingleCell datasets'
raise ValueError(error_message)
datasets = [dataset
if len(genes_in_common) == len(dataset.var_names) and
dataset.var_names.equals(genes_in_common) else
dataset[:, genes_in_common] for dataset in datasets]
# Subset to columns of `var` and keys of `varm`, `varp`, and `uns`
# that are identical in all datasets after this subsetting
var_columns_in_common = [
column.name for column in datasets[0]._var[:, 1:]
if all(column.name in dataset._var and
dataset._var[column.name].equals(column)
for dataset in datasets[1:])]
varm_keys_in_common = [
key for key, value in self._varm.items()
if all(key in dataset._varm and
type(value) is type(dataset._varm[key]) and
(dataset._varm[key].dtype == value.dtype and
array_equal(dataset._varm[key], value)
if isinstance(value, np.ndarray) else
dataset._varm[key].equals(value))
for dataset in datasets[1:])]
varp_keys_in_common = [
key for key, value in self._varp.items()
if all(key in dataset._varp and
dataset._varp[key].dtype == value.dtype and
sparse_equal(dataset._varp[key], value)
for dataset in datasets[1:])]
uns_keys_in_common = [
key for key, value in self._uns.items()
if isinstance(value, dict) and
all(isinstance(dataset._uns[key], dict) and
SingleCell._eq_uns(value, dataset._uns[key],
different_order_ok=True)
for dataset in datasets[1:]) or
isinstance(value, np.ndarray) and
all(isinstance(dataset._uns[key], np.ndarray) and
array_equal(dataset._uns[key], value)
for dataset in datasets[1:]) or
all(not isinstance(dataset._uns[key], (dict, np.ndarray))
and dataset._uns[key] == value
for dataset in datasets[1:])]
for dataset in datasets:
dataset._var = dataset._var.select(dataset.var_names,
*var_columns_in_common)
dataset._varm = {key: dataset._varm[key]
for key in varm_keys_in_common}
dataset._varp = {key: dataset._varp[key]
for key in varp_keys_in_common}
dataset._uns = {key: dataset._uns[key]
for key in uns_keys_in_common}
# Subset to columns of `obs` and keys of `obsm` that are present in
# all datasets, and have the same data types. Also include columns
# of `obs` that are Enum in some datasets and Categorical in
# others, or Enum in all datasets but with different categories in
# each dataset; cast these to Enum.
obs_mismatched_categoricals = {
column for column, dtype in self._obs[:, 1:]
.select(pl.col(pl.Categorical, pl.Enum)).schema.items()
if all(column in dataset._obs and
dataset._obs[column].dtype in (pl.Categorical, pl.Enum)
for dataset in datasets[1:]) and
not all(dataset._obs[column].dtype == dtype
for dataset in datasets[1:])}
obs_columns_in_common = [
column
for column, dtype in islice(self._obs.schema.items(), 1, None)
if column in obs_mismatched_categoricals or
all(column in dataset._obs and
dataset._obs[column].dtype == dtype
for dataset in datasets[1:])]
cast_dict = {column: pl.Enum(
pl.concat([dataset._obs[column].cat.get_categories()
for dataset in datasets])
.unique(maintain_order=True))
for column in obs_mismatched_categoricals}
for dataset in datasets:
# the `.with_columns(...)` is a faster `.cast(cast_dict)`
dataset._obs = dataset._obs\
.select(dataset.obs_names, *obs_columns_in_common)\
.with_columns(cast_to_Enum(dataset._obs[column], enum_type)
.alias(column)
for column, enum_type in cast_dict.items())
obsm_keys_in_common = [
key for key, value in self._obsm.items()
if all(key in dataset._obsm and
type(dataset._obsm[key]) is type(value) and
(dataset._obsm[key].dtype == value.dtype
if isinstance(value, np.ndarray) else
dataset._obsm[key].schema == value.schema)
for dataset in datasets[1:])]
for dataset in datasets:
dataset._obsm = {key: dataset._obsm[key]
for key in obsm_keys_in_common}
else: # non-flexible
# Check that all `var`, `varm`, `varp`, and `uns` are identical
var = self._var
for dataset in datasets[1:]:
if not dataset._var.equals(var):
error_message = (
'all SingleCell datasets must have the same var, '
'unless flexible=True')
raise ValueError(error_message)
varm = self._varm
for dataset in datasets[1:]:
if dataset._varm.keys() != varm.keys() or \
any(type(dataset._varm[key]) is not type(value) or
(dataset._varm[key].dtype != value.dtype
if isinstance(value, np.ndarray) else
dataset._varm[key].schema != value.schema)
for key, value in varm.items()) or not \
all(array_equal(dataset._varm[key], value)
for key, value in varm.items()):
error_message = (
'all SingleCell datasets must have the same varm, '
'unless flexible=True')
raise ValueError(error_message)
varp = self._varp
for dataset in datasets[1:]:
if dataset._varp.keys() != varp.keys() or \
any(type(dataset._varp[key]) is not type(value) or
dataset._varp[key].dtype != value.dtype
for key, value in varp.items()) or not \
all(sparse_equal(dataset._varp[key], value)
for key, value in varp.items()):
error_message = (
'all SingleCell datasets must have the same varp, '
'unless flexible=True')
raise ValueError(error_message)
for dataset in datasets[1:]:
if not SingleCell._eq_uns(self._uns, dataset._uns):
error_message = (
'all SingleCell datasets must have the same uns, '
'unless flexible=True')
raise ValueError(error_message)
# Check that all `obs` have the same columns and data types
schema = self._obs.schema
for dataset in datasets[1:]:
if dataset._obs.schema != schema:
if dataset._obs.columns != self._obs.columns:
error_message = (
'all SingleCell datasets must have the same '
'columns in obs, unless flexible=True')
raise ValueError(error_message)
else:
error_message = (
'all SingleCell datasets must have the same data '
'type for each column of obs, unless '
'flexible=True')
raise TypeError(error_message)
# Check that all `obsm` have the same keys and data types
obsm = self._obsm
for dataset in datasets[1:]:
if dataset._obsm.keys() != obsm.keys():
error_message = (
'all SingleCell datasets must have the same keys in '
'obsm, unless flexible=True')
raise ValueError(error_message)
if any(type(dataset._obsm[key]) is not type(value) or
(dataset._obsm[key].dtype != value.dtype
if isinstance(value, np.ndarray) else
dataset._obsm[key].schema != value.schema)
for key, value in obsm.items()):
error_message = (
'all SingleCell datasets must have the same data '
'type for each key in obsm, unless flexible=True')
raise TypeError(error_message)
# If `dataset_column` is not `None`, add labels for each dataset
if dataset_column is not None:
for dataset, label in zip(datasets, dataset_labels):
dataset._obs = dataset._obs\
.with_columns(pl.lit(label).alias(dataset_column))
# Concatenate; output should be CSR when there's a mix of inputs
obs = pl.concat([dataset._obs for dataset in datasets])
num_unique = obs[:, 0].n_unique()
if num_unique < len(obs):
error_message = (
f'obs_names contains {len(obs) - num_unique:,} duplicates '
f'after concatenation')
raise ValueError(error_message)
if X_present:
if all(isinstance(dataset._X, csr_array) for dataset in datasets):
X = sparse_major_stack([dataset._X for dataset in datasets],
num_threads=num_threads)
elif all(isinstance(dataset._X, csc_array)
for dataset in datasets):
X = sparse_minor_stack([dataset._X for dataset in datasets],
num_threads=num_threads)
else:
X = sparse_major_stack([dataset._X.tocsr()
if isinstance(dataset._X, csc_array)
else dataset._X
for dataset in datasets],
num_threads=num_threads)
else:
X = None
obsm = {key: concatenate([dataset._obsm[key] for dataset in datasets],
num_threads=num_threads)
if isinstance(value, np.ndarray) else
pl.concat([dataset._obsm[key] for dataset in datasets])
for key, value in datasets[0]._obsm.items()}
return SingleCell(X=X, obs=obs, var=datasets[0]._var, obsm=obsm,
varm=datasets[0]._varm, varp=datasets[0]._varp,
uns=datasets[0]._uns,
num_threads=datasets[0]._num_threads)
[docs]
def concat_var(self,
datasets: SingleCell | Iterable[SingleCell],
/,
*more_datasets: SingleCell,
dataset_column: str | None = None,
dataset_labels: Iterable[str] | None = None,
flexible: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Concatenate one or more other SingleCell datasets with this one,
gene-wise. This is much less common than the cell-wise concatenation
provided by `concat_obs()`. All datasets must have distinct
`var_names`.
By default, all datasets must have the same `obs`, `obsm`, `obsp`, and
`uns`. They must also have the same columns in `var` and the same keys
in `varm`, with the same data types. `varp` will be discarded during
the concatenation.
Conversely, if `flexible=True`, subset to cells present in all datasets
(according to the first column of `obs`, i.e. the `obs_names`) before
concatenating. Subset to columns of `obs` and keys of `obsm`, `obsp`,
and `uns` that are identical in all datasets after this subsetting.
Also, subset to columns of `var` and keys of `varm` that are present in
all datasets, and have the same data types. All datasets' `obs_names`
must have the same name and data type, and similarly for their
`var_names`.
The one exception to the `var` "same data type" rule: if a column is
Enum in some datasets and Categorical in others, or Enum in all
datasets but with different categories in each dataset, that column
will be retained as an Enum column (with the union of the categories)
in the concatenated `var`.
If the datasets' `X` are a mix of CSR and CSC sparse arrays, they will
all be coerced to CSR.
Args:
datasets: one or more SingleCell datasets to concatenate with this
one
*more_datasets: additional SingleCell datasets to concatenate with
this one, specified as positional arguments
dataset_column: the name of an Enum column to be added to the
concatenated dataset's `var` labeling which dataset
each cell came from. The labels themselves are
determined by the `dataset_labels` argument.
dataset_labels: a sequence of labels for each dataset, used to
populate `dataset_column`. There must be one label
per dataset being concatenated. If `dataset_labels`
is not specified, the labels default to
`{dataset_column}_0`, `{dataset_column}_1`, ...,
`{dataset_column}_{N - 1}`. Can only be specified
when `dataset_column` is not `None`.
flexible: whether to subset to cells, columns of `obs` and `var`,
and keys of `obsm`, `varm` and `uns` common to all
datasets before concatenating, rather than raising an
error on any mismatches
num_threads: the number of threads to use when concatenating. Does
not affect the concatenated SingleCell dataset's
`num_threads`; this will always be the same as the
first dataset's `num_threads`.
Returns:
The concatenated SingleCell dataset.
"""
# Check inputs
datasets = (self,) + to_tuple(datasets) + more_datasets
if len(datasets) == 1:
error_message = \
'need at least one other SingleCell dataset to concatenate'
raise ValueError(error_message)
check_types(datasets[1:], 'datasets', SingleCell,
'SingleCell datasets')
if self._X is not None:
if all(dataset._X is not None for dataset in datasets):
X_present = True
else:
error_message = (
'some datasets being concatenated have X missing, while '
'others do not')
raise ValueError(error_message)
else:
if all(dataset._X is None for dataset in datasets):
X_present = False
else:
error_message = (
'some datasets being concatenated have X missing, while '
'others do not')
raise ValueError(error_message)
if dataset_column is not None:
check_type(dataset_column, 'dataset_column', str, 'a string')
if any(dataset_column in dataset._obs for dataset in datasets):
error_message = (
f"dataset_column {dataset_column!r} is already a column "
f"of at least one dataset's var; specify a different name "
f"for dataset_column")
raise ValueError(error_message)
if dataset_labels is not None:
dataset_labels = to_tuple_checked(
dataset_labels, 'dataset_labels', str, 'strings')
if len(dataset_labels) != len(datasets):
error_message = (
f'dataset_labels has length {len(dataset_labels):,}, '
f'but there are {len(datasets):,} datasets being '
f'concatenated')
raise ValueError(error_message)
else:
dataset_labels = (f'dataset_{i}' for i in range(len(datasets)))
elif dataset_labels is not None:
error_message = (
'when dataset_labels is specified, dataset_column must also '
'be specified')
raise ValueError(error_message)
check_type(flexible, 'flexible', bool, 'Boolean')
num_threads = self._process_num_threads(num_threads)
# Perform either flexible or non-flexible concatenation
if flexible:
# Check that `obs_names` and `var_names` have the same name and
# data type across all datasets
obs_names_name = self.obs_names.name
if not all(dataset.obs_names.name == obs_names_name
for dataset in datasets[1:]):
error_message = (
'not all SingleCell datasets have the same name for the '
'first column of obs (the obs_names column)')
raise ValueError(error_message)
var_names_name = self.var_names.name
if not all(dataset.var_names.name == var_names_name
for dataset in datasets[1:]):
error_message = (
'not all SingleCell datasets have the same name for the '
'first column of var (the var_names column)')
raise ValueError(error_message)
obs_names_dtype = self.obs_names.dtype
if not all(dataset.obs_names.dtype == obs_names_dtype
for dataset in datasets[1:]):
error_message = (
'not all SingleCell datasets have the same data type for '
'the first column of obs (the obs_names column)')
raise TypeError(error_message)
var_names_dtype = self.var_names.dtype
if not all(dataset.var_names.dtype == var_names_dtype
for dataset in datasets[1:]):
error_message = (
'not all SingleCell datasets have the same data type for '
'the first column of var (the var_names column)')
raise TypeError(error_message)
# Subset to cells in common across all datasets
cells_in_common = self.obs_names\
.to_frame()\
.filter(pl.all_horizontal(
self.obs_names.is_in(dataset.obs_names)
for dataset in datasets))\
.to_series()
if len(cells_in_common) == 0:
error_message = \
'no cells are shared across all SingleCell datasets'
raise ValueError(error_message)
datasets = [dataset
if len(cells_in_common) == len(dataset.obs_names) and
dataset.obs_names.equals(cells_in_common) else
dataset[cells_in_common] for dataset in datasets]
# Subset to columns of `obs` and keys of `obsm`, `obsp`, and `uns`
# that are identical in all datasets after this subsetting
obs_columns_in_common = [
column.name for column in datasets[0]._obs[:, 1:]
if all(column.name in dataset._obs and
dataset._obs[column.name].equals(column)
for dataset in datasets[1:])]
obsm_keys_in_common = [
key for key, value in self._obsm.items()
if all(key in dataset._obsm and
type(value) is type(dataset._obsm[key]) and
(dataset._obsm[key].dtype == value.dtype and
array_equal(dataset._obsm[key], value)
if isinstance(value, np.ndarray) else
dataset._obsm[key].equals(value))
for dataset in datasets[1:])]
obsp_keys_in_common = [
key for key, value in self._obsp.items()
if all(key in dataset._obsp and
dataset._obsp[key].dtype == value.dtype and
sparse_equal(dataset._obsp[key], value)
for dataset in datasets[1:])]
uns_keys_in_common = [
key for key, value in self._uns.items()
if isinstance(value, dict) and
all(isinstance(dataset._uns[key], dict) and
SingleCell._eq_uns(value, dataset._uns[key],
different_order_ok=True)
for dataset in datasets[1:]) or
isinstance(value, np.ndarray) and
all(isinstance(dataset._uns[key], np.ndarray) and
array_equal(dataset._uns[key], value)
for dataset in datasets[1:]) or
all(not isinstance(dataset._uns[key], (dict, np.ndarray))
and dataset._uns[key] == value
for dataset in datasets[1:])]
for dataset in datasets:
dataset._obs = dataset._obs.select(dataset.obs_names,
*obs_columns_in_common)
dataset._obsm = {key: dataset._obsm[key]
for key in obsm_keys_in_common}
dataset._obsp = {key: dataset._obsp[key]
for key in obsp_keys_in_common}
dataset._uns = {key: dataset._uns[key]
for key in uns_keys_in_common}
# Subset to columns of `var` and keys of `varm` that are present in
# all datasets, and have the same data types. Also include columns
# of `var` that are Enum in some datasets and Categorical in
# others, or Enum in all datasets but with different categories in
# each dataset; cast these to Enum.
var_mismatched_categoricals = {
column for column, dtype in self._var[:, 1:]
.select(pl.col(pl.Categorical, pl.Enum)).schema.items()
if all(column in dataset._var and
dataset._var[column].dtype in (pl.Categorical, pl.Enum)
for dataset in datasets[1:]) and
not all(dataset._var[column].dtype == dtype
for dataset in datasets[1:])}
var_columns_in_common = [
column
for column, dtype in islice(self._var.schema.items(), 1, None)
if column in var_mismatched_categoricals or
all(column in dataset._var and
dataset._var[column].dtype == dtype
for dataset in datasets[1:])]
cast_dict = {column: pl.Enum(
pl.concat([dataset._var[column].cat.get_categories()
for dataset in datasets])
.unique(maintain_order=True))
for column in var_mismatched_categoricals}
for dataset in datasets:
# the `.with_columns(...)` is a faster `.cast(cast_dict)`
dataset._var = dataset._var\
.select(dataset.var_names, *var_columns_in_common)\
.with_columns(cast_to_Enum(dataset._var[column], enum_type)
.alias(column)
for column, enum_type in cast_dict.items())
varm_keys_in_common = [
key for key, value in self._varm.items()
if all(key in dataset._varm and
type(dataset._varm[key]) is type(value) and
(dataset._varm[key].dtype == value.dtype
if isinstance(value, np.ndarray) else
dataset._varm[key].schema == value.schema)
for dataset in datasets[1:])]
for dataset in datasets:
dataset._varm = {key: dataset._varm[key]
for key in varm_keys_in_common}
else: # non-flexible
# Check that all `obs`, `obsm`, `obsp`, and `uns` are identical
obs = self._obs
for dataset in datasets[1:]:
if not dataset._obs.equals(obs):
error_message = (
'all SingleCell datasets must have the same obs, '
'unless flexible=True')
raise ValueError(error_message)
obsm = self._obsm
for dataset in datasets[1:]:
if dataset._obsm.keys() != obsm.keys() or \
any(type(dataset._obsm[key]) is not type(value) or
(dataset._obsm[key].dtype != value.dtype
if isinstance(value, np.ndarray) else
dataset._obsm[key].schema != value.schema)
for key, value in obsm.items()) or not \
all(array_equal(dataset._obsm[key], value)
for key, value in obsm.items()):
error_message = (
'all SingleCell datasets must have the same obsm, '
'unless flexible=True')
raise ValueError(error_message)
obsp = self._obsp
for dataset in datasets[1:]:
if dataset._obsp.keys() != obsp.keys() or \
any(type(dataset._obsp[key]) is not type(value) or
dataset._obsp[key].dtype != value.dtype
for key, value in obsp.items()) or not \
all(sparse_equal(dataset._obsp[key], value)
for key, value in obsp.items()):
error_message = (
'all SingleCell datasets must have the same obsp, '
'unless flexible=True')
raise ValueError(error_message)
for dataset in datasets[1:]:
if not SingleCell._eq_uns(self._uns, dataset._uns):
error_message = (
'all SingleCell datasets must have the same uns, '
'unless flexible=True')
raise ValueError(error_message)
# Check that all `var` have the same columns and data types
schema = self._var.schema
for dataset in datasets[1:]:
if dataset._var.schema != schema:
if dataset._var.columns != self._var.columns:
error_message = (
'all SingleCell datasets must have the same '
'columns in var, unless flexible=True')
raise ValueError(error_message)
else:
error_message = (
'all SingleCell datasets must have the same data '
'type for each column of var, unless '
'flexible=True')
raise TypeError(error_message)
# Check that all `varm` have the same keys and data types
varm = self._varm
for dataset in datasets[1:]:
if dataset._varm.keys() != varm.keys():
error_message = (
'all SingleCell datasets must have the same keys in '
'varm, unless flexible=True')
raise ValueError(error_message)
if any(type(dataset._varm[key]) is not type(value) or
(dataset._varm[key].dtype != value.dtype
if isinstance(value, np.ndarray) else
dataset._varm[key].schema != value.schema)
for key, value in varm.items()):
error_message = (
'all SingleCell datasets must have the same data '
'type for each key in varm, unless flexible=True')
raise TypeError(error_message)
# If `dataset_column` is not `None`, add labels for each dataset
if dataset_column is not None:
for dataset, label in zip(datasets, dataset_labels):
dataset._var = dataset._var\
.with_columns(pl.lit(label).alias(dataset_column))
# Concatenate; output should be CSR when there's a mix of inputs
var = pl.concat([dataset._var for dataset in datasets])
num_unique = var[:, 0].n_unique()
if num_unique != len(var):
error_message = (
f'var_names contains {len(var) - num_unique:,} duplicates '
f'after concatenation')
raise ValueError(error_message)
if X_present:
if all(isinstance(dataset._X, csr_array) for dataset in datasets):
X = sparse_minor_stack([dataset._X for dataset in datasets],
num_threads=num_threads)
elif all(isinstance(dataset._X, csc_array)
for dataset in datasets):
X = sparse_major_stack([dataset._X for dataset in datasets],
num_threads=num_threads)
else:
X = sparse_minor_stack([dataset._X.tocsr()
if isinstance(dataset._X, csc_array)
else dataset._X
for dataset in datasets],
num_threads=num_threads)
else:
X = None
varm = {key: concatenate([dataset._varm[key] for dataset in datasets],
num_threads=num_threads)
if isinstance(value, np.ndarray) else
pl.concat([dataset._varm[key] for dataset in datasets])
for key, value in datasets[0]._varm.items()}
return SingleCell(X=X, obs=datasets[0]._obs, var=var,
obsm=datasets[0]._obsm, varm=varm,
obsp=datasets[0]._obsp, uns=datasets[0]._uns,
num_threads=datasets[0]._num_threads)
[docs]
def split_by_obs(self,
by_column: SingleCellColumn,
/,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
sort_by_size: bool = False,
num_threads: int | np.integer | None = None) -> \
dict[str, SingleCell]:
"""
The opposite of `concat_obs()`: splits a SingleCell dataset into a
dictionary of SingleCell datasets, one per unique value of a column of
`obs`.
Args:
by_column: a String, Enum, Categorical, or integer column of `obs`
to split by. Can be a column name, a polars expression,
a polars Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a polars
Series or 1D NumPy array. Can contain `null` entries:
the corresponding cells will not be included in the
result.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will not be
selected when splitting.
sort_by_size: if `True`, datasets in the returned dictionary will
be sorted in decreasing order of size. If `False`,
they will be sorted in ascending order, according to
the sort order of `by_column`'s data type.
num_threads: the number of threads to use when splitting `X`. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores.
Can only be specified when `X` is not `None`.
Returns:
A dictionary mapping each unique value of `by_column` to a
SingleCell dataset subset to cells where `column` has that value.
"""
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
by_column = self._get_column('obs', by_column, 'by_column',
(pl.String, pl.Enum, pl.Categorical,
'integer'),
QC_column=QC_column, allow_null=True)
check_type(sort_by_size, 'sort_by_size', bool, 'Boolean')
num_threads = self._process_num_threads(num_threads)
values = by_column.value_counts(sort=True).to_series().drop_nulls() \
if sort_by_size else by_column.unique().sort()
if QC_column is None:
return {value: self.filter_obs(by_column == value,
num_threads=num_threads)
for value in values}
else:
return {value: self.filter_obs(by_column.eq(value) & QC_column,
num_threads=num_threads)
for value in values}
[docs]
def split_by_var(self,
by_column: SingleCellColumn,
/,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
sort_by_size: bool = False,
num_threads: int | np.integer | None = None) -> \
dict[str, SingleCell]:
"""
The opposite of `concat_var()`: splits a SingleCell dataset into a
dictionary of SingleCell datasets, one per unique value of a column of
`var`.
Args:
by_column: a String, Enum, Categorical, or integer column of `var`
to split by. Can be a column name, a polars expression,
a polars Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a polars
Series or 1D NumPy array. Can contain `null` entries:
the corresponding genes will not be included in the
result.
sort_by_size: if `True`, datasets in the returned dictionary will
be sorted in decreasing order of size. If `False`,
they will be sorted in ascending order, according to
the sort order of `by_column`'s data type.
num_threads: the number of threads to use when splitting `X`. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores.
Can only be specified when `X` is not `None`.
Returns:
A dictionary mapping each unique value of `by_column` to a
SingleCell dataset subset to genes where `column` has that value.
"""
by_column = self._get_column('var', by_column, 'by_column',
(pl.String, pl.Enum, pl.Categorical,
'integer'), allow_null=True)
check_type(sort_by_size, 'sort_by_size', bool, 'Boolean')
num_threads = self._process_num_threads(num_threads)
values = by_column.value_counts(sort=True).to_series().drop_nulls() \
if sort_by_size else by_column.unique().sort()
return {value: self.filter_var(by_column == value,
num_threads=num_threads)
for value in values}
[docs]
def tocsr(self,
*,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Make a copy of this SingleCell dataset, converting `X` to a
`csr_array`. Raise an error if `X` is already a `csr_array`.
Args:
num_threads: the number of threads to use when converting to CSR.
Set `num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores.
Does not affect the copied SingleCell dataset's
`num_threads`; this will always be the same as the
original dataset's `num_threads`.
Returns:
A copy of this SingleCell dataset, with `X` as a `csr_array`.
"""
# Check that `X` is present
if self._X is None:
error_message = 'X is None, so converting to CSR is not possible'
raise ValueError(error_message)
if isinstance(self._X, csr_array):
error_message = 'X is already a csr_array'
raise TypeError(error_message)
num_threads = self._process_num_threads(num_threads)
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
return SingleCell(X=self._X.tocsr(), obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
finally:
self._X._num_threads = original_num_threads
[docs]
def tocsc(self,
*,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Make a copy of this SingleCell dataset, converting `X` to a csc_array.
Raise an error if `X` is already a `csc_array`.
This function is provided for completeness, but `csr_array` is a far
better format than `csc_array` for cell-wise operations like
pseudobulking, so using `tocsc()` is rarely advisable.
Args:
num_threads: the number of threads to use when converting to CSC.
Set `num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores.
Does not affect the copied SingleCell dataset's
`num_threads`; this will always be the same as the
original dataset's `num_threads`.
Returns:
A copy of this SingleCell dataset, with `X` as a `csc_array`.
"""
# Check that `X` is present
if self._X is None:
error_message = 'X is None, so converting to CSC is not possible'
raise ValueError(error_message)
if isinstance(self._X, csc_array):
error_message = 'X is already a csc_array'
raise TypeError(error_message)
num_threads = self._process_num_threads(num_threads)
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
return SingleCell(X=self._X.tocsc(), obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
finally:
self._X._num_threads = original_num_threads
[docs]
def filter_obs(self,
*predicates: pl.Expr | pl.Series | str |
Iterable[pl.Expr | pl.Series | str] | bool |
list[bool] | np.ndarray[np.dtype[np.bool_]],
num_threads: int | np.integer | None = None,
**constraints: Any) -> SingleCell:
"""
Equivalent to `df.filter()` from polars, but applied to both
`obs`/`obsm` and `X`.
Args:
*predicates: one or more column names, expressions that evaluate to
Boolean Series, Boolean Series, lists of Booleans,
and/or 1D Boolean NumPy arrays
num_threads: the number of threads to use when filtering `X`. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores.
Does not affect the filtered SingleCell dataset's
`num_threads`; this will always be the same as the
original dataset's `num_threads`. Can only be
specified when `X` is not `None`.
**constraints: column filters: `name=value` filters to cells
where the column named `name` has the value `value`
Returns:
A new SingleCell dataset filtered to cells passing all the
Boolean filters in `predicates` and `constraints`.
"""
obs = self._obs\
.with_columns(_SingleCell_index=pl.int_range(pl.len(),
dtype=pl.Int32))\
.filter(*predicates, **constraints)
indices = obs['_SingleCell_index'].to_numpy()
if self._X is None:
if num_threads is not None:
error_message = \
'num_threads can only be specified when deep=True'
raise ValueError(error_message)
return SingleCell(X=None,
obs=obs.drop('_SingleCell_index'), var=self._var,
obsm={key: value[indices]
for key, value in self._obsm.items()},
varm=self._varm,
obsp={key: value[ix_symmetric(indices)]
for key, value in self._obsp.items()},
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
else:
num_threads = self._process_num_threads(num_threads)
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
return SingleCell(X=self._X[indices],
obs=obs.drop('_SingleCell_index'),
var=self._var,
obsm={key: value[indices]
for key, value in self._obsm.items()},
varm=self._varm,
obsp={key: value[ix_symmetric(indices)]
for key, value in self._obsp.items()},
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
finally:
self._X._num_threads = original_num_threads
[docs]
def filter_var(self,
*predicates: pl.Expr | pl.Series | str |
Iterable[pl.Expr | pl.Series | str] | bool |
list[bool] | np.ndarray[np.dtype[np.bool_]],
num_threads: int | np.integer | None = None,
**constraints: Any) -> SingleCell:
"""
Equivalent to `df.filter()` from polars, but applied to both
`var`/`varm` and `X`.
Args:
*predicates: one or more column names, expressions that evaluate to
Boolean Series, Boolean Series, lists of Booleans,
and/or 1D Boolean NumPy arrays
num_threads: the number of threads to use when filtering `X`. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`. By default
(`num_threads=None`), use `self.num_threads` cores.
Does not affect the filtered SingleCell dataset's
`num_threads`; this will always be the same as the
original dataset's `num_threads`. Can only be
specified when `X` is not `None`.
**constraints: column filters: `name=value` filters to genes
where the column named `name` has the value `value`
Returns:
A new SingleCell dataset filtered to genes passing all the
Boolean filters in `predicates` and `constraints`.
"""
var = self._var\
.with_columns(_SingleCell_index=pl.int_range(pl.len(),
dtype=pl.Int32))\
.filter(*predicates, **constraints)
indices = var['_SingleCell_index'].to_numpy()
if self._X is None:
if num_threads is not None:
error_message = \
'num_threads can only be specified when X is not None'
raise ValueError(error_message)
return SingleCell(X=None,
obs=self._obs, var=var.drop('_SingleCell_index'),
obsm=self._obsm,
varm={key: value[indices]
for key, value in self._varm.items()},
obsp=self._obsp,
varp={key: value[ix_symmetric(indices)]
for key, value in self._varp.items()},
uns=self._uns, num_threads=self._num_threads)
else:
num_threads = self._process_num_threads(num_threads)
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
return SingleCell(X=self._X[:, indices],
obs=self._obs,
var=var.drop('_SingleCell_index'),
obsm=self._obsm,
varm={key: value[indices]
for key, value in self._varm.items()},
obsp=self._obsp,
varp={key: value[ix_symmetric(indices)]
for key, value in self._varp.items()},
uns=self._uns, num_threads=self._num_threads)
finally:
self._X._num_threads = original_num_threads
[docs]
def select_obs(self,
*exprs: Scalar | pl.Expr | pl.Series |
Iterable[Scalar | pl.Expr | pl.Series],
**named_exprs: Scalar | pl.Expr | pl.Series) -> SingleCell:
"""
Equivalent to `df.select()` from polars, but applied to `obs`.
`obs_names` will be automatically included as the first column, if not
included explicitly.
Args:
*exprs: column(s) to select, specified as positional arguments.
Accepts expression input. Strings are parsed as column
names, other non-expression inputs are parsed as literals.
**named_exprs: additional columns to select, specified as keyword
arguments. The columns will be renamed to the
keyword used.
Returns:
A new SingleCell dataset with
`obs=obs.select(*exprs, **named_exprs)`, and `obs_names` as the
first column unless already included explicitly.
"""
obs = self._obs.select(*exprs, **named_exprs)
if self.obs_names.name in obs:
error_message = (
f'one of the selected columns is the obs_names, '
f'{self.obs_names.name!r}, but the obs_names will always be '
f'selected automatically as the first column and thus should '
f'not be specified explicitly')
raise ValueError(error_message)
obs = obs.select(self.obs_names, pl.all())
return SingleCell(X=self._X, obs=obs, var=self._var, obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def select_var(self,
*exprs: Scalar | pl.Expr | pl.Series |
Iterable[Scalar | pl.Expr | pl.Series],
**named_exprs: Scalar | pl.Expr | pl.Series) -> SingleCell:
"""
Equivalent to `df.select()` from polars, but applied to `var`.
`var_names` will be automatically included as the first column, if not
included explicitly.
Args:
*exprs: column(s) to select, specified as positional arguments.
Accepts expression input. Strings are parsed as column
names, other non-expression inputs are parsed as literals.
**named_exprs: additional columns to select, specified as keyword
arguments. The columns will be renamed to the
keyword used.
Returns:
A new SingleCell dataset with
`var=var.select(*exprs, **named_exprs)`, and `var_names` as the
first column unless already included explicitly.
"""
var = self._var.select(*exprs, **named_exprs)
if self.var_names.name in var:
error_message = (
f'one of the selected columns is the var_names, '
f'{self.var_names.name!r}, but the var_names will always be '
f'selected automatically as the first column and thus should '
f'not be specified explicitly')
raise ValueError(error_message)
var = var.select(self.var_names, pl.all())
return SingleCell(X=self._X, obs=self._obs, var=var, obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def select_obsm(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Subsets `obsm` to the specified key(s).
Args:
keys: key(s) to select
*more_keys: additional keys to select, specified as positional
arguments
Returns:
A new SingleCell dataset with `obsm` subset to the specified
key(s).
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._obsm:
error_message = \
f'tried to select {key!r}, which is not a key of obsm'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm={key: value for key, value in self._obsm.items()
if key in keys},
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def select_varm(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Subsets `varm` to the specified key(s).
Args:
keys: key(s) to select
*more_keys: additional keys to select, specified as positional
arguments
Returns:
A new SingleCell dataset with `varm` subset to the specified
key(s).
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._varm:
error_message = \
f'tried to select {key!r}, which is not a key of varm'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm,
varm={key: value for key, value in self._varm.items()
if key in keys},
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def select_obsp(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Subsets `obsp` to the specified key(s).
Args:
keys: key(s) to select
*more_keys: additional keys to select, specified as positional
arguments
Returns:
A new SingleCell dataset with `obsp` subset to the specified
key(s).
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._obsp:
error_message = \
f'tried to select {key!r}, which is not a key of obsp'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp={key: value for key, value in self._obsp.items()
if key in keys},
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def select_varp(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Subsets `varp` to the specified key(s).
Args:
keys: key(s) to select
*more_keys: additional keys to select, specified as positional
arguments
Returns:
A new SingleCell dataset with `varp` subset to the specified
key(s).
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._varp:
error_message = \
f'tried to select {key!r}, which is not a key of varp'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp={key: value for key, value in self._varp.items()
if key in keys},
uns=self._uns, num_threads=self._num_threads)
[docs]
def select_uns(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Subsets `uns` to the specified key(s).
Args:
keys: key(s) to select
*more_keys: additional keys to select, specified as positional
arguments
Returns:
A new SingleCell dataset with `uns` subset to the specified key(s).
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._uns:
error_message = \
f'tried to select {key!r}, which is not a key of uns'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp,
uns={key: value for key, value in self._uns.items()
if key in keys},
num_threads=self._num_threads)
[docs]
def with_columns_obs(self,
*exprs: Scalar | pl.Expr | pl.Series |
Iterable[Scalar | pl.Expr | pl.Series],
**named_exprs: Scalar | pl.Expr | pl.Series) -> \
SingleCell:
"""
Equivalent to `df.with_columns()` from polars, but applied to `obs`.
Args:
*exprs: column(s) to add, specified as positional arguments.
Accepts expression input. Strings are parsed as column
names, other non-expression inputs are parsed as literals.
**named_exprs: additional columns to add, specified as keyword
arguments. The columns will be renamed to the
keyword used.
Returns:
A new SingleCell dataset with
`obs=obs.with_columns(*exprs, **named_exprs)`.
"""
return SingleCell(X=self._X,
obs=self._obs.with_columns(*exprs, **named_exprs),
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def with_columns_var(self,
*exprs: Scalar | pl.Expr | pl.Series |
Iterable[Scalar | pl.Expr | pl.Series],
**named_exprs: Scalar | pl.Expr | pl.Series) -> \
SingleCell:
"""
Equivalent to `df.with_columns()` from polars, but applied to `var`.
Args:
*exprs: column(s) to add, specified as positional arguments.
Accepts expression input. Strings are parsed as column
names, other non-expression inputs are parsed as literals.
**named_exprs: additional columns to add, specified as keyword
arguments. The columns will be renamed to the
keyword used.
Returns:
A new SingleCell dataset with
`var=var.with_columns(*exprs, **named_exprs)`.
"""
return SingleCell(X=self._X, obs=self._obs,
var=self._var.with_columns(*exprs, **named_exprs),
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def with_obsm(self,
obsm: dict[str, np.ndarray | pl.DataFrame] = {},
/,
**more_obsm: np.ndarray) -> SingleCell:
"""
Adds one or more keys to `obsm`, overwriting existing keys with the
same names if present.
Args:
obsm: a dictionary of keys to add to (or overwrite in) `obsm`
**more_obsm: additional keys to add to (or overwrite in) `obsm`,
specified as keyword arguments
Returns:
A new SingleCell dataset with the new key(s) added to or
overwritten in `obsm`.
"""
check_type(obsm, 'obsm', dict, 'a dictionary')
for key, value in obsm.items():
if not isinstance(key, str):
error_message = (
f'all keys of obsm must be strings, but new obsm contains '
f'a key of type {type(key).__name__!r}')
raise TypeError(error_message)
obsm |= more_obsm
if len(obsm) == 0:
error_message = \
'obsm is empty and no keyword arguments were specified'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm | obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def with_varm(self,
varm: dict[str, np.ndarray | pl.DataFrame] = {},
/,
**more_varm: np.ndarray) -> SingleCell:
"""
Adds one or more keys to `varm`, overwriting existing keys with the
same names if present.
Args:
varm: a dictionary of keys to add to (or overwrite in) `varm`
**more_varm: additional keys to add to (or overwrite in) `varm`,
specified as keyword arguments
Returns:
A new SingleCell dataset with the new key(s) added to or
overwritten in `varm`.
"""
check_type(varm, 'varm', dict, 'a dictionary')
for key, value in varm.items():
if not isinstance(key, str):
error_message = (
f'all keys of varm must be strings, but new varm contains '
f'a key of type {type(key).__name__!r}')
raise TypeError(error_message)
varm |= more_varm
if len(varm) == 0:
error_message = \
'varm is empty and no keyword arguments were specified'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm | varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def with_obsp(self,
obsp: dict[str, csr_array | csc_array] = {},
/,
**more_obsp: csr_array | csc_array) -> SingleCell:
"""
Adds one or more keys to `obsp`, overwriting existing keys with the
same names if present.
Args:
obsp: a dictionary of keys to add to (or overwrite in) `obsp`
**more_obsp: additional keys to add to (or overwrite in) `obsp`,
specified as keyword arguments
Returns:
A new SingleCell dataset with the new key(s) added to or
overwritten in `obsp`.
"""
check_type(obsp, 'obsp', dict, 'a dictionary')
for key, value in obsp.items():
if not isinstance(key, str):
error_message = (
f'all keys of obsp must be strings, but new obsp contains '
f'a key of type {type(key).__name__!r}')
raise TypeError(error_message)
obsp |= more_obsp
if len(obsp) == 0:
error_message = \
'obsp is empty and no keyword arguments were specified'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp=self._obsp | obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def with_varp(self,
varp: dict[str, sparse.csr_array | sparse.csc_array |
sparse.csr_matrix | sparse.csc_matrix] = {},
/,
**more_varp: sparse.csr_array | sparse.csc_array |
sparse.csr_matrix | sparse.csc_matrix) -> \
SingleCell:
"""
Adds one or more keys to `varp`, overwriting existing keys with the
same names if present.
Args:
varp: a dictionary of keys to add to (or overwrite in) `varp`
**more_varp: additional keys to add to (or overwrite in) `varp`,
specified as keyword arguments
Returns:
A new SingleCell dataset with the new key(s) added to or
overwritten in `varp`.
"""
check_type(varp, 'varp', dict, 'a dictionary')
for key, value in varp.items():
if not isinstance(key, str):
error_message = (
f'all keys of varp must be strings, but new varp contains '
f'a key of type {type(key).__name__!r}')
raise TypeError(error_message)
varp |= more_varp
if len(varp) == 0:
error_message = \
'varp is empty and no keyword arguments were specified'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp | varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def with_uns(self,
uns: dict[str, UnsDict] = {},
/,
**more_uns: UnsItem | UnsDict) -> SingleCell:
"""
Adds one or more keys to `uns`, overwriting existing keys with the same
names if present.
Args:
uns: a dictionary of keys to add to (or overwrite in) `uns`
**more_uns: additional keys to add to (or overwrite in) `uns`,
specified as keyword arguments
Returns:
A new SingleCell dataset with the new key(s) added to or
overwritten in `uns`.
"""
check_type(uns, 'uns', dict, 'a dictionary')
for key, value in uns.items():
if not isinstance(key, str):
error_message = (
f'all keys of uns must be strings, but new uns contains a '
f'key of type {type(key).__name__!r}')
raise TypeError(error_message)
uns |= more_uns
if len(uns) == 0:
error_message = \
'uns is empty and no keyword arguments were specified'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp, uns=self._uns | uns,
num_threads=self._num_threads)
[docs]
def drop_X(self):
"""
Create a new SingleCell dataset with `X` removed, to reduce memory use.
Returns:
A new SingleCell dataset with `X` set to `None`.
"""
if self._X is None:
error_message = 'X is None, so it cannot be dropped'
raise TypeError(error_message)
return SingleCell(obs=self._obs, var=self._var, obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def drop_obs(self,
columns: pl.type_aliases.ColumnNameOrSelector |
Iterable[pl.type_aliases.ColumnNameOrSelector],
/,
*more_columns: pl.type_aliases.ColumnNameOrSelector) -> \
SingleCell:
"""
Create a new SingleCell dataset with `columns` and `more_columns`
removed from `obs`.
Args:
columns: columns(s) to drop
*more_columns: additional columns to drop, specified as
positional arguments
Returns:
A new SingleCell dataset with the column(s) removed.
"""
columns = to_tuple(columns) + more_columns
return SingleCell(X=self._X, obs=self._obs.drop(columns),
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def drop_var(self,
columns: pl.type_aliases.ColumnNameOrSelector |
Iterable[pl.type_aliases.ColumnNameOrSelector],
/,
*more_columns: pl.type_aliases.ColumnNameOrSelector) -> \
SingleCell:
"""
Create a new SingleCell dataset with `columns` and `more_columns`
removed from `var`.
Args:
columns: columns(s) to drop
*more_columns: additional columns to drop, specified as
positional arguments
Returns:
A new SingleCell dataset with the column(s) removed.
"""
columns = to_tuple(columns) + more_columns
return SingleCell(X=self._X, obs=self._obs,
var=self._var.drop(columns), obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def drop_obsm(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Create a new SingleCell dataset with `keys` and `more_keys` removed
from `obsm`.
Args:
keys: key(s) to drop
*more_keys: additional keys to drop, specified as positional
arguments
Returns:
A new SingleCell dataset with the specified key(s) removed from
obsm.
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._obsm:
error_message = \
f'tried to drop {key!r}, which is not a key of obsm'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm={key: value for key, value in self._obsm.items()
if key not in keys},
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def drop_varm(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Create a new SingleCell dataset with `keys` and `more_keys` removed
from `varm`.
Args:
keys: key(s) to drop
*more_keys: additional keys to drop, specified as positional
arguments
Returns:
A new SingleCell dataset with the specified key(s) removed from
varm.
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._varm:
error_message = \
f'tried to drop {key!r}, which is not a key of varm'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm,
varm={key: value for key, value in self._varm.items()
if key not in keys},
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def drop_obsp(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Create a new SingleCell dataset with `keys` and `more_keys` removed
from `obsp`.
Args:
keys: key(s) to drop
*more_keys: additional keys to drop, specified as positional
arguments
Returns:
A new SingleCell dataset with the specified key(s) removed from
obsp.
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._obsp:
error_message = \
f'tried to drop {key!r}, which is not a key of obsp'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp={key: value for key, value in self._obsp.items()
if key not in keys},
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def drop_varp(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Create a new SingleCell dataset with `keys` and `more_keys` removed
from `varp`.
Args:
keys: key(s) to drop
*more_keys: additional keys to drop, specified as positional
arguments
Returns:
A new SingleCell dataset with the specified key(s) removed from
varp.
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._varp:
error_message = \
f'tried to drop {key!r}, which is not a key of varp'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp={key: value for key, value in self._varp.items()
if key not in keys},
uns=self._uns, num_threads=self._num_threads)
[docs]
def drop_uns(self, keys: str | Iterable[str], /, *more_keys: str) -> \
SingleCell:
"""
Create a new SingleCell dataset with `keys` and `more_keys` removed
from `uns`.
Args:
keys: key(s) to drop
*more_keys: additional keys to drop, specified as positional
arguments
Returns:
A new SingleCell dataset with the specified key(s) removed from
uns.
"""
keys = to_tuple_checked(keys, 'keys', str, 'strings')
check_types(more_keys, 'more_keys', str, 'strings')
keys += more_keys
for key in keys:
if key not in self._uns:
error_message = \
f'tried to drop {key!r}, which is not a key of uns'
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp,
uns={key: value for key, value in self._uns.items()
if key not in keys},
num_threads=self._num_threads)
[docs]
def rename_obs(self,
mapping: dict[str, str] | Callable[[str], str],
/) -> SingleCell:
"""
Create a new SingleCell dataset with column(s) of `obs` renamed.
Rename column(s) of `obs`.
Args:
mapping: the renaming to apply, either as a dictionary with the old
names as keys and the new names as values, or a function
that takes an old name and returns a new name
Returns:
A new SingleCell dataset with the column(s) of `obs` renamed.
"""
return SingleCell(X=self._X, obs=self._obs.rename(mapping),
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def rename_var(self,
mapping: dict[str, str] | Callable[[str], str],
/) -> SingleCell:
"""
Create a new SingleCell dataset with column(s) of `var` renamed.
Args:
mapping: the renaming to apply, either as a dictionary with the old
names as keys and the new names as values, or a function
that takes an old name and returns a new name
Returns:
A new SingleCell dataset with the column(s) of `var` renamed.
"""
return SingleCell(X=self._X, obs=self._obs,
var=self._var.rename(mapping), obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def rename_obsm(self,
mapping: dict[str, str] | Callable[[str], str],
/) -> SingleCell:
"""
Create a new SingleCell dataset with key(s) of `obsm` renamed.
Args:
mapping: the renaming to apply, either as a dictionary with the old
names as keys and the new names as values, or a function
that takes an old name and returns a new name
Returns:
A new SingleCell dataset with the key(s) of `obsm` renamed.
"""
check_types(mapping.keys(), 'mapping.keys()', str, 'strings')
check_types(mapping.values(), 'mapping.values()', str, 'strings')
if isinstance(mapping, dict):
for key, new_key in mapping.items():
if key not in self._obsm:
error_message = \
f'tried to rename {key!r}, which is not a key of obsm'
raise ValueError(error_message)
if new_key in self._obsm:
error_message = (
f'tried to rename obsm[{key!r}] to obsm[{new_key!r}], '
f'but obsm[{new_key!r}] already exists')
raise ValueError(error_message)
obsm = {mapping.get(key, key): value
for key, value in self._obsm.items()}
elif isinstance(mapping, Callable):
obsm = {}
for key, value in self._obsm.items():
new_key = mapping(key)
if not isinstance(new_key, str):
error_message = (
f'tried to rename obsm[{key!r}] to a non-string value '
f'of type {type(new_key).__name__!r}')
raise TypeError(error_message)
if new_key in self._obsm:
error_message = (
f'tried to rename obsm[{key!r}] to obsm[{new_key!r}], '
f'but obsm[{new_key!r}] already exists')
raise ValueError(error_message)
obsm[new_key] = value
else:
error_message = (
f'mapping must be a dictionary or function, but has type '
f'{type(mapping).__name__!r}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def rename_varm(self,
mapping: dict[str, str] | Callable[[str], str],
/) -> SingleCell:
"""
Create a new SingleCell dataset with key(s) of `varm` renamed.
Args:
mapping: the renaming to apply, either as a dictionary with the old
names as keys and the new names as values, or a function
that takes an old name and returns a new name
Returns:
A new SingleCell dataset with the key(s) of `varm` renamed.
"""
check_types(mapping.keys(), 'mapping.keys()', str, 'strings')
check_types(mapping.values(), 'mapping.values()', str, 'strings')
if isinstance(mapping, dict):
for key, new_key in mapping.items():
if key not in self._varm:
error_message = \
f'tried to rename {key!r}, which is not a key of varm'
raise ValueError(error_message)
if new_key in self._varm:
error_message = (
f'tried to rename varm[{key!r}] to varm[{new_key!r}], '
f'but varm[{new_key!r}] already exists')
raise ValueError(error_message)
varm = {mapping.get(key, key): value
for key, value in self._varm.items()}
elif isinstance(mapping, Callable):
varm = {}
for key, value in self._varm.items():
new_key = mapping(key)
if not isinstance(new_key, str):
error_message = (
f'tried to rename varm[{key!r}] to a non-string value '
f'of type {type(new_key).__name__!r}')
raise TypeError(error_message)
if new_key in self._varm:
error_message = (
f'tried to rename varm[{key!r}] to varm[{new_key!r}], '
f'but varm[{new_key!r}] already exists')
raise ValueError(error_message)
varm[new_key] = value
else:
error_message = (
f'mapping must be a dictionary or function, but has type '
f'{type(mapping).__name__!r}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=varm, obsp=self._obsp,
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def rename_obsp(self,
mapping: dict[str, str] | Callable[[str], str],
/) -> SingleCell:
"""
Create a new SingleCell dataset with key(s) of `obsp` renamed.
Args:
mapping: the renaming to apply, either as a dictionary with the old
names as keys and the new names as values, or a function
that takes an old name and returns a new name
Returns:
A new SingleCell dataset with the key(s) of `obsp` renamed.
"""
check_types(mapping.keys(), 'mapping.keys()', str, 'strings')
check_types(mapping.values(), 'mapping.values()', str, 'strings')
if isinstance(mapping, dict):
for key, new_key in mapping.items():
if key not in self._obsp:
error_message = \
f'tried to rename {key!r}, which is not a key of obsp'
raise ValueError(error_message)
if new_key in self._obsp:
error_message = (
f'tried to rename obsp[{key!r}] to obsp[{new_key!r}], '
f'but obsp[{new_key!r}] already exists')
raise ValueError(error_message)
obsp = {mapping.get(key, key): value
for key, value in self._obsp.items()}
elif isinstance(mapping, Callable):
obsp = {}
for key, value in self._obsp.items():
new_key = mapping(key)
if not isinstance(new_key, str):
error_message = (
f'tried to rename obsp[{key!r}] to a non-string value '
f'of type {type(new_key).__name__!r}')
raise TypeError(error_message)
if new_key in self._obsp:
error_message = (
f'tried to rename obsp[{key!r}] to obsp[{new_key!r}], '
f'but obsp[{new_key!r}] already exists')
raise ValueError(error_message)
obsp[new_key] = value
else:
error_message = (
f'mapping must be a dictionary or function, but has type '
f'{type(mapping).__name__!r}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=obsp,
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def rename_varp(self,
mapping: dict[str, str] | Callable[[str], str],
/) -> SingleCell:
"""
Create a new SingleCell dataset with key(s) of `varp` renamed.
Args:
mapping: the renaming to apply, either as a dictionary with the old
names as keys and the new names as values, or a function
that takes an old name and returns a new name
Returns:
A new SingleCell dataset with the key(s) of `varp` renamed.
"""
check_types(mapping.keys(), 'mapping.keys()', str, 'strings')
check_types(mapping.values(), 'mapping.values()', str, 'strings')
if isinstance(mapping, dict):
for key, new_key in mapping.items():
if key not in self._varp:
error_message = \
f'tried to rename {key!r}, which is not a key of varp'
raise ValueError(error_message)
if new_key in self._varp:
error_message = (
f'tried to rename varp[{key!r}] to varp[{new_key!r}], '
f'but varp[{new_key!r}] already exists')
raise ValueError(error_message)
varp = {mapping.get(key, key): value
for key, value in self._varp.items()}
elif isinstance(mapping, Callable):
varp = {}
for key, value in self._varp.items():
new_key = mapping(key)
if not isinstance(new_key, str):
error_message = (
f'tried to rename varp[{key!r}] to a non-string value '
f'of type {type(new_key).__name__!r}')
raise TypeError(error_message)
if new_key in self._varp:
error_message = (
f'tried to rename varp[{key!r}] to varp[{new_key!r}], '
f'but varp[{new_key!r}] already exists')
raise ValueError(error_message)
varp[new_key] = value
else:
error_message = (
f'mapping must be a dictionary or function, but has type '
f'{type(mapping).__name__!r}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def rename_uns(self,
mapping: dict[str, str] | Callable[[str], str],
/) -> SingleCell:
"""
Create a new SingleCell dataset with key(s) of `uns` renamed.
Args:
mapping: the renaming to apply, either as a dictionary with the old
names as keys and the new names as values, or a function
that takes an old name and returns a new name
Returns:
A new SingleCell dataset with the key(s) of `uns` renamed.
"""
check_types(mapping.keys(), 'mapping.keys()', str, 'strings')
check_types(mapping.values(), 'mapping.values()', str, 'strings')
if isinstance(mapping, dict):
for key, new_key in mapping.items():
if key not in self._uns:
error_message = \
f'tried to rename {key!r}, which is not a key of uns'
raise ValueError(error_message)
if new_key in self._uns:
error_message = (
f'tried to rename uns[{key!r}] to uns[{new_key!r}], '
f'but uns[{new_key!r}] already exists')
raise ValueError(error_message)
uns = {mapping.get(key, key): value
for key, value in self._uns.items()}
elif isinstance(mapping, Callable):
uns = {}
for key, value in self._uns.items():
new_key = mapping(key)
if not isinstance(new_key, str):
error_message = (
f'tried to rename uns[{key!r}] to a non-string value '
f'of type {type(new_key).__name__!r}')
raise TypeError(error_message)
if new_key in self._uns:
error_message = (
f'tried to rename uns[{key!r}] to uns[{new_key!r}], '
f'but uns[{new_key!r}] already exists')
raise ValueError(error_message)
uns[new_key] = value
else:
error_message = (
f'mapping must be a dictionary or function, but has type '
f'{type(mapping).__name__!r}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp, uns=uns,
num_threads=self._num_threads)
[docs]
def cast_X(self, dtype: np._typing.DTypeLike, /) -> SingleCell:
"""
Cast `X` to the specified data type.
Args:
dtype: a NumPy data type
Returns:
A new SingleCell dataset with `X` cast to the specified data type.
"""
# Check that `X` is present
if self._X is None:
error_message = 'X is None, so casting it is not possible'
raise ValueError(error_message)
return SingleCell(X=self._X.astype(dtype),
obs=self._obs, var=self._var, obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def cast_obs(self,
dtypes: Mapping[pl.type_aliases.ColumnNameOrSelector |
pl.type_aliases.PolarsDataType,
pl.type_aliases.PolarsDataType] |
pl.type_aliases.PolarsDataType,
/,
*,
strict: bool = True) -> SingleCell:
"""
Cast column(s) of `obs` to the specified data type(s).
Args:
dtypes: a mapping of column names (or selectors) to data types, or
a single data type to which all columns will be cast
strict: whether to raise an error if a cast could not be performed
(for instance, due to numerical overflow)
Returns:
A new SingleCell dataset with column(s) of `obs` cast to the
specified data type(s).
"""
return SingleCell(X=self._X, obs=self._obs.cast(dtypes, strict=strict),
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def cast_var(self,
dtypes: Mapping[pl.type_aliases.ColumnNameOrSelector |
pl.type_aliases.PolarsDataType,
pl.type_aliases.PolarsDataType] |
pl.type_aliases.PolarsDataType,
/,
*,
strict: bool = True) -> SingleCell:
"""
Cast column(s) of `var` to the specified data type(s).
Args:
dtypes: a mapping of column names (or selectors) to data types, or
a single data type to which all columns will be cast
strict: whether to raise an error if a cast could not be performed
(for instance, due to numerical overflow)
Returns:
A new SingleCell dataset with column(s) of `var` cast to the
specified data type(s).
"""
return SingleCell(X=self._X, obs=self._obs,
var=self._var.cast(dtypes, strict=strict),
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def join_obs(self,
other: pl.DataFrame,
/,
*,
on: str | pl.Expr | Sequence[str | pl.Expr] | None = None,
left_on: str | pl.Expr | Sequence[str | pl.Expr] |
None = None,
right_on: str | pl.Expr | Sequence[str | pl.Expr] |
None = None,
suffix: str = '_right',
validate: Literal['m:m', 'm:1', '1:m', '1:1'] = 'm:m',
nulls_equal: bool = False,
coalesce: bool = True) -> SingleCell:
"""
Left-join `obs` with another DataFrame, using the same logic as
[`polars.DataFrame.join()`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html).
Args:
other: a polars DataFrame to join `obs` with
on: the name(s) of the join column(s) in both DataFrames
left_on: the name(s) of the join column(s) in `obs`
right_on: the name(s) of the join column(s) in `other`
suffix: a suffix to append to columns with a duplicate name
validate: checks whether the join is of the specified type. Can be:
- 'm:m' (many-to-many): the default, no checks performed.
- '1:1' (one-to-one): check that none of the values in
the join column(s) appear more than once in `obs` or
more than once in `other`.
- '1:m' (one-to-many): check that none of the values in
the join column(s) appear more than once in `obs`.
- 'm:1' (many-to-one): check that none of the values in
the join column(s) appear more than once in `other`.
nulls_equal: whether to include `null` as a valid value to join on.
By default, `null` values will never produce matches.
coalesce: if `True`, coalesce each of the pairs of join columns
(the columns in `on` or `left_on`/`right_on`) from `obs`
and `other` into a single column, filling missing values
from one with the corresponding values from the other.
If `False`, include both as separate columns, adding
`suffix` to the join columns from `other`.
Returns:
A new SingleCell dataset with the columns from `other` joined to
obs.
Note:
If a column of `on`, `left_on` or `right_on` is Enum in `obs` and
Categorical in `other` (or vice versa), or Enum in both but with
different categories in each, that pair of columns will be
automatically cast to a common Enum data type (with the union of
the categories) before joining.
"""
check_type(other, 'other', pl.DataFrame, 'a polars DataFrame')
left = self._obs
right = other
if on is None:
if left_on is None and right_on is None:
error_message = (
"either 'on' or both of 'left_on' and 'right_on' must be "
"specified")
raise ValueError(error_message)
elif left_on is None:
error_message = \
'right_on is specified, so left_on must be specified'
raise ValueError(error_message)
elif right_on is None:
error_message = \
'left_on is specified, so right_on must be specified'
raise ValueError(error_message)
left_columns = left.select(left_on)
right_columns = right.select(right_on)
else:
if left_on is not None:
error_message = "'on' is specified, so 'left_on' must be None"
raise ValueError(error_message)
if right_on is not None:
error_message = "'on' is specified, so 'right_on' must be None"
raise ValueError(error_message)
left_columns = left.select(on)
right_columns = right.select(on)
left_cast_dict = {}
right_cast_dict = {}
for left_column, right_column in zip(left_columns, right_columns):
left_dtype = left_column.dtype
right_dtype = right_column.dtype
if left_dtype == right_dtype:
continue
if (left_dtype == pl.Enum or left_dtype == pl.Categorical) and (
right_dtype == pl.Enum or right_dtype == pl.Categorical):
common_dtype = \
pl.Enum(pl.concat([left_column.cat.get_categories(),
right_column.cat.get_categories()])
.unique(maintain_order=True))
left_cast_dict[left_column.name] = common_dtype
right_cast_dict[right_column.name] = common_dtype
else:
error_message = (
f'obs[{left_column.name!r}] has data type '
f'{left_dtype.base_type()!r}, but '
f'other[{right_column.name!r}] has data type '
f'{right_dtype.base_type()!r}')
raise TypeError(error_message)
if left_cast_dict is not None:
left = left.cast(left_cast_dict)
right = right.cast(right_cast_dict)
obs = left.join(right, on=on, how='left', left_on=left_on,
right_on=right_on, suffix=suffix, validate=validate,
nulls_equal=nulls_equal, coalesce=coalesce,
maintain_order='left')
if len(obs) > len(self):
other_on = to_tuple(right_on if right_on is not None else on)
assert other.select(other_on).is_duplicated().any()
duplicate_column = other_on[0] if len(other_on) == 1 else \
next(column for column in other_on
if other[column].is_duplicated().any())
error_message = (
f'other[{duplicate_column!r}] contains duplicate values, so '
f'it must be deduplicated before being joined on')
raise ValueError(error_message)
return SingleCell(X=self._X, obs=obs, var=self._var, obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def join_var(self,
other: pl.DataFrame,
/,
*,
on: str | pl.Expr | Sequence[str | pl.Expr] | None = None,
left_on: str | pl.Expr | Sequence[str | pl.Expr] |
None = None,
right_on: str | pl.Expr | Sequence[str | pl.Expr] |
None = None,
suffix: str = '_right',
validate: Literal['m:m', 'm:1', '1:m', '1:1'] = 'm:m',
nulls_equal: bool = False,
coalesce: bool = True) -> SingleCell:
"""
Left-join `var` with another DataFrame, using the same logic as
[`polars.DataFrame.join()`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html).
Args:
other: a polars DataFrame to join `var` with
on: the name(s) of the join column(s) in both DataFrames
left_on: the name(s) of the join column(s) in `var`
right_on: the name(s) of the join column(s) in `other`
suffix: a suffix to append to columns with a duplicate name
validate: checks whether the join is of the specified type. Can be:
- 'm:m' (many-to-many): the default, no checks performed.
- '1:1' (one-to-one): check that none of the values in
the join column(s) appear more than once in `var` or
more than once in `other`.
- '1:m' (one-to-many): check that none of the values in
the join column(s) appear more than once in `var`.
- 'm:1' (many-to-one): check that none of the values in
the join column(s) appear more than once in `other`.
nulls_equal: whether to include `null` as a valid value to join on.
By default, `null` values will never produce matches.
coalesce: if `True`, coalesce each of the pairs of join columns
(the columns in `on` or `left_on`/`right_on`) from `obs`
and `other` into a single column, filling missing values
from one with the corresponding values from the other.
If `False`, include both as separate columns, adding
`suffix` to the join columns from `other`.
Returns:
A new SingleCell dataset with the columns from `other` joined to
var.
Note:
If a column of `on`, `left_on` or `right_on` is Enum in `obs` and
Categorical in `other` (or vice versa), or Enum in both but with
different categories in each, that pair of columns will be
automatically cast to a common Enum data type (with the union of
the categories) before joining.
"""
check_type(other, 'other', pl.DataFrame, 'a polars DataFrame')
left = self._var
right = other
if on is None:
if left_on is None and right_on is None:
error_message = (
"either 'on' or both of 'left_on' and 'right_on' must be "
"specified")
raise ValueError(error_message)
elif left_on is None:
error_message = \
'right_on is specified, so left_on must be specified'
raise ValueError(error_message)
elif right_on is None:
error_message = \
'left_on is specified, so right_on must be specified'
raise ValueError(error_message)
left_columns = left.select(left_on)
right_columns = right.select(right_on)
else:
if left_on is not None:
error_message = "'on' is specified, so 'left_on' must be None"
raise ValueError(error_message)
if right_on is not None:
error_message = "'on' is specified, so 'right_on' must be None"
raise ValueError(error_message)
left_columns = left.select(on)
right_columns = right.select(on)
left_cast_dict = {}
right_cast_dict = {}
for left_column, right_column in zip(left_columns, right_columns):
left_dtype = left_column.dtype
right_dtype = right_column.dtype
if left_dtype == right_dtype:
continue
if (left_dtype == pl.Enum or left_dtype == pl.Categorical) and (
right_dtype == pl.Enum or right_dtype == pl.Categorical):
common_dtype = \
pl.Enum(pl.concat([left_column.cat.get_categories(),
right_column.cat.get_categories()])
.unique(maintain_order=True))
left_cast_dict[left_column.name] = common_dtype
right_cast_dict[right_column.name] = common_dtype
else:
error_message = (
f'var[{left_column.name!r}] has data type '
f'{left_dtype.base_type()!r}, but '
f'other[{right_column.name!r}] has data type '
f'{right_dtype.base_type()!r}')
raise TypeError(error_message)
if left_cast_dict is not None:
left = left.cast(left_cast_dict)
right = right.cast(right_cast_dict)
var = left.join(right, on=on, how='left', left_on=left_on,
right_on=right_on, suffix=suffix, validate=validate,
nulls_equal=nulls_equal, coalesce=coalesce,
maintain_order='left')
if len(var) > len(self):
other_on = to_tuple(right_on if right_on is not None else on)
assert other.select(other_on).is_duplicated().any()
duplicate_column = other_on[0] if len(other_on) == 1 else \
next(column for column in other_on
if other[column].is_duplicated().any())
error_message = (
f'other[{duplicate_column!r}] contains duplicate values, so '
f'it must be deduplicated before being joined on')
raise ValueError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=var, obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def peek_obs(self, *, row: int = 0) -> None:
"""
Print a row of `obs` (the first row, by default) with each column on
its own line.
Args:
row: the index of the row to print
"""
check_type(row, 'row', int, 'an integer')
with pl.Config(tbl_rows=-1):
print(self._obs[row].unpivot(variable_name='column'))
[docs]
def peek_var(self, *, row: int = 0) -> None:
"""
Print a row of `var` (the first row, by default) with each column on
its own line.
Args:
row: the index of the row to print
"""
check_type(row, 'row', int, 'an integer')
with pl.Config(tbl_rows=-1):
print(self._var[row].unpivot(variable_name='column'))
[docs]
def subsample_obs(self,
*,
n: int | np.integer | None = None,
fraction: int | float | np.integer | np.floating |
None = None,
QC_column: SingleCellColumn | None = 'passed_QC',
by_column: SingleCellColumn | None = None,
subsample_column: str | None = None,
seed: int | np.integer = 0,
overwrite: bool = False,
num_threads: int | np.integer | None = None) -> \
SingleCell:
"""
Subsample a specific number or fraction of cells.
Args:
n: the number of cells to return; mutually exclusive with
`fraction`
fraction: the fraction of cells to return; mutually exclusive with
`n`
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will not be
selected when subsampling, and will not count towards
the denominator of `fraction`; QC_column will not appear
in the returned SingleCell dataset, since it would be
redundant.
by_column: an optional String, Enum, Categorical, or integer column
of `obs` to subsample by. Can be a column name, a
polars expression, a polars Series, a 1D NumPy array, or
a function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Specifying
`by_column` ensures that the same fraction of cells with
each value of `by_column` are subsampled. When combined
with `n`, to make sure the total number of samples is
exactly `n`, some of the smallest groups may be
oversampled by one element, or some of the largest
groups may be undersampled by one element. Can contain
`null` entries: the corresponding cells will not be
included in the result.
subsample_column: an optional name of a Boolean column to add to
obs indicating the subsampled cells; if `None`,
subset to these cells instead
seed: the random seed to use when subsampling
overwrite: if `True`, overwrite `subsample_column` if already
present in `obs`, instead of raising an error. Must be
`False` when `subsample_column` is `None`.
num_threads: the number of threads to use when subsetting `X` to
the sampled cells. Set `num_threads=-1` to use all
available cores, as determined by `os.cpu_count()`. By
default (`num_threads=None`), use `self.num_threads`
cores. Does not affect the subsampled SingleCell
dataset's `num_threads`; this will always be the same
as the original dataset's `num_threads`. Can only be
specified when `subsample_column` is `None`.
Returns:
A new SingleCell dataset subset to the subsampled cells, or if
`subsample_column` is specified, the full dataset with
`subsample_column` added to `obs`. If `QC_column` is a string and
a QC column exists in the original dataset, it will be removed from
the subsampled dataset, since all subsampled cells pass QC and it
would be redundant.
"""
check_type(overwrite, 'overwrite', bool, 'Boolean')
if subsample_column is not None:
check_type(subsample_column, 'subsample_column', str, 'a string')
if not overwrite and subsample_column in self._obs:
error_message = (
f'subsample_column {subsample_column!r} is already a '
f'column of obs; did you already run subsample_obs()? Set '
f'overwrite=True to overwrite.')
raise ValueError(error_message)
elif overwrite:
error_message = \
'overwrite must be False when subsample_column is None'
raise ValueError(error_message)
if n is not None and fraction is not None:
error_message = 'only one of n and fraction can be specified'
raise ValueError(error_message)
if n is not None:
check_type(n, 'n', int, 'a positive integer')
check_bounds(n, 'n', 1)
elif fraction is not None:
check_type(fraction, 'fraction', float,
'a floating-point number between 0 and 1')
check_bounds(fraction, 'fraction', 0, 1, left_open=True,
right_open=True)
else:
error_message = 'either n or fraction must be specified'
raise ValueError(error_message)
# Subsampling subsets to QCed cells by definition, so we should drop
# the QC column from the subsampled dataset at the end - but only if
# it was passed as an explicit column name.
if QC_column is not None:
QC_column_is_string = isinstance(QC_column, str)
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
check_type(seed, 'seed', int, 'an integer')
if subsample_column is None:
num_threads = self._process_num_threads(num_threads)
else:
if num_threads is not None:
error_message = (
'num_threads can only be specified when subsample_column '
'is None')
raise ValueError(error_message)
# Get the indices of cells passing QC, used both for the index-based
# fast path and to get `N` cheaply.
if QC_column is not None:
QC_indices = np.flatnonzero(QC_column.to_numpy())
N = len(QC_indices)
else:
N = len(self._obs)
# When there is a `QC_column`, the boolean mask and indices we compute
# below live in "QCed space" (`[0, number of cells passing QC)`); we
# back-project to the full dataset afterwards
if by_column is not None:
# Grouped: use polars' `shuffle().over()`
by_column = self._get_column(
'obs', by_column, 'by_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_column=QC_column, allow_null=True)
if QC_column is not None:
by_column = by_column.filter(QC_column)
by_frame = by_column.to_frame()
by_name = by_column.name
if n is not None:
# Get a vector of the number of elements to sample per group.
# The total sample size should exactly match the original n; if
# necessary, oversample the smallest groups or undersample the
# largest groups to make this happen.
group_counts = by_frame\
.group_by(by_name)\
.agg(pl.len(), n=(n / len(by_column) * pl.len())
.round().cast(pl.Int32))\
.drop_nulls(by_name)
diff = n - group_counts['n'].sum()
if diff != 0:
group_counts = group_counts\
.sort('len', descending=diff < 0)\
.with_columns(n=pl.col.n +
pl.int_range(pl.len(), dtype=pl.Int32)
.lt(abs(diff)).cast(pl.Int32) *
pl.lit(diff).sign())
selected = by_frame\
.join(group_counts, on=by_name)\
.select(pl.int_range(pl.len(), dtype=pl.Int32)
.shuffle(seed=seed)
.over(by_name)
.lt(pl.col.n))\
.to_series()
else:
selected = by_frame\
.select(pl.int_range(pl.len(), dtype=pl.Int32)
.shuffle(seed=seed)
.over(by_name)
.lt((fraction * pl.len().over(by_name)).round()))\
.to_series()
if subsample_column is None:
# Get the (sorted) indices of the subsampled cells
indices = np.flatnonzero(selected.to_numpy())
if QC_column is not None:
# Back-project to the full dataset
indices = QC_indices[indices]
elif QC_column is not None:
# Back-project to the full dataset
selected = pl.when(QC_column)\
.then(selected.gather(QC_column.cum_sum() - QC_column))
else:
# Ungrouped: `np.random.choice()`
if n is None:
n = int(round(fraction * N))
indices = \
np.random.default_rng(seed).choice(N, size=n, replace=False)
if subsample_column is None:
# Sort to preserve the original ordering of obs
indices.sort()
if QC_column is not None:
# Back-project to the full dataset
indices = QC_indices[indices]
else:
selected = np.zeros(N, dtype=bool)
selected[indices] = True
selected = pl.Series(selected)
if QC_column is not None:
selected = pl.when(QC_column)\
.then(selected.gather(QC_column.cum_sum() - QC_column))
if subsample_column is None:
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
sc = self[indices]
sc._X.num_threads = sc._num_threads # i.e. self._num_threads
finally:
self._X._num_threads = original_num_threads
else:
sc = self.with_columns_obs(selected.alias(subsample_column))
# Drop the now-redundant QC column, since all subsampled cells pass QC
if QC_column is not None and QC_column_is_string:
sc._obs = sc._obs.drop(QC_column.name)
return sc
[docs]
def subsample_var(self,
*,
n: int | np.integer | None = None,
fraction: int | float | np.integer | np.floating |
None = None,
by_column: SingleCellColumn | None = None,
subsample_column: str | None = None,
seed: int | np.integer = 0,
overwrite: bool = False,
num_threads: int | np.integer | None = None) -> \
SingleCell:
"""
Subsample a specific number or fraction of genes.
Args:
n: the number of genes to return; mutually exclusive with
`fraction`
fraction: the fraction of genes to return; mutually exclusive with
`n`
by_column: an optional String, Enum, Categorical, or integer column
of `var` to subsample by. Can be a column name, a
polars expression, a polars Series, a 1D NumPy array, or
a function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Specifying
`by_column` ensures that the same fraction of genes with
each value of `by_column` are subsampled. When combined
with `n`, to make sure the total number of samples is
exactly `n`, some of the smallest groups may be
oversampled by one element, or some of the largest
groups may be undersampled by one element. Can contain
`null` entries: the corresponding genes will not be
included in the result.
subsample_column: an optional name of a Boolean column to add to
var indicating the subsampled genes; if `None`,
subset to these genes instead
seed: the random seed to use when subsampling
overwrite: if `True`, overwrite `subsample_column` if already
present in `var`, instead of raising an error. Must be
`False` when `subsample_column` is `None`.
num_threads: the number of threads to use when subsetting `X` to
the sampled genes. Set `num_threads=-1` to use all
available cores, as determined by `os.cpu_count()`. By
default (`num_threads=None`), use `self.num_threads`
cores. Does not affect the subsampled SingleCell
dataset's `num_threads`; this will always be the same
as the original dataset's `num_threads`. Can only be
specified when `subsample_column` is `None`.
Returns:
A new SingleCell dataset subset to the subsampled genes, or if
`subsample_column` is specified, the full dataset with
`subsample_column` added to `var`.
"""
check_type(overwrite, 'overwrite', bool, 'Boolean')
if subsample_column is not None:
check_type(subsample_column, 'subsample_column', str, 'a string')
if not overwrite and subsample_column in self._var:
error_message = (
f'subsample_column {subsample_column!r} is already a '
f'column of var; did you already run subsample_var()? Set '
f'overwrite=True to overwrite.')
raise ValueError(error_message)
elif overwrite:
error_message = \
'overwrite must be False when subsample_column is None'
raise ValueError(error_message)
if n is not None and fraction is not None:
error_message = 'only one of n and fraction can be specified'
raise ValueError(error_message)
if n is not None:
check_type(n, 'n', int, 'a positive integer')
check_bounds(n, 'n', 1)
elif fraction is not None:
check_type(fraction, 'fraction', float,
'a floating-point number between 0 and 1')
check_bounds(fraction, 'fraction', 0, 1, left_open=True,
right_open=True)
else:
error_message = 'either n or fraction must be specified'
raise ValueError(error_message)
check_type(seed, 'seed', int, 'an integer')
if subsample_column is None:
num_threads = self._process_num_threads(num_threads)
else:
if num_threads is not None:
error_message = (
'num_threads can only be specified when subsample_column '
'is None')
raise ValueError(error_message)
if by_column is not None:
# Grouped: use polars' `shuffle().over()`
by_column = self._get_column(
'var', by_column, 'by_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
allow_null=True)
by_frame = by_column.to_frame()
by_name = by_column.name
if n is not None:
# Get a vector of the number of elements to sample per group.
# The total sample size should exactly match the original n; if
# necessary, oversample the smallest groups or undersample the
# largest groups to make this happen.
group_counts = by_frame\
.group_by(by_name)\
.agg(pl.len(), n=(n / len(by_column) * pl.len())
.round().cast(pl.Int32))\
.drop_nulls(by_name)
diff = n - group_counts['n'].sum()
if diff != 0:
group_counts = group_counts\
.sort('len', descending=diff < 0)\
.with_columns(n=pl.col.n +
pl.int_range(pl.len(), dtype=pl.Int32)
.lt(abs(diff)).cast(pl.Int32) *
pl.lit(diff).sign())
selected = by_frame\
.join(group_counts, on=by_name)\
.select(pl.int_range(pl.len(), dtype=pl.Int32)
.shuffle(seed=seed)
.over(by_name)
.lt(pl.col.n))\
.to_series()
else:
selected = by_frame\
.select(pl.int_range(pl.len(), dtype=pl.Int32)
.shuffle(seed=seed)
.over(by_name)
.lt((fraction * pl.len().over(by_name)).round()))\
.to_series()
if subsample_column is None:
indices = np.flatnonzero(selected.to_numpy())
else:
# Ungrouped: `np.random.choice()`
N = len(self._var)
if n is None:
n = int(round(fraction * N))
indices = \
np.random.default_rng(seed).choice(N, size=n, replace=False)
if subsample_column is None:
indices.sort()
else:
selected = np.zeros(N, dtype=bool)
selected[indices] = True
selected = pl.Series(selected)
if subsample_column is None:
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
sc = self[:, indices]
sc._X.num_threads = sc._num_threads # i.e. self._num_threads
finally:
self._X._num_threads = original_num_threads
else:
sc = self.with_columns_var(selected.alias(subsample_column))
[docs]
def pipe(self,
function: Callable[[SingleCell, ...], Any],
/,
*args: Any,
**kwargs: Any) -> Any:
"""
Apply a function to a SingleCell dataset.
`sc.pipe(func)` is equivalent to `func(sc)`. `sc.pipe(func, 1, a=2)` is
equivalent to `func(sc, 1, a=2)`.
Args:
function: the function to apply to the SingleCell dataset. It must
take a SingleCell dataset as its first argument, and can
return any value. The function may also allow other
arguments after the count matrix, which can be specified
via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
The result of applying the function to this SingleCell dataset.
"""
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return function(self, *args, **kwargs)
[docs]
def pipe_X(self,
function: Callable[[csr_array | csc_array, ...],
csr_array | csc_array],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a SingleCell dataset's `X`.
`sc = sc.pipe_X(func)` is equivalent to `sc.X = func(sc.X)`.
`sc = sc.pipe_X(func, 1, a=2)` is equivalent to
`sc.X = func(sc.X, 1, a=2)`.
Args:
function: the function to apply to `X`. It must take the old `X` as
its first argument and return the new `X`. The function
may also take other arguments after `X`, which can be
specified via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to `X`.
"""
# Check that `X` is present
if self._X is None:
error_message = 'X is None, so piping it is not possible'
raise ValueError(error_message)
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=function(self._X, *args, **kwargs), obs=self._obs,
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def pipe_obs(self,
function: Callable[[pl.DataFrame, ...], pl.DataFrame],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a SingleCell dataset's `obs`.
`sc = sc.pipe_obs(func)` is equivalent to `sc.obs = func(sc.obs)`.
`sc = sc.pipe_obs(func, 1, a=2)` is equivalent to
`sc.obs = func(sc.obs, 1, a=2)`.
Args:
function: the function to apply to `obs`. It must take the old
`obs` as its first argument and return the new `obs`. The
function may also take other arguments after `obs`, which
can be specified via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
obs.
"""
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=function(self._obs, *args, **kwargs),
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def pipe_var(self,
function: Callable[[pl.DataFrame, ...], pl.DataFrame],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a SingleCell dataset's `var`.
`sc = sc.pipe_var(func)` is equivalent to `sc.var = func(sc.var)`.
`sc = sc.pipe_var(func, 1, a=2)` is equivalent to
`sc.var = func(sc.var, 1, a=2)`.
Args:
function: the function to apply to `var`. It must take the old
`var` as its first argument and return the new `var`. The
function may also take other arguments after `var`, which
can be specified via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
var.
"""
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs,
var=function(self._var, *args, **kwargs),
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def pipe_obsm(self,
function: Callable[[dict[str, np.ndarray |
pl.DataFrame], ...],
dict[str, np.ndarray |
pl.DataFrame]],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a SingleCell dataset's `obsm`.
`sc = sc.pipe_obsm(func)` is equivalent to `sc.obsm = func(sc.obsm)`.
`sc = sc.pipe_obsm(func, 1, a=2)` is equivalent to
`sc.obsm = func(sc.obsm, 1, a=2)`.
To apply a function to a specific key of `obsm`, rather than to `obsm`
as a whole, use `pipe_obsm_key()`.
Args:
function: the function to apply to `obsm`. It must take the old
`obsm` as its first argument and return the new `obsm`.
The function may also take other arguments after `obsm`,
which can be specified via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
obsm.
"""
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=function(self._obsm, *args, **kwargs),
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def pipe_obsm_key(self,
key: str,
function: Callable[[np.ndarray |
pl.DataFrame, ...],
np.ndarray | pl.DataFrame],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a specific key in a SingleCell dataset's `obsm`.
`sc = sc.pipe_obsm_key(func, key)` is equivalent to
`sc.obsm[key] = func(sc.obsm[key])`.
`sc = sc.pipe_obsm_key(key, func, 1, a=2)` is equivalent to
`sc.obsm[key] = func(sc.obsm[key], 1, a=2)`.
To apply a function to `obsm` as a whole, rather than to a specific key
of `obsm`, use `pipe_obsm()`.
Args:
key: the key in `obsm` to which the function will be applied.
function: the function to apply to `obsm[key]`. It must take the
old `obsm[key]` as its first argument and return the new
`obsm[key]`. The function may also take other arguments
after `obsm[key]`, which can be specified via `args` and
`kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
`obsm[key]`.
"""
# Check that `key` is a key in `obsm`
check_type(key, 'key', str, 'a string')
if key not in self._obsm:
error_message = f'{key!r} is not a key in obsm'
raise ValueError(error_message)
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm | {
key: function(self._obsm[key], *args, **kwargs)},
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def pipe_varm(self,
function: Callable[[dict[str, np.ndarray |
pl.DataFrame], ...],
dict[str, np.ndarray |
pl.DataFrame]],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a SingleCell dataset's `varm`.
`sc = sc.pipe_varm(func)` is equivalent to `sc.varm = func(sc.varm)`.
`sc = sc.pipe_varm(func, 1, a=2)` is equivalent to
`sc.varm = func(sc.varm, 1, a=2)`.
To apply a function to a specific key of `varm`, rather than to `varm`
as a whole, use `pipe_varm_key()`.
Args:
function: the function to apply to `varm`. It must take the old
`varm` as its first argument and return the new `varm`.
The function may also take other arguments after `varm`,
which can be specified via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
varm.
"""
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm,
varm=function(self._varm, *args, **kwargs),
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def pipe_varm_key(self,
key: str,
function: Callable[[np.ndarray |
pl.DataFrame, ...],
np.ndarray | pl.DataFrame],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a specific key in a SingleCell dataset's `varm`.
`sc = sc.pipe_varm_key(func, key)` is equivalent to
`sc.varm[key] = func(sc.varm[key])`.
`sc = sc.pipe_varm_key(key, func, 1, a=2)` is equivalent to
`sc.varm[key] = func(sc.varm[key], 1, a=2)`.
To apply a function to `varm` as a whole, rather than to a specific key
of `varm`, use `pipe_varm()`.
Args:
key: the key in `varm` to which the function will be applied.
function: the function to apply to `varm[key]`. It must take the
old `varm[key]` as its first argument and return the new
`varm[key]`. The function may also take other arguments
after `varm[key]`, which can be specified via `args` and
`kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
`varm[key]`.
"""
# Check that `key` is a key in `varm`
check_type(key, 'key', str, 'a string')
if key not in self._varm:
error_message = f'{key!r} is not a key in varm'
raise ValueError(error_message)
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm,
varm=self._varm | {
key: function(self._varm[key], *args, **kwargs)},
obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def pipe_obsp(self,
function: Callable[[dict[str, csr_array | csc_array], ...],
dict[str, csr_array | csc_array]],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a SingleCell dataset's `obsp`.
`sc = sc.pipe_obsp(func)` is equivalent to `sc.obsp = func(sc.obsp)`.
`sc = sc.pipe_obsp(func, 1, a=2)` is equivalent to
`sc.obsp = func(sc.obsp, 1, a=2)`.
Args:
function: the function to apply to `obsp`. It must take the old
`obsp` as its first argument and return the new `obsp`.
The function may also take other arguments after `obsp`,
which can be specified via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
To apply a function to a specific key of `obsp`, rather than to `obsp`
as a whole, use `pipe_obsp_key()`.
Returns:
A new SingleCell dataset where the function has been applied to
obsp.
"""
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp=function(self._obsp, *args, **kwargs),
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def pipe_obsp_key(self,
key: str,
function: Callable[[csr_array | csc_array, ...],
csr_array | csc_array],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a specific key in a SingleCell dataset's `obsp`.
`sc = sc.pipe_obsp_key(func, key)` is equivalent to
`sc.obsp[key] = func(sc.obsp[key])`.
`sc = sc.pipe_obsp_key(key, func, 1, a=2)` is equivalent to
`sc.obsp[key] = func(sc.obsp[key], 1, a=2)`.
To apply a function to `obsp` as a whole, rather than to a specific key
of `obsp`, use `pipe_obsp()`.
Args:
key: the key in `obsp` to which the function will be applied.
function: the function to apply to `obsp[key]`. It must take the
old `obsp[key]` as its first argument and return the new
`obsp[key]`. The function may also take other arguments
after `obsp[key]`, which can be specified via `args` and
`kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
`obsp[key]`.
"""
# Check that `key` is a key in `obsp`
check_type(key, 'key', str, 'a string')
if key not in self._obsp:
error_message = f'{key!r} is not a key in obsp'
raise ValueError(error_message)
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp=self._obsp | {
key: function(self._obsp[key], *args, **kwargs)},
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def pipe_varp(self,
function: Callable[[dict[str, csr_array | csc_array], ...],
dict[str, csr_array | csc_array]],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a SingleCell dataset's `varp`.
`sc = sc.pipe_varp(func)` is equivalent to `sc.varp = func(sc.varp)`.
`sc = sc.pipe_varp(func, 1, a=2)` is equivalent to
`sc.varp = func(sc.varp, 1, a=2)`.
To apply a function to a specific key of `varp`, rather than to `varp`
as a whole, use `pipe_varp_key()`.
Args:
function: the function to apply to `varp`. It must take the old
`varp` as its first argument and return the new `varp`.
The function may also take other arguments after `varp`,
which can be specified via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
varp.
"""
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=function(self._varp, *args, **kwargs),
uns=self._uns, num_threads=self._num_threads)
[docs]
def pipe_varp_key(self,
key: str,
function: Callable[[csr_array | csc_array, ...],
csr_array | csc_array],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a specific key in a SingleCell dataset's `varp`.
`sc = sc.pipe_varp_key(func, key)` is equivalent to
`sc.varp[key] = func(sc.varp[key])`.
`sc = sc.pipe_varp_key(key, func, 1, a=2)` is equivalent to
`sc.varp[key] = func(sc.varp[key], 1, a=2)`.
To apply a function to `varp` as a whole, rather than to a specific key
of `varp`, use `pipe_varp()`.
Args:
key: the key in `varp` to which the function will be applied.
function: the function to apply to `varp[key]`. It must take the
old `varp[key]` as its first argument and return the new
`varp[key]`. The function may also take other arguments
after `varp[key]`, which can be specified via `args` and
`kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
`varp[key]`.
"""
# Check that `key` is a key in `varp`
check_type(key, 'key', str, 'a string')
if key not in self._varp:
error_message = f'{key!r} is not a key in varp'
raise ValueError(error_message)
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp | {
key: function(self._varp[key], *args, **kwargs)},
uns=self._uns, num_threads=self._num_threads)
[docs]
def pipe_uns(self,
function: Callable[[UnsDict, ...], UnsDict],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a SingleCell dataset's `uns`.
`sc = sc.pipe_uns(func)` is equivalent to `sc.uns = func(sc.uns)`.
`sc = sc.pipe_uns(func, 1, a=2)` is equivalent to
`sc.uns = func(sc.uns, 1, a=2)`.
To apply a function to a specific key of `uns`, rather than to `uns` as
a whole, use `pipe_uns_key()`.
Args:
function: the function to apply to `uns`. It must take the old
`uns` as its first argument and return the new `uns`. The
function may also take other arguments after `uns`, which
can be specified via `args` and `kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
uns.
"""
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp,
uns=function(self._uns, *args, **kwargs),
num_threads=self._num_threads)
[docs]
def pipe_uns_key(self,
key: str,
function: Callable[[UnsDict, ...], UnsDict],
/,
*args: Any,
**kwargs: Any) -> SingleCell:
"""
Apply a function to a specific key in a SingleCell dataset's `uns`.
`sc = sc.pipe_uns_key(func, key)` is equivalent to
`sc.uns[key] = func(sc.uns[key])`.
`sc = sc.pipe_uns_key(key, func, 1, a=2)` is equivalent to
`sc.uns[key] = func(sc.uns[key], 1, a=2)`.
To apply a function to `uns` as a whole, rather than to a specific key
of `uns`, use `pipe_uns()`.
Args:
key: the key in `uns` to which the function will be applied.
function: the function to apply to `uns[key]`. It must take the
old `uns[key]` as its first argument and return the new
`uns[key]`. The function may also take other arguments
after `uns[key]`, which can be specified via `args` and
`kwargs`.
*args: the positional arguments to the function
**kwargs: the keyword arguments to the function
Returns:
A new SingleCell dataset where the function has been applied to
`uns[key]`.
"""
# Check that `key` is a key in `uns`
check_type(key, 'key', str, 'a string')
if key not in self._uns:
error_message = f'{key!r} is not a key in uns'
raise ValueError(error_message)
# Check that `function` is callable
if not callable(function):
error_message = (
f'function is not callable; it has type '
f'{type(function).__name__}')
raise TypeError(error_message)
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp,
uns=self._uns | {
key: function(self._uns[key], *args, **kwargs)},
num_threads=self._num_threads)
[docs]
def qc_metrics(self,
*,
num_counts_column: str = 'num_counts',
num_genes_column: str = 'num_genes',
mito_fraction_column: str = 'mito_fraction',
allow_float: bool = False,
overwrite: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Adds quality-control metrics to `obs` for each cell: the sum of counts
across all genes (`num_counts`), the number of genes with non-zero
expression (`num_genes`), and the fraction of counts that are
mitochondrial (`mito_fraction`).
This function is intended to be run before `qc()` for users interested
in better understanding the quality of their dataset. It is not a
required step, since `qc()` calculates its own filters internally.
Args:
num_counts_column: the name of an integer column to be added to
`obs` containing each cell's sum of counts
across all genes
num_genes_column: the name of an integer column to be added to
`obs` containing each cell's number of genes
with non-zero expression
mito_fraction_column: the name of an integer column to be added to
`obs` containing each cell's fraction of
counts that are mitochondrial (i.e. from
genes starting with `'MT'`)
allow_float: if `False`, raise an error if `self.X.dtype` is
floating-point (suggesting the user may not be using
the raw counts); if `True`, disable this sanity check.
Note that all steps except mitochondrial percent
filtering give the same result on normalized counts,
so as long as `max_mito_fraction=None` is specified
(not typically recommended), this function will give
the same result on raw and normalized counts.
overwrite: if `False`, raise an error if any of the new columns
already exist in `obs`; if `True`, overwrite them.
num_threads: the number of threads to use when calculating the
quality-control metrics. Set `num_threads=-1` to use
all available cores, as determined by
`os.cpu_count()`, or leave unset to use
`self.num_threads` cores. Does not affect the
resulting SingleCell dataset's `num_threads`; this
will always be the same as the original dataset's
`num_threads`.
Returns:
A new SingleCell dataset with the three metrics added as columns of
`obs`.
Note:
This function will give an incorrect output when run on normalized
data, since floating-point counts will be truncated to integers.
Note:
This function may give an incorrect output if the count matrix
contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is
not checked for, due to speed considerations. In the unlikely event
that your dataset contains explicit zeros, remove them by running
`sc.X.eliminate_zeros()` (an in-place operation) first.
"""
check_type(overwrite, 'overwrite', bool, 'Boolean')
for column_name, column in (
('num_counts_column', num_counts_column),
('num_genes_column', num_genes_column),
('mito_fraction_column', mito_fraction_column)):
check_type(column, column_name, str, 'a string')
if not overwrite and column in self._obs:
error_message = (
f'{column_name} {column!r} is already a column of obs; '
f'did you already run qc_metrics()? Set overwrite=True to '
f'overwrite.')
raise ValueError(error_message)
check_type(allow_float, 'allow_float', bool, 'Boolean')
num_threads = self._process_num_threads(num_threads)
# Check that `X` is present
X = self._X
if X is None:
error_message = \
'X is None, so calculating QC metrics is not possible'
raise ValueError(error_message)
# If `allow_float=False`, raise an error if `X` is floating-point
if not allow_float and np.issubdtype(X.dtype, np.floating):
error_message = (
f'qc_metrics() requires raw counts but X has data type '
f'{str(X.dtype)!r}, a floating-point data type. If you are '
f'sure that all values are raw integer counts, i.e. that '
f'(X.data == X.data.astype(int)).all(), then set '
f'allow_float=True.')
raise TypeError(error_message)
# Check that `obs_names` and `var_names` are unique
num_unique = self.obs_names.n_unique()
if num_unique < len(self._obs):
error_message = (
f'obs_names contains {num_unique - len(self._obs):,} '
f'duplicates; deduplicate with make_obs_names_unique()')
raise ValueError(error_message)
num_unique = self.var_names.n_unique()
if num_unique < len(self._var):
error_message = (
f'var_names contains {num_unique - len(self._var):,} '
f'duplicates; deduplicate with make_var_names_unique()')
raise ValueError(error_message)
# Compute total counts per cell
num_cells = X.shape[0]
num_counts = np.empty(num_cells, dtype=np.uint32)
num_genes = np.empty(num_cells, dtype=np.uint32)
mito_fraction = np.empty(num_cells, dtype=np.float32)
var_names = self.var_names
if var_names.dtype != pl.String:
var_names = var_names.cast(pl.String)
mt_genes = var_names.str.to_uppercase().str.starts_with('MT-')
if isinstance(X, csr_array):
qc_metrics_csr(data=X.data, indices=X.indices, indptr=X.indptr,
mt_genes=mt_genes.to_numpy(), num_counts=num_counts,
num_genes_per_cell=num_genes,
mito_fraction=mito_fraction,
num_threads=num_threads)
else:
qc_metrics_csc(data=X.data, indices=X.indices, indptr=X.indptr,
mt_genes=mt_genes.to_numpy(), num_counts=num_counts,
num_genes_per_cell=num_genes,
mito_fraction=mito_fraction,
num_threads=num_threads)
return self.with_columns_obs(
pl.Series(num_counts_column, num_counts),
pl.Series(num_genes_column, num_genes),
pl.Series(mito_fraction_column, mito_fraction))
[docs]
def qc(self,
*,
custom_filter: SingleCellColumn | None = None,
subset: bool = False,
QC_column: str = 'passed_QC',
max_mito_fraction: int | float | np.integer | np.floating |
None = 0.05,
min_genes: int | np.integer | None = 100,
nonzero_MALAT1: bool = True,
remove_doublets: bool = False,
batch_column: SingleCellColumn | None = None,
doublet_fraction: float | np.floating | None = None,
num_doublet_genes: int | np.integer = 500,
allow_float: bool = False,
overwrite: bool = False,
verbose: bool = True,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Adds a Boolean column to `obs` indicating which cells passed quality
control (QC), or subsets to these cells if `subset=True`.
By default, filters to cells with ≤5% mitochondrial reads, ≥100 genes
detected, and non-zero MALAT1 or Malat1 expression. Can also filter out
doublets when `remove_doublets=True`.
Raises an error if any cell names appear more than once in `obs_names`
(they can be deduplicated with `make_obs_names_unique()`) or any gene
names appear more than once in `var_names` (they can be deduplicated
with `make_var_names_unique()`).
Args:
custom_filter: an optional Boolean column of `obs` containing a
filter to apply on top of the other QC filters;
`True` elements will be kept. Can be a column name,
a polars expression, a polars Series, a 1D NumPy
array, or a function that takes in this SingleCell
dataset and returns a polars Series or 1D NumPy
array.
subset: whether to subset to cells passing QC, instead of merely
adding a `QC_column` to `obs`. This will roughly double
memory usage, but speed up subsequent operations.
QC_column: the name of a Boolean column to add to `obs` indicating
which cells passed QC, if `subset=False`. Gives an error
if `obs` already has a column with this name, unless
`overwrite=True`.
max_mito_fraction: if not `None`, filter to cells with <= this
fraction of mitochondrial counts (i.e. from
genes starting with `'MT'`. The default of 5%
matches Seurat's recommended value.
min_genes: if not `None`, filter to cells with >= this many genes
detected (with non-zero count). The default of 100
matches Scanpy's recommended value, while Seurat
recommends a minimum of 200.
nonzero_MALAT1: if `True`, filter out cells with 0 expression of
the nuclear-expressed lncRNA MALAT1, which
[likely represent](https://biorxiv.org/content/10.1101/2024.07.14.603469v1)
empty droplets or poor-quality cells. There must be
exactly one gene in `obs_names` with the name
`'MALAT1'` or `'Malat1'` to use this filter.
remove_doublets: if `True`, remove predicted doublets (see
`find_doublets()`). Doublet detection uses the
cxds algorithm to score each cell, then thresholds
this continuous score to a binary one (doublet
versus non-doublet) using a threshold derived from
simulated doublets.
batch_column: an optional String, Enum, Categorical, or integer
column of `obs` indicating which batch each cell is
from. Can be a column name, a polars expression, a
polars Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a polars
Series or 1D NumPy array. Only used during doublet
detection; doublet detection will be performed
separately for each batch; cells where `batch_column`
is `null` will collectively be treated as a single
batch. Set to `None` if all cells belong to the same
sequencing batch. Can only be specified when
`remove_doublets=True`.
doublet_fraction: an optional fraction of cells (within each batch,
if `batch_column` is specified) to be classified
as doublets. If `None`, automatically detect the
threshold via the approach described in
`find_doublets()`.
num_doublet_genes: the number of highly variable genes, i.e. genes
expressed in as close to 50% of cells as
possible, to use during doublet detection. This
parameter usually has a minimal influence on
accuracy as long as it is sufficiently large (in
the hundreds), so increasing it further will
mainly just increase runtime. If
`num_doublet_genes` is greater than the number
of genes in the dataset, all genes will be used.
allow_float: if `False`, raise an error if `self.X.dtype` is
floating-point (suggesting the user may not be using
the raw counts); if `True`, disable this sanity check.
Note that all steps except mitochondrial percent
filtering give the same result on normalized counts,
so if `max_mito_fraction=None` were specified (not
recommended), this function would give the same result
on raw and normalized counts.
overwrite: if `False`, raise an error if `uns['QCed']` is `True`
(indicating the dataset has already been QCed) or
`QC_column` is already present in `obs`; if `True`,
disable these two sanity checks and, when
`subset=False`, overwrite `QC_column` if present.
verbose: whether to print how many cells were filtered out at each
step of the QC process
num_threads: the number of threads to use when filtering based on
mitochondrial counts and MALAT1 expression, and for
doublet detection. Set `num_threads=-1` to use all
available cores, as determined by `os.cpu_count()`,
or leave unset to use `self.num_threads` cores. Does
not affect the QCed SingleCell dataset's
`num_threads`; this will always be the same as the
original dataset's `num_threads`.
Returns:
A new SingleCell dataset with `QC_column` added to `obs` (or subset
to QCed cells if `subset=True`) and `uns['QCed']` set to `True`.
"""
# Check that `QC_column` is a string
check_type(QC_column, 'QC_column', str, 'a string')
# Check that `overwrite` is Boolean;
check_type(overwrite, 'overwrite', bool, 'Boolean')
# If `overwrite=False`, check that `uns['QCed']=False` and that
# `QC_column` is not present in `obs`
if not overwrite:
if self._uns['QCed']:
error_message = (
"uns['QCed'] is True; did you already run qc()? Specify "
"overwrite=True, set uns['QCed'] = False, or run "
"with_uns(QCed=False) to bypass this check.")
raise ValueError(error_message)
if QC_column in self._obs:
error_message = (
f'QC_column {QC_column!r} is already a column of obs; did '
f'you already run qc()? Set overwrite=True to overwrite.')
raise ValueError(error_message)
# Check that `X` is present
X = self._X
if X is None:
error_message = 'X is None, so QCing is not possible'
raise ValueError(error_message)
# Get the `custom_filter`, if specified
if custom_filter is not None:
custom_filter = self._get_column(
'obs', custom_filter, 'custom_filter', pl.Boolean)
# Check that `subset` is Boolean; if `subset=True`, check that
# `QC_column` has the default value of `'passed_QC'`
check_type(subset, 'subset', bool, 'Boolean')
if subset and QC_column != 'passed_QC':
error_message = 'QC_column can only be specified when subset=False'
raise ValueError(error_message)
# If `max_mito_fraction` was specified, check that it is a number
# between 0 and 1, inclusive
if max_mito_fraction is not None:
check_type(max_mito_fraction, 'max_mito_fraction', (int, float),
'a number between 0 and 1, inclusive')
check_bounds(max_mito_fraction, 'max_mito_fraction', 0, 1)
# If `min_genes` was specified, check that it is a non-negative integer
if min_genes is not None:
check_type(min_genes, 'min_genes', int, 'a non-negative integer')
check_bounds(min_genes, 'min_genes', 0)
# Check that `nonzero_MALAT1` and `remove_doublets` are Boolean
check_type(nonzero_MALAT1, 'nonzero_MALAT1', bool, 'Boolean')
check_type(remove_doublets, 'remove_doublets', bool, 'Boolean')
# If `batch_column` was specified, get it after checking that
# `remove_doublets=True`
if batch_column is not None:
if not remove_doublets:
error_message = (
'batch_column must be None when remove_doublets is False, '
'since its only use within qc() is for doublet detection')
raise ValueError(error_message)
batch_column = self._get_column(
'obs', batch_column, 'batch_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'))
# If `doublet_fraction` was specified, check that it is a number
# between 0 and 1, exclusive
if doublet_fraction is not None:
check_type(doublet_fraction, 'doublet_fraction', float,
'a number greater than 0 and less than 1')
check_bounds(doublet_fraction, 'doublet_fraction', 0, 1,
left_open=True, right_open=True)
# Check that `num_doublet_genes` is a positive integer
check_type(num_doublet_genes, 'num_doublet_genes', int,
'a positive integer')
check_bounds(num_doublet_genes, 'num_doublet_genes', 1)
# Check that `allow_float` is Boolean
check_type(allow_float, 'allow_float', bool, 'Boolean')
# Check that `verbose` is Boolean
check_type(verbose, 'verbose', bool, 'Boolean')
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# If `allow_float=False`, raise an error if `X` is floating-point
if not allow_float and np.issubdtype(X.dtype, np.floating):
error_message = (
f'qc() requires raw counts but X has data type '
f'{str(X.dtype)!r}, a floating-point data type. If you are '
f'sure that all values are raw integer counts, i.e. that '
f'(X.data == X.data.astype(int)).all(), then set '
f'allow_float=True.')
raise TypeError(error_message)
# Check that `obs_names` and `var_names` are unique
num_unique = self.obs_names.n_unique()
if num_unique < len(self._obs):
error_message = (
f'obs_names contains {num_unique - len(self._obs):,} '
f'duplicates; deduplicate with make_obs_names_unique()')
raise ValueError(error_message)
num_unique = self.var_names.n_unique()
if num_unique < len(self._var):
error_message = (
f'var_names contains {num_unique - len(self._var):,} '
f'duplicates; deduplicate with make_var_names_unique()')
raise ValueError(error_message)
# Apply the custom filter, if specified
if verbose:
print(f'Starting with {len(self):,} {plural("cell", len(self))}.')
mask = None
if custom_filter is not None:
if verbose:
print('Applying the custom filter...')
mask = custom_filter
if verbose:
num_cells = mask.sum()
print(
f'{num_cells:,} '
f'{"cell remains" if num_cells == 1 else "cells remain"} '
f'after applying the custom filter.')
# Filter to cells with ≤ `100 * max_mito_fraction`% mitochondrial
# counts, if `max_mito_fraction` was specified
if max_mito_fraction is not None:
if verbose:
print(f'Filtering to cells with ≤{100 * max_mito_fraction}% '
f'mitochondrial counts...')
var_names = self.var_names
if var_names.dtype != pl.String:
var_names = var_names.cast(pl.String)
mt_genes = var_names.str.to_uppercase().str.starts_with('MT-')
if not mt_genes.any():
error_message = (
'no genes are mitochondrial (start with "MT-", '
'case-insensitively); this may happen if your var_names '
'are Ensembl IDs (ENSG) rather than gene symbols (in '
'which case you should set the gene symbols as the '
'var_names with set_var_names()), or if mitochondrial '
'genes have already been filtered out (in which case you '
'can set max_mito_fraction to None)')
raise ValueError(error_message)
mito_mask = np.empty(X.shape[0], dtype=bool)
if isinstance(X, csr_array):
mito_mask_csr(data=X.data, indices=X.indices, indptr=X.indptr,
mt_genes=mt_genes.to_numpy(),
max_mito_fraction=max_mito_fraction,
mito_mask=mito_mask, num_threads=num_threads)
else:
mito_mask_csc(data=X.data, indices=X.indices, indptr=X.indptr,
mt_genes=mt_genes.to_numpy(),
max_mito_fraction=max_mito_fraction,
mito_mask=mito_mask, num_threads=num_threads)
mito_mask = pl.Series(mito_mask)
if not mito_mask.any():
error_message = (
f'no cells remain after filtering to cells with '
f'≤{100 * max_mito_fraction}% mitochondrial counts')
raise ValueError(error_message)
if mask is None:
mask = mito_mask
else:
mask &= mito_mask
if verbose:
num_cells = mask.sum()
print(
f'{num_cells:,} '
f'{"cell remains" if num_cells == 1 else "cells remain"} '
f'after filtering to cells with '
f'≤{100 * max_mito_fraction}% mitochondrial counts.')
# Filter to cells with ≥ `min_genes` genes detected, if specified
if min_genes is not None:
if verbose:
print(f'Filtering to cells with ≥{min_genes:,} '
f'{plural("gene", min_genes)} detected (with non-zero '
f'count)...')
gene_mask = pl.Series(getnnz_at_least_threshold(
X, axis=1, num_threads=num_threads, threshold=min_genes))
if not gene_mask.any():
error_message = (
f'no cells remain after filtering to cells with '
f'≥{min_genes:,} {plural("gene", min_genes)} detected')
raise ValueError(error_message)
if mask is None:
mask = gene_mask
else:
mask &= gene_mask
if verbose:
num_cells = mask.sum()
print(
f'{num_cells:,} '
f'{"cell remains" if num_cells == 1 else "cells remain"} '
f'after filtering to cells with ≥{min_genes:,} '
f'{plural("gene", min_genes)} detected.')
# Filter to cells with non-zero MALAT1 expression, if
# `nonzero_MALAT1=True`
if nonzero_MALAT1:
if verbose:
print(f'Filtering to cells with non-zero MALAT1 expression...')
MALAT1_index = self._var\
.select(pl.arg_where(pl.col(self.var_names.name)
.is_in(('MALAT1', 'Malat1'))))
if len(MALAT1_index) == 0:
error_message = (
f"neither 'MALAT1' nor 'Malat1' was found in var_names; "
f"this may happen if your var_names are Ensembl IDs "
f"(ENSG) rather than gene symbols (in which case you "
f"should set the gene symbols as the var_names with "
f"set_var_names()). Alternatively, set "
f"nonzero_MALAT1=False to disable filtering on MALAT1 "
f"expression.")
raise ValueError(error_message)
if len(MALAT1_index) == 2:
error_message = (
"both 'MALAT1' and 'Malat1' were found in var_names; if "
"this is intentional, rename one of them before running "
"qc(), or set nonzero_MALAT1=False to disable filtering "
"on MALAT1 expression")
raise ValueError(error_message)
MALAT1_index = MALAT1_index.item()
if isinstance(X, csr_array):
MALAT1_mask = np.empty(X.shape[0], dtype=bool)
if X._has_sorted_indices:
# Known sorted: binary search (works for sorted indices,
# whether canonical or not)
malat1_mask_csr(indices=X.indices, indptr=X.indptr,
MALAT1_index=MALAT1_index,
MALAT1_mask=MALAT1_mask,
num_threads=num_threads)
elif X._has_sorted_indices is False:
# Known unsorted: linear scan with short-circuiting
if verbose:
print('Warning: X does not have sorted indices, so '
'some operations may be slower. You may want to '
'sort indices with `sc.X.sort_indices()` (an '
'in-place operation) as the first step after '
'loading, though be aware that this may take a '
'while.')
malat1_mask_csr_scan(indices=X.indices, indptr=X.indptr,
MALAT1_index=MALAT1_index,
MALAT1_mask=MALAT1_mask,
num_threads=num_threads)
else:
# Unknown sortedness: linear scan without short-circuiting
# checking canonicalness/sortedness along the way
has_canonical_format, has_sorted_indices = \
malat1_mask_csr_check(
indices=X.indices, indptr=X.indptr,
MALAT1_index=MALAT1_index, MALAT1_mask=MALAT1_mask,
num_threads=num_threads)
X._has_canonical_format = has_canonical_format
X._has_sorted_indices = has_sorted_indices
if verbose and not has_sorted_indices:
print('Warning: X does not have sorted indices, so '
'some operations may be slower. You may want '
'to sort indices with `sc.X.sort_indices()` '
'(an in-place operation) as the first step '
'after loading, though be aware that this may '
'take a while.')
else:
start = X.indptr[MALAT1_index]
end = X.indptr[MALAT1_index + 1]
MALAT1_mask = np.zeros(X.shape[0], dtype=bool)
MALAT1_mask[X.indices[start:end]] = True
MALAT1_mask = pl.Series(MALAT1_mask)
if mask is None:
mask = MALAT1_mask
else:
mask &= MALAT1_mask
if verbose:
num_cells = mask.sum()
print(
f'{num_cells:,} '
f'{"cell remains" if num_cells == 1 else "cells remain"} '
f'after filtering to cells with non-zero MALAT1 '
f'expression.')
# Remove predicted doublets, if `remove_doublets=True`. Exclude cells
# that have failed earlier QC steps from being considered in the
# doublet detection by passing `QC_column=mask`.
if remove_doublets:
if verbose:
print('Removing predicted doublets...')
singlets = pl.Series(SingleCell._find_doublets(
X=X, batch_column=batch_column, QC_column=mask,
doublet_fraction=doublet_fraction, num_genes=num_doublet_genes,
return_scores=False, num_threads=num_threads))
if mask is None:
mask = singlets
else:
mask &= singlets
if verbose:
num_cells = mask.sum()
print(
f'{num_cells:,} '
f'{"cell remains" if num_cells == 1 else "cells remain"} '
f'after removing predicted doublets.')
# Add the mask of QCed cells as a column, or subset if `subset=True`
if mask is None:
error_message = 'no QC filters were specified'
raise ValueError(error_message)
if subset:
if verbose:
print(f'Subsetting to cells passing QC (note: you can reduce '
f'memory usage by specifying subset=False)...')
sc = self.filter_obs(mask)
else:
if verbose:
print(f'Adding a Boolean column, obs[{QC_column!r}], '
f'indicating which cells passed QC...')
sc = SingleCell(X=X, obs=self._obs.with_columns(pl.lit(mask)
.alias(QC_column)),
var=self._var, obsm=self._obsm, varm=self._varm,
uns=self._uns, num_threads=self._num_threads)
sc._uns['QCed'] = True
return sc
[docs]
def skip_qc(self) -> SingleCell:
"""
Skips QC, but allows the dataset to be used by downstream functions
that require QCed data. Equivalent to `self.with_uns(QCed=True)`.
Returns:
The dataset with `self.uns['QCed']` set to `True`.
"""
return self.with_uns(QCed=True)
@staticmethod
def _find_doublets(X: csr_array | csc_array,
batch_column: SingleCellColumn | None,
QC_column: SingleCellColumn | None,
doublet_fraction: float | np.floating | None,
num_genes: int | np.integer,
return_scores: bool,
num_threads: int | np.integer | None) -> \
np.ndarray[np.dtype[np.bool_]] | \
tuple[np.ndarray[np.dtype[np.bool_]],
np.ndarray[np.dtype[np.float32]]]:
"""
Find doublets using cxds (co-expression-based doublet scoring;
academic.oup.com/bioinformatics/article/36/4/1150/5566507). Used by
`qc()` (when `remove_doublets=True`) and `find_doublets()`.
Args:
X: the count matrix; may be either normalized or unnormalized
batch_column: an optional String, Enum, Categorical, or integer
column of `obs` indicating which batch each cell is
from. Can be a column name, a polars expression, a
polars Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a polars
Series or 1D NumPy array. Doublet detection will be
performed separately for each batch; cells where
`batch_column` is `null` will collectively be treated
as a single batch. Set to `None` if all cells belong
to the same sequencing batch.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their doublet labels and doublet scores set to
`null`.
doublet_fraction: an optional fraction of cells (within each batch,
if `batch_column` is specified) to be classified
as doublets. If `None`, automatically detect the
threshold via the approach described in
`find_doublets()`.
num_genes: the number of highly variable genes, i.e. genes
expressed in as close to 50% of cells as possible, to
use during doublet detection. This parameter usually has
a minimal influence on accuracy as long as it is
sufficiently large (in the hundreds), so increasing it
further will mainly just increase runtime. If
`num_genes` is greater than the number of genes in the
dataset, all genes will be used.
num_threads: the number of threads to use when finding doublets.
Set `num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores.
Returns:
A NumPy array with the binary doublet calls (`True` for singlets,
`False` for doublets), or a tuple of two arrays with the doublet
calls and doublet scores when `return_scores=True`. For cells where
`QC_column` is `False`, doublet calls will be `False` (though could
equivalently be uninitialized, if not for this Polars bug:
https://github.com/pola-rs/polars/issues/24296) and doublet scores
will be uninitialized.
"""
is_csr = isinstance(X, csr_array)
# If `batch_column` was specified, get the row indices of each batch,
# ignoring cells failing QC when `QC_column` is present in `obs`. If
# no batches were specified but `QC_column` is present, use `QC_column`
# as the batch labels.
if QC_column is not None and batch_column is None:
batch_column = QC_column
if batch_column is not None:
batch_column_name = batch_column.name
batches = (batch_column
.to_frame()
.lazy()
.with_columns(_SingleCell_batch_indices=pl.int_range(
pl.len(), dtype=pl.UInt32))
if QC_column is None else
batch_column
.to_frame()
.lazy()
.with_columns(
_SingleCell_batch_indices=pl.int_range(
pl.len(), dtype=pl.UInt32))
.filter(QC_column)) \
.group_by(batch_column_name, maintain_order=True) \
.agg('_SingleCell_batch_indices') \
.select('_SingleCell_batch_indices') \
.collect() \
.to_series()
else:
# If neither `batch_column` nor `QC_column` were specified, use a
# single dummy batch
batches = pl.Series([], dtype=pl.UInt32),
# Preallocate
total_num_cells = X.shape[0]
if num_threads == 1:
singlets = np.zeros(total_num_cells, dtype=bool)
else:
singlets = numa_zeros(total_num_cells, dtype=bool)
if not return_scores:
doublet_scores = np.array([], dtype=np.float32)
elif num_threads == 1:
doublet_scores = np.empty(total_num_cells, dtype=np.float32)
else:
doublet_scores = numa_zeros(total_num_cells, dtype=np.float32)
max_cells = batches.list.len().max() \
if batch_column is not None else total_num_cells
if doublet_fraction is not None:
doublet_indices = np.empty(max_cells, dtype=np.uint32)
else:
# Set `doublet_fraction` to `-1` if `None`, so it can be passed to
# Cython as a float
doublet_fraction = -1
doublet_indices = np.array([], dtype=np.uint32)
num_total_genes = X.shape[1]
all_detection_counts = np.empty(num_total_genes, dtype=np.uint32)
detection_counts = np.empty(num_genes, dtype=np.uint32)
ps = np.empty(num_genes, dtype=np.float32)
hvgs = np.empty(num_genes, dtype=np.uint32)
distances = np.empty(num_genes, dtype=np.float32)
obs_buffer = np.empty(num_genes * num_genes, dtype=np.uint32)
S_buffer = np.empty(num_genes * num_genes, dtype=np.float32)
cxds_scores_buffer = np.empty(max_cells, dtype=np.float32)
sim_indptr_buffer = np.empty(max_cells + 1, dtype=X.indptr.dtype)
cxds_scores_sim_buffer = np.empty(max_cells, dtype=np.float32)
original_num_threads = X._num_threads
try:
X._num_threads = num_threads
# For each batch...
for batch_index, batch_indices in enumerate(batches):
# Subset to cells in this batch
if batch_column is None:
X_batch = X
else:
if len(batch_indices) == 1:
continue
X_batch = X[batch_indices]
# Get the detection count of each gene
getnnz(X_batch, axis=0, num_threads=num_threads,
output=all_detection_counts)
# Normalize `detection_counts` by `num_cells` to get the
# detection rate `p`. Subset to the `num_genes` genes with
# detection rates closest to 50%. Exclude genes with detection
# rates of 0% or 100%.
num_cells = X_batch.shape[0]
batch_num_genes = get_hvgs(
all_detection_counts=all_detection_counts,
detection_counts=detection_counts, ps=ps, hvgs=hvgs,
distances=distances, num_cells=num_cells,
num_total_genes=num_total_genes, num_genes=num_genes)
detection_counts = detection_counts[:batch_num_genes]
ps = ps[:batch_num_genes]
X_batch = X_batch[:, hvgs[:batch_num_genes]]
# Convert `X_batch` to CSR, if CSC
if not is_csr:
X_batch = X_batch.tocsr()
# Sort indices, if not already sorted (necessary for
# `simulate_doublets`)
if not X.has_sorted_indices:
X_batch.sort_indices()
# Get `obs`, where `obs[i, j]` is the number of cells that
# express exactly one of genes `i` and `j`
obs = obs_buffer[:batch_num_genes * batch_num_genes]\
.reshape((batch_num_genes, batch_num_genes))
compute_obs(detection_counts=detection_counts,
indices=X_batch.indices, indptr=X_batch.indptr,
obs=obs, num_threads=num_threads)
# Get `S`, the upper-tail log binomial p-values of `obs`
S = S_buffer[:batch_num_genes * batch_num_genes]\
.reshape((batch_num_genes, batch_num_genes))
compute_S(obs=obs, ps=ps, num_cells=num_cells, S=S,
num_threads=num_threads)
# Calculate each cell's cxds score: the sum of `-S[i, j]`
# across all gene pairs `i` and `j` that are both expressed by
# the cell
cxds_scores = cxds_scores_buffer[:num_cells]
compute_cxds(indices=X_batch.indices, indptr=X_batch.indptr,
S=S, cxds_scores=cxds_scores,
num_threads=num_threads)
# Now simulate doublets within the batch and compute their cxds
# scores, using the original `S` matrix derived from the real
# data. Conservatively allocate twice as much memory for the
# indices as the original indices, since in the worst-case
# scenario none of the indices will match up and all coinflips
# will be 1.
sim_indptr = sim_indptr_buffer[:num_cells + 1]
sim_indices = np.empty(2 * len(X_batch.indices),
dtype=X_batch.indices.dtype)
simulate_doublets(data=X_batch.data, indices=X_batch.indices,
indptr=X_batch.indptr,
sim_indices=sim_indices,
sim_indptr=sim_indptr, num_cells=num_cells,
seed=batch_index)
sim_indices = sim_indices[:sim_indptr[-1]]
cxds_scores_sim = cxds_scores_sim_buffer[:num_cells]
compute_cxds(indices=sim_indices, indptr=sim_indptr, S=S,
cxds_scores=cxds_scores_sim,
num_threads=num_threads)
# Call doublets. If using multiple batches, map doublet labels
# (and scores, if `return_scores`) back to the full dataset.
call_doublets(cxds_scores=cxds_scores,
median_cxds_score_sim=np.median(cxds_scores_sim)
if doublet_fraction == -1 else 0,
doublet_indices=doublet_indices,
batch_indices=batch_indices.to_numpy(),
singlets=singlets,
doublet_scores=doublet_scores,
doublet_fraction=doublet_fraction,
num_threads=num_threads)
finally:
X._num_threads = original_num_threads
if return_scores:
if batch_column is None:
doublet_scores = cxds_scores
return singlets, doublet_scores
else:
return singlets
[docs]
def find_doublets(self,
*,
batch_column: SingleCellColumn | None,
QC_column: SingleCellColumn | None = 'passed_QC',
doublet_fraction: float | np.floating | None = None,
num_genes: int | np.integer = 500,
doublet_column: str = 'doublet',
doublet_score_column: str | None = 'doublet_score',
overwrite: bool = False,
num_threads: int | np.integer | None = None):
"""
Find doublets using cxds ([co-expression-based doublet scoring]
(academic.oup.com/bioinformatics/article/36/4/1150/5566507)).
The standard way to filter out doublets is by specifying
`remove_doublets=True` in `qc()`. If you did that, do not use this
function! This function should only be used in the unusual scenario
where you want to find doublets without running any other
quality-control steps.
This function gives the same result regardless of whether it is run
before or after normalization. The actual expression value does not
matter, only whether or not it is zero.
Doublets cannot occur across sequencing batches, so make sure to
specify `batch_column` if your dataset has multiple batches! Doublet
detection will be done independently within each batch.
Since the cxds score is continuous, it needs to be converted into a
binary classification of doublets versus non-doublets. This problem can
be framed as finding a cxds score threshold above which a cell is
deemed to be a doublet. To determine this threshold, we simulate
doublets by combining the counts from randomly selected pairs of cells,
via the following steps:
1) Sample as many random pairs of cells (with replacement) as there are
real cells.
2) Combine the counts from each pair of cells into a simulated doublet.
Because cxds operates on binarized count matrices, we average the
two cells' count matrices in a binary sense: if a gene is expressed
in either cell, it is deemed to be expressed in the simulated
doublet, but if it has a count of 1 in one cell and 0 in the other,
it is randomly chosen to be either expressed or not expressed with
equal probability (since the average count would be 0.5).
3) Calculate cxds scores for these simulated doublets, based on the
coexpression patterns (the `S` matrix from cxds) learned from the
real data.
4) Take the median cxds score of the simulated doublets as the
threshold. In other words, if a real cell has a higher doublet score
than the average simulated doublet, we call it a doublet.
Alternatively, specify `doublet_fraction` to force a specific fraction
of cells to be classified as doublets.
Args:
batch_column: an optional String, Enum, Categorical, or integer
column of `obs` indicating which batch each cell is
from. Can be a column name, a polars expression, a
polars Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a polars
Series or 1D NumPy array. Doublet detection will be
performed separately for each batch; cells where
`batch_column` is `null` will collectively be treated
as a single batch. Set to `None` if all cells belong
to the same sequencing batch.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their doublet labels and doublet scores set to
`null`.
doublet_fraction: an optional fraction of cells (within each batch,
if `batch_column` is specified) to be classified
as doublets. If `None`, automatically detect the
threshold via the approach described above.
num_genes: the number of highly variable genes, i.e. genes
expressed in as close to 50% of cells as possible, to
use during doublet detection. This parameter usually has
a minimal influence on accuracy as long as it is
sufficiently large (in the hundreds), so increasing it
further will mainly just increase runtime. If
`num_genes` is greater than the number of genes in the
dataset, all genes will be used.
doublet_column: the name of a Boolean column to be added to `obs`
containing the doublet labels, i.e. whether each
cell is predicted to be a doublet
doublet_score_column: the name of a column to be added to `obs`
containing each cell's doublet score. Higher
scores indicate greater likelihood of being a
doublet. Scores are not normalized and are
not comparable across datasets or batches,
but are guaranteed to be positive (since they
are sums of log p-values). Set
`doublet_score_column=None` to not return
doublet scores, for a slight memory reduction
and speed increase.
overwrite: if `True`, overwrite `doublet_column` and/or
`doublet_score_column` if already present in `obs`,
instead of raising an error
num_threads: the number of threads to use when finding doublets.
Set `num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores.
Returns:
A new SingleCell dataset where `var` contains two additional
columns, `doublet_column` (default: `doublet`), indicating whether
each cell is predicted to be a doublet, and `doublet_score_column`
(default: `'doublet_score'`), containing each cell's doublet score.
Note:
This function's cxds scores are almost exactly half the original
implementation's, because it avoids double-counting the two genes
in each gene pair. Slight deviations from this one-half (usually
by less than one part in a million) may occur because this function
uses a normal approximation to the binomial p-value to avoid long
runtimes on large datasets.
Note:
This function may give an incorrect output if the count matrix
contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is
not checked for, due to speed considerations. In the unlikely event
that your dataset contains explicit zeros, remove them by running
`sc.X.eliminate_zeros()` (an in-place operation) first.
"""
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Check that `doublet_column` and (if not `None`)
# `doublet_score_column` are strings and, unless `overwrite=True`, not
# already in `obs`
check_type(doublet_column, 'doublet_column', str, 'a string')
if not overwrite and doublet_column in self._obs:
error_message = (
f'doublet_column {doublet_column!r} is already a column of '
f'obs; did you already run find_doublets()? Set '
f'overwrite=True to overwrite.')
raise ValueError(error_message)
return_scores = doublet_score_column is not None
if return_scores:
check_type(doublet_score_column, 'doublet_score_column', str,
'a string')
if not overwrite and doublet_score_column in self._obs:
error_message = (
f'doublet_score_column {doublet_score_column!r} is '
f'already a column of obs; did you already run '
f'find_doublets()? Set overwrite=True to overwrite.')
raise ValueError(error_message)
# Check that `X` is present
if self._X is None:
error_message = 'X is None, so doublet finding is not possible'
raise ValueError(error_message)
# Check that `self` is QCed
if not self._uns['QCed']:
error_message = (
"uns['QCed'] is False; did you forget to run qc() before "
"find_doublets()? Set uns['QCed'] = True or run skip_qc() to "
"bypass this check.")
raise ValueError(error_message)
# Get `QC_column` and `batch_column`, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
if batch_column is not None:
batch_column = self._get_column(
'obs', batch_column, 'batch_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_column=QC_column)
# Check that `doublet_fraction`, if specified, is > 0 and < 1
if doublet_fraction is not None:
check_type(doublet_fraction, 'doublet_fraction', float,
'a number greater than 0 and less than 1')
check_bounds(doublet_fraction, 'doublet_fraction', 0, 1,
left_open=True, right_open=True)
# Check that `num_genes` is a positive integer
check_type(num_genes, 'num_genes', int, 'a positive integer')
check_bounds(num_genes, 'num_genes', 1)
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Run doublet detection
singlets = SingleCell._find_doublets(
X=self._X, batch_column=batch_column, QC_column=QC_column,
doublet_fraction=doublet_fraction, num_genes=num_genes,
return_scores=return_scores, num_threads=num_threads)
# Convert doublet labels (and scores, if `doublet_score_column` is not
# `None`) to polars Series to add to `obs`. If `QC_column` exists, set
# doublet labels (and scores) to `null` for cells failing QC.
if not return_scores:
doublet_column = pl.Series(doublet_column, ~singlets)
if QC_column is None:
return self.with_columns_obs(doublet_column)
else:
return self.with_columns_obs(
pl.when(QC_column).then(doublet_column))
else:
singlets, doublet_scores = singlets
doublet_column = pl.Series(doublet_column, ~singlets)
doublet_score_column = \
pl.Series(doublet_score_column, doublet_scores)
if QC_column is None:
return self.with_columns_obs(doublet_column,
doublet_score_column)
else:
return self.with_columns_obs(
pl.when(QC_column).then(doublet_column),
pl.when(QC_column).then(doublet_score_column))
[docs]
def make_obs_names_unique(self, *, separator: str = '-') -> SingleCell:
"""
Make `obs_names` unique by appending `'-1'` to the second occurence of
a given name, `'-2'` to the third occurrence, and so on, where `'-'`
can be switched to a different string via the `separator` argument.
Raises an error if any `obs_names` already contain `separator`.
Args:
separator: the string connecting the original name and the integer
suffix
Returns:
A new SingleCell dataset with the `obs_names` made unique.
"""
check_type(separator, 'separator', str, 'a string')
unique_obs_names = self.obs_names \
if self.obs_names.dtype == pl.String else \
self.obs_names.cat.get_categories()
if unique_obs_names.str.contains(separator).any():
error_message = (
f'some obs_names already contain the separator {separator!r}; '
f'did you already run make_obs_names_unique()? If not, set '
f'the separator argument to a different string.')
raise ValueError(error_message)
obs_names = pl.col(self.obs_names.name)
num_times_seen = pl.int_range(pl.len(), dtype=pl.Int32).over(obs_names)
return SingleCell(X=self._X,
obs=self._obs.with_columns(
pl.when(num_times_seen > 0)
.then(obs_names + separator +
num_times_seen.cast(str))
.otherwise(obs_names)),
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def make_var_names_unique(self, *, separator: str = '-') -> SingleCell:
"""
Make `var_names` unique by appending `'-1'` to the second occurence of
a given name, `'-2'` to the third occurrence, and so on, where `'-'`
can be switched to a different string via the `separator` argument.
Raises an error if any `var_names` already contain `separator`.
Args:
separator: the string connecting the original name and the integer
suffix
Returns:
A new SingleCell dataset with the `var_names` made unique.
"""
check_type(separator, 'separator', str, 'a string')
unique_var_names = self.var_names \
if self.var_names.dtype == pl.String else \
self.var_names.cat.get_categories()
if unique_var_names.str.contains(separator).any():
error_message = (
f'some var_names already contain the separator {separator!r}; '
f'did you already run make_var_names_unique()? If not, set '
f'the separator argument to a different string.')
raise ValueError(error_message)
var_names = pl.col(self.var_names.name)
num_times_seen = pl.int_range(pl.len(), dtype=pl.Int32).over(var_names)
return SingleCell(X=self._X,
obs=self._obs,
var=self._var.with_columns(
pl.when(num_times_seen > 0)
.then(var_names + separator +
num_times_seen.cast(str))
.otherwise(var_names)),
obsm=self._obsm, varm=self._varm, obsp=self._obsp,
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def get_sample_covariates(self,
*,
ID_column: SingleCellColumn,
QC_column: SingleCellColumn |
None = 'passed_QC') -> pl.DataFrame:
"""
Get a DataFrame of sample-level covariates, i.e. the columns of `obs`
that are the same for all cells within each sample.
Args:
ID_column: a String, Enum, Categorical, or integer column of `obs`
containing sample IDs. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored.
Returns:
A DataFrame of the sample-level covariates, with ID_column (sorted)
as the first column.
"""
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
ID_column = self._get_column('obs', ID_column, 'ID_column',
(pl.String, pl.Enum, pl.Categorical,
'integer'), QC_column=QC_column)
ID_column_name = ID_column.name
obs = self._obs.with_columns(ID_column)
if QC_column is not None:
obs = obs.filter(QC_column)
other_columns = \
pl.exclude(ID_column_name) if ID_column_name in obs else pl.all()
aggregated = obs.group_by(ID_column_name).agg(
other_columns.first(),
other_columns.min().name.prefix('_SingleCell_min_'),
other_columns.max().name.prefix('_SingleCell_max_'))
constant_columns = aggregated\
.select(pl.col('^_SingleCell_min_.*$') ==
pl.col('^_SingleCell_max_.*$'))\
.rename(lambda c: c.removeprefix('_SingleCell_min_'))\
.pipe(filter_columns, pl.all().all())\
.columns
sample_covariates = aggregated\
.select(ID_column_name, *constant_columns)\
.sort(ID_column_name)
return sample_covariates
[docs]
def pseudobulk(self,
ID_column: SingleCellColumn,
cell_type_column: SingleCellColumn,
/,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
cell_types: str | Iterable[str] | int | Iterable[int] |
None = None,
excluded_cell_types: str | Iterable[str] | int |
Iterable[int] | None = None,
additional_obs: pl.DataFrame | None = None,
include_nulls: bool = False,
sort_genes: bool = False,
num_threads: int | np.integer | None = None,
verbose: bool = True) -> Pseudobulk:
"""
Pseudobulk a SingleCell dataset with sample ID and cell type columns.
Operates on raw counts, so cannot be run after `normalize()`. Must be
run after `qc()`.
Counts from cells with the same pair of values in `ID_column` and
`cell_type_column` will be summed to a single value. Cells with `null`
in either column are excluded, unless `include_nulls=True`.
You can run this function multiple times at different cell type
resolutions by setting a different `cell_type_column` each time, then
combining the results afterwards with the `|` operator (assuming none
of the cell types overlap between the two resolutions):
```
pb_broad = sc.pseudobulk('ID', 'broad_cell_type')
pb_fine = sc.pseudobulk('ID', 'fine_grained_cell_type')
pb = pb_broad | pb_fine
```
Args:
ID_column: a String, Enum, Categorical, or integer column of `obs`
containing sample IDs. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array.
cell_type_column: a String, Enum, Categorical, or integer column of
`obs` containing cell-type labels. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a
polars Series or 1D NumPy array. If
cell_type_column is an integer column, the cell
types in the Pseudobulk dataset will be coerced
to strings.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be excluded
from the pseudobulk.
cell_types: one or more cell types to pseudobulk; by default,
pseudobulks all cell types in `cell_type_column`.
Specifying `cell_types` is exactly equivalent to
filtering the result to these cell types, but will be
faster when there are many cell types and pseudobulks
are only desired for a few of them. Can also be used to
change the order in which cell types appear in the
resulting Pseudobulk dataset, even if pseuodbulking all
cell types. Mutually exclusive with
`excluded_cell_types` and `include_nulls`.
excluded_cell_types: one or more cell types to exclude from
pseudobulking. Mutually exclusive with
`cell_types` and `include_nulls`.
additional_obs: an optional DataFrame of additional sample-level
covariates, which will be joined to the
pseudobulk's `obs` for each cell type
include_nulls: whether to exclude cells with `null` values in
`ID_column` and/or `cell_type_column` from the
pseudobulk. If `include_nulls=True`, `null` will be
treated just like any other value. This means that,
for instance, all cells from a given cell type that
have `null` as the sample ID will be pseudobulked
together, as will all cells from a given sample ID
that have `null` as the cell type. Mutually
exclusive with `cell_types` and
`excluded_cell_types`.
sort_genes: whether to sort genes in alphabetical order in the
pseudobulk; by default, genes appear in the same order
as in the SingleCell dataset
num_threads: the number of threads to use when pseudobulking;
parallelism happens across {sample, cell type} pairs
(or just samples, if `cell_type_column` is `None`).
Set `num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores. For count matrices stored in
the usual CSR format, parallelization takes place
across cell types and samples, so specifying more
threads than the number of cell type-sample pairs will
not provide additional speedup. Does not affect the
Pseudobulk dataset's `num_threads`; this will always
be the same as the SingleCell dataset's `num_threads`.
verbose: whether to print the number of cells excluded when
`include_nulls=False` (and neither `cell_types` nor
`excluded_cell_types` are specified)
Returns:
A Pseudobulk dataset with `X` (the pseudobulked counts), `obs`
(metadata per sample), and `var` (metadata per gene) fields, each
of which are dictionaries across cell types. The columns of each
cell type's `obs` will be:
- `ID_column`
- `'num_cells'` (the number of cells for that sample and cell type)
followed by whichever columns of the SingleCell dataset's `obs` are
constant across samples. `var` will be identical to the SingleCell
dataset's `var`.
Note:
This function may give an incorrect output if the count matrix
contains negative values: this is not checked for, due to speed
considerations.
"""
# Check that `X` is present
X = self._X
if X is None:
error_message = 'X is None, so pseudobulking is not possible'
raise ValueError(error_message)
# Check that `self` is QCed and not normalized
if not self._uns['QCed']:
error_message = (
"uns['QCed'] is False; did you forget to run qc() before "
"pseudobulk()? Set uns['QCed'] = True or run skip_qc() to "
"bypass this check.")
raise ValueError(error_message)
if self._uns['normalized']:
error_message = (
"uns['normalized'] is True; did you already run normalize()?")
raise ValueError(error_message)
# Get the QC, ID, and cell-type columns
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
original_ID_column = ID_column
ID_column = self._get_column('obs', ID_column, 'ID_column',
(pl.String, pl.Enum, pl.Categorical,
'integer'), QC_column=QC_column,
allow_null=True)
ID_column_name = ID_column.name
cell_type_column = \
self._get_column('obs', cell_type_column, 'cell_type_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_column=QC_column,
allow_null=True)
cell_type_column_name = cell_type_column.name
num_cells_column_name = 'num_cells'
for column_description, column_name in ('ID_column', ID_column_name), \
('cell_type_column', cell_type_column_name):
if column_name == num_cells_column_name:
error_message = (
f'{column_description} has the name '
f'{num_cells_column_name!r}, which conflicts with the '
f'name of the column to be added to the Pseudobulk '
f'dataset containing the number of cells of each cell '
f'type')
raise ValueError(error_message)
# Check that `include_nulls`, `sort_genes`, and `verbose` are Boolean
check_type(include_nulls, 'include_nulls', bool, 'Boolean')
check_type(sort_genes, 'sort_genes', bool, 'Boolean')
check_type(verbose, 'verbose', bool, 'Boolean')
# Check that `include_nulls=True` is not specified alongside either
# `cell_types` or `excluded_cell_types`
if include_nulls:
if cell_types is not None:
error_message = \
'cell_types cannot be specified when include_nulls=True'
raise ValueError(error_message)
if excluded_cell_types is not None:
error_message = (
'excluded_cell_types cannot be specified when '
'include_nulls=True')
raise ValueError(error_message)
# Check that `cell_types` and `excluded_cell_types` are not both
# specified. If `cell_types` is specified, check it contains only cell
# type names present in `cell_type_column`, then set non-matching cell
# types to `null` so that they are treated as a single background cell
# type. If `excluded_cell_types` is specified, do the opposite.
cell_type_column = SingleCell._process_cell_types(
cell_types, excluded_cell_types, cell_type_column)
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
if additional_obs is not None:
check_type(additional_obs, 'additional_obs', pl.DataFrame,
'a polars DataFrame')
if ID_column_name not in additional_obs:
ID_column_description = SingleCell._describe_column(
'ID_column', original_ID_column)
error_message = (
f'{ID_column_description} is not a column of '
f'additional_obs')
raise ValueError(error_message)
if ID_column.dtype != additional_obs[ID_column_name].dtype:
ID_column_description = SingleCell._describe_column(
'ID_column', original_ID_column)
error_message = (
f"{ID_column_description} has a different data type in "
f"additional_obs than in this SingleCell dataset's obs")
raise TypeError(error_message)
# Check that the first column of `var` is String, Enum, Categorical,
# or integer: this is a requirement of the Pseudobulk class. (The first
# column of `obs` must be as well, but this will always be true by
# construction, since it will always be the sample ID.)
dtype = self.var_names.dtype
if dtype not in (pl.String, pl.Enum, pl.Categorical) and \
dtype not in pl.INTEGER_DTYPES:
error_message = (
f'the first column of var (var_names) must be String, Enum, '
f'Categorical, or integer, but has data type {dtype!r}')
raise ValueError(error_message)
# Get the row indices that will be pseudobulked across for each group
# (cell type-sample pair), ignoring cells failing QC when `QC_column`
# is present in `obs`
groups = (pl.LazyFrame((cell_type_column, ID_column))
.with_columns(
_SingleCell_group_indices=pl.int_range(pl.len(),
dtype=pl.UInt32))
if QC_column is None else
pl.LazyFrame((cell_type_column, ID_column, QC_column))
.with_columns(
_SingleCell_group_indices=pl.int_range(pl.len(),
dtype=pl.UInt32))
.filter(QC_column.name))\
.group_by(cell_type_column_name, ID_column_name)\
.agg('_SingleCell_group_indices',
pl.len().alias(num_cells_column_name))\
.sort(cell_type_column_name, ID_column_name)\
.collect()
# Exclude cells with null values in `ID_column` and/or
# `cell_type_column`, if `include_nulls=False`. When `cell_types` or
# `excluded_cell_types` are specified, this includes cells from cell
# types that are not in `cell_types` or in `excluded_cell_types`,
# because of the processing we did above in
# `SingleCell._process_cell_types()`.
if not include_nulls:
if cell_types is not None or excluded_cell_types is not None:
groups = groups.drop_nulls()
else:
if verbose:
excluded = groups\
.filter(pl.any_horizontal(pl.col(
cell_type_column_name, ID_column_name).is_null()))
num_ID_null_only = excluded\
.filter(pl.col(ID_column_name).is_null(),
pl.col(cell_type_column_name).is_not_null())\
[num_cells_column_name]\
.sum()
num_cell_type_null_only = excluded\
.filter(pl.col(ID_column_name).is_not_null(),
pl.col(cell_type_column_name).is_null())\
[num_cells_column_name]\
.sum()
num_excluded = excluded[num_cells_column_name].sum()
num_both_null = num_excluded - num_ID_null_only - \
num_cell_type_null_only
if num_excluded > 0:
print(f'Excluding {num_excluded:,} '
f'{plural("cell", num_excluded)} when '
f'pseudobulking: {num_both_null:,} with nulls '
f'in both the ID ({ID_column_name!r}) and '
f'cell-type ({cell_type_column_name!r}) '
f'columns, {num_ID_null_only:,} with nulls in '
f'just the ID column, and '
f'{num_cell_type_null_only:,} with nulls in '
f'just the cell-type column.')
groups = groups.drop_nulls()
else:
groups = groups.drop_nulls()
if len(groups) == 0:
error_message = (
f'no cells remain after excluding cells with nulls in '
f'the the ID and/or cell-type columns')
raise ValueError(error_message)
# Pseudobulk, storing the result in a preallocated NumPy array
result = np.empty((len(groups), X.shape[1]), dtype=np.uint32)
if isinstance(X, csr_array):
group_indices = \
groups['_SingleCell_group_indices'].explode().to_numpy()
group_ends = groups[num_cells_column_name].cum_sum().to_numpy()
groupby_sum_csr(data=X.data, indices=X.indices, indptr=X.indptr,
group_indices=group_indices, group_ends=group_ends,
result=result, num_threads=num_threads)
else:
group_map = pl.int_range(X.shape[0], dtype=pl.UInt32, eager=True)\
.to_frame('_SingleCell_group_indices')\
.join(groups
.select('_SingleCell_group_indices',
_SingleCell_index=pl.int_range(pl.len(),
dtype=pl.Int32))
.explode('_SingleCell_group_indices'),
on='_SingleCell_group_indices', how='left')\
['_SingleCell_index']
has_missing = group_map.null_count() > 0
if has_missing:
group_map = group_map.fill_null(-1)
group_map = group_map.to_numpy()
groupby_sum_csc(data=X.data, indices=X.indices, indptr=X.indptr,
group_map=group_map, has_missing=has_missing,
result=result, num_threads=num_threads)
# Sort genes, if `sort_genes=True`
cell_type_var = self._var
if sort_genes:
result = result[:, cell_type_var[:, 0].arg_sort().to_numpy()]
cell_type_var = cell_type_var.sort(cell_type_var.columns[0])
# Get sample covariates
obs = self._obs.with_columns(ID_column)
if QC_column is not None:
obs = obs.filter(QC_column)
other_columns = \
pl.exclude(ID_column_name) if ID_column_name in obs else pl.all()
aggregated = obs.group_by(ID_column_name).agg(
other_columns.first(),
other_columns.min().name.prefix('_SingleCell_min_'),
other_columns.max().name.prefix('_SingleCell_max_'))
constant_columns = aggregated\
.select(pl.col('^_SingleCell_min_.*$') ==
pl.col('^_SingleCell_max_.*$'))\
.rename(lambda c: c.removeprefix('_SingleCell_min_'))\
.pipe(filter_columns, pl.all().all())\
.columns
sample_covariates = aggregated\
.select(ID_column_name, *constant_columns)\
.sort(ID_column_name)
# Break up the results by cell type
X, obs, var = {}, {}, {}
start_index = 0
for cell_type, count in groups[cell_type_column_name]\
.value_counts().sort(cell_type_column_name).iter_rows():
cell_type = str(cell_type)
end_index = start_index + count
X[cell_type] = result[start_index:end_index]
obs[cell_type] = groups.lazy()\
.select(ID_column_name, num_cells_column_name)\
.slice(start_index, count)\
.join(sample_covariates.lazy(), on=ID_column_name, how='left')\
.pipe(lambda df: df.join(additional_obs.lazy(),
on=ID_column_name, how='left')
if additional_obs is not None else df)\
.pipe(lambda df: df if QC_column is None else
df.drop(QC_column.name))\
.collect()
var[cell_type] = cell_type_var
start_index = end_index
# Propagate the SingleCell dataset's `num_threads` to the Pseudobulk
# dataset
return Pseudobulk(X=X, obs=obs, var=var, num_threads=self._num_threads)
[docs]
def hvg(self,
*others: SingleCell,
QC_column: SingleCellColumn | None |
Sequence[SingleCellColumn | None] = 'passed_QC',
batch_column: SingleCellColumn | None |
Sequence[SingleCellColumn | None] = None,
num_genes: int | np.integer = 2000,
min_cells: int | np.integer = 3,
exclude: str | Iterable[str] | None = None,
flavor: Literal['seurat_v3', 'seurat_v3_paper'] = 'seurat_v3',
span: int | float | np.integer | np.floating = 0.3,
hvg_column: str = 'highly_variable',
rank_column: str = 'highly_variable_rank',
overwrite: bool = False,
verbose: bool = True,
num_threads: int | np.integer | None = None) -> \
SingleCell | tuple[SingleCell, ...]:
"""
Select highly variable genes using the same approach as Seurat.
Operates on raw counts, so must be run before `normalize()` (but after
`qc()`). When run with multiple datasets, only considers genes present
in every dataset.
By default, uses the same approach as Seurat's `FindVariableFeatures()`
function, and Scanpy's `scanpy.pp.highly_variable_genes()` function
with the `flavor` argument set to the non-default value `'seurat_v3'`.
The general idea is that since genes with higher mean expression tend
to have higher variance in expression (because they have more non-zero
values), we want to select genes that have a high variance *relative to
their mean expression*. Otherwise, we'd only be picking highly
expressed genes! To correct for the mean-variance relationship, fit a
LOESS curve fit to the mean-variance trend.
Args:
others: optional SingleCell datasets to jointly compute highly
variable genes across, alongside this one. Each dataset
will be treated as a separate batch. If `batch_column` is
not `None`, each dataset AND each distinct value of
`batch_column` within each dataset will be treated as a
separate batch. Variances will be computed per batch and
then aggregated (see `flavor`) across batches.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored.
When `others` is specified, `QC_column` can be a
length-`1 + len(others)` sequence of columns,
expressions, Series, functions, or `None` for each
dataset (for `self`, followed by each dataset in
`others`).
batch_column: an optional String, Enum, Categorical, or integer
column of `obs` indicating which batch each cell is
from. Can be a column name, a polars expression, a
polars Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a polars
Series or 1D NumPy array. Each batch will be treated
as if it were a distinct dataset; this is exactly
equivalent to splitting the dataset with
`split_by(batch_column)` and then passing each of the
resulting datasets to `hvg()`, except that the
`min_cells` filter will always be calculated
per-dataset rather than per-batch. Variances will be
computed per batch and then aggregated (see `flavor`)
across batches. Set to `None` to treat each dataset
as having a single batch. When `others` is specified,
`batch_column` can be a length-`1 + len(others)`
sequence of columns, expressions, Series, functions,
or `None` for each dataset (for `self`, followed by
each dataset in `others`).
num_genes: the number of highly variable genes to select. The
default of 2000 matches Seurat and Scanpy's recommended
value. Fewer than `num_genes` genes will be selected if
not enough genes have non-zero count in >= `min_cells`
cells (or when `min_cells` is `None`, if not enough
genes are present).
min_cells: if not `None`, filter to genes detected (with non-zero
count) in >= this many cells in every dataset, before
calculating highly variable genes. The default value of
3 matches Seurat and Scanpy's recommended value. Note
that genes with zero variance in any dataset will always
be filtered out, even if `min_cells` is 0.
exclude: one or more optional case-insensitive regular expressions
matching genes to exclude from the highly variable gene
calculation. For instance, to exclude mitochondrial genes
(starting with `'MT-'`) and ribosomal genes (starting with
`'RPL-'`, `'RPS'`, `'MRPL'`, or `'MRPS'`), specify
`exclude=('^MT-', '^RPL', '^RPS', '^MRPL', '^MRPS')`.
flavor: the highly variable gene algorithm to use. Must be one of
`seurat_v3` and `seurat_v3_paper`, both of which match the
algorithms with the same name in Scanpy. Both algorithms
select genes based on two criteria: 1) which genes are
ranked as most variable (taking the median of the ranks
across batches where the gene is among the top `num_genes`
highly variable genes) and 2) the number of batches in
which a gene is ranked in among the top `num_genes` in
variability. `seurat_v3` ranks genes by 1) and uses 2) to
tiebreak, whereas `seurat_v3_paper` ranks genes by 2) and
uses 1) to tiebreak. When there is only one batch, both
algorithms are the same and only rank based on 1).
span: the span of the LOESS fit; higher values will lead to more
smoothing
hvg_column: the name of a Boolean column to be added to (each
dataset's) `var` indicating the highly variable genes
rank_column: the name of an integer column to be added to (each
dataset's) `var` with the rank of each highly variable
gene's variance (1 = highest variance, 2 =
next-highest, etc.); will be `null` for non-highly
variable genes. In the very unlikely event of ties,
the gene that appears first in `var` will get the
lowest rank.
overwrite: if `True`, overwrite `hvg_column` and/or `rank_column`
if already present in `var`, instead of raising an error
verbose: whether to print the number of genes present in every
dataset, when jointly computing highly variable genes
across multiple datasets
num_threads: the number of threads to use when finding highly
variable genes. Set `num_threads=-1` to use all
available cores, as determined by `os.cpu_count()`, or
leave unset to use `self.num_threads` cores. Does not
affect the `num_threads` of the returned SingleCell
dataset(s); this will always be the same as the
`num_threads` of the original dataset(s).
Returns:
A new SingleCell dataset where `var` contains an additional Boolean
column, `hvg_column` (default: `'highly_variable'`), indicating the
`num_genes` most highly variable genes, and `rank_column` (default:
'highly_variable_rank') indicating the (one-based) rank of each
highly variable gene's variance, with `null` values for non-highly
variable genes. Or, if additional SingleCell dataset(s) are
specified via the `others` argument, a length-`1 + len(others)`
tuple of SingleCell datasets with these two columns added: `self`,
followed by each dataset in `others`.
Note:
This function may give an incorrect output if the count matrix
contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is
not checked for, due to speed considerations. In the unlikely event
that your dataset contains explicit zeros, remove them by running
`sc.X.eliminate_zeros()` (an in-place operation) first.
Note:
This function may give an incorrect output if the count matrix
contains negative values: this is not checked for, due to speed
considerations.
Note:
This function may not give identical results to Seurat and Scanpy.
It avoids floating-point summation, which is more numerically
stable than Scanpy and Seurat's calculations. If multiple genes are
tied as the `num_genes`-th most highly variable gene in a batch or
dataset, this function includes all of them, whereas Seurat and
Scanpy arbitrarily pick one (or a subset) of them. Also, this
function uses the ordering from a stable sort to break ties when
selecting the final list of highly variable genes, instead of the
unstable sort used by Seurat and Scanpy.
"""
from skmisc.loess import loess
# If `others` was specified, check that all elements of `others` are
# SingleCell datasets
if others:
check_types(others, 'others', SingleCell, 'SingleCell datasets')
datasets = [self] + list(others)
# Check that `X` is present for every dataset
if any(dataset._X is None for dataset in datasets):
error_message = (
f'X is None{suffix}, so highly variable gene finding is not '
f'possible')
raise ValueError(error_message)
# Check that all datasets are QCed and not normalized
if not all(dataset._uns['QCed'] for dataset in datasets):
error_message = (
f"uns['QCed'] is False{suffix}; did you forget to run qc() "
f"before hvg()? Set uns['QCed'] = True or run skip_qc() to "
f"bypass this check.")
raise ValueError(error_message)
if any(dataset._uns['normalized'] for dataset in datasets):
error_message = (
f"hvg() requires raw counts but uns['normalized'] is "
f"True{suffix}; did you already run normalize()?")
raise ValueError(error_message)
# Check that there are at least 3 cells in each dataset (since LOESS
# seems to need at least three observations to converge)
if any(len(dataset._obs) < 3 for dataset in datasets):
error_message = (
f'there are fewer than 3 cells{suffix}, so highly variable '
f'genes cannot be calculated')
raise ValueError(error_message)
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Check that `hvg_column` and `rank_column` are strings and, unless
# `overwrite=True`, not already in `var` for any dataset
suffix = ' for at least one dataset' if others else ''
for column, column_name in (hvg_column, 'hvg_column'), \
(rank_column, 'rank_column'):
check_type(column, column_name, str, 'a string')
if not overwrite and \
any(column in dataset._var for dataset in datasets):
error_message = (
f'{column_name} {column!r} is already a column of '
f'var{suffix}; did you already run hvg()? Set '
f'overwrite=True to overwrite.')
raise ValueError(error_message)
# Check that `num_genes` is a positive integer
check_type(num_genes, 'num_genes', int, 'a positive integer')
check_bounds(num_genes, 'num_genes', 1)
# Check that `min_cells` is a positive integer and at least as large as
# the number of cells in each dataset
check_type(min_cells, 'min_cells', int, 'a non-negative integer')
check_bounds(min_cells, 'min_cells', 0)
if any(len(dataset._obs) < min_cells for dataset in datasets):
suffix = ' for at least one dataset' if others else ''
error_message = (
f'the number of cells in this dataset ({len(self._obs):,}) '
f'is less than min_cells ({min_cells:,}){suffix}; increase '
f'min_cells')
raise ValueError(error_message)
# Check that `exclude`, if specified, is a string or sequence thereof
if exclude is not None:
exclude = to_tuple_checked(exclude, 'exclude', str, 'strings')
# Check that `flavor` is 'seurat_v3' or 'seurat_v3_paper'
check_type(flavor, 'flavor', str, 'a string')
if flavor not in ('seurat_v3', 'seurat_v3_paper'):
error_message = (
f"flavor must be 'seurat_v3' or 'seurat_v3_paper', "
f"not {flavor!r}")
raise ValueError(error_message)
# Check that `span` is a positive number
check_type(span, 'span', (int, float), 'a positive number')
check_bounds(span, 'span', 0, left_open=True)
# Check that `verbose` is Boolean
check_type(verbose, 'verbose', bool, 'Boolean')
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Get `QC_column` and `batch_column` from every dataset, if not `None`
QC_columns = SingleCell._get_columns(
'obs', datasets, QC_column, 'QC_column', pl.Boolean,
allow_missing=True)
batch_columns = SingleCell._get_columns(
'obs', datasets, batch_column, 'batch_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_columns=QC_columns)
# Get the universe of genes we'll be considering: those present in all
# datasets, and not matching `exclude` (if specified). If there are
# multiple datasets, also get the indices of these genes in each
# dataset (with `null` for genes not present in that particular
# dataset). Raise an error if no genes are found in all datasets.
if others:
# The use of `align_frames` here is a bit wasteful memory-wise,
# because it creates an identical `'gene'` column for every
# DataFrame in `genes_and_indices`. Fortunately, it's only one
# small string column per dataset.
genes_and_indices = pl.align_frames(*(
dataset.var[:, 0]
.to_frame('gene')
.pipe(lambda df: df.filter(~pl.col.gene.str.contains(
'(?i)' + '|'.join(exclude))) # case-insensitive
if exclude is not None else df)
.with_columns(_SingleCell_index=pl.int_range(pl.len(),
dtype=pl.UInt32))
for dataset in datasets), on='gene', how='inner')
genes_in_all_datasets = genes_and_indices[0]['gene']
num_genes_in_all_datasets = len(genes_in_all_datasets)
if num_genes_in_all_datasets == 0:
error_message = 'no genes are present in every dataset'
if exclude is not None:
error_message += (
f', after applying the '
f'{plural("filter", len(exclude))} specified via the '
f'exclude argument')
raise ValueError(error_message)
if verbose:
print(f'{num_genes_in_all_datasets:,} '
f'{plural("gene", num_genes_in_all_datasets)} are '
f'present in every dataset.')
dataset_gene_indices = [df['_SingleCell_index']
for df in genes_and_indices]
del genes_and_indices
else:
genes_in_all_datasets = self.var_names\
.rename('gene')\
.to_frame()\
.pipe(lambda df: df.filter(~pl.col.gene.str.contains(
'(?i)' + '|'.join(exclude))) # case-insensitive
if exclude is not None else df)\
.to_series()
# Get the batches to calculate variance across (datasets + batches
# within each dataset). For CSR matrices with `batch_column`,
# pre-compute per-batch cell indices in a single pass via polars
# `partition_by`. For CSR, `cell_mask` is always `None` or an int64
# array of cell indices; for CSC, `cell_mask` is always `None` or a
# Boolean array.
if others:
if batch_column is None:
batches = [(dataset._X,
None if dataset_QC_column is None else
np.flatnonzero(dataset_QC_column.to_numpy())
if isinstance(dataset._X, csr_array)
else dataset_QC_column.to_numpy(),
gene_indices)
for dataset, dataset_QC_column, gene_indices
in zip(datasets, QC_columns, dataset_gene_indices)]
else:
batches = []
for dataset, dataset_QC_column, dataset_batch_column, \
gene_indices in zip(datasets, QC_columns, batch_columns,
dataset_gene_indices):
if dataset_batch_column is None:
batches.append((
dataset._X,
None if dataset_QC_column is None else
np.flatnonzero(dataset_QC_column.to_numpy())
if isinstance(dataset._X, csr_array)
else dataset_QC_column.to_numpy(),
gene_indices))
elif isinstance(dataset._X, csr_array):
for partition in (
dataset_batch_column
.to_frame('_SingleCell_batch')
.with_columns(
_SingleCell_idx=pl.int_range(pl.len()))
.pipe(lambda df: df.filter(dataset_QC_column)
if dataset_QC_column is not None else df)
.partition_by('_SingleCell_batch')):
batches.append((
dataset._X,
partition['_SingleCell_idx'].to_numpy(),
gene_indices))
else:
for batch in dataset_batch_column.unique():
batches.append((
dataset._X,
(dataset_batch_column.eq(batch)
if dataset_QC_column is None else
dataset_batch_column.eq(batch) &
dataset_QC_column).to_numpy(),
gene_indices))
else:
X = self._X
batch_column = batch_columns[0]
if batch_column is None:
if QC_column is not None and QC_columns[0] is not None:
if isinstance(X, csr_array):
batches = (X, np.flatnonzero(
QC_columns[0].to_numpy()), None),
else:
batches = (X, QC_columns[0].to_numpy(), None),
else:
batches = (X, None, None),
elif isinstance(X, csr_array):
df = batch_column\
.to_frame('_SingleCell_batch')\
.with_columns(_SingleCell_idx=pl.int_range(pl.len()))
if QC_column is not None and QC_columns[0] is not None:
df = df.filter(QC_columns[0])
batches = [
(X, partition['_SingleCell_idx'].to_numpy(), None)
for partition in df.partition_by(
'_SingleCell_batch')]
else:
if QC_column is not None and QC_columns[0] is not None:
batches = ((X, (batch_column.eq(batch) & QC_columns[0])
.to_numpy(), None)
for batch in batch_column.unique())
else:
batches = ((X, batch_column.eq(batch).to_numpy(), None)
for batch in batch_column.unique())
# Get the variance of each gene in each batch across cells passing QC
norm_gene_vars = []
for X, cell_mask, gene_indices in batches:
num_dataset_genes = X.shape[1]
mean = np.empty(num_dataset_genes, dtype=np.float32)
var = np.empty(num_dataset_genes, dtype=np.float32)
nonzero_count = np.empty(num_dataset_genes, dtype=np.uint32)
is_csr = isinstance(X, csr_array)
if is_csr:
# For CSR, `cell_mask` is either `None` (all cells) or a
# pre-computed int64 array of cell indices
if cell_mask is None:
cell_indices = np.array([], dtype=np.int64)
num_cells = X.shape[0]
else:
cell_indices = cell_mask
num_cells = len(cell_indices)
gene_mean_and_variance_csr(
data=X.data, indices=X.indices, indptr=X.indptr,
cell_indices=cell_indices, num_cells=num_cells,
num_dataset_genes=num_dataset_genes, mean=mean, var=var,
nonzero_count=nonzero_count, num_threads=num_threads)
else:
if cell_mask is None:
cell_mask = np.array([], dtype=bool)
num_cells = X.shape[0]
else:
num_cells = cell_mask.sum()
gene_mean_and_variance_csc(
data=X.data, indices=X.indices, indptr=X.indptr,
cell_mask=cell_mask, num_cells=num_cells,
num_dataset_genes=num_dataset_genes, mean=mean,
var=var, nonzero_count=nonzero_count,
num_threads=num_threads)
not_constant = var > 0
y = np.log10(var[not_constant])
x = np.log10(mean[not_constant])
model = loess(x, y, span=span)
try:
model.fit()
except ValueError as e:
any_small_batches = batch_column is not None and \
batch_column.value_counts()['count'].min() < 500
error_message = (
f'LOESS model fitting failed; this usually only happens '
f'when there are very few cells (e.g. under 500), which '
f'{"is" if num_cells < 500 else "is not"} the case here, '
f'or when batch_column is specified and certain batches '
f'have very few cells (e.g. under 500), which '
f'{"is" if any_small_batches else "is not"} the case here')
raise ValueError(error_message) from e
estimated_variance = np.empty(num_dataset_genes, dtype=np.float32)
estimated_variance[not_constant] = model.outputs.fitted_values
estimated_variance[~not_constant] = 0
estimated_stddev = np.sqrt(10 ** estimated_variance)
clip_val = mean + estimated_stddev * np.sqrt(num_cells,
dtype=np.float32)
batch_counts_sum = np.empty(num_dataset_genes, dtype=np.float32)
squared_batch_counts_sum = np.empty(num_dataset_genes,
dtype=np.float32)
if is_csr:
clipped_sum_csr(
data=X.data, indices=X.indices, indptr=X.indptr,
num_cells=num_cells, num_dataset_genes=num_dataset_genes,
cell_indices=cell_indices, clip_val=clip_val,
batch_counts_sum=batch_counts_sum,
squared_batch_counts_sum=squared_batch_counts_sum,
num_threads=num_threads)
else:
clipped_sum_csc(
data=X.data, indices=X.indices, indptr=X.indptr,
cell_mask=cell_mask, clip_val=clip_val,
batch_counts_sum=batch_counts_sum,
squared_batch_counts_sum=squared_batch_counts_sum,
num_threads=num_threads)
norm_gene_var = pl.Series(
(1 / ((num_cells - 1) * np.square(estimated_stddev))) *
((num_cells * np.square(mean)) + squared_batch_counts_sum -
2 * batch_counts_sum * mean))
# If `min_cells` is non-zero, set variances to `null` for genes
# with a non-zero count less than `min_cells`
if min_cells:
norm_gene_var = norm_gene_var\
.set(pl.Series(nonzero_count < min_cells), None)
# If there are multiple datasets, `norm_gene_var` is currently with
# respect to the genes in `dataset.var_names`; map back to the
# genes in `genes_in_any_dataset`, filling with `null`
if others:
norm_gene_var = norm_gene_var[gene_indices]
norm_gene_vars.append(norm_gene_var)
rank = pl.exclude('gene').rank('min', descending=True)
final_rank = pl.struct(
('median_rank', 'nbatches') if flavor == 'seurat_v3' else
('nbatches', 'median_rank')).rank('ordinal', descending=True)
# Note: the expression for `median_rank` can be replaced by
# `pl.median_horizontal(pl.exclude('gene'))` once polars implements it
hvgs = pl.DataFrame([genes_in_all_datasets] + norm_gene_vars)\
.lazy()\
.pipe(lambda df: df.drop_nulls(pl.selectors.exclude('gene'))
if min_cells or others else df)\
.with_columns(pl.when(rank <= num_genes).then(rank))\
.with_columns(nbatches=pl.sum_horizontal(pl.exclude('gene')
.is_not_null()),
median_rank=-pl.concat_list(pl.exclude('gene'))
.explode()
.median().over(pl.int_range(
pl.len(), dtype=pl.UInt32)))\
.select('gene', (final_rank <= num_genes).alias(hvg_column),
pl.when(final_rank <= num_genes).then(final_rank)
.alias(rank_column))\
.collect()
# Return a new SingleCell dataset (or a tuple of datasets, if others
# is non-empty) containing the highly variable genes
for dataset_index, dataset in enumerate(datasets):
new_var = dataset._var\
.pipe(lambda df: df.drop(hvg_column, rank_column, strict=False)
if overwrite else df)\
.join(hvgs.rename({'gene': dataset.var_names.name}),
on=dataset.var_names.name, how='left')\
.with_columns(pl.col(hvg_column).fill_null(False))
datasets[dataset_index] = \
SingleCell(X=dataset._X, obs=dataset._obs, var=new_var,
obsm=dataset._obsm, varm=dataset._varm,
uns=dataset._uns, num_threads=dataset._num_threads)
return tuple(datasets) if others else datasets[0]
[docs]
def normalize(self,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
method: Literal['PFlog1pPF', 'log1pPF',
'logCP10k'] = 'log1pPF',
inplace: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Normalize this SingleCell dataset's counts.
Must be run after `hvg()` and before `PCA()`.
`normalize()` supports three normalization methods. All three methods
normalize each cell independently of the rest and log-transform the
counts in some way, but differ in the details.
The simplest approach, `method='logCP10k'`, computes the log of the
counts per 10 thousand: `normalized_count = log(count / 10000 + 1)`.
It matches the default settings of Seurat's `NormalizeData()` function,
aside from differences in floating-point error. This method is not
recommended because it implicitly assumes an unrealistically large
amount of overdispersion, and performs worse in the benchmarks of the
papers discussed below.
The next-simplest approach, `method='log1pPF'`, is the default. Instead
of using the same denominator of 10 thousand for every cell, it
uses `X.sum(axis=1) / X.sum(axis=1).mean()` as the denominator. In
other words, a cell's denominator is the cell's library size, relative
to the mean library size across all cells. (By library size, we mean
the sum of a cell's counts across all genes.) This approach of dividing
each cell's counts by its relative library size is sometimes called
"proportional fitting" (PF).
[Ahlmann-Eltze and Huber 2023](https://nature.com/articles/s41592-023-01814-1)
recommend using this method instead of normalizing by a fixed
denominator, like `method='logCP10k'` does. Scanpy's
`normalize_total()` uses this method with a slight variation: it uses
median instead of mean to define the relative library size.
The most complex approach, `method='PFlog1pPF'`, takes the output of
`log1pPF` and applies an additional round of proportional fitting after
the log-transformation.
[Booeshaghi et al. 2022](https://biorxiv.org/content/10.1101/2022.05.06.490859v1.full)
recommend this approach, arguing that log1pPF does not
fully normalize for read depth, because the log transform partially
undoes the first round of proportional fitting.
Args:
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will still be
normalized, but will not count towards the calculation
of the mean total count across cells when `method` is
`'PFlog1pPF'` or `'log1pPF'`. Has no effect when
`method` is `'logCP10k'`.
method: the normalization method to use (see above)
inplace: whether to do in-place normalization. This reduces memory
usage, but is only possible for float32 count matrices and
will raise an error if the count matrix has any other data
type.
num_threads: the number of threads to use when normalizing. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores. Does not affect the
normalized SingleCell dataset's `num_threads`; this
will always be the same as the original dataset's
`num_threads`.
Returns:
A new SingleCell dataset with the normalized counts, and
`uns['normalized']` set to `True`. Or when `inplace=True`, return
the original dataset with the counts normalized in-place.
"""
# Check that `X` is present
if self._X is None:
error_message = 'X is None, so normalizing is not possible'
raise ValueError(error_message)
# Check that `self` is QCed and not already normalized
if not self._uns['QCed']:
error_message = (
"uns['QCed'] is False; did you forget to run qc() (and "
"possibly hvg()) before normalize()? Set uns['QCed'] = True "
"or run skip_qc() to bypass this check.")
raise ValueError(error_message)
if self._uns['normalized']:
error_message = \
"uns['normalized'] is True; did you already run normalize()?"
raise ValueError(error_message)
# Get the QC column, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
# Check that `method` is one of the three valid methods
if method == 'PFlog1pPF':
method_number = 2
elif method == 'log1pPF':
method_number = 1
else:
if method != 'logCP10k':
error_message = (
"method must be one of 'PFlog1pPF', 'log1pPF', or "
"'logCP10k'")
raise ValueError(error_message)
method_number = 0
# Check that `inplace` is Boolean
check_type(inplace, 'inplace', bool, 'Boolean')
# When `inplace=True`, raise an error if the count matrix is not
# float32
if inplace and self._X.dtype != np.float32:
error_message = (
f'inplace normalization requires floating-point count '
f'matrices, but X has data type {str(self._X.dtype)!r}')
raise TypeError(error_message)
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
X = self._X
if isinstance(X, csr_array):
normalize = normalize_csr
sparse_array = csr_array
else:
normalize = normalize_csc
sparse_array = csc_array
num_cells = X.shape[0]
if num_threads == 1:
row_sums = np.empty(num_cells, dtype=np.uint64)
else:
row_sums = numa_zeros(num_cells, dtype=np.uint64)
if inplace:
normalize(data=X.data, indices=X.indices, indptr=X.indptr,
QC_column=QC_column.to_numpy()
if QC_column is not None else
np.array([], dtype=bool),
normalized_data=X.data, row_sums=row_sums,
num_cells=num_cells, method_number=method_number,
num_threads=num_threads)
return self
else:
if num_threads == 1:
normalized_data = np.empty(len(X.data), dtype=np.float32)
else:
normalized_data = numa_zeros(len(X.data), dtype=np.float32)
original_num_threads = X._num_threads
normalize(data=X.data, indices=X.indices, indptr=X.indptr,
QC_column=QC_column.to_numpy()
if QC_column is not None else
np.array([], dtype=bool),
normalized_data=normalized_data, row_sums=row_sums,
num_cells=num_cells, method_number=method_number,
num_threads=num_threads)
X = sparse_array((normalized_data, X.indices, X.indptr),
shape=X.shape)
X._num_threads = original_num_threads
sc = SingleCell(X=X, obs=self._obs, var=self._var, obsm=self._obsm,
varm=self._varm, uns=self._uns,
num_threads=self._num_threads)
sc._uns['normalized'] = True
return sc
[docs]
def pca(self,
*others: SingleCell,
QC_column: SingleCellColumn | None |
Sequence[SingleCellColumn | None] = 'passed_QC',
hvg_column: SingleCellColumn |
Sequence[SingleCellColumn] | None = 'highly_variable',
PC_key: str = 'pca',
num_PCs: int | np.integer = 50,
subspace_size: int | np.integer = 100,
tolerance: int | np.integer | float | np.floating = 1e-6,
max_iterations: int | np.integer = 100,
chunk_size: int | np.integer = 1024,
seed: int | np.integer = 0,
match_parallel: bool = False,
overwrite: bool = False,
verbose: bool = True,
num_threads: int | np.integer | None = None) -> \
SingleCell | tuple[SingleCell, ...]:
"""
Compute principal components (PCs) across cells.
Requires normalized counts, so must be run after `normalize()`. By
default, only the highly variable genes from `hvg()` are used to
compute PCs.
Uses approximate singular value decomposition (SVD) via the Implicitly
Restarted Lanczos Bidiagonalization Algorithm (IRLBA). Seurat uses a
different implementation of the same IRLBA algorithm.
Args:
others: optional SingleCell datasets to jointly compute principal
components across, alongside this one.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their PCs set to `NaN`. When `others` is
specified, `QC_column` can be a length-`1 + len(others)`
sequence of columns, expressions, Series, functions, or
`None` for each dataset (for `self`, followed by each
dataset in `others`).
hvg_column: an optional Boolean column of `var` indicating the
highly variable genes. Set to `None` to include all
genes. Can be a column name, a polars expression, a
polars Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a polars
Series or 1D NumPy array. When `others` is specified,
`hvg_column` can be a length-`1 + len(others)` sequence
of columns, expressions, Series, functions, or `None`
for each dataset (for `self`, followed by each dataset
in `others`).
PC_key: the key of `obsm` where the principal components will be
stored
num_PCs: the number of top principal components to calculate
subspace_size: the size of the Krylov subspace used by IRLBA when
calculating PCs. Must be greater than or equal to
`num_PCs`, and about twice `num_PCs` is recommended.
tolerance: the relative tolerance (expressed as the ratio of a
singular value's residual to the maximum singular value)
required to deem a singular value converged. IRLBA will
stop early, before `max_iterations` iterations, if all
singular values have converged.
max_iterations: the maximum number of iterations to run IRLBA for,
stopping early if all singular values have
converged (see `tolerance`)
chunk_size: the number of rows per fixed block in deterministic
parallel reductions. Used to parallelize operations
like mean and norm that would otherwise be serial.
Block boundaries are fixed regardless of thread count,
ensuring floating-point identical results.
seed: the random seed to use when initializing the PCs, via R's
`set.seed()` function
match_parallel: if `False`, use a different order of operations for
single-threaded PCA. This gives a moderate (~60%)
boost in single-threaded performance, and lower
memory usage, at the cost of no longer exactly
matching the PCs produced by the multithreaded
version (due to differences in floating-point error
arising from the different order of operations).
When `match_parallel=False`, `PCA()` will also give
slightly different results when run with CSR vs CSC
input; when multiple datasets are provided with a
mix of CSR and CSC formats, all datasets are
converted to the format shared by the most total
cells across datasets. If `True`, exactly match the
results of the multithreaded version when
`num_threads=1`. Must be `False` unless
`num_threads=1`.
overwrite: if `True`, overwrite `PC_key` if already present in
obsm, instead of raising an error
verbose: whether to print a message when the singular values did
not converge to a tolerance of `tolerance` within
`max_iterations` iterations
num_threads: the number of threads to use for PCA. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores. Does not affect the returned
SingleCell dataset's `num_threads`; this will always
be the same as the original dataset's `num_threads`.
Returns:
A new SingleCell dataset where `obsm` contains an additional key,
`PC_key` (default: `'pca'`), containing the top `num_PCs` principal
components. Or, if additional SingleCell dataset(s) are specified
via the `others` argument, a length-`1 + len(others)` tuple of
SingleCell datasets with the PCs added: `self`, followed by each
dataset in `others`.
Note:
Unlike Seurat's `RunPCA()` function, which requires `ScaleData()`
to be run first, this function does not require the data to be
scaled beforehand. Instead, it implicitly scales the data to zero
mean and unit variance while performing PCA.
"""
# If `others` was specified, check that all elements of `others` are
# SingleCell datasets
if others:
check_types(others, 'others', SingleCell, 'SingleCell datasets')
datasets = [self] + list(others)
# Check that `PC_key` is a string
check_type(PC_key, 'PC_key', str, 'a string')
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Check that `PC_key` is not already in `obsm`, unless `overwrite=True`
suffix = ' for at least one dataset' if others else ''
for dataset in datasets:
if not overwrite and PC_key in dataset._obsm:
error_message = (
f'PC_key {PC_key!r} is already a key of obsm{suffix}; did '
f'you already run PCA()? Set overwrite=True to overwrite.')
raise ValueError(error_message)
# Check that `X` is present for every dataset
if any(dataset._X is None for dataset in datasets):
error_message = (
f'X is None{suffix}, so finding principal components is not '
f'possible')
raise ValueError(error_message)
# Get `QC_column` and `hvg_column` (if not `None`) from every dataset
QC_columns = SingleCell._get_columns(
'obs', datasets, QC_column, 'QC_column', pl.Boolean,
allow_missing=True)
hvg_columns = SingleCell._get_columns(
'var', datasets, hvg_column, 'hvg_column', pl.Boolean,
custom_error=f'hvg_column {{}} is not a column of var{suffix}; '
f'did you forget to run hvg() (and possibly '
f'normalize()) before PCA()?')
# Check that all datasets are normalized
if not all(dataset._uns['normalized'] for dataset in datasets):
error_message = (
f"PCA() requires normalized counts but uns['normalized'] is "
f"False{suffix}; did you forget to run normalize() before "
f"PCA()?")
raise ValueError(error_message)
# Raise an error if `X` is not float32 for every dataset
for dataset in datasets:
if dataset._X.dtype != np.float32:
error_message = (
f'PCA() requires normalized counts with data type '
f'float32, but X has data type '
f'{str(dataset._X.dtype)!r}{suffix}; did you forget to '
f'run normalize() before PCA()?')
raise TypeError(error_message)
# Check that `num_PCs` is a positive integer
check_type(num_PCs, 'num_PCs', int, 'a positive integer')
check_bounds(num_PCs, 'num_PCs', 1)
# Check that `subspace_size` is a positive integer, and >= `num_PCs`
check_type(subspace_size, 'subspace_size', int, 'a positive integer')
if subspace_size < num_PCs:
error_message = (
f'subspace_size is {subspace_size:,}, but must be ≥ num_PCs '
f'({num_PCs:,})')
raise ValueError(error_message)
# Check that `tolerance` is a positive number
check_type(tolerance, 'tolerance', (int, float), 'a positive number')
check_bounds(tolerance, 'tolerance', 0, left_open=True)
# Check that `max_iterations` is a positive integer
check_type(max_iterations, 'max_iterations', int, 'a positive integer')
check_bounds(max_iterations, 'max_iterations', 1)
# Check that `chunk_size` is a positive integer
check_type(chunk_size, 'chunk_size', int, 'a positive integer')
check_bounds(chunk_size, 'chunk_size', 1)
# Check that `seed` is an integer
check_type(seed, 'seed', int, 'an integer')
# Check that `verbose` is Boolean
check_type(verbose, 'verbose', bool, 'Boolean')
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to
# `os.cpu_count()`.
num_threads = self._process_num_threads(num_threads)
# Check that `match_parallel` is Boolean, and `False` unless
# `num_threads=1`
check_type(match_parallel, 'match_parallel', bool, 'Boolean')
if match_parallel and num_threads != 1:
error_message = \
'match_parallel must be False unless num_threads is 1'
raise ValueError(error_message)
# Get the matrix to compute PCA across: a sparse matrix of counts for
# highly variable genes (or all genes, if `hvg_column` is `None`)
# across cells passing QC. Use `X[np.ix_(rows, columns)]` as a faster,
# more memory-efficient alternative to `X[rows][:, columns]`.
original_num_threads = self._X._num_threads
try:
self._X._num_threads = num_threads
if others:
if hvg_column is None:
genes_in_all_datasets = self.var_names\
.filter(self.var_names
.is_in(pl.concat([dataset.var_names
for dataset in others])))
else:
hvg_in_self = \
self._var.filter(hvg_columns[0]).to_series() \
if hvg_columns[0] is not None else \
self._var.to_series()
genes_in_all_datasets = hvg_in_self\
.filter(hvg_in_self.is_in(pl.concat([
dataset._var.filter(hvg_col).to_series()
if hvg_col is not None else
dataset._var.to_series()
for dataset, hvg_col in
zip(others, hvg_columns[1:])])))
gene_indices = (
genes_in_all_datasets
.to_frame()
.join(dataset._var.with_columns(
_SingleCell_index=pl.int_range(pl.len(),
dtype=pl.Int32)),
left_on=genes_in_all_datasets.name,
right_on=dataset.var_names.name, how='left')
['_SingleCell_index']
.to_numpy()
for dataset in datasets)
if QC_column is None:
Xs = [dataset._X[:, genes]
for dataset, genes in zip(datasets, gene_indices)]
else:
Xs = [dataset._X[np.ix_(QC_col.to_numpy(), genes)]
if QC_col is not None else dataset._X[:, genes]
for dataset, genes, QC_col in
zip(datasets, gene_indices, QC_columns)]
else:
if QC_column is None:
if hvg_column is None:
Xs = [dataset._X for dataset in datasets]
else:
Xs = [dataset._X[:, hvg_col.to_numpy()]
if hvg_col is not None else dataset._X
for dataset, hvg_col in
zip(datasets, hvg_columns)]
else:
if hvg_column is None:
Xs = [dataset._X[QC_col.to_numpy()]
if QC_col is not None else dataset._X
for dataset, QC_col in zip(datasets, QC_columns)]
else:
Xs = [(dataset._X[np.ix_(QC_col.to_numpy(),
hvg_col.to_numpy())]
if QC_col is not None else
dataset._X[:, hvg_col.to_numpy()])
if hvg_col is not None else
(dataset._X[QC_col.to_numpy()]
if QC_col is not None else dataset._X)
for dataset, QC_col, hvg_col in
zip(datasets, QC_columns, hvg_columns)]
finally:
self._X._num_threads = original_num_threads
num_cells_per_dataset = np.array([X.shape[0] for X in Xs])
if len(Xs) == 1:
X = Xs[0]
elif all(isinstance(X, csr_array) for X in Xs):
X = sparse_major_stack(Xs, num_threads=num_threads)
elif all(isinstance(X, csc_array) for X in Xs):
X = sparse_minor_stack(Xs, num_threads=num_threads)
else:
# Mix of CSR and CSC: convert to whichever format has the most
# total cells in that format, to reduce the number of datasets that
# need to be flipped
total_cells = num_cells_per_dataset.sum()
total_csr = sum(X.shape[0] for X in Xs if isinstance(X, csr_array))
if total_csr / total_cells > 0.5: # more CSR than CSC
X = sparse_major_stack([
X.tocsr() if isinstance(X, csc_array) else X for X in Xs],
num_threads=num_threads)
else:
X = sparse_minor_stack([
X.tocsc() if isinstance(X, csr_array) else X for X in Xs],
num_threads=num_threads)
del Xs
# Check that the number of cells and genes are >= `num_PCs`, and that
# there are at least two cells and two genes even if `num_PCs` is 1
num_cells, num_genes = X.shape
if num_cells < num_PCs:
error_message = (
f'num_PCs is {num_PCs:,}, but must be ≤ the number of cells '
f'({num_cells:,})')
raise ValueError(error_message)
if num_genes < num_PCs:
error_message = (
f'num_PCs is {num_PCs:,}, but must be ≤ the number of genes '
f'({num_genes:,})')
raise ValueError(error_message)
if num_cells == 1:
error_message = (
f'there is only one cell, so principal components cannot be '
f'calculated')
raise ValueError(error_message)
if num_genes == 1:
error_message = (
f'there is only one gene, so principal components cannot be '
f'calculated')
raise ValueError(error_message)
# We want to perform SVD on `X`, scaled to zero mean and unit variance.
# Key ideas:
# 1. Because `X` is a sparse matrix, mean-centering cannot be done
# without converting to a dense matrix. So scaling `X` cannot be
# done in the conventional way.
# 2. Fortunately, we can represent scaling as a matrix product:
# `scale(X) = C @ X @ W`, where `W` is a diagonal matrix of the
# standard deviations for each column (gene) and `C` is a "centering
# matrix" that, when applied to any vector or matrix, yields the
# mean-centered version of it.
# 3. We need to calculate the matrix-vector product of our operator
# with some vector `V`, i.e. `scale(X) @ V`. Using the formula from
# point #2, this is equivalent to `C @ X @ W @ V`. Since `W` is
# diagonal, this is equivalent to `C @ (X @ (V / diag(W)))`. In
# other words:
# - divide `V` (which has length `num_genes)` elementwise by
# `diag(W)`, the genewise standard deviations
# - matrix-vector multiply by `X`
# - mean-center the resulting vector, which has length `num_cells`
# 4. We also need to calculate `scale(X).T @ V` for the `rmatvec()`
# part of our operator. Rewriting as `W.T @ X @ C.T @ V` and
# leveraging the fact that `C` turns out to be symmetric (as is `W`,
# since it's diagonal), this is equivalent to
# `(X @ (C @ V)) / diag(W)`. In other words:
# - mean-center `V`, which has length `num_cells`
# - matrix-vector multiply by `X.T`
# - divide the result (which has length `num_genes)` elementwise by
# `diag(W)`, the genewise standard deviations
# 5. CSR matrix-vector multiplication can be done efficiently
# multithreaded, but CSC can't (except by maintaining thread-local
# versions of the output vector and summing them across threads at
# the end, which leads to differences in floating-point error
# depending on the number of threads). This is problematic because,
# regardless of whether `X` is a CSR or a CSC matrix, we need to do
# some multiplications involving `X.T` and others involving `X`. So
# if `X` is CSR, the `X.T` multiplications will be single-threaded
# since `X.T` is a CSC matrix. If `X` is CSC, the `X`
# multiplications will be single-threaded. So one of the two
# multiplications will always be single-threaded. There are hundreds
# of these matrix-vector multiplications and they take up a large
# majority of the total runtime for PCA, so not being able to fully
# multithread them is a huge disadvantage.
# 6. To address the issue in the previous point, make both a CSR and a
# CSC copy of `X` when running PCA in parallel. The first
# multiplication (involving `X.T`) will use the CSC copy of `X`, but
# plugged into the CSR matrix-vector multiplication function. The
# second multiplication (involving `X`) will use CSR multiplication
# as normal. This works because plugging a CSC copy of `X` into a
# matrix-vector multiplication routine designed for CSR matrices (or
# vice versa) is equivalent to multiplying by `X.T` instead of `X`.
# The result: both matrix multiplications can be done in parallel.
# This also has the nice side benefit that the final result is the
# same regaredless of whether the counts are input as CSR or CSC.
# 7. When running single-threaded with `faster_single_threaded=True`,
# however, just use whichever version of `X` (CSR or CSC) we happen
# to have available, to avoid the runtime and memory overhead of
# creating both versions. However, this means that CSR and CSC no
# longer give exactly the same PCs, due to differences in
# floating-point error.
# 8. When calculating the scaled variance for each gene, use the CSC
# version of `X` if available, for speed. Clip tiny standard
# deviations (less than 1e-8) to 1e-8, like Seurat.
# 9. In the parallel code path, all U columns produced by the matvec
# are zero-mean by construction (the centering matrix is applied).
# Reorthogonalization subtracts a linear combination of previous
# zero-mean columns, preserving zero-mean. Scaling by 1/alpha
# preserves zero-mean. This invariant enables two optimizations:
# (a) the matvec can skip mean-centering and defer it to a fused
# pass that also performs reorthogonalization and norm computation;
# (b) the rmatvec can skip mean-centering entirely since its input
# (a U column) is already zero-mean.
# Note: this optimization changes the floating-point result versus
# the single-threaded path (which does redundant mean-centering),
# but match_parallel=True is not allowed when num_threads > 1, so
# the parallel path only has to be self-consistent across thread
# counts.
# Get the data, indices and indptr to be used for the matvec and
# rmatvec. When `match_parallel=True`, this requires calculating the
# "opposite" version of the array (CSC if `X` is CSR, CSR if CSC).
is_csr = isinstance(X, csr_array)
if is_csr:
# `X` is CSR, so always use CSR for the matvec. If `num_threads=1`
# and `match_parallel=False`, use CSR for the rmatvec as well,
# otherwise use CSC for the rmatvec to enable paralellism.
data_matvec = X.data
indices_matvec = X.indices
indptr_matvec = X.indptr
if num_threads == 1 and not match_parallel:
data_rmatvec = data_matvec
indices_rmatvec = indices_matvec
indptr_rmatvec = indptr_matvec
else:
X_csc = X.tocsc()
data_rmatvec = X_csc.data
indices_rmatvec = X_csc.indices
indptr_rmatvec = X_csc.indptr
else:
# `X` is CSC, so always use CSC for the rmatvec. If `num_threads=1`
# and `match_parallel=False`, use CSC for the matvec as well,
# otherwise use CSR for the matvec to enable paralellism.
data_rmatvec = X.data
indices_rmatvec = X.indices
indptr_rmatvec = X.indptr
if num_threads == 1 and not match_parallel:
data_matvec = data_rmatvec
indices_matvec = indices_rmatvec
indptr_matvec = indptr_rmatvec
else:
X_csr = X.tocsr()
data_matvec = X_csr.data
indices_matvec = X_csr.indices
indptr_matvec = X_csr.indptr
# Run PCA with irlba. Use `threadpool_limits()` to run BLAS
# single-threaded when `num_threads=1`. (Unlike the other functions in
# the library that use `threadpool_limits()`, here the multi-threaded
# code path contains BLAS calls outside `prange()`, and these need to
# be explicitly limited to one thread with `threadpool_limits(1)`
# inside Cython.)
if num_threads == 1:
PCs = np.empty((num_cells, num_PCs), dtype=np.float32)
else:
PCs = numa_zeros((num_cells, num_PCs), dtype=np.float32)
original_num_threads = X._num_threads
X._num_threads = num_threads
try:
with threadpool_limits(num_threads):
num_cells, num_genes = X.shape
converged = irlba(
data_matvec=data_matvec, indices_matvec=indices_matvec,
indptr_matvec=indptr_matvec, data_rmatvec=data_rmatvec,
indices_rmatvec=indices_rmatvec,
indptr_rmatvec=indptr_rmatvec, num_cells=num_cells,
num_genes=num_genes, k=num_PCs,
subspace_size=subspace_size, tolerance=tolerance,
max_iterations=max_iterations, seed=seed,
match_parallel=match_parallel, is_csr=is_csr,
num_threads=num_threads, chunk_size=chunk_size, PCs=PCs)
finally:
X._num_threads = original_num_threads
# Print a message if PCs did not converge and `verbose=True`
if not converged and verbose:
print(f'PCA did not converge to a tolerance of {tolerance:.2g} '
f'after {max_iterations:,} iterations; consider increasing '
f'max_iterations or tolerance')
# Store each dataset's PCs in its `obsm`
if not others: # just one dataset
QC_col = QC_columns[0]
# If `QC_col` is not `None`, back-project from QCed cells to all
# cells, filling with `NaN`
if QC_col is not None:
PCs_QCed = PCs
PCs = np.full((len(self), PCs_QCed.shape[1]), np.nan,
dtype=np.float32)
PCs[QC_col.to_numpy()] = PCs_QCed
return SingleCell(
X=dataset._X, obs=dataset._obs, var=dataset._var,
obsm=dataset._obsm | {PC_key: PCs}, varm=self._varm,
uns=self._uns, num_threads=self._num_threads)
else:
for dataset_index, (dataset, QC_col, num_cells, end_index) in \
enumerate(zip(datasets, QC_columns, num_cells_per_dataset,
num_cells_per_dataset.cumsum())):
start_index = end_index - num_cells
dataset_PCs = PCs[start_index:end_index]
# If `QC_col` is not `None` for this dataset, back-project from
# QCed cells to all cells, filling with `NaN`
if QC_col is not None:
dataset_PCs_QCed = dataset_PCs
dataset_PCs = np.full((len(dataset),
dataset_PCs_QCed.shape[1]), np.nan,
dtype=np.float32)
dataset_PCs[QC_col.to_numpy()] = dataset_PCs_QCed
datasets[dataset_index] = SingleCell(
X=dataset._X, obs=dataset._obs, var=dataset._var,
obsm=dataset._obsm | {PC_key: dataset_PCs},
varm=self._varm, uns=self._uns,
num_threads=self._num_threads)
return tuple(datasets)
[docs]
def neighbors(self,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
PC_key: str = 'pca',
neighbors_key: str = 'neighbors',
distances_key: str = 'distances',
num_neighbors: int | np.integer = 20,
num_clusters: int | np.integer | None = None,
num_clusters_searched: int | np.integer | None = None,
num_kmeans_iterations: int | np.integer = 2,
kmeans_tolerance: int | np.integer | float |
np.floating = 1e-2,
kmeans_barbar: bool = False,
num_init_iterations: int | np.integer = 5,
oversampling_factor: int | np.integer | float |
np.floating = 1,
chunk_size_kmeans: int | np.integer | None = None,
chunk_size_search: int | np.integer | None = None,
seed: int | np.integer = 0,
overwrite: bool = False,
verbose: bool = True,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Calculate the `num_neighbors` nearest neighbors of each cell.
`neighbors()` is intended to be run after `PCA()`; by default, it uses
`obsm['pca']` as the input to the nearest-neighbors calculation.
`neighbors()` must be re-run if the dataset is subset; not doing so
will raise an error. A cell is not considered its own nearest neighbor.
`neighbors()` is based on the `IndexIVFFlat` search strategy from the
widely-used `faiss` nearest-neighbor search library. It works in two
main steps:
1) Perform k-means clustering to subdivide the dataset into
`num_clusters` clusters. For large datasets, the number of clusters
is four times the square root of the number of cells, rounding up;
for small datasets, 1% of the number of cells.
2) Find each cell's `num_clusters_searched` (default: 64) nearest
cluster centroids (one of which is the cell's own cluster), then
search these clusters exhaustively for nearest-neighbor candidates.
The key optimization: instead of going one cell at a time and searching
its 64 nearest clusters, group cells into batches of
`chunk_size_search` (default: 256), enumerate which clusters are one of
the 64 nearest to at least one cell in the batch, then run the searches
across each of these clusters, one cluster at a time. Since cells in a
block will tend to share many of their nearest 64 clusters, this
amortizes the cost of streaming the cluster's cells from memory across
all the cells in the block that need them, speeding up the search by an
order of magnitude for large datasets.
The key parameters that affect accuracy and runtime are `num_clusters`
and `num_clusters_searched`. Runtime goes up about linearly with the
fraction of cells searched, which is roughly
`num_clusters_searched / num_clusters`. Accuracy will also increase,
but with strongly diminishing returns. We recommend increasing
`num_clusters_searched` (e.g. to 128) if greater accuracy is desired,
and decreasing it (e.g. to 32) if greater speed is desired.
Args:
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their nearest neighbors set to `-1`.
PC_key: the key of `obsm` containing the principal components
calculated with `PCA()`, to use as the input for the
nearest-neighbors calculation
neighbors_key: the key of `obsm` where the nearest neighbors will
be stored
distances_key: the key of `obsm` where the squared Euclidean
distance to each nearest neighbor will be stored
num_neighbors: the number of nearest neighbors to report for each
cell; must be less than or equal to the number of
cells
num_clusters: the number of k-means clusters to use during the
nearest-neighbor search. Must be less than the number
of cells. If `None`, will be set to
`ceil(min(4 * sqrt(num_cells), num_cells / 100))`
clusters, i.e. the minimum of four times the square
root of the number of cells in `other` and 1% of the
number of cells in `other`, rounding up. The core of
the heuristic, `4 * sqrt(num_cells)`, is the low end
of the range
[recommended](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index)
by faiss, 4 to 16 times the square root. However,
faiss also recommends using between 39 and 256 data
points per centroid when training the k-means
clustering used in the k-nearest neighbors search. To
avoid going below 39, we switch to using
`num_cells / 100` centroids for small datasets (fewer
than 640,000 cells), since 100 is the midpoint of 39
and 256 in log space. For datasets of at least 10,000
cells, clusters with fewer than 100 cells will be
merged into the adjacent cluster with the nearest
centroid, so the actual number of clusters used may
be smaller.
num_clusters_searched: the number of a cell's nearest clusters to
search; must be between 1 and
`num_clusters`. Defaults to
`min(64, num_clusters)`.
num_kmeans_iterations: the maximum number of iterations of k-means
clustering to perform before starting the
nearest-neighbor search, stopping early if a
relative convergence of `kmeans_tolerance`
is reached
kmeans_tolerance: the relative change in inertia (the sum of
squared distances from each cell to its assigned
centroid) used to determine whether to stop
optimizing the k-means clustering before
`num_kmeans_iterations` iterations
kmeans_barbar: whether to use
[k-means||](https://arxiv.org/abs/1203.6402)
initialization (a parallel version of k-means++)
to initialize the k-means clustering centroids,
instead of random initialization. This is more
accurate but takes considerably longer for large
datasets and `num_clusters`.
num_init_iterations: the number of k-means|| iterations used to
initialize the k-means clustering that
constitutes the first step of the
nearest-neighbor search. k-means|| is a
parallel version of the widely used k-means++
initialization scheme for k-means clustering.
The default value of 5 is recommended by the
[k-means|| paper](https://arxiv.org/abs/1203.6402).
Only used when `kmeans_barbar=True`.
oversampling_factor: the number of candidate centroids selected, on
average, at each of the `num_init_iterations`
iterations of k-means||, as a multiple of
`num_clusters`. The default value of 1 is the
midpoint (in log space) of the values explored
by the
[k-means|| paper](https://arxiv.org/abs/1203.6402),
namely 0.1 to 10. The total number of candidate
centroids selected, on average, will be
`oversampling_factor * num_clusters + 1`, from
which the final `num_clusters` centroids will
then be selected via k-means++. Only used when
`kmeans_barbar=True`.
chunk_size_kmeans: the chunk size used for distance calculations
during k-means clustering, and also during the
per-query centroid ranking step of the
nearest-neighbor search. Setting this to a power
of 2 is recommended. Defaults to
`min(4096, num_cells)`.
chunk_size_search: the chunk size used to group query cells
together during the nearest-neighbor search.
Overly small values will tend to increase
runtime by reducing the reuse of information
during the search, whereas overly large values
will lead to excessive memory use. Defaults to
`min(256, num_cells)`.
seed: the random seed to use when finding nearest neighbors
overwrite: if `True`, overwrite `neighbors_key` if already present
in `obsm`, instead of raising an error
verbose: whether to print details of the nearest-neighbor search
num_threads: the number of threads to use when finding nearest
neighbors. Set `num_threads=-1` to use all available
cores, as determined by `os.cpu_count()`, or leave
unset to use `self.num_threads` cores. Does not affect
the returned SingleCell dataset's `num_threads`; this
will always be the same as the original dataset's
`num_threads`. `num_threads` will be capped to 64
when running with Scipy linked against OpenBLAS (see
warning below).
Returns:
A new SingleCell dataset with the indices of each cell's nearest
neighbors - not counting the cell itself - stored in
`obsm[neighbors_key]` as a `len(obs)` × `num_neighbors` NumPy
array, where `obsm[neighbors_key][i, j]` stores the index of the
`i`th cell's `j + 1`th nearest neighbor. (This differs from Scanpy
and Seurat, which use a less compact sparse matrix representation
instead.) For instance, if `num_neighbors=2` and the 0th cell's
nearest neighbors are the 4th cell and the 6th cell, then
`obsm[neighbors_key][0]` will be `np.array([4, 6])`. Note that if
`QC_column` is not `None`, these integer indices are with respect
to QCed cells, not all cells. The squared Euclidean distance to
each nearest neighbor will be stored in `obsm[distances_key]` as an
array of the same shape as `obsm[neighbors_key]`.
Warning:
If you installed Scipy via pip, it will be linked against OpenBLAS,
and `neighbors()` will be limited to 64 threads due to the
limitations of OpenBLAS. To use more than 64 threads, install Scipy
linked against MKL BLAS. This is done automatically when installing
brisc via conda, but you can also do it manually via
`conda install "libblas=*=*mkl" scipy`.
"""
# Check that `neighbors_key` is a string
check_type(neighbors_key, 'neighbors_key', str, 'a string')
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Check that `neighbors_key` is not already a key in `obsm`, unless
# `overwrite=True`
if not overwrite and neighbors_key in self._obsm:
error_message = (
f'neighbors_key {neighbors_key!r} is already a key of obsm; '
f'did you already run neighbors()? Set overwrite=True to '
f'overwrite.')
raise ValueError(error_message)
# Get the QC column, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
# Check that `PC_key` is the name of a key in `obsm`
check_type(PC_key, 'PC_key', str, 'a string')
if PC_key not in self._obsm:
error_message = f'PC_key {PC_key!r} is not a key of obsm'
if PC_key == 'pca':
error_message += \
'; did you forget to run PCA() before neighbors()?'
raise ValueError(error_message)
# Get PCs, and check that they are float32 and C-contiguous
PCs = self._obsm[PC_key]
if PCs.dtype != np.float32:
error_message = \
f'obsm[{PC_key!r}].dtype is {PCs.dtype!r}, but must be float32'
raise TypeError(error_message)
if not PCs.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{PC_key!r}] is not C-contiguous; make it C-contiguous '
f'with pipe_obsm_key({PC_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
# Check that `num_kmeans_iterations` is a positive integer
check_type(num_kmeans_iterations, 'num_kmeans_iterations', int,
'a positive integer')
check_bounds(num_kmeans_iterations, 'num_kmeans_iterations', 1)
# Check that `kmeans_tolerance` is a positive number
check_type(kmeans_tolerance, 'kmeans_tolerance', (int, float),
'a positive number')
check_bounds(kmeans_tolerance, 'kmeans_tolerance', 0, left_open=True)
# Check that `kmeans_barbar` is Boolean
check_type(kmeans_barbar, 'kmeans_barbar', bool, 'Boolean')
# Check that `num_init_iterations` is a positive integer
check_type(num_init_iterations, 'num_init_iterations', int,
'a positive integer')
check_bounds(num_init_iterations, 'num_init_iterations', 1)
# Check that `oversampling_factor` is a positive number
check_type(oversampling_factor, 'oversampling_factor', (int, float),
'a positive number')
check_bounds(oversampling_factor, 'oversampling_factor', 0,
left_open=True)
# If `kmeans_barbar=False`, check that `num_init_iterations` and
# `oversampling_factor` have their default values
if not kmeans_barbar:
if num_init_iterations != 5:
error_message = (
'num_init_iterations can only be specified when '
'kmeans_barbar=True')
raise ValueError(error_message)
if oversampling_factor != 1:
error_message = (
'oversampling_factor can only be specified when '
'kmeans_barbar=True')
raise ValueError(error_message)
# Check that `seed` is an integer
check_type(seed, 'seed', int, 'an integer')
# Check that `verbose` is Boolean
check_type(verbose, 'verbose', bool, 'Boolean')
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Cap to 64 threads for OpenBLAS to avoid the error "OpenBLAS : Program
# is Terminated. Because you tried to allocate too many memory
# regions".
if num_threads > 64 and any(lib.get('internal_api') == 'openblas'
for lib in threadpool_info()):
num_threads = 64
# Subset PCs to QCed cells only, if `QC_column` is not `None`
if QC_column is not None:
QC_column_NumPy = QC_column.to_numpy()
if num_threads == 1:
PCs = PCs[QC_column_NumPy]
else:
indices = np.flatnonzero(QC_column_NumPy)
PCs = parallel_subset_2d(PCs, indices, num_threads)
# Check that `num_neighbors` is between 1 and `num_cells - 1`.
check_type(num_neighbors, 'num_neighbors', int, 'a positive integer')
num_cells = len(PCs)
if not 1 <= num_neighbors < num_cells:
error_message = (
f'num_neighbors is {num_neighbors:,}, but must be ≥ 1 and '
f'less than the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that `num_clusters` is between 1 and `num_cells`; if `None`,
# set to `ceil(min(4 * sqrt(num_cells), num_cells / 100)))`
if num_clusters is None:
num_clusters = \
int(np.ceil(min(4 * np.sqrt(num_cells), num_cells / 100)))
else:
check_type(num_clusters, 'num_clusters', int, 'a positive integer')
if not 1 <= num_clusters <= num_cells:
error_message = (
f'num_clusters is {num_clusters:,}, but must be ≥ 1 and ≤ '
f'the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that `num_clusters_searched` is between 1 and `num_clusters`;
# if `None`, set to `min(64, num_clusters)`
if num_clusters_searched is None:
num_clusters_searched = min(64, num_clusters)
else:
check_type(num_clusters_searched, 'num_clusters_searched', int,
'a positive integer')
if not 1 <= num_clusters_searched <= num_clusters:
error_message = (
f'num_clusters_searched is {num_clusters_searched:,}, but '
f'must be ≥ 1 and ≤ num_clusters ({num_clusters:,})')
raise ValueError(error_message)
# Check that `chunk_size_kmeans` is between 1 and `num_cells`; if
# `None`, set to `min(4096, num_cells)`
if chunk_size_kmeans is None:
chunk_size_kmeans = min(4096, num_cells)
else:
check_type(chunk_size_kmeans, 'chunk_size_kmeans', int,
'a positive integer')
if not 1 <= chunk_size_kmeans <= num_cells:
error_message = (
f'chunk_size_kmeans is {chunk_size_kmeans:,}, but must '
f'be ≥ 1 and ≤ the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that `chunk_size_search` is between 1 and `num_cells`; if
# `None`, set to `min(256, num_cells)`
if chunk_size_search is None:
chunk_size_search = min(256, num_cells)
else:
check_type(chunk_size_search, 'chunk_size_search', int,
'a positive integer')
if not 1 <= chunk_size_search <= num_cells:
error_message = (
f'chunk_size_search is {chunk_size_search:,}, but must be '
f'≥ 1 and ≤ the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Run k-means clustering on `PCs`. Since `centroids` and
# `centroids_new` are swapped every iteration, `centroids_new` will
# contain the final centroids when doing an odd number of k-means
# iterations; if so, swap them at the end. Use `threadpool_limits()` to
# run BLAS single-threaded (directly in the single-threaded case, by
# making everything single-threaded including BLAS; indirectly in the
# parallel case, by disabling nested parallelism inside `prange()`).
# `kmeans()` populates `cell_norms` (||X||²) for reuse during the
# nearest-neighbor search.
num_dimensions = PCs.shape[1]
if num_threads == 1:
cluster_labels = np.empty(num_cells, dtype=np.uint32)
min_distances = np.empty(num_cells, dtype=np.float32)
cell_norms = np.empty(num_cells, dtype=np.float32)
else:
cluster_labels = numa_zeros(num_cells, dtype=np.uint32)
min_distances = numa_zeros(num_cells, dtype=np.float32)
cell_norms = numa_zeros(num_cells, dtype=np.float32)
centroids = np.empty((num_clusters, num_dimensions), dtype=np.float32)
centroids_new = np.empty((num_clusters, num_dimensions),
dtype=np.float32)
num_cells_per_cluster = np.empty(num_clusters, dtype=np.uint32)
with threadpool_limits(num_threads):
iterations_until_convergence, num_clusters = kmeans(
X=PCs, cluster_labels=cluster_labels, centroids=centroids,
centroids_new=centroids_new,
num_cells_per_cluster=num_cells_per_cluster,
min_distances=min_distances, cell_norms=cell_norms,
kmeans_barbar=kmeans_barbar,
num_init_iterations=num_init_iterations,
num_kmeans_iterations=num_kmeans_iterations,
tolerance=kmeans_tolerance,
oversampling_factor=oversampling_factor, seed=seed,
chunk_size=chunk_size_kmeans, num_threads=num_threads)
if iterations_until_convergence & 1:
centroids = centroids_new
del centroids_new, min_distances
centroids = centroids[:num_clusters]
num_cells_per_cluster = num_cells_per_cluster[:num_clusters]
# As an optimization, renumber clusters so that cluster `N` tends to be
# near cluster `N + 1` in PC space, by sorting the centroids by PC1.
# This reduces cache misses during the nearest-neighbor search.
centroid_order = np.argsort(centroids[:, 0]).astype(np.uint32)
centroids = np.ascontiguousarray(centroids[centroid_order])
num_cells_per_cluster = num_cells_per_cluster[centroid_order]
inverse_centroid_order = np.empty_like(centroid_order)
inverse_centroid_order[centroid_order] = np.arange(len(centroid_order),
dtype=np.uint32)
cluster_labels = inverse_centroid_order[cluster_labels]
# As an optimization, sort by cluster to reduce cache misses in the
# nearest-neighbor search. This also lets us skip building an explicit
# inverted file index during the nearest-neighbor search.
sorted_order = np.argsort(cluster_labels).astype(np.uint32)
del cluster_labels
if num_threads == 1:
PCs = PCs[sorted_order]
cell_norms = cell_norms[sorted_order]
else:
PCs = parallel_subset_2d(PCs, sorted_order, num_threads)
cell_norms = parallel_subset_1d(cell_norms, sorted_order,
num_threads)
# Find the `num_neighbors` nearest neighbors of each cell, according to
# `PCs`. Allocate the heap arrays, plus the centroid-distance and
# nearest-cluster scratch buffers, NUMA-aware. Use
# `threadpool_limits()` to run BLAS single-threaded (directly in the
# single-threaded case, by making everything single-threaded including
# BLAS; indirectly in the parallel case, by disabling nested
# parallelism inside `prange()`).
if num_threads == 1:
neighbors = np.empty((num_cells, num_neighbors), dtype=np.uint32)
distances = np.empty((num_cells, num_neighbors), dtype=np.float32)
centroid_distances = np.empty((num_cells, num_clusters_searched),
dtype=np.float32)
nearest_clusters = np.empty((num_cells, num_clusters_searched),
dtype=np.uint32)
else:
neighbors = numa_zeros((num_cells, num_neighbors), dtype=np.uint32)
distances = numa_zeros((num_cells, num_neighbors),
dtype=np.float32)
centroid_distances = numa_zeros(
(num_cells, num_clusters_searched), dtype=np.float32)
nearest_clusters = numa_zeros(
(num_cells, num_clusters_searched), dtype=np.uint32)
with threadpool_limits(num_threads):
knn_self(X=PCs, sorted_order=sorted_order, centroids=centroids,
num_cells_per_cluster=num_cells_per_cluster,
neighbors=neighbors, distances=distances,
cell_norms=cell_norms,
centroid_distances=centroid_distances,
nearest_clusters=nearest_clusters,
num_neighbors=num_neighbors,
num_clusters_searched=num_clusters_searched,
chunk_size_kmeans=chunk_size_kmeans,
chunk_size_search=chunk_size_search,
num_threads=num_threads)
del centroid_distances, nearest_clusters, cell_norms
# Invert the argsort so that `neighbors` and `distances` are with
# respect to the original dataset, rather than being sorted by cluster.
# (The remapping of the nearest-neighbor indices themselves is taken
# care of in `knn_self()` via the `sorted_order` argument.)
if num_threads == 1:
unsorted_neighbors = np.empty_like(neighbors)
unsorted_neighbors[sorted_order] = neighbors
neighbors = unsorted_neighbors
unsorted_distances = np.empty_like(distances)
unsorted_distances[sorted_order] = distances
distances = unsorted_distances
else:
inverse_order = np.empty_like(sorted_order)
inverse_order[sorted_order] = np.arange(len(sorted_order),
dtype=sorted_order.dtype)
neighbors = parallel_subset_2d(neighbors, inverse_order,
num_threads)
distances = parallel_subset_2d(distances, inverse_order,
num_threads)
# If `QC_column` was specified, back-project from QCed cells to all
# cells, filling with `NaN` (for the distances) and `UINT32_MAX` (for
# the nearest-neighbor indices)
if QC_column is not None:
neighbors_QCed = neighbors
neighbors = np.full((len(self), neighbors_QCed.shape[1]),
4_294_967_295, dtype=np.uint32)
neighbors[QC_column_NumPy] = neighbors_QCed
distances_QCed = distances
distances = np.full((len(self), distances_QCed.shape[1]), np.nan,
dtype=np.float32)
distances[QC_column_NumPy] = distances_QCed
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm | {neighbors_key: neighbors,
distances_key: distances},
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def shared_neighbors(self,
*,
QC_column: SingleCellColumn | None |
Sequence[SingleCellColumn |
None] = 'passed_QC',
neighbors_key: str = 'neighbors',
shared_neighbors_key: str = 'shared_neighbors',
min_shared_neighbors: int | np.integer = 3,
overwrite: bool = False,
num_threads: int | np.integer | None = None) -> \
SingleCell:
"""
Calculate the shared nearest neighbor graph of this dataset's cells.
This function is intended to be run after `neighbors()`; by default, it
uses `obsm['neighbors']` as the input to the shared nearest-neighbors
calculation.
This function defines the shared nearest neighbor graph based on the
Jaccard index. It matches Seurat's output, except that diagonal
elements are omitted rather than being set to 1. It does not match
Scanpy, which estimates the shared nearest neighbor graph based on the
connectivity of the UMAP manifold.
This function must be re-run if the dataset is subset; not doing so
will raise an error.
Args:
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and excluded from the shared nearest neighbor graph.
neighbors_key: the key of `obsm` containing the nearest neighbors
of each cell calculated with `neighbors()`, to use
as the input for the shared nearest neighbor graph
calculation
shared_neighbors_key: the key of `obsp` where the shared nearest
neighbor graph will be stored
min_shared_neighbors: the minimum number of neighbors a pair of
cells must share to include an edge between
them in the shared nearest neighbor graph.
With 20 nearest neighbors (the default
`num_neighbors` in `neighbors()`) + 1 for the
cell itself, the default value of
`min_shared_neighbors=3` corresponds to the
default value of `prune.SNN = 1 / 15` in
Seurat's `FindNeighbors()` function. With 3
shared neighbors, the shared nearest neighbor
weight is `3 / (42 - 3)` or about 0.077,
which is greater than `1 / 15`, but when
there are 2, the weight is only
`2 / (42 - 2)` or 0.05, which is less than
`1 / 15`.
overwrite: if `True`, overwrite `shared_neighbors_key` if already
present in `obsp`, instead of raising an error
num_threads: the number of threads to use when finding shared
nearest neighbors. Set `num_threads=-1` to use all
available cores, as determined by `os.cpu_count()`, or
leave unset to use `self.num_threads` cores. Does not
affect the returned SingleCell dataset's
`num_threads`; this will always be the same as the
original dataset's `num_threads`.
Returns:
A new SingleCell dataset with each cell's shared nearest neighbor
graph stored in `obsp[shared_neighbors_key]` as a symmetric
`len(obs)` × `len(obs)` sparse array. Specifically,
`obsp[shared_neighbors_key][i, j]` stores the Jaccard index of the
`i`th and `j`th cell's nearest neighbors: the number of cells that
are neighbors of both `i` and `j`, divided by the number of cells
that are neighbors of at least one of `i` and `j`. Diagonal
elements are omitted.
For instance, if 20 nearest neighbors have been calculated (i.e.
`obsm[neighbors_key].shape[1] == 20`) and 8 of the 20 cells in
`obsm[neighbors_key][i]` are also found in
`obsm[neighbors_key][j]`, then `obsp[shared_neighbors_key][i, j]`
will be 0.25 (`8 / (20 + 20 - 8)`).
"""
# Check that `shared_neighbors_key` is a string
check_type(shared_neighbors_key, 'shared_neighbors_key', str,
'a string')
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Check that `shared_neighbors_key` is not already a key in `obsp`,
# unless `overwrite=True`
if not overwrite and shared_neighbors_key in self._obsp:
error_message = (
f'shared_neighbors_key {shared_neighbors_key!r} is already a '
f'key of obsp; did you already run shared_neighbors()? Set '
f'overwrite=True to overwrite.')
raise ValueError(error_message)
# Get the QC column, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
# Get the nearest-neighbor indices, and check that they are uint32 and
# C-contiguous
check_type(neighbors_key, 'neighbors_key', str, 'a string')
if neighbors_key not in self._obsm:
error_message = \
f'neighbors_key {neighbors_key!r} is not a key of obsm'
if neighbors_key == 'neighbors':
error_message += (
'; did you forget to run neighbors() before '
'shared_neighbors()?')
raise ValueError(error_message)
neighbors = self._obsm[neighbors_key]
if neighbors.dtype != np.uint32:
error_message = (
f'obsm[{neighbors_key!r}] must have uint32 data type, but '
f'has data type {str(neighbors.dtype)!r}')
raise TypeError(error_message)
if not neighbors.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{neighbors_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
# Check that `min_shared_neighbors` is less than the number of
# neighbors
check_type(min_shared_neighbors, 'min_shared_neighbors', int,
'a non-negative integer')
num_cells, num_neighbors = neighbors.shape
if not 0 <= min_shared_neighbors < num_neighbors:
error_message = (
f'min_shared_neighbors is {min_shared_neighbors:,}, but must '
f'be ≥ 0 and less than the number of neighbors in '
f'obsm[{neighbors_key!r}] ({num_neighbors:,})')
raise ValueError(error_message)
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Subset neighbor indices to QCed cells only, if `QC_column` is not
# `None`; also map each cell's row index in the new neighbors array to
# its row index in the original neighbors array
if QC_column is not None:
QC_column_NumPy = QC_column.to_numpy()
QCed_to_full_map = np.flatnonzero(QC_column_NumPy)
if num_threads == 1:
neighbors = neighbors[QCed_to_full_map]
else:
neighbors = parallel_subset_2d(
neighbors, QCed_to_full_map, num_threads)
else:
QCed_to_full_map = np.array([], dtype=np.int64)
# Compute the shared nearest neighbor graph
indptr = np.empty(num_cells + 1, dtype=np.int64)
indices, data = snn(
neighbors=neighbors, QCed_to_full_map=QCed_to_full_map,
indptr=indptr, min_shared_neighbors=min_shared_neighbors,
neighbors_key=neighbors_key, num_threads=num_threads)
snn_graph = csr_array((data, indices, indptr),
shape=(num_cells, num_cells))
snn_graph._num_threads = self._num_threads
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm, varm=self._varm,
obsp=self._obsp | {shared_neighbors_key: snn_graph},
varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def cluster(self,
*,
QC_column: SingleCellColumn | None |
Sequence[SingleCellColumn | None] = 'passed_QC',
resolution: int | float | np.integer | np.floating |
Iterable[int | float | np.integer |
np.floating] = 1,
min_cluster_size: int | np.integer = 5,
shared_neighbors_key: str = 'shared_neighbors',
neighbors_key: str = 'neighbors',
cluster_column: str | Iterable[str] = 'cluster',
seed: int | np.integer = 0,
overwrite: bool = False,
verbose: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Cluster cells into cell types using Leiden clustering.
This function is intended to be run after `shared_neighbors()`; by
default, it uses `obsm['shared_neighbors']` as the input to the
clustering.
Our Leiden implementation is based on
[GVE-Leiden](https://arxiv.org/abs/2312.13936), a non-deterministic
parallel version of the Leiden algorithm. To ensure deterministic
output, our implementation is not parallelized. However, when multiple
resolutions are specified, clusterings for all resolutions run in
parallel by default.
Like Seurat and Scanpy's Leiden clustering, our implementation
optimizes the objective function corresponding to Reichardt and
Bornholdt's Potts model (`RBConfigurationVertexPartition` in the
reference implementation of Leiden clustering, `leidenalg`). This
formulation extends modularity by introducing a `resolution` parameter;
with the default `resolution = 1`, it is exactly equivalent to
maximizing modularity.
Args:
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their cell-type labels set to `NaN`.
resolution: the key parameter of Leiden clustering. Larger values
result in more clusters. If multiple resolutions are
specified, the clustering for each resolution is
executed in parallel by default.
min_cluster_size: the minimum cluster size. Cells in clusters of
size less than `min_cluster_size` will be merged
into the cluster of their nearest neighbor that
is in a cluster of size ≥ `min_cluster_size`.
Set `min_cluster_size=1` to disable this merging.
Clusters with only one or two cells may occur if
they are disconnected from the rest of the shared
nearest neighbor graph.
shared_neighbors_key: the key of `obsp` containing the shared
nearest neighbor graph calculated with
`shared_neighbors()`, to use as the input for
Leiden clustering. Must be symmetric,
although this is not checked for, due to
speed considerations. Diagonals are assumed
to be 1; the actual values are ignored.
neighbors_key: the key of `obsm` containing the nearest neighbors
of each cell calculated with `neighbors()`. This is
only used to merge disconnected clusters of size
less than `min_cluster_size`. Not used when
`min_cluster_size=1`.
cluster_column: the name of an integer column to be added to
obs indicating the cell-type labels. If `N`
resolutions are specified, `N` columns named
`f'{cluster_column}_0'` through
`f'{cluster_column}_{N - 1}'` will be added.
seed: the random seed to use when clustering
overwrite: if `True`, overwrite `cluster_column` if already present
in `obs`, instead of raising an error
verbose: whether to print details of the Leiden clustering; must be
`False` when running multithreaded
num_threads: the number of threads to use when clustering.
Parallelization takes place across resolutions. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores. In both cases, only as many
cores will be used as the number of resolutions
specified. Specifying `num_threads=-1` when only one
resolution is specified will raise an error, as will
specifying a positive value for `num_threads` that is
greater than the number of resolutions specified. Does
not affect the returned SingleCell dataset's
`num_threads`; this will always be the same as the
original dataset's `num_threads`.
Returns:
A new SingleCell dataset where `obs[cluster_column]` is an Enum
column containing an integer cell-type label for each cell (`'0'`,
`'1'`, etc.). Or, if `N` resolutions are specified, a dataset where
`obs[f'{cluster_column}_0']` through
`obs[f'{cluster_column}_{N - 1}']` contain `N` sets of cell-type
labels.
Note:
This function may give an incorrect output if you specified a
custom shared nearest-neighbor graph that a) is non-symmetric (i.e.
`(shared_neighbors != shared_neighbors_key.T).nnz`, where
`shared_neighbors = sc.obsp[shared_neighbors_key]`, is non-zero),
b) contains explicit zeros (i.e. if
`(shared_neighbors.data == 0).any()`), or c) contains negative
values: these are not checked for, due to speed considerations. In
the unlikely event that your custom shared nearest-neighbor graph
contains explicit zeros, remove them by running
`sc.obsp[shared_neighbors_key].eliminate_zeros()` (an in-place
operation) first.
"""
# Check that `cluster_column` is a string
check_type(cluster_column, 'cluster_column', str, 'a string')
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Get the resolution(s). Check that `cluster_column` (if one
# resolution) or `f'{cluster_column}_1'` through
# `f'{cluster_column}_{N}'` (if `N` resolutions) are not already
# columns of `obs`, unless `overwrite=True`.
resolutions = to_tuple_checked(resolution, 'resolution', (int, float),
'positive numbers')
num_resolutions = len(resolutions)
if num_resolutions == 1:
check_bounds(resolution, 'resolution', 0, left_open=True)
if not overwrite and cluster_column in self._obs:
error_message = (
f'cluster_column {cluster_column!r} is already a column '
f'of obs; did you already run cluster()? Set '
f'overwrite=True to overwrite.')
raise ValueError(error_message)
else:
resolutions = np.array(resolutions, dtype=np.float32)
for resolution in resolutions:
if resolution < 0:
error_message = (
f'a resolution is {resolution:,}, but all resolutions '
f'must be > 0')
raise ValueError(error_message)
if not overwrite:
for resolution_index in range(1, num_resolutions + 1):
composite_cluster_column = \
f'{cluster_column}_{resolution_index}'
if composite_cluster_column in self._obs:
error_message = (
f'{composite_cluster_column!r} is already a '
f'column of obs; did you already run cluster()? '
f'Set overwrite=True to overwrite.')
raise ValueError(error_message)
# Check that `min_cluster_size` is a positive integer
check_type(min_cluster_size, 'num_doublet_genes', int,
'a positive integer')
check_bounds(min_cluster_size, 'num_doublet_genes', 1)
# Get the shared nearest neighbor graph, and check that it is float32
check_type(shared_neighbors_key, 'shared_neighbors_key', str,
'a string')
if shared_neighbors_key not in self._obsp:
error_message = (
f'shared_neighbors_key {shared_neighbors_key!r} is not a key '
f'of obsp')
if shared_neighbors_key == 'shared_neighbors':
error_message += (
'; did you forget to run shared_neighbors() before '
'cluster()?')
raise ValueError(error_message)
snn_graph = self._obsp[shared_neighbors_key]
if snn_graph.dtype != np.float32:
error_message = (
f'obsp[{shared_neighbors_key!r}] must have data type float32, '
f'but has data type {str(snn_graph.dtype)!r}')
raise TypeError(error_message)
if snn_graph.nnz == 0:
error_message = f'obsp[{shared_neighbors_key!r}] is an empty graph'
raise ValueError(error_message)
# Get the QC column, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
# If `min_cluster_size > 1`, get the nearest-neighbor indices, and
# check that they are uint32 and C-contiguous. If `min_cluster_size=1`,
# check that `neighbors_key` is not specified (i.e. retains its default
# value).
if min_cluster_size > 1:
check_type(neighbors_key, 'neighbors_key', str, 'a string')
if neighbors_key not in self._obsm:
error_message = \
f'neighbors_key {neighbors_key!r} is not a key of obsm'
if neighbors_key == 'neighbors':
error_message += (
'; did you forget to run neighbors() before '
'shared_neighbors()?')
raise ValueError(error_message)
neighbors = self._obsm[neighbors_key]
if neighbors.dtype != np.uint32:
error_message = (
f'obsm[{neighbors_key!r}] must have uint32 data type, but '
f'has data type {str(neighbors.dtype)!r}')
raise TypeError(error_message)
if not neighbors.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{neighbors_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
else:
neighbors = np.array([[]], dtype=np.uint32)
if neighbors_key != 'neighbors':
error_message = \
'neighbors_key cannot be specified when min_cluster_size=1'
raise ValueError(error_message)
# Check that `seed` is an integer
check_type(seed, 'seed', int, 'an integer')
# Check that `num_threads` is a positive integer, -1 or `None`. If
# `None`, set to `min(self.num_threads, len(resolutions))`. If -1, set
# to `min(os.cpu_count(), len(resolutions))`, but raise an error if the
# user only specified one resolution. If a positive integer, raise an
# error if the user specified more threads than resolutions.
if num_threads is None:
num_threads = min(self._num_threads, num_resolutions)
else:
check_type(num_threads, 'num_threads', int,
'a positive integer, -1, or None')
if num_threads == -1:
if num_resolutions == 1:
error_message = (
'only one resolution was specified, so num_threads '
'must be 1 or None, not -1')
raise ValueError(error_message)
num_threads = min(os.cpu_count(), num_resolutions)
else:
num_threads = int(num_threads)
if num_threads <= 0:
error_message = (
f'num_threads is {num_threads:,}, but must be a '
f'positive integer, -1, or None')
raise ValueError(error_message)
if num_threads > num_resolutions:
error_message = (
f'num_threads is {num_threads:,}, but cannot be '
f'greater than the number of resolutions specified '
f'({num_resolutions:,})')
raise ValueError(error_message)
# Check that `verbose` is Boolean, and `False` when running
# multithreaded
check_type(verbose, 'verbose', bool, 'Boolean')
if verbose and num_threads > 1:
error_message = 'verbose must be False when running multithreaded'
raise ValueError(error_message)
# Subset the nearest-neighbor indices and SNN graph to cells passing
# QC, if `QC_column` is not `None`
if QC_column is not None:
QC_column_NumPy = QC_column.to_numpy()
if num_threads == 1:
neighbors = neighbors[QC_column_NumPy]
else:
indices = np.flatnonzero(QC_column_NumPy)
neighbors = parallel_subset_2d(neighbors, indices, num_threads)
snn_graph = snn_graph[ix_symmetric(QC_column_NumPy)]
# Perform Leiden clustering
num_cells = snn_graph.shape[0]
if num_resolutions == 1:
clusters = np.empty(num_cells, dtype=np.uint32)
num_final_communities, too_large, self_neighbors = leiden(
data=snn_graph.data, indices=snn_graph.indices,
indptr=snn_graph.indptr, neighbors=neighbors,
final_communities=clusters, resolution=resolution,
min_cluster_size=min_cluster_size, seed=seed, verbose=verbose)
# If any nearest-neighbor indices were out of bounds or equal to
# the cell's own index, raise an error
if too_large:
error_message = (
f'some nearest-neighbor indices in '
f'obsm[{neighbors_key!r}] are >= the total number of '
f'cells, {neighbors.shape[0]:,}. This may happen if you '
f'subset this SingleCell dataset between neighbors() and '
f'cluster(); if so, make sure to run neighbors() after, '
f'not before, subsetting.')
raise ValueError(error_message)
elif self_neighbors:
error_message = (
f'some nearest-neighbor indices in '
f'obsm[{neighbors_key!r}] indicate that a cell is its own '
f'neighbor, i.e. obsm[{neighbors_key!r}][i, j] == i for '
f'some i and j. This may happen if you created '
f'obsm[{neighbors_key!r}] manually rather than following '
f'the recommended approach of running neighbors().')
raise ValueError(error_message)
clusters = pl.Series(cluster_column, clusters)\
.cast(pl.Enum(map(str, range(num_final_communities + 1))))
if QC_column is not None:
# Back-project from QCed cells to all cells, filling with
# `null`
clusters = pl.when(QC_column)\
.then(clusters
.gather(QC_column.cum_sum() - QC_column))
else:
clusters = np.empty((num_resolutions, num_cells), dtype=np.uint32)
num_final_communities = np.empty(num_resolutions, dtype=np.uint32)
if num_threads == 1:
for resolution_index, resolution in enumerate(resolutions):
if verbose:
print(f'\nresolution = {resolution}:')
num_final_communities[resolution_index], too_large, \
self_neighbors = leiden(
data=snn_graph.data, indices=snn_graph.indices,
indptr=snn_graph.indptr, neighbors=neighbors,
final_communities=clusters[resolution_index],
resolution=resolution,
min_cluster_size=min_cluster_size, seed=seed,
verbose=verbose)
if too_large or self_neighbors:
break
else:
too_large, self_neighbors = leiden_multiresolution(
data=snn_graph.data, indices=snn_graph.indices,
indptr=snn_graph.indptr, neighbors=neighbors,
final_communities=clusters,
num_final_communities=num_final_communities,
resolutions=resolutions, min_cluster_size=min_cluster_size,
seed=seed, num_threads=num_threads)
# If any nearest-neighbor indices were out of bounds or equal to
# the cell's own index, raise an error
if too_large:
error_message = (
f'some nearest-neighbor indices in '
f'obsm[{neighbors_key!r}] are >= the total number of '
f'cells, {neighbors.shape[0]:,}. This may happen if you '
f'subset this SingleCell dataset between neighbors() and '
f'cluster(); if so, make sure to run neighbors() after, '
f'not before, subsetting.')
raise ValueError(error_message)
elif self_neighbors:
error_message = (
f'some nearest-neighbor indices in '
f'obsm[{neighbors_key!r}] indicate that a cell is its own '
f'neighbor, i.e. obsm[{neighbors_key!r}][i, j] == i for '
f'some i and j. This may happen if you created '
f'obsm[{neighbors_key!r}] manually rather than following '
f'the recommended approach of running neighbors().')
raise ValueError(error_message)
clusters = pl.DataFrame(clusters.T, schema=[
f'{cluster_column}_{resolution_index}'
for resolution_index in range(num_resolutions)])\
.cast({f'{cluster_column}_{resolution_index}': pl.Enum(map(
str, range(num_final_communities[resolution_index] + 1)))
for resolution_index in range(num_resolutions)})
if QC_column is not None:
# Back-project from QCed cells to all cells, filling with
# `null`
clusters = clusters.select(
pl.when(QC_column).then(pl.all().gather(
QC_column.cum_sum() - QC_column)))
return SingleCell(X=self._X, obs=self._obs.with_columns(clusters),
var=self._var, obsm=self._obsm, varm=self._varm,
obsp=self._obsp, varp=self._varp, uns=self._uns,
num_threads=self._num_threads)
[docs]
def harmonize(self,
*others: SingleCell,
QC_column: SingleCellColumn | None |
Sequence[SingleCellColumn | None] = 'passed_QC',
batch_column: SingleCellColumn | None |
Sequence[SingleCellColumn | None] = None,
PC_key: str = 'pca',
Harmony_key: str = 'harmony',
num_clusters: int | np.integer | None = None,
max_iterations: int | np.integer = 10,
num_kmeans_iterations: int | np.integer = 25,
kmeans_tolerance: int | np.integer | float |
np.floating = 1e-2,
kmeans_barbar: bool = False,
num_init_iterations: int | np.integer = 5,
oversampling_factor: int | np.integer | float |
np.floating = 1,
chunk_size_kmeans: int | np.integer | None = None,
max_clustering_iterations: int | np.integer = 5,
block_proportion: int | float | np.integer |
np.floating = 0.05,
tolerance: int | float | np.integer | np.floating = 0.01,
early_stopping: bool = False,
clustering_tolerance: int | float | np.integer |
np.floating = 0.001,
theta: int | float | np.integer | np.floating = 2,
tau: int | float | np.integer | np.floating = 0,
alpha: float | np.floating = 0.2,
sigma: int | float | np.integer | np.floating = 0.1,
chunk_size_Harmony: int | np.integer | None = None,
seed: int | np.integer = 0,
original: bool = False,
overwrite: bool = False,
verbose: bool = True,
num_threads: int | np.integer | None = None) -> \
SingleCell | tuple[SingleCell, ...]:
"""
Harmonize this SingleCell dataset with other datasets, or harmonize
multiple batches of the same dataset, with
[Harmony2](https://github.com/immunogenomics/harmony).
Our implementation differs from the original Harmony2 implementation in
three key ways:
First, we parallelize Harmony via an innovative nested block strategy.
The original implementation partitions cells randomly into blocks, each
containing a fixed fraction (`block_proportion`, 5% by default) of the
total cells. It iterates over each block, subtracting the contribution
of the cells in the block to the observed and expected cluster-batch
co-occurence matrices `O` and `E`, re-calculating the soft-clustering
assignments `R` based on the residual `O` and `E`, then adding back the
contribution of the cells in the block to `O` and `E` based on the new
`R`. This approach resists straightforward parallelization because
updates to `R` for future blocks depend on updates to `O` and `E` from
previous blocks, leading to convergence failure if naively
parallelized. Instead, we parallelize within blocks by dividing them
into chunks of `chunk_size` cells (512 by default). We process each
chunk in parallel, subtracting only the `O` and `E` contributions of
the chunk itself, updating `R` for the chunk, and, after and, after
processing all chunks in the block, add back the `O` and `E`
contributions for all chunks based on the updated `R`. Updating `O` and
`E` at the end of each block (rather than after processing every cell
in the dataset) ensures convergence, while the inner chunking enables
parallelization without disrupting convergence. To use the original
implementation's chunkless strategy for updating `R`, `O`, and `E`,
specify `original=True, num_threads=1`.
Second, we reduce the default number of clustering iterations
(Harmony's inner loop) from 20 to 5, but always complete all 5
iterations without early stopping based on convergence. In practice,
the largest changes to the Harmony objective function result from
updating the PCs (outer loop), not updating the soft-clustering
assignments (inner loop). Our implementation makes updating the PCs
sufficiently fast that there are rapidly diminishing returns from
performing lots of clustering updates, versus just skipping directly to
updating the PCs after a few clustering iterations. By skipping
convergence checks, we reduce the number of objective function
evaluations (which are expensive) from once per inner iteration to once
per outer iteration. To use the original implementation's default of 20
iterations with early stopping, specify
`early_stopping=True, max_clustering_iterations=20`.
Third, we use random k-means initialization, like version 1 of Harmony,
rather than Harmony2's kmeans++ initialization. We implement a parallel
version of kmeans++ called [k-means||](https://arxiv.org/abs/1203.6402),
but find it does not improve the accuracy of label transfer relative to
random initialization and increases runtime. kmeans|| initialization can
be enabled with `kmeans_barbar=True`.
Args:
others: the other SingleCell datasets to harmonize this one with.
Can be omitted if harmonizing between batches of the same
dataset, but then `batch_column` must be specified.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their Harmony embeddings set to `NaN`. When
`others` is specified, `QC_column` can be a
length-`1 + len(others)` sequence of columns,
expressions, Series, functions, or `None` for each
dataset (for `self`, followed by each dataset in
`others`).
batch_column: an optional String, Enum, Categorical, or integer
column of `obs` indicating which batch each cell is
from. Can be a column name, a polars expression, a
polars Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a polars
Series or 1D NumPy array. Each batch will be treated
as if it were a distinct dataset; this is exactly
equivalent to splitting the dataset with
`split_by(batch_column)` and then passing each of the
resulting datasets to `harmonize()`. Set to `None` to
treat each dataset as having a single batch. When
`others` is specified, `batch_column` may be a
length-`1 + len(others)` sequence of columns,
expressions, Series, functions, or `None` for each
dataset (for `self`, followed by each dataset in
`others`).
PC_key: the key of `obsm` containing the principal components
calculated with `PCA()`, to use as the input to Harmony
Harmony_key: the key of `obsm` where the Harmony embeddings will be
stored; will be added in-place to both `self` and each
of the datasets in `others`!
num_clusters: the number of clusters used in the Harmony algorithm,
including in the initial k-means clustering. If not
specified, use 100 clusters if ≥3000 cells, 1 cluster
if ≤30 cells, and `round(num_cells / 30)` if between
30 and 3000 cells. For datasets of at least 10,000
cells, clusters with fewer than 100 cells will be
merged into the adjacent cluster with the nearest
centroid, so the actual number of clusters used may
be smaller.
max_iterations: the maximum number of iterations to run Harmony
for, if convergence is not achieved. Defaults to
10, like the original Harmony R package,
harmony-pytorch, and harmonypy. Set to `None` to
use as many iterations as necessary to achieve
convergence.
num_kmeans_iterations: the maximum number of iterations of
k-means clustering to perform before
starting the nearest-neighbor search,
stopping early if a relative convergence of
`kmeans_tolerance` is reached. Defaults to
25, like the original Harmony R package,
harmony-pytorch, and harmonypy. However,
unlike these packages, only one
initialization is tried rather than 10 to
reduce runtime.
kmeans_tolerance: the relative change in inertia (the sum of
squared distances from each cell to its assigned
centroid) used to determine whether to stop
optimizing the k-means clustering before
`num_kmeans_iterations` iterations
kmeans_barbar: whether to use
[k-means||](https://arxiv.org/abs/1203.6402)
initialization (a parallel version of k-means++)
to initialize the k-means clustering centroids,
instead of random initialization. This is more
accurate but takes considerably longer for large
datasets and `num_clusters`.
num_init_iterations: the number of k-means|| iterations used to
initialize the k-means clustering that
constitutes the first step of the
nearest-neighbor search. k-means|| is a
parallel version of the widely used k-means++
initialization scheme for k-means clustering.
The default value of 5 is recommended by the
[k-means|| paper](https://arxiv.org/abs/1203.6402).
Only used when `kmeans_barbar=True`.
oversampling_factor: the number of candidate centroids selected, on
average, at each of the `num_init_iterations`
iterations of k-means||, as a multiple of
`num_clusters`. The default value of 1 is the
midpoint (in log space) of the values explored
by the
[k-means|| paper](https://arxiv.org/abs/1203.6402),
namely 0.1 to 10. The total number of candidate
centroids selected, on average, will be
`oversampling_factor * num_clusters + 1`, from
which the final `num_clusters` centroids will
then be selected via k-means++. Only used when
`kmeans_barbar=True`.
chunk_size_kmeans: the chunk size to use for distance calculations
in the initial k-means clustering. Setting this
to a power of 2 is recommended. Defaults to
`min(4096, total number of cells)`.
max_clustering_iterations: the number of iterations to run the
clustering step within each Harmony
iteration for, or the maximum number of
iterations if `early_stopping=True`.
Unlike the original Harmony algorithm,
convergence of the clustering is not
checked unless `early_stopping=True`.
Defaults to 5 iterations; this differs
from the 20 used by the original harmony
R package and harmonypy and the 200
iterations used by harmony-pytorch,
which do have convergence checks.
Must be 4 or more when
`early_stopping=True`, since Harmony's
clustering convergence check requires
knowing the errors from the past 3
iterations.
block_proportion: the proportion of cells to use in each batch
update in the clustering step; must be greater
than zero and less than or equal to 1
tolerance: the relative tolerance used to determine whether to stop
optimizing the Harmony embeddings before
`max_iterations` iterations
early_stopping: whether to stop clustering before
`max_clustering_iterations` iterations if
convergence is reached, like in the original
Harmony algorithm
clustering_tolerance: the relative tolerance used to determine
whether to stop clustering before
`max_clustering_iterations` iterations. Only
used when `early_stopping=True`.
theta: the weight of the diversity penalty term in the Harmony
objective function; must be non-negative. Larger values
result in more diverse clusters.
tau: the discounting factor on theta; must be non-negative. By
default, `tau = 0`, so there is no discounting.
alpha: the scaling factor for the ridge regression penalty used
when correcting the principal components to get the Harmony
embeddings; must be greater than 0 and less than 1. The
ridge penalty `lambda` is determined by `alpha` and the
expected number of cells, assuming independence between
batches and clusters:
`lambda = alpha * expected number of cells`. Smaller values
result in more aggressive correction.
sigma: the weight of the entropy term in the Harmony objective
function; must be non-negative
chunk_size_Harmony: the chunk size to use for Harmony. Setting this
to a power of 2 is recommended. Defaults to
`min(256, total number of cells)`. Not used
when `original=True`.
seed: the random seed to use for the initial k-means clustering
original: if `True`, use the original Harmony algorithm's blocking
strategy, rather than our nested chunks-within-blocks
strategy. `original=True` requires `num_threads=1`. This
gives lower memory usage and closer correspondence to the
original algorithm, at the cost of a) a moderate
(~20-25%) degradation in performance and b) no longer
matching the Harmony embeddings produced by the
multithreaded version. If `True`, exactly match the
results of the multithreaded version when
`num_threads=1`. Must be `False` unless `num_threads=1`.
overwrite: if `True`, overwrite `Harmony_key` if already present in
obsm, instead of raising an error
verbose: whether to print details of the harmonization process
num_threads: the number of threads to use when concatenating
principal components and batch/dataset labels across
datasets, for the initial k-means clustering, and for
the matrix and matrix-vector multiplications within
Harmony. Set `num_threads=-1` to use all available
cores, as determined by `os.cpu_count()`, or leave
unset to use `self.num_threads` cores. Does not affect
the returned SingleCell dataset's `num_threads`; this
will always be the same as the original dataset's
`num_threads`. `num_threads` will be capped to 64 when
running with Scipy linked against OpenBLAS (see
warning below).
Returns:
A length-`1 + len(others)` tuple of SingleCell datasets with the
Harmony embeddings stored in `obsm[Harmony_key]`: `self`, followed
by each dataset in `others`. Or, if `others` is omitted, a single
SingleCell dataset with the Harmony embeddings stored in
`obsm[Harmony_key]`.
Warning:
If you installed Scipy via pip, it will be linked against OpenBLAS,
and `harmonize()` will be limited to 64 threads due to the
limitations of OpenBLAS. To use more than 64 threads, install Scipy
linked against MKL BLAS. This is done automatically when installing
brisc via conda, but you can also do it manually via
`conda install "libblas=*=*mkl" scipy`.
"""
# If `others` was specified, check that all elements of `others` are
# SingleCell datasets
if others:
check_types(others, 'others', SingleCell, 'SingleCell datasets')
elif batch_column is None:
error_message = \
'others cannot be empty unless batch_column is specified'
raise ValueError(error_message)
datasets = [self] + list(others)
# Check that `Harmony_key` is a string
check_type(Harmony_key, 'Harmony_key', str, 'a string')
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Check that `Harmony_key` is not already in `obsm` for any dataset,
# unless `overwrite=True`
suffix = ' for at least one dataset' if others else ''
if not overwrite and \
any(Harmony_key in dataset._obsm for dataset in datasets):
error_message = (
f'Harmony_key {Harmony_key!r} is already a key of '
f'obsm{suffix}; did you already run harmonize()? Set '
f'overwrite=True to overwrite.')
raise ValueError(error_message)
# Get `QC_column` and `batch_column` from every dataset, if not `None`
QC_columns = SingleCell._get_columns(
'obs', datasets, QC_column, 'QC_column', pl.Boolean,
allow_missing=True)
QC_columns_NumPy = [QC_col.to_numpy() if QC_col is not None else None
for QC_col in QC_columns]
batch_columns = SingleCell._get_columns(
'obs', datasets, batch_column, 'batch_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_columns=QC_columns)
# Check that `PC_key` is a key of `obsm` for every dataset
check_type(PC_key, 'PC_key', str, 'a string')
if not all(PC_key in dataset._obsm for dataset in datasets):
error_message = (
f'PC_key {PC_key!r} is not a column of obs{suffix}; did you '
f'forget to run PCA() before harmonize()?')
raise ValueError(error_message)
# If `num_clusters` is not `None`, check that it is a positive integer
# We will assign `num_clusters` its default value (if `None`) and check
# its upper bound later, once we know the total number of cells across
# all datasets.
if num_clusters is not None:
check_type(num_clusters, 'num_clusters', int, 'a positive integer')
check_bounds(num_clusters, 'num_clusters', 1)
# Check that `max_iterations` is `None` or a positive integer; if
# `None`, set to `INT32_MAX`
if max_iterations is None:
max_iterations = 2_147_483_647
else:
check_type(max_iterations, 'max_iterations', int,
'a positive integer')
check_bounds(max_iterations, 'max_iterations', 1)
# Check that `num_kmeans_iterations` is a positive integer
check_type(num_kmeans_iterations, 'num_kmeans_iterations', int,
'a positive integer')
check_bounds(num_kmeans_iterations, 'num_kmeans_iterations', 1)
# Check that `kmeans_tolerance` is a positive number
check_type(kmeans_tolerance, 'kmeans_tolerance', (int, float),
'a positive number')
check_bounds(kmeans_tolerance, 'kmeans_tolerance', 0, left_open=True)
# Check that `kmeans_barbar` is Boolean
check_type(kmeans_barbar, 'kmeans_barbar', bool, 'Boolean')
# Check that `num_init_iterations` is a positive integer
check_type(num_init_iterations, 'num_init_iterations', int,
'a positive integer')
check_bounds(num_init_iterations, 'num_init_iterations', 1)
# Check that `oversampling_factor` is a positive number
check_type(oversampling_factor, 'oversampling_factor', (int, float),
'a positive number')
check_bounds(oversampling_factor, 'oversampling_factor', 0,
left_open=True)
# If `kmeans_barbar=False`, check that `num_init_iterations` and
# `oversampling_factor` have their default values
if not kmeans_barbar:
if num_init_iterations != 5:
error_message = (
'num_init_iterations can only be specified when '
'kmeans_barbar=True')
raise ValueError(error_message)
if oversampling_factor != 1:
error_message = (
'oversampling_factor can only be specified when '
'kmeans_barbar=True')
raise ValueError(error_message)
# Check that `max_clustering_iterations` is a positive integer
check_type(max_clustering_iterations, 'max_clustering_iterations', int,
'a positive integer')
check_bounds(max_clustering_iterations, 'max_clustering_iterations', 1)
# Check that `block_proportion` is a number and that
# `0 < block_proportion <= 1`
check_type(block_proportion, 'block_proportion', (int, float),
'a number greater than zero and less than or equal to 1')
check_bounds(block_proportion, 'block_proportion', 0, 1,
left_open=True)
# Check that `tolerance` is a positive number; if an integer, cast it
# to a float
check_type(tolerance, 'tolerance', (int, float), 'a positive number')
check_bounds(tolerance, 'tolerance', 0, left_open=True)
tolerance = float(tolerance)
# Check that `early_stopping` is Boolean
check_type(early_stopping, 'early_stopping', bool, 'Boolean')
# Check that `clustering_tolerance` is a positive number; if an
# integer, cast it to a float
check_type(clustering_tolerance, 'clustering_tolerance', (int, float),
'a positive number')
check_bounds(clustering_tolerance, 'clustering_tolerance', 0,
left_open=True)
clustering_tolerance = float(clustering_tolerance)
# If `early_stopping=False`, check that `clustering_tolerance` has its
# default value
if not early_stopping:
if clustering_tolerance != 0.001:
error_message = (
'clustering_tolerance can only be specified when '
'early_stopping=True')
raise ValueError(error_message)
# Check that `theta` and `tau` are non-negative numbers; if either is
# an integer, cast it to a float
for parameter, parameter_name in (theta, 'theta'), (tau, 'tau'):
check_type(parameter, parameter_name, (int, float),
'a non-negative number')
check_bounds(parameter, parameter_name, 0)
theta = float(theta)
tau = float(tau)
# Check that `alpha` is greater than 0 and less than 1
check_type(alpha, 'alpha', float,
'a number greater than 0 and less than 1')
check_bounds(alpha, 'alpha', 0, 1, left_open=True, right_open=True)
# Check that `sigma` is a non-negative number; if an integer, cast it
# to a float
check_type(sigma, 'sigma', (int, float), 'a non-negative number')
check_bounds(sigma, 'sigma', 0)
sigma = float(sigma)
# If `chunk_size_kmeans` and/or `chunk_size_Harmony` are not `None`,
# check that they are positive integers. We will check their upper
# bounds and set their default values (if `None`) later, once we know
# the total number of cells across all datasets.
for chunk_size, chunk_size_name in \
(chunk_size_kmeans, 'chunk_size_kmeans'), \
(chunk_size_Harmony, 'chunk_size_Harmony'):
if chunk_size is not None:
check_type(chunk_size, chunk_size_name, int,
'a positive integer')
check_bounds(chunk_size, chunk_size_name, 1)
# Check that `seed` is an integer
check_type(seed, 'seed', int, 'an integer')
# Check that `verbose` is Boolean
check_type(verbose, 'verbose', bool, 'Boolean')
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Cap to 64 threads for OpenBLAS to avoid the error "OpenBLAS : Program
# is Terminated. Because you tried to allocate too many memory
# regions".
if num_threads > 64 and any(lib.get('internal_api') == 'openblas'
for lib in threadpool_info()):
num_threads = 64
# Check that `original` is Boolean, and `False` unless `num_threads=1`
check_type(original, 'original', bool, 'Boolean')
if original and num_threads != 1:
error_message = 'original must be False unless num_threads is 1'
raise ValueError(error_message)
# Concatenate PCs across datasets; get labels indicating which rows of
# these concatenated PCs come from each dataset or batch. Check that
# the PCs are float32 and C-contiguous and all have the same width.
all_PCs = [dataset._obsm[PC_key] for dataset in datasets]
for PCs in all_PCs:
dtype = PCs.dtype
if dtype != np.float32:
error_message = (
f'obsm[{PC_key!r}].dtype is {dtype!r}{suffix}, but must '
f'be float32')
raise TypeError(error_message)
if not PCs.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{PC_key!r}].dtype is not C-contiguous{suffix}; '
f'make it C-contiguous with pipe_obsm_key({PC_key!r}, '
f'np.ascontiguousarray)')
raise ValueError(error_message)
width = all_PCs[0].shape[1]
for PCs in all_PCs[1:]:
if PCs.shape[1] != width:
error_message = (
f"two datasets' PCs have different numbers of columns "
f"({width:,} vs {PCs.shape[1]:,}")
raise ValueError(error_message)
if QC_column is not None:
all_PCs = [PCs[QCed] if QCed is not None else PCs
for PCs, QCed in zip(all_PCs, QC_columns_NumPy)]
num_cells_per_dataset = np.array(list(map(len, all_PCs)))
if batch_column is None:
batch_labels = np.repeat(np.arange(len(num_cells_per_dataset),
dtype=np.uint32),
num_cells_per_dataset)
num_batches = len(num_cells_per_dataset)
else:
batch_labels = []
batch_index = 0
for dataset, QC_col, batch_col in \
zip(datasets, QC_columns, batch_columns):
if batch_col is not None:
if QC_col is not None:
batch_col = batch_col.filter(QC_col)
if batch_col.dtype in (pl.String, pl.Enum, pl.Categorical):
if batch_col.dtype != pl.Enum:
batch_col = batch_col \
.cast(pl.Enum(batch_col.unique().drop_nulls()))
batch_col = batch_col.to_physical()
if batch_col.dtype != pl.UInt32:
batch_col = batch_col.cast(pl.UInt32)
batch_labels.append(batch_col.to_numpy() + batch_index)
batch_index += batch_col.n_unique()
else:
batch_labels.append(np.full(batch_index,
len(dataset) if QC_col is None
else QC_col.sum(),
dtype=np.float32))
batch_index += 1
batch_labels = concatenate(batch_labels, num_threads=num_threads)
num_batches = batch_index
PCs = concatenate(all_PCs, num_threads=num_threads)
num_cells, num_PCs = PCs.shape
# If `num_clusters` is `None`, set it to `num_cells / 30`, rounded to
# the nearest integer and clipped to be between 1 and 100. Otherwise,
# check that it is less than the total number of cells across all
# datasets.
if num_clusters is None:
num_clusters = max(min(100, int(round(num_cells / 30))), 1)
elif num_clusters >= num_cells:
error_message = (
f'num_clusters is {num_clusters:,}, but must be less than the '
f'total number of cells across all datasets ({num_cells:,})')
raise ValueError(error_message)
# If `chunk_size_kmeans` is `None`, set it to `min(4096, num_cells)`.
# Otherwise, check that it is less than the total number of cells
# across all datasets.
if chunk_size_kmeans is None:
chunk_size_kmeans = min(4096, num_cells)
elif chunk_size_kmeans >= num_cells:
error_message = (
f'chunk_size_kmeans is {chunk_size_kmeans:,}, but must be '
f'less than the total number of cells across all datasets '
f'({num_cells:,})')
raise ValueError(error_message)
# If `chunk_size_Harmony` is `None`, set it to `min(256, num_cells)`.
# Otherwise, check that it is less than the total number of cells
# across all datasets.
if chunk_size_Harmony is None:
chunk_size_Harmony = min(256, num_cells)
elif chunk_size_Harmony >= num_cells:
error_message = (
f'chunk_size_Harmony is {chunk_size_Harmony:,}, but must be '
f'less than the total number of cells across all datasets '
f'({num_cells:,})')
raise ValueError(error_message)
# Get `Z`, the row-normalized PCs; allocate temporary buffers
# `cluster_labels` and `min_distances`, used in k-means
if num_threads == 1:
Z = np.empty((num_cells, num_PCs), dtype=np.float32)
cluster_labels = np.empty(num_cells, dtype=np.uint32)
min_distances = np.empty(num_cells, dtype=np.float32)
else:
Z = numa_zeros((num_cells, num_PCs), dtype=np.float32)
cluster_labels = numa_zeros(num_cells, dtype=np.uint32)
min_distances = numa_zeros(num_cells, dtype=np.float32)
normalize_rows(arr=PCs, out=Z, num_threads=num_threads)
# Run k-means clustering on `Z`. Since `Y` and `Y_new` are swapped
# every iteration, `Y_new` will contain the final `Y` when doing an odd
# number of k-means iterations; if so, swap them at the end. Use
# `threadpool_limits()` to run BLAS single-threaded (directly in the
# single-threaded case, by making everything single-threaded including
# BLAS; indirectly in the parallel case, by disabling nested
# parallelism inside `prange()`).
Y = np.empty((num_clusters, num_PCs), dtype=np.float32)
Y_new = np.empty((num_clusters, num_PCs), dtype=np.float32)
with threadpool_limits(num_threads):
iterations_until_convergence, num_clusters = kmeans(
X=Z, cluster_labels=cluster_labels, centroids=Y,
centroids_new=Y_new,
num_cells_per_cluster=np.empty(num_clusters, dtype=np.uint32),
min_distances=min_distances,
cell_norms=np.array([], dtype=np.float32),
kmeans_barbar=kmeans_barbar,
num_init_iterations=num_init_iterations,
num_kmeans_iterations=num_kmeans_iterations,
tolerance=kmeans_tolerance,
oversampling_factor=oversampling_factor, seed=seed,
chunk_size=chunk_size_kmeans, num_threads=num_threads)
if iterations_until_convergence & 1:
Y = Y_new
del Y_new, cluster_labels, min_distances
Y = Y[:num_clusters]
# Run Harmony. Unlike above, here we need to use single-threaded matrix
# multiplication to ensure consistent floating-point roundoff. Use
# `threadpool_limits()` to run BLAS single-threaded (directly in the
# single-threaded case, by making everything single-threaded including
# BLAS; indirectly in the parallel case, by disabling nested
# parallelism inside `prange()`).
if num_threads == 1:
R = np.empty((num_cells, num_clusters), dtype=np.float32)
else:
R = numa_zeros((num_cells, num_clusters), dtype=np.float32)
if original:
with threadpool_limits(1):
harmony_original(
PCs=PCs, Z=Z, Y=Y, R=R, batch_labels=batch_labels,
num_batches=num_batches, max_iterations=max_iterations,
max_clustering_iterations=max_clustering_iterations,
block_proportion=block_proportion, tolerance=tolerance,
early_stopping=early_stopping,
clustering_tolerance=clustering_tolerance, theta=theta,
tau=tau, alpha=alpha, sigma=sigma,
chunk_size=chunk_size_Harmony, seed=seed, verbose=verbose)
else:
num_threads = min(num_threads, num_cells)
with threadpool_limits(num_threads):
harmony(PCs=PCs, Z=Z, Y=Y, R=R, batch_labels=batch_labels,
num_batches=num_batches, max_iterations=max_iterations,
max_clustering_iterations=max_clustering_iterations,
block_proportion=block_proportion, tolerance=tolerance,
early_stopping=early_stopping,
clustering_tolerance=clustering_tolerance, theta=theta,
tau=tau, alpha=alpha, sigma=sigma,
chunk_size=chunk_size_Harmony, seed=seed,
verbose=verbose, num_threads=num_threads)
del batch_labels, PCs, Y
# Store each dataset's Harmony embedding in its `obsm`
if not others: # just one dataset
QC_col = QC_columns[0]
Harmony_embedding = Z
# If `QC_col` is not `None`, back-project from QCed cells to all
# cells, filling with `NaN`
if QC_col is not None:
Harmony_embedding_QCed = Harmony_embedding
Harmony_embedding = np.full(
(len(self), Harmony_embedding_QCed.shape[1]),
np.nan, dtype=np.float32)
Harmony_embedding[QC_columns_NumPy[0]] = Harmony_embedding_QCed
return SingleCell(
X=dataset._X, obs=dataset._obs, var=dataset._var,
obsm=dataset._obsm | {Harmony_key: Harmony_embedding},
varm=self._varm, uns=self._uns, num_threads=self._num_threads)
else:
for dataset_index, (dataset, QC_col, num_cells, end_index) in \
enumerate(zip(datasets, QC_columns_NumPy,
num_cells_per_dataset,
num_cells_per_dataset.cumsum())):
start_index = end_index - num_cells
dataset_Harmony_embedding = Z[start_index:end_index]
# If `QC_col` is not `None` for this dataset, back-project from
# QCed cells to all cells, filling with `NaN`
if QC_col is not None:
dataset_Harmony_embedding_QCed = dataset_Harmony_embedding
dataset_Harmony_embedding = np.full(
(len(dataset),
dataset_Harmony_embedding_QCed.shape[1]),
np.nan, dtype=np.float32)
dataset_Harmony_embedding[QC_col] = \
dataset_Harmony_embedding_QCed
datasets[dataset_index] = SingleCell(
X=dataset._X, obs=dataset._obs, var=dataset._var,
obsm=dataset._obsm | {
Harmony_key: dataset_Harmony_embedding},
varm=self._varm, uns=self._uns,
num_threads=self._num_threads)
return tuple(datasets)
[docs]
def label_transfer_from(
self,
other: SingleCell,
original_cell_type_column: SingleCellColumn,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
other_QC_column: SingleCellColumn | None = 'passed_QC',
Harmony_key: str = 'harmony',
cell_type_column: str = 'cell_type',
confidence_column: str | None = None,
next_best: bool = False,
next_best_cell_type_column: str | None = None,
next_best_confidence_column: str | None = None,
num_neighbors: int | np.integer = 20,
num_clusters: int | np.integer | None = None,
num_clusters_searched: int | np.integer | None = None,
num_kmeans_iterations: int | np.integer = 2,
kmeans_tolerance: int | np.integer | float | np.floating = 1e-2,
kmeans_barbar: bool = False,
num_init_iterations: int | np.integer = 5,
oversampling_factor: int | np.integer | float | np.floating = 1,
chunk_size_kmeans: int | np.integer | None = None,
chunk_size_search: int | np.integer | None = None,
seed: int | np.integer = 0,
overwrite: bool = False,
verbose: bool = True,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Transfer cell-type labels from another dataset to this one, using the
two datasets' Harmony embeddings from `harmonize()`.
For each cell in `self`, the transferred cell-type label is the most
common cell-type label among the `num_neighbors` cells in `other` with
the nearest Harmony embeddings. The cell-type confidence is the
fraction of these neighbors that share this most common cell-type
label.
The nearest-neighbor search is conducted using the same method as
`neighbors()`, with one crucial difference: whereas `neighbors()`
searches for a cell's nearest neighbors in its own dataset, this
function searches for a cell's nearest neighbors in another dataset,
i.e. `other`.
Args:
other: the dataset to transfer cell-type labels from
original_cell_type_column: a String, Enum, Categorical, or integer
column of `other.obs` containing
cell-type labels. Can be a column name,
a polars expression, a polars Series, a
1D NumPy array, or a function that takes
in `other` and returns a polars Series
or 1D NumPy array.
QC_column: an optional Boolean column of `self.obs` indicating
which cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in `self` and returns a polars
Series or 1D NumPy array. Set to `None` to include all
cells. Cells failing QC will have their cell-type labels
and confidences set to `null`.
other_QC_column: an optional Boolean column of `other.obs`
indicating which cells passed QC. Can be a column
name, a polars expression, a polars Series, a 1D
NumPy array, or a function that takes in `other`
and returns a polars Series or 1D NumPy array. Set
to `None` to include all cells. Cells failing QC
will be ignored during the label transfer.
Harmony_key: the key of `self.obsm` and `other.obsm` containing the
Harmony embeddings for each dataset
cell_type_column: the name of a column to be added to `self.obs`
indicating each cell's most likely cell type,
i.e. the most common cell-type label among the
cell's `num_neighbors` nearest neighbors in
`other`
confidence_column: the name of a column to be added to `self.obs`
indicating each cell's cell-type confidence,
i.e. the fraction of the cell's `num_neighbors`
nearest neighbors in `other` that share the most
common cell-type label. If multiple cell types
are equally common among the nearest neighbors,
tiebreak based on which of them is most common
in `original_cell_type_column`. If `None`,
defaults to `f'{cell_type_column}_confidence'`.
next_best: whether to also compute each cell's second-most likely
cell type and confidence, or just its most likely
next_best_cell_type_column: the name of a column to be added to
`self.obs` indicating each cell's
second-most likely cell type, i.e. the
second-most common cell-type label
among the cell's `num_neighbors`
nearest neighbors in
`original_cell_type_column`. If `None`,
defaults to
`f'next_best_{cell_type_column}'`. Can
only be specified when
`next_best=True`.
next_best_confidence_column: the name of a column to be added to
`self.obs` indicating each cell's
cell-type confidence, i.e. the
fraction of the cell's `num_neighbors`
nearest neighbors in
`original_cell_type_column` that share
the second-most common cell-type
label. If multiple cell types are
equally common among the nearest
neighbors, tiebreak based on which of
them is most common in `other`. If
`None`, defaults to
`f'next_best_{cell_type_column}_confidence'`.
Can only be specified when
`next_best=True`.
num_neighbors: the number of nearest neighbors to use when
determining a cell's label. All cell-type
confidences will be multiples of
`1 / num_neighbors`.
num_clusters: the number of k-means clusters to use during the
nearest-neighbor search. Must be less than the number
of cells. If `None`, will be set to
`ceil(min(4 * sqrt(num_cells), num_cells / 100))`
clusters, i.e. the minimum of four times the square
root of the number of cells in `other` and 1% of the
number of cells in `other`, rounding up. The core of
the heuristic, `4 * sqrt(num_cells)`, is the low end
of the range
[recommended](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index)
by faiss, 4 to 16 times the square root. However,
faiss also recommends using between 39 and 256 data
points per centroid when training the k-means
clustering used in the k-nearest neighbors search. To
avoid going below 39, we switch to using
`num_cells / 100` centroids for small datasets (fewer
than 640,000 cells), since 100 is the midpoint of 39
and 256 in log space. For datasets of at least 10,000
cells, clusters with fewer than 100 cells will be
merged into the adjacent cluster with the nearest
centroid, so the actual number of clusters used may
be smaller.
num_clusters_searched: the number of a cell's nearest clusters to
search; must be between 1 and
`num_clusters`. Defaults to
`min(64, num_clusters)`.
num_kmeans_iterations: the maximum number of iterations of
k-means clustering to perform before
starting the nearest-neighbor search,
stopping early if a relative convergence of
`kmeans_tolerance` is reached
kmeans_tolerance: the relative change in inertia (the sum of
squared distances from each cell to its assigned
centroid) used to determine whether to stop
optimizing the k-means clustering before
`num_kmeans_iterations` iterations
kmeans_barbar: whether to use
[k-means||](https://arxiv.org/abs/1203.6402)
initialization (a parallel version of k-means++)
to initialize the k-means clustering centroids,
instead of random initialization. This is more
accurate but takes considerably longer for large
datasets and `num_clusters`.
num_init_iterations: the number of k-means|| iterations used to
initialize the k-means clustering that
constitutes the first step of the
nearest-neighbor search. k-means|| is a
parallel version of the widely used k-means++
initialization scheme for k-means clustering.
The default value of 5 is recommended by the
[k-means|| paper](https://arxiv.org/abs/1203.6402).
Only used when `kmeans_barbar=True`.
oversampling_factor: the number of candidate centroids selected, on
average, at each of the `num_init_iterations`
iterations of k-means||, as a multiple of
`num_clusters`. The default value of 1 is the
midpoint (in log space) of the values explored
by the
[k-means|| paper](https://arxiv.org/abs/1203.6402),
namely 0.1 to 10. The total number of candidate
centroids selected, on average, will be
`oversampling_factor * num_clusters + 1`, from
which the final `num_clusters` centroids will
then be selected via k-means++. Only used when
`kmeans_barbar=True`.
chunk_size_kmeans: the chunk size used for distance calculations
during k-means clustering, and also during the
per-query centroid ranking step of the
nearest-neighbor search. Setting this to a power
of 2 is recommended. Defaults to
`min(4096, number of cells in other)`.
chunk_size_search: the chunk size used to group query cells
together during the nearest-neighbor search.
Overly small values will tend to increase
runtime by reducing the reuse of information
during the search, whereas overly large values
will lead to excessive memory use. Defaults to
`min(256, number of cells in other)`.
seed: the random seed to use when finding nearest neighbors
overwrite: if `True`, overwrite `cell_type_column` and/or
`confidence_column` if already present in this dataset's
obs, instead of raising an error
verbose: whether to print details of the nearest-neighbor search
num_threads: the number of threads to use for the nearest-neighbor
search and label transfer. Set `num_threads=-1` to use
all available cores, as determined by
`os.cpu_count()`, or leave unset to use
`self.num_threads` cores. Does not affect the returned
SingleCell dataset's `num_threads`; this will always
be the same as the original dataset's `num_threads`.
`num_threads` will be capped to 64 when running with
Scipy linked against OpenBLAS (see warning below).
Returns:
`self`, but with two columns added to `obs`: `cell_type_column`,
containing the transferred cell-type labels, and
`confidence_column`, containing the cell-type confidences. If
`next_best=True`, also adds the columns
`next_best_cell_type_column` and `next_best_confidence_column`,
containing the second-most likely cell type and its confidence.
Warning:
If you installed Scipy via pip, it will be linked against OpenBLAS,
and `label_transfer_from()` will be limited to 64 threads due to
the limitations of OpenBLAS. To use more than 64 threads, install
Scipy linked against MKL BLAS. This is done automatically when
installing brisc via conda, but you can also do it manually via
`conda install "libblas=*=*mkl" scipy`.
"""
# Check that `cell_type_column` is a string
check_type(cell_type_column, 'cell_type_column', str, 'a string')
# Check that `confidence_column` is a string or `None`; if `None`, set
# to `f'{cell_type_column}_confidence'`.
if confidence_column is None:
confidence_column = f'{cell_type_column}_confidence'
else:
check_type(confidence_column, 'confidence_column', str,
'a string or None')
# Check that `next_best` is Boolean
check_type(next_best, 'next_best', bool, 'Boolean')
# If `next_best=False`, check that `next_best_cell_type_column` and
# `next_best_confidence_column` are `None`. If `next_best=True`, check
# that they are strings or `None`; if `None, set to
# `f'next_best_{cell_type_column}'` and
# `f'next_best_{cell_type_column}_confidence'` respectively.
if next_best:
if next_best_cell_type_column is None:
next_best_cell_type_column = f'next_best_{cell_type_column}'
else:
check_type(next_best_cell_type_column,
'next_best_cell_type_column', str,
'a string or None')
if next_best_confidence_column is None:
next_best_confidence_column = \
f'next_best_{cell_type_column}_confidence'
else:
check_type(next_best_confidence_column,
'next_best_confidence_column', str,
'a string or None')
elif next_best_cell_type_column is not None:
error_message = \
'next_best_cell_type_column must be None unless next_best=True'
raise ValueError(error_message)
elif next_best_confidence_column is not None:
error_message = (
'next_best_confidence_column must be None unless '
'next_best=True')
raise ValueError(error_message)
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# If `overwrite=False`, check that `cell_type_column` and
# `confidence_column` (and `next_best_cell_type_column` and
# `next_best_confidence_column`, if `next_best=True`) are not already
# columns of `self.obs`
if not overwrite:
for column, column_name in (
(cell_type_column, 'cell_type_column'),
(confidence_column, 'confidence_column')):
if column in self._obs:
error_message = (
f'{column_name} {column!r} is already a column '
f'of obs; did you already run label_transfer_from()? '
f'Set overwrite=True to overwrite.')
raise ValueError(error_message)
for column, column_name in (
(next_best_cell_type_column, 'next_best_cell_type_column'),
(next_best_confidence_column,
'next_best_confidence_column')):
if column in self._obs:
error_message = (
f'{column_name} {column!r} is already a column '
f'of obs; did you already run label_transfer_from()? '
f'Set overwrite=True to overwrite, or next_best=False '
f'to not compute {column_name}.')
raise ValueError(error_message)
# Check that `other` is a SingleCell dataset
check_type(other, 'other', SingleCell, 'a SingleCell dataset')
# Get `QC_column` from `self` and `other_QC_column` from `other`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
if other_QC_column is not None:
other_QC_column = other._get_column(
'obs', other_QC_column, 'other_QC_column', pl.Boolean,
allow_missing=other_QC_column == 'passed_QC')
# Get `original_cell_type_column` from `other`
original_original_cell_type_column = original_cell_type_column
original_cell_type_column = other._get_column(
'obs', original_cell_type_column, 'original_cell_type_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_column=other_QC_column)
# If `other_QC_column` was specified, filter the cell type labels in
# `original_cell_type_column` to cells passing QC
if other_QC_column is not None:
original_cell_type_column = \
original_cell_type_column.filter(other_QC_column)
# Check that `original_cell_type_column` has at least two distinct cell
# types
most_common_cell_types = \
original_cell_type_column.value_counts(sort=True).to_series()
if len(most_common_cell_types) == 1:
original_cell_type_column_description = \
SingleCell._describe_column('original_cell_type_column',
original_original_cell_type_column)
error_message = (
f'{original_cell_type_column_description} must have at least '
f'two distinct cell types')
if other_QC_column is not None:
error_message += ' after filtering to cells passing QC'
raise ValueError(error_message)
# Check that `Harmony_key` is a string and in both `self.obsm` and
# `other.obsm`
check_type(Harmony_key, 'Harmony_key', str, 'a string')
datasets = (self, 'self'), (other, 'other')
for dataset, dataset_name in datasets:
if Harmony_key not in dataset._obsm:
error_message = (
f'Harmony_key {Harmony_key!r} is not a column of '
f'{dataset_name}.obs; did you forget to run harmonize() '
f'before label_transfer_from()?')
raise ValueError(error_message)
# Check that `num_kmeans_iterations` is a positive integer
check_type(num_kmeans_iterations, 'num_kmeans_iterations', int,
'a positive integer')
# Check that `kmeans_tolerance` is a positive number
check_type(kmeans_tolerance, 'kmeans_tolerance', (int, float),
'a positive number')
check_bounds(kmeans_tolerance, 'kmeans_tolerance', 0, left_open=True)
# Check that `kmeans_barbar` is Boolean
check_type(kmeans_barbar, 'kmeans_barbar', bool, 'Boolean')
# Check that `num_init_iterations` is a positive integer
check_type(num_init_iterations, 'num_init_iterations', int,
'a positive integer')
check_bounds(num_init_iterations, 'num_init_iterations', 1)
# Check that `oversampling_factor` is a positive number
check_type(oversampling_factor, 'oversampling_factor', (int, float),
'a positive number')
check_bounds(oversampling_factor, 'oversampling_factor', 0,
left_open=True)
# If `kmeans_barbar=False`, check that `num_init_iterations` and
# `oversampling_factor` have their default values
if not kmeans_barbar:
if num_init_iterations != 5:
error_message = (
'num_init_iterations can only be specified when '
'kmeans_barbar=True')
raise ValueError(error_message)
if oversampling_factor != 1:
error_message = (
'oversampling_factor can only be specified when '
'kmeans_barbar=True')
raise ValueError(error_message)
# Check that `seed` is an integer
check_type(seed, 'seed', int, 'an integer')
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Cap to 64 threads for OpenBLAS to avoid the error "OpenBLAS : Program
# is Terminated. Because you tried to allocate too many memory
# regions".
if num_threads > 64 and any(lib.get('internal_api') == 'openblas'
for lib in threadpool_info()):
num_threads = 64
# Check that `verbose` is Boolean
check_type(verbose, 'verbose', bool, 'Boolean')
# Get the Harmony embeddings for self and other; check that they are
# float32 and C-contiguous and have the same width
if QC_column is None:
self_Harmony_embeddings = self._obsm[Harmony_key]
else:
if num_threads == 1:
self_Harmony_embeddings = \
self._obsm[Harmony_key][QC_column.to_numpy()]
else:
indices = np.flatnonzero(QC_column.to_numpy())
self_Harmony_embeddings = parallel_subset_2d(
self._obsm[Harmony_key], indices, num_threads)
if other_QC_column is None:
other_Harmony_embeddings = other._obsm[Harmony_key]
else:
if num_threads == 1:
other_Harmony_embeddings = \
other._obsm[Harmony_key][other_QC_column.to_numpy()]
else:
indices = np.flatnonzero(other_QC_column.to_numpy())
other_Harmony_embeddings = parallel_subset_2d(
other._obsm[Harmony_key], indices, num_threads)
if self_Harmony_embeddings.dtype != np.float32:
error_message = (
f'obsm[{Harmony_key!r}].dtype is '
f'{self_Harmony_embeddings.dtype!r}, but must be float32')
raise TypeError(error_message)
if not self_Harmony_embeddings.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{Harmony_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({Harmony_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
if other_Harmony_embeddings.dtype != np.float32:
error_message = (
f'other.obsm[{Harmony_key!r}].dtype is '
f'{other_Harmony_embeddings.dtype!r}, but must be float32')
raise TypeError(error_message)
if not other_Harmony_embeddings.flags['C_CONTIGUOUS']:
error_message = (
f'other.obsm[{Harmony_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({Harmony_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
num_dimensions = self_Harmony_embeddings.shape[1]
if other_Harmony_embeddings.shape[1] != num_dimensions:
error_message = (
f"the two datasets' Harmony embeddings have different numbers "
f"of columns ({num_dimensions:,} vs "
f"{other_Harmony_embeddings.shape[1]:,}")
raise ValueError(error_message)
# Get the number of cells in `self` and `other`
num_self = len(self_Harmony_embeddings)
num_other = len(other_Harmony_embeddings)
# Check that `num_neighbors` is between 1 and `num_other - 1`
check_type(num_neighbors, 'num_neighbors', int, 'a positive integer')
if not 1 <= num_neighbors < num_other:
error_message = (
f'num_neighbors is {num_neighbors:,}, but must be ≥ 1 and '
f'less than the number of cells in other ({num_other:,})')
raise ValueError(error_message)
# Check that `num_clusters` is between 1 and `num_other`; if `None`,
# set to `ceil(min(4 * sqrt(num_other), num_other / 100)))`
if num_clusters is None:
num_clusters = \
int(np.ceil(min(4 * np.sqrt(num_other), num_other / 100)))
else:
check_type(num_clusters, 'num_clusters', int, 'a positive integer')
if not 1 <= num_clusters <= num_other:
error_message = (
f'num_clusters is {num_clusters:,}, but must be ≥ 1 and ≤ '
f'the number of cells in other ({num_other:,})')
raise ValueError(error_message)
# Check that `num_clusters_searched` is between 1 and `num_clusters`;
# if `None`, set to `min(64, num_clusters)`
if num_clusters_searched is None:
num_clusters_searched = min(64, num_clusters)
else:
check_type(num_clusters_searched, 'num_clusters_searched', int,
'a positive integer')
if not 1 <= num_clusters_searched <= num_clusters:
error_message = (
f'num_clusters_searched is {num_clusters_searched:,}, but '
f'must be ≥ 1 and ≤ num_clusters ({num_clusters:,})')
raise ValueError(error_message)
# Check that `chunk_size_kmeans` is between 1 and `num_other`; if
# `None`, set to `min(4096, num_other)`
if chunk_size_kmeans is None:
chunk_size_kmeans = min(4096, num_other)
else:
check_type(chunk_size_kmeans, 'chunk_size_kmeans', int,
'a positive integer')
if not 1 <= chunk_size_kmeans <= num_other:
error_message = (
f'chunk_size_kmeans is {chunk_size_kmeans:,}, but must '
f'be ≥ 1 and ≤ the number of cells in other '
f'({num_other:,})')
raise ValueError(error_message)
# Check that `chunk_size_search` is between 1 and `num_other`; if
# `None`, set to `min(256, num_other)`
if chunk_size_search is None:
chunk_size_search = min(256, num_other)
else:
check_type(chunk_size_search, 'chunk_size_search', int,
'a positive integer')
if not 1 <= chunk_size_search <= num_other:
error_message = (
f'chunk_size_search is {chunk_size_search:,}, but must be '
f'≥ 1 and ≤ the number of cells in other ({num_other:,})')
raise ValueError(error_message)
# Recode cell types so the most common is 0, the next-most common 1,
# etc. This has the effect of breaking ties by taking the most common
# cell type: we pick the first element in case of ties.
original_cell_type_column = original_cell_type_column.replace_strict(
most_common_cell_types,
pl.arange(len(most_common_cell_types), eager=True,
dtype=pl.UInt32))\
.to_numpy()
# Run k-means clustering on `other_Harmony_embeddings`. Since
# `centroids` and `centroids_new` are swapped every iteration,
# `centroids_new` will contain the final centroids when doing an odd
# number of k-means iterations; if so, swap them at the end. Use
# `threadpool_limits()` to run BLAS single-threaded (directly in the
# single-threaded case, by making everything single-threaded including
# BLAS; indirectly in the parallel case, by disabling nested
# parallelism inside `prange()`).
if num_threads == 1:
cluster_labels = np.empty(num_other, dtype=np.uint32)
min_distances = np.empty(num_other, dtype=np.float32)
cell_norms = np.empty(num_other, dtype=np.float32)
else:
cluster_labels = numa_zeros(num_other, dtype=np.uint32)
min_distances = numa_zeros(num_other, dtype=np.float32)
cell_norms = numa_zeros(num_other, dtype=np.float32)
centroids = np.empty((num_clusters, num_dimensions), dtype=np.float32)
centroids_new = np.empty((num_clusters, num_dimensions),
dtype=np.float32)
num_cells_per_cluster = np.empty(num_clusters, dtype=np.uint32)
with threadpool_limits(num_threads):
iterations_until_convergence, num_clusters = kmeans(
X=other_Harmony_embeddings, cluster_labels=cluster_labels,
centroids=centroids, centroids_new=centroids_new,
num_cells_per_cluster=num_cells_per_cluster,
min_distances=min_distances, cell_norms=cell_norms,
kmeans_barbar=kmeans_barbar,
num_init_iterations=num_init_iterations,
num_kmeans_iterations=num_kmeans_iterations,
tolerance=kmeans_tolerance,
oversampling_factor=oversampling_factor, seed=seed,
chunk_size=chunk_size_kmeans, num_threads=num_threads)
if iterations_until_convergence & 1:
centroids = centroids_new
del centroids_new, min_distances
centroids = centroids[:num_clusters]
num_cells_per_cluster = num_cells_per_cluster[:num_clusters]
# As an optimization, renumber clusters so that cluster `N` tends to be
# near cluster `N + 1` in PC space, by sorting the centroids by PC1.
# This reduces cache misses during the nearest-neighbor search.
centroid_order = np.argsort(centroids[:, 0]).astype(np.uint32)
centroids = np.ascontiguousarray(centroids[centroid_order])
num_cells_per_cluster = num_cells_per_cluster[centroid_order]
inverse_centroid_order = np.empty_like(centroid_order)
inverse_centroid_order[centroid_order] = np.arange(len(centroid_order),
dtype=np.uint32)
cluster_labels = inverse_centroid_order[cluster_labels]
# As an optimization, sort by cluster to reduce cache misses in the
# nearest-neighbor search. This also lets us skip building an explicit
# inverted file index during the nearest-neighbor search. Unlike in
# `neighbors()`, there is no need to pass `sorted_order` to the
# k-nearest neighbors function or to create the reverse mapping, since
# `self` is not being sorted, only `other`.
sorted_order = np.argsort(cluster_labels).astype(np.uint32)
del cluster_labels
if num_threads == 1:
other_Harmony_embeddings = other_Harmony_embeddings[sorted_order]
original_cell_type_column = original_cell_type_column[sorted_order]
cell_norms = cell_norms[sorted_order]
else:
other_Harmony_embeddings = parallel_subset_2d(
other_Harmony_embeddings, sorted_order, num_threads)
original_cell_type_column = parallel_subset_1d(
original_cell_type_column, sorted_order, num_threads)
cell_norms = parallel_subset_1d(cell_norms, sorted_order,
num_threads)
# Find the `num_neighbors` nearest neighbors in `other` of each cell in
# `self`, according to `self_Harmony_embeddings` and
# `other_Harmony_embeddings`. Use `threadpool_limits()` to run BLAS
# single-threaded (directly in the single-threaded case, by making
# everything single-threaded including BLAS; indirectly in the parallel
# case, by disabling nested parallelism inside `prange()`).
if num_threads == 1:
neighbors = np.empty((num_self, num_neighbors), dtype=np.uint32)
distances = np.empty((num_self, num_neighbors), dtype=np.float32)
centroid_distances = np.empty((num_self, num_clusters_searched),
dtype=np.float32)
nearest_clusters = np.empty((num_self, num_clusters_searched),
dtype=np.uint32)
query_norms = np.empty(num_self, dtype=np.float32)
else:
neighbors = numa_zeros((num_self, num_neighbors), dtype=np.uint32)
distances = numa_zeros((num_self, num_neighbors),
dtype=np.float32)
centroid_distances = numa_zeros((num_self, num_clusters_searched),
dtype=np.float32)
nearest_clusters = numa_zeros((num_self, num_clusters_searched),
dtype=np.uint32)
query_norms = numa_zeros(num_self, dtype=np.float32)
with threadpool_limits(num_threads):
knn_cross(Y=self_Harmony_embeddings, X=other_Harmony_embeddings,
centroids=centroids,
num_cells_per_cluster=num_cells_per_cluster,
cell_norms=cell_norms, neighbors=neighbors,
distances=distances,
centroid_distances=centroid_distances,
nearest_clusters=nearest_clusters,
query_norms=query_norms, num_neighbors=num_neighbors,
num_clusters_searched=num_clusters_searched,
chunk_size_kmeans=chunk_size_kmeans,
chunk_size_search=chunk_size_search,
num_threads=num_threads)
del centroid_distances, nearest_clusters, query_norms
# Get the (two) most common cell type(s) for each cell in `self` among
# its `num_neighbors` nearest neighbors in `other`. Pick the first
# element in case of ties, which according to our encoding is the most
# common cell type. Also get the cell-type confidence of each cell's
# type(s), i.e. the frequency of the cell type among the cell's
# `num_neighbors` nearest neighbors.
if num_threads == 1:
cell_types = np.empty(num_self, dtype=np.uint32)
confidences = np.empty(num_self, dtype=np.float32)
else:
cell_types = numa_zeros(num_self, dtype=np.uint32)
confidences = numa_zeros(num_self, dtype=np.float32)
if not next_best:
next_best_cell_types = np.array([], dtype=np.uint32)
next_best_confidences = np.array([], dtype=np.float32)
elif num_threads == 1:
next_best_cell_types = np.empty(num_self, dtype=np.uint32)
next_best_confidences = np.empty(num_self, dtype=np.float32)
else:
next_best_cell_types = numa_zeros(num_self, dtype=np.uint32)
next_best_confidences = numa_zeros(num_self, dtype=np.float32)
label_transfer(neighbors=neighbors,
original_cell_type_column=original_cell_type_column,
num_cell_types=len(most_common_cell_types),
cell_types=cell_types, confidences=confidences,
next_best_cell_types=next_best_cell_types,
next_best_confidences=next_best_confidences,
num_threads=num_threads)
# Map the cell-type codes back to their labels by constructing a polars
# Series from the codes, then casting it to an Enum. Also convert
# cell-type confidences to Series.
cell_types = pl.Series(cell_type_column, cell_types)\
.cast(pl.Enum(most_common_cell_types.to_list()))
confidences = pl.Series(confidence_column, confidences)
if next_best:
next_best_cell_types = \
pl.Series(next_best_cell_type_column, next_best_cell_types)\
.cast(pl.Enum(most_common_cell_types.to_list()))
next_best_confidences = \
pl.Series(next_best_confidence_column, next_best_confidences)
columns = cell_types, confidences, next_best_cell_types, \
next_best_confidences
else:
columns = cell_types, confidences
# Add the cell-type labels and confidences to `self.obs`. If
# `QC_column` was specified, back-project from QCed cells to all cells,
# filling with `null`.
if QC_column is None:
obs = self._obs.with_columns(columns)
else:
expand = lambda series: pl.when(QC_column.name)\
.then(pl.lit(series).gather(pl.col(QC_column.name).cum_sum() -
pl.col(QC_column.name)))
obs = self._obs.with_columns(map(expand, columns))
# Return a new SingleCell dataset containing the cell-type labels and
# confidences
return SingleCell(X=self._X, obs=obs, var=self._var, obsm=self._obsm,
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
@staticmethod
def _get_rocket_r() -> 'LinearSegmentedColormap':
"""
Define Seaborn's rocket_r colormap using base Matplotlib.
Returns:
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
import matplotlib.pyplot as plt
finally:
signal.signal(signal.SIGINT, signal.default_int_handler)
rocket_colors = [
[0.01060815, 0.01808215, 0.10018654],
[0.01428972, 0.02048237, 0.10374486],
[0.01831941, 0.0229766, 0.10738511],
[0.02275049, 0.02554464, 0.11108639],
[0.02759119, 0.02818316, 0.11483751],
[0.03285175, 0.03088792, 0.11863035],
[0.03853466, 0.03365771, 0.12245873],
[0.04447016, 0.03648425, 0.12631831],
[0.05032105, 0.03936808, 0.13020508],
[0.05611171, 0.04224835, 0.13411624],
[0.0618531, 0.04504866, 0.13804929],
[0.06755457, 0.04778179, 0.14200206],
[0.0732236, 0.05045047, 0.14597263],
[0.0788708, 0.05305461, 0.14995981],
[0.08450105, 0.05559631, 0.15396203],
[0.09011319, 0.05808059, 0.15797687],
[0.09572396, 0.06050127, 0.16200507],
[0.10132312, 0.06286782, 0.16604287],
[0.10692823, 0.06517224, 0.17009175],
[0.1125315, 0.06742194, 0.17414848],
[0.11813947, 0.06961499, 0.17821272],
[0.12375803, 0.07174938, 0.18228425],
[0.12938228, 0.07383015, 0.18636053],
[0.13501631, 0.07585609, 0.19044109],
[0.14066867, 0.0778224, 0.19452676],
[0.14633406, 0.07973393, 0.1986151],
[0.15201338, 0.08159108, 0.20270523],
[0.15770877, 0.08339312, 0.20679668],
[0.16342174, 0.0851396, 0.21088893],
[0.16915387, 0.08682996, 0.21498104],
[0.17489524, 0.08848235, 0.2190294],
[0.18065495, 0.09009031, 0.22303512],
[0.18643324, 0.09165431, 0.22699705],
[0.19223028, 0.09317479, 0.23091409],
[0.19804623, 0.09465217, 0.23478512],
[0.20388117, 0.09608689, 0.23860907],
[0.20973515, 0.09747934, 0.24238489],
[0.21560818, 0.09882993, 0.24611154],
[0.22150014, 0.10013944, 0.2497868],
[0.22741085, 0.10140876, 0.25340813],
[0.23334047, 0.10263737, 0.25697736],
[0.23928891, 0.10382562, 0.2604936],
[0.24525608, 0.10497384, 0.26395596],
[0.25124182, 0.10608236, 0.26736359],
[0.25724602, 0.10715148, 0.27071569],
[0.26326851, 0.1081815, 0.27401148],
[0.26930915, 0.1091727, 0.2772502],
[0.27536766, 0.11012568, 0.28043021],
[0.28144375, 0.11104133, 0.2835489],
[0.2875374, 0.11191896, 0.28660853],
[0.29364846, 0.11275876, 0.2896085],
[0.29977678, 0.11356089, 0.29254823],
[0.30592213, 0.11432553, 0.29542718],
[0.31208435, 0.11505284, 0.29824485],
[0.31826327, 0.1157429, 0.30100076],
[0.32445869, 0.11639585, 0.30369448],
[0.33067031, 0.11701189, 0.30632563],
[0.33689808, 0.11759095, 0.3088938],
[0.34314168, 0.11813362, 0.31139721],
[0.34940101, 0.11863987, 0.3138355],
[0.355676, 0.11910909, 0.31620996],
[0.36196644, 0.1195413, 0.31852037],
[0.36827206, 0.11993653, 0.32076656],
[0.37459292, 0.12029443, 0.32294825],
[0.38092887, 0.12061482, 0.32506528],
[0.38727975, 0.12089756, 0.3271175],
[0.39364518, 0.12114272, 0.32910494],
[0.40002537, 0.12134964, 0.33102734],
[0.40642019, 0.12151801, 0.33288464],
[0.41282936, 0.12164769, 0.33467689],
[0.41925278, 0.12173833, 0.33640407],
[0.42569057, 0.12178916, 0.33806605],
[0.43214263, 0.12179973, 0.33966284],
[0.43860848, 0.12177004, 0.34119475],
[0.44508855, 0.12169883, 0.34266151],
[0.45158266, 0.12158557, 0.34406324],
[0.45809049, 0.12142996, 0.34540024],
[0.46461238, 0.12123063, 0.34667231],
[0.47114798, 0.12098721, 0.34787978],
[0.47769736, 0.12069864, 0.34902273],
[0.48426077, 0.12036349, 0.35010104],
[0.49083761, 0.11998161, 0.35111537],
[0.49742847, 0.11955087, 0.35206533],
[0.50403286, 0.11907081, 0.35295152],
[0.51065109, 0.11853959, 0.35377385],
[0.51728314, 0.1179558, 0.35453252],
[0.52392883, 0.11731817, 0.35522789],
[0.53058853, 0.11662445, 0.35585982],
[0.53726173, 0.11587369, 0.35642903],
[0.54394898, 0.11506307, 0.35693521],
[0.5506426, 0.11420757, 0.35737863],
[0.55734473, 0.11330456, 0.35775059],
[0.56405586, 0.11235265, 0.35804813],
[0.57077365, 0.11135597, 0.35827146],
[0.5774991, 0.11031233, 0.35841679],
[0.58422945, 0.10922707, 0.35848469],
[0.59096382, 0.10810205, 0.35847347],
[0.59770215, 0.10693774, 0.35838029],
[0.60444226, 0.10573912, 0.35820487],
[0.61118304, 0.10450943, 0.35794557],
[0.61792306, 0.10325288, 0.35760108],
[0.62466162, 0.10197244, 0.35716891],
[0.63139686, 0.10067417, 0.35664819],
[0.63812122, 0.09938212, 0.35603757],
[0.64483795, 0.0980891, 0.35533555],
[0.65154562, 0.09680192, 0.35454107],
[0.65824241, 0.09552918, 0.3536529],
[0.66492652, 0.09428017, 0.3526697],
[0.67159578, 0.09306598, 0.35159077],
[0.67824099, 0.09192342, 0.3504148],
[0.684863, 0.09085633, 0.34914061],
[0.69146268, 0.0898675, 0.34776864],
[0.69803757, 0.08897226, 0.3462986],
[0.70457834, 0.0882129, 0.34473046],
[0.71108138, 0.08761223, 0.3430635],
[0.7175507, 0.08716212, 0.34129974],
[0.72398193, 0.08688725, 0.33943958],
[0.73035829, 0.0868623, 0.33748452],
[0.73669146, 0.08704683, 0.33543669],
[0.74297501, 0.08747196, 0.33329799],
[0.74919318, 0.08820542, 0.33107204],
[0.75535825, 0.08919792, 0.32876184],
[0.76145589, 0.09050716, 0.32637117],
[0.76748424, 0.09213602, 0.32390525],
[0.77344838, 0.09405684, 0.32136808],
[0.77932641, 0.09634794, 0.31876642],
[0.78513609, 0.09892473, 0.31610488],
[0.79085854, 0.10184672, 0.313391],
[0.7965014, 0.10506637, 0.31063031],
[0.80205987, 0.10858333, 0.30783],
[0.80752799, 0.11239964, 0.30499738],
[0.81291606, 0.11645784, 0.30213802],
[0.81820481, 0.12080606, 0.29926105],
[0.82341472, 0.12535343, 0.2963705],
[0.82852822, 0.13014118, 0.29347474],
[0.83355779, 0.13511035, 0.29057852],
[0.83850183, 0.14025098, 0.2876878],
[0.84335441, 0.14556683, 0.28480819],
[0.84813096, 0.15099892, 0.281943],
[0.85281737, 0.15657772, 0.27909826],
[0.85742602, 0.1622583, 0.27627462],
[0.86196552, 0.16801239, 0.27346473],
[0.86641628, 0.17387796, 0.27070818],
[0.87079129, 0.17982114, 0.26797378],
[0.87507281, 0.18587368, 0.26529697],
[0.87925878, 0.19203259, 0.26268136],
[0.8833417, 0.19830556, 0.26014181],
[0.88731387, 0.20469941, 0.25769539],
[0.89116859, 0.21121788, 0.2553592],
[0.89490337, 0.21785614, 0.25314362],
[0.8985026, 0.22463251, 0.25108745],
[0.90197527, 0.23152063, 0.24918223],
[0.90530097, 0.23854541, 0.24748098],
[0.90848638, 0.24568473, 0.24598324],
[0.911533, 0.25292623, 0.24470258],
[0.9144225, 0.26028902, 0.24369359],
[0.91717106, 0.26773821, 0.24294137],
[0.91978131, 0.27526191, 0.24245973],
[0.92223947, 0.28287251, 0.24229568],
[0.92456587, 0.29053388, 0.24242622],
[0.92676657, 0.29823282, 0.24285536],
[0.92882964, 0.30598085, 0.24362274],
[0.93078135, 0.31373977, 0.24468803],
[0.93262051, 0.3215093, 0.24606461],
[0.93435067, 0.32928362, 0.24775328],
[0.93599076, 0.33703942, 0.24972157],
[0.93752831, 0.34479177, 0.25199928],
[0.93899289, 0.35250734, 0.25452808],
[0.94036561, 0.36020899, 0.25734661],
[0.94167588, 0.36786594, 0.2603949],
[0.94291042, 0.37549479, 0.26369821],
[0.94408513, 0.3830811, 0.26722004],
[0.94520419, 0.39062329, 0.27094924],
[0.94625977, 0.39813168, 0.27489742],
[0.94727016, 0.4055909, 0.27902322],
[0.94823505, 0.41300424, 0.28332283],
[0.94914549, 0.42038251, 0.28780969],
[0.95001704, 0.42771398, 0.29244728],
[0.95085121, 0.43500005, 0.29722817],
[0.95165009, 0.44224144, 0.30214494],
[0.9524044, 0.44944853, 0.3072105],
[0.95312556, 0.45661389, 0.31239776],
[0.95381595, 0.46373781, 0.31769923],
[0.95447591, 0.47082238, 0.32310953],
[0.95510255, 0.47787236, 0.32862553],
[0.95569679, 0.48489115, 0.33421404],
[0.95626788, 0.49187351, 0.33985601],
[0.95681685, 0.49882008, 0.34555431],
[0.9573439, 0.50573243, 0.35130912],
[0.95784842, 0.51261283, 0.35711942],
[0.95833051, 0.51946267, 0.36298589],
[0.95879054, 0.52628305, 0.36890904],
[0.95922872, 0.53307513, 0.3748895],
[0.95964538, 0.53983991, 0.38092784],
[0.96004345, 0.54657593, 0.3870292],
[0.96042097, 0.55328624, 0.39319057],
[0.96077819, 0.55997184, 0.39941173],
[0.9611152, 0.5666337, 0.40569343],
[0.96143273, 0.57327231, 0.41203603],
[0.96173392, 0.57988594, 0.41844491],
[0.96201757, 0.58647675, 0.42491751],
[0.96228344, 0.59304598, 0.43145271],
[0.96253168, 0.5995944, 0.43805131],
[0.96276513, 0.60612062, 0.44471698],
[0.96298491, 0.6126247, 0.45145074],
[0.96318967, 0.61910879, 0.45824902],
[0.96337949, 0.6255736, 0.46511271],
[0.96355923, 0.63201624, 0.47204746],
[0.96372785, 0.63843852, 0.47905028],
[0.96388426, 0.64484214, 0.4861196],
[0.96403203, 0.65122535, 0.4932578],
[0.96417332, 0.65758729, 0.50046894],
[0.9643063, 0.66393045, 0.5077467],
[0.96443322, 0.67025402, 0.51509334],
[0.96455845, 0.67655564, 0.52251447],
[0.96467922, 0.68283846, 0.53000231],
[0.96479861, 0.68910113, 0.53756026],
[0.96492035, 0.69534192, 0.5451917],
[0.96504223, 0.7015636, 0.5528892],
[0.96516917, 0.70776351, 0.5606593],
[0.96530224, 0.71394212, 0.56849894],
[0.96544032, 0.72010124, 0.57640375],
[0.96559206, 0.72623592, 0.58438387],
[0.96575293, 0.73235058, 0.59242739],
[0.96592829, 0.73844258, 0.60053991],
[0.96612013, 0.74451182, 0.60871954],
[0.96632832, 0.75055966, 0.61696136],
[0.96656022, 0.75658231, 0.62527295],
[0.96681185, 0.76258381, 0.63364277],
[0.96709183, 0.76855969, 0.64207921],
[0.96739773, 0.77451297, 0.65057302],
[0.96773482, 0.78044149, 0.65912731],
[0.96810471, 0.78634563, 0.66773889],
[0.96850919, 0.79222565, 0.6764046],
[0.96893132, 0.79809112, 0.68512266],
[0.96935926, 0.80395415, 0.69383201],
[0.9698028, 0.80981139, 0.70252255],
[0.97025511, 0.81566605, 0.71120296],
[0.97071849, 0.82151775, 0.71987163],
[0.97120159, 0.82736371, 0.72851999],
[0.97169389, 0.83320847, 0.73716071],
[0.97220061, 0.83905052, 0.74578903],
[0.97272597, 0.84488881, 0.75440141],
[0.97327085, 0.85072354, 0.76299805],
[0.97383206, 0.85655639, 0.77158353],
[0.97441222, 0.86238689, 0.78015619],
[0.97501782, 0.86821321, 0.78871034],
[0.97564391, 0.87403763, 0.79725261],
[0.97628674, 0.87986189, 0.8057883],
[0.97696114, 0.88568129, 0.81430324],
[0.97765722, 0.89149971, 0.82280948],
[0.97837585, 0.89731727, 0.83130786],
[0.97912374, 0.90313207, 0.83979337],
[0.979891, 0.90894778, 0.84827858],
[0.98067764, 0.91476465, 0.85676611],
[0.98137749, 0.92061729, 0.86536915]]
return plt.matplotlib.colors.LinearSegmentedColormap.from_list(
'rocket_r', rocket_colors[::-1])
[docs]
def plot_heatmap(self,
x: SingleCellColumn,
y: SingleCellColumn,
filename: str | Path | None = None,
/,
*,
cells_to_plot_column: SingleCellColumn |
None = 'passed_QC',
normalize_rows: bool = False,
normalize_columns: bool = False,
ax: 'Axes' | None = None,
figure_kwargs: dict[str, Any] | None = None,
colormap: str | 'Colormap' | None = None,
heatmap_kwargs: dict[str, Any] | None = None,
label: bool = False,
label_format: str | None = None,
label_kwargs: dict[str, Any] | None = None,
colorbar: bool = True,
colorbar_kwargs: dict[str, Any] | None = None,
title: str | None = None,
title_kwargs: dict[str, Any] | None = None,
xlabel: str | Literal[True] | None = True,
xlabel_kwargs: dict[str, Any] | None = None,
ylabel: str | Literal[True] | None = True,
ylabel_kwargs: dict[str, Any] | None = None,
despine: bool = True,
savefig_kwargs: dict[str, Any] | None = None) -> None:
"""
Plot a heatmap of the count of each combination of two categorical
columns, `x` and `y`. If `normalize_rows` or `normalize_columns` is
specified, plot percentages instead of counts, so that each row or
column sums to 100%.
Args:
x: the first column; must be String, Enum, Categorical, or integer
y: the second column; must be String, Enum, Categorical, or integer
filename: the file to save to. If `None`, generate the plot but do
not save it, which allows it to be shown interactively or
modified further before saving.
cells_to_plot_column: an optional Boolean column of `obs`
indicating which cells to plot. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns
a polars Series or 1D NumPy array. Set to
`None` to plot all cells passing QC.
normalize_rows: whether to plot percentages instead of counts, so
that each row sums to 100%. Mutually exclusive with
`normalize_columns`.
normalize_columns: whether to plot percentages instead of counts,
so that each column sums to 100%. Mutually
exclusive with `normalize_rows`.
ax: the Matplotlib axes to save the plot onto; if `None`, create a
new figure with Matpotlib's constrained layout and plot onto it
figure_kwargs: a dictionary of keyword arguments to be passed to
`plt.figure` when `ax` is `None`, such as:
- `figsize`: a two-element sequence of the width and
height of the figure in inches. The default is a
complicated formula based on the number of genes
and cell types being plotted, unlike Matplotlib's
default of `[6.4, 4.8]`.
- `layout`: the layout mechanism used by Matplotlib
to avoid overlapping plot elements. Defaults to
`'constrained'`, instead of Matplotlib's default
of `None`.
colormap: a string or Colormap object indicating the Matplotlib
colormap to use in the heatmap, or `None` to use
Seaborn's `'rocket_r'` colormap.
heatmap_kwargs: a dictionary of keyword arguments to be passed to
`ax.pcolormesh()` when generating the heatmap, such
as:
- `rasterized`: whether to convert the heatmap
cells to a raster (bitmap) image when saving to
a vector format like PDF. Defaults to `True`,
instead of Matplotlib's default of `False`.
- `norm`, `vmin`, and `vmax`: control how the
colormap maps counts or percentages to heatmap
colors
- `edgecolors`: the border color of each heatmap
cell; defaults to `'none'`, meaning no borders.
Specifying `cmap` will raise an error, since it
conflicts with the `colormap` argument.
label: whether to label each cell of the heatmap with its count
(or percentage, if `normalize_rows=True` or
`normalize_columns=True`)
label_format: a format string to apply to the label for each count
or percentage. If `None`, use `'{:,}'` for counts and
`'{:.2f}%'` for percentages. Can only be specified
when `label=True`.
label_kwargs: a dictionary of keyword arguments to be passed to
`ax.text()` when adding labels to control the text
properties, such as:
- `color` and `size` to modify the text color/size.
By default, the color is dark gray for
light-colored cells, and white for dark-colored
ones.
- `verticalalignment` and `horizontalalignment` to
control vertical and horizontal alignment. By
default, unlike Matplotlib, these are both set to
`'center'`.
Can only be specified when `label=True`.
colorbar: whether to add a colorbar
colorbar_kwargs: a dictionary of keyword arguments to be passed to
`plt.colorbar()`, such as:
- `location`: `'left'`, `'right'`, `'top'`, or
`'bottom'`
- `orientation`: `'vertical'` or `'horizontal'`
- `fraction`: the fraction of the axes to
allocate to the colorbar. Defaults to 0.15.
- `shrink`: the fraction to multiply the size of
the colorbar by. Defaults to 0.5, instead of
Matplotlib's default of 1.
- `aspect`: the ratio of the colorbar's long to
short dimensions. Defaults to 20.
- `pad`: the fraction of the axes between the
colorbar and the rest of the figure. Defaults to
0.01, instead of Matplotlib's default of 0.05 if
vertical and 0.15 if horizontal.
Can only be specified when `colorbar=True`.
title: the title of the plot, or `None` to not add a title
title_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_title()` to control text properties, such as
`color` and `size`. Can only be specified when
`title` is not `None`.
xlabel: the x-axis label, `True` to use the name of `x` as the
x-axis label, or `None` to not label the x-axis
xlabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_xlabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`xlabel` is not `None`.
ylabel: the y-axis label, `True` to use the name of `y` as the
y-axis label, or `None` to not label the y-axis
ylabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_ylabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`ylabel` is not `None`.
despine: whether to remove the spines (borders of the plot area)
from the plot; unlike the other plotting functions in this
library, this also removes the left and bottom spines
savefig_kwargs: a dictionary of keyword arguments to be passed to
`plt.savefig()`, such as:
- `dpi`: defaults to 300 instead of Matplotlib's
default of 150
- `bbox_inches`: the bounding box of the portion of
the figure to save; defaults to `'tight'` (crop
out any blank borders) instead of Matplotlib's
default of `None` (save the entire figure)
- `pad_inches`: the number of inches of padding to
add on each of the four sides of the figure when
saving. Defaults to `'layout'` (use the padding
from the constrained layout engine, when `ax` is
not `None`) instead of Matplotlib's default of
0.1.
- `transparent`: whether to save with a transparent
background; defaults to `True` if saving to a PDF
(i.e. when `filename` ends with `'.pdf'`) and
`False` otherwise, instead of Matplotlib's
default of always being `False`.
Can only be specified when `filename` is specified.
"""
# Import matplotlib
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
import matplotlib.pyplot as plt
finally:
signal.signal(signal.SIGINT, signal.default_int_handler)
# Get `cells_to_plot_column`, if not `None`
if cells_to_plot_column is not None:
cells_to_plot_column = self._get_column(
'obs', cells_to_plot_column, 'cells_to_plot_column',
pl.Boolean,
allow_missing=isinstance(cells_to_plot_column, str) and
cells_to_plot_column == 'passed_QC')
# Get `x` and `y`, and check that they are String, Enum, Categorical,
# or integer
x = self._get_column('obs', x, 'x',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_column=cells_to_plot_column)
y = self._get_column('obs', y, 'y',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_column=cells_to_plot_column)
if cells_to_plot_column is not None:
x = x.filter(cells_to_plot_column)
y = y.filter(cells_to_plot_column)
# If `filename` was specified, check that it is a string or
# `pathlib.Path` and that its base directory exists; if `filename` is
# `None`, make sure `savefig_kwargs` is also `None`
if filename is not None:
check_type(filename, 'filename', (str, Path),
'a string or pathlib.Path')
directory = os.path.dirname(filename)
if directory and not os.path.isdir(directory):
error_message = (
f'{filename} refers to a file in the directory '
f'{directory!r}, but this directory does not exist')
raise NotADirectoryError(error_message)
filename = str(filename)
elif savefig_kwargs is not None:
error_message = 'savefig_kwargs must be None when filename is None'
raise ValueError(error_message)
# Check that `normalize_rows`, `normalize_columns`, `label`,
# `colorbar`, and `despine` are Boolean
check_type(normalize_rows, 'normalize_rows', bool, 'Boolean')
check_type(normalize_columns, 'normalize_columns', bool, 'Boolean')
check_type(label, 'label', bool, 'Boolean')
check_type(colorbar, 'colorbar', bool, 'Boolean')
check_type(despine, 'despine', bool, 'Boolean')
# Check that `normalize_rows` and `normalize_columns` are mutually
# exclusive
if normalize_rows and normalize_columns:
error_message = \
'only one of normalize_rows and normalize_columns can be True'
raise ValueError(error_message)
# If `figure_kwargs` was specified, check that `ax` is `None`
if figure_kwargs is not None and ax is not None:
error_message = (
'figure_kwargs must be None when ax is not None, since a new '
'figure does not need to be generated when plotting onto an '
'existing axis')
raise ValueError(error_message)
# Check that `colormap` is a string in `plt.colormaps`, a Colormap
# object, or `None`; if `None`, default to Seaborn's rocket_r colormap
if colormap is None:
colormap = SingleCell._get_rocket_r()
else:
check_type(colormap, 'colormap',
(str, plt.matplotlib.colors.Colormap),
'a string or matplotlib Colormap object')
if isinstance(colormap, str):
colormap = plt.colormaps[colormap]
# If `label=False`, check that `label_format` and `label_kwargs` are
# `None`.
if not label:
if label_format is not None:
error_message = 'label_format must be None when label=False'
raise ValueError(error_message)
if label_kwargs is not None:
error_message = 'label_kwargs must be None when label=False'
raise ValueError(error_message)
# If not `None`, check that `label_format` is a valid format string.
# For simplicity, just check that it has curly braces and that all
# braces are matched. If `label_format` is `None`, use `'{:,}'` for
# counts or, if `normalize_rows=True` or `normalize_columns=True`,
# `'{:.2f}%'` for percentages.
if label_format is None:
label_format = \
'{:.2f}%' if normalize_rows or normalize_columns else '{:,}'
else:
check_type(label_format, 'label_format', str, 'a string')
open_braces = 0
has_braces = False
for char in label_format:
if char == '{':
has_braces = True
open_braces += 1
elif char == '}':
open_braces -= 1
if open_braces < 0:
error_message = \
'label_format contains mismatched curly braces'
raise ValueError(error_message)
if open_braces == 0:
error_message = 'label_format contains mismatched curly braces'
raise ValueError(error_message)
if not has_braces:
error_message = 'label_format must contain curly braces'
raise ValueError(error_message)
# If `colorbar=False`, check that `colorbar_kwargs` is None
if not colorbar and colorbar_kwargs is not None:
error_message = 'colorbar_kwargs must be None when colorbar=False'
raise ValueError(error_message)
# Check that `title` is a string or `None`; if `None`, check that
# `title_kwargs` is `None` as well.
if title is not None:
check_type(title, 'title', str, 'a string')
elif title_kwargs is not None:
error_message = 'title_kwargs must be None when title is None'
raise ValueError(error_message)
# Check that `xlabel` is a string, `True` (in which case set it to
# `x.name`), or `None`; if `None`, check that `xlabel_kwargs` is `None`
# as well. Ditto for `ylabel`.
if xlabel is not None:
if xlabel is True:
xlabel = x.name
else:
check_type(xlabel, 'xlabel', str, 'a string')
elif xlabel_kwargs is not None:
error_message = 'xlabel_kwargs must be None when xlabel is None'
raise ValueError(error_message)
if ylabel is not None:
if ylabel is True:
ylabel = y.name
else:
check_type(ylabel, 'ylabel', str, 'a string')
elif ylabel_kwargs is not None:
error_message = 'ylabel_kwargs must be None when ylabel is None'
raise ValueError(error_message)
# For each of the kwargs arguments, if the argument was specified,
# check that it is a dictionary and that all its keys are strings.
for kwargs, kwargs_name in ((figure_kwargs, 'figure_kwargs'),
(heatmap_kwargs, 'heatmap_kwargs'),
(label_kwargs, 'label_kwargs'),
(colorbar_kwargs, 'colorbar_kwargs'),
(title_kwargs, 'title_kwargs'),
(xlabel_kwargs, 'xlabel_kwargs'),
(ylabel_kwargs, 'ylabel_kwargs'),
(savefig_kwargs, 'savefig_kwargs')):
if kwargs is not None:
check_type(kwargs, kwargs_name, dict, 'a dictionary')
for key in kwargs:
if not isinstance(key, str):
error_message = (
f'all keys of {kwargs_name} must be strings, but '
f'it contains a key of type '
f'{type(key).__name__!r}')
raise TypeError(error_message)
# Override the defaults for certain values of `heatmap_kwargs`; if
# specified, check that `heatmap_kwargs` does not contain the `cmap`
# argument
default_heatmap_kwargs = dict(rasterized=True)
if heatmap_kwargs is None:
heatmap_kwargs = default_heatmap_kwargs
else:
if 'cmap' in heatmap_kwargs:
error_message = (
f"'cmap' cannot be specified as a key in heatmap_kwargs; "
f"specify the colormap argument instead")
raise ValueError(error_message)
heatmap_kwargs = heatmap_kwargs | default_heatmap_kwargs
# Get the heatmap data
count = pl.DataFrame((x, y))\
.group_by(pl.all(), maintain_order=True)\
.len(name='_SingleCell_count')\
.pivot(index=y.name, columns=x.name, values='_SingleCell_count')\
.fill_null(0)
heatmap_data = count[:, 1:].to_numpy()
# Normalize, if `normalize_rows=True` or `normalize_columns=True`
if normalize_rows:
heatmap_data = \
heatmap_data / heatmap_data.sum(axis=1, keepdims=True)
elif normalize_columns:
heatmap_data = \
heatmap_data / heatmap_data.sum(axis=0, keepdims=True)
# If `ax` is `None`, create a new figure; otherwise, check that it is a
# Matplotlib axis
make_new_figure = ax is None
try:
num_rows, num_columns = heatmap_data.shape
if make_new_figure:
default_figure_kwargs = dict(layout='constrained')
if figure_kwargs is None or 'figsize' not in figure_kwargs:
if colorbar:
width_ratio = max(4, 0.2 * num_columns)
width = 6.4 / (1 + width_ratio) + \
6.4 * width_ratio / (1 + width_ratio) * \
max(num_columns, 5) / 20
else:
width = 6.4 * max(num_columns, 5) / 20
height = max(4.8, 1 + 3.8 * num_rows / 20)
default_figure_kwargs['figsize'] = width, height
figure_kwargs = default_figure_kwargs | figure_kwargs \
if figure_kwargs is not None else default_figure_kwargs
plt.figure(**figure_kwargs)
ax = plt.gca()
else:
check_type(ax, 'ax', plt.Axes, 'a Matplotlib axis')
# Make the heatmap
xticks = np.arange(0.5, num_columns)
yticks = np.arange(0.5, num_rows)
heatmap = \
ax.pcolormesh(heatmap_data, cmap=colormap, **heatmap_kwargs)
ax.set_xticks(xticks, count.columns[1:], rotation=90)
ax.set_yticks(yticks, count[:, 0].to_numpy())
ax.set_aspect('equal')
# Add the colorbar; override the defaults for certain keys of
# `colorbar_kwargs`. If normalizing rows or columns, make the
# colorbar ticks percentages.
if colorbar:
default_colorbar_kwargs = dict(shrink=0.5, pad=0.01)
colorbar_kwargs = default_colorbar_kwargs | colorbar_kwargs \
if colorbar_kwargs is not None else default_colorbar_kwargs
cbar = plt.colorbar(heatmap, ax=ax, **colorbar_kwargs)
cbar.outline.set_visible(False)
if normalize_rows or normalize_columns:
cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(
lambda x, pos: f'{100 * x:.0f}%'))
# Add labels; this code is edited from `_annotate_heatmap()` at
# github.com/mwaskom/seaborn/blob/master/seaborn/matrix.py
if label:
heatmap.update_scalarmappable()
xpos, ypos = np.meshgrid(xticks, yticks)
if label_kwargs is None:
label_kwargs = {}
label_kwargs |= dict(
horizontalalignment=label_kwargs.pop(
'horizontalalignment',
label_kwargs.pop('ha', 'center')),
verticalalignment=label_kwargs.pop(
'verticalalignment',
label_kwargs.pop('va', 'center')))
if 'c' in label_kwargs or 'color' in label_kwargs:
# Use the same color for all labels
for x, y, val in zip(xpos.ravel(), ypos.ravel(),
heatmap_data.ravel()):
ax.text(x, y, s=label_format.format(val),
**label_kwargs)
else:
# Use either dark gray or white for the label, depending on
# the cell's luminance
rgb_weights = np.array([0.2126, 0.7152, 0.0722])
for x, y, color, val in zip(
xpos.ravel(), ypos.ravel(),
heatmap.get_facecolors(), heatmap_data.ravel()):
rgb = plt.matplotlib.colors.to_rgba_array(color)[:, :3]
rgb = np.where(rgb <= 0.03928, rgb / 12.92,
((rgb + 0.055) / 1.055) ** 2.4)
lum = rgb.dot(rgb_weights).item()
ax.text(x, y, s=label_format.format(val),
c='.15' if lum > .408 else 'w', **label_kwargs)
# Add the title and axis labels
if xlabel is not None:
if xlabel_kwargs is None:
xlabel_kwargs = {}
ax.set_xlabel(xlabel, **xlabel_kwargs)
if ylabel is not None:
if ylabel_kwargs is None:
ylabel_kwargs = {}
ax.set_ylabel(ylabel, **ylabel_kwargs)
if title is not None:
if title_kwargs is None:
title_kwargs = {}
ax.set_title(title, **title_kwargs)
# Despine, if specified
if despine:
spines = ax.spines
for direction in 'top', 'bottom', 'left', 'right':
spines[direction].set_visible(False)
# Save; override the defaults for certain keys of `savefig_kwargs`
if filename is not None:
default_savefig_kwargs = \
dict(dpi=300, bbox_inches='tight', pad_inches='layout',
transparent=filename is not None and
filename.endswith('.pdf'))
savefig_kwargs = default_savefig_kwargs | savefig_kwargs \
if savefig_kwargs is not None else default_savefig_kwargs
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
plt.savefig(filename, **savefig_kwargs)
if make_new_figure:
plt.close()
except:
# If we made a new figure, make sure to close it if there's an
# exception (but not if there was no error and `filename` is
# `None`, in case the user wants to modify it further before
# saving)
if make_new_figure:
plt.close()
raise
@staticmethod
def _process_cell_types(cell_types: str | Iterable[str] | None,
excluded_cell_types: str | Iterable[str] | None,
cell_type_column: pl.Series) -> pl.Series:
"""
Process the `cell_types` and `excluded_cell_types` arguments of various
SingleCell functions.
Args:
cell_types: one or more cell types to include in the calling
function's operation
excluded_cell_types: one or more cell types to exclude from the
calling function's operation
cell_type_column: a column containing cell-type labels
Returns:
The cell-type column, with cell types not present in `cell_types`
OR present in `excluded_cell_types` set to `null`.
"""
if cell_types is not None:
if excluded_cell_types is not None:
error_message = (
'cell_types and excluded_cell_types cannot both be '
'specified')
raise ValueError(error_message)
is_string = isinstance(cell_types, str)
if cell_type_column.dtype.is_integer():
cell_types = to_tuple_checked(cell_types, 'cell_types', int,
'integers')
else:
cell_types = to_tuple_checked(cell_types, 'cell_types', str,
'strings')
if cell_type_column.dtype == pl.Enum or \
cell_type_column.dtype == pl.Categorical:
unique_cell_types = cell_type_column.cat.get_categories()
else:
unique_cell_types = \
cell_type_column.unique(maintain_order=True)
for cell_type in cell_types:
if cell_type not in unique_cell_types:
if is_string:
error_message = (
f'cell_types is {cell_type!r}, which is not a '
f'cell type in cell_type_column')
raise ValueError(error_message)
else:
error_message = (
f'cell_types contains a cell type, {cell_type!r}, '
f'not present in cell_type_column')
raise ValueError(error_message)
cell_type_column = cell_type_column\
.to_frame()\
.select(pl.when(pl.first().is_in(cell_types))
.then(pl.first()))\
.to_series()
elif excluded_cell_types is not None:
is_string = isinstance(excluded_cell_types, str)
if cell_type_column.dtype.is_integer():
excluded_cell_types = to_tuple_checked(
excluded_cell_types, 'excluded_cell_types', int,
'integers')
else:
excluded_cell_types = to_tuple_checked(
excluded_cell_types, 'excluded_cell_types', str,
'strings')
if cell_type_column.dtype == pl.Enum or \
cell_type_column.dtype == pl.Categorical:
unique_cell_types = cell_type_column.cat.get_categories()
else:
unique_cell_types = \
cell_type_column.unique(maintain_order=True)
for cell_type in excluded_cell_types:
if cell_type not in unique_cell_types:
if is_string:
error_message = (
f'excluded_cell_types is {cell_type!r}, which is '
f'not a cell type in cell_type_column')
raise ValueError(error_message)
else:
error_message = (
f'excluded_cell_types contains a cell type, '
f'{cell_type!r}, not present in cell_type_column')
raise ValueError(error_message)
if len(excluded_cell_types) == len(unique_cell_types):
error_message = \
'all cell types were excluded by excluded_cell_types'
raise ValueError(error_message)
cell_type_column = cell_type_column\
.to_frame()\
.select(pl.when(~pl.first().is_in(excluded_cell_types))
.then(pl.first()))\
.to_series()
return cell_type_column
[docs]
def find_markers(self,
cell_type_column: SingleCellColumn,
/,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
cell_types: str | Iterable[str] | int | Iterable[int] |
None = None,
excluded_cell_types: str | Iterable[str] | int |
Iterable[int] | None = None,
min_detection_rate: int | float | np.integer |
np.floating = 0.25,
min_fold_change: int | float | np.integer |
np.floating = 2,
pareto: bool = True,
all_genes: bool = False,
num_threads: int | np.integer | None = None) -> \
pl.DataFrame:
"""
Find "marker genes" that distinguish each cell type from all other cell
types. This function gives the same result regardless of whether it is
run before or after normalization.
Marker genes are chosen via an adaptation of the strategy of Fischer
and Gillis 2021 (ncbi.nlm.nih.gov/pmc/articles/PMC8571500). For a given
cell type, genes are scored based on a) their "detection rate" in that
cell type (the fraction of cells of that type that have non-zero count
for that gene), as well as b) the fold change in detection rate between
that cell type and every other cell type. Genes must also have a
detection rate of at least `min_detection_rate` (25% by default) and a
minimum fold change of at least `min_fold_change` (2-fold by default)
to be considered as markers.
There is an inherent tradeoff between these two metrics. For instance,
candidate marker genes with high enough expression to be expressed in
every cell of a given type (i.e. to have a high detection rate) tend to
also have at least some expression in other cell types (i.e. a low fold
change in detection rate).
Thus, marker genes are selected to optimally trade off between these
two metrics: all genes on the Pareto front of the two metrics (i.e.
genes for which there is no other gene that does better on both metrics
simultaneously) are selected as marker genes.
Note that Fischer and Gillis use AUROC versus log2 fold change in
detection rate, instead of detection rate versus fold change in
detection rate. However, detection rate is much faster to compute than
AUROC, and is a very accurate proxy for AUROC: as Figure 1D in their
paper shows, AUROC is almost perfectly correlated with detection rate
across marker genes.
Args:
cell_type_column: a String, Categorical, Enum, or integer column of
`obs` containing cell-type labels. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a
polars Series or 1D NumPy array.
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored.
cell_types: one or more cell types to find markers for; by default,
finds markers for all cell types in `cell_type_column`.
Specifying `cell_types` is exactly equivalent to
filtering the result to these cell types, but will be
faster when there are many cell types and markers are
only desired for a few of them. Can also be used to
change the order in which cell types are reported, even
if finding markers for all cell types. Mutually
exclusive with `excluded_cell_types`.
excluded_cell_types: one or more cell types to exclude from marker
finding. Mutually exclusive with `cell_types`.
min_detection_rate: the minimum detection rate required to select
a gene as a marker gene; must be greater than 0
and less than or equal to 1
min_fold_change: the minimum fold change in detection rate required
to select a gene as a marker gene; must be greater
than 1
pareto: if `True`, include only genes on the Pareto front of
detection rate and fold change as markers; if `False`,
include all genes that pass the `min_detection_rate` and
`min_fold_change` thresholds as markers
all_genes: if `True`, include all genes in the output, not just
marker genes. An additional Boolean column will be
included to specify which genes are the marker genes.
Note that this option does not change which marker genes
are selected, only which information is returned.
num_threads: the number of threads to use for marker-gene finding.
Set `num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores.
Returns:
By default, a DataFrame with one row per marker gene, with columns:
- `'cell_type'`: a cell-type name from `cell_type_column`
- `'gene'`: a gene symbol from `var_names`
- `'detection_rate'`: the gene's detection rate in that cell type
- `'fold_change'`, the gene's fold change in detection rate
between that cell type and all other cell types
If `all_genes=True`, a DataFrame with one row per cell type-gene
pair, with those four columns plus one other:
- `'marker'`, a Boolean column listing whether the gene is a marker
for that cell type
If `all_genes=False`, marker genes within each cell type will be
sorted in decreasing order of fold change.
Note:
This function may give an incorrect output if the count matrix
contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is
not checked for, due to speed considerations. In the unlikely event
that your dataset contains explicit zeros, remove them by running
`sc.X.eliminate_zeros()` (an in-place operation) first.
Note:
This function may give an incorrect output if the count matrix
contains negative values: this is not checked for, due to speed
considerations.
"""
# Check that `X` is present
X = self._X
if X is None:
error_message = 'X is None, so marker gene finding is not possible'
raise ValueError(error_message)
# Check that `self` is QCed
if not self._uns['QCed']:
error_message = (
"uns['QCed'] is False; did you forget to run qc() before "
"find_markers()? Set uns['QCed'] = True or run skip_qc() to "
"bypass this check.")
raise ValueError(error_message)
# Get the QC column, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
# Get the cell-type column
original_cell_type_column = cell_type_column
cell_type_column = \
self._get_column('obs', cell_type_column, 'cell_type_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_column=QC_column)
cell_type_column_name = cell_type_column.name
# Check that `cell_types` and `excluded_cell_types` are not both
# specified. If `cell_types` is specified, check it contains only cell
# type names present in `cell_type_column`, then set non-matching cell
# types to `null` so that they are treated as a single background cell
# type. If `excluded_cell_types` is specified, do the opposite.
cell_type_column = SingleCell._process_cell_types(
cell_types, excluded_cell_types, cell_type_column)
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Check that `min_detection_rate` and `min_fold_change` are numeric and
# have the correct ranges: 0 < min_detection_rate <= 1,
# min_fold_change > 1
check_type(min_detection_rate, 'min_detection_rate',
(int, float), 'a positive number less than or equal to 1')
check_bounds(min_detection_rate, 'min_detection_rate', 0, 1,
left_open=True)
check_type(min_fold_change, 'min_fold_change', (int, float),
'a number greater than 1')
check_bounds(min_fold_change, 'min_fold_change', 1, left_open=True)
# Check that `all_genes` is Boolean
check_type(all_genes, 'all_genes', bool, 'Boolean')
# Get the indices of the cells of each cell type, ignoring cells
# failing QC when `QC_column` is present in `obs`
groups = (pl.LazyFrame((cell_type_column,))
.with_columns(
_SingleCell_group_indices=pl.int_range(pl.len(),
dtype=pl.UInt32))
if QC_column is None else
pl.LazyFrame((cell_type_column, QC_column))
.with_columns(
_SingleCell_group_indices=pl.int_range(pl.len(),
dtype=pl.UInt32))
.filter(QC_column.name))\
.group_by(cell_type_column_name, maintain_order=True)\
.agg('_SingleCell_group_indices', _SingleCell_num_cells=pl.len())\
.sort(cell_type_column_name)\
.collect()
# Check that `cell_type_column` contains at least two cell types
num_cell_types = len(groups)
if num_cell_types == 1:
cell_type_column_description = \
SingleCell._describe_column('cell_type_column',
original_cell_type_column)
error_message = (
f'{cell_type_column_description} only contains one unique '
f'value')
raise ValueError(error_message)
# If `cell_types` is not `None`, reorder `groups` to be in the same
# order as `cell_types`
if cell_types is not None:
groups = groups.sort(pl.first().cast(pl.Enum(cell_types)),
nulls_last=True)
# Get a cell-type-by-gene matrix of the number of cells of each type
# with non-zero expression of each gene, i.e. the gene's detection
# count in that cell type
num_genes = X.shape[1]
detection_count = \
np.empty((num_cell_types, num_genes), dtype=np.uint32)
if isinstance(X, csr_array):
group_indices = \
groups['_SingleCell_group_indices'].explode().to_numpy()
group_ends = \
groups['_SingleCell_num_cells'].cum_sum().to_numpy()
groupby_getnnz_csr(indices=X.indices, indptr=X.indptr,
group_indices=group_indices,
group_ends=group_ends, nnz=detection_count,
num_threads=num_threads)
else:
group_map = pl.int_range(X.shape[0], dtype=pl.UInt32, eager=True)\
.to_frame('_SingleCell_group_indices')\
.join(groups
.select('_SingleCell_group_indices',
_SingleCell_index=pl.int_range(pl.len(),
dtype=pl.Int32))
.explode('_SingleCell_group_indices'),
on='_SingleCell_group_indices', how='left')\
['_SingleCell_index']
has_missing = group_map.null_count() > 0
if has_missing:
group_map = group_map.fill_null(-1)
group_map = group_map.to_numpy()
groupby_getnnz_csc(indices=X.indices, indptr=X.indptr,
group_map=group_map, has_missing=has_missing,
nnz=detection_count, num_threads=num_threads)
# For each cell type, calculate the detection rate and the fold change
# of the detection rate. Also, initialize the candidate set of points
# on the Pareto front to those with detection rate of at least
# `min_detection_rate` and fold change of at least `min_fold_change`.
total_detection_count = detection_count.sum(axis=0, dtype=np.uint32)
num_cells_per_cell_type = groups['_SingleCell_num_cells'].to_numpy()
total_num_cells = num_cells_per_cell_type.sum()
detection_rate = np.empty((num_cell_types, num_genes),
dtype=np.float32)
fold_change = np.empty((num_cell_types, num_genes), dtype=np.float32)
is_pareto = np.empty((num_cell_types, num_genes), dtype=bool)
get_detection_rate_and_fold_change_and_pareto_candidates(
detection_count=detection_count,
total_detection_count=total_detection_count,
num_cells_per_cell_type=num_cells_per_cell_type,
total_num_cells=total_num_cells,
min_detection_rate=min_detection_rate,
min_fold_change=min_fold_change, detection_rate=detection_rate,
fold_change=fold_change, is_pareto=is_pareto,
num_threads=num_threads)
# If `pareto=True`, find the genes on the Pareto front of the two
# metrics (tie-breaking by higher gene index); these are the marker
# genes. If `pareto=False`, all genes passing the `min_detection_rate`
# and `min_fold_change` thresholds are marker genes, so no additional
# work needed.
if pareto:
pareto_front(detection_rate=detection_rate,
fold_change=fold_change, is_pareto=is_pareto,
num_threads=num_threads)
# If `cell_types` or `excluded_cell_types` is not `None`, filter out
# the `null` background cell type we introduced earlier
if cell_types is not None or excluded_cell_types is not None:
groups = groups[:-1]
is_pareto = is_pareto[:-1]
detection_rate = detection_rate[:-1]
fold_change = fold_change[:-1]
num_cell_types -= 1
# Return a DataFrame of the selected marker genes, or all genes if
# `all_genes=True`
cell_types = groups[cell_type_column_name].rename('cell_type')
genes = self._var[:, 0].rename('gene')
if all_genes:
cell_types = pl.select(pl.lit(cell_types).repeat_by(num_genes))\
.explode('cell_type')\
.to_series()
genes = pl.concat([genes] * num_cell_types)
return pl.DataFrame((
cell_types, genes,
pl.Series('marker', is_pareto.ravel()),
pl.Series('detection_rate', detection_rate.ravel()),
pl.Series('fold_change', fold_change.ravel())))
else:
cell_type_indices, gene_indices = is_pareto.nonzero()
return pl.DataFrame((
cell_types[cell_type_indices], genes[gene_indices],
pl.Series('detection_rate', detection_rate[
cell_type_indices, gene_indices].ravel()),
pl.Series('fold_change', fold_change[
cell_type_indices, gene_indices].ravel())))\
.select(pl.all().sort_by('fold_change', descending=True)
.over('cell_type'))
[docs]
def plot_markers(self,
genes: str | Iterable[str],
cell_type_column: SingleCellColumn,
filename: str | Path | None = None,
/,
*,
cells_to_plot_column: SingleCellColumn |
None = 'passed_QC',
color: Literal['expression',
'fold_change'] = 'expression',
cell_types: str | Iterable[str] | int | Iterable[int] |
None = None,
excluded_cell_types: str | Iterable[str] | int |
Iterable[int] | None = None,
alphabetical_cell_types: bool = True,
figure_kwargs: dict[str, Any] | None = None,
colormap: str | 'Colormap' | None = None,
infinity_color: Color = 'limegreen',
NaN_color: Color = 'lightgray',
colorbar: bool = True,
colorbar_kwargs: dict[str, Any] | None = None,
swap_axes: bool = False,
scatter_kwargs: dict[str, Any] | None = None,
legend_kwargs: dict[str, Any] | None = None,
title: str | None = None,
title_kwargs: dict[str, Any] | None = None,
xlabel: str | None = None,
xlabel_kwargs: dict[str, Any] | None = None,
ylabel: str | None = None,
ylabel_kwargs: dict[str, Any] | None = None,
despine: bool = True,
savefig_kwargs: dict[str, Any] | None = None,
num_threads: int | np.integer | None = None) -> None:
"""
Make a dot plot of a set of marker genes of interest across cell types.
The size of a gene's dot represents its "detection rate" in that cell
type: the fraction of cells of that type where the gene has non-zero
count. By default (`color='expression'`), the color of a gene's dot
represents its expression.
When `color='fold_change'`, the color instead represents the gene's
fold change in detection rate between cells of that cell type and cells
of every other cell type; this is one of two metrics (along with the
detection rate) that are used to select marker genes in
`find_markers()`. `color='fold_change'` gives the same result whether
it is run before or after normalization, since it only depends on the
pattern of non-zero entries in the data, which is unaffected by
normalization.
Unlike the other plotting functions, this is a figure-level rather than
an axis-level function, and does not take an `axis` argument.
Args:
genes: a list of genes to plot: for instance, marker genes found by
`find_markers()`, or marker genes from the literature
cell_type_column: a String, Enum, Categorical, or integer column of
`obs` containing cell-type labels. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns a
polars Series or 1D NumPy array.
filename: the file to save to. If `None`, generate the plot but do
not save it, which allows it to be shown interactively or
modified further before saving.
cells_to_plot_column: an optional Boolean column of `obs`
indicating which cells to plot. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns
a polars Series or 1D NumPy array. Set to
`None` to plot all cells passing QC.
color: whether the color of a gene's dot represents its expression
(`color='expression'`, the default) or its fold change in
detection rate between cells of that cell type and cells of
every other cell type (`color='fold_change'`). When
`color='fold_change'`, the colorbar ticks show the raw fold
changes, but on a log scale, so that a fold change of 4
is twice as red as a fold change of 2, and a fold change of
0.1 is twice as blue as a fold change of 0.2.
cell_types: one or more cell types to plot; by default, plots all
cell types in `cell_type_column`. Can also be used to
change the order in which cell types are displayed,
even if plotting all cell types. Mutually exclusive
with `excluded_cell_types`.
excluded_cell_types: one or more cell types to exclude from the
plot. Mutually exclusive with `cell_types`.
alphabetical_cell_types: whether to force the cell types to be
listed in alphabetical order, even when
`cell_type_column` is a Categorical or
Enum column where the categories are in
non-alphabetical order. If
`alphabetical_cell_types=False`, cell
types will appear in the order specified
by the Categorical or Enum column. Has no
effect when `cell_type_column` is a String
column, since cell types will always be
plotted in alphabetical order by default.
`alphabetical_cell_types=False` can only
be specified when `cell_types` is `None`,
since when `cell_types` is specified, it
defines the order of the cell types
regardless.
figure_kwargs: a dictionary of keyword arguments to be passed to
`plt.figure`, such as:
- `figsize`: a two-element sequence of the width and
height of the figure in inches. The default is a
complicated formula based on the number of genes
and cell types being plotted, unlike Matplotlib's
default of `[6.4, 4.8]`.
- `layout`: the layout mechanism used by Matplotlib
to avoid overlapping plot elements. Defaults to
`'constrained'`, instead of Matplotlib's default
of `None`.
colormap: a string or Colormap object indicating the Matplotlib
colormap to use in `ax.scatter()` for representing
expression values or (if `color='fold_change'`) fold
changes. Defaults to `'Reds'` for expression and
`'RdBu_r'` for fold changes.
infinity_color: when `color='fold_change'`, the color used to plot
infinite log-fold changes (i.e. where a gene is
only expressed in the cell type where it is a
marker). Can be any valid Matplotlib color, like a
hex string (e.g. `'#FF0000'`), a named color (e.g.
'red'), a 3- or 4-element RGB/RGBA tuple of
integers 0-255 or floats 0-1, or a single float 0-1
for grayscale.
NaN_color: when `color='fold_change'`, the color used to plot NaN
log-fold changes (i.e. where a gene is not expressed in
any cell type). Can be any valid Matplotlib color, like
a hex string (e.g. `'#FF0000'`), a named color (e.g.
'red'), a 3- or 4-element RGB/RGBA tuple of integers
0-255 or floats 0-1, or a single float 0-1 for
grayscale.
colorbar: whether to add a colorbar
colorbar_kwargs: a dictionary of keyword arguments to be passed to
`plt.colorbar()`, such as:
- `location`: `'left'`, `'right'`, `'top'`, or
`'bottom'`
- `orientation`: `'vertical'` or `'horizontal'`
- `fraction`: the fraction of the axes to
allocate to the colorbar. Defaults to 0.15.
- `shrink`: the fraction to multiply the size of
the colorbar by. Defaults to 0.5, instead of
Matplotlib's default of 1.
- `aspect`: the ratio of the colorbar's long to
short dimensions. Defaults to 20.
- `pad`: the fraction of the axes between the
colorbar and the rest of the figure. Defaults to
0.01, instead of Matplotlib's default of 0.05 if
vertical and 0.15 if horizontal.
Can only be specified when `colorbar=True`.
swap_axes: if `True`, plot genes on the y-axis and cell types on
the x-axis, instead of the other way around
scatter_kwargs: a dictionary of keyword arguments to be passed to
`ax.scatter()`, such as:
- `rasterized`: whether to convert the scatter plot
points to a raster (bitmap) image when saving to
a vector format like PDF. Defaults to `False`.
- `marker`: the shape to use for plotting each cell
- `norm`, `vmin`, and `vmax`: control how the
`colormap` maps the numbers in `color_column` to
colors, if `color_column` is numeric. If neither
`vmin` nor `vmax` are specified, the default
behavior depends on what is being plotted. When
`color='expression'` and all expression values
are positive, `vmin` will be set to 0 so that the
color scale includes 0. When
`color='fold_change'`, `vmin` will be set to `-M`
and `vmax` to `M` where `M` is the magnitude of
the largest fold change, so that that the colors
for positive and negative fold changes are
symmetrical.
- `alpha`: the transparency of each point
- `linewidths` and `edgecolors`: the width and
color of the borders around each marker. These
are absent by default (`linewidths=0`,
`edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's
default. Both arguments can be either single
values or sequences.
- `zorder`: the order in which the cells are
plotted, with higher values appearing on top of
lower ones.
Specifying `s`, `c`/`color`, or `cmap` will raise
an error, since the size and color of each point
are set automatically, and `cmap` conflicts with
the `colormap` argument.
legend_kwargs: a dictionary of keyword arguments to be passed to
`ax.legend()` to modify the legend, such as:
- `loc`, `bbox_to_anchor`, and `bbox_transform` to
set its location. The legend will be placed in its
own axis in the top right of the plot, and by
default, `loc` is set to `'center'`.
- `ncols` to set its number of columns
- `prop`, `fontsize`, and `labelcolor` to set its
font properties
- `facecolor` and `framealpha` to set its background
color and transparency
- `frameon=True` or `edgecolor` to add or color its
border. `frameon` defaults to `False`, instead of
Matplotlib's default of `True`.
- `title` to modify the legend title. Defaults to
`'Detection rate'`.
title: the title of the plot, or `None` to not add a title
title_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_title()` to control text properties, such as
`color` and `size`. Can only be specified when
`title` is not `None`.
xlabel: the x-axis label, or `None` to not label the x-axis
xlabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_xlabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`xlabel` is not `None`.
ylabel: the y-axis label, or `None` to not label the y-axis
ylabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_ylabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`ylabel` is not `None`.
despine: whether to remove the top and right spines (borders of the
plot area) from the plot
savefig_kwargs: a dictionary of keyword arguments to be passed to
`plt.savefig()`, such as:
- `dpi`: defaults to 300 instead of Matplotlib's
default of 150
- `bbox_inches`: the bounding box of the portion of
the figure to save; defaults to `'tight'` (crop
out any blank borders) instead of Matplotlib's
default of `None` (save the entire figure)
- `pad_inches`: the number of inches of padding to
add on each of the four sides of the figure when
saving. Defaults to `'layout'` (use the padding
from the constrained layout engine, when `ax` is
not `None`), instead of Matplotlib's default of
0.1.
- `transparent`: whether to save with a transparent
background; defaults to `True` if saving to a PDF
(i.e. when `filename` ends with `'.pdf'`) and
`False` otherwise, instead of Matplotlib's
default of always being `False`.
Can only be specified when `filename` is specified.
num_threads: the number of threads to use when tabulating each
gene's detection rate and fold change. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores.
Note:
This function may give an incorrect output if the count matrix
contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is
not checked for, due to speed considerations. In the unlikely event
that your dataset contains explicit zeros, remove them by running
`sc.X.eliminate_zeros()` (an in-place operation) first.
Note:
This function may give an incorrect output if the count matrix
contains negative values: this is not checked for, due to speed
considerations.
"""
# Check that `X` is present
X = self._X
if X is None:
error_message = \
'X is None, so marker gene plotting is not possible'
raise ValueError(error_message)
# Import matplotlib
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
import matplotlib.pyplot as plt
finally:
signal.signal(signal.SIGINT, signal.default_int_handler)
# Check that `self` is QCed
if not self._uns['QCed']:
error_message = (
"uns['QCed'] is False; did you forget to run qc() before "
"plot_markers()? Set uns['QCed'] = True or run skip_qc() to "
"bypass this check.")
raise ValueError(error_message)
# Get `genes` as a polars Series of the same dtype as `var_names`;
# uniquify; make sure all its entries are present in `var_names`
genes = to_tuple_checked(genes, 'genes', str, 'strings')
genes = pl.Series(genes).unique(maintain_order=True)
var_names = self._var[:, 0]
if not genes.is_in(var_names).all():
if not genes.is_in(var_names).any():
error_message = \
'none of the specified genes were found in var_names'
raise ValueError(error_message)
else:
for gene in genes:
if gene not in var_names:
error_message = (
f'one of the specified genes, {gene!r}, was not '
f'found in var_names')
raise ValueError(error_message)
if var_names.dtype != pl.String:
genes = genes.cast(var_names.dtype)
# Get `cells_to_plot_column`, if not `None`
if cells_to_plot_column is not None:
cells_to_plot_column = self._get_column(
'obs', cells_to_plot_column, 'cells_to_plot_column',
pl.Boolean,
allow_missing=isinstance(cells_to_plot_column, str) and
cells_to_plot_column == 'passed_QC')
# Check that `color` is `'expression'` or `'fold_change'`
check_type(color, 'color', str, 'a string')
if color != 'expression' and color != 'fold_change':
error_message = "color must be 'expression' or 'fold_change'"
raise ValueError(error_message)
# Get the cell-type column
original_cell_type_column = cell_type_column
cell_type_column = \
self._get_column('obs', cell_type_column, 'cell_type_column',
(pl.String, pl.Enum, pl.Categorical, 'integer'),
QC_column=cells_to_plot_column)
cell_type_column_name = cell_type_column.name
# Check that `cell_types` and `excluded_cell_types` are not both
# specified. If `cell_types` is specified, check it contains only cell
# type names present in `cell_type_column`, then set non-matching cell
# types to `null` so that they are treated as a single background cell
# type. If `excluded_cell_types` is specified, do the opposite.
cell_type_column = SingleCell._process_cell_types(
cell_types, excluded_cell_types, cell_type_column)
# Check that `alphabetical_cell_types` is Boolean, and that if `False`,
# `cell_types` is `None`
check_type(alphabetical_cell_types, 'alphabetical_cell_types', bool,
'Boolean')
if not alphabetical_cell_types and cell_types is not None:
error_message = (
'alphabetical_cell_types=False can only be specified when '
'cell_types is None')
raise ValueError(error_message)
# If `filename` was specified, check that it is a string or
# `pathlib.Path` and that its base directory exists; if `filename` is
# `None`, make sure `savefig_kwargs` is also `None`
if filename is not None:
check_type(filename, 'filename', (str, Path),
'a string or pathlib.Path')
directory = os.path.dirname(filename)
if directory and not os.path.isdir(directory):
error_message = (
f'{filename} refers to a file in the directory '
f'{directory!r}, but this directory does not exist')
raise NotADirectoryError(error_message)
filename = str(filename)
elif savefig_kwargs is not None:
error_message = 'savefig_kwargs must be None when filename is None'
raise ValueError(error_message)
# Check that `colormap` is a string in `plt.colormaps`, a Colormap
# object, or `None`. If `None`, set to `'Reds'` for expression and
# `'RdBu_r'` for fold changes. If a string, map to a Colormap object
# via the `plt.colormaps` map.
if colormap is None:
colormap = 'RdBu_r' if color == 'fold_change' else 'Reds'
else:
check_type(colormap, 'colormap',
(str, plt.matplotlib.colors.Colormap),
'a string or matplotlib Colormap object')
if isinstance(colormap, str):
colormap = plt.colormaps[colormap]
# If `color='fold_change'`, check that `infinity_color` and `NaN_color`
# are Matplotlib colors, and convert them to hex. Otherwise, check that
# they have their default values, since they are unused.
if color == 'fold_change':
if not plt.matplotlib.colors.is_color_like(infinity_color):
error_message = \
'infinity_color is not a valid Matplotlib color'
raise ValueError(error_message)
infinity_color = plt.matplotlib.colors.to_hex(infinity_color)
if not plt.matplotlib.colors.is_color_like(NaN_color):
error_message = 'NaN_color is not a valid Matplotlib color'
raise ValueError(error_message)
NaN_color = plt.matplotlib.colors.to_hex(NaN_color)
else:
if not isinstance(infinity_color, str) or \
infinity_color != 'limegreen':
error_message = (
"infinity_color cannot be specified unless "
"color='fold_change'")
raise ValueError(error_message)
if not isinstance(NaN_color, str) or NaN_color != 'lightgray':
error_message = \
"NaN_color cannot be specified unless color='fold_change'"
raise ValueError(error_message)
# Check that `colorbar` is Boolean
check_type(colorbar, 'colorbar', bool, 'Boolean')
# If `colorbar=False`, check that `colorbar_kwargs` is None
if not colorbar and colorbar_kwargs is not None:
error_message = 'colorbar_kwargs must be None when colorbar=False'
raise ValueError(error_message)
# Check that `swap_axes` and `despine` are Boolean
check_type(swap_axes, 'swap_axes', bool, 'Boolean')
check_type(despine, 'despine', bool, 'Boolean')
# Check that `title` is a string or `None`; if `None`, check that
# `title_kwargs` is `None` as well. Ditto for `xlabel` and `ylabel`.
for arg, arg_name, arg_kwargs in (
(title, 'title', title_kwargs),
(xlabel, 'xlabel', xlabel_kwargs),
(ylabel, 'ylabel', ylabel_kwargs)):
if arg is not None:
check_type(arg, arg_name, str, 'a string')
elif arg_kwargs is not None:
error_message = \
f'{arg_name}_kwargs must be None when {arg_name} is None'
raise ValueError(error_message)
# For each of the kwargs arguments, if the argument was specified,
# check that it is a dictionary and that all its keys are strings.
for kwargs, kwargs_name in ((figure_kwargs, 'figure_kwargs'),
(colorbar_kwargs, 'colorbar_kwargs'),
(scatter_kwargs, 'scatter_kwargs'),
(legend_kwargs, 'legend_kwargs'),
(title_kwargs, 'title_kwargs'),
(xlabel_kwargs, 'xlabel_kwargs'),
(ylabel_kwargs, 'ylabel_kwargs'),
(savefig_kwargs, 'savefig_kwargs')):
if kwargs is not None:
check_type(kwargs, kwargs_name, dict, 'a dictionary')
for key in kwargs:
if not isinstance(key, str):
error_message = (
f'all keys of {kwargs_name} must be strings, but '
f'it contains a key of type '
f'{type(key).__name__!r}')
raise TypeError(error_message)
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Get the indices of the cells of each cell type, ignoring cells
# failing QC when `QC_column` is present in `obs`. Also get the number
# of cells of each type.
groups = (pl.LazyFrame((cell_type_column,))
.with_columns(
_SingleCell_group_indices=pl.int_range(pl.len(),
dtype=pl.UInt32))
if cells_to_plot_column is None else
pl.LazyFrame((cell_type_column, cells_to_plot_column))
.with_columns(
_SingleCell_group_indices=pl.int_range(pl.len(),
dtype=pl.UInt32))
.filter(cells_to_plot_column.name))\
.group_by(cell_type_column_name, maintain_order=True)\
.agg('_SingleCell_group_indices', _SingleCell_num_cells=pl.len())\
.sort(cell_type_column_name, nulls_last=True)\
.collect()
# Check that `cell_type_column` contains at least two cell types
num_cell_types = len(groups)
if num_cell_types == 1:
cell_type_column_description = \
SingleCell._describe_column('cell_type_column',
original_cell_type_column)
error_message = (
f'{cell_type_column_description} only contains one unique '
f'value')
raise ValueError(error_message)
# If `cell_types` is not `None`, reorder `groups` to be in the same
# order as `cell_types`. Otherwise, get the list of cell types from
# `groups`, sorting alphabetically if `alphabetical_cell_types` and
# `cell_type_column` is Enum or Categorical.
if cell_types is not None:
groups = groups.sort(pl.first().cast(pl.Enum(cell_types)),
nulls_last=True)
else:
if alphabetical_cell_types and (
cell_type_column.dtype == pl.Enum or
cell_type_column.dtype == pl.Categorical):
groups = groups\
.cast({cell_type_column_name: pl.String})\
.sort(cell_type_column_name)
cell_types = groups[cell_type_column_name]
# Get a cell-type-by-gene matrix of the number of cells of each type
# with non-zero expression of each gene in `genes`, i.e. the gene's
# detection count in that cell type. If `color='expression'`, also
# get a second cell-type-by-gene matrix of the total expression of each
# gene in each cell type; we will normalize to get the mean later.
num_genes = len(genes)
detection_count = \
np.empty((num_cell_types, num_genes), dtype=np.uint32)
if isinstance(X, csr_array):
group_indices = \
groups['_SingleCell_group_indices'].explode().to_numpy()
group_ends = \
groups['_SingleCell_num_cells'].cum_sum().to_numpy()
# Get an array mapping each gene in `var_names` to its position in
# `genes` (-1 if missing from `genes`)
gene_map = var_names\
.to_frame()\
.join(genes
.to_frame(var_names.name)
.with_columns(_SingleCell_index=pl.int_range(
pl.len(), dtype=pl.Int32)),
on=var_names.name, how='left')\
.select('_SingleCell_index')\
.to_series()
gene_map = gene_map.fill_null(-1).to_numpy()
if color == 'fold_change':
groupby_getnnz_csr_for_gene_subset(
indices=X.indices, indptr=X.indptr,
group_indices=group_indices, group_ends=group_ends,
gene_map=gene_map, nnz=detection_count,
num_threads=num_threads)
else:
total_expression = np.empty((num_cell_types, num_genes))
groupby_getnnz_and_total_csr_for_gene_subset(
data=X.data, indices=X.indices, indptr=X.indptr,
group_indices=group_indices, group_ends=group_ends,
gene_map=gene_map, nnz=detection_count,
total=total_expression, num_threads=num_threads)
else:
group_map = pl.int_range(X.shape[0], dtype=pl.UInt32, eager=True)\
.to_frame('_SingleCell_group_indices')\
.join(groups
.select('_SingleCell_group_indices',
_SingleCell_index=pl.int_range(pl.len(),
dtype=pl.Int32))
.explode('_SingleCell_group_indices'),
on='_SingleCell_group_indices', how='left')\
['_SingleCell_index']
has_missing = group_map.null_count() > 0
if has_missing:
group_map = group_map.fill_null(-1)
group_map = group_map.to_numpy()
# Get an array mapping each gene in `genes` to its position in
# `var_names`
gene_map = genes\
.to_frame(var_names.name)\
.join(var_names
.to_frame()
.with_columns(_SingleCell_index=pl.int_range(
pl.len(), dtype=pl.UInt32)),
on=var_names.name, how='left')\
.select('_SingleCell_index')\
.to_series()\
.to_numpy()
if color == 'fold_change':
groupby_getnnz_csc_for_gene_subset(
indices=X.indices, indptr=X.indptr, group_map=group_map,
gene_map=gene_map, has_missing=has_missing,
nnz=detection_count, num_threads=num_threads)
else:
total_expression = np.empty((num_cell_types, num_genes))
groupby_getnnz_and_total_csc_for_gene_subset(
data=X.data, indices=X.indices, indptr=X.indptr,
group_map=group_map, gene_map=gene_map,
has_missing=has_missing, nnz=detection_count,
total=total_expression, num_threads=num_threads)
# For each cell type, calculate the detection rate and (if
# `color == 'fold_change'`) the fold change of the detection rate
total_detection_count = detection_count.sum(axis=0, dtype=np.uint32)
detection_rate = np.empty((num_cell_types, num_genes),
dtype=np.float32)
num_cells_per_cell_type = groups['_SingleCell_num_cells'].to_numpy()
total_num_cells = num_cells_per_cell_type.sum()
if color == 'fold_change':
fold_change = np.empty((num_cell_types, num_genes),
dtype=np.float32)
get_detection_rate_and_fold_change(
detection_count=detection_count,
total_detection_count=total_detection_count,
num_cells_per_cell_type=num_cells_per_cell_type,
total_num_cells=total_num_cells, detection_rate=detection_rate,
fold_change=fold_change, num_threads=num_threads)
else:
get_detection_rate(
detection_count=detection_count,
total_detection_count=total_detection_count,
num_cells_per_cell_type=num_cells_per_cell_type,
total_num_cells=total_num_cells, detection_rate=detection_rate,
num_threads=num_threads)
# If plotting expression, convert total expression to mean expression
if color == 'expression':
mean_expression = \
total_expression / num_cells_per_cell_type[:, None]
# If `cell_types` or `excluded_cell_types` were specified and we
# introduced a `null` background cell type as a result, remove it now
if groups[cell_type_column_name][-1] is None:
detection_rate = detection_rate[:-1]
if color == 'fold_change':
fold_change = fold_change[:-1]
else:
mean_expression = mean_expression[:-1]
if excluded_cell_types is not None:
cell_types = cell_types.drop_nulls()
num_cell_types -= 1
# If `swap_axes=True`, swap cell types and genes
if swap_axes:
cell_types, genes = genes, cell_types
num_cell_types, num_genes = num_genes, num_cell_types
detection_rate = detection_rate.T
if color == 'fold_change':
fold_change = fold_change.T
else:
mean_expression = mean_expression.T
# Calculate the range of the legend, and the multiplier to multiply
# each point's size by
max_detection_rate = detection_rate.max()
interval = 0.2 if max_detection_rate > 0.5 else \
0.1 if max_detection_rate > 0.2 else 0.05
max_detection_rate = \
np.ceil(max_detection_rate / interval) * interval
legend_point_sizes = \
np.arange(interval, max_detection_rate + interval / 2, interval)
point_size_multiplier = 180 / max_detection_rate
try:
# Make the figure, including separate portions on the left for the
# legend and colorbar (if `colorbar=True`)
default_figure_kwargs = dict(layout='constrained')
width_ratio = max(4, 0.2 * num_genes)
if figure_kwargs is None or 'figsize' not in figure_kwargs:
gene_multiplier = 0.4 \
if num_genes < 5 else 0.05 * (num_genes - 5) + 0.4 \
if num_genes < 10 else 0.04 * (num_genes - 10) + 0.65 \
if num_genes < 20 else 0.03 * (num_genes - 20) + 1.05
if colorbar:
width = 6.4 / (1 + width_ratio) + \
6.4 * width_ratio / (1 + width_ratio) * \
gene_multiplier
else:
width = 6.4 * gene_multiplier
height = max(4.8, 1 + 3.8 * num_cell_types / 20)
default_figure_kwargs['figsize'] = width, height
figure_kwargs = default_figure_kwargs | figure_kwargs \
if figure_kwargs is not None else default_figure_kwargs
fig = plt.figure(**figure_kwargs)
if colorbar:
gs = fig.add_gridspec(2, 2, width_ratios=[width_ratio, 1],
height_ratios=[1, 1])
else:
gs = fig.add_gridspec(2, 1, height_ratios=[1, 1])
# Plot the circles; override the defaults for certain keys of
# `scatter_kwargs`. If neither `vmin` nor `vmax` are specified,
# set `vmin=0` when `color='expression'` and all expression values
# are positive to ensure the color scale includes 0, and specify
# `vmin` and `vmax` to be centered at 0 when `color='fold_change'`
# so that the colors are symmetrical.
ax_main = fig.add_subplot(gs[:, 0]) # the main plot spans all rows
point_size = detection_rate * point_size_multiplier
x, y = np.meshgrid(range(len(genes)), range(len(cell_types)))
default_scatter_kwargs = \
dict(linewidths=0, edgecolors=(0, 0, 0, 0))
scatter_kwargs = default_scatter_kwargs | scatter_kwargs \
if scatter_kwargs is not None else default_scatter_kwargs
if color == 'fold_change':
with np.errstate(divide='ignore'):
c = np.log2(fold_change)
c_finite = c.copy()
c_finite[~np.isfinite(c_finite)] = np.nan
if 'vmin' not in scatter_kwargs and \
'vmax' not in scatter_kwargs:
vmax = np.nanmax(np.abs(c_finite))
vmin = -vmax
scatter_kwargs['vmax'] = vmax
scatter_kwargs['vmin'] = vmin
else:
c_finite = mean_expression
if 'vmin' not in scatter_kwargs and \
'vmax' not in scatter_kwargs and \
mean_expression.min() > 0:
scatter_kwargs['vmin'] = 0
scatter = ax_main.scatter(x.ravel(), y.ravel(),
s=point_size.ravel(), c=c_finite.ravel(),
cmap=colormap, **scatter_kwargs)
ax_main.set_aspect('equal')
# If `color='fold_change'`, plot infinite and NaN log-fold changes
# in `infinity_color` and `NaN_color`, respectively. This has to be
# done separately since we are passing in actual colors, not
# numbers that are being mapped to colors.
if color == 'fold_change':
del scatter_kwargs['vmin']
del scatter_kwargs['vmax']
infinite_indices = np.where(np.isinf(c))
if infinite_indices[0].size > 0:
ax_main.scatter(infinite_indices[1], infinite_indices[0],
s=point_size[infinite_indices],
c=infinity_color, **scatter_kwargs)
NaN_indices = np.where(np.isnan(c))
if NaN_indices[0].size > 0:
ax_main.scatter(NaN_indices[1], NaN_indices[0],
s=point_size[NaN_indices], c=NaN_color,
**scatter_kwargs)
# Set x and y limits
padding = 0.6
ax_main.set_xlim((-padding, len(genes) - 1 + padding))
ax_main.set_ylim((-padding, len(cell_types) - 1 + padding))
# Invert the y-axis (must be done after setting x and y limits)
ax_main.invert_yaxis()
# Add x and y ticks and tick labels
ax_main.set_xticks(range(len(genes)), genes, rotation=90)
ax_main.set_yticks(range(len(cell_types)), cell_types)
if xlabel is not None:
if xlabel_kwargs is None:
ax_main.set_xlabel(xlabel)
else:
ax_main.set_xlabel(xlabel, **xlabel_kwargs)
if ylabel is not None:
if ylabel_kwargs is None:
ax_main.set_ylabel(ylabel)
else:
ax_main.set_ylabel(ylabel, **ylabel_kwargs)
# Add a legend for detection rate; markers should be at intervals
# of `X`% (`X`%, `2X`%, `3X`%, ...) up to the maximum detection
# rate (rounded up to the nearest `X`%). Override the defaults for
# certain keys of `legend_kwargs`.
ax_legend = fig.add_subplot(gs[0, 1])
ax_legend.axis('off')
legend_elements = [
plt.Line2D([0], [0], label=f'{100 * size:.0f}%',
markersize=np.sqrt(size * point_size_multiplier),
marker='o', linestyle='None',
markerfacecolor='black', markeredgecolor='None')
for size in legend_point_sizes]
default_legend_kwargs = dict(title='Detection rate', loc='center',
frameon=False)
legend_kwargs = default_legend_kwargs | legend_kwargs \
if legend_kwargs is not None else default_legend_kwargs
if legend_elements:
ax_legend.legend(handles=legend_elements, **legend_kwargs)
# Add a colorbar for expression, or if `color='fold_change'`, fold
# change with labels at powers of 2.
if colorbar:
default_colorbar_kwargs = dict(shrink=0.5, pad=0.01)
colorbar_kwargs = default_colorbar_kwargs | colorbar_kwargs \
if colorbar_kwargs is not None else \
default_colorbar_kwargs
ax_colorbar = fig.add_subplot(gs[1, 1])
cbar = plt.colorbar(scatter, cax=ax_colorbar,
**colorbar_kwargs)
cbar.outline.set_visible(False)
cbar.ax.set_box_aspect(12)
cbar.ax.set_title('Fold change of\ndetection rate'
if color == 'fold_change' else
'Mean\nexpression', size='medium')
if color == 'fold_change':
cbar.ax.yaxis.set_major_locator(
plt.MaxNLocator(integer=True))
cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter(
lambda x, pos: f'{2 ** x:.4f}'.rstrip(
'0').rstrip('.')))
# Add the title
if title is not None:
if title_kwargs is None:
ax_main.set_title(title)
else:
ax_main.set_title(title, **title_kwargs)
# Despine, if specified
if despine:
spines = ax_main.spines
spines['top'].set_visible(False)
spines['right'].set_visible(False)
# Save, if `filename` is not `None`; override the defaults for
# certain keys of `savefig_kwargs`
if filename is not None:
default_savefig_kwargs = \
dict(dpi=300, bbox_inches='tight', pad_inches='layout',
transparent=filename is not None and
filename.endswith('.pdf'))
savefig_kwargs = default_savefig_kwargs | savefig_kwargs \
if savefig_kwargs is not None else default_savefig_kwargs
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
plt.savefig(filename, **savefig_kwargs)
plt.close()
except:
# Since we made a new figure, make sure to close it if there's an
# exception (but not if there was no error and `filename` is
# `None`, in case the user wants to modify it further before
# saving)
plt.close()
raise
[docs]
def pacmap(self,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
PC_key: str = 'pca',
neighbors_key: str = 'neighbors',
distances_key: str = 'distances',
embedding_key: str = 'pacmap',
num_neighbors: int | np.integer = 10,
num_extra_neighbors: int | np.integer = 10,
num_mid_near_pairs: int | np.integer = 5,
num_further_pairs: int | np.integer = 20,
num_iterations: int | np.integer |
tuple[int | np.integer, int | np.integer,
int | np.integer] = (100, 100, 250),
learning_rate: int | float | np.integer | np.floating = 1,
seed: int | np.integer = 0,
match_parallel: bool = False,
overwrite: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Calculate a two-dimensional embedding of this SingleCell dataset
suitable for plotting with `plot_embedding()`.
Uses [PaCMAP](https://arxiv.org/pdf/2012.04456), a relative of UMAP
that captures global structure better.
This function is intended to be run after `PCA()` and `neighbors()`. By
default, it uses `obsm['pca']` and `obsm['neighbors']` as the inputs to
PaCMAP, and stores the output in `obsm['pacmap']` as a
`len(obs)` × 2 NumPy array. It can also be run on Harmony embeddings by
running `harmonize()` and then specifying `PC_key='harmony'`.
Args:
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their embeddings set to `NaN`.
PC_key: the key of `obsm` containing the principal components
calculated with `PCA()`, to use as an input for the
embedding calculation. Can also be set to the Harmony
embeddings calculated by `harmonize()`, by specifying
`PC_key='harmony'`.
neighbors_key: the key of `obsm` containing the nearest-neighbor
indices for each cell, to use as an input for the
embedding calculation
distances_key: the key of `obsm` containing the squared Euclidean
distance to each nearest neighbor in
`neighbors_key`, to use as an input for the
embedding calculation
embedding_key: the key of `obsm` where the embeddings will be
stored
num_neighbors: the number of nearest neighbors in the original
high-dimensional space to consider for each point.
Higher values focus on preserving the broader
topological structure of local neighborhoods,
potentially merging close clusters. Lower values
prioritize the very fine-grained local structure,
which can reveal intricate patterns but may also
fragment larger clusters.
num_extra_neighbors: the number of extra nearest neighbors (on top
of `num_neighbors`) to search for initially,
before pruning to the `num_neighbors` of these
`num_neighbors + num_extra_neighbors` cells
with the smallest scaled distances. For a pair
of cells `i` and `j`, the scaled distance
between `i` and `j` is its squared Euclidean
distance, divided by `i`'s average Euclidean
distance to its 3rd, 4th, and 5th nearest
neighbors, divided by `j`'s average Euclidean
distance to its 3rd, 4th, and 5th nearest
neighbors. Must be a non-negative integer.
Defaults to 10, instead of PaCMAP's original
default of 50. `neighbors_key` and
`distances_key` must contain at least
`num_neighbors + num_extra_neighbors` nearest
neighbors.
num_mid_near_pairs: the number of moderately close cells (not
nearest neighbors) to sample for each cell,
used to attract distinct local neighborhoods
together. Higher values add more "scaffolding"
to preserve the large-scale global structure
and the relationships between clusters. Lower
values reduce this effect, allowing local
structures to be placed more independently of
one another.
num_further_pairs: the number of distant cells to sample for each
cell, used to create repulsive forces that
prevent crowding and shape the final layout.
Higher values increase this repulsive force,
leading to a more spread-out embedding with
clearer separation between clusters. Lower
values reduce the force, which can result in a
more compact layout where clusters may be closer
or overlap.
num_iterations: the number of iterations to run PaCMAP for. Can
be a length-3 tuple of the number of iterations for
each of the 3 stages of optimization, or a single
integer of the number of iterations for the third
stage (in which case the number of iterations for
the first two stages will be set to 100).
learning_rate: the learning rate of the Adam optimizer for PaCMAP
seed: the random seed to use for PaCMAP
match_parallel: if `False`, use a different order of operations for
single-threaded PaCMAP. This gives a modest
(~15%) boost in single-threaded performance at the
cost of no longer exactly matching the embedding
produced by the multithreaded version (due to
differences in floating-point error arising from
the different order of operations). Must be `False`
unless `num_threads=1`.
overwrite: if `True`, overwrite `embedding_key` if already present
in `obsm`, instead of raising an error
num_threads: the number of threads to use when running PaCMAP.
Set `num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores. Does not affect the returned
SingleCell dataset's `num_threads`; this will always
be the same as the original dataset's `num_threads`.
Returns:
A new SingleCell dataset with the PaCMAP embedding stored in
`obsm[embedding_key]`.
Note:
PaCMAP's original implementation assumes generic input data, so
it initializes the embedding by standardizing the input data,
running PCA on it, and taking the first two PCs. Because our input
data is already PCs (or harmonized PCs), we avoid redundant
calculations by omitting this step and directly initializing the
embedding with the first two columns of our input data, i.e. the
first two PCs.
"""
# Check that `embedding_key` is a string
check_type(embedding_key, 'embedding_key', str, 'a string')
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Check that `embedding_key` is not already a key in `obsm`, unless
# `overwrite=True`
if not overwrite and embedding_key in self._obsm:
error_message = (
f'embedding_key {embedding_key!r} is already a key of obsm; '
f'did you already run pacmap()? Set overwrite=True to '
f'overwrite.')
raise ValueError(error_message)
# Get the QC column, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
# Get PCs, and check that they are float32 and C-contiguous
check_type(PC_key, 'PC_key', str, 'a string')
if PC_key not in self._obsm:
error_message = f'PC_key {PC_key!r} is not a key of obsm'
if PC_key == 'pca':
error_message += (
'; did you forget to run PCA() (and possibly neighbors()) '
'before pacmap()?')
raise ValueError(error_message)
PCs = self._obsm[PC_key]
if PCs.dtype != np.float32:
error_message = \
f'obsm[{PC_key!r}].dtype is {PCs.dtype!r}, but must be float32'
raise TypeError(error_message)
if not PCs.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{PC_key!r}] is not C-contiguous; make it C-contiguous '
f'with pipe_obsm_key({PC_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
# Get the nearest-neighbor indices and distances, and check that they
# are uint32 and float32, respectively, C-contiguous, and have the same
# width
check_type(neighbors_key, 'neighbors_key', str, 'a string')
if neighbors_key not in self._obsm:
error_message = \
f'neighbors_key {neighbors_key!r} is not a key of obsm'
if neighbors_key == 'neighbors':
error_message += (
'; did you forget to run neighbors() before pacmap()?')
raise ValueError(error_message)
neighbors = self._obsm[neighbors_key]
if neighbors.dtype != np.uint32:
error_message = (
f'obsm[{neighbors_key!r}] must have uint32 data type, but '
f'has data type {str(neighbors.dtype)!r}')
raise TypeError(error_message)
if not neighbors.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{neighbors_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
check_type(distances_key, 'distances_key', str, 'a string')
if distances_key not in self._obsm:
error_message = \
f'distances_key {distances_key!r} is not a key of obsm'
if distances_key == 'distances':
error_message += (
'; did you forget to run neighbors() before pacmap()?')
raise ValueError(error_message)
distances = self._obsm[distances_key]
if distances.dtype != np.float32:
error_message = (
f'obsm[{distances_key!r}] must have float32 data type, but '
f'has data type {str(distances.dtype)!r}')
raise TypeError(error_message)
if not distances.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{distances_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({distances_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
if neighbors.shape[1] != distances.shape[1]:
error_message = (
f'obsm[{neighbors_key!r}] and obsm[{distances_key!r}] have '
f'different numbers of columns ({neighbors.shape[1]:,} vs '
f'{distances.shape[1]:,}')
raise ValueError(error_message)
# Check that `num_extra_neighbors` is ≥ 0
check_type(num_extra_neighbors, 'num_extra_neighbors', int,
'a non-negative integer')
check_bounds(num_extra_neighbors, 'num_extra_neighbors', 0)
# Check that `num_iterations` is a positive integer or length-3 tuple
# thereof
check_type(num_iterations, 'num_iterations', (int, tuple),
'a positive integer or length-3 tuple of positive '
'integers')
if isinstance(num_iterations, tuple):
if len(num_iterations) != 3:
error_message = (
f'num_iterations must be a positive integer or '
f'length-3 tuple of positive integers, but has length '
f'{len(num_iterations):,}')
raise ValueError(error_message)
for step, step_num_iterations in enumerate(num_iterations):
check_type(step_num_iterations,
f'num_iterations[{step!r}]', int,
'a positive integer')
check_bounds(step_num_iterations,
f'num_iterations[{step!r}]', 1)
num_phase_1_iterations, num_phase_2_iterations, \
num_phase_3_iterations = num_iterations
else:
check_bounds(num_iterations, 'num_iterations', 1)
num_phase_1_iterations = num_phase_2_iterations = 100
num_phase_3_iterations = num_iterations
# Check that `learning_rate` is a positive floating-point number
check_type(learning_rate, 'learning_rate', (int, float),
'a positive number')
check_bounds(learning_rate, 'learning_rate', 0, left_open=True)
# Check that `seed` is an integer
check_type(seed, 'seed', int, 'an integer')
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Check that `match_parallel` is Boolean, and `False` unless
# `num_threads=1`
check_type(match_parallel, 'match_parallel', bool, 'Boolean')
if match_parallel and num_threads != 1:
error_message = \
'match_parallel must be False unless num_threads is 1'
raise ValueError(error_message)
# Subset PCs and nearest-neighbor indices to QCed cells only, if
# `QC_column` is not `None`
if QC_column is not None:
QC_column_NumPy = QC_column.to_numpy()
if num_threads == 1:
PCs = PCs[QC_column_NumPy]
neighbors = neighbors[QC_column_NumPy]
distances = distances[QC_column_NumPy]
else:
indices = np.flatnonzero(QC_column_NumPy)
PCs = parallel_subset_2d(PCs, indices, num_threads)
neighbors = parallel_subset_2d(neighbors, indices, num_threads)
distances = parallel_subset_2d(distances, indices, num_threads)
# Check that there are at least 7 cells (since
# `sample_mid_near_pairs() requires 6 other cells)
num_cells = PCs.shape[0]
if num_cells < 7:
error_message = (
f'there are fewer than 7 cells, so the embedding cannot be '
f'calculated')
raise ValueError(error_message)
# Check that `num_neighbors` is between 1 and `num_cells - 1`
check_type(num_neighbors, 'num_neighbors', int, 'a positive integer')
if not 1 <= num_neighbors < num_cells:
error_message = (
f'num_neighbors is {num_neighbors:,}, but must be ≥ 1 and '
f'less than the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that `num_mid_near_pairs` and `num_further_pairs` are between 1
# and `num_cells`
for variable, variable_name in (
(num_mid_near_pairs, 'num_mid_near_pairs'),
(num_further_pairs, 'num_further_pairs')):
check_type(variable, variable_name, int, 'a positive integer')
if not 1 <= variable <= num_cells:
error_message = (
f'{variable_name} is {variable:,}, but must be ≥ 1 and ≤ '
f'the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that there are at least `num_neighbors + num_further_pairs + 1`
# cells (since `sample_further_pairs()` requires
# `num_neighbors + num_further_pairs` other cells)
if num_cells < num_neighbors + num_further_pairs + 1:
error_message = (
f'there are fewer than '
f'{num_neighbors + num_further_pairs + 1} (num_neighbors + '
f'num_further_pairs + 1) cells, so the embedding cannot be '
f'calculated')
raise ValueError(error_message)
# Define `num_total_neighbors` as `num_neighbors + num_extra_neighbors`
num_total_neighbors = num_neighbors + num_extra_neighbors
# Check that `num_total_neighbors` is less than `num_cells`
if num_total_neighbors >= num_cells:
error_message = (
f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors '
f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but '
f'must be less than the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that `neighbors` and `distances` contain at most
# `num_total_neighbors` nearest neighbors
if num_total_neighbors > neighbors.shape[1]:
error_message = (
f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors '
f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but '
f'obsm[{neighbors_key!r}] has only {neighbors.shape[1]} '
f'columns')
raise ValueError(error_message)
if num_total_neighbors > distances.shape[1]:
error_message = (
f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors '
f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but '
f'obsm[{distances_key!r}] has only {distances.shape[1]} '
f'columns')
raise ValueError(error_message)
# Run PaCMAP
if num_threads == 1:
embedding = np.empty((num_cells, 2), dtype=np.float32)
momentum = np.empty((num_cells, 2), dtype=np.float32)
velocity = np.empty((num_cells, 2), dtype=np.float32)
gradients = np.empty((num_cells, 2), dtype=np.float32)
average_distances = np.empty(num_cells, dtype=np.float32)
neighbor_pairs = np.empty((num_cells, num_neighbors),
dtype=np.uint32)
mid_near_pairs = np.empty((num_cells, num_mid_near_pairs),
dtype=np.uint32)
further_pairs = np.empty((num_cells, num_further_pairs),
dtype=np.uint32)
if match_parallel:
neighbor_pair_indices = \
np.empty(2 * num_cells * num_neighbors, dtype=np.uint32)
neighbor_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32)
mid_near_pair_indices = np.empty(
2 * num_cells * num_mid_near_pairs, dtype=np.uint32)
mid_near_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32)
further_pair_indices = np.empty(
2 * num_cells * num_further_pairs, dtype=np.uint32)
further_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32)
else:
neighbor_pair_indices = np.array([], dtype=np.uint32)
neighbor_pair_indptr = np.array([], dtype=np.uint32)
mid_near_pair_indices = np.array([], dtype=np.uint32)
mid_near_pair_indptr = np.array([], dtype=np.uint32)
further_pair_indices = np.array([], dtype=np.uint32)
further_pair_indptr = np.array([], dtype=np.uint32)
else:
embedding = numa_zeros((num_cells, 2), dtype=np.float32)
momentum = numa_zeros((num_cells, 2), dtype=np.float32)
velocity = numa_zeros((num_cells, 2), dtype=np.float32)
gradients = numa_zeros((num_cells, 2), dtype=np.float32)
average_distances = numa_zeros(num_cells, dtype=np.float32)
neighbor_pairs = numa_zeros((num_cells, num_neighbors),
dtype=np.uint32)
mid_near_pairs = numa_zeros((num_cells, num_mid_near_pairs),
dtype=np.uint32)
further_pairs = numa_zeros((num_cells, num_further_pairs),
dtype=np.uint32)
neighbor_pair_indices = \
numa_zeros(2 * num_cells * num_neighbors, dtype=np.uint32)
neighbor_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32)
mid_near_pair_indices = numa_zeros(
2 * num_cells * num_mid_near_pairs, dtype=np.uint32)
mid_near_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32)
further_pair_indices = numa_zeros(
2 * num_cells * num_further_pairs, dtype=np.uint32)
further_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32)
pacmap(PCs, embedding, momentum, velocity, gradients,
average_distances, neighbor_pairs, mid_near_pairs,
further_pairs, neighbor_pair_indices, neighbor_pair_indptr,
mid_near_pair_indices, mid_near_pair_indptr,
further_pair_indices, further_pair_indptr, neighbors, distances,
num_neighbors, num_extra_neighbors, num_mid_near_pairs,
num_further_pairs, num_phase_1_iterations,
num_phase_2_iterations, num_phase_3_iterations, learning_rate,
seed, match_parallel, neighbors_key, num_threads)
# If `QC_column` was specified, back-project from QCed cells to all
# cells, filling with `NaN`
if QC_column is not None:
embedding_QCed = embedding
embedding = np.full((len(self), embedding_QCed.shape[1]), np.nan,
dtype=np.float32)
embedding[QC_column_NumPy] = embedding_QCed
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm | {embedding_key: embedding},
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def localmap(self,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
PC_key: str = 'pca',
neighbors_key: str = 'neighbors',
distances_key: str = 'distances',
embedding_key: str = 'localmap',
num_neighbors: int | np.integer = 10,
num_extra_neighbors: int | np.integer = 10,
num_mid_near_pairs: int | np.integer = 5,
num_further_pairs: int | np.integer = 20,
num_iterations: int | np.integer |
tuple[int | np.integer, int | np.integer,
int | np.integer] = (100, 100, 250),
learning_rate: int | float | np.integer | np.floating = 1,
max_distance: int | float | np.integer | np.floating = 10,
seed: int | np.integer = 0,
match_parallel: bool = False,
overwrite: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Calculate a two-dimensional embedding of this SingleCell dataset
suitable for plotting with `plot_embedding()`.
Uses [LocalMAP](https://arxiv.org/abs/2412.15426), a relative of UMAP
that captures global structure better.
This function is intended to be run after `PCA()` and `neighbors()`. By
default, it uses `obsm['pca']` and `obsm['neighbors']` as the inputs to
LocalMAP, and stores the output in `obsm['localmap']` as a
`len(obs)` × 2 NumPy array. It can also be run on Harmony embeddings by
running `harmonize()` and then specifying `PC_key='harmony'`.
Args:
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their embeddings set to `NaN`.
PC_key: the key of `obsm` containing the principal components
calculated with `PCA()`, to use as an input for the
embedding calculation. Can also be set to the Harmony
embeddings calculated by `harmonize()`, by specifying
`PC_key='harmony'`.
neighbors_key: the key of `obsm` containing the nearest-neighbor
indices for each cell, to use as an input for the
embedding calculation
distances_key: the key of `obsm` containing the squared Euclidean
distance to each nearest neighbor in
`neighbors_key`, to use as an input for the
embedding calculation
embedding_key: the key of `obsm` where the embeddings will be
stored
num_neighbors: the number of nearest neighbors in the original
high-dimensional space to consider for each point.
Higher values focus on preserving the broader
topological structure of local neighborhoods,
potentially merging close clusters. Lower values
prioritize the very fine-grained local structure,
which can reveal intricate patterns but may also
fragment larger clusters.
num_extra_neighbors: the number of extra nearest neighbors (on top
of `num_neighbors`) to search for initially,
before pruning to the `num_neighbors` of these
`num_neighbors + num_extra_neighbors` cells
with the smallest scaled distances. For a pair
of cells `i` and `j`, the scaled distance
between `i` and `j` is its squared Euclidean
distance, divided by `i`'s average Euclidean
distance to its 3rd, 4th, and 5th nearest
neighbors, divided by `j`'s average Euclidean
distance to its 3rd, 4th, and 5th nearest
neighbors. Must be a non-negative integer.
Defaults to 10, instead of LocalMAP's original
default of 50. `neighbors_key` and
`distances_key` must contain at least
`num_neighbors + num_extra_neighbors` nearest
neighbors.
num_mid_near_pairs: the number of moderately close cells (not
nearest neighbors) to sample for each cell,
used to attract distinct local neighborhoods
together. Higher values add more "scaffolding"
to preserve the large-scale global structure
and the relationships between clusters. Lower
values reduce this effect, allowing local
structures to be placed more independently of
one another.
num_further_pairs: the number of distant cells to sample for each
cell, used to create repulsive forces that
prevent crowding and shape the final layout.
Higher values increase this repulsive force,
leading to a more spread-out embedding with
clearer separation between clusters. Lower
values reduce the force, which can result in a
more compact layout where clusters may be closer
or overlap.
num_iterations: the number of iterations to run LocalMAP for. Can
be a length-3 tuple of the number of iterations for
each of the 3 stages of optimization, or a single
integer of the number of iterations for the third
stage (in which case the number of iterations for
the first two stages will be set to 100).
learning_rate: the learning rate of the Adam optimizer for LocalMAP
max_distance: the distance cutoff (in the embedding space) above
which cells will not be considered as further pair
candidates during the final stage of optimization,
also used to define a scaling factor for the
nearest-neighbor gradients during this stage. This
parameter must be set near its default value of 10 to
produce a good-quality embedding.
seed: the random seed to use for LocalMAP
match_parallel: if `False`, use a different order of operations for
single-threaded LocalMAP. This gives a modest
(~15%) boost in single-threaded performance at the
cost of no longer exactly matching the embedding
produced by the multithreaded version (due to
differences in floating-point error arising from
the different order of operations). Must be `False`
unless `num_threads=1`.
overwrite: if `True`, overwrite `embedding_key` if already present
in `obsm`, instead of raising an error
num_threads: the number of threads to use when running LocalMAP.
Set `num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores. Does not affect the returned
SingleCell dataset's `num_threads`; this will always
be the same as the original dataset's `num_threads`.
Returns:
A new SingleCell dataset with the LocalMAP embedding stored in
`obsm[embedding_key]`.
Note:
LocalMAP's original implementation assumes generic input data, so
it initializes the embedding by standardizing the input data,
running PCA on it, and taking the first two PCs. Because our input
data is already PCs (or harmonized PCs), we avoid redundant
calculations by omitting this step and directly initializing the
embedding with the first two columns of our input data, i.e. the
first two PCs.
"""
# Check that `embedding_key` is a string
check_type(embedding_key, 'embedding_key', str, 'a string')
# Check that `overwrite` is Boolean
check_type(overwrite, 'overwrite', bool, 'Boolean')
# Check that `embedding_key` is not already a key in `obsm`, unless
# `overwrite=True`
if not overwrite and embedding_key in self._obsm:
error_message = (
f'embedding_key {embedding_key!r} is already a key of obsm; '
f'did you already run localmap()? Set overwrite=True to '
f'overwrite.')
raise ValueError(error_message)
# Get the QC column, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
# Get PCs, and check that they are float32 and C-contiguous
check_type(PC_key, 'PC_key', str, 'a string')
if PC_key not in self._obsm:
error_message = f'PC_key {PC_key!r} is not a key of obsm'
if PC_key == 'pca':
error_message += (
'; did you forget to run PCA() (and possibly neighbors()) '
'before localmap()?')
raise ValueError(error_message)
PCs = self._obsm[PC_key]
if PCs.dtype != np.float32:
error_message = \
f'obsm[{PC_key!r}].dtype is {PCs.dtype!r}, but must be float32'
raise TypeError(error_message)
if not PCs.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{PC_key!r}] is not C-contiguous; make it C-contiguous '
f'with pipe_obsm_key({PC_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
# Get the nearest-neighbor indices and distances, and check that they
# are uint32 and float32, respectively, C-contiguous, and have the same
# width
check_type(neighbors_key, 'neighbors_key', str, 'a string')
if neighbors_key not in self._obsm:
error_message = \
f'neighbors_key {neighbors_key!r} is not a key of obsm'
if neighbors_key == 'neighbors':
error_message += (
'; did you forget to run neighbors() before localmap()?')
raise ValueError(error_message)
neighbors = self._obsm[neighbors_key]
if neighbors.dtype != np.uint32:
error_message = (
f'obsm[{neighbors_key!r}] must have uint32 data type, but '
f'has data type {str(neighbors.dtype)!r}')
raise TypeError(error_message)
if not neighbors.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{neighbors_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
check_type(distances_key, 'distances_key', str, 'a string')
if distances_key not in self._obsm:
error_message = \
f'distances_key {distances_key!r} is not a key of obsm'
if distances_key == 'distances':
error_message += (
'; did you forget to run neighbors() before localmap()?')
raise ValueError(error_message)
distances = self._obsm[distances_key]
if distances.dtype != np.float32:
error_message = (
f'obsm[{distances_key!r}] must have float32 data type, but '
f'has data type {str(distances.dtype)!r}')
raise TypeError(error_message)
if not distances.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{distances_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({distances_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
if neighbors.shape[1] != distances.shape[1]:
error_message = (
f'obsm[{neighbors_key!r}] and obsm[{distances_key!r}] have '
f'different numbers of columns ({neighbors.shape[1]:,} vs '
f'{distances.shape[1]:,}')
raise ValueError(error_message)
# Check that `num_extra_neighbors` is ≥ 0
check_type(num_extra_neighbors, 'num_extra_neighbors', int,
'a non-negative integer')
check_bounds(num_extra_neighbors, 'num_extra_neighbors', 0)
# Check that `num_iterations` is a positive integer or length-3 tuple
# thereof
check_type(num_iterations, 'num_iterations', (int, tuple),
'a positive integer or length-3 tuple of positive '
'integers')
if isinstance(num_iterations, tuple):
if len(num_iterations) != 3:
error_message = (
f'num_iterations must be a positive integer or '
f'length-3 tuple of positive integers, but has length '
f'{len(num_iterations):,}')
raise ValueError(error_message)
for step, step_num_iterations in enumerate(num_iterations):
check_type(step_num_iterations,
f'num_iterations[{step!r}]', int,
'a positive integer')
check_bounds(step_num_iterations,
f'num_iterations[{step!r}]', 1)
num_phase_1_iterations, num_phase_2_iterations, \
num_phase_3_iterations = num_iterations
else:
check_bounds(num_iterations, 'num_iterations', 1)
num_phase_1_iterations = num_phase_2_iterations = 100
num_phase_3_iterations = num_iterations
# Check that `learning_rate` and `max_distance` are positive
# floating-point numbers
check_type(learning_rate, 'learning_rate', (int, float),
'a positive number')
check_bounds(learning_rate, 'learning_rate', 0, left_open=True)
check_type(max_distance, 'max_distance',
(int, float), 'a positive number')
check_bounds(max_distance, 'max_distance', 0,
left_open=True)
# Check that `seed` is an integer
check_type(seed, 'seed', int, 'an integer')
# Check that `num_threads` is a positive integer, -1 or `None`; if
# `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`
num_threads = self._process_num_threads(num_threads)
# Check that `match_parallel` is Boolean, and `False` unless
# `num_threads=1`
check_type(match_parallel, 'match_parallel', bool, 'Boolean')
if match_parallel and num_threads != 1:
error_message = \
'match_parallel must be False unless num_threads is 1'
raise ValueError(error_message)
# Subset PCs and nearest-neighbor indices to QCed cells only, if
# `QC_column` is not `None`
if QC_column is not None:
QC_column_NumPy = QC_column.to_numpy()
if num_threads == 1:
PCs = PCs[QC_column_NumPy]
neighbors = neighbors[QC_column_NumPy]
distances = distances[QC_column_NumPy]
else:
indices = np.flatnonzero(QC_column_NumPy)
PCs = parallel_subset_2d(PCs, indices, num_threads)
neighbors = parallel_subset_2d(neighbors, indices, num_threads)
distances = parallel_subset_2d(distances, indices, num_threads)
# Check that there are at least 7 cells (since
# `sample_mid_near_pairs() requires 6 other cells)
num_cells = PCs.shape[0]
if num_cells < 7:
error_message = (
f'there are fewer than 7 cells, so the embedding cannot be '
f'calculated')
raise ValueError(error_message)
# Check that `num_neighbors` is between 1 and `num_cells - 1`
check_type(num_neighbors, 'num_neighbors', int, 'a positive integer')
if not 1 <= num_neighbors < num_cells:
error_message = (
f'num_neighbors is {num_neighbors:,}, but must be ≥ 1 and '
f'less than the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that `num_mid_near_pairs` and `num_further_pairs` are between 1
# and `num_cells`
for variable, variable_name in (
(num_mid_near_pairs, 'num_mid_near_pairs'),
(num_further_pairs, 'num_further_pairs')):
check_type(variable, variable_name, int, 'a positive integer')
if not 1 <= variable <= num_cells:
error_message = (
f'{variable_name} is {variable:,}, but must be ≥ 1 and ≤ '
f'the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that there are at least `num_neighbors + num_further_pairs + 1`
# cells (since `sample_further_pairs()` requires
# `num_neighbors + num_further_pairs` other cells)
if num_cells < num_neighbors + num_further_pairs + 1:
error_message = (
f'there are fewer than '
f'{num_neighbors + num_further_pairs + 1} (num_neighbors + '
f'num_further_pairs + 1) cells, so the embedding cannot be '
f'calculated')
raise ValueError(error_message)
# Define `num_total_neighbors` as `num_neighbors + num_extra_neighbors`
num_total_neighbors = num_neighbors + num_extra_neighbors
# Check that `num_total_neighbors` is less than `num_cells`
if num_total_neighbors >= num_cells:
error_message = (
f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors '
f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but '
f'must be less than the number of cells ({num_cells:,})')
raise ValueError(error_message)
# Check that `neighbors` and `distances` contain at most
# `num_total_neighbors` nearest neighbors
if num_total_neighbors > neighbors.shape[1]:
error_message = (
f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors '
f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but '
f'obsm[{neighbors_key!r}] has only {neighbors.shape[1]} '
f'columns')
raise ValueError(error_message)
if num_total_neighbors > distances.shape[1]:
error_message = (
f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors '
f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but '
f'obsm[{distances_key!r}] has only {distances.shape[1]} '
f'columns')
raise ValueError(error_message)
# Run LocalMAP
if num_threads == 1:
embedding = np.empty((num_cells, 2), dtype=np.float32)
momentum = np.empty((num_cells, 2), dtype=np.float32)
velocity = np.empty((num_cells, 2), dtype=np.float32)
gradients = np.empty((num_cells, 2), dtype=np.float32)
average_distances = np.empty(num_cells, dtype=np.float32)
neighbor_pairs = np.empty((num_cells, num_neighbors),
dtype=np.uint32)
mid_near_pairs = np.empty((num_cells, num_mid_near_pairs),
dtype=np.uint32)
further_pairs = np.empty((num_cells, num_further_pairs),
dtype=np.uint32)
if match_parallel:
neighbor_pair_indices = \
np.empty(2 * num_cells * num_neighbors, dtype=np.uint32)
neighbor_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32)
mid_near_pair_indices = np.empty(
2 * num_cells * num_mid_near_pairs, dtype=np.uint32)
mid_near_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32)
further_pair_indices = np.empty(
2 * num_cells * num_further_pairs, dtype=np.uint32)
further_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32)
else:
neighbor_pair_indices = np.array([], dtype=np.uint32)
neighbor_pair_indptr = np.array([], dtype=np.uint32)
mid_near_pair_indices = np.array([], dtype=np.uint32)
mid_near_pair_indptr = np.array([], dtype=np.uint32)
further_pair_indices = np.array([], dtype=np.uint32)
further_pair_indptr = np.array([], dtype=np.uint32)
else:
embedding = numa_zeros((num_cells, 2), dtype=np.float32)
momentum = numa_zeros((num_cells, 2), dtype=np.float32)
velocity = numa_zeros((num_cells, 2), dtype=np.float32)
gradients = numa_zeros((num_cells, 2), dtype=np.float32)
average_distances = numa_zeros(num_cells, dtype=np.float32)
neighbor_pairs = numa_zeros((num_cells, num_neighbors),
dtype=np.uint32)
mid_near_pairs = numa_zeros((num_cells, num_mid_near_pairs),
dtype=np.uint32)
further_pairs = numa_zeros((num_cells, num_further_pairs),
dtype=np.uint32)
neighbor_pair_indices = \
numa_zeros(2 * num_cells * num_neighbors, dtype=np.uint32)
neighbor_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32)
mid_near_pair_indices = numa_zeros(
2 * num_cells * num_mid_near_pairs, dtype=np.uint32)
mid_near_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32)
further_pair_indices = numa_zeros(
2 * num_cells * num_further_pairs, dtype=np.uint32)
further_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32)
localmap(PCs, embedding, momentum, velocity, gradients,
average_distances, neighbor_pairs, mid_near_pairs,
further_pairs, neighbor_pair_indices, neighbor_pair_indptr,
mid_near_pair_indices, mid_near_pair_indptr,
further_pair_indices, further_pair_indptr, neighbors,
distances, num_neighbors, num_extra_neighbors,
num_mid_near_pairs, num_further_pairs, num_phase_1_iterations,
num_phase_2_iterations, num_phase_3_iterations, learning_rate,
max_distance, seed, match_parallel, neighbors_key,
num_threads)
# If `QC_column` was specified, back-project from QCed cells to all
# cells, filling with `NaN`
if QC_column is not None:
embedding_QCed = embedding
embedding = np.full((len(self), embedding_QCed.shape[1]), np.nan,
dtype=np.float32)
embedding[QC_column_NumPy] = embedding_QCed
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm | {embedding_key: embedding},
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def umap(self,
*,
QC_column: SingleCellColumn | None = 'passed_QC',
PC_key: str = 'pca',
neighbors_key: str = 'neighbors',
distances_key: str = 'distances',
embedding_key: str = 'umap',
num_iterations: int | np.integer = 200,
alpha: int | float | np.integer | np.floating = 1,
gamma: int | float | np.integer | np.floating = 1,
negative_sample_rate: int = 5,
a: int | float | np.integer | np.floating | None = None,
b: int | float | np.integer | np.floating | None = None,
spread: float = 1,
min_dist: float = 0.5,
seed: int | np.integer = 0,
overwrite: bool = False,
hogwild: bool = False,
num_threads: int | np.integer | None = None) -> SingleCell:
"""
Calculate a two-dimensional embedding of this SingleCell dataset with
UMAP (Uniform Manifold Approximation and Projection), suitable for
plotting with `plot_embedding()`.
Use `hogwild=True` to run in parallel. Results will not be
reproducible!
This function is intended to be run after `PCA()` and `neighbors()`. By
default, it uses `obsm['pca']`, `obsm['neighbors']`, and
`obsm['distances']` as the inputs to UMAP, and stores the output in
`obsm['umap']` as a `len(obs)` × 2 NumPy array. It can also be run on
Harmony embeddings by running `harmonize()` and then specifying
`PC_key='harmony'`.
Args:
QC_column: an optional Boolean column of `obs` indicating which
cells passed QC. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Set to `None`
to include all cells. Cells failing QC will be ignored
and have their embeddings set to `NaN`.
PC_key: the key of `obsm` containing the principal components
calculated with `PCA()`, to use as an input for the
embedding calculation
neighbors_key: the key of `obsm` containing the nearest-neighbor
indices for each cell, to use as an input for the
embedding calculation
distances_key: the key of `obsm` containing the squared Euclidean
distance to each nearest neighbor in
`neighbors_key`, to use as an input for the
embedding calculation
embedding_key: the key of `obsm` where the embeddings will be
stored
num_iterations: the number of optimization iterations. In
umap-learn, this defaulted to 500 for datasets of
10,000 elements or less, and 200 for datasets
larger than 10,000 elements.
alpha: the initial learning rate for optimization
gamma: the weight applied to negative samples during optimization
negative_sample_rate: the number of negative samples per positive
sample
a: UMAP curve parameter; if `None`, will be fit based on the values
of `spread` and `min_dist`. Either both or neither of `a` and
`b` must be `None`.
b: UMAP curve parameter; if `None`, will be fit based on the values
of `spread` and `min_dist`. Either both or neither of `a` and
`b` must be `None`.
spread: the effective scale of embedded points. Only used when `a`
and `b` are `None`.
min_dist: the minimum distance between points in the embedding.
Only used when `a` and `b` are `None`.
seed: the random seed to use for UMAP
overwrite: if `True`, overwrite `embedding_key` if already present
in `obsm`, instead of raising an error
hogwild: if `True`, go [Hogwild!](https://arxiv.org/abs/1106.5730)
and optimize the embedding in parallel. Results will not
be reproducible!
num_threads: the number of threads to use when running `umap()`.
Cannot be specified when `hogwild=False`. When
`hogwild=True`, must be explicitly specified and
greater than 1 unless `self.num_threads` is greater
than 1, in which case it can be left unset. Set
`num_threads=-1` to use all available cores, as
determined by `os.cpu_count()`, or leave unset to use
`self.num_threads` cores when `hogwild=True` and one
core when `hogwild=False`. Does not affect the
returned SingleCell dataset's `num_threads`; this will
always be the same as the original dataset's
`num_threads`.
Returns:
A new SingleCell dataset with the UMAP embedding stored in
`obsm[embedding_key]`.
"""
# Check that `QC_column` is valid, if not `None`
if QC_column is not None:
QC_column = self._get_column(
'obs', QC_column, 'QC_column', pl.Boolean,
allow_missing=QC_column == 'passed_QC')
# Check that `PC_key` is a string, and get PCs
check_type(PC_key, 'PC_key', str, 'a string')
if PC_key not in self._obsm:
error_message = f'PC_key {PC_key!r} is not a key of obsm'
if PC_key == 'pca':
error_message += (
'; did you forget to run PCA() (and possibly neighbors()) '
'before umap()?')
raise ValueError(error_message)
PCs = self._obsm[PC_key]
if PCs.dtype != np.float32:
error_message = \
f'obsm[{PC_key!r}].dtype is {PCs.dtype!r}, but must be float32'
raise TypeError(error_message)
if not PCs.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{PC_key!r}] is not C-contiguous; make it C-contiguous '
f'with pipe_obsm_key({PC_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
# Check that `neighbors_key` is a string, and get nearest-neighbor
# indices
check_type(neighbors_key, 'neighbors_key', str, 'a string')
if neighbors_key not in self._obsm:
error_message = \
f'neighbors_key {neighbors_key!r} is not a key of obsm'
if neighbors_key == 'neighbors':
error_message += (
'; did you forget to run neighbors() before '
'umap()?')
raise ValueError(error_message)
neighbors = self._obsm[neighbors_key]
if neighbors.dtype != np.uint32:
error_message = (
f'obsm[{neighbors_key!r}] must have uint32 data type, but '
f'has data type {str(neighbors.dtype)!r}')
raise TypeError(error_message)
if not neighbors.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{neighbors_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
# Check that `distances_key` is a string, and get nearest-neighbor
# distances
check_type(distances_key, 'distances_key', str, 'a string')
if distances_key not in self._obsm:
error_message = \
f'distances_key {distances_key!r} is not a key of obsm'
if distances_key == 'distances':
error_message += (
'; did you forget to run neighbors() before '
'umap()?')
raise ValueError(error_message)
distances = self._obsm[distances_key]
if distances.dtype != np.float32:
error_message = (
f'obsm[{distances_key!r}] must have float32 data type, but '
f'has data type {str(distances.dtype)!r}')
raise TypeError(error_message)
if not distances.flags['C_CONTIGUOUS']:
error_message = (
f'obsm[{distances_key!r}] is not C-contiguous; make it '
f'C-contiguous with '
f'pipe_obsm_key({distances_key!r}, np.ascontiguousarray)')
raise ValueError(error_message)
if neighbors.shape[1] != distances.shape[1]:
error_message = (
f'obsm[{neighbors_key!r}] and obsm[{distances_key!r}] have '
f'different numbers of columns ({neighbors.shape[1]:,} vs '
f'{distances.shape[1]:,})')
raise ValueError(error_message)
# Check that `embedding_key` is a string
check_type(embedding_key, 'embedding_key', str, 'a string')
# Check that `embedding_key` is not already a key in `obsm`, unless
# `overwrite=True`
check_type(overwrite, 'overwrite', bool, 'Boolean')
if not overwrite and embedding_key in self._obsm:
error_message = (
f'embedding_key {embedding_key!r} is already a key of obsm; '
f'did you already run umap()? Set overwrite=True to '
f'overwrite.')
raise ValueError(error_message)
# Check that `num_iterations` is a positive integer
check_type(num_iterations, 'num_iterations', int,
'a positive integer')
check_bounds(num_iterations, 'num_iterations', 1)
# Check that `alpha` and `gamma` are positive numbers
for variable, variable_name in (alpha, 'alpha'), (gamma, 'gamma'):
check_type(variable, variable_name, (int, float),
'a positive number')
check_bounds(variable, variable_name, 0, left_open=True)
# Check that `negative_sample_rate` is a positive integer
check_type(negative_sample_rate, 'negative_sample_rate', int,
'a positive integer')
check_bounds(negative_sample_rate, 'negative_sample_rate', 1)
# Check that either both or neither of `a` and `b` are `None`
if (a is None) != (b is None):
error_message = (
f'either both or neither of a and b must be None')
raise ValueError(error_message)
# If `a` and `b` are `None`, check `spread` and `min_dist`
if a is None and b is None:
check_type(spread, 'spread', (int, float), 'a positive number')
check_bounds(spread, 'spread', 0, left_open=True)
check_type(min_dist, 'min_dist', (int, float),
'a non-negative number')
check_bounds(min_dist, 'min_dist', 0)
# Check that `hogwild` is Boolean
check_type(hogwild, 'hogwild', bool, 'Boolean')
# Check that `num_threads` is valid. When `hogwild=False`,
# `num_threads` cannot be specified and defaults to 1. When
# `hogwild=True`, `num_threads` must be explicitly specified and
# greater than 1 unless `self.num_threads` is greater than 1, in which
# case it can be left unset. If `num_threads` is -1, use all available
# cores.
if not hogwild:
if num_threads is not None:
error_message = \
'num_threads cannot be specified when hogwild=False'
raise ValueError(error_message)
num_threads = 1
elif num_threads is None:
if self._num_threads <= 1:
error_message = (
'when hogwild=True and self.num_threads is 1, num_threads '
'must be explicitly specified and greater than 1')
raise ValueError(error_message)
num_threads = self._num_threads
else:
num_threads = self._process_num_threads(num_threads)
if num_threads == 1:
error_message = \
'num_threads must be greater than 1 when hogwild=True'
raise ValueError(error_message)
# Check that there are at least 2 cells
num_cells = PCs.shape[0]
if num_cells < 2:
error_message = (
f'there are fewer than 2 cells, so the embedding cannot be '
f'calculated')
raise ValueError(error_message)
# Subset PCs and nearest-neighbor indices to QCed cells only, if
# `QC_column` is not `None`
if QC_column is not None:
QC_column_NumPy = QC_column.to_numpy()
PCs = PCs[QC_column_NumPy]
neighbors = neighbors[QC_column_NumPy]
distances = distances[QC_column_NumPy]
# If `a` and `b` are `None`, select them based on `spread` and
# `min_dist`
if a is None and b is None:
from scipy.optimize import curve_fit
xv = np.linspace(0, spread * 3, 300)
yv = np.zeros(xv.shape)
yv[xv < min_dist] = 1
yv[xv >= min_dist] = \
np.exp(-(xv[xv >= min_dist] - min_dist) / spread)
a, b = curve_fit(
lambda x, a, b: 1.0 / (1.0 + a * x ** (2 * b)), xv, yv)[0]
def fuzzy_simplicial_set(neighbors, distances, num_threads):
N, K = neighbors.shape
if num_threads == 1:
data = np.empty(neighbors.size, dtype=np.float32)
else:
data = numa_zeros(neighbors.size, dtype=np.float32)
indices = neighbors.ravel().view(np.int32)
indptr = np.arange(0, (N + 1) * K, K, dtype=np.int32)
umap_fuzzy_weights(distances, data, num_threads)
result = csr_array((data, indices, indptr), shape=(N, N))
result = result + result.T - result * result.T
return result
graph = fuzzy_simplicial_set(neighbors, distances, num_threads)
def spectral_layout(data, graph, random_state):
from scipy.linalg import eigh
from scipy.sparse.csgraph import connected_components
from scipy.sparse.linalg import eigsh
from scipy.spatial.distance import cdist
n_connected_components, component_labels = \
connected_components(graph)
if n_connected_components > 1:
# Multi-component layout
result = np.empty((graph.shape[0], 2), dtype=np.float32)
if n_connected_components > 4:
component_centroids = \
np.empty((n_connected_components, data.shape[1]))
for label in range(n_connected_components):
component_centroids[label] = \
data[component_labels == label].mean(axis=0)
distance_matrix = \
cdist(component_centroids, component_centroids)
affinity_matrix = np.exp(-(distance_matrix ** 2))
# Spectral embedding of the affinity matrix
inv_sqrt_degree = \
affinity_matrix.sum(axis=1).ravel() ** -0.5
normalized_affinity_matrix = \
affinity_matrix * inv_sqrt_degree[:, None] * \
inv_sqrt_degree[None, :]
eigenvectors = eigh(normalized_affinity_matrix)[1]
# skip trivial largest eigenvector, reorder largest to
# smallest, just take the next-largest two
meta_embedding = eigenvectors[:, [-2, -3]]
meta_embedding /= np.abs(meta_embedding).max()
else:
k = int(np.ceil(n_connected_components / 2))
base = np.hstack([np.eye(k), np.zeros((k, 2 - k))])
meta_embedding = \
np.vstack([base, -base])[:n_connected_components]
for label in range(n_connected_components):
mask = component_labels == label
component_graph = graph[np.ix_(mask, mask)]
distances = cdist(meta_embedding[label:label + 1],
meta_embedding)
data_range = distances[distances > 0].min() / 2
if component_graph.shape[0] < 4:
result[component_labels == label] = \
random_state.uniform(
low=-data_range, high=data_range,
size=(component_graph.shape[0], 2)) + \
meta_embedding[label]
else:
component_embedding = spectral_layout(
data=None, graph=component_graph,
random_state=random_state)
expansion = \
data_range / np.abs(component_embedding).max()
component_embedding *= expansion
result[component_labels == label] = \
component_embedding + meta_embedding[label]
return result
# UMAP originally took the smallest-magnitude eigenvalues
# (`which='SM'`) of `I - D @ graph @ D` where
# `D = diag(inv_sqrt_degree)`, but this is equivalent to taking the
# most positive eigenvalues (`which='LA'`) of `D @ graph @ D`,
# which converges faster. Similarly, use the trivial eigenvector as
# an initial guess via `v0=sqrt_degree` to speed up convergence,
# rather than the original's `v0=np.ones(L.shape[0])`.
sqrt_degree = np.sqrt(graph.sum(axis=1).ravel())
inv_sqrt_degree = 1 / sqrt_degree
normalized_graph = graph\
.multiply(inv_sqrt_degree[:, None])\
.multiply(inv_sqrt_degree[None, :])
with threadpool_limits(1):
eigenvalues, eigenvectors = eigsh(normalized_graph, k=3,
which='LA', tol=1e-4,
v0=sqrt_degree)
# skip trivial largest eigenvector, reorder largest to smallest
return np.ascontiguousarray(eigenvectors[:, [1, 0]])
embedding = spectral_layout(
PCs, graph, random_state=np.random.RandomState(seed))
# Create the graph in COO format, equivalent to
# `head = graph.tocoo().row` and `tail = graph.tocoo().col`
head = np.repeat(np.arange(graph.shape[0], dtype=graph.indices.dtype),
repeats=np.diff(graph.indptr))
tail = graph.indices
# Optimize the embedding via stochastic gradient descent, in parallel
# without locks if `hogwild=True`
umap_optimize(
embedding=embedding, head=head, tail=tail,
weights=graph.data, num_iterations=num_iterations, a=a, b=b,
gamma=gamma, initial_alpha=alpha,
negative_sample_rate=negative_sample_rate, seed=seed,
num_threads=num_threads)
# If `QC_column` was specified, back-project from QCed cells to all
# cells, filling with `NaN`
if QC_column is not None:
embedding_QCed = embedding
embedding = np.full((len(self), embedding_QCed.shape[1]), np.nan,
dtype=np.float32)
embedding[QC_column_NumPy] = embedding_QCed
return SingleCell(X=self._X, obs=self._obs, var=self._var,
obsm=self._obsm | {embedding_key: embedding},
varm=self._varm, obsp=self._obsp, varp=self._varp,
uns=self._uns, num_threads=self._num_threads)
[docs]
def plot_pacmap(
self,
color_column: SingleCellColumn | None,
filename: str | Path | None = None,
/,
*,
cells_to_plot_column: SingleCellColumn | None = 'passed_QC',
ax: 'Axes' | None = None,
figure_kwargs: dict[str, Any] | None = None,
point_size: int | float | np.integer | np.floating | str |
None = None,
sort_by_frequency: bool = False,
colormap: str | 'Colormap' | dict[Any, Color] = None,
lightness_range: tuple[float | np.floating, float | np.floating] |
None = (100 / 3, 200 / 3),
chroma_range: tuple[float | np.floating, float | np.floating] |
None = (50, 100),
hue_range: tuple[float | np.floating, float | np.floating] |
None = None,
first_color: Color = '#008cb9',
stride: int | np.integer = 5,
default_color: Color = 'lightgray',
scatter_kwargs: dict[str, Any] | None = None,
label: bool = False,
label_kwargs: dict[str, Any] | None = None,
legend: bool = True,
legend_kwargs: dict[str, Any] | None = None,
colorbar: bool = True,
colorbar_kwargs: dict[str, Any] | None = None,
title: str | None = None,
title_kwargs: dict[str, Any] | None = None,
xlabel: str | None = 'Component 1',
xlabel_kwargs: dict[str, Any] | None = None,
ylabel: str | None = 'Component 2',
ylabel_kwargs: dict[str, Any] | None = None,
xlim: tuple[int | float | np.integer | np.floating,
int | float | np.integer | np.floating] | None = None,
ylim: tuple[int | float | np.integer | np.floating,
int | float | np.integer | np.floating] | None = None,
despine: bool = True,
savefig_kwargs: dict[str, Any] | None = None) -> None:
"""
Plot a PaCMAP embedding created with `pacmap()`.
Syntactic sugar for `plot_embedding('pacmap', ...)`.
Args:
color_column: an optional column of `obs` indicating how to color
each cell in the plot. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Can be
discrete (e.g. cell-type labels), specified as a
String/Enum/Categorical column, or quantitative (e.g.
the number of UMIs per cell), specified as an
integer/floating-point column. Missing (`null`) cells
will be plotted with the color `default_color`. Set
to `None` to use `default_color` for all cells.
filename: the file to save to. If `None`, generate the plot but do
not save it, which allows it to be shown interactively or
modified further before saving.
cells_to_plot_column: an optional Boolean column of `obs`
indicating which cells to plot. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns
a polars Series or 1D NumPy array. Set to
`None` to plot all cells passing QC.
ax: the Matplotlib axes to save the plot onto; if `None`, create a
new figure with Matpotlib's constrained layout and plot onto it
figure_kwargs: a dictionary of keyword arguments to be passed to
`plt.figure` when `ax` is `None`, such as:
- `figsize`: a two-element sequence of the width and
height of the figure in inches. Defaults to
`[8, 6]`, 25% larger than Matplotlib's default of
`[6.4, 4.8]`.
- `layout`: the layout mechanism used by Matplotlib
to avoid overlapping plot elements. Defaults to
`'constrained'`, instead of Matplotlib's default
of `None`.
point_size: the size of the points for each cell; defaults to
20,000 divided by the number of cells. Can be a single
number, or the name of a column of `obs` to make each
point a different size.
sort_by_frequency: if `True`, assign colors and sort the legend in
order of decreasing frequency; if `False` (the
default), use
[natural sort order](en.wikipedia.org/wiki/Natural_sort_order).
Cannot be `True` unless `colormap` is `None` and
`color_column` is discrete; if `colormap` is
not `None`, the plot order is determined by the
order of the keys in `colormap`.
colormap: a string or Colormap object indicating the Matplotlib
colormap to use; or, if `color_column` is discrete, a
dictionary mapping values in `color_column` to Matplotlib
colors (cells with values of `color_column` that are not
in the dictionary will be plotted in the color
`default_color`). Defaults to
`plt.rcParams['image.cmap']` (`'viridis'` by default) if
`color_column` is continous, or the colors from a
maximally perceptually distinct colormap if
`color_column` is discrete (with colors assigned in
decreasing order of frequency). Cannot be specified if
`color_column` is `None`.
lightness_range: a two-element tuple with the lightness range of
colors to generate, or `None` to take the full
range: `[0, 100]`. Can only be specified when
`color_column` is discrete and `colormap` is
`None`.
chroma_range: a two-element tuple with the chroma range of colors
to generate, or `None` to take the full range:
`[0, 100]`. Grays have low chroma, and vivid colors
have high chroma. Can only be specified when
`color_column` is discrete and `colormap` is `None`.
hue_range: a two-element tuple with the hue range of colors to
generate, or `None` to take the full range: `[0, 360]`.
Red is at 0°, green at 120°, and blue at 240°. Because
it wraps around, the first element of the tuple can be
greater than the second, unlike for `lightness_range`
and `chroma_range`. Can only be specified when
`color_column` is discrete and `colormap` is `None`.
first_color: the first color of the palette. Can be any valid
Matplotlib color, like a hex string (e.g.
`'#FF0000'`), a named color (e.g. 'red'), a 3- or
4-element RGB/RGBA tuple of integers 0-255 or floats
0-1, or a single float 0-1 for grayscale.
stride: as an optimization, consider only RGB colors where R, G,
and B are all multiples of this value. Must be a small
divisor of 255: 1, 3, 5, 15, or 17. Set to 1 for the best
possible solution, at orders of magnitude more
computational cost.
default_color: the default color to plot cells in when
`color_column` is `None`, or when certain cells have
missing (`null`) values for `color_column`, or when
`colormap` is a dictionary and some cells have
values of `color_column` that are not in the
dictionary. Can be any valid Matplotlib color, like
a hex string (e.g. `'#FF0000'`), a named color (e.g.
'red'), a 3- or 4-element RGB/RGBA tuple of integers
0-255 or floats 0-1, or a single float 0-1 for
grayscale.
scatter_kwargs: a dictionary of keyword arguments to be passed to
`ax.scatter()`, such as:
- `rasterized`: whether to convert the scatter plot
points to a raster (bitmap) image when saving to
a vector format like PDF. Defaults to `True`,
instead of Matplotlib's default of `False`.
- `marker`: the shape to use for plotting each cell
- `norm`, `vmin`, and `vmax`: control how the
`colormap` maps the numbers in `color_column` to
colors, if `color_column` is numeric
- `alpha`: the transparency of each point
- `linewidths` and `edgecolors`: the width and
color of the borders around each marker. These
are absent by default (`linewidths=0`,
`edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's
default. Both arguments can be either single
values or sequences.
- `zorder`: the order in which the cells are
plotted, with higher values appearing on top of
lower ones.
Specifying `s`, `c`/`color`, or `cmap` will raise
an error, since these arguments conflict with the
`point_size`, `color_column`, and `colormap`
arguments, respectively.
label: whether to label cells with each distinct value of
`color_column`. Labels will be placed at the median x and y
position of the points with that color. Can only be `True`
when `color_column` is discrete. When set to `True`, you may
also want to set `legend=False` to avoid redundancy.
label_kwargs: a dictionary of keyword arguments to be passed to
`ax.text()` when adding labels to control the text
properties, such as:
- `color` and `size` to modify the text color/size
- `verticalalignment` and `horizontalalignment` to
control vertical and horizontal alignment. By
default, unlike Matplotlib, these are both set to
`'center'`.
- `path_effects` to set properties for the border
around the text. By default, set to
`matplotlib.patheffects.withStroke(linewidth=3, foreground='white', alpha=0.75)`
instead of Matplotlib's default of `None`, to put a
semi-transparent white border around the labels for
better contrast.
Can only be specified when `label=True`.
legend: whether to add a legend for each value in `color_column`.
Ignored unless `color_column` is discrete.
legend_kwargs: a dictionary of keyword arguments to be passed to
`ax.legend()` to modify the legend, such as:
- `loc`, `bbox_to_anchor`, and `bbox_transform` to
set its location. By default, `loc` is set to
`'center left'` and `bbox_to_anchor` to `(1, 0.5)`
to put the legend to the right of the plot,
anchored at the middle.
- `ncols` to set its number of columns. By
default, set to
`ceil(obs[color_column].n_unique() / 24)` to have
at most 24 items per column.
- `prop`, `fontsize`, and `labelcolor` to set its
font properties
- `facecolor` and `framealpha` to set its background
color and transparency
- `frameon=True` or `edgecolor` to add or color its
border. `frameon` defaults to `False`, instead of
Matplotlib's default of `True`.
- `title` to add a legend title
Can only be specified when `color_column` is
discrete and `legend=True`.
colorbar: whether to add a colorbar. Ignored unless `color_column`
is quantitative.
colorbar_kwargs: a dictionary of keyword arguments to be passed to
`plt.colorbar()`, such as:
- `location`: `'left'`, `'right'`, `'top'`, or
`'bottom'`
- `orientation`: `'vertical'` or `'horizontal'`
- `fraction`: the fraction of the axes to
allocate to the colorbar. Defaults to 0.15.
- `shrink`: the fraction to multiply the size of
the colorbar by. Defaults to 0.5, instead of
Matplotlib's default of 1.
- `aspect`: the ratio of the colorbar's long to
short dimensions. Defaults to 20.
- `pad`: the fraction of the axes between the
colorbar and the rest of the figure. Defaults to
0.01, instead of Matplotlib's default of 0.05 if
vertical and 0.15 if horizontal.
Can only be specified when `color_column` is
quantitative and `colorbar=True`.
title: the title of the plot, or `None` to not add a title
title_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_title()` to control text properties, such as
`color` and `size`. Can only be specified when
`title` is not `None`.
xlabel: the x-axis label, or `None` to not label the x-axis
xlabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_xlabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`xlabel` is not `None`.
ylabel: the y-axis label, or `None` to not label the y-axis
ylabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_ylabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`ylabel` is not `None`.
xlim: a length-2 tuple of the left and right x-axis limits, or
`None` to set the limits based on the data
ylim: a length-2 tuple of the bottom and top y-axis limits, or
`None` to set the limits based on the data
despine: whether to remove the top and right spines (borders of the
plot area) from the plot
savefig_kwargs: a dictionary of keyword arguments to be passed to
`plt.savefig()`, such as:
- `dpi`: defaults to 300 instead of Matplotlib's
default of 150
- `bbox_inches`: the bounding box of the portion of
the figure to save; defaults to `'tight'` (crop
out any blank borders) instead of Matplotlib's
default of `None` (save the entire figure)
- `pad_inches`: the number of inches of padding to
add on each of the four sides of the figure when
saving. Defaults to `'layout'` (use the padding
from the constrained layout engine, when `ax` is
not `None`) instead of Matplotlib's default of
0.1.
- `transparent`: whether to save with a transparent
background; defaults to `True` if saving to a PDF
(i.e. when `filename` ends with `'.pdf'`) and
`False` otherwise, instead of Matplotlib's
default of always being `False`.
Can only be specified when `filename` is specified.
"""
self.plot_embedding('pacmap', color_column, filename,
cells_to_plot_column=cells_to_plot_column, ax=ax,
figure_kwargs=figure_kwargs, point_size=point_size,
sort_by_frequency=sort_by_frequency,
colormap=colormap, lightness_range=lightness_range,
chroma_range=chroma_range, hue_range=hue_range,
first_color=first_color, stride=stride,
default_color=default_color,
scatter_kwargs=scatter_kwargs, label=label,
label_kwargs=label_kwargs, legend=legend,
legend_kwargs=legend_kwargs, colorbar=colorbar,
colorbar_kwargs=colorbar_kwargs, title=title,
title_kwargs=title_kwargs, xlabel=xlabel,
xlabel_kwargs=xlabel_kwargs, ylabel=ylabel,
ylabel_kwargs=ylabel_kwargs, xlim=xlim, ylim=ylim,
despine=despine, savefig_kwargs=savefig_kwargs)
[docs]
def plot_localmap(
self,
color_column: SingleCellColumn | None,
filename: str | Path | None = None,
/,
*,
cells_to_plot_column: SingleCellColumn | None = 'passed_QC',
ax: 'Axes' | None = None,
figure_kwargs: dict[str, Any] | None = None,
point_size: int | float | np.integer | np.floating | str |
None = None,
sort_by_frequency: bool = False,
colormap: str | 'Colormap' | dict[Any, Color] = None,
lightness_range: tuple[float | np.floating, float | np.floating] |
None = (100 / 3, 200 / 3),
chroma_range: tuple[float | np.floating, float | np.floating] |
None = (50, 100),
hue_range: tuple[float | np.floating, float | np.floating] |
None = None,
first_color: Color = '#008cb9',
stride: int | np.integer = 5,
default_color: Color = 'lightgray',
scatter_kwargs: dict[str, Any] | None = None,
label: bool = False,
label_kwargs: dict[str, Any] | None = None,
legend: bool = True,
legend_kwargs: dict[str, Any] | None = None,
colorbar: bool = True,
colorbar_kwargs: dict[str, Any] | None = None,
title: str | None = None,
title_kwargs: dict[str, Any] | None = None,
xlabel: str | None = 'Component 1',
xlabel_kwargs: dict[str, Any] | None = None,
ylabel: str | None = 'Component 2',
ylabel_kwargs: dict[str, Any] | None = None,
xlim: tuple[int | float | np.integer | np.floating,
int | float | np.integer | np.floating] | None = None,
ylim: tuple[int | float | np.integer | np.floating,
int | float | np.integer | np.floating] | None = None,
despine: bool = True,
savefig_kwargs: dict[str, Any] | None = None) -> None:
"""
Plot a LocalMAP embedding created with `localmap()`.
Syntactic sugar for `plot_embedding('localmap', ...)`.
Args:
color_column: an optional column of `obs` indicating how to color
each cell in the plot. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Can be
discrete (e.g. cell-type labels), specified as a
String/Enum/Categorical column, or quantitative (e.g.
the number of UMIs per cell), specified as an
integer/floating-point column. Missing (`null`) cells
will be plotted with the color `default_color`. Set
to `None` to use `default_color` for all cells.
filename: the file to save to. If `None`, generate the plot but do
not save it, which allows it to be shown interactively or
modified further before saving.
cells_to_plot_column: an optional Boolean column of `obs`
indicating which cells to plot. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns
a polars Series or 1D NumPy array. Set to
`None` to plot all cells passing QC.
ax: the Matplotlib axes to save the plot onto; if `None`, create a
new figure with Matpotlib's constrained layout and plot onto it
figure_kwargs: a dictionary of keyword arguments to be passed to
`plt.figure` when `ax` is `None`, such as:
- `figsize`: a two-element sequence of the width and
height of the figure in inches. Defaults to
`[8, 6]`, 25% larger than Matplotlib's default of
`[6.4, 4.8]`.
- `layout`: the layout mechanism used by Matplotlib
to avoid overlapping plot elements. Defaults to
`'constrained'`, instead of Matplotlib's default
of `None`.
point_size: the size of the points for each cell; defaults to
20,000 divided by the number of cells. Can be a single
number, or the name of a column of `obs` to make each
point a different size.
sort_by_frequency: if `True`, assign colors and sort the legend in
order of decreasing frequency; if `False` (the
default), use
[natural sort order](en.wikipedia.org/wiki/Natural_sort_order).
Cannot be `True` unless `colormap` is `None` and
`color_column` is discrete; if `colormap` is
not `None`, the plot order is determined by the
order of the keys in `colormap`.
colormap: a string or Colormap object indicating the Matplotlib
colormap to use; or, if `color_column` is discrete, a
dictionary mapping values in `color_column` to Matplotlib
colors (cells with values of `color_column` that are not
in the dictionary will be plotted in the color
`default_color`). Defaults to
`plt.rcParams['image.cmap']` (`'viridis'` by default) if
`color_column` is continous, or the colors from a
maximally perceptually distinct colormap if
`color_column` is discrete (with colors assigned in
decreasing order of frequency). Cannot be specified if
`color_column` is `None`.
lightness_range: a two-element tuple with the lightness range of
colors to generate, or `None` to take the full
range: `[0, 100]`. Can only be specified when
`color_column` is discrete and `colormap` is
`None`.
chroma_range: a two-element tuple with the chroma range of colors
to generate, or `None` to take the full range:
`[0, 100]`. Grays have low chroma, and vivid colors
have high chroma. Can only be specified when
`color_column` is discrete and `colormap` is `None`.
hue_range: a two-element tuple with the hue range of colors to
generate, or `None` to take the full range: `[0, 360]`.
Red is at 0°, green at 120°, and blue at 240°. Because
it wraps around, the first element of the tuple can be
greater than the second, unlike for `lightness_range`
and `chroma_range`. Can only be specified when
`color_column` is discrete and `colormap` is `None`.
first_color: the first color of the palette. Can be any valid
Matplotlib color, like a hex string (e.g.
`'#FF0000'`), a named color (e.g. 'red'), a 3- or
4-element RGB/RGBA tuple of integers 0-255 or floats
0-1, or a single float 0-1 for grayscale.
stride: as an optimization, consider only RGB colors where R, G,
and B are all multiples of this value. Must be a small
divisor of 255: 1, 3, 5, 15, or 17. Set to 1 for the best
possible solution, at orders of magnitude more
computational cost.
default_color: the default color to plot cells in when
`color_column` is `None`, or when certain cells have
missing (`null`) values for `color_column`, or when
`colormap` is a dictionary and some cells have
values of `color_column` that are not in the
dictionary. Can be any valid Matplotlib color, like
a hex string (e.g. `'#FF0000'`), a named color (e.g.
'red'), a 3- or 4-element RGB/RGBA tuple of integers
0-255 or floats 0-1, or a single float 0-1 for
grayscale.
scatter_kwargs: a dictionary of keyword arguments to be passed to
`ax.scatter()`, such as:
- `rasterized`: whether to convert the scatter plot
points to a raster (bitmap) image when saving to
a vector format like PDF. Defaults to `True`,
instead of Matplotlib's default of `False`.
- `marker`: the shape to use for plotting each cell
- `norm`, `vmin`, and `vmax`: control how the
`colormap` maps the numbers in `color_column` to
colors, if `color_column` is numeric
- `alpha`: the transparency of each point
- `linewidths` and `edgecolors`: the width and
color of the borders around each marker. These
are absent by default (`linewidths=0`,
`edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's
default. Both arguments can be either single
values or sequences.
- `zorder`: the order in which the cells are
plotted, with higher values appearing on top of
lower ones.
Specifying `s`, `c`/`color`, or `cmap` will raise
an error, since these arguments conflict with the
`point_size`, `color_column`, and `colormap`
arguments, respectively.
label: whether to label cells with each distinct value of
`color_column`. Labels will be placed at the median x and y
position of the points with that color. Can only be `True`
when `color_column` is discrete. When set to `True`, you may
also want to set `legend=False` to avoid redundancy.
label_kwargs: a dictionary of keyword arguments to be passed to
`ax.text()` when adding labels to control the text
properties, such as:
- `color` and `size` to modify the text color/size
- `verticalalignment` and `horizontalalignment` to
control vertical and horizontal alignment. By
default, unlike Matplotlib, these are both set to
`'center'`.
- `path_effects` to set properties for the border
around the text. By default, set to
`matplotlib.patheffects.withStroke(linewidth=3, foreground='white', alpha=0.75)`
instead of Matplotlib's default of `None`, to put a
semi-transparent white border around the labels for
better contrast.
Can only be specified when `label=True`.
legend: whether to add a legend for each value in `color_column`.
Ignored unless `color_column` is discrete.
legend_kwargs: a dictionary of keyword arguments to be passed to
`ax.legend()` to modify the legend, such as:
- `loc`, `bbox_to_anchor`, and `bbox_transform` to
set its location. By default, `loc` is set to
`'center left'` and `bbox_to_anchor` to `(1, 0.5)`
to put the legend to the right of the plot,
anchored at the middle.
- `ncols` to set its number of columns. By
default, set to
`ceil(obs[color_column].n_unique() / 24)` to have
at most 24 items per column.
- `prop`, `fontsize`, and `labelcolor` to set its
font properties
- `facecolor` and `framealpha` to set its background
color and transparency
- `frameon=True` or `edgecolor` to add or color its
border. `frameon` defaults to `False`, instead of
Matplotlib's default of `True`.
- `title` to add a legend title
Can only be specified when `color_column` is
discrete and `legend=True`.
colorbar: whether to add a colorbar. Ignored unless `color_column`
is quantitative.
colorbar_kwargs: a dictionary of keyword arguments to be passed to
`plt.colorbar()`, such as:
- `location`: `'left'`, `'right'`, `'top'`, or
`'bottom'`
- `orientation`: `'vertical'` or `'horizontal'`
- `fraction`: the fraction of the axes to
allocate to the colorbar. Defaults to 0.15.
- `shrink`: the fraction to multiply the size of
the colorbar by. Defaults to 0.5, instead of
Matplotlib's default of 1.
- `aspect`: the ratio of the colorbar's long to
short dimensions. Defaults to 20.
- `pad`: the fraction of the axes between the
colorbar and the rest of the figure. Defaults to
0.01, instead of Matplotlib's default of 0.05 if
vertical and 0.15 if horizontal.
Can only be specified when `color_column` is
quantitative and `colorbar=True`.
title: the title of the plot, or `None` to not add a title
title_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_title()` to control text properties, such as
`color` and `size`. Can only be specified when
`title` is not `None`.
xlabel: the x-axis label, or `None` to not label the x-axis
xlabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_xlabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`xlabel` is not `None`.
ylabel: the y-axis label, or `None` to not label the y-axis
ylabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_ylabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`ylabel` is not `None`.
xlim: a length-2 tuple of the left and right x-axis limits, or
`None` to set the limits based on the data
ylim: a length-2 tuple of the bottom and top y-axis limits, or
`None` to set the limits based on the data
despine: whether to remove the top and right spines (borders of the
plot area) from the plot
savefig_kwargs: a dictionary of keyword arguments to be passed to
`plt.savefig()`, such as:
- `dpi`: defaults to 300 instead of Matplotlib's
default of 150
- `bbox_inches`: the bounding box of the portion of
the figure to save; defaults to `'tight'` (crop
out any blank borders) instead of Matplotlib's
default of `None` (save the entire figure)
- `pad_inches`: the number of inches of padding to
add on each of the four sides of the figure when
saving. Defaults to `'layout'` (use the padding
from the constrained layout engine, when `ax` is
not `None`) instead of Matplotlib's default of
0.1.
- `transparent`: whether to save with a transparent
background; defaults to `True` if saving to a PDF
(i.e. when `filename` ends with `'.pdf'`) and
`False` otherwise, instead of Matplotlib's
default of always being `False`.
Can only be specified when `filename` is specified.
"""
self.plot_embedding('localmap', color_column, filename,
cells_to_plot_column=cells_to_plot_column, ax=ax,
figure_kwargs=figure_kwargs, point_size=point_size,
sort_by_frequency=sort_by_frequency,
colormap=colormap, lightness_range=lightness_range,
chroma_range=chroma_range, hue_range=hue_range,
first_color=first_color, stride=stride,
default_color=default_color,
scatter_kwargs=scatter_kwargs, label=label,
label_kwargs=label_kwargs, legend=legend,
legend_kwargs=legend_kwargs, colorbar=colorbar,
colorbar_kwargs=colorbar_kwargs, title=title,
title_kwargs=title_kwargs, xlabel=xlabel,
xlabel_kwargs=xlabel_kwargs, ylabel=ylabel,
ylabel_kwargs=ylabel_kwargs, xlim=xlim, ylim=ylim,
despine=despine, savefig_kwargs=savefig_kwargs)
[docs]
def plot_umap(
self,
color_column: SingleCellColumn | None,
filename: str | Path | None = None,
/,
*,
cells_to_plot_column: SingleCellColumn | None = 'passed_QC',
ax: 'Axes' | None = None,
figure_kwargs: dict[str, Any] | None = None,
point_size: int | float | np.integer | np.floating | str |
None = None,
sort_by_frequency: bool = False,
colormap: str | 'Colormap' | dict[Any, Color] = None,
lightness_range: tuple[float | np.floating, float | np.floating] |
None = (100 / 3, 200 / 3),
chroma_range: tuple[float | np.floating, float | np.floating] |
None = (50, 100),
hue_range: tuple[float | np.floating, float | np.floating] |
None = None,
first_color: Color = '#008cb9',
stride: int | np.integer = 5,
default_color: Color = 'lightgray',
scatter_kwargs: dict[str, Any] | None = None,
label: bool = False,
label_kwargs: dict[str, Any] | None = None,
legend: bool = True,
legend_kwargs: dict[str, Any] | None = None,
colorbar: bool = True,
colorbar_kwargs: dict[str, Any] | None = None,
title: str | None = None,
title_kwargs: dict[str, Any] | None = None,
xlabel: str | None = 'Component 1',
xlabel_kwargs: dict[str, Any] | None = None,
ylabel: str | None = 'Component 2',
ylabel_kwargs: dict[str, Any] | None = None,
xlim: tuple[int | float | np.integer | np.floating,
int | float | np.integer | np.floating] | None = None,
ylim: tuple[int | float | np.integer | np.floating,
int | float | np.integer | np.floating] | None = None,
despine: bool = True,
savefig_kwargs: dict[str, Any] | None = None) -> None:
"""
Plot a UMAP embedding created with `umap()`.
Syntactic sugar for `plot_embedding('umap', ...)`.
Args:
color_column: an optional column of `obs` indicating how to color
each cell in the plot. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Can be
discrete (e.g. cell-type labels), specified as a
String/Enum/Categorical column, or quantitative (e.g.
the number of UMIs per cell), specified as an
integer/floating-point column. Missing (`null`) cells
will be plotted with the color `default_color`. Set
to `None` to use `default_color` for all cells.
filename: the file to save to. If `None`, generate the plot but do
not save it, which allows it to be shown interactively or
modified further before saving.
cells_to_plot_column: an optional Boolean column of `obs`
indicating which cells to plot. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns
a polars Series or 1D NumPy array. Set to
`None` to plot all cells passing QC.
ax: the Matplotlib axes to save the plot onto; if `None`, create a
new figure with Matpotlib's constrained layout and plot onto it
figure_kwargs: a dictionary of keyword arguments to be passed to
`plt.figure` when `ax` is `None`, such as:
- `figsize`: a two-element sequence of the width and
height of the figure in inches. Defaults to
`[8, 6]`, 25% larger than Matplotlib's default of
`[6.4, 4.8]`.
- `layout`: the layout mechanism used by Matplotlib
to avoid overlapping plot elements. Defaults to
`'constrained'`, instead of Matplotlib's default
of `None`.
point_size: the size of the points for each cell; defaults to
20,000 divided by the number of cells. Can be a single
number, or the name of a column of `obs` to make each
point a different size.
sort_by_frequency: if `True`, assign colors and sort the legend in
order of decreasing frequency; if `False` (the
default), use
[natural sort order](en.wikipedia.org/wiki/Natural_sort_order).
Cannot be `True` unless `colormap` is `None` and
`color_column` is discrete; if `colormap` is
not `None`, the plot order is determined by the
order of the keys in `colormap`.
colormap: a string or Colormap object indicating the Matplotlib
colormap to use; or, if `color_column` is discrete, a
dictionary mapping values in `color_column` to Matplotlib
colors (cells with values of `color_column` that are not
in the dictionary will be plotted in the color
`default_color`). Defaults to
`plt.rcParams['image.cmap']` (`'viridis'` by default) if
`color_column` is continous, or the colors from a
maximally perceptually distinct colormap if
`color_column` is discrete (with colors assigned in
decreasing order of frequency). Cannot be specified if
`color_column` is `None`.
lightness_range: a two-element tuple with the lightness range of
colors to generate, or `None` to take the full
range: `[0, 100]`. Can only be specified when
`color_column` is discrete and `colormap` is
`None`.
chroma_range: a two-element tuple with the chroma range of colors
to generate, or `None` to take the full range:
`[0, 100]`. Grays have low chroma, and vivid colors
have high chroma. Can only be specified when
`color_column` is discrete and `colormap` is `None`.
hue_range: a two-element tuple with the hue range of colors to
generate, or `None` to take the full range: `[0, 360]`.
Red is at 0°, green at 120°, and blue at 240°. Because
it wraps around, the first element of the tuple can be
greater than the second, unlike for `lightness_range`
and `chroma_range`. Can only be specified when
`color_column` is discrete and `colormap` is `None`.
first_color: the first color of the palette. Can be any valid
Matplotlib color, like a hex string (e.g.
`'#FF0000'`), a named color (e.g. 'red'), a 3- or
4-element RGB/RGBA tuple of integers 0-255 or floats
0-1, or a single float 0-1 for grayscale.
stride: as an optimization, consider only RGB colors where R, G,
and B are all multiples of this value. Must be a small
divisor of 255: 1, 3, 5, 15, or 17. Set to 1 for the best
possible solution, at orders of magnitude more
computational cost.
default_color: the default color to plot cells in when
`color_column` is `None`, or when certain cells have
missing (`null`) values for `color_column`, or when
`colormap` is a dictionary and some cells have
values of `color_column` that are not in the
dictionary. Can be any valid Matplotlib color, like
a hex string (e.g. `'#FF0000'`), a named color (e.g.
'red'), a 3- or 4-element RGB/RGBA tuple of integers
0-255 or floats 0-1, or a single float 0-1 for
grayscale.
scatter_kwargs: a dictionary of keyword arguments to be passed to
`ax.scatter()`, such as:
- `rasterized`: whether to convert the scatter plot
points to a raster (bitmap) image when saving to
a vector format like PDF. Defaults to `True`,
instead of Matplotlib's default of `False`.
- `marker`: the shape to use for plotting each cell
- `norm`, `vmin`, and `vmax`: control how the
`colormap` maps the numbers in `color_column` to
colors, if `color_column` is numeric
- `alpha`: the transparency of each point
- `linewidths` and `edgecolors`: the width and
color of the borders around each marker. These
are absent by default (`linewidths=0`,
`edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's
default. Both arguments can be either single
values or sequences.
- `zorder`: the order in which the cells are
plotted, with higher values appearing on top of
lower ones.
Specifying `s`, `c`/`color`, or `cmap` will raise
an error, since these arguments conflict with the
`point_size`, `color_column`, and `colormap`
arguments, respectively.
label: whether to label cells with each distinct value of
`color_column`. Labels will be placed at the median x and y
position of the points with that color. Can only be `True`
when `color_column` is discrete. When set to `True`, you may
also want to set `legend=False` to avoid redundancy.
label_kwargs: a dictionary of keyword arguments to be passed to
`ax.text()` when adding labels to control the text
properties, such as:
- `color` and `size` to modify the text color/size
- `verticalalignment` and `horizontalalignment` to
control vertical and horizontal alignment. By
default, unlike Matplotlib, these are both set to
`'center'`.
- `path_effects` to set properties for the border
around the text. By default, set to
`matplotlib.patheffects.withStroke(linewidth=3, foreground='white', alpha=0.75)`
instead of Matplotlib's default of `None`, to put a
semi-transparent white border around the labels for
better contrast.
Can only be specified when `label=True`.
legend: whether to add a legend for each value in `color_column`.
Ignored unless `color_column` is discrete.
legend_kwargs: a dictionary of keyword arguments to be passed to
`ax.legend()` to modify the legend, such as:
- `loc`, `bbox_to_anchor`, and `bbox_transform` to
set its location. By default, `loc` is set to
`'center left'` and `bbox_to_anchor` to `(1, 0.5)`
to put the legend to the right of the plot,
anchored at the middle.
- `ncols` to set its number of columns. By
default, set to
`ceil(obs[color_column].n_unique() / 24)` to have
at most 24 items per column.
- `prop`, `fontsize`, and `labelcolor` to set its
font properties
- `facecolor` and `framealpha` to set its background
color and transparency
- `frameon=True` or `edgecolor` to add or color its
border. `frameon` defaults to `False`, instead of
Matplotlib's default of `True`.
- `title` to add a legend title
Can only be specified when `color_column` is
discrete and `legend=True`.
colorbar: whether to add a colorbar. Ignored unless `color_column`
is quantitative.
colorbar_kwargs: a dictionary of keyword arguments to be passed to
`plt.colorbar()`, such as:
- `location`: `'left'`, `'right'`, `'top'`, or
`'bottom'`
- `orientation`: `'vertical'` or `'horizontal'`
- `fraction`: the fraction of the axes to
allocate to the colorbar. Defaults to 0.15.
- `shrink`: the fraction to multiply the size of
the colorbar by. Defaults to 0.5, instead of
Matplotlib's default of 1.
- `aspect`: the ratio of the colorbar's long to
short dimensions. Defaults to 20.
- `pad`: the fraction of the axes between the
colorbar and the rest of the figure. Defaults to
0.01, instead of Matplotlib's default of 0.05 if
vertical and 0.15 if horizontal.
Can only be specified when `color_column` is
quantitative and `colorbar=True`.
title: the title of the plot, or `None` to not add a title
title_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_title()` to control text properties, such as
`color` and `size`. Can only be specified when
`title` is not `None`.
xlabel: the x-axis label, or `None` to not label the x-axis
xlabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_xlabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`xlabel` is not `None`.
ylabel: the y-axis label, or `None` to not label the y-axis
ylabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_ylabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`ylabel` is not `None`.
xlim: a length-2 tuple of the left and right x-axis limits, or
`None` to set the limits based on the data
ylim: a length-2 tuple of the bottom and top y-axis limits, or
`None` to set the limits based on the data
despine: whether to remove the top and right spines (borders of the
plot area) from the plot
savefig_kwargs: a dictionary of keyword arguments to be passed to
`plt.savefig()`, such as:
- `dpi`: defaults to 300 instead of Matplotlib's
default of 150
- `bbox_inches`: the bounding box of the portion of
the figure to save; defaults to `'tight'` (crop
out any blank borders) instead of Matplotlib's
default of `None` (save the entire figure)
- `pad_inches`: the number of inches of padding to
add on each of the four sides of the figure when
saving. Defaults to `'layout'` (use the padding
from the constrained layout engine, when `ax` is
not `None`) instead of Matplotlib's default of
0.1.
- `transparent`: whether to save with a transparent
background; defaults to `True` if saving to a PDF
(i.e. when `filename` ends with `'.pdf'`) and
`False` otherwise, instead of Matplotlib's
default of always being `False`.
Can only be specified when `filename` is specified.
"""
self.plot_embedding('umap', color_column, filename,
cells_to_plot_column=cells_to_plot_column, ax=ax,
figure_kwargs=figure_kwargs, point_size=point_size,
sort_by_frequency=sort_by_frequency,
colormap=colormap, lightness_range=lightness_range,
chroma_range=chroma_range, hue_range=hue_range,
first_color=first_color, stride=stride,
default_color=default_color,
scatter_kwargs=scatter_kwargs, label=label,
label_kwargs=label_kwargs, legend=legend,
legend_kwargs=legend_kwargs, colorbar=colorbar,
colorbar_kwargs=colorbar_kwargs, title=title,
title_kwargs=title_kwargs, xlabel=xlabel,
xlabel_kwargs=xlabel_kwargs, ylabel=ylabel,
ylabel_kwargs=ylabel_kwargs, xlim=xlim, ylim=ylim,
despine=despine, savefig_kwargs=savefig_kwargs)
[docs]
def plot_embedding(
self,
embedding_key: str,
color_column: SingleCellColumn | None,
filename: str | Path | None = None,
/,
*,
cells_to_plot_column: SingleCellColumn | None = 'passed_QC',
ax: 'Axes' | None = None,
figure_kwargs: dict[str, Any] | None = None,
point_size: int | float | np.integer | np.floating | str |
None = None,
sort_by_frequency: bool = False,
colormap: str | 'Colormap' | dict[Any, Color] = None,
lightness_range: tuple[float | np.floating, float | np.floating] |
None = (100 / 3, 200 / 3),
chroma_range: tuple[float | np.floating, float | np.floating] |
None = (50, 100),
hue_range: tuple[float | np.floating, float | np.floating] |
None = None,
first_color: Color = '#008cb9',
stride: int | np.integer = 5,
default_color: Color = 'lightgray',
scatter_kwargs: dict[str, Any] | None = None,
label: bool = False,
label_kwargs: dict[str, Any] | None = None,
legend: bool = True,
legend_kwargs: dict[str, Any] | None = None,
colorbar: bool = True,
colorbar_kwargs: dict[str, Any] | None = None,
title: str | None = None,
title_kwargs: dict[str, Any] | None = None,
xlabel: str | None = 'Component 1',
xlabel_kwargs: dict[str, Any] | None = None,
ylabel: str | None = 'Component 2',
ylabel_kwargs: dict[str, Any] | None = None,
xlim: tuple[int | float | np.integer | np.floating,
int | float | np.integer | np.floating] | None = None,
ylim: tuple[int | float | np.integer | np.floating,
int | float | np.integer | np.floating] | None = None,
despine: bool = True,
savefig_kwargs: dict[str, Any] | None = None) -> None:
"""
Plot the specified 2D embedding.
Args:
embedding_key: the key of `obsm` containing the embedding to plot
color_column: an optional column of `obs` indicating how to color
each cell in the plot. Can be a column name, a polars
expression, a polars Series, a 1D NumPy array, or a
function that takes in this SingleCell dataset and
returns a polars Series or 1D NumPy array. Can be
discrete (e.g. cell-type labels), specified as a
String/Enum/Categorical column, or quantitative (e.g.
the number of UMIs per cell), specified as an
integer/floating-point column. Missing (`null`) cells
will be plotted with the color `default_color`. Set
to `None` to use `default_color` for all cells.
filename: the file to save to. If `None`, generate the plot but do
not save it, which allows it to be shown interactively or
modified further before saving.
cells_to_plot_column: an optional Boolean column of `obs`
indicating which cells to plot. Can be a
column name, a polars expression, a polars
Series, a 1D NumPy array, or a function that
takes in this SingleCell dataset and returns
a polars Series or 1D NumPy array. Set to
`None` to plot all cells passing QC.
ax: the Matplotlib axes to save the plot onto; if `None`, create a
new figure with Matpotlib's constrained layout and plot onto it
figure_kwargs: a dictionary of keyword arguments to be passed to
`plt.figure` when `ax` is `None`, such as:
- `figsize`: a two-element sequence of the width and
height of the figure in inches. Defaults to
`[8, 6]`, 25% larger than Matplotlib's default of
`[6.4, 4.8]`.
- `layout`: the layout mechanism used by Matplotlib
to avoid overlapping plot elements. Defaults to
`'constrained'`, instead of Matplotlib's default
of `None`.
point_size: the size of the points for each cell; defaults to
20,000 divided by the number of cells. Can be a single
number, or the name of a column of `obs` to make each
point a different size.
sort_by_frequency: if `True`, assign colors and sort the legend in
order of decreasing frequency; if `False` (the
default), use
[natural sort order](en.wikipedia.org/wiki/Natural_sort_order).
Cannot be `True` unless `colormap` is `None` and
`color_column` is discrete; if `colormap` is
not `None`, the plot order is determined by the
order of the keys in `colormap`.
colormap: a string or Colormap object indicating the Matplotlib
colormap to use; or, if `color_column` is discrete, a
dictionary mapping values in `color_column` to Matplotlib
colors (cells with values of `color_column` that are not
in the dictionary will be plotted in the color
`default_color`). Defaults to
`plt.rcParams['image.cmap']` (`'viridis'` by default) if
`color_column` is continous, or the colors from a
maximally perceptually distinct colormap if
`color_column` is discrete (with colors assigned in
decreasing order of frequency). Cannot be specified if
`color_column` is `None`.
lightness_range: a two-element tuple with the lightness range of
colors to generate, or `None` to take the full
range: `[0, 100]`. Can only be specified when
`color_column` is discrete and `colormap` is
`None`.
chroma_range: a two-element tuple with the chroma range of colors
to generate, or `None` to take the full range:
`[0, 100]`. Grays have low chroma, and vivid colors
have high chroma. Can only be specified when
`color_column` is discrete and `colormap` is `None`.
hue_range: a two-element tuple with the hue range of colors to
generate, or `None` to take the full range: `[0, 360]`.
Red is at 0°, green at 120°, and blue at 240°. Because
it wraps around, the first element of the tuple can be
greater than the second, unlike for `lightness_range`
and `chroma_range`. Can only be specified when
`color_column` is discrete and `colormap` is `None`.
first_color: the first color of the palette. Can be any valid
Matplotlib color, like a hex string (e.g.
`'#FF0000'`), a named color (e.g. 'red'), a 3- or
4-element RGB/RGBA tuple of integers 0-255 or floats
0-1, or a single float 0-1 for grayscale.
stride: as an optimization, consider only RGB colors where R, G,
and B are all multiples of this value. Must be a small
divisor of 255: 1, 3, 5, 15, or 17. Set to 1 for the best
possible solution, at orders of magnitude more
computational cost.
default_color: the default color to plot cells in when
`color_column` is `None`, or when certain cells have
missing (`null`) values for `color_column`, or when
`colormap` is a dictionary and some cells have
values of `color_column` that are not in the
dictionary. Can be any valid Matplotlib color, like
a hex string (e.g. `'#FF0000'`), a named color (e.g.
'red'), a 3- or 4-element RGB/RGBA tuple of integers
0-255 or floats 0-1, or a single float 0-1 for
grayscale.
scatter_kwargs: a dictionary of keyword arguments to be passed to
`ax.scatter()`, such as:
- `rasterized`: whether to convert the scatter plot
points to a raster (bitmap) image when saving to
a vector format like PDF. Defaults to `True`,
instead of Matplotlib's default of `False`.
- `marker`: the shape to use for plotting each cell
- `norm`, `vmin`, and `vmax`: control how the
`colormap` maps the numbers in `color_column` to
colors, if `color_column` is numeric
- `alpha`: the transparency of each point
- `linewidths` and `edgecolors`: the width and
color of the borders around each marker. These
are absent by default (`linewidths=0`,
`edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's
default. Both arguments can be either single
values or sequences.
- `zorder`: the order in which the cells are
plotted, with higher values appearing on top of
lower ones.
Specifying `s`, `c`/`color`, or `cmap` will raise
an error, since these arguments conflict with the
`point_size`, `color_column`, and `colormap`
arguments, respectively.
label: whether to label cells with each distinct value of
`color_column`. Labels will be placed at the median x and y
position of the points with that color. Can only be `True`
when `color_column` is discrete. When set to `True`, you may
also want to set `legend=False` to avoid redundancy.
label_kwargs: a dictionary of keyword arguments to be passed to
`ax.text()` when adding labels to control the text
properties, such as:
- `color` and `size` to modify the text color/size
- `verticalalignment` and `horizontalalignment` to
control vertical and horizontal alignment. By
default, unlike Matplotlib, these are both set to
`'center'`.
- `path_effects` to set properties for the border
around the text. By default, set to
`matplotlib.patheffects.withStroke(linewidth=3, foreground='white', alpha=0.75)`
instead of Matplotlib's default of `None`, to put a
semi-transparent white border around the labels for
better contrast.
Can only be specified when `label=True`.
legend: whether to add a legend for each value in `color_column`.
Ignored unless `color_column` is discrete.
legend_kwargs: a dictionary of keyword arguments to be passed to
`ax.legend()` to modify the legend, such as:
- `loc`, `bbox_to_anchor`, and `bbox_transform` to
set its location. By default, `loc` is set to
`'center left'` and `bbox_to_anchor` to `(1, 0.5)`
to put the legend to the right of the plot,
anchored at the middle.
- `ncols` to set its number of columns. By
default, set to
`ceil(obs[color_column].n_unique() / 24)` to have
at most 24 items per column.
- `prop`, `fontsize`, and `labelcolor` to set its
font properties
- `facecolor` and `framealpha` to set its background
color and transparency
- `frameon=True` or `edgecolor` to add or color its
border. `frameon` defaults to `False`, instead of
Matplotlib's default of `True`.
- `title` to add a legend title
Can only be specified when `color_column` is
discrete and `legend=True`.
colorbar: whether to add a colorbar. Ignored unless `color_column`
is quantitative.
colorbar_kwargs: a dictionary of keyword arguments to be passed to
`plt.colorbar()`, such as:
- `location`: `'left'`, `'right'`, `'top'`, or
`'bottom'`
- `orientation`: `'vertical'` or `'horizontal'`
- `fraction`: the fraction of the axes to
allocate to the colorbar. Defaults to 0.15.
- `shrink`: the fraction to multiply the size of
the colorbar by. Defaults to 0.5, instead of
Matplotlib's default of 1.
- `aspect`: the ratio of the colorbar's long to
short dimensions. Defaults to 20.
- `pad`: the fraction of the axes between the
colorbar and the rest of the figure. Defaults to
0.01, instead of Matplotlib's default of 0.05 if
vertical and 0.15 if horizontal.
Can only be specified when `color_column` is
quantitative and `colorbar=True`.
title: the title of the plot, or `None` to not add a title
title_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_title()` to control text properties, such as
`color` and `size`. Can only be specified when
`title` is not `None`.
xlabel: the x-axis label, or `None` to not label the x-axis
xlabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_xlabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`xlabel` is not `None`.
ylabel: the y-axis label, or `None` to not label the y-axis
ylabel_kwargs: a dictionary of keyword arguments to be passed to
`ax.set_ylabel()` to control text properties, such
as `color` and `size`. Can only be specified when
`ylabel` is not `None`.
xlim: a length-2 tuple of the left and right x-axis limits, or
`None` to set the limits based on the data
ylim: a length-2 tuple of the bottom and top y-axis limits, or
`None` to set the limits based on the data
despine: whether to remove the top and right spines (borders of the
plot area) from the plot
savefig_kwargs: a dictionary of keyword arguments to be passed to
`plt.savefig()`, such as:
- `dpi`: defaults to 300 instead of Matplotlib's
default of 150
- `bbox_inches`: the bounding box of the portion of
the figure to save; defaults to `'tight'` (crop
out any blank borders) instead of Matplotlib's
default of `None` (save the entire figure)
- `pad_inches`: the number of inches of padding to
add on each of the four sides of the figure when
saving. Defaults to `'layout'` (use the padding
from the constrained layout engine, when `ax` is
not `None`) instead of Matplotlib's default of
0.1.
- `transparent`: whether to save with a transparent
background; defaults to `True` if saving to a PDF
(i.e. when `filename` ends with `'.pdf'`) and
`False` otherwise, instead of Matplotlib's
default of always being `False`.
Can only be specified when `filename` is specified.
"""
# Import matplotlib
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
import matplotlib.pyplot as plt
finally:
signal.signal(signal.SIGINT, signal.default_int_handler)
# Get `cells_to_plot_column`, if not `None`
original_cells_to_plot_column = cells_to_plot_column
if cells_to_plot_column is not None:
cells_to_plot_column = self._get_column(
'obs', cells_to_plot_column, 'cells_to_plot_column',
pl.Boolean,
allow_missing=isinstance(cells_to_plot_column, str) and
cells_to_plot_column == 'passed_QC')
# If `color_column` was specified, check that it either discrete
# (String, Enum, or Categorical) or quantitative (integer or
# floating-point). If discrete, check that `color_column` has at least
# two distinct values.
original_color_column = color_column
if color_column is not None:
color_column = self._get_column(
'obs', color_column, 'color_column',
(pl.String, pl.Enum, pl.Categorical, 'integer',
'floating-point'), allow_null=True,
QC_column=cells_to_plot_column)
unique_color_column = color_column.unique().drop_nulls()
dtype = color_column.dtype
discrete = dtype in (pl.String, pl.Enum, pl.Categorical)
if discrete and len(unique_color_column) == 1:
color_column_description = \
SingleCell._describe_column('color_column',
original_color_column)
error_message = (
f'{color_column_description} must have at least two '
f'distinct values when its data '
f'type is {dtype.base_type()!r}')
raise ValueError(error_message)
# If `filename` was specified, check that it is a string or
# `pathlib.Path` and that its base directory exists; if `filename` is
# `None`, make sure `savefig_kwargs` is also `None`
if filename is not None:
check_type(filename, 'filename', (str, Path),
'a string or pathlib.Path')
directory = os.path.dirname(filename)
if directory and not os.path.isdir(directory):
error_message = (
f'{filename} refers to a file in the directory '
f'{directory!r}, but this directory does not exist')
raise NotADirectoryError(error_message)
filename = str(filename)
elif savefig_kwargs is not None:
error_message = 'savefig_kwargs must be None when filename is None'
raise ValueError(error_message)
# Check that `embedding_key` is the name of a key in `obsm`
check_type(embedding_key, 'embedding_key', str, 'a string')
if embedding_key not in self._obsm:
error_message = \
f'embedding_key {embedding_key!r} is not a key of obsm'
if embedding_key in ('umap', 'pacmap', 'localmap'):
error_message += (
f'; did you forget to run {embedding_key}()?')
raise ValueError(error_message)
# Check that the embedding `embedding_key` references is 2D.
embedding = self._obsm[embedding_key]
if embedding.shape[1] != 2:
error_message = (
f'the embedding at obsm[{embedding_key!r}] is '
f'{embedding.shape[1]:,}-dimensional, but must be '
f'2-dimensional to be plotted')
raise ValueError(error_message)
# If `cells_to_plot_column` was specified, subset to these cells
if cells_to_plot_column is not None:
embedding = embedding[cells_to_plot_column.to_numpy()]
if color_column is not None:
color_column = color_column.filter(cells_to_plot_column)
unique_color_column = color_column.unique().drop_nulls()
# Check that the embedding does not contain NaNs
if np.isnan(embedding).any():
error_message = \
f'the embedding at obsm[{embedding_key!r}] contains NaNs; '
if cells_to_plot_column is None and 'passed_QC' in self._obs:
error_message += (
'did you forget to set QC_column to None when creating the '
'embedding, to match the fact that you set '
'cells_to_plot_column to None here?')
else:
cells_to_plot_column_description = \
SingleCell._describe_column('cells_to_plot_column',
original_cells_to_plot_column)
error_message += (
f'does your {cells_to_plot_column_description} contain '
f'cells that were excluded by the QC_column used when '
f'creating the embedding?')
raise ValueError(error_message)
# For each of the kwargs arguments, if the argument was specified,
# check that it is a dictionary and that all its keys are strings.
for kwargs, kwargs_name in ((figure_kwargs, 'figure_kwargs'),
(scatter_kwargs, 'scatter_kwargs'),
(label_kwargs, 'label_kwargs'),
(legend_kwargs, 'legend_kwargs'),
(colorbar_kwargs, 'colorbar_kwargs'),
(title_kwargs, 'title_kwargs'),
(xlabel_kwargs, 'xlabel_kwargs'),
(ylabel_kwargs, 'ylabel_kwargs'),
(savefig_kwargs, 'savefig_kwargs')):
if kwargs is not None:
check_type(kwargs, kwargs_name, dict, 'a dictionary')
for key in kwargs:
if not isinstance(key, str):
error_message = (
f'all keys of {kwargs_name} must be strings, but '
f'it contains a key of type '
f'{type(key).__name__!r}')
raise TypeError(error_message)
# If `figure_kwargs` was specified, check that `ax` is `None`
if figure_kwargs is not None and ax is not None:
error_message = (
'figure_kwargs must be None when ax is not None, since a new '
'figure does not need to be generated when plotting onto an '
'existing axis')
raise ValueError(error_message)
# If `point_size` is `None`, default to 20,000 / num_cells; otherwise,
# check that it is a positive number or the name of a numeric column of
# `obs` with all-positive numbers
num_cells = \
len(self) if cells_to_plot_column is None else len(embedding)
if point_size is None:
point_size = 20_000 / num_cells
else:
check_type(point_size, 'point_size', (int, float, str),
'a positive number or string')
if isinstance(point_size, (int, float)):
check_bounds(point_size, 'point_size', 0, left_open=True)
else:
if point_size not in self._obs:
error_message = \
f'point_size {point_size!r} is not a column of obs'
raise ValueError(error_message)
point_size = self._obs[point_size]
if not (point_size.dtype.is_integer() or
point_size.dtype.is_float()):
error_message = (
f'the point_size column, obs[{point_size!r}], must '
f'have an integer or floating-point data type, but '
f'has data type {point_size.dtype.base_type()!r}')
raise TypeError(error_message)
if point_size.min() <= 0:
error_message = (
f'the point_size column, obs[{point_size!r}], does '
f'not have all-positive elements')
raise ValueError(error_message)
# If `sort_by_frequency=True`, ensure `colormap` is `None` and
# `color_column` is discrete
check_type(sort_by_frequency, 'sort_by_frequency', bool, 'Boolean')
if sort_by_frequency:
if colormap is not None:
error_message = (
f'sort_by_frequency must be False when colormap is '
f'specified')
raise ValueError(error_message)
if color_column is None:
error_message = \
'sort_by_frequency must be False when color_column is None'
raise ValueError(error_message)
if not discrete:
color_column_description = \
SingleCell._describe_column('color_column',
original_color_column)
error_message = (
f'sort_by_frequency must be False when '
f'{color_column_description} is continuous')
raise ValueError(error_message)
# Handle coloring based on the values of `colormap` and `color_column`
if colormap is not None:
# If `colormap` was specified, check that it is a string in
# `plt.colormaps`, Colormap object, or dictionary where all keys
# are in `color_column` and all values are valid Matplotlib colors.
# Normalize the color(s) to RGBA. Make sure `color_column` is not
# `None` and `lightness_range`, `chroma_range`, `hue_range`,
# `first_color`, and `stride` have their default values.
check_type(colormap, 'colormap',
(str, plt.matplotlib.colors.Colormap, dict),
'a string, matplotlib Colormap object, or dictionary')
if color_column is None:
error_message = \
'colormap must be None when color_column is None'
raise ValueError(error_message)
if not (isinstance(lightness_range, tuple) and
len(lightness_range) == 2 and
isinstance(lightness_range[0], float) and
lightness_range[0] == 100 / 3 and
isinstance(lightness_range[1], float) and
lightness_range[1] == 200 / 3):
error_message = (
f'lightness_range cannot be specified when colormap is '
f'specified')
raise ValueError(error_message)
if not (isinstance(chroma_range, tuple) and
len(chroma_range) == 2 and
isinstance(chroma_range[0], int) and
chroma_range[0] == 50 and
isinstance(chroma_range[1], int) and
chroma_range[1] == 100):
error_message = (
f'chroma_range cannot be specified when colormap is '
f'specified')
raise ValueError(error_message)
for arg, arg_name in ((hue_range, 'hue_range'),
(first_color, 'first_color'),
(stride, 'stride')):
if arg is not None:
error_message = \
f'{arg_name} must be None when colormap is specified'
raise ValueError(error_message)
if isinstance(colormap, str):
colormap = plt.colormaps[colormap]
elif isinstance(colormap, dict):
if not discrete:
color_column_description = \
SingleCell._describe_column('color_column',
original_color_column)
error_message = (
f'colormap cannot be a dictionary when '
f'{color_column_description} is continuous')
raise ValueError(error_message)
for key, value in colormap.items():
if not isinstance(key, str):
error_message = (
f'all keys of colormap must be strings, but it '
f'contains a key of type {type(key).__name__!r}')
raise TypeError(error_message)
if key not in unique_color_column:
error_message = (
f'colormap is a dictionary containing the key '
f'{key!r}, which is not one of the values in '
f'obs[{color_column!r}]')
raise ValueError(error_message)
if not plt.matplotlib.colors.is_color_like(value):
error_message = (
f'colormap[{key!r}] is not a valid Matplotlib '
f'color')
raise ValueError(error_message)
colormap[key] = plt.matplotlib.colors.to_rgba(value)
else:
if color_column is not None and discrete:
# `color_column` is discrete and `colormap` was not specified;
# generate a maximally perceptually distinct colormap. Assign
# colors in natural sort order, or decreasing order of
# frequency if `sort_by_frequency=True`.
color_order = color_column\
.value_counts(sort=True)\
.to_series()\
.drop_nulls() if sort_by_frequency else \
sorted(unique_color_column,
key=lambda color_label: [
int(text) if text.isdigit() else
text.lower() for text in
re.split('([0-9]+)', color_label)])
colormap = generate_palette(num_colors=len(color_order),
lightness_range=lightness_range,
chroma_range=chroma_range,
hue_range=hue_range,
first_color=first_color,
stride=stride)
colormap = np.c_[colormap, np.ones(len(colormap))] # add alpha
colormap = dict(zip(color_order, colormap))
else:
# `color_column` is `None` or continuous, so make sure
# `lightness_range`, `chroma_range`, `hue_range`,
# `first_color`, and `stride` have their default values
for arg, arg_name in ((lightness_range, 'lightness_range'),
(chroma_range, 'chroma_range'),
(hue_range, 'hue_range'),
(first_color, 'first_color'),
(stride, 'stride')):
if arg is lightness_range:
if isinstance(lightness_range, tuple) and \
len(lightness_range) == 2 and \
isinstance(lightness_range[0], float) and \
lightness_range[0] == 100 / 3 and \
isinstance(lightness_range[1], float) and \
lightness_range[1] == 200 / 3:
continue
elif arg is chroma_range:
if isinstance(chroma_range, tuple) and \
len(chroma_range) == 2 and \
isinstance(chroma_range[0], int) and \
chroma_range[0] == 50 and \
isinstance(chroma_range[1], int) and \
chroma_range[1] == 100:
continue
elif arg is None:
continue
if color_column is None:
error_message = (
f'{arg_name} must be None when color_column is '
f'None')
raise ValueError(error_message)
else:
color_column_description = \
SingleCell._describe_column(
'color_column', original_color_column)
error_message = (
f'{arg_name} must be None when '
f'{color_column_description} is continuous')
raise ValueError(error_message)
# Check that `default_color` is a valid Matplotlib color, and convert
# it to RGBA
if not plt.matplotlib.colors.is_color_like(default_color):
error_message = 'default_color is not a valid Matplotlib color'
raise ValueError(error_message)
default_color = plt.matplotlib.colors.to_rgba(default_color)
# Override the defaults for certain keys of `scatter_kwargs`
default_scatter_kwargs = dict(rasterized=True, linewidths=0,
edgecolors=(0, 0, 0, 0))
scatter_kwargs = default_scatter_kwargs | scatter_kwargs \
if scatter_kwargs is not None else default_scatter_kwargs
# Check that `scatter_kwargs` does not contain the `s`, `c`/`color`, or
# `cmap` keys
if 's' in scatter_kwargs:
error_message = (
"'s' cannot be specified as a key in scatter_kwargs; specify "
"the point_size argument instead")
raise ValueError(error_message)
for key in 'c', 'color', 'cmap':
if key in scatter_kwargs:
error_message = (
f'{key!r} cannot be specified as a key in scatter_kwargs; '
f'specify the color_column, colormap, lightness_range, '
f'chroma_range, hue_range, first_color, stride, and/or '
f'default_color arguments instead')
raise ValueError(error_message)
# If `label=True`, check that `color_column` is discrete.
# If `label=False`, check that `label_kwargs` is `None`.
check_type(label, 'label', bool, 'Boolean')
if label:
if color_column is None:
error_message = 'color_column cannot be None when label=True'
raise ValueError(error_message)
if not discrete:
color_column_description = \
SingleCell._describe_column('color_column',
original_color_column)
error_message = (
f'{color_column_description} cannot be continuous when '
f'label=True')
raise ValueError(error_message)
elif label_kwargs is not None:
error_message = 'label_kwargs must be None when label=False'
raise ValueError(error_message)
# Only add a legend if `legend=True` and `color_column` is discrete.
# If not adding a legend, check that `legend_kwargs` is `None`.
check_type(legend, 'legend', bool, 'Boolean')
add_legend = legend and color_column is not None and discrete
if not add_legend and legend_kwargs is not None:
if color_column is None:
error_message = \
'legend_kwargs must be None when color_column is None'
raise ValueError(error_message)
else:
color_column_description = SingleCell._describe_column(
'color_column', original_color_column)
error_message = (
f'legend_kwargs must be None when '
f'{color_column_description} is continuous')
raise ValueError(error_message)
# Only add a colorbar if `colorbar=True` and `color_column` is
# continuous. If not adding a colorbar, check that `colorbar_kwargs` is
# `None`.
check_type(colorbar, 'colorbar', bool, 'Boolean')
add_colorbar = colorbar and color_column is not None and not discrete
if not add_colorbar and colorbar_kwargs is not None:
if color_column is None:
error_message = \
'colorbar_kwargs must be None when color_column is None'
raise ValueError(error_message)
else:
color_column_description = SingleCell._describe_column(
'color_column', original_color_column)
error_message = (
f'colorbar_kwargs must be None when '
f'{color_column_description} is discrete')
raise ValueError(error_message)
# Check that `title` is a string or `None`; if `None`, check that
# `title_kwargs` is `None` as well. Ditto for `xlabel` and `ylabel`.
for arg, arg_name, arg_kwargs in (
(title, 'title', title_kwargs),
(xlabel, 'xlabel', xlabel_kwargs),
(ylabel, 'ylabel', ylabel_kwargs)):
if arg is not None:
check_type(arg, arg_name, str, 'a string')
elif arg_kwargs is not None:
error_message = \
f'{arg_name}_kwargs must be None when {arg_name} is None'
raise ValueError(error_message)
# Check that `xlim` and `ylim` are be length-2 tuples or `None`, with
# the first element less than the second
for arg, arg_name in (xlim, 'xlim'), (ylim, 'ylim'):
if arg is not None:
check_type(arg, arg_name, tuple, 'a length-2 tuple')
if len(arg) != 2:
error_message = (
f'{arg_name} must be a length-2 tuple, but has length '
f'{len(arg):,}')
raise ValueError(error_message)
if arg[0] >= arg[1]:
error_message = \
f'{arg_name}[0] must be less than {arg_name}[1]'
raise ValueError(error_message)
# If `color_column` is `None`, plot all cells in `default_color`. If
# `colormap` is a dictionary, generate an explicit list of colors to
# plot each cell in. If `colormap` is a Colormap, just pass it as the
# cmap` argument. If `colormap` is missing and `color_column` is
# continuous, set it to `plt.rcParams['image.cmap']` ('viridis' by
# default)
if color_column is None:
c = default_color
cmap = None
elif isinstance(colormap, dict):
# Fill both missing values and values missing from `colormap` with
# `default_color`
if color_column.dtype == pl.String:
color_column = color_column.cast(pl.Categorical)
categories = color_column.cat.get_categories()
lookup = np.array([colormap.get(cat, default_color)
for cat in categories], dtype=np.float32)
c = lookup[color_column.to_physical().fill_null(0)]
c[color_column.is_null()] = default_color
cmap = None
else:
# Need to `copy()` because `set_bad()` is in-place
c = color_column.to_numpy()
if colormap is not None:
cmap = colormap.copy()
cmap.set_bad(default_color)
else: # `color_column` is continuous
cmap = plt.rcParams['image.cmap']
# Check that `despine` is Boolean
check_type(despine, 'despine', bool, 'Boolean')
# If `ax` is `None`, create a new figure; otherwise, check that it is a
# Matplotlib axis
make_new_figure = ax is None
try:
if make_new_figure:
default_figure_kwargs = \
dict(figsize=(8, 6), layout='constrained')
figure_kwargs = default_figure_kwargs | figure_kwargs \
if figure_kwargs is not None else default_figure_kwargs
plt.figure(**figure_kwargs)
ax = plt.gca()
else:
check_type(ax, 'ax', plt.Axes, 'a Matplotlib axis')
# Make a scatter plot of the embedding with equal x-y aspect
# ratios. If `color_column` is discrete (and so `colormap` is a
# dictionary), plot one color at a time, in order of decreasing
# frequency, so rarer cell types end up on top of more common
# ones. This also lets each `ax.scatter()` call use Matplotlib's
# fast path for when the color is scalar (i.e. all points being
# plotted are the same color). Cells with missing (`null`) values,
# or with values not in `colormap`, are plotted first (at the
# bottom) in `default_color`.
if color_column is not None and discrete and \
isinstance(colormap, dict):
color_values = color_column.to_numpy()
point_size_per_cell = isinstance(point_size, pl.Series)
if point_size_per_cell:
point_size = point_size.to_numpy()
unknown_mask = ~np.isin(color_values, tuple(colormap))
if unknown_mask.any():
scatter = ax.scatter(
embedding[unknown_mask, 0], embedding[unknown_mask, 1],
s=point_size[unknown_mask]
if point_size_per_cell else point_size,
color=default_color, **scatter_kwargs)
ordered_labels = color_column\
.value_counts(sort=True)\
.to_series()\
.drop_nulls()
for color_label in ordered_labels:
if color_label not in colormap:
continue
mask = color_values == color_label
scatter = ax.scatter(
embedding[mask, 0], embedding[mask, 1],
s=point_size[mask]
if point_size_per_cell else point_size,
color=colormap[color_label], **scatter_kwargs)
else:
scatter = ax.scatter(embedding[:, 0], embedding[:, 1],
s=point_size, c=c, cmap=cmap,
**scatter_kwargs)
ax.set_aspect('equal')
# Add the title, axis labels and axis limits
if title is not None:
if title_kwargs is None:
ax.set_title(title)
else:
ax.set_title(title, **title_kwargs)
if xlabel is not None:
if xlabel_kwargs is None:
ax.set_xlabel(xlabel)
else:
ax.set_xlabel(xlabel, **xlabel_kwargs)
if ylabel is not None:
if ylabel_kwargs is None:
ax.set_ylabel(ylabel)
else:
ax.set_ylabel(ylabel, **ylabel_kwargs)
if xlim is not None:
ax.set_xlim(*xlim)
if ylim is not None:
ax.set_ylim(*ylim)
# Add the legend; override the defaults for certain values of
# `legend_kwargs`
if add_legend:
default_legend_kwargs = dict(
loc='center left', bbox_to_anchor=(1, 0.5), frameon=False,
ncols=(len(unique_color_column) + 23) // 24)
legend_kwargs = default_legend_kwargs | legend_kwargs \
if legend_kwargs is not None else default_legend_kwargs
if isinstance(colormap, dict):
for color_label, color in colormap.items():
ax.add_artist(plt.Line2D([], [], color=color,
label=color_label, marker='o',
markersize=4, linewidth=0))
plt.legend(**legend_kwargs)
else:
plt.legend(*scatter.legend_elements(), **legend_kwargs)
# Add the colorbar; override the defaults for certain keys of
# `colorbar_kwargs`
if add_colorbar:
default_colorbar_kwargs = dict(shrink=0.5, pad=0.01)
colorbar_kwargs = default_colorbar_kwargs | colorbar_kwargs \
if colorbar_kwargs is not None else default_colorbar_kwargs
cbar = plt.colorbar(scatter, ax=ax, **colorbar_kwargs)
cbar.outline.set_visible(False)
# Label cells; override the defaults for certain keys of
# `label_kwargs`
if label:
from matplotlib.patheffects import withStroke
if label_kwargs is None:
label_kwargs = {}
label_kwargs |= dict(
horizontalalignment=label_kwargs.pop(
'horizontalalignment',
label_kwargs.pop('ha', 'center')),
verticalalignment=label_kwargs.pop(
'verticalalignment',
label_kwargs.pop('va', 'center')),
path_effects=[withStroke(linewidth=3, foreground='white',
alpha=0.75)])
median_coordinates = pl.DataFrame({
'x': embedding[:, 0],
'y': embedding[:, 1],
'color_column': color_column})\
.group_by('color_column')\
.agg(pl.median('x', 'y'))\
.drop_nulls()\
.sort('color_column')
for color_label, median_x, median_y in \
median_coordinates.iter_rows():
ax.text(median_x, median_y, color_label, **label_kwargs)
# Despine, if specified
if despine:
spines = ax.spines
spines['top'].set_visible(False)
spines['right'].set_visible(False)
# Save, if `filename` is not `None`; override the defaults for
# certain keys of `savefig_kwargs`
if filename is not None:
default_savefig_kwargs = \
dict(dpi=300, bbox_inches='tight', pad_inches='layout',
transparent=filename is not None and
filename.endswith('.pdf'))
savefig_kwargs = default_savefig_kwargs | savefig_kwargs \
if savefig_kwargs is not None else default_savefig_kwargs
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
plt.savefig(filename, **savefig_kwargs)
if make_new_figure:
plt.close()
except:
# If we made a new figure, make sure to close it if there's an
# exception (but not if there was no error and `filename` is
# `None`, in case the user wants to modify it further before
# saving)
if make_new_figure:
plt.close()
raise