Source code for brisc.single_cell

from __future__ import annotations
import ctypes
import h5py
import mmap
import numpy as np
import os
import polars as pl
import pyarrow as pa
import re
import signal
import sys
import warnings
from collections.abc import Iterable
from itertools import chain, islice
from pathlib import Path
from scipy import sparse
from textwrap import fill
from threadpoolctl import threadpool_info, threadpool_limits
from typing import Any, Callable, Literal, Mapping, NoReturn, Sequence
from .pseudobulk import Pseudobulk
from .sparse import csc_array, csr_array
from .type_aliases import Color, Indexer, Scalar, UnsDict, UnsItem, \
    SingleCellColumn
from .utils import array_equal, cast_to_Enum, check_bounds, check_dtype, \
    check_R_variable_name, check_type, check_types, concatenate, \
    filter_columns, generate_palette, getnnz, getnnz_at_least_threshold, \
    import_cython, ix_symmetric, numa_zeros, parallel_subset_1d, \
    parallel_subset_2d, plural, read_parallel_multiprocessing, sparse_equal, \
    sparse_major_stack, sparse_minor_stack, to_tuple, to_tuple_checked
from .validated_dict import Obsm, Obsp, Uns, Varm, Varp
import_cython({
    'doublets': ('call_doublets', 'compute_cxds', 'compute_obs', 'compute_S',
                 'get_hvgs', 'simulate_doublets'),
    'embed': ('localmap', 'pacmap', 'umap_fuzzy_weights', 'umap_optimize'),
    'harmonize': ('label_transfer', 'harmony', 'harmony_original',
                  'normalize_rows'),
    'hdf5': 'read_all_datasets',
    'hvg': ('clipped_sum_csc', 'clipped_sum_csr', 'gene_mean_and_variance_csc',
            'gene_mean_and_variance_csr'),
    'kmeans': 'kmeans',
    'knn': ('knn_self', 'knn_cross'),
    'leiden': ('leiden', 'leiden_multiresolution'),
    'normalize': ('normalize_csc', 'normalize_csr'),
    'pca': 'irlba',
    'pseudobulk_and_markers': (
        'get_detection_rate', 'get_detection_rate_and_fold_change',
        'get_detection_rate_and_fold_change_and_pareto_candidates',
        'groupby_getnnz_csc', 'groupby_getnnz_csc_for_gene_subset',
        'groupby_getnnz_and_total_csc_for_gene_subset',
        'groupby_getnnz_csr', 'groupby_getnnz_csr_for_gene_subset',
        'groupby_getnnz_and_total_csr_for_gene_subset', 'groupby_sum_csr',
        'groupby_sum_csc', 'pareto_front'),
    'qc': ('malat1_mask_csr', 'malat1_mask_csr_check', 'malat1_mask_csr_scan',
           'mito_mask_csr', 'mito_mask_csc', 'qc_metrics_csr',
           'qc_metrics_csc'),
    'snn': 'snn'})


[docs] class SingleCell: """ A single-cell dataset. Has slots for: - `X`: a scipy sparse array of counts per cell and gene - `obs`: a polars DataFrame of cell metadata - `var`: a polars DataFrame of gene metadata - `obsm`: a dictionary of NumPy arrays and polars DataFrames of cell metadata - `varm`: a dictionary of NumPy arrays and polars DataFrames of gene metadata - `uns`: a dictionary of scalars (strings, numbers or Booleans) or NumPy arrays, or nested dictionaries thereof - `num_threads`: the default number of threads to use for operations on the dataset that support multithreading (which can be overridden by individual functions) as well as `obs_names` and `var_names`, aliases for `obs[:, 0]` and `var[:, 0]`. """
[docs] def __init__(self, source: str | Path | 'AnnData' | None = None, /, *, X: sparse.csr_array | sparse.csc_array | sparse.csr_matrix | sparse.csc_matrix | Literal[False] | None = None, obs: pl.DataFrame | None = None, var: pl.DataFrame | None = None, obsm: dict[str, np.ndarray | pl.DataFrame] | Literal[False] | None = None, varm: dict[str, np.ndarray | pl.DataFrame] | Literal[False] | None = None, obsp: dict[str, sparse.csr_array | sparse.csc_array | sparse.csr_matrix | sparse.csc_matrix] | Literal[False] | None = None, varp: dict[str, sparse.csr_array | sparse.csc_array | sparse.csr_matrix | sparse.csc_matrix] | Literal[False] | None = None, uns: UnsDict | Literal[False] | None = None, X_key: str | None = None, assay: str | None = None, obs_columns: str | Iterable[str] = None, var_columns: str | Iterable[str] = None, num_threads: int | np.integer = -1) -> None: """ Load a SingleCell dataset from a file, or create one from an in-memory AnnData object or count matrix + metadata. SingleCell supports reading and writing files from each of the three major single-cell ecosystems: - scverse/Scanpy AnnData (`.h5ad`) - Seurat (`.rds` and `.h5Seurat`) - Bioconductor SingleCellExperiment (`.rds`) as well as raw 10x data files (`.h5` or `.mtx`/`.mtx.gz`). By default, when an AnnData object, `.h5ad` file, `.h5Seurat` file, or `.rds` file contains both raw and normalized counts, only the raw counts will be loaded. To load normalized counts instead, use the `X` argument (for AnnData objects) or `X_key` argument (for files). Reading and writing `.rds` files requires the ryp Python-R bridge. To create a SingleCell dataset from an in-memory Seurat or SingleCellExperiment object in the ryp R workspace, use `SingleCell.from_seurat()` or `SingleCell.from_sce()`. Reading and writing loom files is not supported because SingleCell only supports sparse count matrices, and loom only supports dense matrices. Using loom files for SingleCell data is not recommended due to this wastefulness. If you must, load them with `SingleCell(scanpy.read_loom(loom_filename))`; `read_loom()` implicitly converts the counts to a sparse matrix by default. Args: source: a filename or AnnData object, or `None` if specifying `X`, `obs`, and `var` instead. Supported file formats are scverse/Scanpy AnnData (`.h5ad`), Seurat (`.rds` and `.h5Seurat`), Bioconductor SingleCellExperiment (`.rds`), and raw 10x data files (`.h5` or `.mtx`/`.mtx.gz`). If `source` is a 10x `.mtx`/`.mtx.gz` filename, `barcodes.tsv`/`barcodes.tsv.gz` and `features.tsv`/`features.tsv.gz` are assumed to be in the same directory (with the ungzipped versions used preferentially), unless custom paths to these files are specified via the `obs` and/or `var` arguments. X: If `source` is `None`, the data as a sparse array or matrix (with rows = cells, columns = genes). If `source` is an AnnData object, an optional sparse array or matrix to use as `X`. By default, `X` will be loaded from `source.layers['UMIs']` or `source.raw.X` if present and `source.X` otherwise. If `X` is `None` when `source` is `None`, or `False` when `source` is a filename, do not store any data in `X` and set it to `None`. This helps save memory, but the resulting dataset cannot be saved, converted to another format, or used to run analyses that require `X`. obs: a polars DataFrame of metadata for each cell (row of `X`), or `None` if specifying `source` instead. Or, if `source` is a 10x `.mtx`/`.mtx.gz` filename, an optional filename for cell-level metadata, which is otherwise assumed to be at `barcodes.tsv` (or `barcodes.tsv.gz`) in the same directory as the `.mtx`/`.mtx.gz` file. var: a polars DataFrame of metadata for each gene (column of `X`), or `None` if specifying `source` instead. Or, if `source` is a 10x `.mtx`/`.mtx.gz` filename, an optional filename for gene-level metadata, which is otherwise assumed to be at `features.tsv` (or `features.tsv.gz`) in the same directory as the `.mtx`/`.mtx.gz` file. obsm: an optional dictionary mapping string names to NumPy arrays and polars DataFrames of metadata for each cell, or `False` to skip loading `obsm` when reading `.h5ad` and `.h5Seurat` files varm: an optional dictionary mapping string names to NumPy arrays and polars DataFrames of metadata for each gene, or `False` to skip loading `varm` when reading `.h5ad` files obsp: an optional dictionary mapping string names to sparse arrays or matrices containing pairwise cell-cell information like nearest-neighbors graphs, or `False` to skip loading `obsp` when reading `.h5ad` and `.h5Seurat` files varp: an optional dictionary mapping string names to sparse arrays or matrices containing pairwise gene-gene information, or `False` to skip loading `varp` when reading `.h5ad` files uns: an optional dictionary mapping string names to unstructured metadata - scalars (strings, numbers or Booleans), NumPy arrays, or nested dictionaries thereof - or `False` to skip loading `uns` when reading `.h5ad` or `.h5Seurat` files X_key: if `source` is an AnnData `.h5ad`, Seurat `.rds` or `.h5Seurat` filename, or SingleCellExperiment `.rds` filename, the location within `source` to use as `X`: - If `source` is an `.h5ad` filename, the name of the key in the `.h5ad` file to use as `X`. If `None`, defaults to `'layers/UMIs'` (i.e. `self.layers['UMIs']` in Scanpy) or `'raw/X'` (i.e. `self.raw.X` in Scanpy) if present, otherwise `'X'`. Tip: `SingleCell.ls(h5ad_file)` shows the structure of an `.h5ad` file without loading it, allowing you to figure out which key to use as `X`. - If `source` is a Seurat `.rds` or `.h5Seurat` filename, the layer within the active assay (or the assay specified by the `assay` argument, if not `None`) to use as `X`. Set to `'data'` to load the normalized counts, or `'scale.data'` to load the normalized and scaled counts, if available. If `None`, defaults to `'counts'`. - If `source` is a SingleCellExperiment `.rds` filename, the element within `@assays@data` to use as `X`. Set to `'logcounts'` to load the normalized counts, if available. If `None`, defaults to `'counts'`. assay: if `source` is a Seurat `.rds` or `.h5Seurat`/`.h5seurat` filename, the name of the assay within the Seurat object to load data from. Defaults to the Seurat object's `active.assay` attribute (usually `'RNA'`). obs_columns: if `source` is an `.h5ad` or `.h5Seurat` filename, the columns of `obs` to load. If not specified, load all columns. Specifying only a subset of columns can speed up reading. Not supported for `.h5` files, since they only have a single `obs` column (`'barcodes'`), nor for Seurat and SingleCellExperiment `.rds` files, since `.rds` files do not support partial loading. var_columns: if `source` is an `.h5ad`, `.h5`, or `.h5Seurat` filename, the columns of `var` to load. If not specified, load all columns. Specifying only a subset of columns can speed up reading. Not supported for Seurat and SingleCellExperiment `.rds` files, since the `.rds` file format does not support partial loading. num_threads: the number of threads to use when reading `.h5ad` and `.h5` files, and the default number of threads to use for all subsequent operations on this SingleCell dataset. Also sets the number of threads for this SingleCell dataset's count matrix, if present. By default (`num_threads=-1`), use all available cores, as determined by `os.cpu_count()`. Examples: Load an `.h5ad` file: >>> sc = SingleCell('data.h5ad') Load an `.h5ad` file where raw counts are stored in a non-default location, `adata.raw.counts` (use `SingleCell.ls('data.h5ad')` to inspect its structure before loading): >>> sc = SingleCell('data.h5ad', X_key='raw/counts') Load only selected metadata columns from an `.h5ad` file to reduce loading time and memory usage: >>> sc = SingleCell('data.h5ad', ... obs_columns=['cell_type', 'batch'], ... var_columns=['gene_symbol']) Skip loading the count matrix to minimize memory usage (note: dataset cannot be saved, converted, or used for analyses that require `X`): >>> sc = SingleCell('large_data.h5ad', X=False) Load a Seurat `.h5Seurat` file: >>> sc = SingleCell('seurat_obj.h5Seurat') Load a Seurat `.rds` file: >>> sc = SingleCell('seurat_obj.rds') Load a Bioconductor SingleCellExperiment `.rds` file and use log-normalized counts: >>> sc = SingleCell('sce_obj.rds', X_key='logcounts') Load raw 10x Genomics data from an `.h5` file: >>> sc = SingleCell('matrix.h5') Load raw 10x Genomics data from an `.mtx.gz` file, with barcodes and features stored in the same directory as `barcodes.tsv.gz` and `features.tsv.gz` in the usual way: >>> sc = SingleCell('matrix.mtx.gz') Manually create a SingleCell dataset from an in-memory sparse matrix and metadata: >>> import polars as pl >>> from scipy.sparse import csr_array >>> X = csr_array([[1, 0, 3], [0, 2, 0]]) >>> obs = pl.DataFrame({'cell_id': ['cell1', 'cell2']}) >>> var = pl.DataFrame({'gene_id': ['g1', 'g2', 'g3']}) >>> sc = SingleCell(X=X, obs=obs, var=var) Note: Both ordered and unordered categorical columns of `obs` and `var` will be loaded as polars Enums rather than polars Categoricals. This is because polars Categoricals use a shared numerical encoding across columns, so their codes are not `[0, 1, 2, ...]` like pandas categoricals and polars Enums are. Using Categoricals leads to a large overhead (~25%) when loading `obs` from an `.h5ad` file, for example. Note: SingleCell does not support dense matrices, which are highly memory-inefficient for single-cell data. Passing a NumPy array as the `X` argument will give an error; if for some reason your data has been improperly stored as a dense matrix, convert it to a sparse array first with `csr_array(numpy_array)`). However, when loading from disk or converting from other formats, dense matrices will be automatically converted to sparse matrices, to avoid giving an error when loading or converting. """ # Initialize this SingleCell dataset's `num_threads` num_threads = SingleCell._process_num_threads_static(num_threads) self._num_threads = num_threads # Initialize the SingleCell dataset depending on which arguments were # specified if source is None: # Sparse array or matrix if X is not None: for prop, prop_name in ( (X_key, 'X_key'), (assay, 'assay'), (obs_columns, 'obs_columns'), (var_columns, 'var_columns')): if prop is not None: error_message = ( f'when X is a sparse array or matrix, {prop_name} ' f'must be None') raise ValueError(error_message) self.X = X # not self._X; triggers validation else: for prop, prop_name in ( (X_key, 'X_key'), (assay, 'assay'), (obs_columns, 'obs_columns'), (var_columns, 'var_columns')): if prop is not None: error_message = ( f'when X and source are both None, {prop_name} ' f'must be None') raise ValueError(error_message) self._X = None self.obs = obs self.var = var self.obsm = obsm if obsm is not None else {} self.varm = varm if varm is not None else {} self.obsp = obsp if obsp is not None else {} self.varp = varp if varp is not None else {} self.uns = uns if uns is not None else {} elif str(type(source)).startswith("<class 'anndata"): # AnnData object signal.signal(signal.SIGINT, signal.SIG_IGN) try: from anndata import AnnData finally: signal.signal(signal.SIGINT, signal.default_int_handler) check_type(source, 'source', AnnData, 'a filename, Path, or Anndata object') for prop, prop_name in ( (obs, 'obs'), (var, 'var'), (obsm, 'obsm'), (varm, 'varm'), (obsp, 'obsp'), (varp, 'varp'), (uns, 'uns'), (X_key, 'X_key'), (assay, 'assay'), (obs_columns, 'obs_columns'), (var_columns, 'var_columns')): if prop is not None: error_message = ( f'when initializing a SingleCell dataset from an ' f'AnnData object, {prop_name} must be None') raise ValueError(error_message) # Get `X` if X is False: self._X = None else: if X is None: has_layers_UMIs = 'UMIs' in source._layers has_raw_X = hasattr(source._raw, '_X') if has_layers_UMIs and has_raw_X: error_message = ( "both layers['UMIs'] and raw.X are present in " "this AnnData object; this should never " "happen in well-formed AnnData objects") raise ValueError(error_message) X = source._layers['UMIs'] if has_layers_UMIs else \ source._raw._X if has_raw_X else source._X if isinstance(X, np.ndarray): X_string = f"layers['UMIs']" if has_layers_UMIs else \ f'raw.X' if has_raw_X else 'X' warning_message = ( f"this AnnData object's {X_string} is stored as a " f"dense matrix; auto-converting to a sparse csr_array") warnings.warn(warning_message) X = csr_array(X) self.X = X # Get `obs` and `var` for attr in '_obs', '_var': df = getattr(source, attr) if df.index.name is None: # Make the index name consistent with what you'd get if # you had loaded the same AnnData object directly from # an `.h5ad` file df = df.rename_axis('_index') # Convert Categoricals with string categories to polars Enums, # and Categoricals with other types of categories (e.g. # integers) to non-categorical columns of the corresponding # polars dtype (e.g. pl.Int64) since polars only supports # string categories schema_overrides = {} cast_dict = {} for column, dtype in \ df.dtypes[df.dtypes == 'category'].items(): categories_dtype = dtype.categories.dtype if categories_dtype == object: schema_overrides[column] = pl.Enum(dtype.categories) else: cast_dict[column] = categories_dtype setattr(self, attr, pl.from_pandas( df.astype(cast_dict), schema_overrides=schema_overrides, include_index=True)) # Get the other fields self.obsm = source._obsm self.varm = source._varm self.obsp = source._obsp self.varp = source._varp self.uns = source._uns else: # Filename check_type(source, 'source', (str, Path), 'a filename, Path, or AnnData object') source = os.path.expanduser(source) if source.endswith('.h5ad'): # For the AnnData on-disk format specification, see # anndata.readthedocs.io/en/latest/fileformat-prose.html if not os.path.exists(source): error_message = f'.h5ad file {source} does not exist' raise FileNotFoundError(error_message) for prop, prop_name in \ (obs, 'obs'), (var, 'var'), (assay, 'assay'): if prop is not None: error_message = ( f'when loading an .h5ad file, {prop_name} ' f'must be None') raise ValueError(error_message) for prop, prop_name in ( (obsm, 'obsm'), (varm, 'varm'), (obsp, 'obsp'), (varp, 'varp'), (uns, 'uns')): if prop is not None and prop is not False: error_message = ( f'when loading an .h5ad file, {prop_name} ' f'must be None or False') raise ValueError(error_message) if obs_columns is not None: obs_columns = to_tuple_checked( obs_columns, 'obs_columns', str, 'strings') if var_columns is not None: var_columns = to_tuple_checked( var_columns, 'var_columns', str, 'strings') # Loading happens in three stages: # 1. Tabulate which HDF5 datasets we need to load (except uns) # 2. "Preload" datasets in parallel, using threads + a custom # Cython reader for most h5ad files and falling back to # multiprocessing + shared memory using h5py for h5ad files # with chunking and/or compression. Variable-length # (dtype=object) string datasets will only be preloaded when # using threads, since they are not compatible with shared # memory. Store the preloaded datasets in a cache. # 3. Assemble the SingleCell dataset, loading from the # cache if the dataset is there, and otherwise loading # it single-threaded. uns is always loaded single-threaded. # When loading single-threaded, skip steps 1 and 2. h5ad_file = h5py.File(source) try: if num_threads == 1: preloaded_datasets = {} else: # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() # `X` if X is False: if X_key is not None: error_message = ( 'when loading an .h5ad file with X=False, ' 'X_key must be None') raise ValueError(error_message) self._X = None else: if X_key is None: has_layers_UMIs = 'layers/UMIs' in h5ad_file has_raw_X = 'raw/X' in h5ad_file if has_layers_UMIs and has_raw_X: error_message = ( "both layers['UMIs'] and raw.X are " "present; this should never happen in " "well-formed .h5ad files") raise ValueError(error_message) X_key = 'layers/UMIs' if has_layers_UMIs \ else 'raw/X' if has_raw_X else 'X' else: check_type(X_key, 'X_key', str, 'a string') if X_key not in h5ad_file: error_message = ( f'X_key {X_key!r} is not present in ' f'the .h5ad file') raise ValueError(error_message) X = h5ad_file[X_key] if isinstance(X, h5py.Dataset): datasets_to_load.add(X.name) else: datasets_to_load.add(X['data'].name) datasets_to_load.add(X['indices'].name) datasets_to_load.add(X['indptr'].name) # `obs` SingleCell._tabulate_h5ad_dataframe( h5ad_file, 'obs', datasets_to_load, columns=obs_columns) # `var` SingleCell._tabulate_h5ad_dataframe( h5ad_file, 'var', datasets_to_load, columns=var_columns) # `obsm` if obsm is None and 'obsm' in h5ad_file: for key, value in h5ad_file['obsm'].items(): if isinstance(value, h5py.Dataset): datasets_to_load.add(f'obsm/{key}') else: SingleCell._tabulate_h5ad_dataframe( h5ad_file, f'obsm/{key}', datasets_to_load) # `varm` if varm is None and 'varm' in h5ad_file: for key, value in h5ad_file['varm'].items(): if isinstance(value, h5py.Dataset): datasets_to_load.add(f'varm/{key}') else: SingleCell._tabulate_h5ad_dataframe( h5ad_file, f'varm/{key}', datasets_to_load) # `obsp` if obsp is None and 'obsp' in h5ad_file: for value in h5ad_file['obsp'].values(): datasets_to_load.add(value['data'].name) datasets_to_load.add(value['indices'].name) datasets_to_load.add(value['indptr'].name) # `varp` if varp is None and 'varp' in h5ad_file: for value in h5ad_file['varp'].values(): datasets_to_load.add(value['data'].name) datasets_to_load.add(value['indices'].name) datasets_to_load.add(value['indptr'].name) # Infer the number of cells obs_group = h5ad_file['obs'] if isinstance(obs_group, h5py.Group): num_cells = obs_group[ obs_group.attrs['_index']].shape[0] else: num_cells = obs_group.shape[0] # 2. Preload datasets in parallel # If any fixed-width dataset is chunked or compressed # (indicated by a `None` `file_offset`), fall back to # the multiprocessing-based reader. # Crucial: the HDF5 file must be closed before # `read_parallel_multiprocessing()` to avoid the child # processes inheriting the HDF5 library's state # associated with the open file handle when forking. datasets_to_preload = \ SingleCell._get_datasets_to_preload( h5ad_file, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): h5ad_file.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, source, num_threads=num_threads) h5ad_file = h5py.File(source) else: preloaded_datasets = SingleCell._read_parallel( datasets_to_preload, source, num_cells=num_cells, num_threads=num_threads) # 3. Assemble the SingleCell dataset, loading from the # cache if the dataset is there, and otherwise loading # it single-threaded. Always load `uns` single-threaded # for simplicity, since it's rarely that large. # Load `X` if X is False: if X_key is not None: error_message = ( 'when loading an .h5ad file with X=False, ' 'X_key must be None') raise ValueError(error_message) self._X = None else: if X_key is None: has_layers_UMIs = \ 'layers/UMIs' in h5ad_file has_raw_X = 'raw/X' in h5ad_file if has_layers_UMIs and has_raw_X: error_message = ( "both layers['UMIs'] and raw.X are " "present; this should never happen in " "well-formed .h5ad files") raise ValueError(error_message) X_key = 'layers/UMIs' if has_layers_UMIs \ else 'raw/X' if has_raw_X else 'X' else: check_type(X_key, 'X_key', str, 'a string') if X_key not in h5ad_file: error_message = ( f'X_key {X_key!r} is not present in ' f'the .h5ad file') raise ValueError(error_message) X = h5ad_file[X_key] if isinstance(X, h5py.Dataset): warning_message = ( f'{X_key!r} is stored as a dense matrix; ' f'auto-converting to a sparse csr_array') warnings.warn(warning_message) X = SingleCell._read_dataset(X, preloaded_datasets) self.X = csr_array(X) else: matrix_class = X.attrs['encoding-type'] \ if 'encoding-type' in X.attrs else \ X.attrs['h5sparse_format'] + '_matrix' if matrix_class == 'csr_matrix': array_class = csr_array elif matrix_class == 'csc_matrix': array_class = csc_array else: error_message = ( f"X has unsupported encoding-type " f"{matrix_class!r}, but should be " f"'csr_matrix' or 'csc_matrix'") raise ValueError(error_message) data = SingleCell._read_dataset( X['data'], preloaded_datasets) self.X = array_class(( data, SingleCell._read_dataset( X['indices'], preloaded_datasets), SingleCell._read_dataset( X['indptr'], preloaded_datasets)), shape=X.attrs['shape'] if 'shape' in X.attrs else X.attrs['h5sparse_shape']) # Load `obs` and `var` self.obs = SingleCell._read_h5ad_dataframe( h5ad_file, 'obs', preloaded_datasets, columns=obs_columns) self.var = SingleCell._read_h5ad_dataframe( h5ad_file, 'var', preloaded_datasets, columns=var_columns) # Load `obsm` if obsm is None and 'obsm' in h5ad_file: self.obsm = { key: SingleCell._read_dataset( value, preloaded_datasets) if isinstance(value, h5py.Dataset) else SingleCell._read_h5ad_dataframe( h5ad_file, f'obsm/{key}', preloaded_datasets) for key, value in h5ad_file['obsm'].items()} else: self.obsm = {} # Load `varm` if varm is None and 'varm' in h5ad_file: self.varm = { key: SingleCell._read_dataset( value, preloaded_datasets) if isinstance(value, h5py.Dataset) else SingleCell._read_h5ad_dataframe( h5ad_file, f'varm/{key}', preloaded_datasets) for key, value in h5ad_file['varm'].items()} else: self.varm = {} # Load `obsp` if obsp is None and 'obsp' in h5ad_file: self.obsp = { key: (csr_array if value.attrs['encoding-type'] == 'csr_matrix' else csc_array)( (SingleCell._read_dataset( value['data'], preloaded_datasets), SingleCell._read_dataset( value['indices'], preloaded_datasets), SingleCell._read_dataset( value['indptr'], preloaded_datasets)), shape=value.attrs['shape']) for key, value in h5ad_file['obsp'].items()} else: self.obsp = {} # Load `varp` if varp is None and 'varp' in h5ad_file: self.varp = { key: (csr_array if value.attrs['encoding-type'] == 'csr_matrix' else csc_array)( (SingleCell._read_dataset( value['data'], preloaded_datasets), SingleCell._read_dataset( value['indices'], preloaded_datasets), SingleCell._read_dataset( value['indptr'], preloaded_datasets)), shape=value.attrs['shape']) for key, value in h5ad_file['varp'].items()} else: self.varp = {} # Load `uns` if uns is None and 'uns' in h5ad_file: self.uns = SingleCell._read_uns(h5ad_file['uns']) else: self.uns = {} finally: h5ad_file.close() elif source.endswith('.h5Seurat') or \ source.endswith('.h5seurat'): if not os.path.exists(source): error_message = \ f'.h5Seurat file {source} does not exist' raise FileNotFoundError(error_message) for prop, prop_name in ( (obs, 'obs'), (var, 'var'), (varm, 'varm'), (varp, 'varp')): if prop is not None: error_message = ( f'when loading an .h5Seurat file, {prop_name} ' f'must be None') raise ValueError(error_message) for prop, prop_name in \ (obsm, 'obsm'), (obsp, 'obsp'), (uns, 'uns'): if prop is not None and prop is not False: error_message = ( f'when loading an .h5Seurat file, {prop_name} ' f'must be None or False') raise ValueError(error_message) if obs_columns is not None: obs_columns = to_tuple_checked( obs_columns, 'obs_columns', str, 'strings') if var_columns is not None: var_columns = to_tuple_checked( var_columns, 'var_columns', str, 'strings') # The logic here is similar to `.h5ad` files: preload in # parallel where we can h5Seurat_file = h5py.File(source) try: if num_threads == 1: preloaded_datasets = {} else: # Check that `assay` is an assay in the `.h5Seurat` # file, or set it to `active.assay` if `None` if assay is None: assay = \ h5Seurat_file.attrs['active.assay'].item() elif assay not in h5Seurat_file['assays']: error_message = ( f'assay {assay!r} does not exist in the ' f'.h5Seurat file; specify a different ' f'assay than {assay!r}') raise ValueError(error_message) # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() # `X` assay_group = h5Seurat_file['assays'][assay] if X is False: if X_key is not None: error_message = ( 'when loading an .h5Seurat file with ' 'X=False, X_key must be None') raise ValueError(error_message) self._X = None else: if X_key is None: X_key = 'counts' if X_key not in assay_group: error_message = ( f"the 'counts' key is not present " f"in the .h5Seurat file as part " f"of assay {assay!r}; specify a " f"different assay than {assay!r} " f"or specify X_key as something " f"other than 'counts'") raise ValueError(error_message) else: check_type(X_key, 'X_key', str, 'a string or False') if X_key not in assay_group: error_message = ( f'X_key {X_key!r} is not present ' f'in the .h5Seurat file as part ' f'of assay {assay!r}; specify a ' f'different assay than {assay!r} ' f'or a different X_key than ' f'{X_key!r}') raise ValueError(error_message) X = assay_group[X_key] if isinstance(X, h5py.Dataset): datasets_to_load.add(X.name) else: datasets_to_load.add(X['data'].name) datasets_to_load.add(X['indices'].name) datasets_to_load.add(X['indptr'].name) # `obs` SingleCell._tabulate_h5Seurat_dataframe( h5Seurat_file['meta.data'], datasets_to_load, columns=obs_columns) # `var` if 'meta.features' in assay_group: SingleCell._tabulate_h5Seurat_dataframe( assay_group['meta.features'], datasets_to_load, columns=var_columns) # `obsm` if obsm is None: for key, value in \ h5Seurat_file['reductions'].items(): if value.attrs['active.assay'] != assay: continue datasets_to_load.add( f'reductions/{key}/cell.embeddings') # `obsp` if obsp is None: for value in h5Seurat_file['graphs'].values(): if value.attrs['assay.used'] != assay: continue datasets_to_load.add(value['data'].name) datasets_to_load.add(value['indices'].name) datasets_to_load.add(value['indptr'].name) # Infer the number of cells meta_data = h5Seurat_file['meta.data'] num_cells = meta_data['_index'].shape[0] # 2. Preload datasets in parallel datasets_to_preload = \ SingleCell._get_datasets_to_preload( h5Seurat_file, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): h5Seurat_file.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, source, num_threads=num_threads) h5Seurat_file = h5py.File(source) else: preloaded_datasets = \ SingleCell._read_parallel( datasets_to_preload, source, num_cells=num_cells, num_threads=num_threads) # 3. Assemble the SingleCell dataset # Check that `assay` is an assay in the `.h5Seurat` # file, or set it to `active.assay` if `None` if assay is None: assay = h5Seurat_file.attrs['active.assay'].item() elif assay not in h5Seurat_file['assays']: error_message = ( f'assay {assay!r} does not exist in the ' f'.h5Seurat file; specify a different assay ' f'than {assay!r}') raise ValueError(error_message) # Load `X` assay_group = h5Seurat_file['assays'][assay] if X is False: if X_key is not None: error_message = ( 'when loading an .h5Seurat file with ' 'X=False, X_key must be None') raise ValueError(error_message) self._X = None else: if X_key is None: X_key = 'counts' if X_key not in assay_group: error_message = ( f"the 'counts' key is not present in " f"the .h5Seurat file as part of assay " f"{assay!r}; specify a different " f"assay than {assay!r} or specify " f"X_key as something other than " f"'counts'") raise ValueError(error_message) else: check_type(X_key, 'X_key', str, 'a string or False') if X_key not in assay_group: error_message = ( f'X_key {X_key!r} is not present in ' f'the .h5Seurat file as part of assay ' f'{assay!r}; specify a different ' f'assay than {assay!r} or a different ' f'X_key than {X_key!r}') raise ValueError(error_message) X = assay_group[X_key] if isinstance(X, h5py.Dataset): warning_message = ( f'{X_key!r} is stored as a dense matrix; ' f'auto-converting to a sparse csr_array') warnings.warn(warning_message) X = SingleCell._read_dataset(X, preloaded_datasets) self.X = csr_array(X.T) else: data = SingleCell._read_dataset( X['data'], preloaded_datasets) self.X = csr_array(( data, SingleCell._read_dataset( X['indices'], preloaded_datasets), SingleCell._read_dataset( X['indptr'], preloaded_datasets)), shape=X.attrs['dims'][::-1]) # Load `obs` self.obs = SingleCell._read_h5Seurat_dataframe( h5Seurat_file['meta.data'], preloaded_datasets, columns=obs_columns) # Load `var` self.var = SingleCell._read_h5Seurat_dataframe( assay_group['meta.features'], preloaded_datasets, columns=var_columns) \ if 'meta.features' in assay_group else \ pl.Series('feature.names', assay_group['features'][:])\ .cast(pl.String)\ .to_frame() # Load `obsm` self.obsm = { key: value['cell.embeddings'][:].T for key, value in h5Seurat_file['reductions'].items() if value.attrs['active.assay'] == assay} # Load `obsp` self.obsp = { key: csr_array(( SingleCell._read_dataset( value['data'], preloaded_datasets), SingleCell._read_dataset( value['indices'], preloaded_datasets), SingleCell._read_dataset( value['indptr'], preloaded_datasets)), shape=value.attrs['dims'][::-1]) for key, value in h5Seurat_file['graphs'].items() if value.attrs['assay.used'] == assay} # Load `uns` self.uns = SingleCell._read_h5Seurat_uns( h5Seurat_file['misc']) self.varm = {} self.varp = {} finally: h5Seurat_file.close() elif source.endswith('.h5'): if not os.path.exists(source): error_message = f'10x .h5 file {source} does not exist' raise FileNotFoundError(error_message) for prop, prop_name in ( (obs, 'obs'), (var, 'var'), (obsm, 'obsm'), (varm, 'varm'), (obsp, 'obsp'), (varp, 'varp'), (uns, 'uns'), (X_key, 'X_key'), (assay, 'assay'), (obs_columns, 'obs_columns')): if prop is not None: error_message = ( f'when loading a 10x .h5 file, {prop_name} ' f'must be None') raise ValueError(error_message) if var_columns is not None: var_columns = to_tuple_checked( var_columns, 'var_columns', str, 'strings') # The logic here is similar to `.h5ad` files: preload in # parallel where we can h5_file = h5py.File(source) try: if num_threads == 1: preloaded_datasets = {} else: # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() # `X` matrix = h5_file['matrix'] if X is None: datasets_to_load.add(matrix['data'].name) datasets_to_load.add(matrix['indices'].name) datasets_to_load.add(matrix['indptr'].name) elif X is not False: error_message = ( 'when loading a 10x .h5 file, X must be ' 'None or False') raise ValueError(error_message) # `obs` datasets_to_load.add(matrix['barcodes'].name) # `var` features = matrix['features'] if var_columns is not None: for column in var_columns: if column not in features: error_message = ( f'var_columns contains the column ' f'{column!r}, which is not ' f'present in the .h5 file') raise ValueError(error_message) datasets_to_load.add(features[column].name) else: for column in 'name', 'id', 'feature_type', \ 'genome': datasets_to_load.add(features[column].name) for column in 'pattern', 'read', 'sequence': if column in features: datasets_to_load.add( features[column].name) # Infer the number of cells num_cells = matrix['barcodes'].shape[0] # 2. Preload datasets in parallel datasets_to_preload = \ SingleCell._get_datasets_to_preload( h5_file, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): h5_file.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, source, num_threads=num_threads) h5_file = h5py.File(source) else: preloaded_datasets = \ SingleCell._read_parallel( datasets_to_preload, source, num_cells=num_cells, num_threads=num_threads) # 3. Assemble the SingleCell dataset matrix = h5_file['matrix'] features = matrix['features'] # Load `X` if X is None: self.X = csr_array(( SingleCell._read_dataset( matrix['data'], preloaded_datasets), SingleCell._read_dataset( matrix['indices'], preloaded_datasets), SingleCell._read_dataset( matrix['indptr'], preloaded_datasets)), shape=matrix['shape'][:][::-1]) elif X is False: self._X = None else: error_message = ( 'when loading a 10x .h5 file, X must be None ' 'or False') raise ValueError(error_message) # Load `obs` and `var` self.obs = pl.Series('barcodes', matrix['barcodes'][:])\ .cast(pl.String)\ .to_frame() if var_columns is not None: for column in var_columns: if column not in features: error_message = ( f'var_columns contains the column ' f'{column!r}, which is not present in ' f'the .h5 file') raise ValueError(error_message) else: var_columns = \ ['name', 'id', 'feature_type', 'genome'] + \ [column for column in ('pattern', 'read', 'sequence') if column in features] self.var = pl.DataFrame([ pl.Series(column, SingleCell._read_dataset( features[column], preloaded_datasets)) .cast(pl.String) for column in var_columns]) self.obsm = {} self.varm = {} self.obsp = {} self.varp = {} self.uns = {} finally: h5_file.close() elif source.endswith('.mtx') or source.endswith('.mtx.gz'): if not os.path.exists(source): error_message = f'10x file {source} does not exist' raise FileNotFoundError(error_message) if obs is None: ungzipped_barcode_file = \ f'{os.path.dirname(source)}/barcodes.tsv' if os.path.exists(ungzipped_barcode_file): barcode_file = ungzipped_barcode_file else: gzipped_barcode_file = \ f'{os.path.dirname(source)}/barcodes.tsv.gz' if os.path.exists(gzipped_barcode_file): barcode_file = gzipped_barcode_file else: error_message = ( f'the cell-level metadata file ' f'corresponding to {source} was not found ' f'at either {ungzipped_barcode_file} or ' f'{gzipped_barcode_file}; you can specify ' f'a custom location via the obs argument') raise FileNotFoundError(error_message) else: if not isinstance(obs, (str, Path)): error_message = ( f'when loading a 10x .mtx or .mtx.gz file, ' f'obs must be None or the path to a ' f'barcodes.tsv or barcodes.tsv.gz file of ' f'cell-level metadata, but it has type ' f'{type(obs).__name__!r}') raise TypeError(error_message) barcode_file = str(obs) if not os.path.exists(barcode_file): error_message = ( f'the cell-level metadata file {barcode_file} ' f'does not exist') raise FileNotFoundError(error_message) if var is None: ungzipped_feature_file = \ f'{os.path.dirname(source)}/features.tsv' if os.path.exists(ungzipped_feature_file): feature_file = ungzipped_feature_file else: gzipped_feature_file = \ f'{os.path.dirname(source)}/features.tsv.gz' if os.path.exists(gzipped_feature_file): feature_file = gzipped_feature_file else: error_message = ( f'the gene-level metadata file ' f'corresponding to {source} was not found ' f'at either {ungzipped_feature_file} or ' f'{gzipped_feature_file}; you can specify ' f'a custom location via the var argument') raise FileNotFoundError(error_message) else: if not isinstance(var, (str, Path)): error_message = ( f'when loading a 10x .mtx or .mtx.gz file, ' f'var must be None or the path to a ' f'features.tsv or features.tsv.gz file of ' f'gene-level metadata, but it has type ' f'{type(var).__name__!r}') raise TypeError(error_message) feature_file = str(var) if not os.path.exists(feature_file): error_message = ( f'the gene-level metadata file {feature_file} ' f'does not exist') raise FileNotFoundError(error_message) for prop, prop_name in ( (obsm, 'obsm'), (varm, 'varm'), (obsp, 'obsp'), (varp, 'varp'), (uns, 'uns'), (X_key, 'X_key'), (assay, 'assay'), (obs_columns, 'obs_columns'), (var_columns, 'var_columns')): if prop is not None: error_message = ( f'when loading a 10 .mtx or .mtx.gz file, ' f'{prop_name} must be None') raise ValueError(error_message) from scipy.io import mmread # Load `obs` and `var` self.obs = pl.read_csv(barcode_file, has_header=False, new_columns=['cell']) self.var = pl.read_csv(feature_file, has_header=False, new_columns=['gene']) # Load `X` if X is None: self.X = csr_array(mmread(source).T.tocsr()) elif X is False: self._X = None else: error_message = ( 'when loading a 10x .mtx or .mtx.gz file, X must ' 'be None or False') raise ValueError(error_message) self.obsm = {} self.varm = {} self.obsp = {} self.varp = {} self.uns = {} elif source.endswith('.rds'): if not os.path.exists(source): error_message = f'.rds file {source} does not exist' raise FileNotFoundError(error_message) for prop, prop_name in ( (X, 'X'), (obs, 'obs'), (var, 'var'), (varm, 'varm'), (obsp, 'obsp'), (varp, 'varp'), (uns, 'uns'), (obs_columns, 'obs_columns'), (var_columns, 'var_columns')): if prop is not None: error_message = ( f'when loading a .rds file, {prop_name} must ' f'be None') raise ValueError(error_message) from ryp import r, to_py, to_r r(f'.SingleCell.object = readRDS({source!r})') try: if X_key is None: X_key = 'counts' else: check_type(X_key, 'X_key', str, 'a string') classes = to_py('class(.SingleCell.object)', squeeze=False) if len(classes) == 1: if classes[0] == 'Seurat': r('suppressPackageStartupMessages(' 'library(SeuratObject))') X, self.obs, self.var, self.obsm, self.obsp, \ self.uns = SingleCell._from_seurat( '.SingleCell.object', assay=assay, layer=X_key, layer_name='X_key') self.varm = {} self.varp = {} elif classes[0] == 'SingleCellExperiment': if assay is not None: error_message = ( f'when loading a SingleCellExperiment ' f'.rds file, assay must be None') raise ValueError(error_message) r('suppressPackageStartupMessages(' 'library(SingleCellExperiment))') X, self.obs, self.var, self.obsm, self.uns = \ SingleCell._from_sce( '.SingleCell.object', assay=X_key, assay_name='X_key') self.obsp = {} self.varm = {} self.varp = {} else: error_message = ( f'the R object loaded from {source} must ' f'be a Seurat or SingleCellExperiment ' f'object, but has class {classes[0]!r}') raise TypeError(error_message) self.X = X elif len(classes) == 0: error_message = ( f'the R object loaded from {source} must be a ' f'Seurat or SingleCellExperiment object, but ' f'has no classes') raise TypeError(error_message) else: classes_string = \ ', '.join(f'{c!r}' for c in classes[:-1]) error_message = ( f'the R object loaded from {source} must be a ' f'Seurat object, but has classes ' f'{classes_string} and {classes[-1]!r}') raise TypeError(error_message) finally: r('rm(.SingleCell.object)') else: error_message = ( f'source is a filename with unsupported extension ' f'.{".".join(source.split(".")[1:])}; it must be ' f'.h5ad (AnnData), .h5 or .mtx/.mtx.gz (10x), .rds ' f'(Seurat or SingleCellExperiment), or ' f'.h5Seurat/.h5seurat (Seurat)') raise ValueError(error_message) # Propagate this SingleCell dataset's `num_threads` to its count matrix if self._X is not None: self._X._num_threads = num_threads # Check that each dimension is non-zero and ≤INT32_MAX num_cells = len(self._obs) num_genes = len(self._var) if num_cells == 0: error_message = 'len(obs) is 0: no cells remain' raise ValueError(error_message) if num_genes == 0: error_message = 'len(var) is 0: no genes remain' raise ValueError(error_message) if num_cells > 2_147_483_647: error_message = ( 'X has more than 2,147,483,647 (INT32_MAX) cells, which is ' 'not currently supported') raise ValueError(error_message) if num_genes > 2_147_483_647: error_message = ( 'X has more than 2,147,483,647 (INT32_MAX) genes, which is ' 'not currently supported') raise ValueError(error_message) # Set `uns['QCed']` and `uns['normalized']` to `False` if not set yet; # if already set but not a Boolean, back it up to # `uns['_QCed']`/`uns['_normalized']` for key in 'QCed', 'normalized': if key in self._uns: if not isinstance(self._uns[key], bool): new_key = f'_{key}' while new_key in self._uns: new_key = f'_{new_key}' warning_message = ( f'uns[{key!r}] already exists and is not Boolean; ' f'moving it to uns[{new_key!r}]') warnings.warn(warning_message) self._uns[new_key] = self._uns[key] self._uns[key] = False else: self._uns[key] = False
@property def X(self) -> csr_array | csc_array | None: """ The count matrix, as a sparse array. It behaves like a Scipy sparse array, but also has a `num_threads` attribute which determines the number of threads used for operations like subsetting. """ return self._X @X.setter def X(self, X: sparse.csr_array | sparse.csc_array | sparse.csr_matrix | sparse.csc_matrix) -> None: if isinstance(X, (csr_array, csc_array)): pass elif isinstance(X, (sparse.csr_array, sparse.csr_matrix)): X = csr_array(X) elif isinstance(X, (sparse.csc_array, sparse.csc_matrix)): X = csc_array(X) elif X is None: error_message = ( 'attempting to set X to None; if you want to remove X to save ' 'memory, use drop_X() instead') raise ValueError(error_message) else: error_message = ( f'X must be a csr_array, csc_array, csr_matrix, or ' f'csc_matrix, but has type {type(X).__name__!r}') raise TypeError(error_message) try: if X.shape != self._X.shape: error_message = ( f'new X is {X.shape[0]:,} × {X.shape[1]:,}, but old X is ' f'{self._X.shape[0]:,} × {self._X.shape[1]:,}') raise ValueError(error_message) except AttributeError: # `self._X` may not have been assigned yet pass dtype = X.dtype if dtype != np.int32 and dtype != np.int64 and \ dtype != np.float32 and dtype != np.float64 and \ dtype != np.uint32 and dtype != np.uint64: error_message = ( f'X must be (u)int32/64 or float32/64, but has data type ' f'{str(dtype)}') raise TypeError(error_message) try: if X.shape[0] != len(self._obs): error_message = ( f'len(obs) is {len(self._obs):,}, but X.shape[0] is ' f'{X.shape[0]:,}') raise ValueError(error_message) except AttributeError: # `self._obs` may not have been assigned yet pass try: if X.shape[1] != len(self._var): error_message = ( f'len(var) is {len(self._var):,}, but X.shape[1] is ' f'{X.shape[1]:,}') raise ValueError(error_message) except AttributeError: # `self._var` may not have been assigned yet pass is_csc = int(isinstance(X, csc_array)) if len(X.indptr) - 1 != X.shape[is_csc]: error_message = ( f'X is corrupted: len(X.indptr) - 1 is ' f'{len(X.indptr) - 1:,}, but X.shape[{is_csc}] is ' f'{X.shape[is_csc]:,}') raise ValueError(error_message) self._X = X if X.nnz != 0 else None @property def obs(self) -> pl.DataFrame: """ A Polars DataFrame of metadata for each cell. """ return self._obs @obs.setter def obs(self, obs: pl.DataFrame) -> None: check_type(obs, 'obs', pl.DataFrame, 'a polars DataFrame') obs_names_dtype = obs[:, 0].dtype if obs_names_dtype not in (pl.String, pl.Enum, pl.Categorical) and \ obs_names_dtype not in pl.INTEGER_DTYPES: error_message = ( f'the first column of obs ({obs.columns[0]!r}) must be ' f'String, Enum, Categorical, or integer, but has data type ' f'{obs_names_dtype.base_type()!r}') raise ValueError(error_message) try: if len(obs) != len(self._obs): error_message = ( f'new obs has length {len(obs):,}, but old obs has length ' f'{len(self._obs):,}') raise ValueError(error_message) except AttributeError: # `self._obs` may not have been assigned yet pass self._obs = obs @property def var(self) -> pl.DataFrame: """ A Polars DataFrame of metadata for each gene. """ return self._var @var.setter def var(self, var: pl.DataFrame) -> None: check_type(var, 'var', pl.DataFrame, 'a polars DataFrame') var_names_dtype = var[:, 0].dtype if var_names_dtype not in (pl.String, pl.Enum, pl.Categorical) and \ var_names_dtype not in pl.INTEGER_DTYPES: error_message = ( f'the first column of var ({var.columns[0]!r}) must be ' f'String, Enum, Categorical, or integer, but has data type ' f'{var_names_dtype.base_type()!r}') raise ValueError(error_message) try: if len(var) != len(self._var): error_message = ( f'new var has length {len(var):,}, but old var has length ' f'{len(self._var):,}') raise ValueError(error_message) except AttributeError: # `self._var` may not have been assigned yet pass self._var = var @property def obsm(self) -> dict[str, np.ndarray | pl.DataFrame]: """ A dictionary of 2D NumPy arrays, where the length of each array's first dimension is the number of cells. """ return self._obsm @obsm.setter def obsm(self, obsm: dict[str, np.ndarray | pl.DataFrame]) -> None: check_type(obsm, 'obsm', dict, 'a dictionary') self._obsm = Obsm(obsm, length=len(self._obs)) @property def varm(self) -> dict[str, np.ndarray | pl.DataFrame]: """ A dictionary of 2D NumPy arrays, where the length of each array's first dimension is the number of genes. """ return self._varm @varm.setter def varm(self, varm: dict[str, np.ndarray | pl.DataFrame]) -> None: check_type(varm, 'varm', dict, 'a dictionary') self._varm = Varm(varm, length=len(self._var)) @property def obsp(self) -> dict[str, csr_array | csc_array]: """ A dictionary of 2D sparse arrays, where the length and width of each array is the number of cells. The arrays behave like Scipy sparse arrays, but also have a `num_threads` attribute which determines the number of threads used for operations like subsetting. """ return self._obsp @obsp.setter def obsp(self, obsp: dict[str, sparse.csr_array | sparse.csc_array | sparse.csr_matrix | sparse.csc_matrix]) -> None: check_type(obsp, 'obsp', dict, 'a dictionary') self._obsp = Obsp(obsp, length=len(self._obs)) @property def varp(self) -> dict[str, csr_array | csc_array]: """ A dictionary of 2D sparse arrays, where the length and width of each array is the number of genes. The arrays behave like Scipy sparse arrays, but also have a `num_threads` attribute which determines the number of threads used for operations like subsetting. """ return self._varp @varp.setter def varp(self, varp: dict[str, sparse.csr_array | sparse.csc_array | sparse.csr_matrix | sparse.csc_matrix]) -> None: check_type(varp, 'varp', dict, 'a dictionary') self._varp = Varp(varp, length=len(self._var)) @property def uns(self) -> UnsDict: """ A dictionary of miscellaneous metadata. Keys must be strings and values may be scalars, NumPy arrays, or other dictionaries. """ return self._uns @uns.setter def uns(self, uns: UnsDict) -> None: check_type(uns, 'uns', dict, 'a dictionary') self._uns = Uns(uns) @staticmethod def _read_uns(uns_group: h5py.Group) -> UnsDict: """ Recursively load `uns` from an `.h5ad` file. Args: uns_group: `uns` as an `h5py.Group` Returns: The loaded `uns`. """ return {key: SingleCell._read_uns(value) if isinstance(value, h5py.Group) else None if value.shape is None else (pl.Series(value[:]).cast(pl.String).to_numpy() if value.shape else value[()].decode('utf-8')) if value.dtype == object else (value[:] if value.shape else value[()].item()) for key, value in uns_group.items()} @staticmethod def _read_h5Seurat_uns(uns_group: h5py.Group) -> UnsDict: """ Recursively load `uns` (i.e. `misc`) from an `.h5Seurat` file. Args: uns_group: `uns` as an `h5py.Group` Returns: The loaded `uns`. """ return {key: SingleCell._read_h5Seurat_uns(value) if isinstance(value, h5py.Group) else None if value.shape is None else (pl.Series(value[:]).cast(pl.String).to_numpy() if len(value) > 1 else value[:].item().decode('utf-8')) if value.dtype == object else (value[:] if len(value) > 1 else value[:].item()) for key, value in uns_group.items()} @staticmethod def _save_uns(uns: UnsDict, uns_group: h5py.Group, h5ad_file: h5py.File) -> None: """ Recursively save `uns` to an `.h5ad` file. Args: uns: an `uns` dictionary uns_group: `uns` as an `h5py.Group` h5ad_file: an `h5py.File` open in write mode """ uns_group.attrs['encoding-type'] = 'dict' uns_group.attrs['encoding-version'] = '0.1.0' for key, value in uns.items(): if isinstance(value, dict): SingleCell._save_uns(value, uns_group.create_group(key), h5ad_file) else: dataset = uns_group.create_dataset(key, data=value) dataset.attrs['encoding-type'] = \ ('string-array' if value.dtype == object else 'array') \ if isinstance(value, np.ndarray) else \ 'string' if isinstance(value, str) else 'numeric-scalar' dataset.attrs['encoding-version'] = '0.2.0' @staticmethod def _save_h5Seurat_uns(uns: UnsDict, misc_group: h5py.Group, h5Seurat_file: h5py.File) -> None: """ Recursively save `uns` (i.e. `misc`) to an `.h5Seurat` file. Only string values will be saved. Args: uns: an `uns` dictionary misc_group: `uns` as an `h5py.Group` h5Seurat_file: an `h5py.File` open in write mode """ for key, value in uns.items(): if isinstance(value, dict): SingleCell._save_h5Seurat_uns( value, misc_group.create_group(key), h5Seurat_file) elif isinstance(value, str): misc_group.create_dataset(key, data=value) @property def obs_names(self) -> pl.Series: """ A shortcut to access the first column of `obs`. Generally holds cell barcodes. """ return self._obs[:, 0] @property def var_names(self) -> pl.Series: """ A shortcut to access the first column of `var`. Generally holds gene names. """ return self._var[:, 0]
[docs] def set_obs_names(self, column: str, /) -> SingleCell: """ Sets a column as the new first column of `obs`, i.e. the `obs_names`. Args: column: the column name in `obs`; must be String, Enum, Categorical, or integer Returns: A new SingleCell dataset with `column` as the first column of `obs`. If `column` is already the first column, return this dataset unchanged. """ obs = self._obs check_type(column, 'column', str, 'a string') if column == obs.columns[0]: return self if column not in obs: error_message = f'{column!r} is not a column of obs' raise ValueError(error_message) check_dtype(obs[column], f'obs[{column!r}]', (pl.String, pl.Enum, pl.Categorical, 'integer')) return SingleCell(X=self._X, obs=obs.select(column, pl.exclude(column)), var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def set_var_names(self, column: str, /) -> SingleCell: """ Sets a column as the new first column of `var`, i.e. the `var_names`. Args: column: the column name in `var`; must be String, Enum, Categorical, or integer Returns: A new SingleCell dataset with `column` as the first column of `var`. If `column` is already the first column, return this dataset unchanged. """ var = self._var check_type(column, 'column', str, 'a string') if column == var.columns[0]: return self if column not in var: error_message = f'{column!r} is not a column of var' raise ValueError(error_message) check_dtype(self._var[column], f'var[{column!r}]', (pl.String, pl.Enum, pl.Categorical, 'integer')) return SingleCell(X=self._X, obs=self._obs, var=var.select(column, pl.exclude(column)), obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
@property def num_threads(self) -> int: """ The default number of threads used for this SingleCell dataset's operations. """ return self._num_threads @num_threads.setter def num_threads(self, num_threads: int | np.integer) -> None: """ Set the default number of threads used for this SingleCell dataset's operations. Also sets the number of threads for this SingleCell object's count matrix, if present. Args: num_threads: the new default number of threads. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. """ check_type(num_threads, 'num_threads', int, 'a positive integer or -1') if num_threads == -1: num_threads = os.cpu_count() else: num_threads = int(num_threads) if num_threads <= 0: error_message = ( f'num_threads is {num_threads:,}, but must be a positive ' f'integer or -1') raise ValueError(error_message) self._num_threads = num_threads if self._X is not None: self._X.num_threads = num_threads
[docs] def set_num_threads(self, num_threads: int | np.integer, /) -> SingleCell: """ Return a new SingleCell dataset with a different default number of threads. Also sets the number of threads for the SingleCell dataset's count matrix, if present. Args: num_threads: the new default number of threads. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. """ if self._X is not None: self._X.num_threads = num_threads return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=num_threads)
def _process_num_threads(self, num_threads: int | np.integer | None) -> int: """ Process a `num_threads` value specified by the user as an argument to a SingleCell function. Check that `num_threads` is a positive integer, -1 or `None`; if `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()`. Args: num_threads: the number of threads specified by the user Returns: The actual number of threads to use. """ if num_threads is None: return self._num_threads check_type(num_threads, 'num_threads', int, 'a positive integer, -1, or None') if num_threads == -1: return os.cpu_count() else: num_threads = int(num_threads) if num_threads <= 0: error_message = ( f'num_threads is {num_threads:,}, but must be a positive ' f'integer, -1, or None') raise ValueError(error_message) return num_threads @staticmethod def _process_num_threads_static(num_threads: int | np.integer) -> int: """ Process a `num_threads` value specified by the user as an argument to a SingleCell function. Check that `num_threads` is a positive integer or -1; if -1, set to `os.cpu_count()`. Unlike `_process_num_threads()`, there is no SingleCell object, so `num_threads` cannot be `None`. Args: num_threads: the number of threads specified by the user Returns: The actual number of threads to use. """ check_type(num_threads, 'num_threads', int, 'a positive integer or -1') if num_threads == -1: return os.cpu_count() else: num_threads = int(num_threads) if num_threads <= 0: error_message = ( f'num_threads is {num_threads:,}, but must be a positive ' f'integer or -1') raise ValueError(error_message) return num_threads @staticmethod def _get_datasets_to_preload(hdf5_file: h5py.File, datasets_to_load: set[str]) -> \ tuple[list[tuple[str, tuple[int, ...], np.dtype, int | None]], list[tuple[str, tuple[int, ...], np.dtype, int | None]]]: """ Given the names of datasets that need to be loaded, stratify into fixed-length datasets (numeric or fixed-length string data) and variable-length string datasets (dtype == object). Also get each dataset's offset within the hdf5 file. Args: hdf5_file: an `h5py.File` open in read mode datasets_to_load: the set of datasets to be loaded Returns: Two lists of `(name, shape, dtype, file_offset)` tuples of datasets to preload: the first is for fixed-width datasets (numeric and fixed-length strings), the second for variable-length string datasets. `file_offset` is the byte offset of the dataset's contiguous storage in the file, or `None` if the dataset is chunked and/or compressed. If `None` for any dataset, we will use the multiprocessing-based parallel reader rather than the thread-based one. """ fixed_datasets = [] vlen_datasets = [] for key in datasets_to_load: dataset = hdf5_file[key] file_offset = dataset.id.get_offset() dataset_info = \ dataset.name, dataset.shape, dataset.dtype, file_offset if dataset.dtype.kind == 'O': vlen_datasets.append(dataset_info) else: fixed_datasets.append(dataset_info) return fixed_datasets, vlen_datasets @staticmethod def _read_dataset(dataset: h5py.Dataset, preloaded_datasets: dict[str, np.ndarray]) -> \ np.ndarray: """ Read an HDF5 dataset into a NumPy array, if not already preloaded. Read in ~32-MB chunks, to allow KeyboardInterrupts. Args: dataset: the dataset to read preloaded_datasets: a dictionary of preloaded datasets. If `key` is present, load it from there instead of the HDF5 file. Returns: A 1D NumPy array with the contents of the dataset. """ dataset_name = dataset.name if dataset_name in preloaded_datasets: return preloaded_datasets[dataset_name] # Fast path for datasets under 32 MB if not dataset.shape: return dataset[()] elif dataset.size * dataset.dtype.itemsize <= 33_554_432: return dataset[:] # Calculate how many rows fit into 32 MB, minimum 1 step = max(1, 33_554_432 // ( int(np.prod(dataset.shape[1:])) * dataset.dtype.itemsize)) # Align reading to HDF5 internal chunks if dataset.chunks is not None: step = max(dataset.chunks[0], ( step // dataset.chunks[0]) * dataset.chunks[0]) result = np.empty(dataset.shape, dtype=dataset.dtype) for start in range(0, dataset.shape[0], step): end = min(start + step, dataset.shape[0]) # `read_direct()` skips temporary array allocation overhead dataset.read_direct(result, np.s_[start:end], np.s_[start:end]) return result @staticmethod def _read_parallel(datasets_to_preload: tuple[ list[tuple[str, tuple[int, ...], np.dtype, int | None]], list[tuple[str, tuple[int, ...], np.dtype, int | None]]], filename: str, *, num_cells: int, num_threads: int | np.integer) -> \ dict[str, np.ndarray | pl.Series]: """ Read a sequence of HDF5 datasets into a dictionary of 1D NumPy arrays (or polars String Series, for variable-length strings) in parallel. Optimized for the common case of unchunked, uncompressed datasets: uses `pread()` with OpenMP threads to bypass HDF5's single-threaded I/O. Variable-length string datasets are read by parsing HDF5's global heap format directly, avoiding both HDF5's global lock and the overhead of creating Python string objects. The result is a polars String Series backed by Arrow. All datasets >= `num_cells` elements are loaded with per-thread byte slices, so that Linux's first-touch NUMA policy distributes pages across NUMA nodes in the same pattern as downstream `prange` over cells. Smaller datasets are split into ~1 MiB chunks and distributed among threads via best-fit decreasing bin-packing. Args: datasets_to_preload: two lists of (name, shape, dtype, file_offset) tuples of datasets to preload: the first is for fixed-width datasets, the second for variable-length string datasets. file_offset is None for chunked/compressed datasets. filename: the HDF5 filename num_cells: the number of cells (observations) in the dataset; datasets of at least `num_cells` in size will be loaded in a NUMA-aware way across all threads, while smaller datasets will undergo chunk-based loading num_threads: the number of threads (or processes, in the fallback) to spawn when reading Returns: A dictionary mapping dataset names to their contents: 1D NumPy arrays for fixed-width datasets, and polars String Series for variable-length string datasets. """ fixed_datasets, vlen_datasets = datasets_to_preload if len(fixed_datasets) == 0 and len(vlen_datasets) == 0: return {} result = {} # Allocate arrays for fixed-width datasets and separate into large # (>= num_cells elements, loaded in a NUMA-aware way across all # threads) and small (loaded in ~1 MiB chunks, bin-packed among # threads) large_fixed_arrays = [] large_fixed_file_offsets = [] small_fixed_info = [] total_small_bytes = 0 for name, shape, dtype, file_offset in fixed_datasets: if np.prod(shape) >= num_cells: array = numa_zeros(shape, dtype=dtype) large_fixed_arrays.append(array) large_fixed_file_offsets.append(file_offset) else: array = np.empty(shape, dtype=dtype) rows = shape[0] bytes_per_row = np.prod(shape[1:]) * dtype.itemsize bytes_per_dataset = rows * bytes_per_row small_fixed_info.append((array, file_offset, rows, bytes_per_row, bytes_per_dataset)) total_small_bytes += bytes_per_dataset result[name] = array large_fixed_file_offsets = np.array(large_fixed_file_offsets, dtype=np.uint64) # Bin-pack small fixed-width datasets among threads using a best-fit # decreasing (BFD) bin-packing algorithm page_size = 16_384 if sys.platform == 'darwin' else 4096 page_size_minus_1 = page_size - 1 target = (total_small_bytes + num_threads - 1) // num_threads target = min(max(target, 4_194_304), 67_108_864) # 4 MB to 64 MB target = (target // page_size) * page_size or \ page_size # multiple of `page_size` bytes task_records = [] for array, file_offset, rows, bytes_per_row, bytes_per_dataset \ in small_fixed_info: destination_base = array.ctypes.data if bytes_per_dataset <= target: task_records.append((bytes_per_dataset, file_offset, destination_base)) else: rows_per_chunk = min(max(1, target // bytes_per_row), rows) for start in range(0, rows, rows_per_chunk): end = min(start + rows_per_chunk, rows) chunk_bytes = (end - start) * bytes_per_row # Since we are doing direct reading, as an optimization, # align the chunk size when the file offset is # block-aligned offset = file_offset + start * bytes_per_row if offset & page_size_minus_1 == 0: chunk_bytes &= ~page_size_minus_1 end = start + chunk_bytes // bytes_per_row task_records.append(( chunk_bytes, offset, destination_base + start * bytes_per_row)) task_records.sort(reverse=True) tasks_by_thread = [[] for _ in range(num_threads)] loads = np.zeros(num_threads, dtype=np.int64) for chunk_bytes, chunk_file_offset, chunk_destination in task_records: best_slack = None for index, load in enumerate(loads): slack = target - (load + chunk_bytes) if slack >= 0 and (best_slack is None or slack < best_slack): best_index = index best_slack = slack if best_slack is None: best_index = loads.argmin() tasks_by_thread[best_index].append( (chunk_file_offset, chunk_destination, chunk_bytes)) loads[best_index] += chunk_bytes # Flatten per-thread tasks into arrays for Cython chunk_file_offsets = [] chunk_byte_sizes = [] chunk_destinations = [] chunk_thread_boundaries = [0] for thread in range(num_threads): for chunk_file_offset, chunk_destination, chunk_bytes \ in tasks_by_thread[thread]: chunk_file_offsets.append(chunk_file_offset) chunk_byte_sizes.append(chunk_bytes) chunk_destinations.append(chunk_destination) chunk_thread_boundaries.append(len(chunk_file_offsets)) chunk_file_offsets = np.array(chunk_file_offsets, dtype=np.uint64) chunk_byte_sizes = np.array(chunk_byte_sizes, dtype=np.uint64) chunk_destinations = np.array(chunk_destinations, dtype=np.uint64) chunk_thread_boundaries = \ np.array(chunk_thread_boundaries, dtype=np.uint64) # Prepare for variable-length string loading num_vlen_datasets = len(vlen_datasets) vlen_num_strings = np.empty(num_vlen_datasets, dtype=np.uint64) vlen_file_offsets = np.empty(num_vlen_datasets, dtype=np.uint64) if num_vlen_datasets > 0: with h5py.File(filename) as hdf5_file: offset_size, length_size = \ hdf5_file.id.get_create_plist().get_sizes() for index, (name, shape, dtype, file_offset) in \ enumerate(vlen_datasets): vlen_num_strings[index] = np.prod(shape) vlen_file_offsets[index] = file_offset else: offset_size = length_size = 0 # Open the file and read the datasets. Direct reads (`O_DIRECT`) bypass # the page cache, giving a several-fold speedup on NFS filesystems. # However, certain file systems (Lustre < 2.16, ext4, xfs, etc.) only # support direct reads that are aligned to a multiple of e.g. 4096 # bytes. Thus, we first probe the filesystem with a deliberately # unaligned direct read, and use `O_DIRECT` only if the probe succeeds, # falling back to buffered reading if it fails. fd = -1 direct = False try: if hasattr(os, 'O_DIRECT'): try: candidate = os.open(filename, os.O_RDONLY | os.O_DIRECT) except OSError: # Some filesystems (tmpfs, some overlay/FUSE mounts) reject # O_DIRECT at open() time candidate = -1 if candidate != -1: # Probe: force all three O_DIRECT alignment axes to be # unaligned at once - file offset (1), length (1), and # destination buffer. An mmap region is page-aligned, so # reading into buf[1:2] guarantees an unaligned memory # address. probe_buffer = mmap.mmap(-1, 4096) try: os.preadv(candidate, [memoryview(probe_buffer)[1:2]], 1) fd = candidate direct = True except OSError: os.close(candidate) finally: probe_buffer.close() if fd == -1: fd = os.open(filename, os.O_RDONLY) if sys.platform == 'darwin': # macOS equivalent of `O_DIRECT`; no alignment constraints import fcntl fcntl.fcntl(fd, fcntl.F_NOCACHE, 1) if not direct and hasattr(os, 'posix_fadvise'): # Buffered fallback path only: `posix_fadvise()` is a # page-cache hint and is a no-op under O_DIRECT. Use # `POSIX_FADV_SEQUENTIAL` because the gain on the large # fixed-width datasets likely outweighs the cost on the small # scattered vlen header reads. os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_SEQUENTIAL) vlen_offsets_out, vlen_data_out = read_all_datasets( fd, large_fixed_arrays, large_fixed_file_offsets, chunk_file_offsets, chunk_byte_sizes, chunk_destinations, chunk_thread_boundaries, vlen_file_offsets, vlen_num_strings, length_size, offset_size, num_threads) finally: if fd != -1: os.close(fd) for dataset_index, dataset in enumerate(vlen_datasets): output_offsets = vlen_offsets_out[dataset_index] output_data = vlen_data_out[dataset_index] num_strings = vlen_num_strings[dataset_index] total_string_bytes = output_offsets[num_strings] name = dataset[0] result[name] = pl.from_arrow( pa.LargeStringArray.from_buffers( num_strings, pa.py_buffer(output_offsets), pa.py_buffer(output_data[:total_string_bytes]))) return result @staticmethod def _tabulate_h5ad_dataframe( h5ad_file: h5py.File, key: str, datasets_to_load: set[str], *, columns: str | Sequence[str] | None = None) -> None: """ Add to `datasets_to_load` the datasets in an `.h5ad` file required to load `obs` or `var` (or dataframe keys of `obsm` or `varm`), or the requested columns thereof if `obs_columns` or `var_columns` was specified. Args: h5ad_file: an `h5py.File` open in read mode key: the key to be loaded as a DataFrame, e.g. `'obs'` or `'var'` datasets_to_load: the set of required datasets, which will be modified by this function columns: the column(s) of the DataFrame requested; the index column is always loaded as the first column, regardless of whether it is specified here """ # Get the group corresponding to `obs` or `var` group = h5ad_file[key] # Special case: the entire `obs` or `var` may rarely be a single NumPy # structured array (`dtype=void`) if isinstance(group, h5py.Dataset) and \ np.issubdtype(group.dtype, np.void): datasets_to_load.add(key) return # Special case: the entire `obs` or `var` may rarely be a single, # unnested column. unnested = group.attrs['encoding-type'] != 'dataframe' # Get the list of which columns to load, in which order if not unnested: if columns is None: columns = group.attrs['column-order'] else: columns = [column for column in to_tuple(columns) if column != group.attrs['_index']] for column in columns: if column not in group.attrs['column-order']: error_message = f'{column!r} is not a column of {key}' raise ValueError(error_message) # For each column, enumerate the dataset(s) required to load it for column, value in ((column, group[column]) for column in chain((group.attrs['_index'],), columns)) \ if not unnested else (('_index', group),): encoding_type = value.attrs.get('encoding-type') if encoding_type == 'categorical' or ( isinstance(value, h5py.Group) and all( key == 'categories' or key == 'codes' for key in value.keys())) or 'categories' in value.attrs: # Sometimes, the categories are stored in a different place # which is pointed to by `value.attrs['categories']` if 'categories' in value.attrs: category_object = h5ad_file[value.attrs['categories']] category_encoding_type = None datasets_to_load.add(value.name) else: category_object = value['categories'] category_encoding_type = \ category_object.attrs.get('encoding-type') datasets_to_load.add(value['codes'].name) # Sometimes, the categories are themselves nullable # integer or Boolean arrays if category_encoding_type == 'nullable-integer' or \ category_encoding_type == 'nullable-boolean' or ( isinstance(category_object, h5py.Group) and all( key == 'values' or key == 'mask' for key in category_object.keys())): datasets_to_load.add(category_object['values'].name) datasets_to_load.add(category_object['mask'].name) else: datasets_to_load.add(category_object.name) elif encoding_type == 'nullable-integer' or \ encoding_type == 'nullable-boolean' or ( isinstance(value, h5py.Group) and all( key == 'values' or key == 'mask' for key in value.keys())): datasets_to_load.add(value['values'].name) datasets_to_load.add(value['mask'].name) elif encoding_type == 'array' or \ encoding_type == 'string-array' or \ isinstance(value, h5py.Dataset): datasets_to_load.add(value.name) else: encoding = f'encoding-type {encoding_type!r}' \ if encoding_type is not None else 'encoding' if unnested: error_message = f'{key!r} has unsupported {encoding}' raise ValueError(error_message) else: error_message = ( f'{column!r} column of {key!r} has unsupported ' f'{encoding}') raise ValueError(error_message) @staticmethod def _read_h5ad_dataframe(h5ad_file: h5py.File, key: str, preloaded_datasets: dict[str, np.ndarray], *, columns: str | Sequence[str] | None = None) \ -> pl.DataFrame: """ Load `obs` or `var` from an `.h5ad` file as a polars DataFrame. Enum casts and Binary-to-String casts are deferred and applied in a single lazy pass, allowing polars to parallelize them across columns. Args: h5ad_file: an `h5py.File` open in read mode key: the key to load as a DataFrame, e.g. `'obs'` or `'var'` preloaded_datasets: a dictionary of preloaded datasets columns: the column(s) of the DataFrame to load; the index column is always loaded as the first column, regardless of whether it is specified here, and then the remaining columns are loaded in the order specified Returns: A polars DataFrame of the data in `h5ad_file[key]`. """ # Get the group corresponding to `obs` or `var` group = h5ad_file[key] # Special case: the entire `obs` or `var` may rarely be a single NumPy # structured array (`dtype=void`) if isinstance(group, h5py.Dataset) and \ np.issubdtype(group.dtype, np.void): data = pl.from_numpy(group[:]) data = data.with_columns(pl.col(pl.Binary).cast(pl.String)) return data # Special case: the entire `obs` or `var` may rarely be a single, # unnested column. unnested = group.attrs['encoding-type'] != 'dataframe' # Get the list of which columns to load, in which order if not unnested: if columns is None: columns = group.attrs['column-order'] else: columns = [column for column in to_tuple(columns) if column != group.attrs['_index']] for column in columns: if column not in group.attrs['column-order']: error_message = \ f'{column!r} is not a column of {key}' raise ValueError(error_message) # Create the DataFrame. Enum casts are deferred and applied in a # single lazy pass at the end, so that polars can parallelize them # across columns. data = {} deferred_casts = [] for column, value in ((column, group[column]) for column in chain((group.attrs['_index'],), columns)) \ if not unnested else (('_index', group),): encoding_type = value.attrs.get('encoding-type') if encoding_type == 'categorical' or ( isinstance(value, h5py.Group) and all( key == 'categories' or key == 'codes' for key in value.keys())) or 'categories' in value.attrs: # Sometimes, the categories are stored in a different place # which is pointed to by `value.attrs['categories']` if 'categories' in value.attrs: category_object = h5ad_file[value.attrs['categories']] category_encoding_type = None codes = SingleCell._read_dataset( value, preloaded_datasets) else: category_object = value['categories'] category_encoding_type = \ category_object.attrs.get('encoding-type') codes = SingleCell._read_dataset( value['codes'], preloaded_datasets) # Sometimes, the categories are themselves nullable # integer or Boolean arrays if category_encoding_type == 'nullable-integer' or \ category_encoding_type == 'nullable-boolean' or ( isinstance(category_object, h5py.Group) and all( key == 'values' or key == 'mask' for key in category_object.keys())): data[column] = pl.Series(SingleCell._read_dataset( category_object['values'], preloaded_datasets)[codes]) mask = pl.Series(SingleCell._read_dataset( category_object['mask'], preloaded_datasets)[codes] | (codes == -1)) has_missing = mask.any() if has_missing: data[column] = data[column].set(mask, None) continue categories = SingleCell._read_dataset( category_object, preloaded_datasets) mask = pl.Series(codes == -1) has_missing = mask.any() # polars does not (as of version 1.0) support Categoricals # or Enums with non-string categories, so if the categories # are not strings, just map the codes to the categories. if category_encoding_type == 'array' or ( isinstance(category_object, h5py.Dataset) and category_object.dtype != object): data[column] = pl.Series(categories[codes], nan_to_null=True) if has_missing: data[column] = data[column].set(mask, None) elif category_encoding_type == 'string-array' or ( isinstance(category_object, h5py.Dataset) and category_object.dtype == object): if has_missing: codes[mask] = 0 data[column] = pl.Series(codes, dtype=pl.UInt32) if has_missing: data[column] = data[column].set(mask, None) deferred_casts.append(pl.col(column).cast( pl.Enum(pl.Series(categories).cast(pl.String)))) else: encoding = \ f'encoding-type {category_encoding_type!r}' \ if category_encoding_type is not None \ else 'encoding' error_message = ( f'{column!r} column of {key!r} is a categorical ' f'with unsupported {encoding}') raise ValueError(error_message) elif encoding_type == 'nullable-integer' or \ encoding_type == 'nullable-boolean' or ( isinstance(value, h5py.Group) and all( key == 'values' or key == 'mask' for key in value.keys())): values = SingleCell._read_dataset( value['values'], preloaded_datasets) mask = SingleCell._read_dataset( value['mask'], preloaded_datasets) data[column] = pl.Series(values).set( pl.Series(mask), None) elif encoding_type == 'array' or ( isinstance(value, h5py.Dataset) and value.dtype != object): data[column] = pl.Series(SingleCell._read_dataset( value, preloaded_datasets), nan_to_null=True) elif encoding_type == 'string-array' or ( isinstance(value, h5py.Dataset) and value.dtype == object): data[column] = \ SingleCell._read_dataset(value, preloaded_datasets) else: encoding = f'encoding-type {encoding_type!r}' \ if encoding_type is not None else 'encoding' error_message = ( f'{column!r} column of {key!r} has unsupported ' f'{encoding}') raise ValueError(error_message) data = pl.DataFrame(data) # Apply deferred Enum casts and Binary-to-String casts in one pass, # allowing polars to parallelize across columns deferred_casts.append(pl.col(pl.Binary).cast(pl.String)) data = data.with_columns(deferred_casts) return data @staticmethod def _tabulate_h5Seurat_dataframe(group: h5py.Group, datasets_to_load: set[str], *, columns: str | Sequence[str] | None = None) -> None: """ Add to `datasets_to_load` the datasets in an `.h5ad` file required to load `obs` or `var` (or dataframe keys of `obsm` or `varm`), or the requested columns thereof if `obs_columns` or `var_columns` was specified. Args: group: the group to load as a DataFrame, e.g. `'meta.data'` or `'meta.features'` datasets_to_load: the set of datasets that will eventually need to be loaded; will be modified by this function columns: the column(s) of the DataFrame requested; the index column is always loaded as the first column, regardless of whether it is specified here """ # Get the list of which columns to load, in which order if columns is None: columns = group.attrs['colnames'] else: columns = [column for column in to_tuple(columns) if column != group.attrs['_index']] for column in columns: if column not in group.attrs['colnames']: error_message = f'{column!r} is not a column of meta.data' raise ValueError(error_message) # For each column, enumerate the dataset(s) required to load it for column in chain(group.attrs['_index'], columns): value = group[column] if isinstance(value, h5py.Group): # Factor if len(value) != 2 or 'levels' not in value or \ 'values' not in value: error_message = ( f"the h5Seurat file's meta.data contains a group " f"of unknown format, ") if len(value) <= 5: error_message += ( f'with the following keys: ' f'{", ".join(map(repr, value))}') else: error_message += ( f'with {len(value):,} keys, including the ' f'following: ' f'{", ".join(map(repr, islice(value, 5)))}') raise ValueError(error_message) datasets_to_load.add(value['values'].name) datasets_to_load.add(value['levels'].name) else: datasets_to_load.add(value.name) @staticmethod def _read_h5Seurat_dataframe(group: h5py.Group, preloaded_datasets: dict[ str, np.ndarray], *, columns: str | Sequence[str] | None = None) -> pl.DataFrame: """ Load `meta.data` (i.e. `obs`) or `meta.features` (i.e. `var`) from an `.h5Seurat` file as a polars DataFrame. Args: group: the group to load as a DataFrame, e.g. `'meta.data'` or `'meta.features'` preloaded_datasets: a dictionary of preloaded datasets columns: the column(s) of the DataFrame to load; the index column is always loaded as the first column, regardless of whether it is specified here, and then the remaining columns are loaded in the order specified Returns: A polars DataFrame of the data in `h5Seurat_file[key]`. """ # Get the list of which columns to load, in which order if columns is None: columns = group.attrs['colnames'] else: columns = [column for column in to_tuple(columns) if column != group.attrs['_index']] for column in columns: if column not in group.attrs['colnames']: error_message = f'{column!r} is not a column of meta.data' raise ValueError(error_message) # Get the list of which columns are Boolean Boolean_columns = group.attrs.get('logicals', ()) # Create the DataFrame data = {} for column in chain(group.attrs['_index'], columns): value = group[column] if isinstance(value, h5py.Group): # Factor if len(value) != 2 or 'levels' not in value or \ 'values' not in value: error_message = ( f"the h5Seurat file's meta.data contains a group " f"of unknown format, ") if len(value) <= 5: error_message += ( f'with the following keys: ' f'{", ".join(map(repr, value))}') else: error_message += ( f'with {len(value):,} keys, including the ' f'following: ' f'{", ".join(map(repr, islice(value, 5)))}') raise ValueError(error_message) values = SingleCell._read_dataset( value['values'], preloaded_datasets) levels = value['levels'][:] data[column] = (pl.Series(values) - 1)\ .cast(pl.Enum(pl.Series(levels).cast(pl.String))) else: data[column] = pl.Series(SingleCell._read_dataset( value, preloaded_datasets), nan_to_null=True) if column in Boolean_columns: data[column] = data[column]\ .replace({2: None})\ .cast(pl.Boolean) elif data[column].dtype == pl.Int32: data[column] = data[column].replace({-2147483648: None}) data = pl.DataFrame(data) # NumPy doesn't support encoding object-dtyped string arrays as UTF-8, # so do the conversion in polars instead. Also convert `b'NA'` # (h5Seurat's missing value indicator) to `null`. data = data.with_columns(pl.col(pl.Binary).replace({b'NA': None}) .cast(pl.String)) return data
[docs] @staticmethod def read_obs(h5ad_file: h5py.File | str | Path, *, columns: str | Iterable[str] | None = None, num_threads: int | np.integer = -1) -> pl.DataFrame: """ Load just `obs` from an `.h5ad` file as a polars DataFrame. Args: h5ad_file: an `.h5ad` filename columns: the column(s) of `obs` to load; if `None`, load all columns num_threads: the number of threads to use when reading. By default (`num_threads=-1`), use all available cores, as determined by `os.cpu_count()`. Returns: A polars DataFrame of the data in `obs`. """ check_type(h5ad_file, 'h5ad_file', (str, Path), 'a string or pathlib.Path') h5ad_file = str(h5ad_file) if not h5ad_file.endswith('.h5ad'): error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'" raise ValueError(error_message) filename = os.path.expanduser(h5ad_file) if not os.path.exists(filename): error_message = f'.h5ad file {h5ad_file} does not exist' raise FileNotFoundError(error_message) if columns is not None: columns = to_tuple_checked(columns, 'columns', str, 'strings') # Check that `num_threads` is a positive integer or -1; if -1, set to # `os.cpu_count()` num_threads = SingleCell._process_num_threads_static(num_threads) # Load `obs`, using a similar approach to loading the full `.h5ad` file if num_threads == 1: preloaded_datasets = {} else: # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() with h5py.File(h5ad_file) as f: SingleCell._tabulate_h5ad_dataframe( f, 'obs', datasets_to_load, columns=columns) # 2. Preload datasets in parallel # If any fixed-width dataset is chunked or compressed # (indicated by a `None` `file_offset`), fall back to # the multiprocessing-based reader. # Crucial: the HDF5 file must be closed before # `read_parallel_multiprocessing()` to avoid the child # processes inheriting the HDF5 library's state # associated with the open file handle when forking. datasets_to_preload = \ SingleCell._get_datasets_to_preload( f, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): f.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, h5ad_file, num_threads=num_threads) f = h5py.File(h5ad_file) else: preloaded_datasets = SingleCell._read_parallel( datasets_to_preload, h5ad_file, num_threads=num_threads) # 3. Assemble `obs` with h5py.File(h5ad_file) as f: obs = SingleCell._read_h5ad_dataframe( f, 'obs', preloaded_datasets, columns=columns) return obs
[docs] @staticmethod def read_var(h5ad_file: str | Path, *, columns: str | Iterable[str] | None = None, num_threads: int | np.integer = -1) -> pl.DataFrame: """ Load just `var` from an `.h5ad` file as a polars DataFrame. Args: h5ad_file: an `.h5ad` filename columns: the column(s) of `var` to load; if `None`, load all columns num_threads: the number of threads to use when reading. By default (`num_threads=-1`), use all available cores, as determined by `os.cpu_count()`. Returns: A polars DataFrame of the data in `var`. """ check_type(h5ad_file, 'h5ad_file', (str, Path), 'a string or pathlib.Path') h5ad_file = str(h5ad_file) if not h5ad_file.endswith('.h5ad'): error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'" raise ValueError(error_message) filename = os.path.expanduser(h5ad_file) if not os.path.exists(filename): error_message = f'.h5ad file {h5ad_file} does not exist' raise FileNotFoundError(error_message) if columns is not None: columns = to_tuple_checked(columns, 'columns', str, 'strings') # Check that `num_threads` is a positive integer or -1; if -1, set to # `os.cpu_count()` num_threads = SingleCell._process_num_threads_static(num_threads) # Load `var`, using a similar approach to loading the full `.h5ad` file if num_threads == 1: preloaded_datasets = {} else: # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() with h5py.File(h5ad_file) as f: SingleCell._tabulate_h5ad_dataframe( f, 'var', datasets_to_load, columns=columns) # 2. Preload datasets in parallel # If any fixed-width dataset is chunked or compressed # (indicated by a `None` `file_offset`), fall back to # the multiprocessing-based reader. # Crucial: the HDF5 file must be closed before # `read_parallel_multiprocessing()` to avoid the child # processes inheriting the HDF5 library's state # associated with the open file handle when forking. datasets_to_preload = \ SingleCell._get_datasets_to_preload( f, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): f.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, h5ad_file, num_threads=num_threads) f = h5py.File(h5ad_file) else: preloaded_datasets = SingleCell._read_parallel( datasets_to_preload, h5ad_file, num_threads=num_threads) # 3. Assemble `var` with h5py.File(h5ad_file) as f: var = SingleCell._read_h5ad_dataframe( f, 'var', preloaded_datasets, columns=columns) return var
[docs] @staticmethod def read_obsm(h5ad_file: str | Path, *, keys: str | Iterable[str] | None = None, num_threads: int | np.integer = -1) -> \ dict[str, np.ndarray | pl.DataFrame]: """ Load just `obsm` from an `.h5ad` file as a dictionary of Numpy arrays or DataFrames. Args: h5ad_file: an `.h5ad` filename keys: the keys(s) of `obsm` to load; if `None`, load all keys num_threads: the number of threads to use when reading. By default (`num_threads=-1`), use all available cores, as determined by `os.cpu_count()`. Returns: A dictionary of NumPy arrays and polars DataFrames of the data in obsm. """ check_type(h5ad_file, 'h5ad_file', (str, Path), 'a string or pathlib.Path') h5ad_file = str(h5ad_file) if not h5ad_file.endswith('.h5ad'): error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'" raise ValueError(error_message) filename = os.path.expanduser(h5ad_file) if not os.path.exists(filename): error_message = f'.h5ad file {h5ad_file} does not exist' raise FileNotFoundError(error_message) if keys is not None: keys = to_tuple_checked(keys, 'keys', str, 'strings') # Check that `num_threads` is a positive integer or -1; if -1, set to # `os.cpu_count()` num_threads = SingleCell._process_num_threads_static(num_threads) # Load `obsm`, using a similar approach to loading the full `.h5ad` # file if num_threads == 1: preloaded_datasets = {} else: # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() with h5py.File(h5ad_file) as f: if 'obsm' in f: obsm = f['obsm'] if keys is None: for key, value in obsm.items(): if isinstance(value, h5py.Dataset): datasets_to_load.add(f'obsm/{key}') else: SingleCell._tabulate_h5ad_dataframe( f, f'obsm/{key}', datasets_to_load) else: for key_index, key in enumerate(keys): if key not in obsm: error_message = ( f'keys[{key_index}] is {key!r}, which is ' f'not a key of obsm') raise ValueError(error_message) for key in keys: value = obsm[key] if isinstance(value, h5py.Dataset): datasets_to_load.add(f'obsm/{key}') else: SingleCell._tabulate_h5ad_dataframe( f, f'obsm/{key}', datasets_to_load) else: if keys is not None: error_message = 'keys was specified, but obsm is empty' raise ValueError(error_message) return {} # 2. Preload datasets in parallel # If any fixed-width dataset is chunked or compressed # (indicated by a `None` `file_offset`), fall back to # the multiprocessing-based reader. # Crucial: the HDF5 file must be closed before # `read_parallel_multiprocessing()` to avoid the child # processes inheriting the HDF5 library's state # associated with the open file handle when forking. datasets_to_preload = \ SingleCell._get_datasets_to_preload( f, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): f.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, h5ad_file, num_threads=num_threads) f = h5py.File(h5ad_file) else: preloaded_datasets = SingleCell._read_parallel( datasets_to_preload, h5ad_file, num_threads=num_threads) # 3. Assemble `obsm` with h5py.File(h5ad_file) as f: if 'obsm' in f: obsm = { key: SingleCell._read_dataset(value, preloaded_datasets) if isinstance(value, h5py.Dataset) else SingleCell._read_h5ad_dataframe( f, f'obsm/{key}', preloaded_datasets) for key, value in f['obsm'].items()} else: obsm = {} return obsm
[docs] @staticmethod def read_varm(h5ad_file: str | Path, *, keys: str | Iterable[str] | None = None, num_threads: int | np.integer = -1) -> \ dict[str, np.ndarray | pl.DataFrame]: """ Load just `varm` from an `.h5ad` file as a dictionary of Numpy arrays or DataFrames. Args: h5ad_file: an `.h5ad` filename keys: the keys(s) of `varm` to load; if `None`, load all keys num_threads: the number of threads to use when reading DataFrame keys of `varm`. By default (`num_threads=-1`), use all available cores, as determined by `os.cpu_count()`. Returns: A dictionary of NumPy arrays and polars DataFrames of the data in varm. """ check_type(h5ad_file, 'h5ad_file', (str, Path), 'a string or pathlib.Path') h5ad_file = str(h5ad_file) if not h5ad_file.endswith('.h5ad'): error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'" raise ValueError(error_message) filename = os.path.expanduser(h5ad_file) if not os.path.exists(filename): error_message = f'.h5ad file {h5ad_file} does not exist' raise FileNotFoundError(error_message) if keys is not None: keys = to_tuple_checked(keys, 'keys', str, 'strings') # Check that `num_threads` is a positive integer or -1; if -1, set to # `os.cpu_count()` num_threads = SingleCell._process_num_threads_static(num_threads) # Load `varm`, using a similar approach to loading the full `.h5ad` # file if num_threads == 1: preloaded_datasets = {} else: # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() with h5py.File(h5ad_file) as f: if 'varm' in f: varm = f['varm'] if keys is None: for key, value in varm.items(): if isinstance(value, h5py.Dataset): datasets_to_load.add(f'varm/{key}') else: SingleCell._tabulate_h5ad_dataframe( f, f'varm/{key}', datasets_to_load) else: for key_index, key in enumerate(keys): if key not in varm: error_message = ( f'keys[{key_index}] is {key!r}, which is ' f'not a key of varm') raise ValueError(error_message) for key in keys: value = varm[key] if isinstance(value, h5py.Dataset): datasets_to_load.add(f'varm/{key}') else: SingleCell._tabulate_h5ad_dataframe( f, f'varm/{key}', datasets_to_load) else: if keys is not None: error_message = 'keys was specified, but varm is empty' raise ValueError(error_message) return {} # 2. Preload datasets in parallel # If any fixed-width dataset is chunked or compressed # (indicated by a `None` `file_offset`), fall back to # the multiprocessing-based reader. # Crucial: the HDF5 file must be closed before # `read_parallel_multiprocessing()` to avoid the child # processes inheriting the HDF5 library's state # associated with the open file handle when forking. datasets_to_preload = \ SingleCell._get_datasets_to_preload( f, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): f.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, h5ad_file, num_threads=num_threads) f = h5py.File(h5ad_file) else: preloaded_datasets = SingleCell._read_parallel( datasets_to_preload, h5ad_file, num_threads=num_threads) # 3. Assemble `varm` with h5py.File(h5ad_file) as f: if 'varm' in f: varm = { key: SingleCell._read_dataset(value, preloaded_datasets) if isinstance(value, h5py.Dataset) else SingleCell._read_h5ad_dataframe( f, f'varm/{key}', preloaded_datasets) for key, value in f['varm'].items()} else: varm = {} return varm
[docs] @staticmethod def read_obsp(h5ad_file: str | Path, *, keys: str | Iterable[str] | None = None, num_threads: int | np.integer = -1) -> \ dict[str, csr_array | csc_array]: """ Load just `obsp` from an `.h5ad` file as a dictionary of sparse arrays. Args: h5ad_file: an `.h5ad` filename keys: the keys(s) of `obsp` to load; if `None`, load all keys num_threads: the number of threads to use when reading DataFrame keys of `varm`. By default (`num_threads=-1`), use all available cores, as determined by `os.cpu_count()`. Returns: A dictionary of sparse arrays of the data in `obsp`. """ check_type(h5ad_file, 'h5ad_file', (str, Path), 'a string or pathlib.Path') h5ad_file = str(h5ad_file) if not h5ad_file.endswith('.h5ad'): error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'" raise ValueError(error_message) filename = os.path.expanduser(h5ad_file) if not os.path.exists(filename): error_message = f'.h5ad file {h5ad_file} does not exist' raise FileNotFoundError(error_message) if keys is not None: keys = to_tuple_checked(keys, 'keys', str, 'strings') # Load `obsp`, using a similar approach to loading the full `.h5ad` # file if num_threads == 1: preloaded_datasets = {} else: # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() with h5py.File(h5ad_file) as f: if 'obsp' in f: obsp = f['obsp'] if keys is None: for value in f['obsp'].values(): datasets_to_load.add(value['data'].name) datasets_to_load.add(value['indices'].name) datasets_to_load.add(value['indptr'].name) else: for key_index, key in enumerate(keys): if key not in obsp: error_message = ( f'keys[{key_index}] is {key!r}, which is ' f'not a key of obsp') raise ValueError(error_message) for key in keys: value = obsp[key] datasets_to_load.add(value['data'].name) datasets_to_load.add(value['indices'].name) datasets_to_load.add(value['indptr'].name) else: if keys is not None: error_message = 'keys was specified, but obsp is empty' raise ValueError(error_message) return {} # 2. Preload datasets in parallel # If any fixed-width dataset is chunked or compressed # (indicated by a `None` `file_offset`), fall back to # the multiprocessing-based reader. # Crucial: the HDF5 file must be closed before # `read_parallel_multiprocessing()` to avoid the child # processes inheriting the HDF5 library's state # associated with the open file handle when forking. datasets_to_preload = \ SingleCell._get_datasets_to_preload( f, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): f.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, h5ad_file, num_threads=num_threads) f = h5py.File(h5ad_file) else: preloaded_datasets = SingleCell._read_parallel( datasets_to_preload, h5ad_file, num_threads=num_threads) # 3. Assemble `obsp` with h5py.File(h5ad_file) as f: if 'obsp' in f: obsp = { key: (csr_array if value.attrs['encoding-type'] == 'csr_matrix' else csc_array)( (SingleCell._read_dataset(value['data'], preloaded_datasets), SingleCell._read_dataset(value['indices'], preloaded_datasets), SingleCell._read_dataset(value['indptr'], preloaded_datasets)), shape=value.attrs['shape']) for key, value in f['obsp'].items()} else: obsp = {} return obsp
[docs] @staticmethod def read_varp(h5ad_file: str | Path, *, keys: str | Iterable[str] | None = None, num_threads: int | np.integer = -1) -> \ dict[str, csr_array | csc_array]: """ Load just `varp` from an `.h5ad` file as a dictionary of sparse arrays. Args: h5ad_file: an `.h5ad` filename keys: the keys(s) of `varp` to load; if `None`, load all keys num_threads: the number of threads to use when reading DataFrame keys of `varm`. By default (`num_threads=-1`), use all available cores, as determined by `os.cpu_count()`. Returns: A dictionary of sparse arrays of the data in `varp`. """ check_type(h5ad_file, 'h5ad_file', (str, Path), 'a string or pathlib.Path') h5ad_file = str(h5ad_file) if not h5ad_file.endswith('.h5ad'): error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'" raise ValueError(error_message) filename = os.path.expanduser(h5ad_file) if not os.path.exists(filename): error_message = f'.h5ad file {h5ad_file} does not exist' raise FileNotFoundError(error_message) if keys is not None: keys = to_tuple_checked(keys, 'keys', str, 'strings') # Load `varp`, using a similar approach to loading the full `.h5ad` # file if num_threads == 1: preloaded_datasets = {} else: # 1. Tabulate which HDF5 datasets need to be loaded datasets_to_load = set() with h5py.File(h5ad_file) as f: if 'varp' in f: varp = f['varp'] if keys is None: for value in f['varp'].values(): datasets_to_load.add(value['data'].name) datasets_to_load.add(value['indices'].name) datasets_to_load.add(value['indptr'].name) else: for key_index, key in enumerate(keys): if key not in varp: error_message = ( f'keys[{key_index}] is {key!r}, which is ' f'not a key of varp') raise ValueError(error_message) for key in keys: value = varp[key] datasets_to_load.add(value['data'].name) datasets_to_load.add(value['indices'].name) datasets_to_load.add(value['indptr'].name) else: if keys is not None: error_message = 'keys was specified, but varp is empty' raise ValueError(error_message) return {} # 2. Preload datasets in parallel # If any fixed-width dataset is chunked or compressed # (indicated by a `None` `file_offset`), fall back to # the multiprocessing-based reader. # Crucial: the HDF5 file must be closed before # `read_parallel_multiprocessing()` to avoid the child # processes inheriting the HDF5 library's state # associated with the open file handle when forking. datasets_to_preload = \ SingleCell._get_datasets_to_preload( f, datasets_to_load) if any(dataset[-1] is None for dataset in datasets_to_preload[0]): f.close() preloaded_datasets = \ read_parallel_multiprocessing( datasets_to_preload, h5ad_file, num_threads=num_threads) f = h5py.File(h5ad_file) else: preloaded_datasets = SingleCell._read_parallel( datasets_to_preload, h5ad_file, num_threads=num_threads) # 3. Assemble `varp` with h5py.File(h5ad_file) as f: if 'varp' in f: varp = { key: (csr_array if value.attrs['encoding-type'] == 'csr_matrix' else csc_array)( (SingleCell._read_dataset(value['data'], preloaded_datasets), SingleCell._read_dataset(value['indices'], preloaded_datasets), SingleCell._read_dataset(value['indptr'], preloaded_datasets)), shape=value.attrs['shape']) for key, value in f['varp'].items()} else: varp = {} return varp
[docs] @staticmethod def read_uns(h5ad_file: str | Path) -> UnsDict: """ Load just `uns` from an `.h5ad` file as a dictionary. Args: h5ad_file: an .`h5ad` filename Returns: A dictionary of the data in `uns`. """ check_type(h5ad_file, 'h5ad_file', (str, Path), 'a string or pathlib.Path') h5ad_file = str(h5ad_file) if not h5ad_file.endswith('.h5ad'): error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'" raise ValueError(error_message) filename = os.path.expanduser(h5ad_file) if not os.path.exists(filename): error_message = f'.h5ad file {h5ad_file} does not exist' raise FileNotFoundError(error_message) with h5py.File(filename) as f: if 'uns' in f: return SingleCell._read_uns(f['uns']) else: return {}
@staticmethod def _print_matrix_info(X: h5py.Group | h5py.Dataset, X_name: str) -> None: """ Given a key of an `.h5ad` file representing a sparse or dense matrix, print its shape, data type and (if sparse) number of non-zero elements. Args: X: the key in the `.h5ad` file representing the matrix, as a `Group` or `Dataset` object X_name: the name of the key """ is_sparse = isinstance(X, h5py.Group) if is_sparse: data = X['data'] shape = X.attrs['shape'] if 'shape' in X.attrs else \ X.attrs['h5sparse_shape'] dtype = str(data.dtype) nnz = data.shape[0] print(f'{X_name}: {shape[0]:,} × {shape[1]:,} {dtype} sparse ' f'array with {nnz:,} non-zero elements and first non-zero ' f'element = {data[0]:.6g}') else: shape = X.shape dtype = str(X.dtype) print(f'{X_name}: {shape[0]:,} × {shape[1]:,} {dtype} dense ' f'matrix with first element = {X[0, 0]:.6g}')
[docs] @staticmethod def ls(h5ad_file: str | Path) -> None: """ Print the fields in an `.h5ad` file. This can be useful e.g. when deciding which count matrix to load via the `X_key` argument to `SingleCell()`. Args: h5ad_file: an `.h5ad` filename """ check_type(h5ad_file, 'h5ad_file', (str, Path), 'a string or pathlib.Path') h5ad_file = str(h5ad_file) if not h5ad_file.endswith('.h5ad'): error_message = f".h5ad file {h5ad_file!r} must end with '.h5ad'" raise ValueError(error_message) filename = os.path.expanduser(h5ad_file) if not os.path.exists(filename): error_message = f'.h5ad file {h5ad_file} does not exist' raise FileNotFoundError(error_message) try: terminal_width = os.get_terminal_size().columns except AttributeError: terminal_width = 80 # for Jupyter notebooks attrs = 'obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns' with h5py.File(filename) as f: # `X` SingleCell._print_matrix_info(f['X'], 'X') # layers if 'layers' in f: layers = f['layers'] if len(layers) > 0: for layer_name, layer in layers.items(): SingleCell._print_matrix_info( layer, f'layers[{layer_name!r}]') # `obs`, `var`, `obsm`, `varm`, `obsp`, `varp`, `uns` for attr in attrs: if attr in f: entries = f[attr] if (attr == 'obs' or attr == 'var') and \ isinstance(entries, h5py.Dataset) and \ np.issubdtype(entries.dtype, np.void): entries = entries.dtype.fields if len(entries) > 0: print(fill(f'{attr}: {", ".join(entries)}', width=terminal_width, subsequent_indent=' ' * (len(attr) + 2))) # raw if 'raw' in f: raw = f['raw'] if len(raw) > 0: print('raw:') if 'X' in raw: SingleCell._print_matrix_info(raw['X'], ' X') if 'layers' in raw: layers = raw['layers'] if len(layers) > 0: for layer_name, layer in layers.items(): SingleCell._print_matrix_info( layer, f' layers[{layer_name!r}]') for attr in attrs: if attr in raw: entries = raw[attr] if (attr == 'obs' or attr == 'var') and \ isinstance(entries, h5py.Dataset) and \ np.issubdtype(entries.dtype, np.void): entries = entries.dtype.fields if len(entries) > 0: print(fill(f' {attr}: {", ".join(entries)}', width=terminal_width, subsequent_indent=' ' * ( len(attr) + 6)))
def __eq__(self, other: SingleCell) -> bool: """ Test for equality with another SingleCell dataset. Args: other: the other SingleCell dataset to test for equality with Returns: Whether the two SingleCell datasets are identical. """ if not isinstance(other, SingleCell): error_message = ( f'the left-hand operand of `==` is a SingleCell dataset, but ' f'the right-hand operand has type {type(other).__name__!r}') raise TypeError(error_message) return (other._X is None if self._X is None else other._X is not None) and \ self._num_threads == other._num_threads and \ self._obs.equals(other._obs) and \ self._var.equals(other._var) and \ self._obsm.keys() == other._obsm.keys() and \ self._varm.keys() == other._varm.keys() and \ all(type(other._obsm[key]) is type(value) and (array_equal(other._obsm[key], value) if isinstance(value, np.ndarray) else other._obsm[key].equals(value)) for key, value in self._obsm.items()) and \ all(type(other._varm[key]) is type(value) and (array_equal(other._varm[key], value) if isinstance(value, np.ndarray) else other._varm[key].equals(value)) for key, value in self._varm.items()) and \ SingleCell._eq_uns(self._uns, other._uns) and \ self._X.equals(other._X) @staticmethod def _eq_uns(uns: UnsDict, other_uns: UnsDict, different_order_ok: bool = False) -> bool: """ Test whether two `uns` are equal. Args: uns: an `uns` other_uns: another `uns` different_order_ok: whether to consider `uns` and `other_uns` equal when they have the same keys and values, but in a different order Returns: Whether `uns` and `other_uns` are equal. """ return set(uns.keys()) == set(other_uns.keys()) \ if different_order_ok else uns.keys() == other_uns.keys() and all( isinstance(value, dict) and isinstance(other_value, dict) and SingleCell._eq_uns(value, other_value, different_order_ok) or isinstance(value, np.ndarray) and isinstance(other_value, np.ndarray) and array_equal(value, other_value) or not isinstance(other_value, (dict, np.ndarray)) and value == other_value for key, value, other_value in ((key, value, other_uns[key]) for key, value in uns.items())) @staticmethod def _getitem_error(item: Indexer | tuple[Indexer, Indexer]) -> NoReturn: """ Raise an error if the indexer is invalid. Args: item: the indexer """ types = tuple(type(elem).__name__ for elem in to_tuple(item)) if len(types) == 1: types = types[0] error_message = ( f'SingleCell indices must be cells, a length-1 tuple of (cells,), ' f'or a length-2 tuple of (cells, genes). Cells and genes must ' f'each be a string or integer; a slice of strings or integers; or ' f'a list, NumPy array, or polars Series of strings, integers, or ' f'Booleans. You indexed with: {types}.') raise ValueError(error_message) @staticmethod def _getitem_by_string(df: pl.DataFrame, string: str) -> int: """ Get the index where df[:, 0] == string, raising an error if no rows or multiple rows match. Args: df: a DataFrame (`obs` or `var`) string: the string to find the index of in the first column of df Returns: The integer index of the string within the first column of df. """ first_column = df.columns[0] try: return df\ .select(pl.int_range(pl.len(), dtype=pl.Int32) .alias('_SingleCell_getitem'), first_column)\ .row(by_predicate=pl.col(first_column) == string)\ [0] except pl.exceptions.NoRowsReturnedError: raise KeyError(string) @staticmethod def _getitem_process(item: Indexer | tuple[Indexer, Indexer], index: int, df: pl.DataFrame) -> list[int] | slice | pl.Series: """ Process an element of an item passed to `__getitem__()`. Args: item: the item index: the index of the element to process df: the DataFrame (`obs` or `var`) to process the element with respect to Returns: A new indexer indicating the rows/columns to index. """ subitem = item[index] if isinstance(subitem, (int, np.integer)): return [subitem] elif isinstance(subitem, str): return [SingleCell._getitem_by_string(df, subitem)] elif isinstance(subitem, slice): start = subitem.start stop = subitem.stop step = subitem.step if isinstance(start, str): start = SingleCell._getitem_by_string(df, start) elif start is not None and \ not isinstance(start, (int, np.integer)): SingleCell._getitem_error(item) if isinstance(stop, str): stop = SingleCell._getitem_by_string(df, stop) elif stop is not None and not isinstance(stop, (int, np.integer)): SingleCell._getitem_error(item) if step is not None and not isinstance(step, (int, np.integer)): SingleCell._getitem_error(item) return slice(start, stop, step) elif isinstance(subitem, (tuple, list, np.ndarray, pl.Series)): subitem = pl.Series(subitem) if subitem.is_null().any(): error_message = 'your indexer contains missing values' raise ValueError(error_message) dtype = subitem.dtype if dtype in (pl.String, pl.Enum, pl.Categorical): names_dtype = df[:, 0].dtype if dtype != names_dtype: subitem = subitem.cast(names_dtype) indices = subitem\ .to_frame(df.columns[0])\ .join(df.with_columns(_SingleCell_index=pl.int_range( pl.len(), dtype=pl.UInt32)), on=df.columns[0], how='left')\ ['_SingleCell_index'] if indices.null_count(): error_message = subitem.filter(indices.is_null())[0] raise KeyError(error_message) return indices elif dtype.is_integer() or dtype == pl.Boolean: return subitem else: SingleCell._getitem_error(item) else: SingleCell._getitem_error(item) def __getitem__(self, item: Indexer | tuple[Indexer, Indexer]) -> \ SingleCell: """ Subset to specific cell(s) and/or gene(s). Index with a tuple of `(cells, genes)`. If `cells` and `genes` are integers, arrays/lists/slices of integers, or arrays/lists of Booleans, the result will be a SingleCell dataset subset to `X[cells, genes]`, `obs[cells]`, `var[genes]`, `obsm[cells]`, `varm[genes]`, `obsp[cells][:, cells]`, and `varp[genes][:, genes]`. However, `cells` and/or `genes` can instead be strings (or arrays or slices of strings), in which case they refer to the first column of `obs` (`obs_names`) and/or `var` (`var_names`), respectively. Args: item: the item to index with Returns: A new SingleCell dataset subset to the specified cells and/or genes. Examples: Subset to one cell (all genes): >>> sc[2] >>> sc['CGAATTGGTGACAGGT-L8TX_210916_01_B05-1131590416'] Subset to one gene (all cells): >>> sc[:, 13196] >>> sc[:, 'APOE'] Subset to one cell and one gene: >>> sc[2, 13196] >>> sc['CGAATTGGTGACAGGT-L8TX_210916_01_B05-1131590416', 'APOE'] Subset to a range of cells and genes: >>> sc[2:6, 13196:34268] >>> sc['CGAATTGGTGACAGGT-L8TX_210916_01_B05-1131590416': ... 'CCCTCTCAGCAGCCTC-L8TX_211007_01_A09-1135034522', ... 'APOE':'TREM2'] Subset to specific cells or genes: >>> sc[[2, 5, 9]] >>> sc[:, ['APOE', 'TREM2']] Boolean indexing: >>> sc[sc.obs['cell_type'] == 'microglia'] >>> sc[:, sc.var['gene_type'] == 'protein_coding'] Use different indexing types for cells and genes: >>> sc[sc.obs['batch'] == 'A', ['APOE', 'TREM2']] ``` """ if not isinstance(item, (int, str, slice, tuple, list, np.ndarray, pl.Series)): error_message = ( f'SingleCell datasets must be indexed with an integer, ' f'string, slice, tuple, list, NumPy array, or polars Series, ' f'but you tried to index with an object of type ' f'{type(item).__name__!r}') raise TypeError(error_message) if isinstance(item, tuple): if not 1 <= len(item) <= 2: self._getitem_error(item) else: item = item, rows = self._getitem_process(item, 0, self._obs) rows_is_Series = isinstance(rows, pl.Series) if rows_is_Series: boolean_Series = rows.dtype == pl.Boolean obs = self._obs.filter(rows) if boolean_Series else self._obs[rows] rows_NumPy = rows.to_numpy() else: boolean_Series = False obs = self._obs[rows] rows_NumPy = \ np.asarray(rows) if not isinstance(rows, slice) else rows obsm = {key: (value.filter(rows) if boolean_Series else value[rows]) if isinstance(value, pl.DataFrame) else value[rows_NumPy] for key, value in self._obsm.items()} if self._obsm else {} rows_is_slice = isinstance(rows, slice) obsp = ({key: value[rows_NumPy, rows_NumPy] for key, value in self._obsp.items()} if rows_is_slice else {key: value[ix_symmetric(rows_NumPy)] for key, value in self._obsp.items()}) if self._obsp else {} if len(item) == 1: return SingleCell( X=self._X[rows_NumPy] if self._X is not None else None, obs=obs, var=self._var, obsm=obsm, varm=self._varm, obsp=obsp, varp=self._varp,uns=self._uns, num_threads=self._num_threads) cols = self._getitem_process(item, 1, self._var) cols_is_Series = isinstance(cols, pl.Series) if cols_is_Series: boolean_Series = cols.dtype == pl.Boolean var = self._var.filter(cols) if boolean_Series else self._var[cols] cols_NumPy = cols.to_numpy() else: boolean_Series = False var = self._var[cols] cols_NumPy = \ np.asarray(cols) if not isinstance(cols, slice) else cols varm = {key: (value.filter(cols) if boolean_Series else value[cols]) if isinstance(value, pl.DataFrame) else value[cols_NumPy] for key, value in self._varm.items()} if self._varm else {} cols_is_slice = isinstance(cols, slice) varp = ({key: value[cols_NumPy, cols_NumPy] for key, value in self._varp.items()} if cols_is_slice else {key: value[ix_symmetric(cols_NumPy)] for key, value in self._varp.items()}) if self._varp else {} X = None if self._X is None else self._X[rows_NumPy, cols_NumPy] \ if rows_is_slice or cols_is_slice else \ self._X[np.ix_(rows_NumPy, cols_NumPy)] return SingleCell(X=X, obs=obs, var=var, obsm=obsm, varm=varm, obsp=obsp, varp=varp, uns=self._uns, num_threads=self._num_threads)
[docs] def cell(self, cell: str, /, *, num_threads: int | np.integer | None = None) -> \ np.ndarray: """ Get the row of `X` corresponding to a single cell, based on the cell's name in `obs_names`. Args: cell: the name of the cell in `obs_names` num_threads: the number of threads to use when retrieving the row of `X`. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores when `X` is a CSC array and 1 thread when `X` is a CSR array. Cannot be specified when `X` is a CSR array, since there is no benefit to parallelism in that case. Returns: The corresponding row of `X`, as a dense 1D NumPy array with zeros included. """ # Check that `X` is present if self._X is None: error_message = 'X is None, so getting a row of X is not possible' raise ValueError(error_message) # Check that `num_threads` is not specified when `X` is a CSR array if num_threads is not None and isinstance(self._X, csr_array): error_message = ( 'num_threads cannot be specified when X is a CSR array, since ' 'cell() will always run single-threaded') raise ValueError(error_message) # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) row_index = SingleCell._getitem_by_string(self._obs, cell) original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads return self._X[row_index].toarray().squeeze() finally: self._X._num_threads = original_num_threads
[docs] def gene(self, gene: str, /, *, num_threads: int | np.integer | None = None) -> \ np.ndarray: """ Get the column of `X` corresponding to a single gene, based on the gene's name in `var_names`. Args: gene: the name of the gene in `var_names` num_threads: the number of threads to use when retrieving the row of `X`. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores when `X` is a CSR array and 1 thread when `X` is a CSC array. Cannot be specified when `X` is a CSC array, since there is no benefit to parallelism in that case. Returns: The corresponding column of `X`, as a dense 1D NumPy array with zeros included. """ # Check that `X` is present if self._X is None: error_message = \ 'X is None, so getting a column of X is not possible' raise ValueError(error_message) # Check that `num_threads` is not specified when `X` is a CSC array if num_threads is not None and isinstance(self._X, csc_array): error_message = ( 'num_threads cannot be specified when X is a CSC array, since ' 'cell() will always run single-threaded') raise ValueError(error_message) # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) column_index = SingleCell._getitem_by_string(self._var, gene) original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads return self._X[:, column_index].toarray().squeeze() finally: self._X._num_threads = original_num_threads
def __len__(self) -> int: """ Get the number of cells in this SingleCell dataset. Returns: The number of cells. """ return len(self._obs) def __repr__(self) -> str: """ Get a string representation of this SingleCell dataset. Returns: A string summarizing the dataset. """ descr = ( f'SingleCell dataset in ' f'{"CSR" if isinstance(self._X, csr_array) else "CSC"} format ' f'with {len(self._obs):,} {plural("cell", len(self._obs))} (obs), ' f'{len(self._var):,} {plural("gene", len(self._var))} (var), and ') if self._X is not None: descr += (f'{self._X.nnz:,} non-zero {self._X.dtype} ' f'{"entries" if self._X.nnz != 1 else "entry"} (X)') else: descr += 'no X' try: terminal_width = os.get_terminal_size().columns except AttributeError: terminal_width = 80 # for Jupyter notebooks for attr in 'obs', 'var', 'obsm', 'varm', 'obsp', 'varp', 'uns': entries = getattr(self, attr).columns \ if attr == 'obs' or attr == 'var' else getattr(self, attr) if len(entries) > 0: descr += '\n' + fill( f' {attr}: {", ".join(entries)}', width=terminal_width, subsequent_indent=' ' * (len(attr) + 6)) return descr def __iter__(self): error_message = 'SingleCell datasets do not support iteration' raise TypeError(error_message) @property def shape(self) -> tuple[int, int]: """ The shape of this SingleCell dataset: a length-2 tuple where the first element is the number of cells, and the second is the number of genes. """ return len(self._obs), len(self._var) @staticmethod def _save_h5ad_dataframe(h5ad_file: h5py.File, df: pl.DataFrame, key: str, preserve_strings: bool) -> None: """ Save `obs` or `var` to an `.h5ad` file. Args: h5ad_file: an `h5py.File` open in write mode df: the DataFrame to write, e.g. `obs` or `var` key: the key to create in `h5ad_file`, e.g. `'obs'` or `'var'` preserve_strings: if `False`, encode string columns with duplicate values as Enums to save space; if `True`, preserve these columns as string columns """ # Create a group for the data frame and add top-level metadata group = h5ad_file.create_group(key) group.attrs['_index'] = df.columns[0] group.attrs['column-order'] = df.columns[1:] group.attrs['encoding-type'] = 'dataframe' group.attrs['encoding-version'] = '0.2.0' for column in df: dtype = column.dtype if dtype == pl.String: if column.null_count() or not preserve_strings and \ column.is_duplicated().any(): column = column\ .cast(pl.Enum(column.unique(maintain_order=True) .drop_nulls())) dtype = column.dtype else: dataset = group.create_dataset(column.name, data=column.to_numpy()) dataset.attrs['encoding-type'] = 'string-array' dataset.attrs['encoding-version'] = '0.2.0' continue if dtype == pl.Enum or dtype == pl.Categorical: is_Enum = dtype == pl.Enum subgroup = group.create_group(column.name) subgroup.attrs['encoding-type'] = 'categorical' subgroup.attrs['encoding-version'] = '0.2.0' subgroup.attrs['ordered'] = is_Enum categories = column.cat.get_categories() if not is_Enum: column = column.cast(pl.Enum(categories)) codes = column.to_physical().fill_null(-1) subgroup.create_dataset('codes', data=codes.to_numpy()) if len(categories) == 0: subgroup.create_dataset('categories', shape=(0,), dtype=h5py.special_dtype(vlen=str)) else: subgroup.create_dataset('categories', data=categories.to_numpy()) elif dtype.is_float(): # Nullable floats are not supported, so convert `null` to `NaN` dataset = group.create_dataset( column.name, data=column.fill_null(np.nan).to_numpy()) dataset.attrs['encoding-type'] = 'array' dataset.attrs['encoding-version'] = '0.2.0' else: # Boolean or integer is_Boolean = dtype == pl.Boolean if column.null_count(): # Store as nullable integer/Boolean subgroup = group.create_group(column.name) subgroup.attrs['encoding-type'] = \ f'nullable-{"boolean" if is_Boolean else "integer"}' subgroup.attrs['encoding-version'] = '0.1.0' subgroup.create_dataset( 'values', data=column.fill_null(False if is_Boolean else 1) .to_numpy()) subgroup.create_dataset( 'mask', data=column.is_null().to_numpy()) else: # Store as regular integer/Boolean dataset = group.create_dataset(column.name, data=column.to_numpy()) dataset.attrs['encoding-type'] = 'array' dataset.attrs['encoding-version'] = '0.2.0' @staticmethod def _save_h5Seurat_dataframe(h5ad_file: h5py.File, df: pl.DataFrame, key: str, preserve_strings: bool) -> None: """ Save `obs` or `var` to an `.h5Seurat` file. Args: h5ad_file: an `h5py.File` open in write mode df: the DataFrame to write, e.g. `obs` or `var` key: the key to create in `h5ad_file`, e.g. `'meta.data'` or `'RNA/meta.features'` preserve_strings: if `False`, encode string columns with duplicate values as Enums to save space; if `True`, preserve these columns as string columns """ # Create a group for the data frame and add top-level metadata group = h5ad_file.create_group(key) group.attrs['_index'] = np.array([df.columns[0]], dtype=object) group.attrs['colnames'] = df.columns[1:] logicals = \ pl.selectors.expand_selector(df, pl.selectors.by_dtype(pl.Boolean)) if len(logicals) > 0: group.attrs['logicals'] = logicals for column in df: dtype = column.dtype if dtype == pl.String: if column.null_count() or not preserve_strings and \ column.is_duplicated().any(): column = column\ .cast(pl.Enum(column.unique(maintain_order=True) .drop_nulls())) dtype = column.dtype else: group.create_dataset(column.name, data=column.to_numpy()) continue if dtype == pl.Enum or dtype == pl.Categorical: subgroup = group.create_group(column.name) levels = column.cat.get_categories() if dtype != pl.Enum: column = column.cast(pl.Enum(levels)) values = (column.to_physical() + 1).fill_null(-2147483648) subgroup.create_dataset('values', data=values.to_numpy()) if len(levels) == 0: subgroup.create_dataset('levels', shape=(0,), dtype=h5py.special_dtype(vlen=str)) else: subgroup.create_dataset('levels', data=levels.to_numpy()) else: if dtype.is_float(): column = column.fill_null(np.nan) elif dtype == pl.Boolean: column = column.cast(pl.Int32).fill_null(2) else: # integer column = column.fill_null(-2147483648) group.create_dataset(column.name, data=column.to_numpy())
[docs] def save(self, filename: str | Path, /, *, assay: str = 'RNA', X_key: str = 'counts', overwrite: bool = False, preserve_strings: bool = False, empty_X: bool = False, v3: bool = False, sce: bool = False) -> None: """ Save this SingleCell dataset to a file. File format will be inferred from the file extension (e.g. `.h5ad`). Args: filename: an AnnData `.h5ad` file, Seurat `.rds` or `.h5Seurat` file, SingleCellExperiment `.rds` file, or 10x `.h5` or `.mtx`/`.mtx.gz` file to save to. If the extension is `.rds`, the `sce` argument will determine whether to save to a Seurat or a SingleCellExperiment object. - When saving to a Seurat `.rds` file, to match the requirements of Seurat objects, the `'X_'` prefix (often used by Scanpy) will be removed from each key of obsm where it is present (e.g. `'X_umap'` will become `'umap'`). - When saving to a Seurat `.rds` file, Seurat will add `'orig.ident'`, `'nCount_RNA'` and `'nFeature_RNA'` as gene-level metadata by default; you can disable the calculation of the latter two columns with: ```python from ryp import r r('options(Seurat.object.assay.calcn = FALSE)') ``` - When saving to a Seurat `.rds` or `.h5Seurat` file or a SingleCellExperiment `.rds` file, `varm` will not be saved. - When saving to a 10x `.h5` file, `obs['barcodes']`, `var['feature_type']`, `var['genome']`, `var['id']`, and `var['name']` must all exist. Only `X` and these columns will be saved, along with whichever of `var['pattern']`, `var['read']`, and `var['sequence']` exist. All of these columns (if they exist) must be String, Enum, Categorical, or integer. - When saving to a 10x `.mtx` file, `barcodes.tsv` and `features.tsv` will be created in the same directory. Only `X`, `obs` and `var` will be saved. - When saving to a 10x `.mtx.gz` file, `barcodes.tsv.gz` and `features.tsv.gz` will be created in the same directory. Only `X`, `obs` and `var` will be saved. assay: when saving to a Seurat `.rds` or `.h5Seurat` file, the name to use for the active assay X_key: when saving to a Seurat `.rds` or `.h5Seurat` file, the name of the layer within the active assay to save `X` to; must be `'counts'` or `'data'`. When saving to a SingleCellExperiment `.rds` file, the name of the layer within `@assays@data` to save `X` to. overwrite: if `False`, raises an error if (any of) the file(s) exist; if `True`, overwrites them preserve_strings: if `False`, encode string columns with duplicate values as Enums to save space, when saving to AnnData `.h5ad` or Seurat or SingleCellExperiment `.rds`; if `True`, preserve these columns as string columns. (Regardless of the value of `preserve_strings`, String columns with `null` values will be encoded as Enums when saving to `.h5ad`, since the `.h5ad` format cannot represent them otherwise.) empty_X: if `True`, allow saving of SingleCell datasets with missing counts (i.e. where `self.X is None`), by saving an empty, dummy count matrix with no non-zero entries v3: if `True`, use Seurat's version 3 format instead of the more current version 5 format when saving to a Seurat `.rds` file. When `v3=True`, the Seurat object's assay is created with `CreateAssayObject()` instead of `CreateAssay5Object()`. `v3=True` cannot be specified when saving to `.h5Seurat`, since `.h5Seurat` files [do not support](https://github.com/mojaveazure/seurat-disk/issues/147) Seurat version 5 and so saving to `.h5Seurat` will always save to version 3. sce: if `True` and the extension of filename is `.rds`, save to a SingleCellExperiment object instead of a Seurat object Examples: Save to an AnnData `.h5ad` file: >>> sc.save('data.h5ad') Overwrite an existing file: >>> sc.save('data.h5ad', overwrite=True) Save to a Seurat `.h5Seurat` file: >>> sc.save('seurat_obj.h5Seurat') Save to a Seurat `.rds` file: >>> sc.save('seurat_obj.rds') Save to a Seurat `.rds` file using the `'data'` layer instead of `'counts'`: >>> sc.save('seurat_obj.rds', X_key='data') Save to a Seurat version 3 `.rds` file: >>> sc.save('seurat_v3.rds', v3=True) Save to a SingleCellExperiment `.rds` file: >>> sc.save('sce_obj.rds', sce=True) Save to a 10x Genomics `.h5` file (requires `obs` to contain a `'barcodes'` column and `var` to contain `'feature_type'`, `'genome'`, `'id'`, and `'name'` columns): >>> sc.save('matrix.h5') Save to a 10x Genomics `.mtx.gz` file (creates `barcodes.tsv.gz` and `features.tsv.gz` in the same directory; unlike when saving to `.h5`, `obs` and `var` are not required to contain any specific columns): >>> sc.save('matrix.mtx.gz') Allow saving when `X` is missing by writing an empty count matrix: >>> sc.save('empty.h5ad', empty_X=True) Preserve string columns instead of encoding them as Enums: >>> sc.save('data_preserve_strings.h5ad', ... preserve_strings=True) """ # Check that `filename` is a string or `Path`; convert it to a string # and expand `~` into home directories check_type(filename, 'filename', (str, Path), 'a string or pathlib.Path') filename = str(filename) filename_expanduser = os.path.expanduser(filename) # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Raise an error if the filename already exists, unless # `overwrite=True` if not overwrite and os.path.exists(filename_expanduser): error_message = ( f'filename {filename!r} already exists; set overwrite=True ' f'to overwrite') raise FileExistsError(error_message) # Check that `empty_X` is `True` when `X` is absent, and vice versa. # If `empty_X` is `True`, make an empty R-compatible count matrix. check_type(empty_X, 'empty_X', bool, 'Boolean') if empty_X: if self._X is not None: error_message = \ 'X is not None, so empty_X=True cannot be specified' raise ValueError(error_message) counts = csr_array((np.array([], dtype=np.float64), np.array([], dtype=np.int32), np.array([0], dtype=np.int32)), shape=(0, 0)) else: if self._X is None: error_message = ( 'X is None, so saving requires creating an empty count ' 'matrix; specify empty_X=True to do this automatically ' 'when saving') raise ValueError(error_message) counts = self._X # Get the file type from the extension, raising an error if invalid is_h5ad = filename.endswith('.h5ad') is_rds = filename.endswith('.rds') is_h5Seurat = filename.endswith('.h5Seurat') or \ filename.endswith('.h5seurat') is_h5 = filename.endswith('.h5') is_hdf5 = is_h5ad or is_h5 or is_h5Seurat is_mtx = filename.endswith('.mtx') is_mtx_gz = filename.endswith('.mtx.gz') is_Seurat = is_h5Seurat or is_rds and not sce if not (is_hdf5 or is_mtx or is_mtx_gz or is_rds): error_message = ( f"filename {filename!r} does not end with '.h5ad', '.rds', " f"'.h5Seurat'/'.h5seurat', '.h5', or '.mtx.gz'") raise ValueError(error_message) # Check that `assay` is a string, and that `assay` is not specified # (i.e. retains its default value) unless saving to a Seurat object check_type(assay, 'assay', str, 'a string') if not is_Seurat and assay != 'RNA': error_message = \ 'assay cannot be specified unless saving to a Seurat object' raise ValueError(error_message) # Check that `X_key` is `'counts'` or `'data'` when saving to a Seurat # object, any string when saving to a SingleCellExperiment object, and # not specified (i.e. retains its default value) otherwise check_type(X_key, 'X_key', str, 'a string') if is_Seurat: if X_key != 'counts' and X_key != 'data': error_message = ( f"when saving to a Seurat object, X_key must be 'counts' " f"or 'data', not {X_key!r}") raise ValueError(error_message) elif not sce and X_key != 'counts': error_message = ( 'X_key cannot be specified unless saving to a Seurat or ' 'SingleCellExperiment object') raise ValueError(error_message) # Check that `preserve_strings` and `sce` are Boolean, and that `sce` # is only `True` when saving to an `.rds` file check_type(preserve_strings, 'preserve_strings', bool, 'Boolean') check_type(sce, 'sce', bool, 'Boolean') if sce and not is_rds: error_message = 'sce can only be True when saving to an .rds file' raise ValueError(error_message) # Check that `v3` is Boolean, and only `True` when saving to an `.rds` # file with `sce=False` check_type(v3, 'v3', bool, 'Boolean') if v3 and (not is_rds or sce): error_message = ( 'v3=True can only be specified when saving to a Seurat .rds ' 'file') raise ValueError(error_message) # Raise an error if `obs` or `var` (or, if saving to `.h5ad`, DataFrame # keys of `obsm` or `varm`) contain columns with unsupported data types # (anything but float, int, String, Enum, Categorical, Boolean) valid_dtypes = pl.FLOAT_DTYPES | pl.INTEGER_DTYPES | \ {pl.String, pl.Enum, pl.Categorical, pl.Boolean} for df, df_name in (self._obs, 'obs'), (self._var, 'var'): for column, dtype in df.schema.items(): if dtype.base_type() not in valid_dtypes: error_message = ( f'{df_name}[{column!r}] has the data type ' f'{dtype.base_type()!r}, which is not supported when ' f'saving') raise TypeError(error_message) if is_h5ad: for field, field_name in (self._obsm, 'obsm'), \ (self._varm, 'varm'): for key, value in field.items(): if not isinstance(value, pl.DataFrame): continue for column, dtype in value.schema.items(): if dtype.base_type() not in valid_dtypes: error_message = ( f'{field}[{key!r}][{column!r}] has the data ' f'type {dtype.base_type()!r}, which is not ' f'supported when saving') raise TypeError(error_message) # Raise an error if `obsm`, `varm` or `uns` contain NumPy arrays with # unsupported data types (`datetime64`, `timedelta64`, unstructured # `void`). Do not specifically check `dtype=object` to avoid extra # overhead. for field, field_name in \ (self._obsm, 'obsm'), (self._varm, 'varm'), (self._uns, 'uns'): for key, value in field.items(): if not isinstance(value, np.ndarray): continue if value.dtype.type == np.void and value.dtype.names is None: error_message = ( f'{field_name}[{key!r}] is an unstructured void ' f'array, which is not supported when saving') raise TypeError(error_message) elif value.dtype == np.datetime64: error_message = ( f'{field_name}[{key!r}] is a datetime64 array, which ' f'is not supported when saving') raise TypeError(error_message) elif value.dtype == np.timedelta64: error_message = ( f'{field_name}[{key!r}] is a timedelta64 array, which ' f'is not supported when saving') raise TypeError(error_message) # Save, depending on the file extension if is_hdf5: try: with h5py.File(filename_expanduser, 'w') as hdf5_file: if is_h5ad: # Add top-level metadata hdf5_file.attrs['encoding-type'] = 'anndata' hdf5_file.attrs['encoding-version'] = '0.1.0' # Save `obs` and `var` SingleCell._save_h5ad_dataframe( hdf5_file, self._obs, 'obs', preserve_strings) SingleCell._save_h5ad_dataframe( hdf5_file, self._var, 'var', preserve_strings) # Save `obsm` if self._obsm: obsm = hdf5_file.create_group('obsm') obsm.attrs['encoding-type'] = 'dict' obsm.attrs['encoding-version'] = '0.1.0' for key, value in self._obsm.items(): if isinstance(value, pl.DataFrame): SingleCell._save_h5ad_dataframe( hdf5_file, value, f'obsm/{key}', preserve_strings) else: obsm.create_dataset(key, data=value) # Save `varm` if self._varm: varm = hdf5_file.create_group('varm') varm.attrs['encoding-type'] = 'dict' varm.attrs['encoding-version'] = '0.1.0' for key, value in self._varm.items(): if isinstance(value, pl.DataFrame): SingleCell._save_h5ad_dataframe( hdf5_file, value, f'varm/{key}', preserve_strings) else: varm.create_dataset(key, data=value) # Save `obsp` obsp = hdf5_file.create_group('obsp') obsp.attrs['encoding-type'] = 'dict' obsp.attrs['encoding-version'] = '0.1.0' for key, value in self._obsp.items(): group = obsp.create_group(key) group.attrs['encoding-type'] = 'csr_matrix' \ if isinstance(value, csr_array) else \ 'csc_matrix' group.attrs['encoding-version'] = '0.1.0' group.attrs['shape'] = value.shape group.create_dataset('data', data=value.data) group.create_dataset('indices', data=value.indices) group.create_dataset('indptr', data=value.indptr) # Save `varp` varp = hdf5_file.create_group('varp') varp.attrs['encoding-type'] = 'dict' varp.attrs['encoding-version'] = '0.1.0' for key, value in self._varp.items(): group = varp.create_group(key) group.attrs['encoding-type'] = 'csr_matrix' \ if isinstance(value, csr_array) else \ 'csc_matrix' group.attrs['encoding-version'] = '0.1.0' group.attrs['shape'] = value.shape group.create_dataset('data', data=value.data) group.create_dataset('indices', data=value.indices) group.create_dataset('indptr', data=value.indptr) # Save `uns` if self._uns: SingleCell._save_uns(self._uns, hdf5_file.create_group('uns'), hdf5_file) # Save `X` X = hdf5_file.create_group('X') X.attrs['encoding-type'] = 'csr_matrix' \ if isinstance(counts, csr_array) else 'csc_matrix' X.attrs['encoding-version'] = '0.1.0' X.attrs['shape'] = counts.shape X.create_dataset('data', data=counts.data) X.create_dataset('indices', data=counts.indices) X.create_dataset('indptr', data=counts.indptr) elif is_h5Seurat: obs_names = self.obs_names var_names = self.var_names for names, names_name in (obs_names, 'obs_names'), \ (var_names, 'var_names'): null_count = names.null_count() if null_count: error_message = ( f'{names_name} contains {null_count:,} ' f'null values, but must not contain any ' f'when saving to an .h5Seurat file') raise ValueError(error_message) # Add top-level metadata and required groups/datasets hdf5_file.attrs['active.assay'] = \ np.array([assay], dtype=object) hdf5_file.attrs['project'] = \ np.array(['SeuratProject'], dtype=object) hdf5_file.attrs['version'] = \ np.array(['3.0.0'], dtype=object) for required_group in 'commands', 'images', 'tools': hdf5_file.create_group(required_group) active_ident = hdf5_file.create_group('active.ident') active_ident.create_dataset('levels', data=[b'local']) active_ident.create_dataset( 'values', data=np.ones(len(self._obs), dtype=np.int32)) hdf5_file.create_dataset('cell.names', data=obs_names.to_numpy()) active_assay = \ hdf5_file.create_group(f'assays/{assay}') active_assay.attrs['key'] = \ np.array([f'{assay.lower()}_'], dtype=object) active_assay.create_dataset( 'features', data=var_names.to_numpy()) active_assay.create_group('misc') # Save `obs` and `var` SingleCell._save_h5Seurat_dataframe( hdf5_file, self._obs, 'meta.data', preserve_strings) SingleCell._save_h5Seurat_dataframe( hdf5_file, self._var, f'assays/{assay}/meta.features', preserve_strings) # Save `obsm` reductions = hdf5_file.create_group('reductions') if self._obsm: for key, value in self._obsm.items(): if isinstance(value, pl.DataFrame): continue group = reductions.create_group(key) group.attrs['active.assay'] = \ np.array([assay], dtype=object) group.attrs['global'] = \ np.array([0], dtype=np.int32) group.attrs['key'] = \ np.array([f'{key.upper()}_'], dtype=object) group.create_dataset('cell.embeddings', data=value.T) # Save `obsp` graphs = hdf5_file.create_group('graphs') if self._obsp: for key, value in self._obsp.items(): if isinstance(value, csc_array): value = value.tocsr() group = graphs.create_group(key) group.attrs['assay.used'] = \ np.array([assay], dtype=object) group.attrs['dims'] = value.shape[::-1] group.create_dataset('data', data=value.data) group.create_dataset('indices', data=value.indices) group.create_dataset('indptr', data=value.indptr) # Save `uns` misc = hdf5_file.create_group('misc') if self._uns: SingleCell._save_h5Seurat_uns( self._uns, misc, hdf5_file) # Save `X` if isinstance(counts, csc_array): counts = counts.tocsr() X = active_assay.create_group(X_key) X.attrs['dims'] = counts.shape[::-1] X.create_dataset('data', data=counts.data) X.create_dataset('indices', data=counts.indices) X.create_dataset('indptr', data=counts.indptr) else: # `.h5` obs_columns = ['barcodes'] var_columns = ['feature_type', 'genome', 'id', 'name'] for columns, df, df_name in \ (obs_columns, self._obs, 'obs'), \ (var_columns, self._var, 'var'): for column in columns: if column not in df: error_message = ( f'{column!r} was not found in ' f'{df_name}, but is a required column ' f'when saving to a 10x .h5 file') raise ValueError(error_message) check_dtype(df[column], f'{df_name}[{column!r}]', (pl.String, pl.Categorical, pl.Enum)) all_tag_keys = ['genome'] for column in 'pattern', 'read', 'sequence': if column in self._var: check_dtype(self._var[column], f'var[{column!r}]', (pl.String, pl.Categorical, pl.Enum)) var_columns.append(column) all_tag_keys.append(column) matrix = hdf5_file.create_group('matrix') matrix.create_dataset('barcodes', data=self._obs[:, 0].to_numpy()) matrix.create_dataset('data', data=counts.data) features = matrix.create_group('features') matrix.create_dataset('indices', data=counts.indices) matrix.create_dataset('indptr', data=counts.indptr) matrix.create_dataset('shape', data=counts.shape[::-1]) features.create_dataset('_all_tag_keys', data=all_tag_keys) for column in var_columns: features.create_dataset( column, data=self._var[column].to_numpy()) except: if os.path.exists(filename_expanduser): os.unlink(filename_expanduser) raise elif is_mtx or is_mtx_gz: barcode_filename = os.path.join( os.path.dirname(filename_expanduser), 'barcodes.tsv' if is_mtx else 'barcodes.tsv.gz') feature_filename = os.path.join( os.path.dirname(filename_expanduser), 'features.tsv' if is_mtx else 'features.tsv.gz') if not overwrite: for ancillary_filename in barcode_filename, feature_filename: if os.path.exists(ancillary_filename): error_message = ( f'{ancillary_filename!r} already exists; set ' f'overwrite=True to overwrite') raise FileExistsError(error_message) from scipy.io import mmwrite try: mmwrite(filename_expanduser, counts.T) if is_mtx: self._obs.write_csv(barcode_filename, include_header=False) self._var.write_csv(feature_filename, include_header=False) else: import gzip with gzip.open(barcode_filename, 'wb') as f: self._obs.write_csv(f, include_header=False) with gzip.open(feature_filename, 'wb') as f: self._var.write_csv(f, include_header=False) except: if os.path.exists(filename_expanduser): os.unlink(filename_expanduser) if os.path.exists(barcode_filename): os.unlink(barcode_filename) if os.path.exists(feature_filename): os.unlink(feature_filename) raise else: from ryp import r if preserve_strings: sc = self else: # Convert string columns with duplicate values to Enum enumify = lambda df: df.cast({ row[0]: pl.Enum(row[1]) for row in df .select(pl.selectors.string() .unique(maintain_order=True) .implode() .list.drop_nulls()) .unpivot() .filter(pl.col.value.list.len() < len(df)) .rows()}) sc = SingleCell(X=counts, obs=enumify(self._obs), var=enumify(self._var), obsm=self._obsm, uns=self._uns, num_threads=self._num_threads) # Do not include `empty_X=empty_X` in `to_sce()` and `to_seurat()`, # since we've already handled it above if sce: sc.to_sce('.SingleCell.object', assay=X_key) else: sc.to_seurat('.SingleCell.object', assay=assay, layer=X_key, v3=v3) try: r(f'saveRDS(.SingleCell.object, {filename_expanduser!r})') except: if os.path.exists(filename_expanduser): os.unlink(filename_expanduser) raise finally: r('rm(.SingleCell.object)')
def _get_column(self, obs_or_var_name: Literal['obs', 'var'], column: SingleCellColumn, variable_name: str, dtypes: pl.datatypes.classes.DataTypeClass | str | tuple[pl.datatypes.classes.DataTypeClass | str, ...], *, QC_column: pl.Series | None = None, allow_missing: bool = False, allow_null: bool = False, custom_error: str | None = None) -> pl.Series | None: """ Get a column of the same length as `obs`/`var`, or `None` if the column is missing from `obs`/`var` and `allow_missing=True`. Args: obs_or_var_name: the name of the DataFrame the column is with respect to, i.e. `'obs'` or `'var'` column: a string naming a column of `obs`/`var`, a polars expression that evaluates to a single column when applied to `obs`/`var`, a polars Series or 1D NumPy array of the same length as `obs`/`var`, or a function that takes in `self` and returns a polars Series or 1D NumPy array of the same length as `obs`/`var` variable_name: the name of the variable corresponding to `column` dtypes: the required dtype(s) of the column QC_column: an optional column of cells passing QC. If specified, the presence of `null` values will only raise an error for cells passing QC. Has no effect when `allow_null=True`. allow_missing: whether to allow `column` to be a string missing from `obs`/`var`, returning `None` in this case allow_null: whether to allow `column` to contain `null` values custom_error: a custom error message for when `column` is a string and is not found in `obs`/`var`, and `allow_missing=False`; use `{}` as a placeholder for the name of the column Returns: A polars Series of the same length as `obs`/`var`, or `None` if the column is missing from `obs`/`var` and `allow_missing=True`. """ obs_or_var = self._obs if obs_or_var_name == 'obs' else self._var if isinstance(column, str): variable_name = f'{variable_name} {column!r}' if column in obs_or_var: column = obs_or_var[column] elif allow_missing: return None else: error_message = \ f'{variable_name} is not a column of {obs_or_var_name}' \ if custom_error is None else \ custom_error.format(f'{column!r}') raise ValueError(error_message) elif isinstance(column, pl.Expr): column = obs_or_var.select(column) if column.width > 1: error_message = ( f'{variable_name} is a polars expression that expands to ' f'{column.width:,} columns rather than 1') raise ValueError(error_message) column = column.to_series() elif isinstance(column, pl.Series): if len(column) != len(obs_or_var): error_message = ( f'{variable_name} is a polars Series of length ' f'{len(column):,}, which differs from the length of ' f'{obs_or_var_name} ({len(obs_or_var):,})') raise ValueError(error_message) elif isinstance(column, np.ndarray): if len(column) != len(obs_or_var): error_message = ( f'{variable_name} is a NumPy array of length ' f'{len(column):,}, which differs from the length of ' f'{obs_or_var_name} ({len(obs_or_var):,})') raise ValueError(error_message) column = pl.Series(variable_name, column) elif callable(column): column = column(self) if isinstance(column, np.ndarray): if column.ndim != 1: error_message = ( f'{variable_name} is a function that returns a ' f'{column.ndim:,}D NumPy array, but must return a ' f'polars Series or 1D NumPy array') raise ValueError(error_message) column = pl.Series(variable_name, column) elif not isinstance(column, pl.Series): error_message = ( f'{variable_name} is a function that returns a variable ' f'of type {type(column).__name__}, but must return a ' f'polars Series or 1D NumPy array') raise TypeError(error_message) if len(column) != len(obs_or_var): error_message = ( f'{variable_name} is a function that returns a column of ' f'length {len(column):,}, which differs from the length ' f'of {obs_or_var_name} ({len(obs_or_var):,})') raise ValueError(error_message) else: error_message = ( f'{variable_name} must be a string column name, a polars ' f'expression, a polars Series, a 1D NumPy array, or a ' f'function that returns a polars Series or 1D NumPy array ' f'when applied to this SingleCell dataset, but has type ' f'{type(column).__name__!r}') raise TypeError(error_message) check_dtype(column, variable_name, dtypes) if not allow_null: if QC_column is None: null_count = column.null_count() if null_count > 0: error_message = ( f'{variable_name} contains {null_count:,} ' f'{plural("null value", null_count)}, but must not ' f'contain any') raise ValueError(error_message) else: null_count = (column.is_null() & QC_column).sum() if null_count > 0: error_message = ( f'{variable_name} contains {null_count:,} ' f'{plural("null value", null_count)} for cells ' f'passing QC, but must not contain any') raise ValueError(error_message) return column @staticmethod def _get_columns(obs_or_var_name: Literal['obs', 'var'], datasets: Sequence[SingleCell], columns: SingleCellColumn | None | Sequence[SingleCellColumn | None], variable_name: str, dtypes: pl.datatypes.classes.DataTypeClass | str | tuple[pl.datatypes.classes.DataTypeClass | str, ...], *, QC_columns: list[pl.Series | None] = None, allow_None: bool = True, allow_missing: bool = False, allow_null: bool = False, custom_error: str | None = None) -> \ list[pl.Series | None]: """ Get a column of the same length as `obs`/`var` from each dataset. Args: obs_or_var_name: the name of the DataFrame the column is with respect to, i.e. `'obs'` or `'var'` datasets: a sequence of SingleCell datasets columns: a string naming a column of `obs`/`var`, a polars expression that evaluates to a single column when applied to `obs`/`var`, a polars Series or 1D NumPy array of the same length as `obs`/`var`, or a function that takes in `self` and returns a polars Series or 1D NumPy array of the same length as `obs`/`var`. Or, a Sequence of these, one per dataset in `datasets`. May also be `None` (or a Sequence containing `None`) if `allow_None=True`. variable_name: the name of the variable corresponding to `columns` dtypes: the required dtype(s) of the columns QC_columns: an optional column of cells passing QC for each dataset. If not `None` for a given dataset, the presence of `null` values for that dataset will only raise an error for cells passing QC. Has no effect when `allow_null=True`. allow_None: whether to allow `columns` or its elements to be `None` allow_missing: whether to allow `columns` to be a string (or contain strings) missing from certain datasets' `obs`/`var`, returning `None` for these datasets allow_null: whether to allow `columns` to contain `null` values custom_error: a custom error message for when `column` is a string and is not found in `obs`/`var`, and `allow_missing=False`; use `{}` as a placeholder for the name of the column Returns: A list of polars Series of the same length as `datasets`, where each Series has the same length as the corresponding dataset's `obs`/`var`. Or, if `columns` is `None` (or if some elements are `None`) or missing from `obs`/`var` (when `allow_missing=True`), a list of `None` (or where the corresponding elements are `None`). """ if columns is None: if not allow_None: error_message = f'{variable_name} is None' raise TypeError(error_message) return [None] * len(datasets) if isinstance(columns, Sequence) and not isinstance(columns, str): if len(columns) != len(datasets): error_message = ( f'{variable_name} has length {len(columns):,}, but you ' f'specified {len(datasets):,} datasets') raise ValueError(error_message) if not allow_None and any(column is None for column in columns): error_message = \ f'{variable_name} contains an element that is None' raise TypeError(error_message) if QC_columns is None: return [dataset._get_column( obs_or_var_name=obs_or_var_name, column=column, variable_name=variable_name, dtypes=dtypes, allow_null=allow_null, allow_missing=allow_missing, custom_error=custom_error) if column is not None else None for dataset, column in zip(datasets, columns)] else: return [dataset._get_column( obs_or_var_name=obs_or_var_name, column=column, variable_name=variable_name, dtypes=dtypes, QC_column=QC_column, allow_null=allow_null, allow_missing=allow_missing, custom_error=custom_error) if column is not None else None for dataset, column, QC_column in zip(datasets, columns, QC_columns)] else: if QC_columns is None: return [dataset._get_column( obs_or_var_name=obs_or_var_name, column=columns, variable_name=variable_name, dtypes=dtypes, allow_null=allow_null, allow_missing=allow_missing, custom_error=custom_error) for dataset in datasets] else: return [dataset._get_column( obs_or_var_name=obs_or_var_name, column=columns, variable_name=variable_name, dtypes=dtypes, QC_column=QC_column, allow_null=allow_null, allow_missing=allow_missing, custom_error=custom_error) for dataset, QC_column in zip(datasets, QC_columns)] @staticmethod def _describe_column(column_name: str, column: SingleCellColumn): """ Describe a column-name argument in an error message. Args: column_name: the name of the column-name argument column: the value of the column-name argument Returns: The column's description: just the argument's name unless the value is a string (i.e. the column's name in `obs` or `var`), in which case also include the value. """ return f'{column_name} {column!r}' \ if isinstance(column, str) else column_name
[docs] def to_scanpy(self, *, QC_column: str | None = 'passed_QC') -> 'AnnData': """ Converts this SingleCell dataset to an AnnData object, the representation used by Scanpy. Make sure to remove cells failing QC with `sc.filter_obs(QC_column)` first, or specify `subset=True` in `qc()`. Alternatively, to include cells failing QC in the AnnData object, set `QC_column` to `None`. Note that there is no `from_scanpy()`; simply do `SingleCell(anndata_object)` to initialize a SingleCell dataset from an in-memory AnnData object. Args: QC_column: if not `None`, raise an error if this column is present in `obs` and not all cells pass QC Returns: An AnnData object. For AnnData versions older than 0.11.0, which do not support `csr_array`/`csc_array`, counts will be converted to `csr_matrix`/`csc_matrix`. Note: The count matrix is not copied during the conversion to save memory. This means modifying the AnnData object's count matrix will also modify the original SingleCell dataset. If this behavior is undesirable, explicitly copy the dataset before converting with `sc.copy().to_scanpy()`. """ signal.signal(signal.SIGINT, signal.SIG_IGN) try: from anndata import AnnData import pandas as pd finally: signal.signal(signal.SIGINT, signal.default_int_handler) valid_dtypes = pl.FLOAT_DTYPES | pl.INTEGER_DTYPES | \ {pl.String, pl.Enum, pl.Categorical, pl.Boolean} for df, df_name in (self._obs, 'obs'), (self._var, 'var'): for column, dtype in df.schema.items(): if dtype.base_type() not in valid_dtypes: error_message = ( f'{df_name}[{column!r}] has the data type ' f'{dtype.base_type()!r}, which is not supported by ' f'AnnData') raise TypeError(error_message) if QC_column is not None: check_type(QC_column, 'QC_column', str, 'a string') if QC_column in self._obs: QCed_cells = self._obs[QC_column] check_dtype(QCed_cells, f'obs[{QC_column!r}]', pl.Boolean) if QCed_cells.null_count() or not QCed_cells.all(): error_message = ( f'not all cells pass QC; remove cells failing QC with ' f'filter_obs({QC_column!r}) or by specifying ' f'subset=True in qc(), or set QC_column=None to ' f'include them in the AnnData object') raise ValueError(error_message) type_mapping = { pa.int8(): pd.Int8Dtype(), pa.int16(): pd.Int16Dtype(), pa.int32(): pd.Int32Dtype(), pa.int64(): pd.Int64Dtype(), pa.uint8(): pd.UInt8Dtype(), pa.uint16(): pd.UInt16Dtype(), pa.uint32(): pd.UInt32Dtype(), pa.uint64(): pd.UInt64Dtype(), pa.string(): pd.StringDtype(storage='pyarrow') if int(pd.__version__.split('.')[0]) >= 2 else object, pa.bool_(): pd.BooleanDtype()} to_pandas = lambda df: df\ .to_pandas(split_blocks=True, types_mapper=type_mapping.get)\ .set_index(df.columns[0]) return anndata.AnnData( X=None if self._X is None else sparse.csr_matrix(self._X) if isinstance(self._X, csr_array) else sparse.csc_matrix(self._X), obs=to_pandas(self._obs), var=to_pandas(self._var), obsm=dict(self._obsm), varm=dict(self._varm), obsp=dict(self._obsp), varp=dict(self._varp), uns=Uns._copy_uns(self._uns))
@staticmethod def _from_seurat(seurat_object_name: str, *, assay: str | None, layer: str, layer_name: str) -> \ tuple[csr_array | csc_array, pl.DataFrame, pl.DataFrame, dict[str, np.ndarray], dict[str, csc_array], UnsDict]: """ Create a SingleCell dataset from an in-memory Seurat object loaded with the ryp Python-R bridge. Used by `__init__()` and `from_seurat()`. Args: seurat_object_name: the name of the Seurat object in the ryp R workspace assay: the name of the assay within the Seurat object to load data from; if `None`, defaults to the Seurat object's `active.assay` attribute (usually `'RNA'`) layer: the layer within the active assay (or the assay specified by the `assay` argument, if not `None`) to use as `X`. Set to `'data'` to load the normalized counts, or `'scale.data'` to load the normalized and scaled counts, if available. If dense, will be automatically converted to a sparse array. layer_name: the name of the variable passed via the `layer` argument Returns: A length-5 tuple of (`X`, `obs`, `var`, `obsm`, `obsp`, `uns`). """ from ryp import r, to_py if assay is None: assay = to_py(f'{seurat_object_name}@active.assay') elif assay not in to_py(f'Assays({seurat_object_name})', squeeze=False): error_message = ( f'assay {assay!r} does not exist in ' f'{seurat_object_name}@assays; specify a different assay than ' f'{assay!r}') raise ValueError(error_message) assay_slot = f'{seurat_object_name}@assays${assay}' # If Seurat v5, merge layers if necessary, and use `$slot` instead of # `@slot` for `X` and `meta.data` instead of `meta.features` for `var` v5 = to_py(f'inherits({assay_slot}, "Assay5")') if v5: r(f'.SingleCell.layers = names({assay_slot}@layers)') try: layers = to_py('.SingleCell.layers') if layer not in layers: # The exact layer doesn't exist. Check for sharded versions # like `'counts.1'`, `'counts.2'`, etc. r(rf'.SingleCell.layers = grep(' rf'"^{layer}\\.[0-9]+$", .SingleCell.layers, ' rf'value=TRUE)') if to_py('length(.SingleCell.layers)') > 0: r(f'{assay_slot} = JoinLayers(' f'{assay_slot}, layers=.SingleCell.layers)') else: # Collapse shard names like `'counts.1'`, `'counts.2'` # etc. into just `'counts'` current_assay_layers = to_py( rf'unique(sub("\\.[0-9]+$", "", ' rf'names({assay_slot}@layers)))') error_message = ( f'{layer_name} {layer!r} does not exist in the ' f'current assay ({assay!r}); specify a different ' f'assay than {assay!r} or a different ' f'{layer_name} than {layer!r}. The layers in the ' f'current assay are: ' f'{", ".join(map(repr, current_assay_layers))}') raise ValueError(error_message) finally: r('rm(".SingleCell.layers")') # In Seurat v5, feature metadata (`meta.data`) is per assay, # while the features themselves (`features`) are per layer. To # reconcile them, we must get the layer-specific feature names and # use them to filter the assay-level metadata. var = to_py(f''' (function() {{ target_features = {assay_slot}@features[[{layer!r}]] meta.data = {assay_slot}@meta.data if (nrow(meta.data) == nrow({assay_slot})) {{ rownames(meta.data) = rownames({assay_slot}) }} res <- meta.data[target_features, , drop=FALSE] rownames(res) = target_features res }})()''', index='gene') X_slot = f'{assay_slot}@layers${layer}' else: # unlike v5 objects, v3 objects indicate the absence of a layer # with a 0 x 0 matrix if not (to_py(f'"{layer}" %in% slotNames({assay_slot})') and to_py(f'prod(dim({assay_slot}@{layer}))') > 0): current_assay_slots = to_py(f"slotNames({assay_slot})") error_message = ( f'{layer_name} {layer!r} does not exist in the current ' f'assay ({assay!r}); specify a different assay than ' f'{assay!r} or a different {layer_name} than {layer!r}. ' f'The layers in the current assay are: ' f'{", ".join(map(repr, current_assay_slots))}') raise ValueError(error_message) X_slot = f'{assay_slot}@{layer}' var = to_py(f'{assay_slot}@meta.features') X_classes = tuple(to_py(f'class({X_slot})', squeeze=False)) if X_classes == ('dgCMatrix',): X = to_py(X_slot).T elif X_classes == ('matrix', 'array'): warning_message = ( f"this Seurat object's {layer_name} {layer!r} is stored as a " f"dense matrix; auto-converting to a sparse csr_array") warnings.warn(warning_message) X = csr_array(to_py(X_slot, format='numpy').T) else: error_message = ( f'{layer_name} {layer!r} exists in {assay_slot} but is not a ' f'dgCMatrix (column-oriented sparse matrix) or matrix, ' f'instead having ') if len(X_classes) == 0: error_message += 'no classes' elif len(X_classes) == 1: error_message += f'the class {X_classes[0]!r}' else: error_message += ( f'the classes ' f'{", ".join(f"{c!r}" for c in X_classes[:-1])} and ' f'{X_classes[-1]}') error_message += ( f'; specify a different assay than {assay!r} or a different ' f'{layer_name} than {layer!r}') raise TypeError(error_message) obs_key = f'{seurat_object_name}@meta.data' obs = to_py(obs_key, index='_index' if to_py(f'"cell" %in% {obs_key}') else 'cell') if var is None: var = to_py(f'rownames({assay_slot}@{layer})').to_frame('gene') obs = obs.cast({column.name: pl.Enum(column.cat.get_categories()) for column in obs.select(pl.col(pl.Categorical))}) var = var.cast({column.name: pl.Enum(column.cat.get_categories()) for column in var.select(pl.col(pl.Categorical))}) reduction_names = to_py(f'names({seurat_object_name}@reductions)') obsm = {reduction_name: to_py(f'{seurat_object_name}@reductions$' f'{reduction_name}@cell.embeddings', format='numpy') for reduction_name in reduction_names if not to_py(f'is.null({seurat_object_name}@reductions$' f'{reduction_name})') and to_py(f'{seurat_object_name}@reductions${reduction_name}' f'@assay.used') == assay} \ if reduction_names is not None else {} graph_names = to_py(f'names({seurat_object_name}@graphs)') obsp = {graph_name: csc_array(( to_py(f'{seurat_object_name}${graph_name}@x', format='numpy'), to_py(f'{seurat_object_name}${graph_name}@i', format='numpy'), to_py(f'{seurat_object_name}${graph_name}@p', format='numpy')), shape=to_py(f'{seurat_object_name}${graph_name}@Dim', format='numpy')) for graph_name in graph_names if to_py(f'length({seurat_object_name}${graph_name}@' f'assay.used)') == 0 or to_py(f'{seurat_object_name}${graph_name}@assay.used') == assay } if graph_names is not None else {} uns = to_py(f'{seurat_object_name}@misc') if not uns: # uns may be an empty unnamed list in R, which ryp converts to a # Python list rather than a dictionary uns = {} return X, obs, var, obsm, obsp, uns
[docs] @staticmethod def from_seurat(seurat_object_name: str, *, assay: str | None = None, layer: str = 'counts', num_threads: int | np.integer = -1) -> SingleCell: """ Create a SingleCell dataset from a Seurat object that has already been loaded into memory via the ryp Python-R bridge. To load a Seurat object from disk, use e.g. `SingleCell('filename.rds')`. Both version 3 and version 5 Seurat objects are supported. Args: seurat_object_name: the name of the Seurat object in the ryp R workspace assay: the name of the assay within the Seurat object to load data from; if `None`, defaults to the Seurat object's `active.assay` attribute (usually `'RNA'`) layer: the layer within the active assay (or the assay specified by the `assay` argument, if not `None`) to use as `X`. Defaults to `'counts'`. Set to `'data'` to load the normalized counts, or `'scale.data'` to load the normalized and scaled counts, if available. If dense, will be automatically converted to a sparse array. num_threads: the default number of threads to use for all subsequent operations on this SingleCell dataset. Also sets the number of threads for this SingleCell dataset's count matrix, if present. Does not affect the number of threads used for data loading; this will always be single-threaded for Seurat objects. Returns: The corresponding SingleCell dataset. """ from ryp import to_py check_type(seurat_object_name, 'seurat_object_name', str, 'a string') check_R_variable_name(seurat_object_name, 'seurat_object_name') if assay is not None: check_type(assay, 'assay', str, 'a string') check_type(layer, 'layer', str, 'a string') num_threads = SingleCell._process_num_threads_static(num_threads) if not to_py(f'inherits({seurat_object_name}, "Seurat")'): classes = to_py(f'class({seurat_object_name})', squeeze=False) error_message = ( f'the R object named by seurat_object_name, ' f'{seurat_object_name}, must be a Seurat object, but has ') if len(classes) == 0: error_message += 'no classes' elif len(classes) == 1: error_message += f'the class {classes[0]!r}' else: error_message += ( f'the classes ' f'{", ".join(f"{c!r}" for c in classes[:-1])} and ' f'{classes[-1]!r}') raise TypeError(error_message) X, obs, var, obsm, obsp, uns = \ SingleCell._from_seurat(seurat_object_name, assay=assay, layer=layer, layer_name='layer') return SingleCell(X=X, obs=obs, var=var, obsm=obsm, obsp=obsp, uns=uns, num_threads=num_threads)
[docs] def to_seurat(self, seurat_object_name: str, /, *, QC_column: str | None = 'passed_QC', assay: str = 'RNA', layer: str = 'counts', empty_X: bool = False, v3: bool = False) -> None: """ Convert this SingleCell dataset to a Seurat object in the R workspace of the ryp Python-R bridge. Make sure to remove cells failing QC with `sc.filter_obs(QC_column)` first, or specify `subset=True` in `qc()`. Alternatively, to include cells failing QC in the Seurat object, set `QC_column` to `None`. When converting to Seurat, to match the requirements of Seurat objects, the `'X_'` prefix (often used by Scanpy) will be removed from each key of `obsm` where it is present (e.g. `'X_umap'` will become `'umap'`). Seurat will also add `'orig.ident'`, `'nCount_RNA'` and `'nFeature_RNA'` as gene-level metadata by default; you can disable the calculation of the latter two columns with: ```python from ryp import r r('options(Seurat.object.assay.calcn = FALSE)') ``` `varm` and DataFrame keys of `obsm` will not be converted. Args: seurat_object_name: the name of the R variable to assign the Seurat object to QC_column: if not `None`, raise an error if this column is present in `obs` and not all cells pass QC assay: the name to use for the active assay layer: the name of the layer within the active assay to save `X` to; must be `'counts'` or `'data'` empty_X: if `True`, allow converting SingleCell datasets with missing counts (i.e. where `self.X is None`), by using an empty, dummy matrix with no non-zero entries as the Seurat object's count matrix v3: if `True`, save to Seurat's version 3 format instead of the more current version 5 format. When `v3=True`, the Seurat object's assay is created with `CreateAssayObject()` instead of `CreateAssay5Object()`. Note: Seurat objects do not support count matrices with more than 2,147,483,647 (INT32_MAX) non-zero elements. If your SingleCell dataset is larger than this, it cannot be converted! Note: Seurat requires the counts to be float64 and the indices of the count matrix to be int32. To avoid copying data when converting, consider converting the counts to float64 with `sc.X = sc.X.astype(np.float64, copy=False)` and the indices with `sc.X.shrink_indices()` (an in-place operation). Note: In the special case where the counts are float64 or the indices of the count matrix are int32, the counts or indices will not be copied during the conversion to save memory. This means modifying the Seurat object's count matrix will also modify the original SingleCell dataset. If this behavior is undesirable, explicitly copy the dataset before converting with `sc.copy().to_seurat()`. """ # Check that `empty_X` is `True` when `X` is absent, and vice versa. # If `empty_X` is `True`, make an empty R-compatible count matrix. check_type(empty_X, 'empty_X', bool, 'Boolean') if empty_X: if self._X is not None: error_message = \ 'X is not None, so empty_X=True cannot be specified' raise ValueError(error_message) X = csr_array((np.array([], dtype=np.float64), np.array([], dtype=np.int32), np.array([0], dtype=np.int32)), shape=(0, 0)) else: if self._X is None: error_message = ( 'X is None, so converting to Seurat requires creating an ' 'empty count matrix; specify empty_X=True to do this ' 'automatically when saving') raise ValueError(error_message) X = self._X # Check that `X` is CSR if isinstance(self._X, csc_array): error_message = ( 'X is a csc_array; run `sc.tocsr() before converting to ' 'Seurat') raise TypeError(error_message) # Check that `X` has sorted indices if not X.has_sorted_indices: error_message = ( 'X does not have sorted indices; run ' '`sc.X.sort_indices()` (an in-place operation) before ' 'converting to Seurat') raise ValueError(error_message) # Check that `X` does not have too many non-zero elements if X.nnz > 2_147_483_647: error_message = ( f'X has {X.nnz:,} non-zero elements, more than ' f'2,147,483,647 (INT32_MAX), the maximum supported in R') raise ValueError(error_message) from ryp import r, to_r r('suppressPackageStartupMessages(library(Seurat))') check_type(seurat_object_name, 'seurat_object_name', str, 'a string') check_R_variable_name(seurat_object_name, 'seurat_object_name') if QC_column is not None: check_type(QC_column, 'QC_column', str, 'a string') if QC_column in self._obs: QCed_cells = self._obs[QC_column] check_dtype(QCed_cells, f'obs[{QC_column!r}]', pl.Boolean) if QCed_cells.null_count() or not QCed_cells.all(): error_message = ( f'not all cells pass QC; remove cells failing QC with ' f'filter_obs({QC_column!r}) or by specifying ' f'subset=True in qc(), or set QC_column=None to ' f'include them in the Seurat object') raise ValueError(error_message) check_type(assay, 'assay', str, 'a string') check_type(layer, 'X_key', str, 'a string') check_type(v3, 'v3', bool, 'Boolean') if layer != 'counts' and layer != 'data': error_message = ( f"when converting to a Seurat object, X_key must be 'counts' " f"or 'data', not {layer!r}") raise ValueError(error_message) valid_dtypes = pl.FLOAT_DTYPES | pl.INTEGER_DTYPES | \ {pl.String, pl.Enum, pl.Categorical, pl.Boolean} for df, df_name in (self._obs, 'obs'), (self._var, 'var'): for column, dtype in df.schema.items(): if dtype.base_type() not in valid_dtypes: error_message = ( f'{df_name}[{column!r}] has the data type ' f'{dtype.base_type()!r}, which is not supported when ' f'converting to a Seurat object') raise TypeError(error_message) obs_names = self.obs_names var_names = self.var_names for names, names_name in (obs_names, 'obs_names'), \ (var_names, 'var_names'): null_count = names.null_count() if null_count: error_message = ( f'{names_name} contains {null_count:,} null values, but ' f'must not contain any when converting to a Seurat object') raise ValueError(error_message) is_string = var_names.dtype == pl.String num_with_underscores = var_names.str.contains('_').sum() \ if is_string else \ var_names.cat.get_categories().str.contains('_').sum() if num_with_underscores: var_names_expression = f'pl.col.{var_names.name}' \ if var_names.name.isidentifier() else \ f'pl.col({var_names.name!r})' error_message = ( f"var_names contains {num_with_underscores:,}" f"{'' if is_string else ' unique'} gene " f"{plural('name', num_with_underscores)} with " f"underscores, which are not supported by Seurat; Seurat " f"recommends changing the underscores to dashes, which you " f"can do with .with_columns_var({var_names_expression}" f"{'' if is_string else '.cast(pl.String)'}" f".str.replace_all('_', '-'))") raise ValueError(error_message) # The R sparse matrix class Seurat uses (dgCMatrix) requires float64 # data and int32 indices and indptr. Convert each of these three to the # correct type, if not already correct. X = X.astype(np.float64, copy=False) if X.indices.dtype == np.int64: X = X.shrink_indices(copy=True) if obs_names.dtype in pl.INTEGER_DTYPES: idents = obs_names.cast(pl.String) # dummy idents elif obs_names.dtype == pl.String: idents = obs_names\ .str.splitn('_', 2)\ .struct.field('field_0')\ .cast(pl.Categorical) else: # Categorical or Enum idents = obs_names.to_frame().join( obs_names.to_frame().unique().with_columns( idents=pl.first() .cast(pl.String) .str.splitn('_', 2) .struct.field('field_0').cast(pl.Categorical)), on=obs_names.name)['idents'] # Create the Seurat object, with just counts (X) and meta.data (obs) try: to_r(obs_names, '.SingleCell.obs_names') to_r(var_names, '.SingleCell.var_names') to_r(X.T, '.SingleCell.X.T') r('rownames(.SingleCell.X.T) = .SingleCell.var_names') r('colnames(.SingleCell.X.T) = .SingleCell.obs_names') to_r(self._obs.drop(obs_names.name) .with_columns(idents.alias('orig.idents'), pl.all()), '.SingleCell.obs') r('rownames(.SingleCell.obs) = .SingleCell.obs_names') r('.SingleCell.active_ident = .SingleCell.obs$orig.idents') r('names(.SingleCell.active_ident) = .SingleCell.obs_names') r(f'.SingleCell.assays = list(CreateAssay{"" if v3 else "5"}' f'Object({layer}=.SingleCell.X.T))') r(f'names(.SingleCell.assays) = {assay!r}') r(f'Key(.SingleCell.assays[[{assay!r}]]) = "{assay.lower()}_"') r(f'''.SingleCell.seurat = new( Class = 'Seurat', assays = .SingleCell.assays, meta.data = .SingleCell.obs, active.assay = {assay!r}, active.ident = .SingleCell.active_ident, project.name = 'SeuratProject', version = packageVersion(pkg = 'SeuratObject'), graphs = list(), neighbors = list(), reductions = list(), images = list(), misc = list(), commands = list(), tools = list())''') # Add var to_r(self._var.drop(var_names.name), '.SingleCell.var', rownames=var_names) r(f'.SingleCell.seurat@assays${assay}@' f'meta.{"features" if v3 else "data"} = .SingleCell.var') # Add obsm if self._obsm: for original_key, value in self._obsm.items(): if isinstance(value, pl.DataFrame): continue dtype = value.dtype if np.issubdtype(dtype, np.floating): value = value.astype(np.float64, copy=False) # ryp converts uint32 to int32 zero-copy, so skip the cast elif np.issubdtype(dtype, np.integer) and \ dtype != np.uint32: value = value.astype(np.int32, copy=False) else: error_message = ( f'obsm[{original_key!r}] has data type {dtype}, ' f'but it must be integer or floating-point to be ' f'convertible to Seurat') raise TypeError(error_message) # Remove the initial X_ from the reduction name and suffix # with `'_'` when creating the key, like # mojaveazure.github.io/seurat-disk/reference/Convert.html key = original_key.removeprefix('X_') to_r(value, '.SingleCell.value', colnames=[ f'{key}_{i}' for i in range(1, value.shape[1] + 1)]) r('rownames(.SingleCell.value) = .SingleCell.obs_names') r(f'.SingleCell.seurat@reductions${key} = ' f'CreateDimReducObject(.SingleCell.value, ' f'key="{key}_", assay="{assay}")') # Add obsp if self._obsp: obsp = {} for key, value in self._obsp.items(): dtype = value.dtype if np.issubdtype(dtype, np.floating): value = value.astype(np.float64, copy=False) # ryp converts uint32 to int32 zero-copy, so skip the cast elif np.issubdtype(dtype, np.integer) and \ dtype != np.uint32: value = value.astype(np.int32, copy=False) else: error_message = ( f'obsp[{key!r}] has data type {dtype}, but it ' f'must be integer or floating-point to be ' f'convertible to Seurat') raise TypeError(error_message) if value.indices.dtype == np.int64: value = value.shrink_indices(copy=True) obsp[key] = value to_r(obsp, '.SingleCell.obsp', rownames=obs_names, colnames=obs_names) r(''' for (i in seq_along(.SingleCell.obsp)) { rownames(.SingleCell.obsp[[i]]) <- .SingleCell.obs_names colnames(.SingleCell.obsp[[i]]) <- .SingleCell.obs_names }''') r(f'.SingleCell.seurat@graphs = ' f'lapply(.SingleCell.obsp, as.Graph)') r('rm(.SingleCell.obsp)') # Add uns if self._uns: to_r(self._uns, '.SingleCell.uns') r(f'.SingleCell.seurat@misc = .SingleCell.uns') # Atomically assign the Seurat object to its final name at the end r(f'{seurat_object_name} = .SingleCell.seurat') finally: r(f'rm(list = Filter(exists, c(".SingleCell.obs_names", ' f'".SingleCell.var_names", ".SingleCell.X.T", ' f'".SingleCell.obs", ".SingleCell.active_ident", ' f'".SingleCell.assays", ".SingleCell.seurat", ' f'".SingleCell.var", ".SingleCell.value", ".SingleCell.obsp", ' f'".SingleCell.uns")))')
@staticmethod def _get_DFrame(dframe: str, *, index: str) -> pl.DataFrame: """ Convert a DFrame in the ryp R workspace to a Python DataFrame, raising an error if the DFrame contains any nested data stuctures. Args: dframe: the name of the DFrame object in the ryp R workspace index: the name of the column containing the rownames in the output DataFrame; if this column name is already present in the DFrame, fall back to calling the rownames column `'_index'` Returns: A polars DataFrame containing the data in `dframe`. """ from ryp import to_py df = to_py(f'{dframe}@listData', index=False) if not all(isinstance(value, pl.Series) for value in df.values()): error_message = ( f'{dframe} contains nested data; unnest before converting to ' f'a SingleCell dataset') raise ValueError(error_message) if index in df.keys(): index = '_index' df = pl.DataFrame({index: to_py(f'{dframe}@rownames')} | df) return df @staticmethod def _from_sce(sce_object_name: str, *, assay: str, assay_name: str) -> \ tuple[csr_array | csc_array, pl.DataFrame, pl.DataFrame, dict[str, np.ndarray], UnsDict]: """ Create a SingleCell dataset from an in-memory SingleCellExperiment object loaded with the ryp Python-R bridge. Used by `__init__()` and `from_sce()`. Args: sce_object_name: the name of the SingleCellExperiment object in the ryp R workspace assay: the element within `sce_object@assays@data` to use as `X`. Set to `'counts'` to load raw counts, or `'logcounts'` to load the normalized counts if available. If dense, will be automatically converted to a sparse array. assay_name: the name of the variable passed via the `assay` argument Returns: A length-5 tuple of (`X`, `obs`, `var`, `obsm`, `uns`). """ from ryp import to_py X_slot = f'{sce_object_name}@assays@data${assay}' if not to_py(f'"{assay}" %in% names({sce_object_name}@assays@data)'): error_message = ( f'{assay_name} {assay!r} does not exist in ' f'{sce_object_name}@assays@data${assay}; specify a different ' f'{assay_name} than {assay!r}') raise ValueError(error_message) X_classes = tuple(to_py(f'class({X_slot})', squeeze=False)) if X_classes == ('dgCMatrix',): X = to_py(X_slot).T elif X_classes == ('matrix', 'array'): warning_message = ( f"this SingleCellExperiment object's {layer_name} {layer!r} " f"is stored as a dense matrix; auto-converting to a sparse " f"csr_array") warnings.warn(warning_message) X = csr_array(to_py(X_slot, format='numpy').T) else: error_message = ( f'{assay_name} {assay!r} exists in ' f'{sce_object_name}@assays@data${assay} but is not a ' f'dgCMatrix (column-oriented sparse matrix) or matrix, ' f'instead having ') if len(X_classes) == 0: error_message += 'no classes' elif len(X_classes) == 1: error_message += f'the class {X_classes[0]!r}' else: error_message += ( f'the classes ' f'{", ".join(f"{c!r}" for c in X_classes[:-1])} and ' f'{X_classes[-1]}') error_message += ( f'; specify a different {assay_name} than {assay!r}') raise TypeError(error_message) obs = SingleCell._get_DFrame(f'colData({sce_object_name})', index='cell') var = SingleCell._get_DFrame(f'rowData({sce_object_name})', index='gene') obs = obs.cast({column.name: pl.Enum(column.cat.get_categories()) for column in obs.select(pl.col(pl.Categorical))}) var = var.cast({column.name: pl.Enum(column.cat.get_categories()) for column in var.select(pl.col(pl.Categorical))}) obsm = to_py(f'reducedDims({sce_object_name})@listData', format='numpy') uns = to_py(f'{sce_object_name}@metadata', format='numpy') # `@metadata` can be a non-empty unnamed list (though this is not # recommended), but `uns` requires names, so add dummy ones if isinstance(uns, list): uns = {f'element_{i}': uns_element for i, uns_element in enumerate(uns)} return X, obs, var, obsm, uns
[docs] @staticmethod def from_sce(sce_object_name: str, *, assay: str = 'counts', num_threads: int | np.integer = -1) -> SingleCell: """ Create a SingleCell dataset from a SingleCellExperiment object that has already been loaded into memory via the ryp Python-R bridge. To load a SingleCellExperiment object from disk, use e.g. `SingleCell('filename.rds')`. Args: sce_object_name: the name of the SingleCellExperiment object in the ryp R workspace assay: the element within `{sce_object_name}@assays@data` to use as `X`. Defaults to `'counts'`. If available, set to `'logcounts'` to load the normalized counts. If dense, will be automatically converted to a sparse array. num_threads: the default number of threads to use for all subsequent operations on this SingleCell dataset. Also sets the number of threads for this SingleCell dataset's count matrix, if present. Does not affect the number of threads used for data loading; this will always be single-threaded for SingleCellExperiment objects. Returns: The corresponding SingleCell dataset. """ from ryp import r, to_py r('suppressPackageStartupMessages(library(SingleCellExperiment))') check_type(sce_object_name, 'sce_object_name', str, 'a string') check_R_variable_name(sce_object_name, 'sce_object_name') check_type(assay, 'assay', str, 'a string') num_threads = SingleCell._process_num_threads_static(num_threads) if not to_py(f'inherits({sce_object_name}, "SingleCellExperiment")'): classes = to_py(f'class({sce_object_name})', squeeze=False) error_message = ( f'the R object named by sce_object_name, {sce_object_name}, ' f'must be a SingleCellExperiment object, but has ') if len(classes) == 0: error_message += 'no classes' elif len(classes) == 1: error_message += f'the class {classes[0]!r}' else: error_message += ( f'the classes ' f'{", ".join(f"{c!r}" for c in classes[:-1])} and ' f'{classes[-1]!r}') raise TypeError(error_message) X, obs, var, obsm, uns = \ SingleCell._from_sce(sce_object_name, assay=assay, assay_name='assay') return SingleCell(X=X, obs=obs, var=var, obsm=obsm, uns=uns, num_threads=num_threads)
[docs] def to_sce(self, sce_object_name: str, /, *, QC_column: str | None = 'passed_QC', assay: str = 'counts', empty_X: bool = False) -> None: """ Convert this SingleCell dataset to a SingleCellExperiment object in the R workspace of the ryp Python-R bridge. Make sure to remove cells failing QC with `sc.filter_obs(QC_column)` first, or specify `subset=True` in `qc()`. Alternatively, to include cells failing QC in the SingleCellExperiment object, set `QC_column` to `None`. `varm`, `obsp`, `varp`, and DataFrame keys of `obsm` will not be converted. Args: sce_object_name: the name of the R variable to assign the SingleCellExperiment object to QC_column: if not `None`, raise an error if this column is present in `obs` and not all cells pass QC assay: the name of the slot within `{sce_object_name}@assays@data` to save `X` to empty_X: if `True`, allow converting SingleCell datasets with missing counts (i.e. where `self.X is None`), by using an empty, dummy matrix with no non-zero entries as the SingleCellExperiment object's count matrix Note: SingleCellExperiment objects do not support count matrices with more than 2,147,483,647 (INT32_MAX) non-zero elements. If your SingleCell dataset is larger than this, it cannot be converted! Note: SingleCellExperiment requires the counts to be float64 and the indices of the count matrix to be int32. To avoid copying data when converting, consider converting the counts to float64 with `sc.X = sc.X.astype(np.float64, copy=False)` and the indices with `sc.X.shrink_indices()` (an in-place operation). Note: In the special case where the counts are float64 or the indices of the count matrix are int32, the counts or indices will not be copied during the conversion to save memory. This means modifying the SinglecellExperiment object's count matrix will also modify the original SingleCell dataset. If this behavior is undesirable, explicitly copy the dataset before converting with `sc.copy().to_sce()`. """ # Check that `empty_X` is `True` when `X` is absent, and vice versa. # If `empty_X` is `True`, make an empty R-compatible count matrix. check_type(empty_X, 'empty_X', bool, 'Boolean') if empty_X: if self._X is not None: error_message = \ 'X is not None, so empty_X=True cannot be specified' raise ValueError(error_message) X = csr_array((np.array([], dtype=np.float64), np.array([], dtype=np.int32), np.array([0], dtype=np.int32)), shape=(0, 0)) else: if self._X is None: error_message = ( 'X is None, so converting to SingleCellExperiment ' 'requires creating an empty count matrix; specify ' 'empty_X=True to do this automatically when saving') raise ValueError(error_message) X = self._X # Check that `X` is CSR if isinstance(self._X, csc_array): error_message = ( 'X is a csc_array; run `sc.tocsr() before converting to ' 'SingleCellExperiment') raise TypeError(error_message) # Check that `X` has sorted indices if not X.has_sorted_indices: error_message = ( 'X does not have sorted indices; run ' '`sc.X.sort_indices()` (an in-place operation) before ' 'converting to SingleCellExperiment') raise ValueError(error_message) # Check that `X` does not have too many non-zero elements if X.nnz > 2_147_483_647: error_message = ( f'X has {X.nnz:,} non-zero elements, more than ' f'2,147,483,647 (INT32_MAX), the maximum supported in R') raise ValueError(error_message) from ryp import r, to_r r('suppressPackageStartupMessages(library(SingleCellExperiment))') check_type(sce_object_name, 'sce_object_name', str, 'a string') check_R_variable_name(sce_object_name, 'sce_object_name') if QC_column is not None: check_type(QC_column, 'QC_column', str, 'a string') if QC_column in self._obs: QCed_cells = self._obs[QC_column] check_dtype(QCed_cells, f'obs[{QC_column!r}]', pl.Boolean) if QCed_cells.null_count() or not QCed_cells.all(): error_message = ( f'not all cells pass QC; remove cells failing QC with ' f'filter_obs({QC_column!r}) or by specifying ' f'subset=True in qc(), or set QC_column=None to ' f'include them in the SingleCellExperiment object') raise ValueError(error_message) check_type(assay, 'assay', str, 'a string') valid_dtypes = pl.FLOAT_DTYPES | pl.INTEGER_DTYPES | \ {pl.String, pl.Enum, pl.Categorical, pl.Boolean} for df, df_name in (self._obs, 'obs'), (self._var, 'var'): for column, dtype in df.schema.items(): if dtype.base_type() not in valid_dtypes: error_message = ( f'{df_name}[{column!r}] has the data type ' f'{dtype.base_type()!r}, which is not supported when ' f'converting to a SingleCellExperiment object') raise TypeError(error_message) # The R sparse matrix class SingleCellExperiment uses (dgCMatrix) # requires float64 data and int32 indices and indptr. Convert each of # these three to the correct type, if not already correct. X = X.astype(np.float64, copy=False) if X.indices.dtype == np.int64: X = X.shrink_indices(copy=True) try: obs_names = self.obs_names var_names = self.var_names for names, names_name in (obs_names, 'obs_names'), \ (var_names, 'var_names'): null_count = names.null_count() if null_count: error_message = ( f'{names_name} contains {null_count:,} null values, ' f'but must not contain any when converting to a ' f'SingleCellExperiment object') raise ValueError(error_message) to_r(X.T, '.SingleCell.X.T') to_r(self._obs.drop(obs_names.name), '.SingleCell.obs', rownames=obs_names) to_r(self._var.drop(var_names.name), '.SingleCell.var', rownames=var_names) r('.SingleCell.obs_names = rownames(.SingleCell.obs)') r('.SingleCell.var_names = rownames(.SingleCell.var)') r('rownames(.SingleCell.X.T) = .SingleCell.var_names') r('colnames(.SingleCell.X.T) = .SingleCell.obs_names') r(f'.SingleCell.sce = SingleCellExperiment(' f'assays = list({assay} = .SingleCell.X.T), ' f'colData = S4Vectors::DataFrame(' f' .SingleCell.obs, check.names=FALSE), ' f'rowData = S4Vectors::DataFrame(' f' .SingleCell.var, check.names=FALSE))') if self._obsm: for key, value in self._obsm.items(): if isinstance(value, pl.DataFrame): continue to_r(value, '.SingleCell.value', colnames=[ f'{key}_{i}' for i in range(1, value.shape[1] + 1)]) r('rownames(.SingleCell.value) = .SingleCell.obs_names') r(f'reducedDim(.SingleCell.sce, {key!r}) = ' f'.SingleCell.value') if self._uns: to_r(self._uns, '.SingleCell.uns') r(f'.SingleCell.sce@metadata = .SingleCell.uns') r(f'{sce_object_name} = .SingleCell.sce') finally: r(f'rm(list = Filter(exists, c(".SingleCell.obs_names", ' f'".SingleCell.var_names", ".SingleCell.X.T", ' f'".SingleCell.obs", ".SingleCell.var", ".SingleCell.sce", ' f'".SingleCell.value", ".SingleCell.uns")))')
[docs] def copy(self, *, deep: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Make a copy of this SingleCell dataset. Args: deep: whether to perform a deep or shallow copy. Since polars DataFrames are immutable, `obs` and `var` will always point to the same underlying data as the original. The difference when `deep=True` is that `X` and any NumPy arrays in `obsm`, `varm`, `obsp`, `varp`, and `uns` will point to fresh copies of the underlying data, instead of the same data as the original SingleCell dataset. When `deep=False`, any modifications to these NumPy arrays will modify both copies! num_threads: the number of threads to use when making a deep copy of `X`. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Does not affect the copied SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Can only be specified when `deep=True` and `X` is not `None`. Returns: A copy of the SingleCell dataset. """ check_type(deep, 'deep', bool, 'Boolean') if deep: num_threads = self._process_num_threads(num_threads) if self._X is not None: original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads X = self._X.copy() finally: self._X._num_threads = original_num_threads else: if num_threads is not None: error_message = \ 'num_threads can only be specified when X is not None' raise ValueError(error_message) X = None obsm = {key: value if isinstance(value, pl.DataFrame) else value.copy() for key, value in self._obsm.items()} varm = {key: value if isinstance(value, pl.DataFrame) else value.copy() for key, value in self._varm.items()} obsp = {key: value.copy() for key, value in self._obsp.items()} varp = {key: value.copy() for key, value in self._varp.items()} uns = self._uns.copy(deep=True) else: if num_threads is not None: error_message = \ 'num_threads can only be specified when X is not None' raise ValueError(error_message) X = self._X obsm = self._obsm.copy() varm = self._varm.copy() obsp = self._obsp.copy() varp = self._varp.copy() uns = self._uns.copy() return SingleCell(X=X, obs=self._obs, var=self._var, obsm=obsm, varm=varm, obsp=obsp, varp=varp, uns=uns, num_threads=self._num_threads)
[docs] def concat_obs(self, datasets: SingleCell | Iterable[SingleCell], /, *more_datasets: SingleCell, dataset_column: str | None = None, dataset_labels: Iterable[str] | None = None, flexible: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Concatenate one or more other SingleCell datasets with this one, cell-wise. All datasets must have distinct `obs_names`. By default, all datasets must have the same `var`, `varm`, `varp`, and `uns`. They must also have the same columns in `obs` and the same keys in `obsm`, with the same data types. `obsp` will be discarded during the concatenation. Conversely, if `flexible=True`, subset to genes present in all datasets (according to the first column of `var`, i.e. the `var_names`) before concatenating. Subset to columns of `var` and keys of `varm`, `varp`, and `uns` that are identical in all datasets after this subsetting. Also, subset to columns of `obs` and keys of `obsm` that are present in all datasets, and have the same data types. All datasets' `obs_names` must have the same name and data type, and similarly for their `var_names`. The one exception to the `obs` "same data type" rule: if a column is Enum in some datasets and Categorical in others, or Enum in all datasets but with different categories in each dataset, that column will be retained as an Enum column (with the union of the categories) in the concatenated `obs`. If the datasets' `X` are a mix of CSR and CSC sparse arrays, they will all be coerced to CSR. Args: datasets: one or more SingleCell datasets to concatenate with this one *more_datasets: additional SingleCell datasets to concatenate with this one, specified as positional arguments dataset_column: the name of an Enum column to be added to the concatenated dataset's `obs` labeling which dataset each cell came from. The labels themselves are determined by the `dataset_labels` argument. dataset_labels: a sequence of labels for each dataset, used to populate `dataset_column`. There must be one label per dataset being concatenated. If `dataset_labels` is not specified, the labels default to `{dataset_column}_0`, `{dataset_column}_1`, ..., `{dataset_column}_{N - 1}`. Can only be specified when `dataset_column` is not `None`. flexible: whether to subset to genes, columns of `obs` and `var`, and keys of `obsm`, `varm` and `uns` common to all datasets before concatenating, rather than raising an error on any mismatches num_threads: the number of threads to use when concatenating. Does not affect the concatenated SingleCell dataset's `num_threads`; this will always be the same as the first dataset's `num_threads`. Returns: The concatenated SingleCell dataset. """ # Check inputs datasets = (self,) + to_tuple(datasets) + more_datasets if len(datasets) == 1: error_message = \ 'need at least one other SingleCell dataset to concatenate' raise ValueError(error_message) check_types(datasets[1:], 'datasets', SingleCell, 'SingleCell datasets') if self._X is not None: if all(dataset._X is not None for dataset in datasets): X_present = True else: error_message = ( 'some datasets being concatenated have X missing, while ' 'others do not') raise ValueError(error_message) else: if all(dataset._X is None for dataset in datasets): X_present = False else: error_message = ( 'some datasets being concatenated have X missing, while ' 'others do not') raise ValueError(error_message) if dataset_column is not None: check_type(dataset_column, 'dataset_column', str, 'a string') if any(dataset_column in dataset._obs for dataset in datasets): error_message = ( f"dataset_column {dataset_column!r} is already a column " f"of at least one dataset's obs; specify a different name " f"for dataset_column") raise ValueError(error_message) if dataset_labels is not None: dataset_labels = to_tuple_checked( dataset_labels, 'dataset_labels', str, 'strings') if len(dataset_labels) != len(datasets): error_message = ( f'dataset_labels has length {len(dataset_labels):,}, ' f'but there are {len(datasets):,} datasets being ' f'concatenated') raise ValueError(error_message) else: dataset_labels = (f'dataset_{i}' for i in range(len(datasets))) elif dataset_labels is not None: error_message = ( 'when dataset_labels is specified, dataset_column must also ' 'be specified') raise ValueError(error_message) check_type(flexible, 'flexible', bool, 'Boolean') num_threads = self._process_num_threads(num_threads) # Perform either flexible or non-flexible concatenation if flexible: # Check that `obs_names` and `var_names` have the same name and # data type across all datasets obs_names_name = self.obs_names.name if not all(dataset.obs_names.name == obs_names_name for dataset in datasets[1:]): error_message = ( 'not all SingleCell datasets have the same name for the ' 'first column of obs (the obs_names column)') raise ValueError(error_message) var_names_name = self.var_names.name if not all(dataset.var_names.name == var_names_name for dataset in datasets[1:]): error_message = ( 'not all SingleCell datasets have the same name for the ' 'first column of var (the var_names column)') raise ValueError(error_message) obs_names_dtype = self.obs_names.dtype if not all(dataset.obs_names.dtype == obs_names_dtype for dataset in datasets[1:]): error_message = ( 'not all SingleCell datasets have the same data type for ' 'the first column of obs (the obs_names column)') raise TypeError(error_message) var_names_dtype = self.var_names.dtype if not all(dataset.var_names.dtype == var_names_dtype for dataset in datasets[1:]): error_message = ( 'not all SingleCell datasets have the same data type for ' 'the first column of var (the var_names column)') raise TypeError(error_message) # Subset to genes in common across all datasets genes_in_common = self.var_names\ .to_frame()\ .filter(pl.all_horizontal( self.var_names.is_in(dataset.var_names) for dataset in datasets))\ .to_series() if len(genes_in_common) == 0: error_message = \ 'no genes are shared across all SingleCell datasets' raise ValueError(error_message) datasets = [dataset if len(genes_in_common) == len(dataset.var_names) and dataset.var_names.equals(genes_in_common) else dataset[:, genes_in_common] for dataset in datasets] # Subset to columns of `var` and keys of `varm`, `varp`, and `uns` # that are identical in all datasets after this subsetting var_columns_in_common = [ column.name for column in datasets[0]._var[:, 1:] if all(column.name in dataset._var and dataset._var[column.name].equals(column) for dataset in datasets[1:])] varm_keys_in_common = [ key for key, value in self._varm.items() if all(key in dataset._varm and type(value) is type(dataset._varm[key]) and (dataset._varm[key].dtype == value.dtype and array_equal(dataset._varm[key], value) if isinstance(value, np.ndarray) else dataset._varm[key].equals(value)) for dataset in datasets[1:])] varp_keys_in_common = [ key for key, value in self._varp.items() if all(key in dataset._varp and dataset._varp[key].dtype == value.dtype and sparse_equal(dataset._varp[key], value) for dataset in datasets[1:])] uns_keys_in_common = [ key for key, value in self._uns.items() if isinstance(value, dict) and all(isinstance(dataset._uns[key], dict) and SingleCell._eq_uns(value, dataset._uns[key], different_order_ok=True) for dataset in datasets[1:]) or isinstance(value, np.ndarray) and all(isinstance(dataset._uns[key], np.ndarray) and array_equal(dataset._uns[key], value) for dataset in datasets[1:]) or all(not isinstance(dataset._uns[key], (dict, np.ndarray)) and dataset._uns[key] == value for dataset in datasets[1:])] for dataset in datasets: dataset._var = dataset._var.select(dataset.var_names, *var_columns_in_common) dataset._varm = {key: dataset._varm[key] for key in varm_keys_in_common} dataset._varp = {key: dataset._varp[key] for key in varp_keys_in_common} dataset._uns = {key: dataset._uns[key] for key in uns_keys_in_common} # Subset to columns of `obs` and keys of `obsm` that are present in # all datasets, and have the same data types. Also include columns # of `obs` that are Enum in some datasets and Categorical in # others, or Enum in all datasets but with different categories in # each dataset; cast these to Enum. obs_mismatched_categoricals = { column for column, dtype in self._obs[:, 1:] .select(pl.col(pl.Categorical, pl.Enum)).schema.items() if all(column in dataset._obs and dataset._obs[column].dtype in (pl.Categorical, pl.Enum) for dataset in datasets[1:]) and not all(dataset._obs[column].dtype == dtype for dataset in datasets[1:])} obs_columns_in_common = [ column for column, dtype in islice(self._obs.schema.items(), 1, None) if column in obs_mismatched_categoricals or all(column in dataset._obs and dataset._obs[column].dtype == dtype for dataset in datasets[1:])] cast_dict = {column: pl.Enum( pl.concat([dataset._obs[column].cat.get_categories() for dataset in datasets]) .unique(maintain_order=True)) for column in obs_mismatched_categoricals} for dataset in datasets: # the `.with_columns(...)` is a faster `.cast(cast_dict)` dataset._obs = dataset._obs\ .select(dataset.obs_names, *obs_columns_in_common)\ .with_columns(cast_to_Enum(dataset._obs[column], enum_type) .alias(column) for column, enum_type in cast_dict.items()) obsm_keys_in_common = [ key for key, value in self._obsm.items() if all(key in dataset._obsm and type(dataset._obsm[key]) is type(value) and (dataset._obsm[key].dtype == value.dtype if isinstance(value, np.ndarray) else dataset._obsm[key].schema == value.schema) for dataset in datasets[1:])] for dataset in datasets: dataset._obsm = {key: dataset._obsm[key] for key in obsm_keys_in_common} else: # non-flexible # Check that all `var`, `varm`, `varp`, and `uns` are identical var = self._var for dataset in datasets[1:]: if not dataset._var.equals(var): error_message = ( 'all SingleCell datasets must have the same var, ' 'unless flexible=True') raise ValueError(error_message) varm = self._varm for dataset in datasets[1:]: if dataset._varm.keys() != varm.keys() or \ any(type(dataset._varm[key]) is not type(value) or (dataset._varm[key].dtype != value.dtype if isinstance(value, np.ndarray) else dataset._varm[key].schema != value.schema) for key, value in varm.items()) or not \ all(array_equal(dataset._varm[key], value) for key, value in varm.items()): error_message = ( 'all SingleCell datasets must have the same varm, ' 'unless flexible=True') raise ValueError(error_message) varp = self._varp for dataset in datasets[1:]: if dataset._varp.keys() != varp.keys() or \ any(type(dataset._varp[key]) is not type(value) or dataset._varp[key].dtype != value.dtype for key, value in varp.items()) or not \ all(sparse_equal(dataset._varp[key], value) for key, value in varp.items()): error_message = ( 'all SingleCell datasets must have the same varp, ' 'unless flexible=True') raise ValueError(error_message) for dataset in datasets[1:]: if not SingleCell._eq_uns(self._uns, dataset._uns): error_message = ( 'all SingleCell datasets must have the same uns, ' 'unless flexible=True') raise ValueError(error_message) # Check that all `obs` have the same columns and data types schema = self._obs.schema for dataset in datasets[1:]: if dataset._obs.schema != schema: if dataset._obs.columns != self._obs.columns: error_message = ( 'all SingleCell datasets must have the same ' 'columns in obs, unless flexible=True') raise ValueError(error_message) else: error_message = ( 'all SingleCell datasets must have the same data ' 'type for each column of obs, unless ' 'flexible=True') raise TypeError(error_message) # Check that all `obsm` have the same keys and data types obsm = self._obsm for dataset in datasets[1:]: if dataset._obsm.keys() != obsm.keys(): error_message = ( 'all SingleCell datasets must have the same keys in ' 'obsm, unless flexible=True') raise ValueError(error_message) if any(type(dataset._obsm[key]) is not type(value) or (dataset._obsm[key].dtype != value.dtype if isinstance(value, np.ndarray) else dataset._obsm[key].schema != value.schema) for key, value in obsm.items()): error_message = ( 'all SingleCell datasets must have the same data ' 'type for each key in obsm, unless flexible=True') raise TypeError(error_message) # If `dataset_column` is not `None`, add labels for each dataset if dataset_column is not None: for dataset, label in zip(datasets, dataset_labels): dataset._obs = dataset._obs\ .with_columns(pl.lit(label).alias(dataset_column)) # Concatenate; output should be CSR when there's a mix of inputs obs = pl.concat([dataset._obs for dataset in datasets]) num_unique = obs[:, 0].n_unique() if num_unique < len(obs): error_message = ( f'obs_names contains {len(obs) - num_unique:,} duplicates ' f'after concatenation') raise ValueError(error_message) if X_present: if all(isinstance(dataset._X, csr_array) for dataset in datasets): X = sparse_major_stack([dataset._X for dataset in datasets], num_threads=num_threads) elif all(isinstance(dataset._X, csc_array) for dataset in datasets): X = sparse_minor_stack([dataset._X for dataset in datasets], num_threads=num_threads) else: X = sparse_major_stack([dataset._X.tocsr() if isinstance(dataset._X, csc_array) else dataset._X for dataset in datasets], num_threads=num_threads) else: X = None obsm = {key: concatenate([dataset._obsm[key] for dataset in datasets], num_threads=num_threads) if isinstance(value, np.ndarray) else pl.concat([dataset._obsm[key] for dataset in datasets]) for key, value in datasets[0]._obsm.items()} return SingleCell(X=X, obs=obs, var=datasets[0]._var, obsm=obsm, varm=datasets[0]._varm, varp=datasets[0]._varp, uns=datasets[0]._uns, num_threads=datasets[0]._num_threads)
[docs] def concat_var(self, datasets: SingleCell | Iterable[SingleCell], /, *more_datasets: SingleCell, dataset_column: str | None = None, dataset_labels: Iterable[str] | None = None, flexible: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Concatenate one or more other SingleCell datasets with this one, gene-wise. This is much less common than the cell-wise concatenation provided by `concat_obs()`. All datasets must have distinct `var_names`. By default, all datasets must have the same `obs`, `obsm`, `obsp`, and `uns`. They must also have the same columns in `var` and the same keys in `varm`, with the same data types. `varp` will be discarded during the concatenation. Conversely, if `flexible=True`, subset to cells present in all datasets (according to the first column of `obs`, i.e. the `obs_names`) before concatenating. Subset to columns of `obs` and keys of `obsm`, `obsp`, and `uns` that are identical in all datasets after this subsetting. Also, subset to columns of `var` and keys of `varm` that are present in all datasets, and have the same data types. All datasets' `obs_names` must have the same name and data type, and similarly for their `var_names`. The one exception to the `var` "same data type" rule: if a column is Enum in some datasets and Categorical in others, or Enum in all datasets but with different categories in each dataset, that column will be retained as an Enum column (with the union of the categories) in the concatenated `var`. If the datasets' `X` are a mix of CSR and CSC sparse arrays, they will all be coerced to CSR. Args: datasets: one or more SingleCell datasets to concatenate with this one *more_datasets: additional SingleCell datasets to concatenate with this one, specified as positional arguments dataset_column: the name of an Enum column to be added to the concatenated dataset's `var` labeling which dataset each cell came from. The labels themselves are determined by the `dataset_labels` argument. dataset_labels: a sequence of labels for each dataset, used to populate `dataset_column`. There must be one label per dataset being concatenated. If `dataset_labels` is not specified, the labels default to `{dataset_column}_0`, `{dataset_column}_1`, ..., `{dataset_column}_{N - 1}`. Can only be specified when `dataset_column` is not `None`. flexible: whether to subset to cells, columns of `obs` and `var`, and keys of `obsm`, `varm` and `uns` common to all datasets before concatenating, rather than raising an error on any mismatches num_threads: the number of threads to use when concatenating. Does not affect the concatenated SingleCell dataset's `num_threads`; this will always be the same as the first dataset's `num_threads`. Returns: The concatenated SingleCell dataset. """ # Check inputs datasets = (self,) + to_tuple(datasets) + more_datasets if len(datasets) == 1: error_message = \ 'need at least one other SingleCell dataset to concatenate' raise ValueError(error_message) check_types(datasets[1:], 'datasets', SingleCell, 'SingleCell datasets') if self._X is not None: if all(dataset._X is not None for dataset in datasets): X_present = True else: error_message = ( 'some datasets being concatenated have X missing, while ' 'others do not') raise ValueError(error_message) else: if all(dataset._X is None for dataset in datasets): X_present = False else: error_message = ( 'some datasets being concatenated have X missing, while ' 'others do not') raise ValueError(error_message) if dataset_column is not None: check_type(dataset_column, 'dataset_column', str, 'a string') if any(dataset_column in dataset._obs for dataset in datasets): error_message = ( f"dataset_column {dataset_column!r} is already a column " f"of at least one dataset's var; specify a different name " f"for dataset_column") raise ValueError(error_message) if dataset_labels is not None: dataset_labels = to_tuple_checked( dataset_labels, 'dataset_labels', str, 'strings') if len(dataset_labels) != len(datasets): error_message = ( f'dataset_labels has length {len(dataset_labels):,}, ' f'but there are {len(datasets):,} datasets being ' f'concatenated') raise ValueError(error_message) else: dataset_labels = (f'dataset_{i}' for i in range(len(datasets))) elif dataset_labels is not None: error_message = ( 'when dataset_labels is specified, dataset_column must also ' 'be specified') raise ValueError(error_message) check_type(flexible, 'flexible', bool, 'Boolean') num_threads = self._process_num_threads(num_threads) # Perform either flexible or non-flexible concatenation if flexible: # Check that `obs_names` and `var_names` have the same name and # data type across all datasets obs_names_name = self.obs_names.name if not all(dataset.obs_names.name == obs_names_name for dataset in datasets[1:]): error_message = ( 'not all SingleCell datasets have the same name for the ' 'first column of obs (the obs_names column)') raise ValueError(error_message) var_names_name = self.var_names.name if not all(dataset.var_names.name == var_names_name for dataset in datasets[1:]): error_message = ( 'not all SingleCell datasets have the same name for the ' 'first column of var (the var_names column)') raise ValueError(error_message) obs_names_dtype = self.obs_names.dtype if not all(dataset.obs_names.dtype == obs_names_dtype for dataset in datasets[1:]): error_message = ( 'not all SingleCell datasets have the same data type for ' 'the first column of obs (the obs_names column)') raise TypeError(error_message) var_names_dtype = self.var_names.dtype if not all(dataset.var_names.dtype == var_names_dtype for dataset in datasets[1:]): error_message = ( 'not all SingleCell datasets have the same data type for ' 'the first column of var (the var_names column)') raise TypeError(error_message) # Subset to cells in common across all datasets cells_in_common = self.obs_names\ .to_frame()\ .filter(pl.all_horizontal( self.obs_names.is_in(dataset.obs_names) for dataset in datasets))\ .to_series() if len(cells_in_common) == 0: error_message = \ 'no cells are shared across all SingleCell datasets' raise ValueError(error_message) datasets = [dataset if len(cells_in_common) == len(dataset.obs_names) and dataset.obs_names.equals(cells_in_common) else dataset[cells_in_common] for dataset in datasets] # Subset to columns of `obs` and keys of `obsm`, `obsp`, and `uns` # that are identical in all datasets after this subsetting obs_columns_in_common = [ column.name for column in datasets[0]._obs[:, 1:] if all(column.name in dataset._obs and dataset._obs[column.name].equals(column) for dataset in datasets[1:])] obsm_keys_in_common = [ key for key, value in self._obsm.items() if all(key in dataset._obsm and type(value) is type(dataset._obsm[key]) and (dataset._obsm[key].dtype == value.dtype and array_equal(dataset._obsm[key], value) if isinstance(value, np.ndarray) else dataset._obsm[key].equals(value)) for dataset in datasets[1:])] obsp_keys_in_common = [ key for key, value in self._obsp.items() if all(key in dataset._obsp and dataset._obsp[key].dtype == value.dtype and sparse_equal(dataset._obsp[key], value) for dataset in datasets[1:])] uns_keys_in_common = [ key for key, value in self._uns.items() if isinstance(value, dict) and all(isinstance(dataset._uns[key], dict) and SingleCell._eq_uns(value, dataset._uns[key], different_order_ok=True) for dataset in datasets[1:]) or isinstance(value, np.ndarray) and all(isinstance(dataset._uns[key], np.ndarray) and array_equal(dataset._uns[key], value) for dataset in datasets[1:]) or all(not isinstance(dataset._uns[key], (dict, np.ndarray)) and dataset._uns[key] == value for dataset in datasets[1:])] for dataset in datasets: dataset._obs = dataset._obs.select(dataset.obs_names, *obs_columns_in_common) dataset._obsm = {key: dataset._obsm[key] for key in obsm_keys_in_common} dataset._obsp = {key: dataset._obsp[key] for key in obsp_keys_in_common} dataset._uns = {key: dataset._uns[key] for key in uns_keys_in_common} # Subset to columns of `var` and keys of `varm` that are present in # all datasets, and have the same data types. Also include columns # of `var` that are Enum in some datasets and Categorical in # others, or Enum in all datasets but with different categories in # each dataset; cast these to Enum. var_mismatched_categoricals = { column for column, dtype in self._var[:, 1:] .select(pl.col(pl.Categorical, pl.Enum)).schema.items() if all(column in dataset._var and dataset._var[column].dtype in (pl.Categorical, pl.Enum) for dataset in datasets[1:]) and not all(dataset._var[column].dtype == dtype for dataset in datasets[1:])} var_columns_in_common = [ column for column, dtype in islice(self._var.schema.items(), 1, None) if column in var_mismatched_categoricals or all(column in dataset._var and dataset._var[column].dtype == dtype for dataset in datasets[1:])] cast_dict = {column: pl.Enum( pl.concat([dataset._var[column].cat.get_categories() for dataset in datasets]) .unique(maintain_order=True)) for column in var_mismatched_categoricals} for dataset in datasets: # the `.with_columns(...)` is a faster `.cast(cast_dict)` dataset._var = dataset._var\ .select(dataset.var_names, *var_columns_in_common)\ .with_columns(cast_to_Enum(dataset._var[column], enum_type) .alias(column) for column, enum_type in cast_dict.items()) varm_keys_in_common = [ key for key, value in self._varm.items() if all(key in dataset._varm and type(dataset._varm[key]) is type(value) and (dataset._varm[key].dtype == value.dtype if isinstance(value, np.ndarray) else dataset._varm[key].schema == value.schema) for dataset in datasets[1:])] for dataset in datasets: dataset._varm = {key: dataset._varm[key] for key in varm_keys_in_common} else: # non-flexible # Check that all `obs`, `obsm`, `obsp`, and `uns` are identical obs = self._obs for dataset in datasets[1:]: if not dataset._obs.equals(obs): error_message = ( 'all SingleCell datasets must have the same obs, ' 'unless flexible=True') raise ValueError(error_message) obsm = self._obsm for dataset in datasets[1:]: if dataset._obsm.keys() != obsm.keys() or \ any(type(dataset._obsm[key]) is not type(value) or (dataset._obsm[key].dtype != value.dtype if isinstance(value, np.ndarray) else dataset._obsm[key].schema != value.schema) for key, value in obsm.items()) or not \ all(array_equal(dataset._obsm[key], value) for key, value in obsm.items()): error_message = ( 'all SingleCell datasets must have the same obsm, ' 'unless flexible=True') raise ValueError(error_message) obsp = self._obsp for dataset in datasets[1:]: if dataset._obsp.keys() != obsp.keys() or \ any(type(dataset._obsp[key]) is not type(value) or dataset._obsp[key].dtype != value.dtype for key, value in obsp.items()) or not \ all(sparse_equal(dataset._obsp[key], value) for key, value in obsp.items()): error_message = ( 'all SingleCell datasets must have the same obsp, ' 'unless flexible=True') raise ValueError(error_message) for dataset in datasets[1:]: if not SingleCell._eq_uns(self._uns, dataset._uns): error_message = ( 'all SingleCell datasets must have the same uns, ' 'unless flexible=True') raise ValueError(error_message) # Check that all `var` have the same columns and data types schema = self._var.schema for dataset in datasets[1:]: if dataset._var.schema != schema: if dataset._var.columns != self._var.columns: error_message = ( 'all SingleCell datasets must have the same ' 'columns in var, unless flexible=True') raise ValueError(error_message) else: error_message = ( 'all SingleCell datasets must have the same data ' 'type for each column of var, unless ' 'flexible=True') raise TypeError(error_message) # Check that all `varm` have the same keys and data types varm = self._varm for dataset in datasets[1:]: if dataset._varm.keys() != varm.keys(): error_message = ( 'all SingleCell datasets must have the same keys in ' 'varm, unless flexible=True') raise ValueError(error_message) if any(type(dataset._varm[key]) is not type(value) or (dataset._varm[key].dtype != value.dtype if isinstance(value, np.ndarray) else dataset._varm[key].schema != value.schema) for key, value in varm.items()): error_message = ( 'all SingleCell datasets must have the same data ' 'type for each key in varm, unless flexible=True') raise TypeError(error_message) # If `dataset_column` is not `None`, add labels for each dataset if dataset_column is not None: for dataset, label in zip(datasets, dataset_labels): dataset._var = dataset._var\ .with_columns(pl.lit(label).alias(dataset_column)) # Concatenate; output should be CSR when there's a mix of inputs var = pl.concat([dataset._var for dataset in datasets]) num_unique = var[:, 0].n_unique() if num_unique != len(var): error_message = ( f'var_names contains {len(var) - num_unique:,} duplicates ' f'after concatenation') raise ValueError(error_message) if X_present: if all(isinstance(dataset._X, csr_array) for dataset in datasets): X = sparse_minor_stack([dataset._X for dataset in datasets], num_threads=num_threads) elif all(isinstance(dataset._X, csc_array) for dataset in datasets): X = sparse_major_stack([dataset._X for dataset in datasets], num_threads=num_threads) else: X = sparse_minor_stack([dataset._X.tocsr() if isinstance(dataset._X, csc_array) else dataset._X for dataset in datasets], num_threads=num_threads) else: X = None varm = {key: concatenate([dataset._varm[key] for dataset in datasets], num_threads=num_threads) if isinstance(value, np.ndarray) else pl.concat([dataset._varm[key] for dataset in datasets]) for key, value in datasets[0]._varm.items()} return SingleCell(X=X, obs=datasets[0]._obs, var=var, obsm=datasets[0]._obsm, varm=varm, obsp=datasets[0]._obsp, uns=datasets[0]._uns, num_threads=datasets[0]._num_threads)
[docs] def split_by_obs(self, by_column: SingleCellColumn, /, *, QC_column: SingleCellColumn | None = 'passed_QC', sort_by_size: bool = False, num_threads: int | np.integer | None = None) -> \ dict[str, SingleCell]: """ The opposite of `concat_obs()`: splits a SingleCell dataset into a dictionary of SingleCell datasets, one per unique value of a column of `obs`. Args: by_column: a String, Enum, Categorical, or integer column of `obs` to split by. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Can contain `null` entries: the corresponding cells will not be included in the result. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will not be selected when splitting. sort_by_size: if `True`, datasets in the returned dictionary will be sorted in decreasing order of size. If `False`, they will be sorted in ascending order, according to the sort order of `by_column`'s data type. num_threads: the number of threads to use when splitting `X`. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Can only be specified when `X` is not `None`. Returns: A dictionary mapping each unique value of `by_column` to a SingleCell dataset subset to cells where `column` has that value. """ if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') by_column = self._get_column('obs', by_column, 'by_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=QC_column, allow_null=True) check_type(sort_by_size, 'sort_by_size', bool, 'Boolean') num_threads = self._process_num_threads(num_threads) values = by_column.value_counts(sort=True).to_series().drop_nulls() \ if sort_by_size else by_column.unique().sort() if QC_column is None: return {value: self.filter_obs(by_column == value, num_threads=num_threads) for value in values} else: return {value: self.filter_obs(by_column.eq(value) & QC_column, num_threads=num_threads) for value in values}
[docs] def split_by_var(self, by_column: SingleCellColumn, /, *, QC_column: SingleCellColumn | None = 'passed_QC', sort_by_size: bool = False, num_threads: int | np.integer | None = None) -> \ dict[str, SingleCell]: """ The opposite of `concat_var()`: splits a SingleCell dataset into a dictionary of SingleCell datasets, one per unique value of a column of `var`. Args: by_column: a String, Enum, Categorical, or integer column of `var` to split by. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Can contain `null` entries: the corresponding genes will not be included in the result. sort_by_size: if `True`, datasets in the returned dictionary will be sorted in decreasing order of size. If `False`, they will be sorted in ascending order, according to the sort order of `by_column`'s data type. num_threads: the number of threads to use when splitting `X`. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Can only be specified when `X` is not `None`. Returns: A dictionary mapping each unique value of `by_column` to a SingleCell dataset subset to genes where `column` has that value. """ by_column = self._get_column('var', by_column, 'by_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), allow_null=True) check_type(sort_by_size, 'sort_by_size', bool, 'Boolean') num_threads = self._process_num_threads(num_threads) values = by_column.value_counts(sort=True).to_series().drop_nulls() \ if sort_by_size else by_column.unique().sort() return {value: self.filter_var(by_column == value, num_threads=num_threads) for value in values}
[docs] def tocsr(self, *, num_threads: int | np.integer | None = None) -> SingleCell: """ Make a copy of this SingleCell dataset, converting `X` to a `csr_array`. Raise an error if `X` is already a `csr_array`. Args: num_threads: the number of threads to use when converting to CSR. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Does not affect the copied SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A copy of this SingleCell dataset, with `X` as a `csr_array`. """ # Check that `X` is present if self._X is None: error_message = 'X is None, so converting to CSR is not possible' raise ValueError(error_message) if isinstance(self._X, csr_array): error_message = 'X is already a csr_array' raise TypeError(error_message) num_threads = self._process_num_threads(num_threads) original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads return SingleCell(X=self._X.tocsr(), obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads) finally: self._X._num_threads = original_num_threads
[docs] def tocsc(self, *, num_threads: int | np.integer | None = None) -> SingleCell: """ Make a copy of this SingleCell dataset, converting `X` to a csc_array. Raise an error if `X` is already a `csc_array`. This function is provided for completeness, but `csr_array` is a far better format than `csc_array` for cell-wise operations like pseudobulking, so using `tocsc()` is rarely advisable. Args: num_threads: the number of threads to use when converting to CSC. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Does not affect the copied SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A copy of this SingleCell dataset, with `X` as a `csc_array`. """ # Check that `X` is present if self._X is None: error_message = 'X is None, so converting to CSC is not possible' raise ValueError(error_message) if isinstance(self._X, csc_array): error_message = 'X is already a csc_array' raise TypeError(error_message) num_threads = self._process_num_threads(num_threads) original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads return SingleCell(X=self._X.tocsc(), obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads) finally: self._X._num_threads = original_num_threads
[docs] def filter_obs(self, *predicates: pl.Expr | pl.Series | str | Iterable[pl.Expr | pl.Series | str] | bool | list[bool] | np.ndarray[np.dtype[np.bool_]], num_threads: int | np.integer | None = None, **constraints: Any) -> SingleCell: """ Equivalent to `df.filter()` from polars, but applied to both `obs`/`obsm` and `X`. Args: *predicates: one or more column names, expressions that evaluate to Boolean Series, Boolean Series, lists of Booleans, and/or 1D Boolean NumPy arrays num_threads: the number of threads to use when filtering `X`. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Does not affect the filtered SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Can only be specified when `X` is not `None`. **constraints: column filters: `name=value` filters to cells where the column named `name` has the value `value` Returns: A new SingleCell dataset filtered to cells passing all the Boolean filters in `predicates` and `constraints`. """ obs = self._obs\ .with_columns(_SingleCell_index=pl.int_range(pl.len(), dtype=pl.Int32))\ .filter(*predicates, **constraints) indices = obs['_SingleCell_index'].to_numpy() if self._X is None: if num_threads is not None: error_message = \ 'num_threads can only be specified when deep=True' raise ValueError(error_message) return SingleCell(X=None, obs=obs.drop('_SingleCell_index'), var=self._var, obsm={key: value[indices] for key, value in self._obsm.items()}, varm=self._varm, obsp={key: value[ix_symmetric(indices)] for key, value in self._obsp.items()}, varp=self._varp, uns=self._uns, num_threads=self._num_threads) else: num_threads = self._process_num_threads(num_threads) original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads return SingleCell(X=self._X[indices], obs=obs.drop('_SingleCell_index'), var=self._var, obsm={key: value[indices] for key, value in self._obsm.items()}, varm=self._varm, obsp={key: value[ix_symmetric(indices)] for key, value in self._obsp.items()}, varp=self._varp, uns=self._uns, num_threads=self._num_threads) finally: self._X._num_threads = original_num_threads
[docs] def filter_var(self, *predicates: pl.Expr | pl.Series | str | Iterable[pl.Expr | pl.Series | str] | bool | list[bool] | np.ndarray[np.dtype[np.bool_]], num_threads: int | np.integer | None = None, **constraints: Any) -> SingleCell: """ Equivalent to `df.filter()` from polars, but applied to both `var`/`varm` and `X`. Args: *predicates: one or more column names, expressions that evaluate to Boolean Series, Boolean Series, lists of Booleans, and/or 1D Boolean NumPy arrays num_threads: the number of threads to use when filtering `X`. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Does not affect the filtered SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Can only be specified when `X` is not `None`. **constraints: column filters: `name=value` filters to genes where the column named `name` has the value `value` Returns: A new SingleCell dataset filtered to genes passing all the Boolean filters in `predicates` and `constraints`. """ var = self._var\ .with_columns(_SingleCell_index=pl.int_range(pl.len(), dtype=pl.Int32))\ .filter(*predicates, **constraints) indices = var['_SingleCell_index'].to_numpy() if self._X is None: if num_threads is not None: error_message = \ 'num_threads can only be specified when X is not None' raise ValueError(error_message) return SingleCell(X=None, obs=self._obs, var=var.drop('_SingleCell_index'), obsm=self._obsm, varm={key: value[indices] for key, value in self._varm.items()}, obsp=self._obsp, varp={key: value[ix_symmetric(indices)] for key, value in self._varp.items()}, uns=self._uns, num_threads=self._num_threads) else: num_threads = self._process_num_threads(num_threads) original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads return SingleCell(X=self._X[:, indices], obs=self._obs, var=var.drop('_SingleCell_index'), obsm=self._obsm, varm={key: value[indices] for key, value in self._varm.items()}, obsp=self._obsp, varp={key: value[ix_symmetric(indices)] for key, value in self._varp.items()}, uns=self._uns, num_threads=self._num_threads) finally: self._X._num_threads = original_num_threads
[docs] def select_obs(self, *exprs: Scalar | pl.Expr | pl.Series | Iterable[Scalar | pl.Expr | pl.Series], **named_exprs: Scalar | pl.Expr | pl.Series) -> SingleCell: """ Equivalent to `df.select()` from polars, but applied to `obs`. `obs_names` will be automatically included as the first column, if not included explicitly. Args: *exprs: column(s) to select, specified as positional arguments. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. **named_exprs: additional columns to select, specified as keyword arguments. The columns will be renamed to the keyword used. Returns: A new SingleCell dataset with `obs=obs.select(*exprs, **named_exprs)`, and `obs_names` as the first column unless already included explicitly. """ obs = self._obs.select(*exprs, **named_exprs) if self.obs_names.name in obs: error_message = ( f'one of the selected columns is the obs_names, ' f'{self.obs_names.name!r}, but the obs_names will always be ' f'selected automatically as the first column and thus should ' f'not be specified explicitly') raise ValueError(error_message) obs = obs.select(self.obs_names, pl.all()) return SingleCell(X=self._X, obs=obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def select_var(self, *exprs: Scalar | pl.Expr | pl.Series | Iterable[Scalar | pl.Expr | pl.Series], **named_exprs: Scalar | pl.Expr | pl.Series) -> SingleCell: """ Equivalent to `df.select()` from polars, but applied to `var`. `var_names` will be automatically included as the first column, if not included explicitly. Args: *exprs: column(s) to select, specified as positional arguments. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. **named_exprs: additional columns to select, specified as keyword arguments. The columns will be renamed to the keyword used. Returns: A new SingleCell dataset with `var=var.select(*exprs, **named_exprs)`, and `var_names` as the first column unless already included explicitly. """ var = self._var.select(*exprs, **named_exprs) if self.var_names.name in var: error_message = ( f'one of the selected columns is the var_names, ' f'{self.var_names.name!r}, but the var_names will always be ' f'selected automatically as the first column and thus should ' f'not be specified explicitly') raise ValueError(error_message) var = var.select(self.var_names, pl.all()) return SingleCell(X=self._X, obs=self._obs, var=var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def select_obsm(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Subsets `obsm` to the specified key(s). Args: keys: key(s) to select *more_keys: additional keys to select, specified as positional arguments Returns: A new SingleCell dataset with `obsm` subset to the specified key(s). """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._obsm: error_message = \ f'tried to select {key!r}, which is not a key of obsm' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm={key: value for key, value in self._obsm.items() if key in keys}, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def select_varm(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Subsets `varm` to the specified key(s). Args: keys: key(s) to select *more_keys: additional keys to select, specified as positional arguments Returns: A new SingleCell dataset with `varm` subset to the specified key(s). """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._varm: error_message = \ f'tried to select {key!r}, which is not a key of varm' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm={key: value for key, value in self._varm.items() if key in keys}, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def select_obsp(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Subsets `obsp` to the specified key(s). Args: keys: key(s) to select *more_keys: additional keys to select, specified as positional arguments Returns: A new SingleCell dataset with `obsp` subset to the specified key(s). """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._obsp: error_message = \ f'tried to select {key!r}, which is not a key of obsp' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp={key: value for key, value in self._obsp.items() if key in keys}, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def select_varp(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Subsets `varp` to the specified key(s). Args: keys: key(s) to select *more_keys: additional keys to select, specified as positional arguments Returns: A new SingleCell dataset with `varp` subset to the specified key(s). """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._varp: error_message = \ f'tried to select {key!r}, which is not a key of varp' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp={key: value for key, value in self._varp.items() if key in keys}, uns=self._uns, num_threads=self._num_threads)
[docs] def select_uns(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Subsets `uns` to the specified key(s). Args: keys: key(s) to select *more_keys: additional keys to select, specified as positional arguments Returns: A new SingleCell dataset with `uns` subset to the specified key(s). """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._uns: error_message = \ f'tried to select {key!r}, which is not a key of uns' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns={key: value for key, value in self._uns.items() if key in keys}, num_threads=self._num_threads)
[docs] def with_columns_obs(self, *exprs: Scalar | pl.Expr | pl.Series | Iterable[Scalar | pl.Expr | pl.Series], **named_exprs: Scalar | pl.Expr | pl.Series) -> \ SingleCell: """ Equivalent to `df.with_columns()` from polars, but applied to `obs`. Args: *exprs: column(s) to add, specified as positional arguments. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. **named_exprs: additional columns to add, specified as keyword arguments. The columns will be renamed to the keyword used. Returns: A new SingleCell dataset with `obs=obs.with_columns(*exprs, **named_exprs)`. """ return SingleCell(X=self._X, obs=self._obs.with_columns(*exprs, **named_exprs), var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def with_columns_var(self, *exprs: Scalar | pl.Expr | pl.Series | Iterable[Scalar | pl.Expr | pl.Series], **named_exprs: Scalar | pl.Expr | pl.Series) -> \ SingleCell: """ Equivalent to `df.with_columns()` from polars, but applied to `var`. Args: *exprs: column(s) to add, specified as positional arguments. Accepts expression input. Strings are parsed as column names, other non-expression inputs are parsed as literals. **named_exprs: additional columns to add, specified as keyword arguments. The columns will be renamed to the keyword used. Returns: A new SingleCell dataset with `var=var.with_columns(*exprs, **named_exprs)`. """ return SingleCell(X=self._X, obs=self._obs, var=self._var.with_columns(*exprs, **named_exprs), obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def with_obsm(self, obsm: dict[str, np.ndarray | pl.DataFrame] = {}, /, **more_obsm: np.ndarray) -> SingleCell: """ Adds one or more keys to `obsm`, overwriting existing keys with the same names if present. Args: obsm: a dictionary of keys to add to (or overwrite in) `obsm` **more_obsm: additional keys to add to (or overwrite in) `obsm`, specified as keyword arguments Returns: A new SingleCell dataset with the new key(s) added to or overwritten in `obsm`. """ check_type(obsm, 'obsm', dict, 'a dictionary') for key, value in obsm.items(): if not isinstance(key, str): error_message = ( f'all keys of obsm must be strings, but new obsm contains ' f'a key of type {type(key).__name__!r}') raise TypeError(error_message) obsm |= more_obsm if len(obsm) == 0: error_message = \ 'obsm is empty and no keyword arguments were specified' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm | obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def with_varm(self, varm: dict[str, np.ndarray | pl.DataFrame] = {}, /, **more_varm: np.ndarray) -> SingleCell: """ Adds one or more keys to `varm`, overwriting existing keys with the same names if present. Args: varm: a dictionary of keys to add to (or overwrite in) `varm` **more_varm: additional keys to add to (or overwrite in) `varm`, specified as keyword arguments Returns: A new SingleCell dataset with the new key(s) added to or overwritten in `varm`. """ check_type(varm, 'varm', dict, 'a dictionary') for key, value in varm.items(): if not isinstance(key, str): error_message = ( f'all keys of varm must be strings, but new varm contains ' f'a key of type {type(key).__name__!r}') raise TypeError(error_message) varm |= more_varm if len(varm) == 0: error_message = \ 'varm is empty and no keyword arguments were specified' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm | varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def with_obsp(self, obsp: dict[str, csr_array | csc_array] = {}, /, **more_obsp: csr_array | csc_array) -> SingleCell: """ Adds one or more keys to `obsp`, overwriting existing keys with the same names if present. Args: obsp: a dictionary of keys to add to (or overwrite in) `obsp` **more_obsp: additional keys to add to (or overwrite in) `obsp`, specified as keyword arguments Returns: A new SingleCell dataset with the new key(s) added to or overwritten in `obsp`. """ check_type(obsp, 'obsp', dict, 'a dictionary') for key, value in obsp.items(): if not isinstance(key, str): error_message = ( f'all keys of obsp must be strings, but new obsp contains ' f'a key of type {type(key).__name__!r}') raise TypeError(error_message) obsp |= more_obsp if len(obsp) == 0: error_message = \ 'obsp is empty and no keyword arguments were specified' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp | obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def with_varp(self, varp: dict[str, sparse.csr_array | sparse.csc_array | sparse.csr_matrix | sparse.csc_matrix] = {}, /, **more_varp: sparse.csr_array | sparse.csc_array | sparse.csr_matrix | sparse.csc_matrix) -> \ SingleCell: """ Adds one or more keys to `varp`, overwriting existing keys with the same names if present. Args: varp: a dictionary of keys to add to (or overwrite in) `varp` **more_varp: additional keys to add to (or overwrite in) `varp`, specified as keyword arguments Returns: A new SingleCell dataset with the new key(s) added to or overwritten in `varp`. """ check_type(varp, 'varp', dict, 'a dictionary') for key, value in varp.items(): if not isinstance(key, str): error_message = ( f'all keys of varp must be strings, but new varp contains ' f'a key of type {type(key).__name__!r}') raise TypeError(error_message) varp |= more_varp if len(varp) == 0: error_message = \ 'varp is empty and no keyword arguments were specified' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp | varp, uns=self._uns, num_threads=self._num_threads)
[docs] def with_uns(self, uns: dict[str, UnsDict] = {}, /, **more_uns: UnsItem | UnsDict) -> SingleCell: """ Adds one or more keys to `uns`, overwriting existing keys with the same names if present. Args: uns: a dictionary of keys to add to (or overwrite in) `uns` **more_uns: additional keys to add to (or overwrite in) `uns`, specified as keyword arguments Returns: A new SingleCell dataset with the new key(s) added to or overwritten in `uns`. """ check_type(uns, 'uns', dict, 'a dictionary') for key, value in uns.items(): if not isinstance(key, str): error_message = ( f'all keys of uns must be strings, but new uns contains a ' f'key of type {type(key).__name__!r}') raise TypeError(error_message) uns |= more_uns if len(uns) == 0: error_message = \ 'uns is empty and no keyword arguments were specified' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns | uns, num_threads=self._num_threads)
[docs] def drop_X(self): """ Create a new SingleCell dataset with `X` removed, to reduce memory use. Returns: A new SingleCell dataset with `X` set to `None`. """ if self._X is None: error_message = 'X is None, so it cannot be dropped' raise TypeError(error_message) return SingleCell(obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def drop_obs(self, columns: pl.type_aliases.ColumnNameOrSelector | Iterable[pl.type_aliases.ColumnNameOrSelector], /, *more_columns: pl.type_aliases.ColumnNameOrSelector) -> \ SingleCell: """ Create a new SingleCell dataset with `columns` and `more_columns` removed from `obs`. Args: columns: columns(s) to drop *more_columns: additional columns to drop, specified as positional arguments Returns: A new SingleCell dataset with the column(s) removed. """ columns = to_tuple(columns) + more_columns return SingleCell(X=self._X, obs=self._obs.drop(columns), var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def drop_var(self, columns: pl.type_aliases.ColumnNameOrSelector | Iterable[pl.type_aliases.ColumnNameOrSelector], /, *more_columns: pl.type_aliases.ColumnNameOrSelector) -> \ SingleCell: """ Create a new SingleCell dataset with `columns` and `more_columns` removed from `var`. Args: columns: columns(s) to drop *more_columns: additional columns to drop, specified as positional arguments Returns: A new SingleCell dataset with the column(s) removed. """ columns = to_tuple(columns) + more_columns return SingleCell(X=self._X, obs=self._obs, var=self._var.drop(columns), obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def drop_obsm(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Create a new SingleCell dataset with `keys` and `more_keys` removed from `obsm`. Args: keys: key(s) to drop *more_keys: additional keys to drop, specified as positional arguments Returns: A new SingleCell dataset with the specified key(s) removed from obsm. """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._obsm: error_message = \ f'tried to drop {key!r}, which is not a key of obsm' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm={key: value for key, value in self._obsm.items() if key not in keys}, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def drop_varm(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Create a new SingleCell dataset with `keys` and `more_keys` removed from `varm`. Args: keys: key(s) to drop *more_keys: additional keys to drop, specified as positional arguments Returns: A new SingleCell dataset with the specified key(s) removed from varm. """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._varm: error_message = \ f'tried to drop {key!r}, which is not a key of varm' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm={key: value for key, value in self._varm.items() if key not in keys}, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def drop_obsp(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Create a new SingleCell dataset with `keys` and `more_keys` removed from `obsp`. Args: keys: key(s) to drop *more_keys: additional keys to drop, specified as positional arguments Returns: A new SingleCell dataset with the specified key(s) removed from obsp. """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._obsp: error_message = \ f'tried to drop {key!r}, which is not a key of obsp' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp={key: value for key, value in self._obsp.items() if key not in keys}, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def drop_varp(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Create a new SingleCell dataset with `keys` and `more_keys` removed from `varp`. Args: keys: key(s) to drop *more_keys: additional keys to drop, specified as positional arguments Returns: A new SingleCell dataset with the specified key(s) removed from varp. """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._varp: error_message = \ f'tried to drop {key!r}, which is not a key of varp' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp={key: value for key, value in self._varp.items() if key not in keys}, uns=self._uns, num_threads=self._num_threads)
[docs] def drop_uns(self, keys: str | Iterable[str], /, *more_keys: str) -> \ SingleCell: """ Create a new SingleCell dataset with `keys` and `more_keys` removed from `uns`. Args: keys: key(s) to drop *more_keys: additional keys to drop, specified as positional arguments Returns: A new SingleCell dataset with the specified key(s) removed from uns. """ keys = to_tuple_checked(keys, 'keys', str, 'strings') check_types(more_keys, 'more_keys', str, 'strings') keys += more_keys for key in keys: if key not in self._uns: error_message = \ f'tried to drop {key!r}, which is not a key of uns' raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns={key: value for key, value in self._uns.items() if key not in keys}, num_threads=self._num_threads)
[docs] def rename_obs(self, mapping: dict[str, str] | Callable[[str], str], /) -> SingleCell: """ Create a new SingleCell dataset with column(s) of `obs` renamed. Rename column(s) of `obs`. Args: mapping: the renaming to apply, either as a dictionary with the old names as keys and the new names as values, or a function that takes an old name and returns a new name Returns: A new SingleCell dataset with the column(s) of `obs` renamed. """ return SingleCell(X=self._X, obs=self._obs.rename(mapping), var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def rename_var(self, mapping: dict[str, str] | Callable[[str], str], /) -> SingleCell: """ Create a new SingleCell dataset with column(s) of `var` renamed. Args: mapping: the renaming to apply, either as a dictionary with the old names as keys and the new names as values, or a function that takes an old name and returns a new name Returns: A new SingleCell dataset with the column(s) of `var` renamed. """ return SingleCell(X=self._X, obs=self._obs, var=self._var.rename(mapping), obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def rename_obsm(self, mapping: dict[str, str] | Callable[[str], str], /) -> SingleCell: """ Create a new SingleCell dataset with key(s) of `obsm` renamed. Args: mapping: the renaming to apply, either as a dictionary with the old names as keys and the new names as values, or a function that takes an old name and returns a new name Returns: A new SingleCell dataset with the key(s) of `obsm` renamed. """ check_types(mapping.keys(), 'mapping.keys()', str, 'strings') check_types(mapping.values(), 'mapping.values()', str, 'strings') if isinstance(mapping, dict): for key, new_key in mapping.items(): if key not in self._obsm: error_message = \ f'tried to rename {key!r}, which is not a key of obsm' raise ValueError(error_message) if new_key in self._obsm: error_message = ( f'tried to rename obsm[{key!r}] to obsm[{new_key!r}], ' f'but obsm[{new_key!r}] already exists') raise ValueError(error_message) obsm = {mapping.get(key, key): value for key, value in self._obsm.items()} elif isinstance(mapping, Callable): obsm = {} for key, value in self._obsm.items(): new_key = mapping(key) if not isinstance(new_key, str): error_message = ( f'tried to rename obsm[{key!r}] to a non-string value ' f'of type {type(new_key).__name__!r}') raise TypeError(error_message) if new_key in self._obsm: error_message = ( f'tried to rename obsm[{key!r}] to obsm[{new_key!r}], ' f'but obsm[{new_key!r}] already exists') raise ValueError(error_message) obsm[new_key] = value else: error_message = ( f'mapping must be a dictionary or function, but has type ' f'{type(mapping).__name__!r}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def rename_varm(self, mapping: dict[str, str] | Callable[[str], str], /) -> SingleCell: """ Create a new SingleCell dataset with key(s) of `varm` renamed. Args: mapping: the renaming to apply, either as a dictionary with the old names as keys and the new names as values, or a function that takes an old name and returns a new name Returns: A new SingleCell dataset with the key(s) of `varm` renamed. """ check_types(mapping.keys(), 'mapping.keys()', str, 'strings') check_types(mapping.values(), 'mapping.values()', str, 'strings') if isinstance(mapping, dict): for key, new_key in mapping.items(): if key not in self._varm: error_message = \ f'tried to rename {key!r}, which is not a key of varm' raise ValueError(error_message) if new_key in self._varm: error_message = ( f'tried to rename varm[{key!r}] to varm[{new_key!r}], ' f'but varm[{new_key!r}] already exists') raise ValueError(error_message) varm = {mapping.get(key, key): value for key, value in self._varm.items()} elif isinstance(mapping, Callable): varm = {} for key, value in self._varm.items(): new_key = mapping(key) if not isinstance(new_key, str): error_message = ( f'tried to rename varm[{key!r}] to a non-string value ' f'of type {type(new_key).__name__!r}') raise TypeError(error_message) if new_key in self._varm: error_message = ( f'tried to rename varm[{key!r}] to varm[{new_key!r}], ' f'but varm[{new_key!r}] already exists') raise ValueError(error_message) varm[new_key] = value else: error_message = ( f'mapping must be a dictionary or function, but has type ' f'{type(mapping).__name__!r}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def rename_obsp(self, mapping: dict[str, str] | Callable[[str], str], /) -> SingleCell: """ Create a new SingleCell dataset with key(s) of `obsp` renamed. Args: mapping: the renaming to apply, either as a dictionary with the old names as keys and the new names as values, or a function that takes an old name and returns a new name Returns: A new SingleCell dataset with the key(s) of `obsp` renamed. """ check_types(mapping.keys(), 'mapping.keys()', str, 'strings') check_types(mapping.values(), 'mapping.values()', str, 'strings') if isinstance(mapping, dict): for key, new_key in mapping.items(): if key not in self._obsp: error_message = \ f'tried to rename {key!r}, which is not a key of obsp' raise ValueError(error_message) if new_key in self._obsp: error_message = ( f'tried to rename obsp[{key!r}] to obsp[{new_key!r}], ' f'but obsp[{new_key!r}] already exists') raise ValueError(error_message) obsp = {mapping.get(key, key): value for key, value in self._obsp.items()} elif isinstance(mapping, Callable): obsp = {} for key, value in self._obsp.items(): new_key = mapping(key) if not isinstance(new_key, str): error_message = ( f'tried to rename obsp[{key!r}] to a non-string value ' f'of type {type(new_key).__name__!r}') raise TypeError(error_message) if new_key in self._obsp: error_message = ( f'tried to rename obsp[{key!r}] to obsp[{new_key!r}], ' f'but obsp[{new_key!r}] already exists') raise ValueError(error_message) obsp[new_key] = value else: error_message = ( f'mapping must be a dictionary or function, but has type ' f'{type(mapping).__name__!r}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def rename_varp(self, mapping: dict[str, str] | Callable[[str], str], /) -> SingleCell: """ Create a new SingleCell dataset with key(s) of `varp` renamed. Args: mapping: the renaming to apply, either as a dictionary with the old names as keys and the new names as values, or a function that takes an old name and returns a new name Returns: A new SingleCell dataset with the key(s) of `varp` renamed. """ check_types(mapping.keys(), 'mapping.keys()', str, 'strings') check_types(mapping.values(), 'mapping.values()', str, 'strings') if isinstance(mapping, dict): for key, new_key in mapping.items(): if key not in self._varp: error_message = \ f'tried to rename {key!r}, which is not a key of varp' raise ValueError(error_message) if new_key in self._varp: error_message = ( f'tried to rename varp[{key!r}] to varp[{new_key!r}], ' f'but varp[{new_key!r}] already exists') raise ValueError(error_message) varp = {mapping.get(key, key): value for key, value in self._varp.items()} elif isinstance(mapping, Callable): varp = {} for key, value in self._varp.items(): new_key = mapping(key) if not isinstance(new_key, str): error_message = ( f'tried to rename varp[{key!r}] to a non-string value ' f'of type {type(new_key).__name__!r}') raise TypeError(error_message) if new_key in self._varp: error_message = ( f'tried to rename varp[{key!r}] to varp[{new_key!r}], ' f'but varp[{new_key!r}] already exists') raise ValueError(error_message) varp[new_key] = value else: error_message = ( f'mapping must be a dictionary or function, but has type ' f'{type(mapping).__name__!r}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=varp, uns=self._uns, num_threads=self._num_threads)
[docs] def rename_uns(self, mapping: dict[str, str] | Callable[[str], str], /) -> SingleCell: """ Create a new SingleCell dataset with key(s) of `uns` renamed. Args: mapping: the renaming to apply, either as a dictionary with the old names as keys and the new names as values, or a function that takes an old name and returns a new name Returns: A new SingleCell dataset with the key(s) of `uns` renamed. """ check_types(mapping.keys(), 'mapping.keys()', str, 'strings') check_types(mapping.values(), 'mapping.values()', str, 'strings') if isinstance(mapping, dict): for key, new_key in mapping.items(): if key not in self._uns: error_message = \ f'tried to rename {key!r}, which is not a key of uns' raise ValueError(error_message) if new_key in self._uns: error_message = ( f'tried to rename uns[{key!r}] to uns[{new_key!r}], ' f'but uns[{new_key!r}] already exists') raise ValueError(error_message) uns = {mapping.get(key, key): value for key, value in self._uns.items()} elif isinstance(mapping, Callable): uns = {} for key, value in self._uns.items(): new_key = mapping(key) if not isinstance(new_key, str): error_message = ( f'tried to rename uns[{key!r}] to a non-string value ' f'of type {type(new_key).__name__!r}') raise TypeError(error_message) if new_key in self._uns: error_message = ( f'tried to rename uns[{key!r}] to uns[{new_key!r}], ' f'but uns[{new_key!r}] already exists') raise ValueError(error_message) uns[new_key] = value else: error_message = ( f'mapping must be a dictionary or function, but has type ' f'{type(mapping).__name__!r}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=uns, num_threads=self._num_threads)
[docs] def cast_X(self, dtype: np._typing.DTypeLike, /) -> SingleCell: """ Cast `X` to the specified data type. Args: dtype: a NumPy data type Returns: A new SingleCell dataset with `X` cast to the specified data type. """ # Check that `X` is present if self._X is None: error_message = 'X is None, so casting it is not possible' raise ValueError(error_message) return SingleCell(X=self._X.astype(dtype), obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def cast_obs(self, dtypes: Mapping[pl.type_aliases.ColumnNameOrSelector | pl.type_aliases.PolarsDataType, pl.type_aliases.PolarsDataType] | pl.type_aliases.PolarsDataType, /, *, strict: bool = True) -> SingleCell: """ Cast column(s) of `obs` to the specified data type(s). Args: dtypes: a mapping of column names (or selectors) to data types, or a single data type to which all columns will be cast strict: whether to raise an error if a cast could not be performed (for instance, due to numerical overflow) Returns: A new SingleCell dataset with column(s) of `obs` cast to the specified data type(s). """ return SingleCell(X=self._X, obs=self._obs.cast(dtypes, strict=strict), var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def cast_var(self, dtypes: Mapping[pl.type_aliases.ColumnNameOrSelector | pl.type_aliases.PolarsDataType, pl.type_aliases.PolarsDataType] | pl.type_aliases.PolarsDataType, /, *, strict: bool = True) -> SingleCell: """ Cast column(s) of `var` to the specified data type(s). Args: dtypes: a mapping of column names (or selectors) to data types, or a single data type to which all columns will be cast strict: whether to raise an error if a cast could not be performed (for instance, due to numerical overflow) Returns: A new SingleCell dataset with column(s) of `var` cast to the specified data type(s). """ return SingleCell(X=self._X, obs=self._obs, var=self._var.cast(dtypes, strict=strict), obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def join_obs(self, other: pl.DataFrame, /, *, on: str | pl.Expr | Sequence[str | pl.Expr] | None = None, left_on: str | pl.Expr | Sequence[str | pl.Expr] | None = None, right_on: str | pl.Expr | Sequence[str | pl.Expr] | None = None, suffix: str = '_right', validate: Literal['m:m', 'm:1', '1:m', '1:1'] = 'm:m', nulls_equal: bool = False, coalesce: bool = True) -> SingleCell: """ Left-join `obs` with another DataFrame, using the same logic as [`polars.DataFrame.join()`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html). Args: other: a polars DataFrame to join `obs` with on: the name(s) of the join column(s) in both DataFrames left_on: the name(s) of the join column(s) in `obs` right_on: the name(s) of the join column(s) in `other` suffix: a suffix to append to columns with a duplicate name validate: checks whether the join is of the specified type. Can be: - 'm:m' (many-to-many): the default, no checks performed. - '1:1' (one-to-one): check that none of the values in the join column(s) appear more than once in `obs` or more than once in `other`. - '1:m' (one-to-many): check that none of the values in the join column(s) appear more than once in `obs`. - 'm:1' (many-to-one): check that none of the values in the join column(s) appear more than once in `other`. nulls_equal: whether to include `null` as a valid value to join on. By default, `null` values will never produce matches. coalesce: if `True`, coalesce each of the pairs of join columns (the columns in `on` or `left_on`/`right_on`) from `obs` and `other` into a single column, filling missing values from one with the corresponding values from the other. If `False`, include both as separate columns, adding `suffix` to the join columns from `other`. Returns: A new SingleCell dataset with the columns from `other` joined to obs. Note: If a column of `on`, `left_on` or `right_on` is Enum in `obs` and Categorical in `other` (or vice versa), or Enum in both but with different categories in each, that pair of columns will be automatically cast to a common Enum data type (with the union of the categories) before joining. """ check_type(other, 'other', pl.DataFrame, 'a polars DataFrame') left = self._obs right = other if on is None: if left_on is None and right_on is None: error_message = ( "either 'on' or both of 'left_on' and 'right_on' must be " "specified") raise ValueError(error_message) elif left_on is None: error_message = \ 'right_on is specified, so left_on must be specified' raise ValueError(error_message) elif right_on is None: error_message = \ 'left_on is specified, so right_on must be specified' raise ValueError(error_message) left_columns = left.select(left_on) right_columns = right.select(right_on) else: if left_on is not None: error_message = "'on' is specified, so 'left_on' must be None" raise ValueError(error_message) if right_on is not None: error_message = "'on' is specified, so 'right_on' must be None" raise ValueError(error_message) left_columns = left.select(on) right_columns = right.select(on) left_cast_dict = {} right_cast_dict = {} for left_column, right_column in zip(left_columns, right_columns): left_dtype = left_column.dtype right_dtype = right_column.dtype if left_dtype == right_dtype: continue if (left_dtype == pl.Enum or left_dtype == pl.Categorical) and ( right_dtype == pl.Enum or right_dtype == pl.Categorical): common_dtype = \ pl.Enum(pl.concat([left_column.cat.get_categories(), right_column.cat.get_categories()]) .unique(maintain_order=True)) left_cast_dict[left_column.name] = common_dtype right_cast_dict[right_column.name] = common_dtype else: error_message = ( f'obs[{left_column.name!r}] has data type ' f'{left_dtype.base_type()!r}, but ' f'other[{right_column.name!r}] has data type ' f'{right_dtype.base_type()!r}') raise TypeError(error_message) if left_cast_dict is not None: left = left.cast(left_cast_dict) right = right.cast(right_cast_dict) obs = left.join(right, on=on, how='left', left_on=left_on, right_on=right_on, suffix=suffix, validate=validate, nulls_equal=nulls_equal, coalesce=coalesce, maintain_order='left') if len(obs) > len(self): other_on = to_tuple(right_on if right_on is not None else on) assert other.select(other_on).is_duplicated().any() duplicate_column = other_on[0] if len(other_on) == 1 else \ next(column for column in other_on if other[column].is_duplicated().any()) error_message = ( f'other[{duplicate_column!r}] contains duplicate values, so ' f'it must be deduplicated before being joined on') raise ValueError(error_message) return SingleCell(X=self._X, obs=obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def join_var(self, other: pl.DataFrame, /, *, on: str | pl.Expr | Sequence[str | pl.Expr] | None = None, left_on: str | pl.Expr | Sequence[str | pl.Expr] | None = None, right_on: str | pl.Expr | Sequence[str | pl.Expr] | None = None, suffix: str = '_right', validate: Literal['m:m', 'm:1', '1:m', '1:1'] = 'm:m', nulls_equal: bool = False, coalesce: bool = True) -> SingleCell: """ Left-join `var` with another DataFrame, using the same logic as [`polars.DataFrame.join()`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html). Args: other: a polars DataFrame to join `var` with on: the name(s) of the join column(s) in both DataFrames left_on: the name(s) of the join column(s) in `var` right_on: the name(s) of the join column(s) in `other` suffix: a suffix to append to columns with a duplicate name validate: checks whether the join is of the specified type. Can be: - 'm:m' (many-to-many): the default, no checks performed. - '1:1' (one-to-one): check that none of the values in the join column(s) appear more than once in `var` or more than once in `other`. - '1:m' (one-to-many): check that none of the values in the join column(s) appear more than once in `var`. - 'm:1' (many-to-one): check that none of the values in the join column(s) appear more than once in `other`. nulls_equal: whether to include `null` as a valid value to join on. By default, `null` values will never produce matches. coalesce: if `True`, coalesce each of the pairs of join columns (the columns in `on` or `left_on`/`right_on`) from `obs` and `other` into a single column, filling missing values from one with the corresponding values from the other. If `False`, include both as separate columns, adding `suffix` to the join columns from `other`. Returns: A new SingleCell dataset with the columns from `other` joined to var. Note: If a column of `on`, `left_on` or `right_on` is Enum in `obs` and Categorical in `other` (or vice versa), or Enum in both but with different categories in each, that pair of columns will be automatically cast to a common Enum data type (with the union of the categories) before joining. """ check_type(other, 'other', pl.DataFrame, 'a polars DataFrame') left = self._var right = other if on is None: if left_on is None and right_on is None: error_message = ( "either 'on' or both of 'left_on' and 'right_on' must be " "specified") raise ValueError(error_message) elif left_on is None: error_message = \ 'right_on is specified, so left_on must be specified' raise ValueError(error_message) elif right_on is None: error_message = \ 'left_on is specified, so right_on must be specified' raise ValueError(error_message) left_columns = left.select(left_on) right_columns = right.select(right_on) else: if left_on is not None: error_message = "'on' is specified, so 'left_on' must be None" raise ValueError(error_message) if right_on is not None: error_message = "'on' is specified, so 'right_on' must be None" raise ValueError(error_message) left_columns = left.select(on) right_columns = right.select(on) left_cast_dict = {} right_cast_dict = {} for left_column, right_column in zip(left_columns, right_columns): left_dtype = left_column.dtype right_dtype = right_column.dtype if left_dtype == right_dtype: continue if (left_dtype == pl.Enum or left_dtype == pl.Categorical) and ( right_dtype == pl.Enum or right_dtype == pl.Categorical): common_dtype = \ pl.Enum(pl.concat([left_column.cat.get_categories(), right_column.cat.get_categories()]) .unique(maintain_order=True)) left_cast_dict[left_column.name] = common_dtype right_cast_dict[right_column.name] = common_dtype else: error_message = ( f'var[{left_column.name!r}] has data type ' f'{left_dtype.base_type()!r}, but ' f'other[{right_column.name!r}] has data type ' f'{right_dtype.base_type()!r}') raise TypeError(error_message) if left_cast_dict is not None: left = left.cast(left_cast_dict) right = right.cast(right_cast_dict) var = left.join(right, on=on, how='left', left_on=left_on, right_on=right_on, suffix=suffix, validate=validate, nulls_equal=nulls_equal, coalesce=coalesce, maintain_order='left') if len(var) > len(self): other_on = to_tuple(right_on if right_on is not None else on) assert other.select(other_on).is_duplicated().any() duplicate_column = other_on[0] if len(other_on) == 1 else \ next(column for column in other_on if other[column].is_duplicated().any()) error_message = ( f'other[{duplicate_column!r}] contains duplicate values, so ' f'it must be deduplicated before being joined on') raise ValueError(error_message) return SingleCell(X=self._X, obs=self._obs, var=var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def peek_obs(self, *, row: int = 0) -> None: """ Print a row of `obs` (the first row, by default) with each column on its own line. Args: row: the index of the row to print """ check_type(row, 'row', int, 'an integer') with pl.Config(tbl_rows=-1): print(self._obs[row].unpivot(variable_name='column'))
[docs] def peek_var(self, *, row: int = 0) -> None: """ Print a row of `var` (the first row, by default) with each column on its own line. Args: row: the index of the row to print """ check_type(row, 'row', int, 'an integer') with pl.Config(tbl_rows=-1): print(self._var[row].unpivot(variable_name='column'))
[docs] def subsample_obs(self, *, n: int | np.integer | None = None, fraction: int | float | np.integer | np.floating | None = None, QC_column: SingleCellColumn | None = 'passed_QC', by_column: SingleCellColumn | None = None, subsample_column: str | None = None, seed: int | np.integer = 0, overwrite: bool = False, num_threads: int | np.integer | None = None) -> \ SingleCell: """ Subsample a specific number or fraction of cells. Args: n: the number of cells to return; mutually exclusive with `fraction` fraction: the fraction of cells to return; mutually exclusive with `n` QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will not be selected when subsampling, and will not count towards the denominator of `fraction`; QC_column will not appear in the returned SingleCell dataset, since it would be redundant. by_column: an optional String, Enum, Categorical, or integer column of `obs` to subsample by. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Specifying `by_column` ensures that the same fraction of cells with each value of `by_column` are subsampled. When combined with `n`, to make sure the total number of samples is exactly `n`, some of the smallest groups may be oversampled by one element, or some of the largest groups may be undersampled by one element. Can contain `null` entries: the corresponding cells will not be included in the result. subsample_column: an optional name of a Boolean column to add to obs indicating the subsampled cells; if `None`, subset to these cells instead seed: the random seed to use when subsampling overwrite: if `True`, overwrite `subsample_column` if already present in `obs`, instead of raising an error. Must be `False` when `subsample_column` is `None`. num_threads: the number of threads to use when subsetting `X` to the sampled cells. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Does not affect the subsampled SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Can only be specified when `subsample_column` is `None`. Returns: A new SingleCell dataset subset to the subsampled cells, or if `subsample_column` is specified, the full dataset with `subsample_column` added to `obs`. If `QC_column` is a string and a QC column exists in the original dataset, it will be removed from the subsampled dataset, since all subsampled cells pass QC and it would be redundant. """ check_type(overwrite, 'overwrite', bool, 'Boolean') if subsample_column is not None: check_type(subsample_column, 'subsample_column', str, 'a string') if not overwrite and subsample_column in self._obs: error_message = ( f'subsample_column {subsample_column!r} is already a ' f'column of obs; did you already run subsample_obs()? Set ' f'overwrite=True to overwrite.') raise ValueError(error_message) elif overwrite: error_message = \ 'overwrite must be False when subsample_column is None' raise ValueError(error_message) if n is not None and fraction is not None: error_message = 'only one of n and fraction can be specified' raise ValueError(error_message) if n is not None: check_type(n, 'n', int, 'a positive integer') check_bounds(n, 'n', 1) elif fraction is not None: check_type(fraction, 'fraction', float, 'a floating-point number between 0 and 1') check_bounds(fraction, 'fraction', 0, 1, left_open=True, right_open=True) else: error_message = 'either n or fraction must be specified' raise ValueError(error_message) # Subsampling subsets to QCed cells by definition, so we should drop # the QC column from the subsampled dataset at the end - but only if # it was passed as an explicit column name. if QC_column is not None: QC_column_is_string = isinstance(QC_column, str) QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') check_type(seed, 'seed', int, 'an integer') if subsample_column is None: num_threads = self._process_num_threads(num_threads) else: if num_threads is not None: error_message = ( 'num_threads can only be specified when subsample_column ' 'is None') raise ValueError(error_message) # Get the indices of cells passing QC, used both for the index-based # fast path and to get `N` cheaply. if QC_column is not None: QC_indices = np.flatnonzero(QC_column.to_numpy()) N = len(QC_indices) else: N = len(self._obs) # When there is a `QC_column`, the boolean mask and indices we compute # below live in "QCed space" (`[0, number of cells passing QC)`); we # back-project to the full dataset afterwards if by_column is not None: # Grouped: use polars' `shuffle().over()` by_column = self._get_column( 'obs', by_column, 'by_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=QC_column, allow_null=True) if QC_column is not None: by_column = by_column.filter(QC_column) by_frame = by_column.to_frame() by_name = by_column.name if n is not None: # Get a vector of the number of elements to sample per group. # The total sample size should exactly match the original n; if # necessary, oversample the smallest groups or undersample the # largest groups to make this happen. group_counts = by_frame\ .group_by(by_name)\ .agg(pl.len(), n=(n / len(by_column) * pl.len()) .round().cast(pl.Int32))\ .drop_nulls(by_name) diff = n - group_counts['n'].sum() if diff != 0: group_counts = group_counts\ .sort('len', descending=diff < 0)\ .with_columns(n=pl.col.n + pl.int_range(pl.len(), dtype=pl.Int32) .lt(abs(diff)).cast(pl.Int32) * pl.lit(diff).sign()) selected = by_frame\ .join(group_counts, on=by_name)\ .select(pl.int_range(pl.len(), dtype=pl.Int32) .shuffle(seed=seed) .over(by_name) .lt(pl.col.n))\ .to_series() else: selected = by_frame\ .select(pl.int_range(pl.len(), dtype=pl.Int32) .shuffle(seed=seed) .over(by_name) .lt((fraction * pl.len().over(by_name)).round()))\ .to_series() if subsample_column is None: # Get the (sorted) indices of the subsampled cells indices = np.flatnonzero(selected.to_numpy()) if QC_column is not None: # Back-project to the full dataset indices = QC_indices[indices] elif QC_column is not None: # Back-project to the full dataset selected = pl.when(QC_column)\ .then(selected.gather(QC_column.cum_sum() - QC_column)) else: # Ungrouped: `np.random.choice()` if n is None: n = int(round(fraction * N)) indices = \ np.random.default_rng(seed).choice(N, size=n, replace=False) if subsample_column is None: # Sort to preserve the original ordering of obs indices.sort() if QC_column is not None: # Back-project to the full dataset indices = QC_indices[indices] else: selected = np.zeros(N, dtype=bool) selected[indices] = True selected = pl.Series(selected) if QC_column is not None: selected = pl.when(QC_column)\ .then(selected.gather(QC_column.cum_sum() - QC_column)) if subsample_column is None: original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads sc = self[indices] sc._X.num_threads = sc._num_threads # i.e. self._num_threads finally: self._X._num_threads = original_num_threads else: sc = self.with_columns_obs(selected.alias(subsample_column)) # Drop the now-redundant QC column, since all subsampled cells pass QC if QC_column is not None and QC_column_is_string: sc._obs = sc._obs.drop(QC_column.name) return sc
[docs] def subsample_var(self, *, n: int | np.integer | None = None, fraction: int | float | np.integer | np.floating | None = None, by_column: SingleCellColumn | None = None, subsample_column: str | None = None, seed: int | np.integer = 0, overwrite: bool = False, num_threads: int | np.integer | None = None) -> \ SingleCell: """ Subsample a specific number or fraction of genes. Args: n: the number of genes to return; mutually exclusive with `fraction` fraction: the fraction of genes to return; mutually exclusive with `n` by_column: an optional String, Enum, Categorical, or integer column of `var` to subsample by. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Specifying `by_column` ensures that the same fraction of genes with each value of `by_column` are subsampled. When combined with `n`, to make sure the total number of samples is exactly `n`, some of the smallest groups may be oversampled by one element, or some of the largest groups may be undersampled by one element. Can contain `null` entries: the corresponding genes will not be included in the result. subsample_column: an optional name of a Boolean column to add to var indicating the subsampled genes; if `None`, subset to these genes instead seed: the random seed to use when subsampling overwrite: if `True`, overwrite `subsample_column` if already present in `var`, instead of raising an error. Must be `False` when `subsample_column` is `None`. num_threads: the number of threads to use when subsetting `X` to the sampled genes. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`. By default (`num_threads=None`), use `self.num_threads` cores. Does not affect the subsampled SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Can only be specified when `subsample_column` is `None`. Returns: A new SingleCell dataset subset to the subsampled genes, or if `subsample_column` is specified, the full dataset with `subsample_column` added to `var`. """ check_type(overwrite, 'overwrite', bool, 'Boolean') if subsample_column is not None: check_type(subsample_column, 'subsample_column', str, 'a string') if not overwrite and subsample_column in self._var: error_message = ( f'subsample_column {subsample_column!r} is already a ' f'column of var; did you already run subsample_var()? Set ' f'overwrite=True to overwrite.') raise ValueError(error_message) elif overwrite: error_message = \ 'overwrite must be False when subsample_column is None' raise ValueError(error_message) if n is not None and fraction is not None: error_message = 'only one of n and fraction can be specified' raise ValueError(error_message) if n is not None: check_type(n, 'n', int, 'a positive integer') check_bounds(n, 'n', 1) elif fraction is not None: check_type(fraction, 'fraction', float, 'a floating-point number between 0 and 1') check_bounds(fraction, 'fraction', 0, 1, left_open=True, right_open=True) else: error_message = 'either n or fraction must be specified' raise ValueError(error_message) check_type(seed, 'seed', int, 'an integer') if subsample_column is None: num_threads = self._process_num_threads(num_threads) else: if num_threads is not None: error_message = ( 'num_threads can only be specified when subsample_column ' 'is None') raise ValueError(error_message) if by_column is not None: # Grouped: use polars' `shuffle().over()` by_column = self._get_column( 'var', by_column, 'by_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), allow_null=True) by_frame = by_column.to_frame() by_name = by_column.name if n is not None: # Get a vector of the number of elements to sample per group. # The total sample size should exactly match the original n; if # necessary, oversample the smallest groups or undersample the # largest groups to make this happen. group_counts = by_frame\ .group_by(by_name)\ .agg(pl.len(), n=(n / len(by_column) * pl.len()) .round().cast(pl.Int32))\ .drop_nulls(by_name) diff = n - group_counts['n'].sum() if diff != 0: group_counts = group_counts\ .sort('len', descending=diff < 0)\ .with_columns(n=pl.col.n + pl.int_range(pl.len(), dtype=pl.Int32) .lt(abs(diff)).cast(pl.Int32) * pl.lit(diff).sign()) selected = by_frame\ .join(group_counts, on=by_name)\ .select(pl.int_range(pl.len(), dtype=pl.Int32) .shuffle(seed=seed) .over(by_name) .lt(pl.col.n))\ .to_series() else: selected = by_frame\ .select(pl.int_range(pl.len(), dtype=pl.Int32) .shuffle(seed=seed) .over(by_name) .lt((fraction * pl.len().over(by_name)).round()))\ .to_series() if subsample_column is None: indices = np.flatnonzero(selected.to_numpy()) else: # Ungrouped: `np.random.choice()` N = len(self._var) if n is None: n = int(round(fraction * N)) indices = \ np.random.default_rng(seed).choice(N, size=n, replace=False) if subsample_column is None: indices.sort() else: selected = np.zeros(N, dtype=bool) selected[indices] = True selected = pl.Series(selected) if subsample_column is None: original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads sc = self[:, indices] sc._X.num_threads = sc._num_threads # i.e. self._num_threads finally: self._X._num_threads = original_num_threads else: sc = self.with_columns_var(selected.alias(subsample_column))
[docs] def pipe(self, function: Callable[[SingleCell, ...], Any], /, *args: Any, **kwargs: Any) -> Any: """ Apply a function to a SingleCell dataset. `sc.pipe(func)` is equivalent to `func(sc)`. `sc.pipe(func, 1, a=2)` is equivalent to `func(sc, 1, a=2)`. Args: function: the function to apply to the SingleCell dataset. It must take a SingleCell dataset as its first argument, and can return any value. The function may also allow other arguments after the count matrix, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: The result of applying the function to this SingleCell dataset. """ # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return function(self, *args, **kwargs)
[docs] def pipe_X(self, function: Callable[[csr_array | csc_array, ...], csr_array | csc_array], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a SingleCell dataset's `X`. `sc = sc.pipe_X(func)` is equivalent to `sc.X = func(sc.X)`. `sc = sc.pipe_X(func, 1, a=2)` is equivalent to `sc.X = func(sc.X, 1, a=2)`. Args: function: the function to apply to `X`. It must take the old `X` as its first argument and return the new `X`. The function may also take other arguments after `X`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to `X`. """ # Check that `X` is present if self._X is None: error_message = 'X is None, so piping it is not possible' raise ValueError(error_message) # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=function(self._X, *args, **kwargs), obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_obs(self, function: Callable[[pl.DataFrame, ...], pl.DataFrame], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a SingleCell dataset's `obs`. `sc = sc.pipe_obs(func)` is equivalent to `sc.obs = func(sc.obs)`. `sc = sc.pipe_obs(func, 1, a=2)` is equivalent to `sc.obs = func(sc.obs, 1, a=2)`. Args: function: the function to apply to `obs`. It must take the old `obs` as its first argument and return the new `obs`. The function may also take other arguments after `obs`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to obs. """ # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=function(self._obs, *args, **kwargs), var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_var(self, function: Callable[[pl.DataFrame, ...], pl.DataFrame], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a SingleCell dataset's `var`. `sc = sc.pipe_var(func)` is equivalent to `sc.var = func(sc.var)`. `sc = sc.pipe_var(func, 1, a=2)` is equivalent to `sc.var = func(sc.var, 1, a=2)`. Args: function: the function to apply to `var`. It must take the old `var` as its first argument and return the new `var`. The function may also take other arguments after `var`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to var. """ # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=function(self._var, *args, **kwargs), obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_obsm(self, function: Callable[[dict[str, np.ndarray | pl.DataFrame], ...], dict[str, np.ndarray | pl.DataFrame]], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a SingleCell dataset's `obsm`. `sc = sc.pipe_obsm(func)` is equivalent to `sc.obsm = func(sc.obsm)`. `sc = sc.pipe_obsm(func, 1, a=2)` is equivalent to `sc.obsm = func(sc.obsm, 1, a=2)`. To apply a function to a specific key of `obsm`, rather than to `obsm` as a whole, use `pipe_obsm_key()`. Args: function: the function to apply to `obsm`. It must take the old `obsm` as its first argument and return the new `obsm`. The function may also take other arguments after `obsm`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to obsm. """ # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=function(self._obsm, *args, **kwargs), varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_obsm_key(self, key: str, function: Callable[[np.ndarray | pl.DataFrame, ...], np.ndarray | pl.DataFrame], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a specific key in a SingleCell dataset's `obsm`. `sc = sc.pipe_obsm_key(func, key)` is equivalent to `sc.obsm[key] = func(sc.obsm[key])`. `sc = sc.pipe_obsm_key(key, func, 1, a=2)` is equivalent to `sc.obsm[key] = func(sc.obsm[key], 1, a=2)`. To apply a function to `obsm` as a whole, rather than to a specific key of `obsm`, use `pipe_obsm()`. Args: key: the key in `obsm` to which the function will be applied. function: the function to apply to `obsm[key]`. It must take the old `obsm[key]` as its first argument and return the new `obsm[key]`. The function may also take other arguments after `obsm[key]`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to `obsm[key]`. """ # Check that `key` is a key in `obsm` check_type(key, 'key', str, 'a string') if key not in self._obsm: error_message = f'{key!r} is not a key in obsm' raise ValueError(error_message) # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm | { key: function(self._obsm[key], *args, **kwargs)}, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_varm(self, function: Callable[[dict[str, np.ndarray | pl.DataFrame], ...], dict[str, np.ndarray | pl.DataFrame]], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a SingleCell dataset's `varm`. `sc = sc.pipe_varm(func)` is equivalent to `sc.varm = func(sc.varm)`. `sc = sc.pipe_varm(func, 1, a=2)` is equivalent to `sc.varm = func(sc.varm, 1, a=2)`. To apply a function to a specific key of `varm`, rather than to `varm` as a whole, use `pipe_varm_key()`. Args: function: the function to apply to `varm`. It must take the old `varm` as its first argument and return the new `varm`. The function may also take other arguments after `varm`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to varm. """ # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=function(self._varm, *args, **kwargs), obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_varm_key(self, key: str, function: Callable[[np.ndarray | pl.DataFrame, ...], np.ndarray | pl.DataFrame], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a specific key in a SingleCell dataset's `varm`. `sc = sc.pipe_varm_key(func, key)` is equivalent to `sc.varm[key] = func(sc.varm[key])`. `sc = sc.pipe_varm_key(key, func, 1, a=2)` is equivalent to `sc.varm[key] = func(sc.varm[key], 1, a=2)`. To apply a function to `varm` as a whole, rather than to a specific key of `varm`, use `pipe_varm()`. Args: key: the key in `varm` to which the function will be applied. function: the function to apply to `varm[key]`. It must take the old `varm[key]` as its first argument and return the new `varm[key]`. The function may also take other arguments after `varm[key]`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to `varm[key]`. """ # Check that `key` is a key in `varm` check_type(key, 'key', str, 'a string') if key not in self._varm: error_message = f'{key!r} is not a key in varm' raise ValueError(error_message) # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm | { key: function(self._varm[key], *args, **kwargs)}, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_obsp(self, function: Callable[[dict[str, csr_array | csc_array], ...], dict[str, csr_array | csc_array]], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a SingleCell dataset's `obsp`. `sc = sc.pipe_obsp(func)` is equivalent to `sc.obsp = func(sc.obsp)`. `sc = sc.pipe_obsp(func, 1, a=2)` is equivalent to `sc.obsp = func(sc.obsp, 1, a=2)`. Args: function: the function to apply to `obsp`. It must take the old `obsp` as its first argument and return the new `obsp`. The function may also take other arguments after `obsp`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function To apply a function to a specific key of `obsp`, rather than to `obsp` as a whole, use `pipe_obsp_key()`. Returns: A new SingleCell dataset where the function has been applied to obsp. """ # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=function(self._obsp, *args, **kwargs), varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_obsp_key(self, key: str, function: Callable[[csr_array | csc_array, ...], csr_array | csc_array], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a specific key in a SingleCell dataset's `obsp`. `sc = sc.pipe_obsp_key(func, key)` is equivalent to `sc.obsp[key] = func(sc.obsp[key])`. `sc = sc.pipe_obsp_key(key, func, 1, a=2)` is equivalent to `sc.obsp[key] = func(sc.obsp[key], 1, a=2)`. To apply a function to `obsp` as a whole, rather than to a specific key of `obsp`, use `pipe_obsp()`. Args: key: the key in `obsp` to which the function will be applied. function: the function to apply to `obsp[key]`. It must take the old `obsp[key]` as its first argument and return the new `obsp[key]`. The function may also take other arguments after `obsp[key]`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to `obsp[key]`. """ # Check that `key` is a key in `obsp` check_type(key, 'key', str, 'a string') if key not in self._obsp: error_message = f'{key!r} is not a key in obsp' raise ValueError(error_message) # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp | { key: function(self._obsp[key], *args, **kwargs)}, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_varp(self, function: Callable[[dict[str, csr_array | csc_array], ...], dict[str, csr_array | csc_array]], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a SingleCell dataset's `varp`. `sc = sc.pipe_varp(func)` is equivalent to `sc.varp = func(sc.varp)`. `sc = sc.pipe_varp(func, 1, a=2)` is equivalent to `sc.varp = func(sc.varp, 1, a=2)`. To apply a function to a specific key of `varp`, rather than to `varp` as a whole, use `pipe_varp_key()`. Args: function: the function to apply to `varp`. It must take the old `varp` as its first argument and return the new `varp`. The function may also take other arguments after `varp`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to varp. """ # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=function(self._varp, *args, **kwargs), uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_varp_key(self, key: str, function: Callable[[csr_array | csc_array, ...], csr_array | csc_array], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a specific key in a SingleCell dataset's `varp`. `sc = sc.pipe_varp_key(func, key)` is equivalent to `sc.varp[key] = func(sc.varp[key])`. `sc = sc.pipe_varp_key(key, func, 1, a=2)` is equivalent to `sc.varp[key] = func(sc.varp[key], 1, a=2)`. To apply a function to `varp` as a whole, rather than to a specific key of `varp`, use `pipe_varp()`. Args: key: the key in `varp` to which the function will be applied. function: the function to apply to `varp[key]`. It must take the old `varp[key]` as its first argument and return the new `varp[key]`. The function may also take other arguments after `varp[key]`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to `varp[key]`. """ # Check that `key` is a key in `varp` check_type(key, 'key', str, 'a string') if key not in self._varp: error_message = f'{key!r} is not a key in varp' raise ValueError(error_message) # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp | { key: function(self._varp[key], *args, **kwargs)}, uns=self._uns, num_threads=self._num_threads)
[docs] def pipe_uns(self, function: Callable[[UnsDict, ...], UnsDict], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a SingleCell dataset's `uns`. `sc = sc.pipe_uns(func)` is equivalent to `sc.uns = func(sc.uns)`. `sc = sc.pipe_uns(func, 1, a=2)` is equivalent to `sc.uns = func(sc.uns, 1, a=2)`. To apply a function to a specific key of `uns`, rather than to `uns` as a whole, use `pipe_uns_key()`. Args: function: the function to apply to `uns`. It must take the old `uns` as its first argument and return the new `uns`. The function may also take other arguments after `uns`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to uns. """ # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=function(self._uns, *args, **kwargs), num_threads=self._num_threads)
[docs] def pipe_uns_key(self, key: str, function: Callable[[UnsDict, ...], UnsDict], /, *args: Any, **kwargs: Any) -> SingleCell: """ Apply a function to a specific key in a SingleCell dataset's `uns`. `sc = sc.pipe_uns_key(func, key)` is equivalent to `sc.uns[key] = func(sc.uns[key])`. `sc = sc.pipe_uns_key(key, func, 1, a=2)` is equivalent to `sc.uns[key] = func(sc.uns[key], 1, a=2)`. To apply a function to `uns` as a whole, rather than to a specific key of `uns`, use `pipe_uns()`. Args: key: the key in `uns` to which the function will be applied. function: the function to apply to `uns[key]`. It must take the old `uns[key]` as its first argument and return the new `uns[key]`. The function may also take other arguments after `uns[key]`, which can be specified via `args` and `kwargs`. *args: the positional arguments to the function **kwargs: the keyword arguments to the function Returns: A new SingleCell dataset where the function has been applied to `uns[key]`. """ # Check that `key` is a key in `uns` check_type(key, 'key', str, 'a string') if key not in self._uns: error_message = f'{key!r} is not a key in uns' raise ValueError(error_message) # Check that `function` is callable if not callable(function): error_message = ( f'function is not callable; it has type ' f'{type(function).__name__}') raise TypeError(error_message) return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns | { key: function(self._uns[key], *args, **kwargs)}, num_threads=self._num_threads)
[docs] def qc_metrics(self, *, num_counts_column: str = 'num_counts', num_genes_column: str = 'num_genes', mito_fraction_column: str = 'mito_fraction', allow_float: bool = False, overwrite: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Adds quality-control metrics to `obs` for each cell: the sum of counts across all genes (`num_counts`), the number of genes with non-zero expression (`num_genes`), and the fraction of counts that are mitochondrial (`mito_fraction`). This function is intended to be run before `qc()` for users interested in better understanding the quality of their dataset. It is not a required step, since `qc()` calculates its own filters internally. Args: num_counts_column: the name of an integer column to be added to `obs` containing each cell's sum of counts across all genes num_genes_column: the name of an integer column to be added to `obs` containing each cell's number of genes with non-zero expression mito_fraction_column: the name of an integer column to be added to `obs` containing each cell's fraction of counts that are mitochondrial (i.e. from genes starting with `'MT'`) allow_float: if `False`, raise an error if `self.X.dtype` is floating-point (suggesting the user may not be using the raw counts); if `True`, disable this sanity check. Note that all steps except mitochondrial percent filtering give the same result on normalized counts, so as long as `max_mito_fraction=None` is specified (not typically recommended), this function will give the same result on raw and normalized counts. overwrite: if `False`, raise an error if any of the new columns already exist in `obs`; if `True`, overwrite them. num_threads: the number of threads to use when calculating the quality-control metrics. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the resulting SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset with the three metrics added as columns of `obs`. Note: This function will give an incorrect output when run on normalized data, since floating-point counts will be truncated to integers. Note: This function may give an incorrect output if the count matrix contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is not checked for, due to speed considerations. In the unlikely event that your dataset contains explicit zeros, remove them by running `sc.X.eliminate_zeros()` (an in-place operation) first. """ check_type(overwrite, 'overwrite', bool, 'Boolean') for column_name, column in ( ('num_counts_column', num_counts_column), ('num_genes_column', num_genes_column), ('mito_fraction_column', mito_fraction_column)): check_type(column, column_name, str, 'a string') if not overwrite and column in self._obs: error_message = ( f'{column_name} {column!r} is already a column of obs; ' f'did you already run qc_metrics()? Set overwrite=True to ' f'overwrite.') raise ValueError(error_message) check_type(allow_float, 'allow_float', bool, 'Boolean') num_threads = self._process_num_threads(num_threads) # Check that `X` is present X = self._X if X is None: error_message = \ 'X is None, so calculating QC metrics is not possible' raise ValueError(error_message) # If `allow_float=False`, raise an error if `X` is floating-point if not allow_float and np.issubdtype(X.dtype, np.floating): error_message = ( f'qc_metrics() requires raw counts but X has data type ' f'{str(X.dtype)!r}, a floating-point data type. If you are ' f'sure that all values are raw integer counts, i.e. that ' f'(X.data == X.data.astype(int)).all(), then set ' f'allow_float=True.') raise TypeError(error_message) # Check that `obs_names` and `var_names` are unique num_unique = self.obs_names.n_unique() if num_unique < len(self._obs): error_message = ( f'obs_names contains {num_unique - len(self._obs):,} ' f'duplicates; deduplicate with make_obs_names_unique()') raise ValueError(error_message) num_unique = self.var_names.n_unique() if num_unique < len(self._var): error_message = ( f'var_names contains {num_unique - len(self._var):,} ' f'duplicates; deduplicate with make_var_names_unique()') raise ValueError(error_message) # Compute total counts per cell num_cells = X.shape[0] num_counts = np.empty(num_cells, dtype=np.uint32) num_genes = np.empty(num_cells, dtype=np.uint32) mito_fraction = np.empty(num_cells, dtype=np.float32) var_names = self.var_names if var_names.dtype != pl.String: var_names = var_names.cast(pl.String) mt_genes = var_names.str.to_uppercase().str.starts_with('MT-') if isinstance(X, csr_array): qc_metrics_csr(data=X.data, indices=X.indices, indptr=X.indptr, mt_genes=mt_genes.to_numpy(), num_counts=num_counts, num_genes_per_cell=num_genes, mito_fraction=mito_fraction, num_threads=num_threads) else: qc_metrics_csc(data=X.data, indices=X.indices, indptr=X.indptr, mt_genes=mt_genes.to_numpy(), num_counts=num_counts, num_genes_per_cell=num_genes, mito_fraction=mito_fraction, num_threads=num_threads) return self.with_columns_obs( pl.Series(num_counts_column, num_counts), pl.Series(num_genes_column, num_genes), pl.Series(mito_fraction_column, mito_fraction))
[docs] def qc(self, *, custom_filter: SingleCellColumn | None = None, subset: bool = False, QC_column: str = 'passed_QC', max_mito_fraction: int | float | np.integer | np.floating | None = 0.05, min_genes: int | np.integer | None = 100, nonzero_MALAT1: bool = True, remove_doublets: bool = False, batch_column: SingleCellColumn | None = None, doublet_fraction: float | np.floating | None = None, num_doublet_genes: int | np.integer = 500, allow_float: bool = False, overwrite: bool = False, verbose: bool = True, num_threads: int | np.integer | None = None) -> SingleCell: """ Adds a Boolean column to `obs` indicating which cells passed quality control (QC), or subsets to these cells if `subset=True`. By default, filters to cells with ≤5% mitochondrial reads, ≥100 genes detected, and non-zero MALAT1 or Malat1 expression. Can also filter out doublets when `remove_doublets=True`. Raises an error if any cell names appear more than once in `obs_names` (they can be deduplicated with `make_obs_names_unique()`) or any gene names appear more than once in `var_names` (they can be deduplicated with `make_var_names_unique()`). Args: custom_filter: an optional Boolean column of `obs` containing a filter to apply on top of the other QC filters; `True` elements will be kept. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. subset: whether to subset to cells passing QC, instead of merely adding a `QC_column` to `obs`. This will roughly double memory usage, but speed up subsequent operations. QC_column: the name of a Boolean column to add to `obs` indicating which cells passed QC, if `subset=False`. Gives an error if `obs` already has a column with this name, unless `overwrite=True`. max_mito_fraction: if not `None`, filter to cells with <= this fraction of mitochondrial counts (i.e. from genes starting with `'MT'`. The default of 5% matches Seurat's recommended value. min_genes: if not `None`, filter to cells with >= this many genes detected (with non-zero count). The default of 100 matches Scanpy's recommended value, while Seurat recommends a minimum of 200. nonzero_MALAT1: if `True`, filter out cells with 0 expression of the nuclear-expressed lncRNA MALAT1, which [likely represent](https://biorxiv.org/content/10.1101/2024.07.14.603469v1) empty droplets or poor-quality cells. There must be exactly one gene in `obs_names` with the name `'MALAT1'` or `'Malat1'` to use this filter. remove_doublets: if `True`, remove predicted doublets (see `find_doublets()`). Doublet detection uses the cxds algorithm to score each cell, then thresholds this continuous score to a binary one (doublet versus non-doublet) using a threshold derived from simulated doublets. batch_column: an optional String, Enum, Categorical, or integer column of `obs` indicating which batch each cell is from. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Only used during doublet detection; doublet detection will be performed separately for each batch; cells where `batch_column` is `null` will collectively be treated as a single batch. Set to `None` if all cells belong to the same sequencing batch. Can only be specified when `remove_doublets=True`. doublet_fraction: an optional fraction of cells (within each batch, if `batch_column` is specified) to be classified as doublets. If `None`, automatically detect the threshold via the approach described in `find_doublets()`. num_doublet_genes: the number of highly variable genes, i.e. genes expressed in as close to 50% of cells as possible, to use during doublet detection. This parameter usually has a minimal influence on accuracy as long as it is sufficiently large (in the hundreds), so increasing it further will mainly just increase runtime. If `num_doublet_genes` is greater than the number of genes in the dataset, all genes will be used. allow_float: if `False`, raise an error if `self.X.dtype` is floating-point (suggesting the user may not be using the raw counts); if `True`, disable this sanity check. Note that all steps except mitochondrial percent filtering give the same result on normalized counts, so if `max_mito_fraction=None` were specified (not recommended), this function would give the same result on raw and normalized counts. overwrite: if `False`, raise an error if `uns['QCed']` is `True` (indicating the dataset has already been QCed) or `QC_column` is already present in `obs`; if `True`, disable these two sanity checks and, when `subset=False`, overwrite `QC_column` if present. verbose: whether to print how many cells were filtered out at each step of the QC process num_threads: the number of threads to use when filtering based on mitochondrial counts and MALAT1 expression, and for doublet detection. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the QCed SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset with `QC_column` added to `obs` (or subset to QCed cells if `subset=True`) and `uns['QCed']` set to `True`. """ # Check that `QC_column` is a string check_type(QC_column, 'QC_column', str, 'a string') # Check that `overwrite` is Boolean; check_type(overwrite, 'overwrite', bool, 'Boolean') # If `overwrite=False`, check that `uns['QCed']=False` and that # `QC_column` is not present in `obs` if not overwrite: if self._uns['QCed']: error_message = ( "uns['QCed'] is True; did you already run qc()? Specify " "overwrite=True, set uns['QCed'] = False, or run " "with_uns(QCed=False) to bypass this check.") raise ValueError(error_message) if QC_column in self._obs: error_message = ( f'QC_column {QC_column!r} is already a column of obs; did ' f'you already run qc()? Set overwrite=True to overwrite.') raise ValueError(error_message) # Check that `X` is present X = self._X if X is None: error_message = 'X is None, so QCing is not possible' raise ValueError(error_message) # Get the `custom_filter`, if specified if custom_filter is not None: custom_filter = self._get_column( 'obs', custom_filter, 'custom_filter', pl.Boolean) # Check that `subset` is Boolean; if `subset=True`, check that # `QC_column` has the default value of `'passed_QC'` check_type(subset, 'subset', bool, 'Boolean') if subset and QC_column != 'passed_QC': error_message = 'QC_column can only be specified when subset=False' raise ValueError(error_message) # If `max_mito_fraction` was specified, check that it is a number # between 0 and 1, inclusive if max_mito_fraction is not None: check_type(max_mito_fraction, 'max_mito_fraction', (int, float), 'a number between 0 and 1, inclusive') check_bounds(max_mito_fraction, 'max_mito_fraction', 0, 1) # If `min_genes` was specified, check that it is a non-negative integer if min_genes is not None: check_type(min_genes, 'min_genes', int, 'a non-negative integer') check_bounds(min_genes, 'min_genes', 0) # Check that `nonzero_MALAT1` and `remove_doublets` are Boolean check_type(nonzero_MALAT1, 'nonzero_MALAT1', bool, 'Boolean') check_type(remove_doublets, 'remove_doublets', bool, 'Boolean') # If `batch_column` was specified, get it after checking that # `remove_doublets=True` if batch_column is not None: if not remove_doublets: error_message = ( 'batch_column must be None when remove_doublets is False, ' 'since its only use within qc() is for doublet detection') raise ValueError(error_message) batch_column = self._get_column( 'obs', batch_column, 'batch_column', (pl.String, pl.Enum, pl.Categorical, 'integer')) # If `doublet_fraction` was specified, check that it is a number # between 0 and 1, exclusive if doublet_fraction is not None: check_type(doublet_fraction, 'doublet_fraction', float, 'a number greater than 0 and less than 1') check_bounds(doublet_fraction, 'doublet_fraction', 0, 1, left_open=True, right_open=True) # Check that `num_doublet_genes` is a positive integer check_type(num_doublet_genes, 'num_doublet_genes', int, 'a positive integer') check_bounds(num_doublet_genes, 'num_doublet_genes', 1) # Check that `allow_float` is Boolean check_type(allow_float, 'allow_float', bool, 'Boolean') # Check that `verbose` is Boolean check_type(verbose, 'verbose', bool, 'Boolean') # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # If `allow_float=False`, raise an error if `X` is floating-point if not allow_float and np.issubdtype(X.dtype, np.floating): error_message = ( f'qc() requires raw counts but X has data type ' f'{str(X.dtype)!r}, a floating-point data type. If you are ' f'sure that all values are raw integer counts, i.e. that ' f'(X.data == X.data.astype(int)).all(), then set ' f'allow_float=True.') raise TypeError(error_message) # Check that `obs_names` and `var_names` are unique num_unique = self.obs_names.n_unique() if num_unique < len(self._obs): error_message = ( f'obs_names contains {num_unique - len(self._obs):,} ' f'duplicates; deduplicate with make_obs_names_unique()') raise ValueError(error_message) num_unique = self.var_names.n_unique() if num_unique < len(self._var): error_message = ( f'var_names contains {num_unique - len(self._var):,} ' f'duplicates; deduplicate with make_var_names_unique()') raise ValueError(error_message) # Apply the custom filter, if specified if verbose: print(f'Starting with {len(self):,} {plural("cell", len(self))}.') mask = None if custom_filter is not None: if verbose: print('Applying the custom filter...') mask = custom_filter if verbose: num_cells = mask.sum() print( f'{num_cells:,} ' f'{"cell remains" if num_cells == 1 else "cells remain"} ' f'after applying the custom filter.') # Filter to cells with ≤ `100 * max_mito_fraction`% mitochondrial # counts, if `max_mito_fraction` was specified if max_mito_fraction is not None: if verbose: print(f'Filtering to cells with ≤{100 * max_mito_fraction}% ' f'mitochondrial counts...') var_names = self.var_names if var_names.dtype != pl.String: var_names = var_names.cast(pl.String) mt_genes = var_names.str.to_uppercase().str.starts_with('MT-') if not mt_genes.any(): error_message = ( 'no genes are mitochondrial (start with "MT-", ' 'case-insensitively); this may happen if your var_names ' 'are Ensembl IDs (ENSG) rather than gene symbols (in ' 'which case you should set the gene symbols as the ' 'var_names with set_var_names()), or if mitochondrial ' 'genes have already been filtered out (in which case you ' 'can set max_mito_fraction to None)') raise ValueError(error_message) mito_mask = np.empty(X.shape[0], dtype=bool) if isinstance(X, csr_array): mito_mask_csr(data=X.data, indices=X.indices, indptr=X.indptr, mt_genes=mt_genes.to_numpy(), max_mito_fraction=max_mito_fraction, mito_mask=mito_mask, num_threads=num_threads) else: mito_mask_csc(data=X.data, indices=X.indices, indptr=X.indptr, mt_genes=mt_genes.to_numpy(), max_mito_fraction=max_mito_fraction, mito_mask=mito_mask, num_threads=num_threads) mito_mask = pl.Series(mito_mask) if not mito_mask.any(): error_message = ( f'no cells remain after filtering to cells with ' f'≤{100 * max_mito_fraction}% mitochondrial counts') raise ValueError(error_message) if mask is None: mask = mito_mask else: mask &= mito_mask if verbose: num_cells = mask.sum() print( f'{num_cells:,} ' f'{"cell remains" if num_cells == 1 else "cells remain"} ' f'after filtering to cells with ' f'≤{100 * max_mito_fraction}% mitochondrial counts.') # Filter to cells with ≥ `min_genes` genes detected, if specified if min_genes is not None: if verbose: print(f'Filtering to cells with ≥{min_genes:,} ' f'{plural("gene", min_genes)} detected (with non-zero ' f'count)...') gene_mask = pl.Series(getnnz_at_least_threshold( X, axis=1, num_threads=num_threads, threshold=min_genes)) if not gene_mask.any(): error_message = ( f'no cells remain after filtering to cells with ' f'≥{min_genes:,} {plural("gene", min_genes)} detected') raise ValueError(error_message) if mask is None: mask = gene_mask else: mask &= gene_mask if verbose: num_cells = mask.sum() print( f'{num_cells:,} ' f'{"cell remains" if num_cells == 1 else "cells remain"} ' f'after filtering to cells with ≥{min_genes:,} ' f'{plural("gene", min_genes)} detected.') # Filter to cells with non-zero MALAT1 expression, if # `nonzero_MALAT1=True` if nonzero_MALAT1: if verbose: print(f'Filtering to cells with non-zero MALAT1 expression...') MALAT1_index = self._var\ .select(pl.arg_where(pl.col(self.var_names.name) .is_in(('MALAT1', 'Malat1')))) if len(MALAT1_index) == 0: error_message = ( f"neither 'MALAT1' nor 'Malat1' was found in var_names; " f"this may happen if your var_names are Ensembl IDs " f"(ENSG) rather than gene symbols (in which case you " f"should set the gene symbols as the var_names with " f"set_var_names()). Alternatively, set " f"nonzero_MALAT1=False to disable filtering on MALAT1 " f"expression.") raise ValueError(error_message) if len(MALAT1_index) == 2: error_message = ( "both 'MALAT1' and 'Malat1' were found in var_names; if " "this is intentional, rename one of them before running " "qc(), or set nonzero_MALAT1=False to disable filtering " "on MALAT1 expression") raise ValueError(error_message) MALAT1_index = MALAT1_index.item() if isinstance(X, csr_array): MALAT1_mask = np.empty(X.shape[0], dtype=bool) if X._has_sorted_indices: # Known sorted: binary search (works for sorted indices, # whether canonical or not) malat1_mask_csr(indices=X.indices, indptr=X.indptr, MALAT1_index=MALAT1_index, MALAT1_mask=MALAT1_mask, num_threads=num_threads) elif X._has_sorted_indices is False: # Known unsorted: linear scan with short-circuiting if verbose: print('Warning: X does not have sorted indices, so ' 'some operations may be slower. You may want to ' 'sort indices with `sc.X.sort_indices()` (an ' 'in-place operation) as the first step after ' 'loading, though be aware that this may take a ' 'while.') malat1_mask_csr_scan(indices=X.indices, indptr=X.indptr, MALAT1_index=MALAT1_index, MALAT1_mask=MALAT1_mask, num_threads=num_threads) else: # Unknown sortedness: linear scan without short-circuiting # checking canonicalness/sortedness along the way has_canonical_format, has_sorted_indices = \ malat1_mask_csr_check( indices=X.indices, indptr=X.indptr, MALAT1_index=MALAT1_index, MALAT1_mask=MALAT1_mask, num_threads=num_threads) X._has_canonical_format = has_canonical_format X._has_sorted_indices = has_sorted_indices if verbose and not has_sorted_indices: print('Warning: X does not have sorted indices, so ' 'some operations may be slower. You may want ' 'to sort indices with `sc.X.sort_indices()` ' '(an in-place operation) as the first step ' 'after loading, though be aware that this may ' 'take a while.') else: start = X.indptr[MALAT1_index] end = X.indptr[MALAT1_index + 1] MALAT1_mask = np.zeros(X.shape[0], dtype=bool) MALAT1_mask[X.indices[start:end]] = True MALAT1_mask = pl.Series(MALAT1_mask) if mask is None: mask = MALAT1_mask else: mask &= MALAT1_mask if verbose: num_cells = mask.sum() print( f'{num_cells:,} ' f'{"cell remains" if num_cells == 1 else "cells remain"} ' f'after filtering to cells with non-zero MALAT1 ' f'expression.') # Remove predicted doublets, if `remove_doublets=True`. Exclude cells # that have failed earlier QC steps from being considered in the # doublet detection by passing `QC_column=mask`. if remove_doublets: if verbose: print('Removing predicted doublets...') singlets = pl.Series(SingleCell._find_doublets( X=X, batch_column=batch_column, QC_column=mask, doublet_fraction=doublet_fraction, num_genes=num_doublet_genes, return_scores=False, num_threads=num_threads)) if mask is None: mask = singlets else: mask &= singlets if verbose: num_cells = mask.sum() print( f'{num_cells:,} ' f'{"cell remains" if num_cells == 1 else "cells remain"} ' f'after removing predicted doublets.') # Add the mask of QCed cells as a column, or subset if `subset=True` if mask is None: error_message = 'no QC filters were specified' raise ValueError(error_message) if subset: if verbose: print(f'Subsetting to cells passing QC (note: you can reduce ' f'memory usage by specifying subset=False)...') sc = self.filter_obs(mask) else: if verbose: print(f'Adding a Boolean column, obs[{QC_column!r}], ' f'indicating which cells passed QC...') sc = SingleCell(X=X, obs=self._obs.with_columns(pl.lit(mask) .alias(QC_column)), var=self._var, obsm=self._obsm, varm=self._varm, uns=self._uns, num_threads=self._num_threads) sc._uns['QCed'] = True return sc
[docs] def skip_qc(self) -> SingleCell: """ Skips QC, but allows the dataset to be used by downstream functions that require QCed data. Equivalent to `self.with_uns(QCed=True)`. Returns: The dataset with `self.uns['QCed']` set to `True`. """ return self.with_uns(QCed=True)
@staticmethod def _find_doublets(X: csr_array | csc_array, batch_column: SingleCellColumn | None, QC_column: SingleCellColumn | None, doublet_fraction: float | np.floating | None, num_genes: int | np.integer, return_scores: bool, num_threads: int | np.integer | None) -> \ np.ndarray[np.dtype[np.bool_]] | \ tuple[np.ndarray[np.dtype[np.bool_]], np.ndarray[np.dtype[np.float32]]]: """ Find doublets using cxds (co-expression-based doublet scoring; academic.oup.com/bioinformatics/article/36/4/1150/5566507). Used by `qc()` (when `remove_doublets=True`) and `find_doublets()`. Args: X: the count matrix; may be either normalized or unnormalized batch_column: an optional String, Enum, Categorical, or integer column of `obs` indicating which batch each cell is from. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Doublet detection will be performed separately for each batch; cells where `batch_column` is `null` will collectively be treated as a single batch. Set to `None` if all cells belong to the same sequencing batch. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their doublet labels and doublet scores set to `null`. doublet_fraction: an optional fraction of cells (within each batch, if `batch_column` is specified) to be classified as doublets. If `None`, automatically detect the threshold via the approach described in `find_doublets()`. num_genes: the number of highly variable genes, i.e. genes expressed in as close to 50% of cells as possible, to use during doublet detection. This parameter usually has a minimal influence on accuracy as long as it is sufficiently large (in the hundreds), so increasing it further will mainly just increase runtime. If `num_genes` is greater than the number of genes in the dataset, all genes will be used. num_threads: the number of threads to use when finding doublets. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Returns: A NumPy array with the binary doublet calls (`True` for singlets, `False` for doublets), or a tuple of two arrays with the doublet calls and doublet scores when `return_scores=True`. For cells where `QC_column` is `False`, doublet calls will be `False` (though could equivalently be uninitialized, if not for this Polars bug: https://github.com/pola-rs/polars/issues/24296) and doublet scores will be uninitialized. """ is_csr = isinstance(X, csr_array) # If `batch_column` was specified, get the row indices of each batch, # ignoring cells failing QC when `QC_column` is present in `obs`. If # no batches were specified but `QC_column` is present, use `QC_column` # as the batch labels. if QC_column is not None and batch_column is None: batch_column = QC_column if batch_column is not None: batch_column_name = batch_column.name batches = (batch_column .to_frame() .lazy() .with_columns(_SingleCell_batch_indices=pl.int_range( pl.len(), dtype=pl.UInt32)) if QC_column is None else batch_column .to_frame() .lazy() .with_columns( _SingleCell_batch_indices=pl.int_range( pl.len(), dtype=pl.UInt32)) .filter(QC_column)) \ .group_by(batch_column_name, maintain_order=True) \ .agg('_SingleCell_batch_indices') \ .select('_SingleCell_batch_indices') \ .collect() \ .to_series() else: # If neither `batch_column` nor `QC_column` were specified, use a # single dummy batch batches = pl.Series([], dtype=pl.UInt32), # Preallocate total_num_cells = X.shape[0] if num_threads == 1: singlets = np.zeros(total_num_cells, dtype=bool) else: singlets = numa_zeros(total_num_cells, dtype=bool) if not return_scores: doublet_scores = np.array([], dtype=np.float32) elif num_threads == 1: doublet_scores = np.empty(total_num_cells, dtype=np.float32) else: doublet_scores = numa_zeros(total_num_cells, dtype=np.float32) max_cells = batches.list.len().max() \ if batch_column is not None else total_num_cells if doublet_fraction is not None: doublet_indices = np.empty(max_cells, dtype=np.uint32) else: # Set `doublet_fraction` to `-1` if `None`, so it can be passed to # Cython as a float doublet_fraction = -1 doublet_indices = np.array([], dtype=np.uint32) num_total_genes = X.shape[1] all_detection_counts = np.empty(num_total_genes, dtype=np.uint32) detection_counts = np.empty(num_genes, dtype=np.uint32) ps = np.empty(num_genes, dtype=np.float32) hvgs = np.empty(num_genes, dtype=np.uint32) distances = np.empty(num_genes, dtype=np.float32) obs_buffer = np.empty(num_genes * num_genes, dtype=np.uint32) S_buffer = np.empty(num_genes * num_genes, dtype=np.float32) cxds_scores_buffer = np.empty(max_cells, dtype=np.float32) sim_indptr_buffer = np.empty(max_cells + 1, dtype=X.indptr.dtype) cxds_scores_sim_buffer = np.empty(max_cells, dtype=np.float32) original_num_threads = X._num_threads try: X._num_threads = num_threads # For each batch... for batch_index, batch_indices in enumerate(batches): # Subset to cells in this batch if batch_column is None: X_batch = X else: if len(batch_indices) == 1: continue X_batch = X[batch_indices] # Get the detection count of each gene getnnz(X_batch, axis=0, num_threads=num_threads, output=all_detection_counts) # Normalize `detection_counts` by `num_cells` to get the # detection rate `p`. Subset to the `num_genes` genes with # detection rates closest to 50%. Exclude genes with detection # rates of 0% or 100%. num_cells = X_batch.shape[0] batch_num_genes = get_hvgs( all_detection_counts=all_detection_counts, detection_counts=detection_counts, ps=ps, hvgs=hvgs, distances=distances, num_cells=num_cells, num_total_genes=num_total_genes, num_genes=num_genes) detection_counts = detection_counts[:batch_num_genes] ps = ps[:batch_num_genes] X_batch = X_batch[:, hvgs[:batch_num_genes]] # Convert `X_batch` to CSR, if CSC if not is_csr: X_batch = X_batch.tocsr() # Sort indices, if not already sorted (necessary for # `simulate_doublets`) if not X.has_sorted_indices: X_batch.sort_indices() # Get `obs`, where `obs[i, j]` is the number of cells that # express exactly one of genes `i` and `j` obs = obs_buffer[:batch_num_genes * batch_num_genes]\ .reshape((batch_num_genes, batch_num_genes)) compute_obs(detection_counts=detection_counts, indices=X_batch.indices, indptr=X_batch.indptr, obs=obs, num_threads=num_threads) # Get `S`, the upper-tail log binomial p-values of `obs` S = S_buffer[:batch_num_genes * batch_num_genes]\ .reshape((batch_num_genes, batch_num_genes)) compute_S(obs=obs, ps=ps, num_cells=num_cells, S=S, num_threads=num_threads) # Calculate each cell's cxds score: the sum of `-S[i, j]` # across all gene pairs `i` and `j` that are both expressed by # the cell cxds_scores = cxds_scores_buffer[:num_cells] compute_cxds(indices=X_batch.indices, indptr=X_batch.indptr, S=S, cxds_scores=cxds_scores, num_threads=num_threads) # Now simulate doublets within the batch and compute their cxds # scores, using the original `S` matrix derived from the real # data. Conservatively allocate twice as much memory for the # indices as the original indices, since in the worst-case # scenario none of the indices will match up and all coinflips # will be 1. sim_indptr = sim_indptr_buffer[:num_cells + 1] sim_indices = np.empty(2 * len(X_batch.indices), dtype=X_batch.indices.dtype) simulate_doublets(data=X_batch.data, indices=X_batch.indices, indptr=X_batch.indptr, sim_indices=sim_indices, sim_indptr=sim_indptr, num_cells=num_cells, seed=batch_index) sim_indices = sim_indices[:sim_indptr[-1]] cxds_scores_sim = cxds_scores_sim_buffer[:num_cells] compute_cxds(indices=sim_indices, indptr=sim_indptr, S=S, cxds_scores=cxds_scores_sim, num_threads=num_threads) # Call doublets. If using multiple batches, map doublet labels # (and scores, if `return_scores`) back to the full dataset. call_doublets(cxds_scores=cxds_scores, median_cxds_score_sim=np.median(cxds_scores_sim) if doublet_fraction == -1 else 0, doublet_indices=doublet_indices, batch_indices=batch_indices.to_numpy(), singlets=singlets, doublet_scores=doublet_scores, doublet_fraction=doublet_fraction, num_threads=num_threads) finally: X._num_threads = original_num_threads if return_scores: if batch_column is None: doublet_scores = cxds_scores return singlets, doublet_scores else: return singlets
[docs] def find_doublets(self, *, batch_column: SingleCellColumn | None, QC_column: SingleCellColumn | None = 'passed_QC', doublet_fraction: float | np.floating | None = None, num_genes: int | np.integer = 500, doublet_column: str = 'doublet', doublet_score_column: str | None = 'doublet_score', overwrite: bool = False, num_threads: int | np.integer | None = None): """ Find doublets using cxds ([co-expression-based doublet scoring] (academic.oup.com/bioinformatics/article/36/4/1150/5566507)). The standard way to filter out doublets is by specifying `remove_doublets=True` in `qc()`. If you did that, do not use this function! This function should only be used in the unusual scenario where you want to find doublets without running any other quality-control steps. This function gives the same result regardless of whether it is run before or after normalization. The actual expression value does not matter, only whether or not it is zero. Doublets cannot occur across sequencing batches, so make sure to specify `batch_column` if your dataset has multiple batches! Doublet detection will be done independently within each batch. Since the cxds score is continuous, it needs to be converted into a binary classification of doublets versus non-doublets. This problem can be framed as finding a cxds score threshold above which a cell is deemed to be a doublet. To determine this threshold, we simulate doublets by combining the counts from randomly selected pairs of cells, via the following steps: 1) Sample as many random pairs of cells (with replacement) as there are real cells. 2) Combine the counts from each pair of cells into a simulated doublet. Because cxds operates on binarized count matrices, we average the two cells' count matrices in a binary sense: if a gene is expressed in either cell, it is deemed to be expressed in the simulated doublet, but if it has a count of 1 in one cell and 0 in the other, it is randomly chosen to be either expressed or not expressed with equal probability (since the average count would be 0.5). 3) Calculate cxds scores for these simulated doublets, based on the coexpression patterns (the `S` matrix from cxds) learned from the real data. 4) Take the median cxds score of the simulated doublets as the threshold. In other words, if a real cell has a higher doublet score than the average simulated doublet, we call it a doublet. Alternatively, specify `doublet_fraction` to force a specific fraction of cells to be classified as doublets. Args: batch_column: an optional String, Enum, Categorical, or integer column of `obs` indicating which batch each cell is from. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Doublet detection will be performed separately for each batch; cells where `batch_column` is `null` will collectively be treated as a single batch. Set to `None` if all cells belong to the same sequencing batch. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their doublet labels and doublet scores set to `null`. doublet_fraction: an optional fraction of cells (within each batch, if `batch_column` is specified) to be classified as doublets. If `None`, automatically detect the threshold via the approach described above. num_genes: the number of highly variable genes, i.e. genes expressed in as close to 50% of cells as possible, to use during doublet detection. This parameter usually has a minimal influence on accuracy as long as it is sufficiently large (in the hundreds), so increasing it further will mainly just increase runtime. If `num_genes` is greater than the number of genes in the dataset, all genes will be used. doublet_column: the name of a Boolean column to be added to `obs` containing the doublet labels, i.e. whether each cell is predicted to be a doublet doublet_score_column: the name of a column to be added to `obs` containing each cell's doublet score. Higher scores indicate greater likelihood of being a doublet. Scores are not normalized and are not comparable across datasets or batches, but are guaranteed to be positive (since they are sums of log p-values). Set `doublet_score_column=None` to not return doublet scores, for a slight memory reduction and speed increase. overwrite: if `True`, overwrite `doublet_column` and/or `doublet_score_column` if already present in `obs`, instead of raising an error num_threads: the number of threads to use when finding doublets. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Returns: A new SingleCell dataset where `var` contains two additional columns, `doublet_column` (default: `doublet`), indicating whether each cell is predicted to be a doublet, and `doublet_score_column` (default: `'doublet_score'`), containing each cell's doublet score. Note: This function's cxds scores are almost exactly half the original implementation's, because it avoids double-counting the two genes in each gene pair. Slight deviations from this one-half (usually by less than one part in a million) may occur because this function uses a normal approximation to the binomial p-value to avoid long runtimes on large datasets. Note: This function may give an incorrect output if the count matrix contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is not checked for, due to speed considerations. In the unlikely event that your dataset contains explicit zeros, remove them by running `sc.X.eliminate_zeros()` (an in-place operation) first. """ # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Check that `doublet_column` and (if not `None`) # `doublet_score_column` are strings and, unless `overwrite=True`, not # already in `obs` check_type(doublet_column, 'doublet_column', str, 'a string') if not overwrite and doublet_column in self._obs: error_message = ( f'doublet_column {doublet_column!r} is already a column of ' f'obs; did you already run find_doublets()? Set ' f'overwrite=True to overwrite.') raise ValueError(error_message) return_scores = doublet_score_column is not None if return_scores: check_type(doublet_score_column, 'doublet_score_column', str, 'a string') if not overwrite and doublet_score_column in self._obs: error_message = ( f'doublet_score_column {doublet_score_column!r} is ' f'already a column of obs; did you already run ' f'find_doublets()? Set overwrite=True to overwrite.') raise ValueError(error_message) # Check that `X` is present if self._X is None: error_message = 'X is None, so doublet finding is not possible' raise ValueError(error_message) # Check that `self` is QCed if not self._uns['QCed']: error_message = ( "uns['QCed'] is False; did you forget to run qc() before " "find_doublets()? Set uns['QCed'] = True or run skip_qc() to " "bypass this check.") raise ValueError(error_message) # Get `QC_column` and `batch_column`, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') if batch_column is not None: batch_column = self._get_column( 'obs', batch_column, 'batch_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=QC_column) # Check that `doublet_fraction`, if specified, is > 0 and < 1 if doublet_fraction is not None: check_type(doublet_fraction, 'doublet_fraction', float, 'a number greater than 0 and less than 1') check_bounds(doublet_fraction, 'doublet_fraction', 0, 1, left_open=True, right_open=True) # Check that `num_genes` is a positive integer check_type(num_genes, 'num_genes', int, 'a positive integer') check_bounds(num_genes, 'num_genes', 1) # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Run doublet detection singlets = SingleCell._find_doublets( X=self._X, batch_column=batch_column, QC_column=QC_column, doublet_fraction=doublet_fraction, num_genes=num_genes, return_scores=return_scores, num_threads=num_threads) # Convert doublet labels (and scores, if `doublet_score_column` is not # `None`) to polars Series to add to `obs`. If `QC_column` exists, set # doublet labels (and scores) to `null` for cells failing QC. if not return_scores: doublet_column = pl.Series(doublet_column, ~singlets) if QC_column is None: return self.with_columns_obs(doublet_column) else: return self.with_columns_obs( pl.when(QC_column).then(doublet_column)) else: singlets, doublet_scores = singlets doublet_column = pl.Series(doublet_column, ~singlets) doublet_score_column = \ pl.Series(doublet_score_column, doublet_scores) if QC_column is None: return self.with_columns_obs(doublet_column, doublet_score_column) else: return self.with_columns_obs( pl.when(QC_column).then(doublet_column), pl.when(QC_column).then(doublet_score_column))
[docs] def make_obs_names_unique(self, *, separator: str = '-') -> SingleCell: """ Make `obs_names` unique by appending `'-1'` to the second occurence of a given name, `'-2'` to the third occurrence, and so on, where `'-'` can be switched to a different string via the `separator` argument. Raises an error if any `obs_names` already contain `separator`. Args: separator: the string connecting the original name and the integer suffix Returns: A new SingleCell dataset with the `obs_names` made unique. """ check_type(separator, 'separator', str, 'a string') unique_obs_names = self.obs_names \ if self.obs_names.dtype == pl.String else \ self.obs_names.cat.get_categories() if unique_obs_names.str.contains(separator).any(): error_message = ( f'some obs_names already contain the separator {separator!r}; ' f'did you already run make_obs_names_unique()? If not, set ' f'the separator argument to a different string.') raise ValueError(error_message) obs_names = pl.col(self.obs_names.name) num_times_seen = pl.int_range(pl.len(), dtype=pl.Int32).over(obs_names) return SingleCell(X=self._X, obs=self._obs.with_columns( pl.when(num_times_seen > 0) .then(obs_names + separator + num_times_seen.cast(str)) .otherwise(obs_names)), var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def make_var_names_unique(self, *, separator: str = '-') -> SingleCell: """ Make `var_names` unique by appending `'-1'` to the second occurence of a given name, `'-2'` to the third occurrence, and so on, where `'-'` can be switched to a different string via the `separator` argument. Raises an error if any `var_names` already contain `separator`. Args: separator: the string connecting the original name and the integer suffix Returns: A new SingleCell dataset with the `var_names` made unique. """ check_type(separator, 'separator', str, 'a string') unique_var_names = self.var_names \ if self.var_names.dtype == pl.String else \ self.var_names.cat.get_categories() if unique_var_names.str.contains(separator).any(): error_message = ( f'some var_names already contain the separator {separator!r}; ' f'did you already run make_var_names_unique()? If not, set ' f'the separator argument to a different string.') raise ValueError(error_message) var_names = pl.col(self.var_names.name) num_times_seen = pl.int_range(pl.len(), dtype=pl.Int32).over(var_names) return SingleCell(X=self._X, obs=self._obs, var=self._var.with_columns( pl.when(num_times_seen > 0) .then(var_names + separator + num_times_seen.cast(str)) .otherwise(var_names)), obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def get_sample_covariates(self, *, ID_column: SingleCellColumn, QC_column: SingleCellColumn | None = 'passed_QC') -> pl.DataFrame: """ Get a DataFrame of sample-level covariates, i.e. the columns of `obs` that are the same for all cells within each sample. Args: ID_column: a String, Enum, Categorical, or integer column of `obs` containing sample IDs. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored. Returns: A DataFrame of the sample-level covariates, with ID_column (sorted) as the first column. """ if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') ID_column = self._get_column('obs', ID_column, 'ID_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=QC_column) ID_column_name = ID_column.name obs = self._obs.with_columns(ID_column) if QC_column is not None: obs = obs.filter(QC_column) other_columns = \ pl.exclude(ID_column_name) if ID_column_name in obs else pl.all() aggregated = obs.group_by(ID_column_name).agg( other_columns.first(), other_columns.min().name.prefix('_SingleCell_min_'), other_columns.max().name.prefix('_SingleCell_max_')) constant_columns = aggregated\ .select(pl.col('^_SingleCell_min_.*$') == pl.col('^_SingleCell_max_.*$'))\ .rename(lambda c: c.removeprefix('_SingleCell_min_'))\ .pipe(filter_columns, pl.all().all())\ .columns sample_covariates = aggregated\ .select(ID_column_name, *constant_columns)\ .sort(ID_column_name) return sample_covariates
[docs] def pseudobulk(self, ID_column: SingleCellColumn, cell_type_column: SingleCellColumn, /, *, QC_column: SingleCellColumn | None = 'passed_QC', cell_types: str | Iterable[str] | int | Iterable[int] | None = None, excluded_cell_types: str | Iterable[str] | int | Iterable[int] | None = None, additional_obs: pl.DataFrame | None = None, include_nulls: bool = False, sort_genes: bool = False, num_threads: int | np.integer | None = None, verbose: bool = True) -> Pseudobulk: """ Pseudobulk a SingleCell dataset with sample ID and cell type columns. Operates on raw counts, so cannot be run after `normalize()`. Must be run after `qc()`. Counts from cells with the same pair of values in `ID_column` and `cell_type_column` will be summed to a single value. Cells with `null` in either column are excluded, unless `include_nulls=True`. You can run this function multiple times at different cell type resolutions by setting a different `cell_type_column` each time, then combining the results afterwards with the `|` operator (assuming none of the cell types overlap between the two resolutions): ``` pb_broad = sc.pseudobulk('ID', 'broad_cell_type') pb_fine = sc.pseudobulk('ID', 'fine_grained_cell_type') pb = pb_broad | pb_fine ``` Args: ID_column: a String, Enum, Categorical, or integer column of `obs` containing sample IDs. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. cell_type_column: a String, Enum, Categorical, or integer column of `obs` containing cell-type labels. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. If cell_type_column is an integer column, the cell types in the Pseudobulk dataset will be coerced to strings. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be excluded from the pseudobulk. cell_types: one or more cell types to pseudobulk; by default, pseudobulks all cell types in `cell_type_column`. Specifying `cell_types` is exactly equivalent to filtering the result to these cell types, but will be faster when there are many cell types and pseudobulks are only desired for a few of them. Can also be used to change the order in which cell types appear in the resulting Pseudobulk dataset, even if pseuodbulking all cell types. Mutually exclusive with `excluded_cell_types` and `include_nulls`. excluded_cell_types: one or more cell types to exclude from pseudobulking. Mutually exclusive with `cell_types` and `include_nulls`. additional_obs: an optional DataFrame of additional sample-level covariates, which will be joined to the pseudobulk's `obs` for each cell type include_nulls: whether to exclude cells with `null` values in `ID_column` and/or `cell_type_column` from the pseudobulk. If `include_nulls=True`, `null` will be treated just like any other value. This means that, for instance, all cells from a given cell type that have `null` as the sample ID will be pseudobulked together, as will all cells from a given sample ID that have `null` as the cell type. Mutually exclusive with `cell_types` and `excluded_cell_types`. sort_genes: whether to sort genes in alphabetical order in the pseudobulk; by default, genes appear in the same order as in the SingleCell dataset num_threads: the number of threads to use when pseudobulking; parallelism happens across {sample, cell type} pairs (or just samples, if `cell_type_column` is `None`). Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. For count matrices stored in the usual CSR format, parallelization takes place across cell types and samples, so specifying more threads than the number of cell type-sample pairs will not provide additional speedup. Does not affect the Pseudobulk dataset's `num_threads`; this will always be the same as the SingleCell dataset's `num_threads`. verbose: whether to print the number of cells excluded when `include_nulls=False` (and neither `cell_types` nor `excluded_cell_types` are specified) Returns: A Pseudobulk dataset with `X` (the pseudobulked counts), `obs` (metadata per sample), and `var` (metadata per gene) fields, each of which are dictionaries across cell types. The columns of each cell type's `obs` will be: - `ID_column` - `'num_cells'` (the number of cells for that sample and cell type) followed by whichever columns of the SingleCell dataset's `obs` are constant across samples. `var` will be identical to the SingleCell dataset's `var`. Note: This function may give an incorrect output if the count matrix contains negative values: this is not checked for, due to speed considerations. """ # Check that `X` is present X = self._X if X is None: error_message = 'X is None, so pseudobulking is not possible' raise ValueError(error_message) # Check that `self` is QCed and not normalized if not self._uns['QCed']: error_message = ( "uns['QCed'] is False; did you forget to run qc() before " "pseudobulk()? Set uns['QCed'] = True or run skip_qc() to " "bypass this check.") raise ValueError(error_message) if self._uns['normalized']: error_message = ( "uns['normalized'] is True; did you already run normalize()?") raise ValueError(error_message) # Get the QC, ID, and cell-type columns if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') original_ID_column = ID_column ID_column = self._get_column('obs', ID_column, 'ID_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=QC_column, allow_null=True) ID_column_name = ID_column.name cell_type_column = \ self._get_column('obs', cell_type_column, 'cell_type_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=QC_column, allow_null=True) cell_type_column_name = cell_type_column.name num_cells_column_name = 'num_cells' for column_description, column_name in ('ID_column', ID_column_name), \ ('cell_type_column', cell_type_column_name): if column_name == num_cells_column_name: error_message = ( f'{column_description} has the name ' f'{num_cells_column_name!r}, which conflicts with the ' f'name of the column to be added to the Pseudobulk ' f'dataset containing the number of cells of each cell ' f'type') raise ValueError(error_message) # Check that `include_nulls`, `sort_genes`, and `verbose` are Boolean check_type(include_nulls, 'include_nulls', bool, 'Boolean') check_type(sort_genes, 'sort_genes', bool, 'Boolean') check_type(verbose, 'verbose', bool, 'Boolean') # Check that `include_nulls=True` is not specified alongside either # `cell_types` or `excluded_cell_types` if include_nulls: if cell_types is not None: error_message = \ 'cell_types cannot be specified when include_nulls=True' raise ValueError(error_message) if excluded_cell_types is not None: error_message = ( 'excluded_cell_types cannot be specified when ' 'include_nulls=True') raise ValueError(error_message) # Check that `cell_types` and `excluded_cell_types` are not both # specified. If `cell_types` is specified, check it contains only cell # type names present in `cell_type_column`, then set non-matching cell # types to `null` so that they are treated as a single background cell # type. If `excluded_cell_types` is specified, do the opposite. cell_type_column = SingleCell._process_cell_types( cell_types, excluded_cell_types, cell_type_column) # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) if additional_obs is not None: check_type(additional_obs, 'additional_obs', pl.DataFrame, 'a polars DataFrame') if ID_column_name not in additional_obs: ID_column_description = SingleCell._describe_column( 'ID_column', original_ID_column) error_message = ( f'{ID_column_description} is not a column of ' f'additional_obs') raise ValueError(error_message) if ID_column.dtype != additional_obs[ID_column_name].dtype: ID_column_description = SingleCell._describe_column( 'ID_column', original_ID_column) error_message = ( f"{ID_column_description} has a different data type in " f"additional_obs than in this SingleCell dataset's obs") raise TypeError(error_message) # Check that the first column of `var` is String, Enum, Categorical, # or integer: this is a requirement of the Pseudobulk class. (The first # column of `obs` must be as well, but this will always be true by # construction, since it will always be the sample ID.) dtype = self.var_names.dtype if dtype not in (pl.String, pl.Enum, pl.Categorical) and \ dtype not in pl.INTEGER_DTYPES: error_message = ( f'the first column of var (var_names) must be String, Enum, ' f'Categorical, or integer, but has data type {dtype!r}') raise ValueError(error_message) # Get the row indices that will be pseudobulked across for each group # (cell type-sample pair), ignoring cells failing QC when `QC_column` # is present in `obs` groups = (pl.LazyFrame((cell_type_column, ID_column)) .with_columns( _SingleCell_group_indices=pl.int_range(pl.len(), dtype=pl.UInt32)) if QC_column is None else pl.LazyFrame((cell_type_column, ID_column, QC_column)) .with_columns( _SingleCell_group_indices=pl.int_range(pl.len(), dtype=pl.UInt32)) .filter(QC_column.name))\ .group_by(cell_type_column_name, ID_column_name)\ .agg('_SingleCell_group_indices', pl.len().alias(num_cells_column_name))\ .sort(cell_type_column_name, ID_column_name)\ .collect() # Exclude cells with null values in `ID_column` and/or # `cell_type_column`, if `include_nulls=False`. When `cell_types` or # `excluded_cell_types` are specified, this includes cells from cell # types that are not in `cell_types` or in `excluded_cell_types`, # because of the processing we did above in # `SingleCell._process_cell_types()`. if not include_nulls: if cell_types is not None or excluded_cell_types is not None: groups = groups.drop_nulls() else: if verbose: excluded = groups\ .filter(pl.any_horizontal(pl.col( cell_type_column_name, ID_column_name).is_null())) num_ID_null_only = excluded\ .filter(pl.col(ID_column_name).is_null(), pl.col(cell_type_column_name).is_not_null())\ [num_cells_column_name]\ .sum() num_cell_type_null_only = excluded\ .filter(pl.col(ID_column_name).is_not_null(), pl.col(cell_type_column_name).is_null())\ [num_cells_column_name]\ .sum() num_excluded = excluded[num_cells_column_name].sum() num_both_null = num_excluded - num_ID_null_only - \ num_cell_type_null_only if num_excluded > 0: print(f'Excluding {num_excluded:,} ' f'{plural("cell", num_excluded)} when ' f'pseudobulking: {num_both_null:,} with nulls ' f'in both the ID ({ID_column_name!r}) and ' f'cell-type ({cell_type_column_name!r}) ' f'columns, {num_ID_null_only:,} with nulls in ' f'just the ID column, and ' f'{num_cell_type_null_only:,} with nulls in ' f'just the cell-type column.') groups = groups.drop_nulls() else: groups = groups.drop_nulls() if len(groups) == 0: error_message = ( f'no cells remain after excluding cells with nulls in ' f'the the ID and/or cell-type columns') raise ValueError(error_message) # Pseudobulk, storing the result in a preallocated NumPy array result = np.empty((len(groups), X.shape[1]), dtype=np.uint32) if isinstance(X, csr_array): group_indices = \ groups['_SingleCell_group_indices'].explode().to_numpy() group_ends = groups[num_cells_column_name].cum_sum().to_numpy() groupby_sum_csr(data=X.data, indices=X.indices, indptr=X.indptr, group_indices=group_indices, group_ends=group_ends, result=result, num_threads=num_threads) else: group_map = pl.int_range(X.shape[0], dtype=pl.UInt32, eager=True)\ .to_frame('_SingleCell_group_indices')\ .join(groups .select('_SingleCell_group_indices', _SingleCell_index=pl.int_range(pl.len(), dtype=pl.Int32)) .explode('_SingleCell_group_indices'), on='_SingleCell_group_indices', how='left')\ ['_SingleCell_index'] has_missing = group_map.null_count() > 0 if has_missing: group_map = group_map.fill_null(-1) group_map = group_map.to_numpy() groupby_sum_csc(data=X.data, indices=X.indices, indptr=X.indptr, group_map=group_map, has_missing=has_missing, result=result, num_threads=num_threads) # Sort genes, if `sort_genes=True` cell_type_var = self._var if sort_genes: result = result[:, cell_type_var[:, 0].arg_sort().to_numpy()] cell_type_var = cell_type_var.sort(cell_type_var.columns[0]) # Get sample covariates obs = self._obs.with_columns(ID_column) if QC_column is not None: obs = obs.filter(QC_column) other_columns = \ pl.exclude(ID_column_name) if ID_column_name in obs else pl.all() aggregated = obs.group_by(ID_column_name).agg( other_columns.first(), other_columns.min().name.prefix('_SingleCell_min_'), other_columns.max().name.prefix('_SingleCell_max_')) constant_columns = aggregated\ .select(pl.col('^_SingleCell_min_.*$') == pl.col('^_SingleCell_max_.*$'))\ .rename(lambda c: c.removeprefix('_SingleCell_min_'))\ .pipe(filter_columns, pl.all().all())\ .columns sample_covariates = aggregated\ .select(ID_column_name, *constant_columns)\ .sort(ID_column_name) # Break up the results by cell type X, obs, var = {}, {}, {} start_index = 0 for cell_type, count in groups[cell_type_column_name]\ .value_counts().sort(cell_type_column_name).iter_rows(): cell_type = str(cell_type) end_index = start_index + count X[cell_type] = result[start_index:end_index] obs[cell_type] = groups.lazy()\ .select(ID_column_name, num_cells_column_name)\ .slice(start_index, count)\ .join(sample_covariates.lazy(), on=ID_column_name, how='left')\ .pipe(lambda df: df.join(additional_obs.lazy(), on=ID_column_name, how='left') if additional_obs is not None else df)\ .pipe(lambda df: df if QC_column is None else df.drop(QC_column.name))\ .collect() var[cell_type] = cell_type_var start_index = end_index # Propagate the SingleCell dataset's `num_threads` to the Pseudobulk # dataset return Pseudobulk(X=X, obs=obs, var=var, num_threads=self._num_threads)
[docs] def hvg(self, *others: SingleCell, QC_column: SingleCellColumn | None | Sequence[SingleCellColumn | None] = 'passed_QC', batch_column: SingleCellColumn | None | Sequence[SingleCellColumn | None] = None, num_genes: int | np.integer = 2000, min_cells: int | np.integer = 3, exclude: str | Iterable[str] | None = None, flavor: Literal['seurat_v3', 'seurat_v3_paper'] = 'seurat_v3', span: int | float | np.integer | np.floating = 0.3, hvg_column: str = 'highly_variable', rank_column: str = 'highly_variable_rank', overwrite: bool = False, verbose: bool = True, num_threads: int | np.integer | None = None) -> \ SingleCell | tuple[SingleCell, ...]: """ Select highly variable genes using the same approach as Seurat. Operates on raw counts, so must be run before `normalize()` (but after `qc()`). When run with multiple datasets, only considers genes present in every dataset. By default, uses the same approach as Seurat's `FindVariableFeatures()` function, and Scanpy's `scanpy.pp.highly_variable_genes()` function with the `flavor` argument set to the non-default value `'seurat_v3'`. The general idea is that since genes with higher mean expression tend to have higher variance in expression (because they have more non-zero values), we want to select genes that have a high variance *relative to their mean expression*. Otherwise, we'd only be picking highly expressed genes! To correct for the mean-variance relationship, fit a LOESS curve fit to the mean-variance trend. Args: others: optional SingleCell datasets to jointly compute highly variable genes across, alongside this one. Each dataset will be treated as a separate batch. If `batch_column` is not `None`, each dataset AND each distinct value of `batch_column` within each dataset will be treated as a separate batch. Variances will be computed per batch and then aggregated (see `flavor`) across batches. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored. When `others` is specified, `QC_column` can be a length-`1 + len(others)` sequence of columns, expressions, Series, functions, or `None` for each dataset (for `self`, followed by each dataset in `others`). batch_column: an optional String, Enum, Categorical, or integer column of `obs` indicating which batch each cell is from. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Each batch will be treated as if it were a distinct dataset; this is exactly equivalent to splitting the dataset with `split_by(batch_column)` and then passing each of the resulting datasets to `hvg()`, except that the `min_cells` filter will always be calculated per-dataset rather than per-batch. Variances will be computed per batch and then aggregated (see `flavor`) across batches. Set to `None` to treat each dataset as having a single batch. When `others` is specified, `batch_column` can be a length-`1 + len(others)` sequence of columns, expressions, Series, functions, or `None` for each dataset (for `self`, followed by each dataset in `others`). num_genes: the number of highly variable genes to select. The default of 2000 matches Seurat and Scanpy's recommended value. Fewer than `num_genes` genes will be selected if not enough genes have non-zero count in >= `min_cells` cells (or when `min_cells` is `None`, if not enough genes are present). min_cells: if not `None`, filter to genes detected (with non-zero count) in >= this many cells in every dataset, before calculating highly variable genes. The default value of 3 matches Seurat and Scanpy's recommended value. Note that genes with zero variance in any dataset will always be filtered out, even if `min_cells` is 0. exclude: one or more optional case-insensitive regular expressions matching genes to exclude from the highly variable gene calculation. For instance, to exclude mitochondrial genes (starting with `'MT-'`) and ribosomal genes (starting with `'RPL-'`, `'RPS'`, `'MRPL'`, or `'MRPS'`), specify `exclude=('^MT-', '^RPL', '^RPS', '^MRPL', '^MRPS')`. flavor: the highly variable gene algorithm to use. Must be one of `seurat_v3` and `seurat_v3_paper`, both of which match the algorithms with the same name in Scanpy. Both algorithms select genes based on two criteria: 1) which genes are ranked as most variable (taking the median of the ranks across batches where the gene is among the top `num_genes` highly variable genes) and 2) the number of batches in which a gene is ranked in among the top `num_genes` in variability. `seurat_v3` ranks genes by 1) and uses 2) to tiebreak, whereas `seurat_v3_paper` ranks genes by 2) and uses 1) to tiebreak. When there is only one batch, both algorithms are the same and only rank based on 1). span: the span of the LOESS fit; higher values will lead to more smoothing hvg_column: the name of a Boolean column to be added to (each dataset's) `var` indicating the highly variable genes rank_column: the name of an integer column to be added to (each dataset's) `var` with the rank of each highly variable gene's variance (1 = highest variance, 2 = next-highest, etc.); will be `null` for non-highly variable genes. In the very unlikely event of ties, the gene that appears first in `var` will get the lowest rank. overwrite: if `True`, overwrite `hvg_column` and/or `rank_column` if already present in `var`, instead of raising an error verbose: whether to print the number of genes present in every dataset, when jointly computing highly variable genes across multiple datasets num_threads: the number of threads to use when finding highly variable genes. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the `num_threads` of the returned SingleCell dataset(s); this will always be the same as the `num_threads` of the original dataset(s). Returns: A new SingleCell dataset where `var` contains an additional Boolean column, `hvg_column` (default: `'highly_variable'`), indicating the `num_genes` most highly variable genes, and `rank_column` (default: 'highly_variable_rank') indicating the (one-based) rank of each highly variable gene's variance, with `null` values for non-highly variable genes. Or, if additional SingleCell dataset(s) are specified via the `others` argument, a length-`1 + len(others)` tuple of SingleCell datasets with these two columns added: `self`, followed by each dataset in `others`. Note: This function may give an incorrect output if the count matrix contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is not checked for, due to speed considerations. In the unlikely event that your dataset contains explicit zeros, remove them by running `sc.X.eliminate_zeros()` (an in-place operation) first. Note: This function may give an incorrect output if the count matrix contains negative values: this is not checked for, due to speed considerations. Note: This function may not give identical results to Seurat and Scanpy. It avoids floating-point summation, which is more numerically stable than Scanpy and Seurat's calculations. If multiple genes are tied as the `num_genes`-th most highly variable gene in a batch or dataset, this function includes all of them, whereas Seurat and Scanpy arbitrarily pick one (or a subset) of them. Also, this function uses the ordering from a stable sort to break ties when selecting the final list of highly variable genes, instead of the unstable sort used by Seurat and Scanpy. """ from skmisc.loess import loess # If `others` was specified, check that all elements of `others` are # SingleCell datasets if others: check_types(others, 'others', SingleCell, 'SingleCell datasets') datasets = [self] + list(others) # Check that `X` is present for every dataset if any(dataset._X is None for dataset in datasets): error_message = ( f'X is None{suffix}, so highly variable gene finding is not ' f'possible') raise ValueError(error_message) # Check that all datasets are QCed and not normalized if not all(dataset._uns['QCed'] for dataset in datasets): error_message = ( f"uns['QCed'] is False{suffix}; did you forget to run qc() " f"before hvg()? Set uns['QCed'] = True or run skip_qc() to " f"bypass this check.") raise ValueError(error_message) if any(dataset._uns['normalized'] for dataset in datasets): error_message = ( f"hvg() requires raw counts but uns['normalized'] is " f"True{suffix}; did you already run normalize()?") raise ValueError(error_message) # Check that there are at least 3 cells in each dataset (since LOESS # seems to need at least three observations to converge) if any(len(dataset._obs) < 3 for dataset in datasets): error_message = ( f'there are fewer than 3 cells{suffix}, so highly variable ' f'genes cannot be calculated') raise ValueError(error_message) # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Check that `hvg_column` and `rank_column` are strings and, unless # `overwrite=True`, not already in `var` for any dataset suffix = ' for at least one dataset' if others else '' for column, column_name in (hvg_column, 'hvg_column'), \ (rank_column, 'rank_column'): check_type(column, column_name, str, 'a string') if not overwrite and \ any(column in dataset._var for dataset in datasets): error_message = ( f'{column_name} {column!r} is already a column of ' f'var{suffix}; did you already run hvg()? Set ' f'overwrite=True to overwrite.') raise ValueError(error_message) # Check that `num_genes` is a positive integer check_type(num_genes, 'num_genes', int, 'a positive integer') check_bounds(num_genes, 'num_genes', 1) # Check that `min_cells` is a positive integer and at least as large as # the number of cells in each dataset check_type(min_cells, 'min_cells', int, 'a non-negative integer') check_bounds(min_cells, 'min_cells', 0) if any(len(dataset._obs) < min_cells for dataset in datasets): suffix = ' for at least one dataset' if others else '' error_message = ( f'the number of cells in this dataset ({len(self._obs):,}) ' f'is less than min_cells ({min_cells:,}){suffix}; increase ' f'min_cells') raise ValueError(error_message) # Check that `exclude`, if specified, is a string or sequence thereof if exclude is not None: exclude = to_tuple_checked(exclude, 'exclude', str, 'strings') # Check that `flavor` is 'seurat_v3' or 'seurat_v3_paper' check_type(flavor, 'flavor', str, 'a string') if flavor not in ('seurat_v3', 'seurat_v3_paper'): error_message = ( f"flavor must be 'seurat_v3' or 'seurat_v3_paper', " f"not {flavor!r}") raise ValueError(error_message) # Check that `span` is a positive number check_type(span, 'span', (int, float), 'a positive number') check_bounds(span, 'span', 0, left_open=True) # Check that `verbose` is Boolean check_type(verbose, 'verbose', bool, 'Boolean') # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Get `QC_column` and `batch_column` from every dataset, if not `None` QC_columns = SingleCell._get_columns( 'obs', datasets, QC_column, 'QC_column', pl.Boolean, allow_missing=True) batch_columns = SingleCell._get_columns( 'obs', datasets, batch_column, 'batch_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_columns=QC_columns) # Get the universe of genes we'll be considering: those present in all # datasets, and not matching `exclude` (if specified). If there are # multiple datasets, also get the indices of these genes in each # dataset (with `null` for genes not present in that particular # dataset). Raise an error if no genes are found in all datasets. if others: # The use of `align_frames` here is a bit wasteful memory-wise, # because it creates an identical `'gene'` column for every # DataFrame in `genes_and_indices`. Fortunately, it's only one # small string column per dataset. genes_and_indices = pl.align_frames(*( dataset.var[:, 0] .to_frame('gene') .pipe(lambda df: df.filter(~pl.col.gene.str.contains( '(?i)' + '|'.join(exclude))) # case-insensitive if exclude is not None else df) .with_columns(_SingleCell_index=pl.int_range(pl.len(), dtype=pl.UInt32)) for dataset in datasets), on='gene', how='inner') genes_in_all_datasets = genes_and_indices[0]['gene'] num_genes_in_all_datasets = len(genes_in_all_datasets) if num_genes_in_all_datasets == 0: error_message = 'no genes are present in every dataset' if exclude is not None: error_message += ( f', after applying the ' f'{plural("filter", len(exclude))} specified via the ' f'exclude argument') raise ValueError(error_message) if verbose: print(f'{num_genes_in_all_datasets:,} ' f'{plural("gene", num_genes_in_all_datasets)} are ' f'present in every dataset.') dataset_gene_indices = [df['_SingleCell_index'] for df in genes_and_indices] del genes_and_indices else: genes_in_all_datasets = self.var_names\ .rename('gene')\ .to_frame()\ .pipe(lambda df: df.filter(~pl.col.gene.str.contains( '(?i)' + '|'.join(exclude))) # case-insensitive if exclude is not None else df)\ .to_series() # Get the batches to calculate variance across (datasets + batches # within each dataset). For CSR matrices with `batch_column`, # pre-compute per-batch cell indices in a single pass via polars # `partition_by`. For CSR, `cell_mask` is always `None` or an int64 # array of cell indices; for CSC, `cell_mask` is always `None` or a # Boolean array. if others: if batch_column is None: batches = [(dataset._X, None if dataset_QC_column is None else np.flatnonzero(dataset_QC_column.to_numpy()) if isinstance(dataset._X, csr_array) else dataset_QC_column.to_numpy(), gene_indices) for dataset, dataset_QC_column, gene_indices in zip(datasets, QC_columns, dataset_gene_indices)] else: batches = [] for dataset, dataset_QC_column, dataset_batch_column, \ gene_indices in zip(datasets, QC_columns, batch_columns, dataset_gene_indices): if dataset_batch_column is None: batches.append(( dataset._X, None if dataset_QC_column is None else np.flatnonzero(dataset_QC_column.to_numpy()) if isinstance(dataset._X, csr_array) else dataset_QC_column.to_numpy(), gene_indices)) elif isinstance(dataset._X, csr_array): for partition in ( dataset_batch_column .to_frame('_SingleCell_batch') .with_columns( _SingleCell_idx=pl.int_range(pl.len())) .pipe(lambda df: df.filter(dataset_QC_column) if dataset_QC_column is not None else df) .partition_by('_SingleCell_batch')): batches.append(( dataset._X, partition['_SingleCell_idx'].to_numpy(), gene_indices)) else: for batch in dataset_batch_column.unique(): batches.append(( dataset._X, (dataset_batch_column.eq(batch) if dataset_QC_column is None else dataset_batch_column.eq(batch) & dataset_QC_column).to_numpy(), gene_indices)) else: X = self._X batch_column = batch_columns[0] if batch_column is None: if QC_column is not None and QC_columns[0] is not None: if isinstance(X, csr_array): batches = (X, np.flatnonzero( QC_columns[0].to_numpy()), None), else: batches = (X, QC_columns[0].to_numpy(), None), else: batches = (X, None, None), elif isinstance(X, csr_array): df = batch_column\ .to_frame('_SingleCell_batch')\ .with_columns(_SingleCell_idx=pl.int_range(pl.len())) if QC_column is not None and QC_columns[0] is not None: df = df.filter(QC_columns[0]) batches = [ (X, partition['_SingleCell_idx'].to_numpy(), None) for partition in df.partition_by( '_SingleCell_batch')] else: if QC_column is not None and QC_columns[0] is not None: batches = ((X, (batch_column.eq(batch) & QC_columns[0]) .to_numpy(), None) for batch in batch_column.unique()) else: batches = ((X, batch_column.eq(batch).to_numpy(), None) for batch in batch_column.unique()) # Get the variance of each gene in each batch across cells passing QC norm_gene_vars = [] for X, cell_mask, gene_indices in batches: num_dataset_genes = X.shape[1] mean = np.empty(num_dataset_genes, dtype=np.float32) var = np.empty(num_dataset_genes, dtype=np.float32) nonzero_count = np.empty(num_dataset_genes, dtype=np.uint32) is_csr = isinstance(X, csr_array) if is_csr: # For CSR, `cell_mask` is either `None` (all cells) or a # pre-computed int64 array of cell indices if cell_mask is None: cell_indices = np.array([], dtype=np.int64) num_cells = X.shape[0] else: cell_indices = cell_mask num_cells = len(cell_indices) gene_mean_and_variance_csr( data=X.data, indices=X.indices, indptr=X.indptr, cell_indices=cell_indices, num_cells=num_cells, num_dataset_genes=num_dataset_genes, mean=mean, var=var, nonzero_count=nonzero_count, num_threads=num_threads) else: if cell_mask is None: cell_mask = np.array([], dtype=bool) num_cells = X.shape[0] else: num_cells = cell_mask.sum() gene_mean_and_variance_csc( data=X.data, indices=X.indices, indptr=X.indptr, cell_mask=cell_mask, num_cells=num_cells, num_dataset_genes=num_dataset_genes, mean=mean, var=var, nonzero_count=nonzero_count, num_threads=num_threads) not_constant = var > 0 y = np.log10(var[not_constant]) x = np.log10(mean[not_constant]) model = loess(x, y, span=span) try: model.fit() except ValueError as e: any_small_batches = batch_column is not None and \ batch_column.value_counts()['count'].min() < 500 error_message = ( f'LOESS model fitting failed; this usually only happens ' f'when there are very few cells (e.g. under 500), which ' f'{"is" if num_cells < 500 else "is not"} the case here, ' f'or when batch_column is specified and certain batches ' f'have very few cells (e.g. under 500), which ' f'{"is" if any_small_batches else "is not"} the case here') raise ValueError(error_message) from e estimated_variance = np.empty(num_dataset_genes, dtype=np.float32) estimated_variance[not_constant] = model.outputs.fitted_values estimated_variance[~not_constant] = 0 estimated_stddev = np.sqrt(10 ** estimated_variance) clip_val = mean + estimated_stddev * np.sqrt(num_cells, dtype=np.float32) batch_counts_sum = np.empty(num_dataset_genes, dtype=np.float32) squared_batch_counts_sum = np.empty(num_dataset_genes, dtype=np.float32) if is_csr: clipped_sum_csr( data=X.data, indices=X.indices, indptr=X.indptr, num_cells=num_cells, num_dataset_genes=num_dataset_genes, cell_indices=cell_indices, clip_val=clip_val, batch_counts_sum=batch_counts_sum, squared_batch_counts_sum=squared_batch_counts_sum, num_threads=num_threads) else: clipped_sum_csc( data=X.data, indices=X.indices, indptr=X.indptr, cell_mask=cell_mask, clip_val=clip_val, batch_counts_sum=batch_counts_sum, squared_batch_counts_sum=squared_batch_counts_sum, num_threads=num_threads) norm_gene_var = pl.Series( (1 / ((num_cells - 1) * np.square(estimated_stddev))) * ((num_cells * np.square(mean)) + squared_batch_counts_sum - 2 * batch_counts_sum * mean)) # If `min_cells` is non-zero, set variances to `null` for genes # with a non-zero count less than `min_cells` if min_cells: norm_gene_var = norm_gene_var\ .set(pl.Series(nonzero_count < min_cells), None) # If there are multiple datasets, `norm_gene_var` is currently with # respect to the genes in `dataset.var_names`; map back to the # genes in `genes_in_any_dataset`, filling with `null` if others: norm_gene_var = norm_gene_var[gene_indices] norm_gene_vars.append(norm_gene_var) rank = pl.exclude('gene').rank('min', descending=True) final_rank = pl.struct( ('median_rank', 'nbatches') if flavor == 'seurat_v3' else ('nbatches', 'median_rank')).rank('ordinal', descending=True) # Note: the expression for `median_rank` can be replaced by # `pl.median_horizontal(pl.exclude('gene'))` once polars implements it hvgs = pl.DataFrame([genes_in_all_datasets] + norm_gene_vars)\ .lazy()\ .pipe(lambda df: df.drop_nulls(pl.selectors.exclude('gene')) if min_cells or others else df)\ .with_columns(pl.when(rank <= num_genes).then(rank))\ .with_columns(nbatches=pl.sum_horizontal(pl.exclude('gene') .is_not_null()), median_rank=-pl.concat_list(pl.exclude('gene')) .explode() .median().over(pl.int_range( pl.len(), dtype=pl.UInt32)))\ .select('gene', (final_rank <= num_genes).alias(hvg_column), pl.when(final_rank <= num_genes).then(final_rank) .alias(rank_column))\ .collect() # Return a new SingleCell dataset (or a tuple of datasets, if others # is non-empty) containing the highly variable genes for dataset_index, dataset in enumerate(datasets): new_var = dataset._var\ .pipe(lambda df: df.drop(hvg_column, rank_column, strict=False) if overwrite else df)\ .join(hvgs.rename({'gene': dataset.var_names.name}), on=dataset.var_names.name, how='left')\ .with_columns(pl.col(hvg_column).fill_null(False)) datasets[dataset_index] = \ SingleCell(X=dataset._X, obs=dataset._obs, var=new_var, obsm=dataset._obsm, varm=dataset._varm, uns=dataset._uns, num_threads=dataset._num_threads) return tuple(datasets) if others else datasets[0]
[docs] def normalize(self, *, QC_column: SingleCellColumn | None = 'passed_QC', method: Literal['PFlog1pPF', 'log1pPF', 'logCP10k'] = 'log1pPF', inplace: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Normalize this SingleCell dataset's counts. Must be run after `hvg()` and before `PCA()`. `normalize()` supports three normalization methods. All three methods normalize each cell independently of the rest and log-transform the counts in some way, but differ in the details. The simplest approach, `method='logCP10k'`, computes the log of the counts per 10 thousand: `normalized_count = log(count / 10000 + 1)`. It matches the default settings of Seurat's `NormalizeData()` function, aside from differences in floating-point error. This method is not recommended because it implicitly assumes an unrealistically large amount of overdispersion, and performs worse in the benchmarks of the papers discussed below. The next-simplest approach, `method='log1pPF'`, is the default. Instead of using the same denominator of 10 thousand for every cell, it uses `X.sum(axis=1) / X.sum(axis=1).mean()` as the denominator. In other words, a cell's denominator is the cell's library size, relative to the mean library size across all cells. (By library size, we mean the sum of a cell's counts across all genes.) This approach of dividing each cell's counts by its relative library size is sometimes called "proportional fitting" (PF). [Ahlmann-Eltze and Huber 2023](https://nature.com/articles/s41592-023-01814-1) recommend using this method instead of normalizing by a fixed denominator, like `method='logCP10k'` does. Scanpy's `normalize_total()` uses this method with a slight variation: it uses median instead of mean to define the relative library size. The most complex approach, `method='PFlog1pPF'`, takes the output of `log1pPF` and applies an additional round of proportional fitting after the log-transformation. [Booeshaghi et al. 2022](https://biorxiv.org/content/10.1101/2022.05.06.490859v1.full) recommend this approach, arguing that log1pPF does not fully normalize for read depth, because the log transform partially undoes the first round of proportional fitting. Args: QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will still be normalized, but will not count towards the calculation of the mean total count across cells when `method` is `'PFlog1pPF'` or `'log1pPF'`. Has no effect when `method` is `'logCP10k'`. method: the normalization method to use (see above) inplace: whether to do in-place normalization. This reduces memory usage, but is only possible for float32 count matrices and will raise an error if the count matrix has any other data type. num_threads: the number of threads to use when normalizing. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the normalized SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset with the normalized counts, and `uns['normalized']` set to `True`. Or when `inplace=True`, return the original dataset with the counts normalized in-place. """ # Check that `X` is present if self._X is None: error_message = 'X is None, so normalizing is not possible' raise ValueError(error_message) # Check that `self` is QCed and not already normalized if not self._uns['QCed']: error_message = ( "uns['QCed'] is False; did you forget to run qc() (and " "possibly hvg()) before normalize()? Set uns['QCed'] = True " "or run skip_qc() to bypass this check.") raise ValueError(error_message) if self._uns['normalized']: error_message = \ "uns['normalized'] is True; did you already run normalize()?" raise ValueError(error_message) # Get the QC column, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') # Check that `method` is one of the three valid methods if method == 'PFlog1pPF': method_number = 2 elif method == 'log1pPF': method_number = 1 else: if method != 'logCP10k': error_message = ( "method must be one of 'PFlog1pPF', 'log1pPF', or " "'logCP10k'") raise ValueError(error_message) method_number = 0 # Check that `inplace` is Boolean check_type(inplace, 'inplace', bool, 'Boolean') # When `inplace=True`, raise an error if the count matrix is not # float32 if inplace and self._X.dtype != np.float32: error_message = ( f'inplace normalization requires floating-point count ' f'matrices, but X has data type {str(self._X.dtype)!r}') raise TypeError(error_message) # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) X = self._X if isinstance(X, csr_array): normalize = normalize_csr sparse_array = csr_array else: normalize = normalize_csc sparse_array = csc_array num_cells = X.shape[0] if num_threads == 1: row_sums = np.empty(num_cells, dtype=np.uint64) else: row_sums = numa_zeros(num_cells, dtype=np.uint64) if inplace: normalize(data=X.data, indices=X.indices, indptr=X.indptr, QC_column=QC_column.to_numpy() if QC_column is not None else np.array([], dtype=bool), normalized_data=X.data, row_sums=row_sums, num_cells=num_cells, method_number=method_number, num_threads=num_threads) return self else: if num_threads == 1: normalized_data = np.empty(len(X.data), dtype=np.float32) else: normalized_data = numa_zeros(len(X.data), dtype=np.float32) original_num_threads = X._num_threads normalize(data=X.data, indices=X.indices, indptr=X.indptr, QC_column=QC_column.to_numpy() if QC_column is not None else np.array([], dtype=bool), normalized_data=normalized_data, row_sums=row_sums, num_cells=num_cells, method_number=method_number, num_threads=num_threads) X = sparse_array((normalized_data, X.indices, X.indptr), shape=X.shape) X._num_threads = original_num_threads sc = SingleCell(X=X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, uns=self._uns, num_threads=self._num_threads) sc._uns['normalized'] = True return sc
[docs] def pca(self, *others: SingleCell, QC_column: SingleCellColumn | None | Sequence[SingleCellColumn | None] = 'passed_QC', hvg_column: SingleCellColumn | Sequence[SingleCellColumn] | None = 'highly_variable', PC_key: str = 'pca', num_PCs: int | np.integer = 50, subspace_size: int | np.integer = 100, tolerance: int | np.integer | float | np.floating = 1e-6, max_iterations: int | np.integer = 100, chunk_size: int | np.integer = 1024, seed: int | np.integer = 0, match_parallel: bool = False, overwrite: bool = False, verbose: bool = True, num_threads: int | np.integer | None = None) -> \ SingleCell | tuple[SingleCell, ...]: """ Compute principal components (PCs) across cells. Requires normalized counts, so must be run after `normalize()`. By default, only the highly variable genes from `hvg()` are used to compute PCs. Uses approximate singular value decomposition (SVD) via the Implicitly Restarted Lanczos Bidiagonalization Algorithm (IRLBA). Seurat uses a different implementation of the same IRLBA algorithm. Args: others: optional SingleCell datasets to jointly compute principal components across, alongside this one. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their PCs set to `NaN`. When `others` is specified, `QC_column` can be a length-`1 + len(others)` sequence of columns, expressions, Series, functions, or `None` for each dataset (for `self`, followed by each dataset in `others`). hvg_column: an optional Boolean column of `var` indicating the highly variable genes. Set to `None` to include all genes. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. When `others` is specified, `hvg_column` can be a length-`1 + len(others)` sequence of columns, expressions, Series, functions, or `None` for each dataset (for `self`, followed by each dataset in `others`). PC_key: the key of `obsm` where the principal components will be stored num_PCs: the number of top principal components to calculate subspace_size: the size of the Krylov subspace used by IRLBA when calculating PCs. Must be greater than or equal to `num_PCs`, and about twice `num_PCs` is recommended. tolerance: the relative tolerance (expressed as the ratio of a singular value's residual to the maximum singular value) required to deem a singular value converged. IRLBA will stop early, before `max_iterations` iterations, if all singular values have converged. max_iterations: the maximum number of iterations to run IRLBA for, stopping early if all singular values have converged (see `tolerance`) chunk_size: the number of rows per fixed block in deterministic parallel reductions. Used to parallelize operations like mean and norm that would otherwise be serial. Block boundaries are fixed regardless of thread count, ensuring floating-point identical results. seed: the random seed to use when initializing the PCs, via R's `set.seed()` function match_parallel: if `False`, use a different order of operations for single-threaded PCA. This gives a moderate (~60%) boost in single-threaded performance, and lower memory usage, at the cost of no longer exactly matching the PCs produced by the multithreaded version (due to differences in floating-point error arising from the different order of operations). When `match_parallel=False`, `PCA()` will also give slightly different results when run with CSR vs CSC input; when multiple datasets are provided with a mix of CSR and CSC formats, all datasets are converted to the format shared by the most total cells across datasets. If `True`, exactly match the results of the multithreaded version when `num_threads=1`. Must be `False` unless `num_threads=1`. overwrite: if `True`, overwrite `PC_key` if already present in obsm, instead of raising an error verbose: whether to print a message when the singular values did not converge to a tolerance of `tolerance` within `max_iterations` iterations num_threads: the number of threads to use for PCA. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset where `obsm` contains an additional key, `PC_key` (default: `'pca'`), containing the top `num_PCs` principal components. Or, if additional SingleCell dataset(s) are specified via the `others` argument, a length-`1 + len(others)` tuple of SingleCell datasets with the PCs added: `self`, followed by each dataset in `others`. Note: Unlike Seurat's `RunPCA()` function, which requires `ScaleData()` to be run first, this function does not require the data to be scaled beforehand. Instead, it implicitly scales the data to zero mean and unit variance while performing PCA. """ # If `others` was specified, check that all elements of `others` are # SingleCell datasets if others: check_types(others, 'others', SingleCell, 'SingleCell datasets') datasets = [self] + list(others) # Check that `PC_key` is a string check_type(PC_key, 'PC_key', str, 'a string') # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Check that `PC_key` is not already in `obsm`, unless `overwrite=True` suffix = ' for at least one dataset' if others else '' for dataset in datasets: if not overwrite and PC_key in dataset._obsm: error_message = ( f'PC_key {PC_key!r} is already a key of obsm{suffix}; did ' f'you already run PCA()? Set overwrite=True to overwrite.') raise ValueError(error_message) # Check that `X` is present for every dataset if any(dataset._X is None for dataset in datasets): error_message = ( f'X is None{suffix}, so finding principal components is not ' f'possible') raise ValueError(error_message) # Get `QC_column` and `hvg_column` (if not `None`) from every dataset QC_columns = SingleCell._get_columns( 'obs', datasets, QC_column, 'QC_column', pl.Boolean, allow_missing=True) hvg_columns = SingleCell._get_columns( 'var', datasets, hvg_column, 'hvg_column', pl.Boolean, custom_error=f'hvg_column {{}} is not a column of var{suffix}; ' f'did you forget to run hvg() (and possibly ' f'normalize()) before PCA()?') # Check that all datasets are normalized if not all(dataset._uns['normalized'] for dataset in datasets): error_message = ( f"PCA() requires normalized counts but uns['normalized'] is " f"False{suffix}; did you forget to run normalize() before " f"PCA()?") raise ValueError(error_message) # Raise an error if `X` is not float32 for every dataset for dataset in datasets: if dataset._X.dtype != np.float32: error_message = ( f'PCA() requires normalized counts with data type ' f'float32, but X has data type ' f'{str(dataset._X.dtype)!r}{suffix}; did you forget to ' f'run normalize() before PCA()?') raise TypeError(error_message) # Check that `num_PCs` is a positive integer check_type(num_PCs, 'num_PCs', int, 'a positive integer') check_bounds(num_PCs, 'num_PCs', 1) # Check that `subspace_size` is a positive integer, and >= `num_PCs` check_type(subspace_size, 'subspace_size', int, 'a positive integer') if subspace_size < num_PCs: error_message = ( f'subspace_size is {subspace_size:,}, but must be ≥ num_PCs ' f'({num_PCs:,})') raise ValueError(error_message) # Check that `tolerance` is a positive number check_type(tolerance, 'tolerance', (int, float), 'a positive number') check_bounds(tolerance, 'tolerance', 0, left_open=True) # Check that `max_iterations` is a positive integer check_type(max_iterations, 'max_iterations', int, 'a positive integer') check_bounds(max_iterations, 'max_iterations', 1) # Check that `chunk_size` is a positive integer check_type(chunk_size, 'chunk_size', int, 'a positive integer') check_bounds(chunk_size, 'chunk_size', 1) # Check that `seed` is an integer check_type(seed, 'seed', int, 'an integer') # Check that `verbose` is Boolean check_type(verbose, 'verbose', bool, 'Boolean') # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to # `os.cpu_count()`. num_threads = self._process_num_threads(num_threads) # Check that `match_parallel` is Boolean, and `False` unless # `num_threads=1` check_type(match_parallel, 'match_parallel', bool, 'Boolean') if match_parallel and num_threads != 1: error_message = \ 'match_parallel must be False unless num_threads is 1' raise ValueError(error_message) # Get the matrix to compute PCA across: a sparse matrix of counts for # highly variable genes (or all genes, if `hvg_column` is `None`) # across cells passing QC. Use `X[np.ix_(rows, columns)]` as a faster, # more memory-efficient alternative to `X[rows][:, columns]`. original_num_threads = self._X._num_threads try: self._X._num_threads = num_threads if others: if hvg_column is None: genes_in_all_datasets = self.var_names\ .filter(self.var_names .is_in(pl.concat([dataset.var_names for dataset in others]))) else: hvg_in_self = \ self._var.filter(hvg_columns[0]).to_series() \ if hvg_columns[0] is not None else \ self._var.to_series() genes_in_all_datasets = hvg_in_self\ .filter(hvg_in_self.is_in(pl.concat([ dataset._var.filter(hvg_col).to_series() if hvg_col is not None else dataset._var.to_series() for dataset, hvg_col in zip(others, hvg_columns[1:])]))) gene_indices = ( genes_in_all_datasets .to_frame() .join(dataset._var.with_columns( _SingleCell_index=pl.int_range(pl.len(), dtype=pl.Int32)), left_on=genes_in_all_datasets.name, right_on=dataset.var_names.name, how='left') ['_SingleCell_index'] .to_numpy() for dataset in datasets) if QC_column is None: Xs = [dataset._X[:, genes] for dataset, genes in zip(datasets, gene_indices)] else: Xs = [dataset._X[np.ix_(QC_col.to_numpy(), genes)] if QC_col is not None else dataset._X[:, genes] for dataset, genes, QC_col in zip(datasets, gene_indices, QC_columns)] else: if QC_column is None: if hvg_column is None: Xs = [dataset._X for dataset in datasets] else: Xs = [dataset._X[:, hvg_col.to_numpy()] if hvg_col is not None else dataset._X for dataset, hvg_col in zip(datasets, hvg_columns)] else: if hvg_column is None: Xs = [dataset._X[QC_col.to_numpy()] if QC_col is not None else dataset._X for dataset, QC_col in zip(datasets, QC_columns)] else: Xs = [(dataset._X[np.ix_(QC_col.to_numpy(), hvg_col.to_numpy())] if QC_col is not None else dataset._X[:, hvg_col.to_numpy()]) if hvg_col is not None else (dataset._X[QC_col.to_numpy()] if QC_col is not None else dataset._X) for dataset, QC_col, hvg_col in zip(datasets, QC_columns, hvg_columns)] finally: self._X._num_threads = original_num_threads num_cells_per_dataset = np.array([X.shape[0] for X in Xs]) if len(Xs) == 1: X = Xs[0] elif all(isinstance(X, csr_array) for X in Xs): X = sparse_major_stack(Xs, num_threads=num_threads) elif all(isinstance(X, csc_array) for X in Xs): X = sparse_minor_stack(Xs, num_threads=num_threads) else: # Mix of CSR and CSC: convert to whichever format has the most # total cells in that format, to reduce the number of datasets that # need to be flipped total_cells = num_cells_per_dataset.sum() total_csr = sum(X.shape[0] for X in Xs if isinstance(X, csr_array)) if total_csr / total_cells > 0.5: # more CSR than CSC X = sparse_major_stack([ X.tocsr() if isinstance(X, csc_array) else X for X in Xs], num_threads=num_threads) else: X = sparse_minor_stack([ X.tocsc() if isinstance(X, csr_array) else X for X in Xs], num_threads=num_threads) del Xs # Check that the number of cells and genes are >= `num_PCs`, and that # there are at least two cells and two genes even if `num_PCs` is 1 num_cells, num_genes = X.shape if num_cells < num_PCs: error_message = ( f'num_PCs is {num_PCs:,}, but must be ≤ the number of cells ' f'({num_cells:,})') raise ValueError(error_message) if num_genes < num_PCs: error_message = ( f'num_PCs is {num_PCs:,}, but must be ≤ the number of genes ' f'({num_genes:,})') raise ValueError(error_message) if num_cells == 1: error_message = ( f'there is only one cell, so principal components cannot be ' f'calculated') raise ValueError(error_message) if num_genes == 1: error_message = ( f'there is only one gene, so principal components cannot be ' f'calculated') raise ValueError(error_message) # We want to perform SVD on `X`, scaled to zero mean and unit variance. # Key ideas: # 1. Because `X` is a sparse matrix, mean-centering cannot be done # without converting to a dense matrix. So scaling `X` cannot be # done in the conventional way. # 2. Fortunately, we can represent scaling as a matrix product: # `scale(X) = C @ X @ W`, where `W` is a diagonal matrix of the # standard deviations for each column (gene) and `C` is a "centering # matrix" that, when applied to any vector or matrix, yields the # mean-centered version of it. # 3. We need to calculate the matrix-vector product of our operator # with some vector `V`, i.e. `scale(X) @ V`. Using the formula from # point #2, this is equivalent to `C @ X @ W @ V`. Since `W` is # diagonal, this is equivalent to `C @ (X @ (V / diag(W)))`. In # other words: # - divide `V` (which has length `num_genes)` elementwise by # `diag(W)`, the genewise standard deviations # - matrix-vector multiply by `X` # - mean-center the resulting vector, which has length `num_cells` # 4. We also need to calculate `scale(X).T @ V` for the `rmatvec()` # part of our operator. Rewriting as `W.T @ X @ C.T @ V` and # leveraging the fact that `C` turns out to be symmetric (as is `W`, # since it's diagonal), this is equivalent to # `(X @ (C @ V)) / diag(W)`. In other words: # - mean-center `V`, which has length `num_cells` # - matrix-vector multiply by `X.T` # - divide the result (which has length `num_genes)` elementwise by # `diag(W)`, the genewise standard deviations # 5. CSR matrix-vector multiplication can be done efficiently # multithreaded, but CSC can't (except by maintaining thread-local # versions of the output vector and summing them across threads at # the end, which leads to differences in floating-point error # depending on the number of threads). This is problematic because, # regardless of whether `X` is a CSR or a CSC matrix, we need to do # some multiplications involving `X.T` and others involving `X`. So # if `X` is CSR, the `X.T` multiplications will be single-threaded # since `X.T` is a CSC matrix. If `X` is CSC, the `X` # multiplications will be single-threaded. So one of the two # multiplications will always be single-threaded. There are hundreds # of these matrix-vector multiplications and they take up a large # majority of the total runtime for PCA, so not being able to fully # multithread them is a huge disadvantage. # 6. To address the issue in the previous point, make both a CSR and a # CSC copy of `X` when running PCA in parallel. The first # multiplication (involving `X.T`) will use the CSC copy of `X`, but # plugged into the CSR matrix-vector multiplication function. The # second multiplication (involving `X`) will use CSR multiplication # as normal. This works because plugging a CSC copy of `X` into a # matrix-vector multiplication routine designed for CSR matrices (or # vice versa) is equivalent to multiplying by `X.T` instead of `X`. # The result: both matrix multiplications can be done in parallel. # This also has the nice side benefit that the final result is the # same regaredless of whether the counts are input as CSR or CSC. # 7. When running single-threaded with `faster_single_threaded=True`, # however, just use whichever version of `X` (CSR or CSC) we happen # to have available, to avoid the runtime and memory overhead of # creating both versions. However, this means that CSR and CSC no # longer give exactly the same PCs, due to differences in # floating-point error. # 8. When calculating the scaled variance for each gene, use the CSC # version of `X` if available, for speed. Clip tiny standard # deviations (less than 1e-8) to 1e-8, like Seurat. # 9. In the parallel code path, all U columns produced by the matvec # are zero-mean by construction (the centering matrix is applied). # Reorthogonalization subtracts a linear combination of previous # zero-mean columns, preserving zero-mean. Scaling by 1/alpha # preserves zero-mean. This invariant enables two optimizations: # (a) the matvec can skip mean-centering and defer it to a fused # pass that also performs reorthogonalization and norm computation; # (b) the rmatvec can skip mean-centering entirely since its input # (a U column) is already zero-mean. # Note: this optimization changes the floating-point result versus # the single-threaded path (which does redundant mean-centering), # but match_parallel=True is not allowed when num_threads > 1, so # the parallel path only has to be self-consistent across thread # counts. # Get the data, indices and indptr to be used for the matvec and # rmatvec. When `match_parallel=True`, this requires calculating the # "opposite" version of the array (CSC if `X` is CSR, CSR if CSC). is_csr = isinstance(X, csr_array) if is_csr: # `X` is CSR, so always use CSR for the matvec. If `num_threads=1` # and `match_parallel=False`, use CSR for the rmatvec as well, # otherwise use CSC for the rmatvec to enable paralellism. data_matvec = X.data indices_matvec = X.indices indptr_matvec = X.indptr if num_threads == 1 and not match_parallel: data_rmatvec = data_matvec indices_rmatvec = indices_matvec indptr_rmatvec = indptr_matvec else: X_csc = X.tocsc() data_rmatvec = X_csc.data indices_rmatvec = X_csc.indices indptr_rmatvec = X_csc.indptr else: # `X` is CSC, so always use CSC for the rmatvec. If `num_threads=1` # and `match_parallel=False`, use CSC for the matvec as well, # otherwise use CSR for the matvec to enable paralellism. data_rmatvec = X.data indices_rmatvec = X.indices indptr_rmatvec = X.indptr if num_threads == 1 and not match_parallel: data_matvec = data_rmatvec indices_matvec = indices_rmatvec indptr_matvec = indptr_rmatvec else: X_csr = X.tocsr() data_matvec = X_csr.data indices_matvec = X_csr.indices indptr_matvec = X_csr.indptr # Run PCA with irlba. Use `threadpool_limits()` to run BLAS # single-threaded when `num_threads=1`. (Unlike the other functions in # the library that use `threadpool_limits()`, here the multi-threaded # code path contains BLAS calls outside `prange()`, and these need to # be explicitly limited to one thread with `threadpool_limits(1)` # inside Cython.) if num_threads == 1: PCs = np.empty((num_cells, num_PCs), dtype=np.float32) else: PCs = numa_zeros((num_cells, num_PCs), dtype=np.float32) original_num_threads = X._num_threads X._num_threads = num_threads try: with threadpool_limits(num_threads): num_cells, num_genes = X.shape converged = irlba( data_matvec=data_matvec, indices_matvec=indices_matvec, indptr_matvec=indptr_matvec, data_rmatvec=data_rmatvec, indices_rmatvec=indices_rmatvec, indptr_rmatvec=indptr_rmatvec, num_cells=num_cells, num_genes=num_genes, k=num_PCs, subspace_size=subspace_size, tolerance=tolerance, max_iterations=max_iterations, seed=seed, match_parallel=match_parallel, is_csr=is_csr, num_threads=num_threads, chunk_size=chunk_size, PCs=PCs) finally: X._num_threads = original_num_threads # Print a message if PCs did not converge and `verbose=True` if not converged and verbose: print(f'PCA did not converge to a tolerance of {tolerance:.2g} ' f'after {max_iterations:,} iterations; consider increasing ' f'max_iterations or tolerance') # Store each dataset's PCs in its `obsm` if not others: # just one dataset QC_col = QC_columns[0] # If `QC_col` is not `None`, back-project from QCed cells to all # cells, filling with `NaN` if QC_col is not None: PCs_QCed = PCs PCs = np.full((len(self), PCs_QCed.shape[1]), np.nan, dtype=np.float32) PCs[QC_col.to_numpy()] = PCs_QCed return SingleCell( X=dataset._X, obs=dataset._obs, var=dataset._var, obsm=dataset._obsm | {PC_key: PCs}, varm=self._varm, uns=self._uns, num_threads=self._num_threads) else: for dataset_index, (dataset, QC_col, num_cells, end_index) in \ enumerate(zip(datasets, QC_columns, num_cells_per_dataset, num_cells_per_dataset.cumsum())): start_index = end_index - num_cells dataset_PCs = PCs[start_index:end_index] # If `QC_col` is not `None` for this dataset, back-project from # QCed cells to all cells, filling with `NaN` if QC_col is not None: dataset_PCs_QCed = dataset_PCs dataset_PCs = np.full((len(dataset), dataset_PCs_QCed.shape[1]), np.nan, dtype=np.float32) dataset_PCs[QC_col.to_numpy()] = dataset_PCs_QCed datasets[dataset_index] = SingleCell( X=dataset._X, obs=dataset._obs, var=dataset._var, obsm=dataset._obsm | {PC_key: dataset_PCs}, varm=self._varm, uns=self._uns, num_threads=self._num_threads) return tuple(datasets)
[docs] def neighbors(self, *, QC_column: SingleCellColumn | None = 'passed_QC', PC_key: str = 'pca', neighbors_key: str = 'neighbors', distances_key: str = 'distances', num_neighbors: int | np.integer = 20, num_clusters: int | np.integer | None = None, num_clusters_searched: int | np.integer | None = None, num_kmeans_iterations: int | np.integer = 2, kmeans_tolerance: int | np.integer | float | np.floating = 1e-2, kmeans_barbar: bool = False, num_init_iterations: int | np.integer = 5, oversampling_factor: int | np.integer | float | np.floating = 1, chunk_size_kmeans: int | np.integer | None = None, chunk_size_search: int | np.integer | None = None, seed: int | np.integer = 0, overwrite: bool = False, verbose: bool = True, num_threads: int | np.integer | None = None) -> SingleCell: """ Calculate the `num_neighbors` nearest neighbors of each cell. `neighbors()` is intended to be run after `PCA()`; by default, it uses `obsm['pca']` as the input to the nearest-neighbors calculation. `neighbors()` must be re-run if the dataset is subset; not doing so will raise an error. A cell is not considered its own nearest neighbor. `neighbors()` is based on the `IndexIVFFlat` search strategy from the widely-used `faiss` nearest-neighbor search library. It works in two main steps: 1) Perform k-means clustering to subdivide the dataset into `num_clusters` clusters. For large datasets, the number of clusters is four times the square root of the number of cells, rounding up; for small datasets, 1% of the number of cells. 2) Find each cell's `num_clusters_searched` (default: 64) nearest cluster centroids (one of which is the cell's own cluster), then search these clusters exhaustively for nearest-neighbor candidates. The key optimization: instead of going one cell at a time and searching its 64 nearest clusters, group cells into batches of `chunk_size_search` (default: 256), enumerate which clusters are one of the 64 nearest to at least one cell in the batch, then run the searches across each of these clusters, one cluster at a time. Since cells in a block will tend to share many of their nearest 64 clusters, this amortizes the cost of streaming the cluster's cells from memory across all the cells in the block that need them, speeding up the search by an order of magnitude for large datasets. The key parameters that affect accuracy and runtime are `num_clusters` and `num_clusters_searched`. Runtime goes up about linearly with the fraction of cells searched, which is roughly `num_clusters_searched / num_clusters`. Accuracy will also increase, but with strongly diminishing returns. We recommend increasing `num_clusters_searched` (e.g. to 128) if greater accuracy is desired, and decreasing it (e.g. to 32) if greater speed is desired. Args: QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their nearest neighbors set to `-1`. PC_key: the key of `obsm` containing the principal components calculated with `PCA()`, to use as the input for the nearest-neighbors calculation neighbors_key: the key of `obsm` where the nearest neighbors will be stored distances_key: the key of `obsm` where the squared Euclidean distance to each nearest neighbor will be stored num_neighbors: the number of nearest neighbors to report for each cell; must be less than or equal to the number of cells num_clusters: the number of k-means clusters to use during the nearest-neighbor search. Must be less than the number of cells. If `None`, will be set to `ceil(min(4 * sqrt(num_cells), num_cells / 100))` clusters, i.e. the minimum of four times the square root of the number of cells in `other` and 1% of the number of cells in `other`, rounding up. The core of the heuristic, `4 * sqrt(num_cells)`, is the low end of the range [recommended](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index) by faiss, 4 to 16 times the square root. However, faiss also recommends using between 39 and 256 data points per centroid when training the k-means clustering used in the k-nearest neighbors search. To avoid going below 39, we switch to using `num_cells / 100` centroids for small datasets (fewer than 640,000 cells), since 100 is the midpoint of 39 and 256 in log space. For datasets of at least 10,000 cells, clusters with fewer than 100 cells will be merged into the adjacent cluster with the nearest centroid, so the actual number of clusters used may be smaller. num_clusters_searched: the number of a cell's nearest clusters to search; must be between 1 and `num_clusters`. Defaults to `min(64, num_clusters)`. num_kmeans_iterations: the maximum number of iterations of k-means clustering to perform before starting the nearest-neighbor search, stopping early if a relative convergence of `kmeans_tolerance` is reached kmeans_tolerance: the relative change in inertia (the sum of squared distances from each cell to its assigned centroid) used to determine whether to stop optimizing the k-means clustering before `num_kmeans_iterations` iterations kmeans_barbar: whether to use [k-means||](https://arxiv.org/abs/1203.6402) initialization (a parallel version of k-means++) to initialize the k-means clustering centroids, instead of random initialization. This is more accurate but takes considerably longer for large datasets and `num_clusters`. num_init_iterations: the number of k-means|| iterations used to initialize the k-means clustering that constitutes the first step of the nearest-neighbor search. k-means|| is a parallel version of the widely used k-means++ initialization scheme for k-means clustering. The default value of 5 is recommended by the [k-means|| paper](https://arxiv.org/abs/1203.6402). Only used when `kmeans_barbar=True`. oversampling_factor: the number of candidate centroids selected, on average, at each of the `num_init_iterations` iterations of k-means||, as a multiple of `num_clusters`. The default value of 1 is the midpoint (in log space) of the values explored by the [k-means|| paper](https://arxiv.org/abs/1203.6402), namely 0.1 to 10. The total number of candidate centroids selected, on average, will be `oversampling_factor * num_clusters + 1`, from which the final `num_clusters` centroids will then be selected via k-means++. Only used when `kmeans_barbar=True`. chunk_size_kmeans: the chunk size used for distance calculations during k-means clustering, and also during the per-query centroid ranking step of the nearest-neighbor search. Setting this to a power of 2 is recommended. Defaults to `min(4096, num_cells)`. chunk_size_search: the chunk size used to group query cells together during the nearest-neighbor search. Overly small values will tend to increase runtime by reducing the reuse of information during the search, whereas overly large values will lead to excessive memory use. Defaults to `min(256, num_cells)`. seed: the random seed to use when finding nearest neighbors overwrite: if `True`, overwrite `neighbors_key` if already present in `obsm`, instead of raising an error verbose: whether to print details of the nearest-neighbor search num_threads: the number of threads to use when finding nearest neighbors. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. `num_threads` will be capped to 64 when running with Scipy linked against OpenBLAS (see warning below). Returns: A new SingleCell dataset with the indices of each cell's nearest neighbors - not counting the cell itself - stored in `obsm[neighbors_key]` as a `len(obs)` × `num_neighbors` NumPy array, where `obsm[neighbors_key][i, j]` stores the index of the `i`th cell's `j + 1`th nearest neighbor. (This differs from Scanpy and Seurat, which use a less compact sparse matrix representation instead.) For instance, if `num_neighbors=2` and the 0th cell's nearest neighbors are the 4th cell and the 6th cell, then `obsm[neighbors_key][0]` will be `np.array([4, 6])`. Note that if `QC_column` is not `None`, these integer indices are with respect to QCed cells, not all cells. The squared Euclidean distance to each nearest neighbor will be stored in `obsm[distances_key]` as an array of the same shape as `obsm[neighbors_key]`. Warning: If you installed Scipy via pip, it will be linked against OpenBLAS, and `neighbors()` will be limited to 64 threads due to the limitations of OpenBLAS. To use more than 64 threads, install Scipy linked against MKL BLAS. This is done automatically when installing brisc via conda, but you can also do it manually via `conda install "libblas=*=*mkl" scipy`. """ # Check that `neighbors_key` is a string check_type(neighbors_key, 'neighbors_key', str, 'a string') # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Check that `neighbors_key` is not already a key in `obsm`, unless # `overwrite=True` if not overwrite and neighbors_key in self._obsm: error_message = ( f'neighbors_key {neighbors_key!r} is already a key of obsm; ' f'did you already run neighbors()? Set overwrite=True to ' f'overwrite.') raise ValueError(error_message) # Get the QC column, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') # Check that `PC_key` is the name of a key in `obsm` check_type(PC_key, 'PC_key', str, 'a string') if PC_key not in self._obsm: error_message = f'PC_key {PC_key!r} is not a key of obsm' if PC_key == 'pca': error_message += \ '; did you forget to run PCA() before neighbors()?' raise ValueError(error_message) # Get PCs, and check that they are float32 and C-contiguous PCs = self._obsm[PC_key] if PCs.dtype != np.float32: error_message = \ f'obsm[{PC_key!r}].dtype is {PCs.dtype!r}, but must be float32' raise TypeError(error_message) if not PCs.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{PC_key!r}] is not C-contiguous; make it C-contiguous ' f'with pipe_obsm_key({PC_key!r}, np.ascontiguousarray)') raise ValueError(error_message) # Check that `num_kmeans_iterations` is a positive integer check_type(num_kmeans_iterations, 'num_kmeans_iterations', int, 'a positive integer') check_bounds(num_kmeans_iterations, 'num_kmeans_iterations', 1) # Check that `kmeans_tolerance` is a positive number check_type(kmeans_tolerance, 'kmeans_tolerance', (int, float), 'a positive number') check_bounds(kmeans_tolerance, 'kmeans_tolerance', 0, left_open=True) # Check that `kmeans_barbar` is Boolean check_type(kmeans_barbar, 'kmeans_barbar', bool, 'Boolean') # Check that `num_init_iterations` is a positive integer check_type(num_init_iterations, 'num_init_iterations', int, 'a positive integer') check_bounds(num_init_iterations, 'num_init_iterations', 1) # Check that `oversampling_factor` is a positive number check_type(oversampling_factor, 'oversampling_factor', (int, float), 'a positive number') check_bounds(oversampling_factor, 'oversampling_factor', 0, left_open=True) # If `kmeans_barbar=False`, check that `num_init_iterations` and # `oversampling_factor` have their default values if not kmeans_barbar: if num_init_iterations != 5: error_message = ( 'num_init_iterations can only be specified when ' 'kmeans_barbar=True') raise ValueError(error_message) if oversampling_factor != 1: error_message = ( 'oversampling_factor can only be specified when ' 'kmeans_barbar=True') raise ValueError(error_message) # Check that `seed` is an integer check_type(seed, 'seed', int, 'an integer') # Check that `verbose` is Boolean check_type(verbose, 'verbose', bool, 'Boolean') # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Cap to 64 threads for OpenBLAS to avoid the error "OpenBLAS : Program # is Terminated. Because you tried to allocate too many memory # regions". if num_threads > 64 and any(lib.get('internal_api') == 'openblas' for lib in threadpool_info()): num_threads = 64 # Subset PCs to QCed cells only, if `QC_column` is not `None` if QC_column is not None: QC_column_NumPy = QC_column.to_numpy() if num_threads == 1: PCs = PCs[QC_column_NumPy] else: indices = np.flatnonzero(QC_column_NumPy) PCs = parallel_subset_2d(PCs, indices, num_threads) # Check that `num_neighbors` is between 1 and `num_cells - 1`. check_type(num_neighbors, 'num_neighbors', int, 'a positive integer') num_cells = len(PCs) if not 1 <= num_neighbors < num_cells: error_message = ( f'num_neighbors is {num_neighbors:,}, but must be ≥ 1 and ' f'less than the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that `num_clusters` is between 1 and `num_cells`; if `None`, # set to `ceil(min(4 * sqrt(num_cells), num_cells / 100)))` if num_clusters is None: num_clusters = \ int(np.ceil(min(4 * np.sqrt(num_cells), num_cells / 100))) else: check_type(num_clusters, 'num_clusters', int, 'a positive integer') if not 1 <= num_clusters <= num_cells: error_message = ( f'num_clusters is {num_clusters:,}, but must be ≥ 1 and ≤ ' f'the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that `num_clusters_searched` is between 1 and `num_clusters`; # if `None`, set to `min(64, num_clusters)` if num_clusters_searched is None: num_clusters_searched = min(64, num_clusters) else: check_type(num_clusters_searched, 'num_clusters_searched', int, 'a positive integer') if not 1 <= num_clusters_searched <= num_clusters: error_message = ( f'num_clusters_searched is {num_clusters_searched:,}, but ' f'must be ≥ 1 and ≤ num_clusters ({num_clusters:,})') raise ValueError(error_message) # Check that `chunk_size_kmeans` is between 1 and `num_cells`; if # `None`, set to `min(4096, num_cells)` if chunk_size_kmeans is None: chunk_size_kmeans = min(4096, num_cells) else: check_type(chunk_size_kmeans, 'chunk_size_kmeans', int, 'a positive integer') if not 1 <= chunk_size_kmeans <= num_cells: error_message = ( f'chunk_size_kmeans is {chunk_size_kmeans:,}, but must ' f'be ≥ 1 and ≤ the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that `chunk_size_search` is between 1 and `num_cells`; if # `None`, set to `min(256, num_cells)` if chunk_size_search is None: chunk_size_search = min(256, num_cells) else: check_type(chunk_size_search, 'chunk_size_search', int, 'a positive integer') if not 1 <= chunk_size_search <= num_cells: error_message = ( f'chunk_size_search is {chunk_size_search:,}, but must be ' f'≥ 1 and ≤ the number of cells ({num_cells:,})') raise ValueError(error_message) # Run k-means clustering on `PCs`. Since `centroids` and # `centroids_new` are swapped every iteration, `centroids_new` will # contain the final centroids when doing an odd number of k-means # iterations; if so, swap them at the end. Use `threadpool_limits()` to # run BLAS single-threaded (directly in the single-threaded case, by # making everything single-threaded including BLAS; indirectly in the # parallel case, by disabling nested parallelism inside `prange()`). # `kmeans()` populates `cell_norms` (||X||²) for reuse during the # nearest-neighbor search. num_dimensions = PCs.shape[1] if num_threads == 1: cluster_labels = np.empty(num_cells, dtype=np.uint32) min_distances = np.empty(num_cells, dtype=np.float32) cell_norms = np.empty(num_cells, dtype=np.float32) else: cluster_labels = numa_zeros(num_cells, dtype=np.uint32) min_distances = numa_zeros(num_cells, dtype=np.float32) cell_norms = numa_zeros(num_cells, dtype=np.float32) centroids = np.empty((num_clusters, num_dimensions), dtype=np.float32) centroids_new = np.empty((num_clusters, num_dimensions), dtype=np.float32) num_cells_per_cluster = np.empty(num_clusters, dtype=np.uint32) with threadpool_limits(num_threads): iterations_until_convergence, num_clusters = kmeans( X=PCs, cluster_labels=cluster_labels, centroids=centroids, centroids_new=centroids_new, num_cells_per_cluster=num_cells_per_cluster, min_distances=min_distances, cell_norms=cell_norms, kmeans_barbar=kmeans_barbar, num_init_iterations=num_init_iterations, num_kmeans_iterations=num_kmeans_iterations, tolerance=kmeans_tolerance, oversampling_factor=oversampling_factor, seed=seed, chunk_size=chunk_size_kmeans, num_threads=num_threads) if iterations_until_convergence & 1: centroids = centroids_new del centroids_new, min_distances centroids = centroids[:num_clusters] num_cells_per_cluster = num_cells_per_cluster[:num_clusters] # As an optimization, renumber clusters so that cluster `N` tends to be # near cluster `N + 1` in PC space, by sorting the centroids by PC1. # This reduces cache misses during the nearest-neighbor search. centroid_order = np.argsort(centroids[:, 0]).astype(np.uint32) centroids = np.ascontiguousarray(centroids[centroid_order]) num_cells_per_cluster = num_cells_per_cluster[centroid_order] inverse_centroid_order = np.empty_like(centroid_order) inverse_centroid_order[centroid_order] = np.arange(len(centroid_order), dtype=np.uint32) cluster_labels = inverse_centroid_order[cluster_labels] # As an optimization, sort by cluster to reduce cache misses in the # nearest-neighbor search. This also lets us skip building an explicit # inverted file index during the nearest-neighbor search. sorted_order = np.argsort(cluster_labels).astype(np.uint32) del cluster_labels if num_threads == 1: PCs = PCs[sorted_order] cell_norms = cell_norms[sorted_order] else: PCs = parallel_subset_2d(PCs, sorted_order, num_threads) cell_norms = parallel_subset_1d(cell_norms, sorted_order, num_threads) # Find the `num_neighbors` nearest neighbors of each cell, according to # `PCs`. Allocate the heap arrays, plus the centroid-distance and # nearest-cluster scratch buffers, NUMA-aware. Use # `threadpool_limits()` to run BLAS single-threaded (directly in the # single-threaded case, by making everything single-threaded including # BLAS; indirectly in the parallel case, by disabling nested # parallelism inside `prange()`). if num_threads == 1: neighbors = np.empty((num_cells, num_neighbors), dtype=np.uint32) distances = np.empty((num_cells, num_neighbors), dtype=np.float32) centroid_distances = np.empty((num_cells, num_clusters_searched), dtype=np.float32) nearest_clusters = np.empty((num_cells, num_clusters_searched), dtype=np.uint32) else: neighbors = numa_zeros((num_cells, num_neighbors), dtype=np.uint32) distances = numa_zeros((num_cells, num_neighbors), dtype=np.float32) centroid_distances = numa_zeros( (num_cells, num_clusters_searched), dtype=np.float32) nearest_clusters = numa_zeros( (num_cells, num_clusters_searched), dtype=np.uint32) with threadpool_limits(num_threads): knn_self(X=PCs, sorted_order=sorted_order, centroids=centroids, num_cells_per_cluster=num_cells_per_cluster, neighbors=neighbors, distances=distances, cell_norms=cell_norms, centroid_distances=centroid_distances, nearest_clusters=nearest_clusters, num_neighbors=num_neighbors, num_clusters_searched=num_clusters_searched, chunk_size_kmeans=chunk_size_kmeans, chunk_size_search=chunk_size_search, num_threads=num_threads) del centroid_distances, nearest_clusters, cell_norms # Invert the argsort so that `neighbors` and `distances` are with # respect to the original dataset, rather than being sorted by cluster. # (The remapping of the nearest-neighbor indices themselves is taken # care of in `knn_self()` via the `sorted_order` argument.) if num_threads == 1: unsorted_neighbors = np.empty_like(neighbors) unsorted_neighbors[sorted_order] = neighbors neighbors = unsorted_neighbors unsorted_distances = np.empty_like(distances) unsorted_distances[sorted_order] = distances distances = unsorted_distances else: inverse_order = np.empty_like(sorted_order) inverse_order[sorted_order] = np.arange(len(sorted_order), dtype=sorted_order.dtype) neighbors = parallel_subset_2d(neighbors, inverse_order, num_threads) distances = parallel_subset_2d(distances, inverse_order, num_threads) # If `QC_column` was specified, back-project from QCed cells to all # cells, filling with `NaN` (for the distances) and `UINT32_MAX` (for # the nearest-neighbor indices) if QC_column is not None: neighbors_QCed = neighbors neighbors = np.full((len(self), neighbors_QCed.shape[1]), 4_294_967_295, dtype=np.uint32) neighbors[QC_column_NumPy] = neighbors_QCed distances_QCed = distances distances = np.full((len(self), distances_QCed.shape[1]), np.nan, dtype=np.float32) distances[QC_column_NumPy] = distances_QCed return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm | {neighbors_key: neighbors, distances_key: distances}, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def shared_neighbors(self, *, QC_column: SingleCellColumn | None | Sequence[SingleCellColumn | None] = 'passed_QC', neighbors_key: str = 'neighbors', shared_neighbors_key: str = 'shared_neighbors', min_shared_neighbors: int | np.integer = 3, overwrite: bool = False, num_threads: int | np.integer | None = None) -> \ SingleCell: """ Calculate the shared nearest neighbor graph of this dataset's cells. This function is intended to be run after `neighbors()`; by default, it uses `obsm['neighbors']` as the input to the shared nearest-neighbors calculation. This function defines the shared nearest neighbor graph based on the Jaccard index. It matches Seurat's output, except that diagonal elements are omitted rather than being set to 1. It does not match Scanpy, which estimates the shared nearest neighbor graph based on the connectivity of the UMAP manifold. This function must be re-run if the dataset is subset; not doing so will raise an error. Args: QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and excluded from the shared nearest neighbor graph. neighbors_key: the key of `obsm` containing the nearest neighbors of each cell calculated with `neighbors()`, to use as the input for the shared nearest neighbor graph calculation shared_neighbors_key: the key of `obsp` where the shared nearest neighbor graph will be stored min_shared_neighbors: the minimum number of neighbors a pair of cells must share to include an edge between them in the shared nearest neighbor graph. With 20 nearest neighbors (the default `num_neighbors` in `neighbors()`) + 1 for the cell itself, the default value of `min_shared_neighbors=3` corresponds to the default value of `prune.SNN = 1 / 15` in Seurat's `FindNeighbors()` function. With 3 shared neighbors, the shared nearest neighbor weight is `3 / (42 - 3)` or about 0.077, which is greater than `1 / 15`, but when there are 2, the weight is only `2 / (42 - 2)` or 0.05, which is less than `1 / 15`. overwrite: if `True`, overwrite `shared_neighbors_key` if already present in `obsp`, instead of raising an error num_threads: the number of threads to use when finding shared nearest neighbors. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset with each cell's shared nearest neighbor graph stored in `obsp[shared_neighbors_key]` as a symmetric `len(obs)` × `len(obs)` sparse array. Specifically, `obsp[shared_neighbors_key][i, j]` stores the Jaccard index of the `i`th and `j`th cell's nearest neighbors: the number of cells that are neighbors of both `i` and `j`, divided by the number of cells that are neighbors of at least one of `i` and `j`. Diagonal elements are omitted. For instance, if 20 nearest neighbors have been calculated (i.e. `obsm[neighbors_key].shape[1] == 20`) and 8 of the 20 cells in `obsm[neighbors_key][i]` are also found in `obsm[neighbors_key][j]`, then `obsp[shared_neighbors_key][i, j]` will be 0.25 (`8 / (20 + 20 - 8)`). """ # Check that `shared_neighbors_key` is a string check_type(shared_neighbors_key, 'shared_neighbors_key', str, 'a string') # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Check that `shared_neighbors_key` is not already a key in `obsp`, # unless `overwrite=True` if not overwrite and shared_neighbors_key in self._obsp: error_message = ( f'shared_neighbors_key {shared_neighbors_key!r} is already a ' f'key of obsp; did you already run shared_neighbors()? Set ' f'overwrite=True to overwrite.') raise ValueError(error_message) # Get the QC column, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') # Get the nearest-neighbor indices, and check that they are uint32 and # C-contiguous check_type(neighbors_key, 'neighbors_key', str, 'a string') if neighbors_key not in self._obsm: error_message = \ f'neighbors_key {neighbors_key!r} is not a key of obsm' if neighbors_key == 'neighbors': error_message += ( '; did you forget to run neighbors() before ' 'shared_neighbors()?') raise ValueError(error_message) neighbors = self._obsm[neighbors_key] if neighbors.dtype != np.uint32: error_message = ( f'obsm[{neighbors_key!r}] must have uint32 data type, but ' f'has data type {str(neighbors.dtype)!r}') raise TypeError(error_message) if not neighbors.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{neighbors_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)') raise ValueError(error_message) # Check that `min_shared_neighbors` is less than the number of # neighbors check_type(min_shared_neighbors, 'min_shared_neighbors', int, 'a non-negative integer') num_cells, num_neighbors = neighbors.shape if not 0 <= min_shared_neighbors < num_neighbors: error_message = ( f'min_shared_neighbors is {min_shared_neighbors:,}, but must ' f'be ≥ 0 and less than the number of neighbors in ' f'obsm[{neighbors_key!r}] ({num_neighbors:,})') raise ValueError(error_message) # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Subset neighbor indices to QCed cells only, if `QC_column` is not # `None`; also map each cell's row index in the new neighbors array to # its row index in the original neighbors array if QC_column is not None: QC_column_NumPy = QC_column.to_numpy() QCed_to_full_map = np.flatnonzero(QC_column_NumPy) if num_threads == 1: neighbors = neighbors[QCed_to_full_map] else: neighbors = parallel_subset_2d( neighbors, QCed_to_full_map, num_threads) else: QCed_to_full_map = np.array([], dtype=np.int64) # Compute the shared nearest neighbor graph indptr = np.empty(num_cells + 1, dtype=np.int64) indices, data = snn( neighbors=neighbors, QCed_to_full_map=QCed_to_full_map, indptr=indptr, min_shared_neighbors=min_shared_neighbors, neighbors_key=neighbors_key, num_threads=num_threads) snn_graph = csr_array((data, indices, indptr), shape=(num_cells, num_cells)) snn_graph._num_threads = self._num_threads return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp | {shared_neighbors_key: snn_graph}, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def cluster(self, *, QC_column: SingleCellColumn | None | Sequence[SingleCellColumn | None] = 'passed_QC', resolution: int | float | np.integer | np.floating | Iterable[int | float | np.integer | np.floating] = 1, min_cluster_size: int | np.integer = 5, shared_neighbors_key: str = 'shared_neighbors', neighbors_key: str = 'neighbors', cluster_column: str | Iterable[str] = 'cluster', seed: int | np.integer = 0, overwrite: bool = False, verbose: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Cluster cells into cell types using Leiden clustering. This function is intended to be run after `shared_neighbors()`; by default, it uses `obsm['shared_neighbors']` as the input to the clustering. Our Leiden implementation is based on [GVE-Leiden](https://arxiv.org/abs/2312.13936), a non-deterministic parallel version of the Leiden algorithm. To ensure deterministic output, our implementation is not parallelized. However, when multiple resolutions are specified, clusterings for all resolutions run in parallel by default. Like Seurat and Scanpy's Leiden clustering, our implementation optimizes the objective function corresponding to Reichardt and Bornholdt's Potts model (`RBConfigurationVertexPartition` in the reference implementation of Leiden clustering, `leidenalg`). This formulation extends modularity by introducing a `resolution` parameter; with the default `resolution = 1`, it is exactly equivalent to maximizing modularity. Args: QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their cell-type labels set to `NaN`. resolution: the key parameter of Leiden clustering. Larger values result in more clusters. If multiple resolutions are specified, the clustering for each resolution is executed in parallel by default. min_cluster_size: the minimum cluster size. Cells in clusters of size less than `min_cluster_size` will be merged into the cluster of their nearest neighbor that is in a cluster of size ≥ `min_cluster_size`. Set `min_cluster_size=1` to disable this merging. Clusters with only one or two cells may occur if they are disconnected from the rest of the shared nearest neighbor graph. shared_neighbors_key: the key of `obsp` containing the shared nearest neighbor graph calculated with `shared_neighbors()`, to use as the input for Leiden clustering. Must be symmetric, although this is not checked for, due to speed considerations. Diagonals are assumed to be 1; the actual values are ignored. neighbors_key: the key of `obsm` containing the nearest neighbors of each cell calculated with `neighbors()`. This is only used to merge disconnected clusters of size less than `min_cluster_size`. Not used when `min_cluster_size=1`. cluster_column: the name of an integer column to be added to obs indicating the cell-type labels. If `N` resolutions are specified, `N` columns named `f'{cluster_column}_0'` through `f'{cluster_column}_{N - 1}'` will be added. seed: the random seed to use when clustering overwrite: if `True`, overwrite `cluster_column` if already present in `obs`, instead of raising an error verbose: whether to print details of the Leiden clustering; must be `False` when running multithreaded num_threads: the number of threads to use when clustering. Parallelization takes place across resolutions. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. In both cases, only as many cores will be used as the number of resolutions specified. Specifying `num_threads=-1` when only one resolution is specified will raise an error, as will specifying a positive value for `num_threads` that is greater than the number of resolutions specified. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset where `obs[cluster_column]` is an Enum column containing an integer cell-type label for each cell (`'0'`, `'1'`, etc.). Or, if `N` resolutions are specified, a dataset where `obs[f'{cluster_column}_0']` through `obs[f'{cluster_column}_{N - 1}']` contain `N` sets of cell-type labels. Note: This function may give an incorrect output if you specified a custom shared nearest-neighbor graph that a) is non-symmetric (i.e. `(shared_neighbors != shared_neighbors_key.T).nnz`, where `shared_neighbors = sc.obsp[shared_neighbors_key]`, is non-zero), b) contains explicit zeros (i.e. if `(shared_neighbors.data == 0).any()`), or c) contains negative values: these are not checked for, due to speed considerations. In the unlikely event that your custom shared nearest-neighbor graph contains explicit zeros, remove them by running `sc.obsp[shared_neighbors_key].eliminate_zeros()` (an in-place operation) first. """ # Check that `cluster_column` is a string check_type(cluster_column, 'cluster_column', str, 'a string') # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Get the resolution(s). Check that `cluster_column` (if one # resolution) or `f'{cluster_column}_1'` through # `f'{cluster_column}_{N}'` (if `N` resolutions) are not already # columns of `obs`, unless `overwrite=True`. resolutions = to_tuple_checked(resolution, 'resolution', (int, float), 'positive numbers') num_resolutions = len(resolutions) if num_resolutions == 1: check_bounds(resolution, 'resolution', 0, left_open=True) if not overwrite and cluster_column in self._obs: error_message = ( f'cluster_column {cluster_column!r} is already a column ' f'of obs; did you already run cluster()? Set ' f'overwrite=True to overwrite.') raise ValueError(error_message) else: resolutions = np.array(resolutions, dtype=np.float32) for resolution in resolutions: if resolution < 0: error_message = ( f'a resolution is {resolution:,}, but all resolutions ' f'must be > 0') raise ValueError(error_message) if not overwrite: for resolution_index in range(1, num_resolutions + 1): composite_cluster_column = \ f'{cluster_column}_{resolution_index}' if composite_cluster_column in self._obs: error_message = ( f'{composite_cluster_column!r} is already a ' f'column of obs; did you already run cluster()? ' f'Set overwrite=True to overwrite.') raise ValueError(error_message) # Check that `min_cluster_size` is a positive integer check_type(min_cluster_size, 'num_doublet_genes', int, 'a positive integer') check_bounds(min_cluster_size, 'num_doublet_genes', 1) # Get the shared nearest neighbor graph, and check that it is float32 check_type(shared_neighbors_key, 'shared_neighbors_key', str, 'a string') if shared_neighbors_key not in self._obsp: error_message = ( f'shared_neighbors_key {shared_neighbors_key!r} is not a key ' f'of obsp') if shared_neighbors_key == 'shared_neighbors': error_message += ( '; did you forget to run shared_neighbors() before ' 'cluster()?') raise ValueError(error_message) snn_graph = self._obsp[shared_neighbors_key] if snn_graph.dtype != np.float32: error_message = ( f'obsp[{shared_neighbors_key!r}] must have data type float32, ' f'but has data type {str(snn_graph.dtype)!r}') raise TypeError(error_message) if snn_graph.nnz == 0: error_message = f'obsp[{shared_neighbors_key!r}] is an empty graph' raise ValueError(error_message) # Get the QC column, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') # If `min_cluster_size > 1`, get the nearest-neighbor indices, and # check that they are uint32 and C-contiguous. If `min_cluster_size=1`, # check that `neighbors_key` is not specified (i.e. retains its default # value). if min_cluster_size > 1: check_type(neighbors_key, 'neighbors_key', str, 'a string') if neighbors_key not in self._obsm: error_message = \ f'neighbors_key {neighbors_key!r} is not a key of obsm' if neighbors_key == 'neighbors': error_message += ( '; did you forget to run neighbors() before ' 'shared_neighbors()?') raise ValueError(error_message) neighbors = self._obsm[neighbors_key] if neighbors.dtype != np.uint32: error_message = ( f'obsm[{neighbors_key!r}] must have uint32 data type, but ' f'has data type {str(neighbors.dtype)!r}') raise TypeError(error_message) if not neighbors.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{neighbors_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)') raise ValueError(error_message) else: neighbors = np.array([[]], dtype=np.uint32) if neighbors_key != 'neighbors': error_message = \ 'neighbors_key cannot be specified when min_cluster_size=1' raise ValueError(error_message) # Check that `seed` is an integer check_type(seed, 'seed', int, 'an integer') # Check that `num_threads` is a positive integer, -1 or `None`. If # `None`, set to `min(self.num_threads, len(resolutions))`. If -1, set # to `min(os.cpu_count(), len(resolutions))`, but raise an error if the # user only specified one resolution. If a positive integer, raise an # error if the user specified more threads than resolutions. if num_threads is None: num_threads = min(self._num_threads, num_resolutions) else: check_type(num_threads, 'num_threads', int, 'a positive integer, -1, or None') if num_threads == -1: if num_resolutions == 1: error_message = ( 'only one resolution was specified, so num_threads ' 'must be 1 or None, not -1') raise ValueError(error_message) num_threads = min(os.cpu_count(), num_resolutions) else: num_threads = int(num_threads) if num_threads <= 0: error_message = ( f'num_threads is {num_threads:,}, but must be a ' f'positive integer, -1, or None') raise ValueError(error_message) if num_threads > num_resolutions: error_message = ( f'num_threads is {num_threads:,}, but cannot be ' f'greater than the number of resolutions specified ' f'({num_resolutions:,})') raise ValueError(error_message) # Check that `verbose` is Boolean, and `False` when running # multithreaded check_type(verbose, 'verbose', bool, 'Boolean') if verbose and num_threads > 1: error_message = 'verbose must be False when running multithreaded' raise ValueError(error_message) # Subset the nearest-neighbor indices and SNN graph to cells passing # QC, if `QC_column` is not `None` if QC_column is not None: QC_column_NumPy = QC_column.to_numpy() if num_threads == 1: neighbors = neighbors[QC_column_NumPy] else: indices = np.flatnonzero(QC_column_NumPy) neighbors = parallel_subset_2d(neighbors, indices, num_threads) snn_graph = snn_graph[ix_symmetric(QC_column_NumPy)] # Perform Leiden clustering num_cells = snn_graph.shape[0] if num_resolutions == 1: clusters = np.empty(num_cells, dtype=np.uint32) num_final_communities, too_large, self_neighbors = leiden( data=snn_graph.data, indices=snn_graph.indices, indptr=snn_graph.indptr, neighbors=neighbors, final_communities=clusters, resolution=resolution, min_cluster_size=min_cluster_size, seed=seed, verbose=verbose) # If any nearest-neighbor indices were out of bounds or equal to # the cell's own index, raise an error if too_large: error_message = ( f'some nearest-neighbor indices in ' f'obsm[{neighbors_key!r}] are >= the total number of ' f'cells, {neighbors.shape[0]:,}. This may happen if you ' f'subset this SingleCell dataset between neighbors() and ' f'cluster(); if so, make sure to run neighbors() after, ' f'not before, subsetting.') raise ValueError(error_message) elif self_neighbors: error_message = ( f'some nearest-neighbor indices in ' f'obsm[{neighbors_key!r}] indicate that a cell is its own ' f'neighbor, i.e. obsm[{neighbors_key!r}][i, j] == i for ' f'some i and j. This may happen if you created ' f'obsm[{neighbors_key!r}] manually rather than following ' f'the recommended approach of running neighbors().') raise ValueError(error_message) clusters = pl.Series(cluster_column, clusters)\ .cast(pl.Enum(map(str, range(num_final_communities + 1)))) if QC_column is not None: # Back-project from QCed cells to all cells, filling with # `null` clusters = pl.when(QC_column)\ .then(clusters .gather(QC_column.cum_sum() - QC_column)) else: clusters = np.empty((num_resolutions, num_cells), dtype=np.uint32) num_final_communities = np.empty(num_resolutions, dtype=np.uint32) if num_threads == 1: for resolution_index, resolution in enumerate(resolutions): if verbose: print(f'\nresolution = {resolution}:') num_final_communities[resolution_index], too_large, \ self_neighbors = leiden( data=snn_graph.data, indices=snn_graph.indices, indptr=snn_graph.indptr, neighbors=neighbors, final_communities=clusters[resolution_index], resolution=resolution, min_cluster_size=min_cluster_size, seed=seed, verbose=verbose) if too_large or self_neighbors: break else: too_large, self_neighbors = leiden_multiresolution( data=snn_graph.data, indices=snn_graph.indices, indptr=snn_graph.indptr, neighbors=neighbors, final_communities=clusters, num_final_communities=num_final_communities, resolutions=resolutions, min_cluster_size=min_cluster_size, seed=seed, num_threads=num_threads) # If any nearest-neighbor indices were out of bounds or equal to # the cell's own index, raise an error if too_large: error_message = ( f'some nearest-neighbor indices in ' f'obsm[{neighbors_key!r}] are >= the total number of ' f'cells, {neighbors.shape[0]:,}. This may happen if you ' f'subset this SingleCell dataset between neighbors() and ' f'cluster(); if so, make sure to run neighbors() after, ' f'not before, subsetting.') raise ValueError(error_message) elif self_neighbors: error_message = ( f'some nearest-neighbor indices in ' f'obsm[{neighbors_key!r}] indicate that a cell is its own ' f'neighbor, i.e. obsm[{neighbors_key!r}][i, j] == i for ' f'some i and j. This may happen if you created ' f'obsm[{neighbors_key!r}] manually rather than following ' f'the recommended approach of running neighbors().') raise ValueError(error_message) clusters = pl.DataFrame(clusters.T, schema=[ f'{cluster_column}_{resolution_index}' for resolution_index in range(num_resolutions)])\ .cast({f'{cluster_column}_{resolution_index}': pl.Enum(map( str, range(num_final_communities[resolution_index] + 1))) for resolution_index in range(num_resolutions)}) if QC_column is not None: # Back-project from QCed cells to all cells, filling with # `null` clusters = clusters.select( pl.when(QC_column).then(pl.all().gather( QC_column.cum_sum() - QC_column))) return SingleCell(X=self._X, obs=self._obs.with_columns(clusters), var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def harmonize(self, *others: SingleCell, QC_column: SingleCellColumn | None | Sequence[SingleCellColumn | None] = 'passed_QC', batch_column: SingleCellColumn | None | Sequence[SingleCellColumn | None] = None, PC_key: str = 'pca', Harmony_key: str = 'harmony', num_clusters: int | np.integer | None = None, max_iterations: int | np.integer = 10, num_kmeans_iterations: int | np.integer = 25, kmeans_tolerance: int | np.integer | float | np.floating = 1e-2, kmeans_barbar: bool = False, num_init_iterations: int | np.integer = 5, oversampling_factor: int | np.integer | float | np.floating = 1, chunk_size_kmeans: int | np.integer | None = None, max_clustering_iterations: int | np.integer = 5, block_proportion: int | float | np.integer | np.floating = 0.05, tolerance: int | float | np.integer | np.floating = 0.01, early_stopping: bool = False, clustering_tolerance: int | float | np.integer | np.floating = 0.001, theta: int | float | np.integer | np.floating = 2, tau: int | float | np.integer | np.floating = 0, alpha: float | np.floating = 0.2, sigma: int | float | np.integer | np.floating = 0.1, chunk_size_Harmony: int | np.integer | None = None, seed: int | np.integer = 0, original: bool = False, overwrite: bool = False, verbose: bool = True, num_threads: int | np.integer | None = None) -> \ SingleCell | tuple[SingleCell, ...]: """ Harmonize this SingleCell dataset with other datasets, or harmonize multiple batches of the same dataset, with [Harmony2](https://github.com/immunogenomics/harmony). Our implementation differs from the original Harmony2 implementation in three key ways: First, we parallelize Harmony via an innovative nested block strategy. The original implementation partitions cells randomly into blocks, each containing a fixed fraction (`block_proportion`, 5% by default) of the total cells. It iterates over each block, subtracting the contribution of the cells in the block to the observed and expected cluster-batch co-occurence matrices `O` and `E`, re-calculating the soft-clustering assignments `R` based on the residual `O` and `E`, then adding back the contribution of the cells in the block to `O` and `E` based on the new `R`. This approach resists straightforward parallelization because updates to `R` for future blocks depend on updates to `O` and `E` from previous blocks, leading to convergence failure if naively parallelized. Instead, we parallelize within blocks by dividing them into chunks of `chunk_size` cells (512 by default). We process each chunk in parallel, subtracting only the `O` and `E` contributions of the chunk itself, updating `R` for the chunk, and, after and, after processing all chunks in the block, add back the `O` and `E` contributions for all chunks based on the updated `R`. Updating `O` and `E` at the end of each block (rather than after processing every cell in the dataset) ensures convergence, while the inner chunking enables parallelization without disrupting convergence. To use the original implementation's chunkless strategy for updating `R`, `O`, and `E`, specify `original=True, num_threads=1`. Second, we reduce the default number of clustering iterations (Harmony's inner loop) from 20 to 5, but always complete all 5 iterations without early stopping based on convergence. In practice, the largest changes to the Harmony objective function result from updating the PCs (outer loop), not updating the soft-clustering assignments (inner loop). Our implementation makes updating the PCs sufficiently fast that there are rapidly diminishing returns from performing lots of clustering updates, versus just skipping directly to updating the PCs after a few clustering iterations. By skipping convergence checks, we reduce the number of objective function evaluations (which are expensive) from once per inner iteration to once per outer iteration. To use the original implementation's default of 20 iterations with early stopping, specify `early_stopping=True, max_clustering_iterations=20`. Third, we use random k-means initialization, like version 1 of Harmony, rather than Harmony2's kmeans++ initialization. We implement a parallel version of kmeans++ called [k-means||](https://arxiv.org/abs/1203.6402), but find it does not improve the accuracy of label transfer relative to random initialization and increases runtime. kmeans|| initialization can be enabled with `kmeans_barbar=True`. Args: others: the other SingleCell datasets to harmonize this one with. Can be omitted if harmonizing between batches of the same dataset, but then `batch_column` must be specified. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their Harmony embeddings set to `NaN`. When `others` is specified, `QC_column` can be a length-`1 + len(others)` sequence of columns, expressions, Series, functions, or `None` for each dataset (for `self`, followed by each dataset in `others`). batch_column: an optional String, Enum, Categorical, or integer column of `obs` indicating which batch each cell is from. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Each batch will be treated as if it were a distinct dataset; this is exactly equivalent to splitting the dataset with `split_by(batch_column)` and then passing each of the resulting datasets to `harmonize()`. Set to `None` to treat each dataset as having a single batch. When `others` is specified, `batch_column` may be a length-`1 + len(others)` sequence of columns, expressions, Series, functions, or `None` for each dataset (for `self`, followed by each dataset in `others`). PC_key: the key of `obsm` containing the principal components calculated with `PCA()`, to use as the input to Harmony Harmony_key: the key of `obsm` where the Harmony embeddings will be stored; will be added in-place to both `self` and each of the datasets in `others`! num_clusters: the number of clusters used in the Harmony algorithm, including in the initial k-means clustering. If not specified, use 100 clusters if ≥3000 cells, 1 cluster if ≤30 cells, and `round(num_cells / 30)` if between 30 and 3000 cells. For datasets of at least 10,000 cells, clusters with fewer than 100 cells will be merged into the adjacent cluster with the nearest centroid, so the actual number of clusters used may be smaller. max_iterations: the maximum number of iterations to run Harmony for, if convergence is not achieved. Defaults to 10, like the original Harmony R package, harmony-pytorch, and harmonypy. Set to `None` to use as many iterations as necessary to achieve convergence. num_kmeans_iterations: the maximum number of iterations of k-means clustering to perform before starting the nearest-neighbor search, stopping early if a relative convergence of `kmeans_tolerance` is reached. Defaults to 25, like the original Harmony R package, harmony-pytorch, and harmonypy. However, unlike these packages, only one initialization is tried rather than 10 to reduce runtime. kmeans_tolerance: the relative change in inertia (the sum of squared distances from each cell to its assigned centroid) used to determine whether to stop optimizing the k-means clustering before `num_kmeans_iterations` iterations kmeans_barbar: whether to use [k-means||](https://arxiv.org/abs/1203.6402) initialization (a parallel version of k-means++) to initialize the k-means clustering centroids, instead of random initialization. This is more accurate but takes considerably longer for large datasets and `num_clusters`. num_init_iterations: the number of k-means|| iterations used to initialize the k-means clustering that constitutes the first step of the nearest-neighbor search. k-means|| is a parallel version of the widely used k-means++ initialization scheme for k-means clustering. The default value of 5 is recommended by the [k-means|| paper](https://arxiv.org/abs/1203.6402). Only used when `kmeans_barbar=True`. oversampling_factor: the number of candidate centroids selected, on average, at each of the `num_init_iterations` iterations of k-means||, as a multiple of `num_clusters`. The default value of 1 is the midpoint (in log space) of the values explored by the [k-means|| paper](https://arxiv.org/abs/1203.6402), namely 0.1 to 10. The total number of candidate centroids selected, on average, will be `oversampling_factor * num_clusters + 1`, from which the final `num_clusters` centroids will then be selected via k-means++. Only used when `kmeans_barbar=True`. chunk_size_kmeans: the chunk size to use for distance calculations in the initial k-means clustering. Setting this to a power of 2 is recommended. Defaults to `min(4096, total number of cells)`. max_clustering_iterations: the number of iterations to run the clustering step within each Harmony iteration for, or the maximum number of iterations if `early_stopping=True`. Unlike the original Harmony algorithm, convergence of the clustering is not checked unless `early_stopping=True`. Defaults to 5 iterations; this differs from the 20 used by the original harmony R package and harmonypy and the 200 iterations used by harmony-pytorch, which do have convergence checks. Must be 4 or more when `early_stopping=True`, since Harmony's clustering convergence check requires knowing the errors from the past 3 iterations. block_proportion: the proportion of cells to use in each batch update in the clustering step; must be greater than zero and less than or equal to 1 tolerance: the relative tolerance used to determine whether to stop optimizing the Harmony embeddings before `max_iterations` iterations early_stopping: whether to stop clustering before `max_clustering_iterations` iterations if convergence is reached, like in the original Harmony algorithm clustering_tolerance: the relative tolerance used to determine whether to stop clustering before `max_clustering_iterations` iterations. Only used when `early_stopping=True`. theta: the weight of the diversity penalty term in the Harmony objective function; must be non-negative. Larger values result in more diverse clusters. tau: the discounting factor on theta; must be non-negative. By default, `tau = 0`, so there is no discounting. alpha: the scaling factor for the ridge regression penalty used when correcting the principal components to get the Harmony embeddings; must be greater than 0 and less than 1. The ridge penalty `lambda` is determined by `alpha` and the expected number of cells, assuming independence between batches and clusters: `lambda = alpha * expected number of cells`. Smaller values result in more aggressive correction. sigma: the weight of the entropy term in the Harmony objective function; must be non-negative chunk_size_Harmony: the chunk size to use for Harmony. Setting this to a power of 2 is recommended. Defaults to `min(256, total number of cells)`. Not used when `original=True`. seed: the random seed to use for the initial k-means clustering original: if `True`, use the original Harmony algorithm's blocking strategy, rather than our nested chunks-within-blocks strategy. `original=True` requires `num_threads=1`. This gives lower memory usage and closer correspondence to the original algorithm, at the cost of a) a moderate (~20-25%) degradation in performance and b) no longer matching the Harmony embeddings produced by the multithreaded version. If `True`, exactly match the results of the multithreaded version when `num_threads=1`. Must be `False` unless `num_threads=1`. overwrite: if `True`, overwrite `Harmony_key` if already present in obsm, instead of raising an error verbose: whether to print details of the harmonization process num_threads: the number of threads to use when concatenating principal components and batch/dataset labels across datasets, for the initial k-means clustering, and for the matrix and matrix-vector multiplications within Harmony. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. `num_threads` will be capped to 64 when running with Scipy linked against OpenBLAS (see warning below). Returns: A length-`1 + len(others)` tuple of SingleCell datasets with the Harmony embeddings stored in `obsm[Harmony_key]`: `self`, followed by each dataset in `others`. Or, if `others` is omitted, a single SingleCell dataset with the Harmony embeddings stored in `obsm[Harmony_key]`. Warning: If you installed Scipy via pip, it will be linked against OpenBLAS, and `harmonize()` will be limited to 64 threads due to the limitations of OpenBLAS. To use more than 64 threads, install Scipy linked against MKL BLAS. This is done automatically when installing brisc via conda, but you can also do it manually via `conda install "libblas=*=*mkl" scipy`. """ # If `others` was specified, check that all elements of `others` are # SingleCell datasets if others: check_types(others, 'others', SingleCell, 'SingleCell datasets') elif batch_column is None: error_message = \ 'others cannot be empty unless batch_column is specified' raise ValueError(error_message) datasets = [self] + list(others) # Check that `Harmony_key` is a string check_type(Harmony_key, 'Harmony_key', str, 'a string') # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Check that `Harmony_key` is not already in `obsm` for any dataset, # unless `overwrite=True` suffix = ' for at least one dataset' if others else '' if not overwrite and \ any(Harmony_key in dataset._obsm for dataset in datasets): error_message = ( f'Harmony_key {Harmony_key!r} is already a key of ' f'obsm{suffix}; did you already run harmonize()? Set ' f'overwrite=True to overwrite.') raise ValueError(error_message) # Get `QC_column` and `batch_column` from every dataset, if not `None` QC_columns = SingleCell._get_columns( 'obs', datasets, QC_column, 'QC_column', pl.Boolean, allow_missing=True) QC_columns_NumPy = [QC_col.to_numpy() if QC_col is not None else None for QC_col in QC_columns] batch_columns = SingleCell._get_columns( 'obs', datasets, batch_column, 'batch_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_columns=QC_columns) # Check that `PC_key` is a key of `obsm` for every dataset check_type(PC_key, 'PC_key', str, 'a string') if not all(PC_key in dataset._obsm for dataset in datasets): error_message = ( f'PC_key {PC_key!r} is not a column of obs{suffix}; did you ' f'forget to run PCA() before harmonize()?') raise ValueError(error_message) # If `num_clusters` is not `None`, check that it is a positive integer # We will assign `num_clusters` its default value (if `None`) and check # its upper bound later, once we know the total number of cells across # all datasets. if num_clusters is not None: check_type(num_clusters, 'num_clusters', int, 'a positive integer') check_bounds(num_clusters, 'num_clusters', 1) # Check that `max_iterations` is `None` or a positive integer; if # `None`, set to `INT32_MAX` if max_iterations is None: max_iterations = 2_147_483_647 else: check_type(max_iterations, 'max_iterations', int, 'a positive integer') check_bounds(max_iterations, 'max_iterations', 1) # Check that `num_kmeans_iterations` is a positive integer check_type(num_kmeans_iterations, 'num_kmeans_iterations', int, 'a positive integer') check_bounds(num_kmeans_iterations, 'num_kmeans_iterations', 1) # Check that `kmeans_tolerance` is a positive number check_type(kmeans_tolerance, 'kmeans_tolerance', (int, float), 'a positive number') check_bounds(kmeans_tolerance, 'kmeans_tolerance', 0, left_open=True) # Check that `kmeans_barbar` is Boolean check_type(kmeans_barbar, 'kmeans_barbar', bool, 'Boolean') # Check that `num_init_iterations` is a positive integer check_type(num_init_iterations, 'num_init_iterations', int, 'a positive integer') check_bounds(num_init_iterations, 'num_init_iterations', 1) # Check that `oversampling_factor` is a positive number check_type(oversampling_factor, 'oversampling_factor', (int, float), 'a positive number') check_bounds(oversampling_factor, 'oversampling_factor', 0, left_open=True) # If `kmeans_barbar=False`, check that `num_init_iterations` and # `oversampling_factor` have their default values if not kmeans_barbar: if num_init_iterations != 5: error_message = ( 'num_init_iterations can only be specified when ' 'kmeans_barbar=True') raise ValueError(error_message) if oversampling_factor != 1: error_message = ( 'oversampling_factor can only be specified when ' 'kmeans_barbar=True') raise ValueError(error_message) # Check that `max_clustering_iterations` is a positive integer check_type(max_clustering_iterations, 'max_clustering_iterations', int, 'a positive integer') check_bounds(max_clustering_iterations, 'max_clustering_iterations', 1) # Check that `block_proportion` is a number and that # `0 < block_proportion <= 1` check_type(block_proportion, 'block_proportion', (int, float), 'a number greater than zero and less than or equal to 1') check_bounds(block_proportion, 'block_proportion', 0, 1, left_open=True) # Check that `tolerance` is a positive number; if an integer, cast it # to a float check_type(tolerance, 'tolerance', (int, float), 'a positive number') check_bounds(tolerance, 'tolerance', 0, left_open=True) tolerance = float(tolerance) # Check that `early_stopping` is Boolean check_type(early_stopping, 'early_stopping', bool, 'Boolean') # Check that `clustering_tolerance` is a positive number; if an # integer, cast it to a float check_type(clustering_tolerance, 'clustering_tolerance', (int, float), 'a positive number') check_bounds(clustering_tolerance, 'clustering_tolerance', 0, left_open=True) clustering_tolerance = float(clustering_tolerance) # If `early_stopping=False`, check that `clustering_tolerance` has its # default value if not early_stopping: if clustering_tolerance != 0.001: error_message = ( 'clustering_tolerance can only be specified when ' 'early_stopping=True') raise ValueError(error_message) # Check that `theta` and `tau` are non-negative numbers; if either is # an integer, cast it to a float for parameter, parameter_name in (theta, 'theta'), (tau, 'tau'): check_type(parameter, parameter_name, (int, float), 'a non-negative number') check_bounds(parameter, parameter_name, 0) theta = float(theta) tau = float(tau) # Check that `alpha` is greater than 0 and less than 1 check_type(alpha, 'alpha', float, 'a number greater than 0 and less than 1') check_bounds(alpha, 'alpha', 0, 1, left_open=True, right_open=True) # Check that `sigma` is a non-negative number; if an integer, cast it # to a float check_type(sigma, 'sigma', (int, float), 'a non-negative number') check_bounds(sigma, 'sigma', 0) sigma = float(sigma) # If `chunk_size_kmeans` and/or `chunk_size_Harmony` are not `None`, # check that they are positive integers. We will check their upper # bounds and set their default values (if `None`) later, once we know # the total number of cells across all datasets. for chunk_size, chunk_size_name in \ (chunk_size_kmeans, 'chunk_size_kmeans'), \ (chunk_size_Harmony, 'chunk_size_Harmony'): if chunk_size is not None: check_type(chunk_size, chunk_size_name, int, 'a positive integer') check_bounds(chunk_size, chunk_size_name, 1) # Check that `seed` is an integer check_type(seed, 'seed', int, 'an integer') # Check that `verbose` is Boolean check_type(verbose, 'verbose', bool, 'Boolean') # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Cap to 64 threads for OpenBLAS to avoid the error "OpenBLAS : Program # is Terminated. Because you tried to allocate too many memory # regions". if num_threads > 64 and any(lib.get('internal_api') == 'openblas' for lib in threadpool_info()): num_threads = 64 # Check that `original` is Boolean, and `False` unless `num_threads=1` check_type(original, 'original', bool, 'Boolean') if original and num_threads != 1: error_message = 'original must be False unless num_threads is 1' raise ValueError(error_message) # Concatenate PCs across datasets; get labels indicating which rows of # these concatenated PCs come from each dataset or batch. Check that # the PCs are float32 and C-contiguous and all have the same width. all_PCs = [dataset._obsm[PC_key] for dataset in datasets] for PCs in all_PCs: dtype = PCs.dtype if dtype != np.float32: error_message = ( f'obsm[{PC_key!r}].dtype is {dtype!r}{suffix}, but must ' f'be float32') raise TypeError(error_message) if not PCs.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{PC_key!r}].dtype is not C-contiguous{suffix}; ' f'make it C-contiguous with pipe_obsm_key({PC_key!r}, ' f'np.ascontiguousarray)') raise ValueError(error_message) width = all_PCs[0].shape[1] for PCs in all_PCs[1:]: if PCs.shape[1] != width: error_message = ( f"two datasets' PCs have different numbers of columns " f"({width:,} vs {PCs.shape[1]:,}") raise ValueError(error_message) if QC_column is not None: all_PCs = [PCs[QCed] if QCed is not None else PCs for PCs, QCed in zip(all_PCs, QC_columns_NumPy)] num_cells_per_dataset = np.array(list(map(len, all_PCs))) if batch_column is None: batch_labels = np.repeat(np.arange(len(num_cells_per_dataset), dtype=np.uint32), num_cells_per_dataset) num_batches = len(num_cells_per_dataset) else: batch_labels = [] batch_index = 0 for dataset, QC_col, batch_col in \ zip(datasets, QC_columns, batch_columns): if batch_col is not None: if QC_col is not None: batch_col = batch_col.filter(QC_col) if batch_col.dtype in (pl.String, pl.Enum, pl.Categorical): if batch_col.dtype != pl.Enum: batch_col = batch_col \ .cast(pl.Enum(batch_col.unique().drop_nulls())) batch_col = batch_col.to_physical() if batch_col.dtype != pl.UInt32: batch_col = batch_col.cast(pl.UInt32) batch_labels.append(batch_col.to_numpy() + batch_index) batch_index += batch_col.n_unique() else: batch_labels.append(np.full(batch_index, len(dataset) if QC_col is None else QC_col.sum(), dtype=np.float32)) batch_index += 1 batch_labels = concatenate(batch_labels, num_threads=num_threads) num_batches = batch_index PCs = concatenate(all_PCs, num_threads=num_threads) num_cells, num_PCs = PCs.shape # If `num_clusters` is `None`, set it to `num_cells / 30`, rounded to # the nearest integer and clipped to be between 1 and 100. Otherwise, # check that it is less than the total number of cells across all # datasets. if num_clusters is None: num_clusters = max(min(100, int(round(num_cells / 30))), 1) elif num_clusters >= num_cells: error_message = ( f'num_clusters is {num_clusters:,}, but must be less than the ' f'total number of cells across all datasets ({num_cells:,})') raise ValueError(error_message) # If `chunk_size_kmeans` is `None`, set it to `min(4096, num_cells)`. # Otherwise, check that it is less than the total number of cells # across all datasets. if chunk_size_kmeans is None: chunk_size_kmeans = min(4096, num_cells) elif chunk_size_kmeans >= num_cells: error_message = ( f'chunk_size_kmeans is {chunk_size_kmeans:,}, but must be ' f'less than the total number of cells across all datasets ' f'({num_cells:,})') raise ValueError(error_message) # If `chunk_size_Harmony` is `None`, set it to `min(256, num_cells)`. # Otherwise, check that it is less than the total number of cells # across all datasets. if chunk_size_Harmony is None: chunk_size_Harmony = min(256, num_cells) elif chunk_size_Harmony >= num_cells: error_message = ( f'chunk_size_Harmony is {chunk_size_Harmony:,}, but must be ' f'less than the total number of cells across all datasets ' f'({num_cells:,})') raise ValueError(error_message) # Get `Z`, the row-normalized PCs; allocate temporary buffers # `cluster_labels` and `min_distances`, used in k-means if num_threads == 1: Z = np.empty((num_cells, num_PCs), dtype=np.float32) cluster_labels = np.empty(num_cells, dtype=np.uint32) min_distances = np.empty(num_cells, dtype=np.float32) else: Z = numa_zeros((num_cells, num_PCs), dtype=np.float32) cluster_labels = numa_zeros(num_cells, dtype=np.uint32) min_distances = numa_zeros(num_cells, dtype=np.float32) normalize_rows(arr=PCs, out=Z, num_threads=num_threads) # Run k-means clustering on `Z`. Since `Y` and `Y_new` are swapped # every iteration, `Y_new` will contain the final `Y` when doing an odd # number of k-means iterations; if so, swap them at the end. Use # `threadpool_limits()` to run BLAS single-threaded (directly in the # single-threaded case, by making everything single-threaded including # BLAS; indirectly in the parallel case, by disabling nested # parallelism inside `prange()`). Y = np.empty((num_clusters, num_PCs), dtype=np.float32) Y_new = np.empty((num_clusters, num_PCs), dtype=np.float32) with threadpool_limits(num_threads): iterations_until_convergence, num_clusters = kmeans( X=Z, cluster_labels=cluster_labels, centroids=Y, centroids_new=Y_new, num_cells_per_cluster=np.empty(num_clusters, dtype=np.uint32), min_distances=min_distances, cell_norms=np.array([], dtype=np.float32), kmeans_barbar=kmeans_barbar, num_init_iterations=num_init_iterations, num_kmeans_iterations=num_kmeans_iterations, tolerance=kmeans_tolerance, oversampling_factor=oversampling_factor, seed=seed, chunk_size=chunk_size_kmeans, num_threads=num_threads) if iterations_until_convergence & 1: Y = Y_new del Y_new, cluster_labels, min_distances Y = Y[:num_clusters] # Run Harmony. Unlike above, here we need to use single-threaded matrix # multiplication to ensure consistent floating-point roundoff. Use # `threadpool_limits()` to run BLAS single-threaded (directly in the # single-threaded case, by making everything single-threaded including # BLAS; indirectly in the parallel case, by disabling nested # parallelism inside `prange()`). if num_threads == 1: R = np.empty((num_cells, num_clusters), dtype=np.float32) else: R = numa_zeros((num_cells, num_clusters), dtype=np.float32) if original: with threadpool_limits(1): harmony_original( PCs=PCs, Z=Z, Y=Y, R=R, batch_labels=batch_labels, num_batches=num_batches, max_iterations=max_iterations, max_clustering_iterations=max_clustering_iterations, block_proportion=block_proportion, tolerance=tolerance, early_stopping=early_stopping, clustering_tolerance=clustering_tolerance, theta=theta, tau=tau, alpha=alpha, sigma=sigma, chunk_size=chunk_size_Harmony, seed=seed, verbose=verbose) else: num_threads = min(num_threads, num_cells) with threadpool_limits(num_threads): harmony(PCs=PCs, Z=Z, Y=Y, R=R, batch_labels=batch_labels, num_batches=num_batches, max_iterations=max_iterations, max_clustering_iterations=max_clustering_iterations, block_proportion=block_proportion, tolerance=tolerance, early_stopping=early_stopping, clustering_tolerance=clustering_tolerance, theta=theta, tau=tau, alpha=alpha, sigma=sigma, chunk_size=chunk_size_Harmony, seed=seed, verbose=verbose, num_threads=num_threads) del batch_labels, PCs, Y # Store each dataset's Harmony embedding in its `obsm` if not others: # just one dataset QC_col = QC_columns[0] Harmony_embedding = Z # If `QC_col` is not `None`, back-project from QCed cells to all # cells, filling with `NaN` if QC_col is not None: Harmony_embedding_QCed = Harmony_embedding Harmony_embedding = np.full( (len(self), Harmony_embedding_QCed.shape[1]), np.nan, dtype=np.float32) Harmony_embedding[QC_columns_NumPy[0]] = Harmony_embedding_QCed return SingleCell( X=dataset._X, obs=dataset._obs, var=dataset._var, obsm=dataset._obsm | {Harmony_key: Harmony_embedding}, varm=self._varm, uns=self._uns, num_threads=self._num_threads) else: for dataset_index, (dataset, QC_col, num_cells, end_index) in \ enumerate(zip(datasets, QC_columns_NumPy, num_cells_per_dataset, num_cells_per_dataset.cumsum())): start_index = end_index - num_cells dataset_Harmony_embedding = Z[start_index:end_index] # If `QC_col` is not `None` for this dataset, back-project from # QCed cells to all cells, filling with `NaN` if QC_col is not None: dataset_Harmony_embedding_QCed = dataset_Harmony_embedding dataset_Harmony_embedding = np.full( (len(dataset), dataset_Harmony_embedding_QCed.shape[1]), np.nan, dtype=np.float32) dataset_Harmony_embedding[QC_col] = \ dataset_Harmony_embedding_QCed datasets[dataset_index] = SingleCell( X=dataset._X, obs=dataset._obs, var=dataset._var, obsm=dataset._obsm | { Harmony_key: dataset_Harmony_embedding}, varm=self._varm, uns=self._uns, num_threads=self._num_threads) return tuple(datasets)
[docs] def label_transfer_from( self, other: SingleCell, original_cell_type_column: SingleCellColumn, *, QC_column: SingleCellColumn | None = 'passed_QC', other_QC_column: SingleCellColumn | None = 'passed_QC', Harmony_key: str = 'harmony', cell_type_column: str = 'cell_type', confidence_column: str | None = None, next_best: bool = False, next_best_cell_type_column: str | None = None, next_best_confidence_column: str | None = None, num_neighbors: int | np.integer = 20, num_clusters: int | np.integer | None = None, num_clusters_searched: int | np.integer | None = None, num_kmeans_iterations: int | np.integer = 2, kmeans_tolerance: int | np.integer | float | np.floating = 1e-2, kmeans_barbar: bool = False, num_init_iterations: int | np.integer = 5, oversampling_factor: int | np.integer | float | np.floating = 1, chunk_size_kmeans: int | np.integer | None = None, chunk_size_search: int | np.integer | None = None, seed: int | np.integer = 0, overwrite: bool = False, verbose: bool = True, num_threads: int | np.integer | None = None) -> SingleCell: """ Transfer cell-type labels from another dataset to this one, using the two datasets' Harmony embeddings from `harmonize()`. For each cell in `self`, the transferred cell-type label is the most common cell-type label among the `num_neighbors` cells in `other` with the nearest Harmony embeddings. The cell-type confidence is the fraction of these neighbors that share this most common cell-type label. The nearest-neighbor search is conducted using the same method as `neighbors()`, with one crucial difference: whereas `neighbors()` searches for a cell's nearest neighbors in its own dataset, this function searches for a cell's nearest neighbors in another dataset, i.e. `other`. Args: other: the dataset to transfer cell-type labels from original_cell_type_column: a String, Enum, Categorical, or integer column of `other.obs` containing cell-type labels. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in `other` and returns a polars Series or 1D NumPy array. QC_column: an optional Boolean column of `self.obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in `self` and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will have their cell-type labels and confidences set to `null`. other_QC_column: an optional Boolean column of `other.obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in `other` and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored during the label transfer. Harmony_key: the key of `self.obsm` and `other.obsm` containing the Harmony embeddings for each dataset cell_type_column: the name of a column to be added to `self.obs` indicating each cell's most likely cell type, i.e. the most common cell-type label among the cell's `num_neighbors` nearest neighbors in `other` confidence_column: the name of a column to be added to `self.obs` indicating each cell's cell-type confidence, i.e. the fraction of the cell's `num_neighbors` nearest neighbors in `other` that share the most common cell-type label. If multiple cell types are equally common among the nearest neighbors, tiebreak based on which of them is most common in `original_cell_type_column`. If `None`, defaults to `f'{cell_type_column}_confidence'`. next_best: whether to also compute each cell's second-most likely cell type and confidence, or just its most likely next_best_cell_type_column: the name of a column to be added to `self.obs` indicating each cell's second-most likely cell type, i.e. the second-most common cell-type label among the cell's `num_neighbors` nearest neighbors in `original_cell_type_column`. If `None`, defaults to `f'next_best_{cell_type_column}'`. Can only be specified when `next_best=True`. next_best_confidence_column: the name of a column to be added to `self.obs` indicating each cell's cell-type confidence, i.e. the fraction of the cell's `num_neighbors` nearest neighbors in `original_cell_type_column` that share the second-most common cell-type label. If multiple cell types are equally common among the nearest neighbors, tiebreak based on which of them is most common in `other`. If `None`, defaults to `f'next_best_{cell_type_column}_confidence'`. Can only be specified when `next_best=True`. num_neighbors: the number of nearest neighbors to use when determining a cell's label. All cell-type confidences will be multiples of `1 / num_neighbors`. num_clusters: the number of k-means clusters to use during the nearest-neighbor search. Must be less than the number of cells. If `None`, will be set to `ceil(min(4 * sqrt(num_cells), num_cells / 100))` clusters, i.e. the minimum of four times the square root of the number of cells in `other` and 1% of the number of cells in `other`, rounding up. The core of the heuristic, `4 * sqrt(num_cells)`, is the low end of the range [recommended](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index) by faiss, 4 to 16 times the square root. However, faiss also recommends using between 39 and 256 data points per centroid when training the k-means clustering used in the k-nearest neighbors search. To avoid going below 39, we switch to using `num_cells / 100` centroids for small datasets (fewer than 640,000 cells), since 100 is the midpoint of 39 and 256 in log space. For datasets of at least 10,000 cells, clusters with fewer than 100 cells will be merged into the adjacent cluster with the nearest centroid, so the actual number of clusters used may be smaller. num_clusters_searched: the number of a cell's nearest clusters to search; must be between 1 and `num_clusters`. Defaults to `min(64, num_clusters)`. num_kmeans_iterations: the maximum number of iterations of k-means clustering to perform before starting the nearest-neighbor search, stopping early if a relative convergence of `kmeans_tolerance` is reached kmeans_tolerance: the relative change in inertia (the sum of squared distances from each cell to its assigned centroid) used to determine whether to stop optimizing the k-means clustering before `num_kmeans_iterations` iterations kmeans_barbar: whether to use [k-means||](https://arxiv.org/abs/1203.6402) initialization (a parallel version of k-means++) to initialize the k-means clustering centroids, instead of random initialization. This is more accurate but takes considerably longer for large datasets and `num_clusters`. num_init_iterations: the number of k-means|| iterations used to initialize the k-means clustering that constitutes the first step of the nearest-neighbor search. k-means|| is a parallel version of the widely used k-means++ initialization scheme for k-means clustering. The default value of 5 is recommended by the [k-means|| paper](https://arxiv.org/abs/1203.6402). Only used when `kmeans_barbar=True`. oversampling_factor: the number of candidate centroids selected, on average, at each of the `num_init_iterations` iterations of k-means||, as a multiple of `num_clusters`. The default value of 1 is the midpoint (in log space) of the values explored by the [k-means|| paper](https://arxiv.org/abs/1203.6402), namely 0.1 to 10. The total number of candidate centroids selected, on average, will be `oversampling_factor * num_clusters + 1`, from which the final `num_clusters` centroids will then be selected via k-means++. Only used when `kmeans_barbar=True`. chunk_size_kmeans: the chunk size used for distance calculations during k-means clustering, and also during the per-query centroid ranking step of the nearest-neighbor search. Setting this to a power of 2 is recommended. Defaults to `min(4096, number of cells in other)`. chunk_size_search: the chunk size used to group query cells together during the nearest-neighbor search. Overly small values will tend to increase runtime by reducing the reuse of information during the search, whereas overly large values will lead to excessive memory use. Defaults to `min(256, number of cells in other)`. seed: the random seed to use when finding nearest neighbors overwrite: if `True`, overwrite `cell_type_column` and/or `confidence_column` if already present in this dataset's obs, instead of raising an error verbose: whether to print details of the nearest-neighbor search num_threads: the number of threads to use for the nearest-neighbor search and label transfer. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. `num_threads` will be capped to 64 when running with Scipy linked against OpenBLAS (see warning below). Returns: `self`, but with two columns added to `obs`: `cell_type_column`, containing the transferred cell-type labels, and `confidence_column`, containing the cell-type confidences. If `next_best=True`, also adds the columns `next_best_cell_type_column` and `next_best_confidence_column`, containing the second-most likely cell type and its confidence. Warning: If you installed Scipy via pip, it will be linked against OpenBLAS, and `label_transfer_from()` will be limited to 64 threads due to the limitations of OpenBLAS. To use more than 64 threads, install Scipy linked against MKL BLAS. This is done automatically when installing brisc via conda, but you can also do it manually via `conda install "libblas=*=*mkl" scipy`. """ # Check that `cell_type_column` is a string check_type(cell_type_column, 'cell_type_column', str, 'a string') # Check that `confidence_column` is a string or `None`; if `None`, set # to `f'{cell_type_column}_confidence'`. if confidence_column is None: confidence_column = f'{cell_type_column}_confidence' else: check_type(confidence_column, 'confidence_column', str, 'a string or None') # Check that `next_best` is Boolean check_type(next_best, 'next_best', bool, 'Boolean') # If `next_best=False`, check that `next_best_cell_type_column` and # `next_best_confidence_column` are `None`. If `next_best=True`, check # that they are strings or `None`; if `None, set to # `f'next_best_{cell_type_column}'` and # `f'next_best_{cell_type_column}_confidence'` respectively. if next_best: if next_best_cell_type_column is None: next_best_cell_type_column = f'next_best_{cell_type_column}' else: check_type(next_best_cell_type_column, 'next_best_cell_type_column', str, 'a string or None') if next_best_confidence_column is None: next_best_confidence_column = \ f'next_best_{cell_type_column}_confidence' else: check_type(next_best_confidence_column, 'next_best_confidence_column', str, 'a string or None') elif next_best_cell_type_column is not None: error_message = \ 'next_best_cell_type_column must be None unless next_best=True' raise ValueError(error_message) elif next_best_confidence_column is not None: error_message = ( 'next_best_confidence_column must be None unless ' 'next_best=True') raise ValueError(error_message) # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # If `overwrite=False`, check that `cell_type_column` and # `confidence_column` (and `next_best_cell_type_column` and # `next_best_confidence_column`, if `next_best=True`) are not already # columns of `self.obs` if not overwrite: for column, column_name in ( (cell_type_column, 'cell_type_column'), (confidence_column, 'confidence_column')): if column in self._obs: error_message = ( f'{column_name} {column!r} is already a column ' f'of obs; did you already run label_transfer_from()? ' f'Set overwrite=True to overwrite.') raise ValueError(error_message) for column, column_name in ( (next_best_cell_type_column, 'next_best_cell_type_column'), (next_best_confidence_column, 'next_best_confidence_column')): if column in self._obs: error_message = ( f'{column_name} {column!r} is already a column ' f'of obs; did you already run label_transfer_from()? ' f'Set overwrite=True to overwrite, or next_best=False ' f'to not compute {column_name}.') raise ValueError(error_message) # Check that `other` is a SingleCell dataset check_type(other, 'other', SingleCell, 'a SingleCell dataset') # Get `QC_column` from `self` and `other_QC_column` from `other` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') if other_QC_column is not None: other_QC_column = other._get_column( 'obs', other_QC_column, 'other_QC_column', pl.Boolean, allow_missing=other_QC_column == 'passed_QC') # Get `original_cell_type_column` from `other` original_original_cell_type_column = original_cell_type_column original_cell_type_column = other._get_column( 'obs', original_cell_type_column, 'original_cell_type_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=other_QC_column) # If `other_QC_column` was specified, filter the cell type labels in # `original_cell_type_column` to cells passing QC if other_QC_column is not None: original_cell_type_column = \ original_cell_type_column.filter(other_QC_column) # Check that `original_cell_type_column` has at least two distinct cell # types most_common_cell_types = \ original_cell_type_column.value_counts(sort=True).to_series() if len(most_common_cell_types) == 1: original_cell_type_column_description = \ SingleCell._describe_column('original_cell_type_column', original_original_cell_type_column) error_message = ( f'{original_cell_type_column_description} must have at least ' f'two distinct cell types') if other_QC_column is not None: error_message += ' after filtering to cells passing QC' raise ValueError(error_message) # Check that `Harmony_key` is a string and in both `self.obsm` and # `other.obsm` check_type(Harmony_key, 'Harmony_key', str, 'a string') datasets = (self, 'self'), (other, 'other') for dataset, dataset_name in datasets: if Harmony_key not in dataset._obsm: error_message = ( f'Harmony_key {Harmony_key!r} is not a column of ' f'{dataset_name}.obs; did you forget to run harmonize() ' f'before label_transfer_from()?') raise ValueError(error_message) # Check that `num_kmeans_iterations` is a positive integer check_type(num_kmeans_iterations, 'num_kmeans_iterations', int, 'a positive integer') # Check that `kmeans_tolerance` is a positive number check_type(kmeans_tolerance, 'kmeans_tolerance', (int, float), 'a positive number') check_bounds(kmeans_tolerance, 'kmeans_tolerance', 0, left_open=True) # Check that `kmeans_barbar` is Boolean check_type(kmeans_barbar, 'kmeans_barbar', bool, 'Boolean') # Check that `num_init_iterations` is a positive integer check_type(num_init_iterations, 'num_init_iterations', int, 'a positive integer') check_bounds(num_init_iterations, 'num_init_iterations', 1) # Check that `oversampling_factor` is a positive number check_type(oversampling_factor, 'oversampling_factor', (int, float), 'a positive number') check_bounds(oversampling_factor, 'oversampling_factor', 0, left_open=True) # If `kmeans_barbar=False`, check that `num_init_iterations` and # `oversampling_factor` have their default values if not kmeans_barbar: if num_init_iterations != 5: error_message = ( 'num_init_iterations can only be specified when ' 'kmeans_barbar=True') raise ValueError(error_message) if oversampling_factor != 1: error_message = ( 'oversampling_factor can only be specified when ' 'kmeans_barbar=True') raise ValueError(error_message) # Check that `seed` is an integer check_type(seed, 'seed', int, 'an integer') # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Cap to 64 threads for OpenBLAS to avoid the error "OpenBLAS : Program # is Terminated. Because you tried to allocate too many memory # regions". if num_threads > 64 and any(lib.get('internal_api') == 'openblas' for lib in threadpool_info()): num_threads = 64 # Check that `verbose` is Boolean check_type(verbose, 'verbose', bool, 'Boolean') # Get the Harmony embeddings for self and other; check that they are # float32 and C-contiguous and have the same width if QC_column is None: self_Harmony_embeddings = self._obsm[Harmony_key] else: if num_threads == 1: self_Harmony_embeddings = \ self._obsm[Harmony_key][QC_column.to_numpy()] else: indices = np.flatnonzero(QC_column.to_numpy()) self_Harmony_embeddings = parallel_subset_2d( self._obsm[Harmony_key], indices, num_threads) if other_QC_column is None: other_Harmony_embeddings = other._obsm[Harmony_key] else: if num_threads == 1: other_Harmony_embeddings = \ other._obsm[Harmony_key][other_QC_column.to_numpy()] else: indices = np.flatnonzero(other_QC_column.to_numpy()) other_Harmony_embeddings = parallel_subset_2d( other._obsm[Harmony_key], indices, num_threads) if self_Harmony_embeddings.dtype != np.float32: error_message = ( f'obsm[{Harmony_key!r}].dtype is ' f'{self_Harmony_embeddings.dtype!r}, but must be float32') raise TypeError(error_message) if not self_Harmony_embeddings.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{Harmony_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({Harmony_key!r}, np.ascontiguousarray)') raise ValueError(error_message) if other_Harmony_embeddings.dtype != np.float32: error_message = ( f'other.obsm[{Harmony_key!r}].dtype is ' f'{other_Harmony_embeddings.dtype!r}, but must be float32') raise TypeError(error_message) if not other_Harmony_embeddings.flags['C_CONTIGUOUS']: error_message = ( f'other.obsm[{Harmony_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({Harmony_key!r}, np.ascontiguousarray)') raise ValueError(error_message) num_dimensions = self_Harmony_embeddings.shape[1] if other_Harmony_embeddings.shape[1] != num_dimensions: error_message = ( f"the two datasets' Harmony embeddings have different numbers " f"of columns ({num_dimensions:,} vs " f"{other_Harmony_embeddings.shape[1]:,}") raise ValueError(error_message) # Get the number of cells in `self` and `other` num_self = len(self_Harmony_embeddings) num_other = len(other_Harmony_embeddings) # Check that `num_neighbors` is between 1 and `num_other - 1` check_type(num_neighbors, 'num_neighbors', int, 'a positive integer') if not 1 <= num_neighbors < num_other: error_message = ( f'num_neighbors is {num_neighbors:,}, but must be ≥ 1 and ' f'less than the number of cells in other ({num_other:,})') raise ValueError(error_message) # Check that `num_clusters` is between 1 and `num_other`; if `None`, # set to `ceil(min(4 * sqrt(num_other), num_other / 100)))` if num_clusters is None: num_clusters = \ int(np.ceil(min(4 * np.sqrt(num_other), num_other / 100))) else: check_type(num_clusters, 'num_clusters', int, 'a positive integer') if not 1 <= num_clusters <= num_other: error_message = ( f'num_clusters is {num_clusters:,}, but must be ≥ 1 and ≤ ' f'the number of cells in other ({num_other:,})') raise ValueError(error_message) # Check that `num_clusters_searched` is between 1 and `num_clusters`; # if `None`, set to `min(64, num_clusters)` if num_clusters_searched is None: num_clusters_searched = min(64, num_clusters) else: check_type(num_clusters_searched, 'num_clusters_searched', int, 'a positive integer') if not 1 <= num_clusters_searched <= num_clusters: error_message = ( f'num_clusters_searched is {num_clusters_searched:,}, but ' f'must be ≥ 1 and ≤ num_clusters ({num_clusters:,})') raise ValueError(error_message) # Check that `chunk_size_kmeans` is between 1 and `num_other`; if # `None`, set to `min(4096, num_other)` if chunk_size_kmeans is None: chunk_size_kmeans = min(4096, num_other) else: check_type(chunk_size_kmeans, 'chunk_size_kmeans', int, 'a positive integer') if not 1 <= chunk_size_kmeans <= num_other: error_message = ( f'chunk_size_kmeans is {chunk_size_kmeans:,}, but must ' f'be ≥ 1 and ≤ the number of cells in other ' f'({num_other:,})') raise ValueError(error_message) # Check that `chunk_size_search` is between 1 and `num_other`; if # `None`, set to `min(256, num_other)` if chunk_size_search is None: chunk_size_search = min(256, num_other) else: check_type(chunk_size_search, 'chunk_size_search', int, 'a positive integer') if not 1 <= chunk_size_search <= num_other: error_message = ( f'chunk_size_search is {chunk_size_search:,}, but must be ' f'≥ 1 and ≤ the number of cells in other ({num_other:,})') raise ValueError(error_message) # Recode cell types so the most common is 0, the next-most common 1, # etc. This has the effect of breaking ties by taking the most common # cell type: we pick the first element in case of ties. original_cell_type_column = original_cell_type_column.replace_strict( most_common_cell_types, pl.arange(len(most_common_cell_types), eager=True, dtype=pl.UInt32))\ .to_numpy() # Run k-means clustering on `other_Harmony_embeddings`. Since # `centroids` and `centroids_new` are swapped every iteration, # `centroids_new` will contain the final centroids when doing an odd # number of k-means iterations; if so, swap them at the end. Use # `threadpool_limits()` to run BLAS single-threaded (directly in the # single-threaded case, by making everything single-threaded including # BLAS; indirectly in the parallel case, by disabling nested # parallelism inside `prange()`). if num_threads == 1: cluster_labels = np.empty(num_other, dtype=np.uint32) min_distances = np.empty(num_other, dtype=np.float32) cell_norms = np.empty(num_other, dtype=np.float32) else: cluster_labels = numa_zeros(num_other, dtype=np.uint32) min_distances = numa_zeros(num_other, dtype=np.float32) cell_norms = numa_zeros(num_other, dtype=np.float32) centroids = np.empty((num_clusters, num_dimensions), dtype=np.float32) centroids_new = np.empty((num_clusters, num_dimensions), dtype=np.float32) num_cells_per_cluster = np.empty(num_clusters, dtype=np.uint32) with threadpool_limits(num_threads): iterations_until_convergence, num_clusters = kmeans( X=other_Harmony_embeddings, cluster_labels=cluster_labels, centroids=centroids, centroids_new=centroids_new, num_cells_per_cluster=num_cells_per_cluster, min_distances=min_distances, cell_norms=cell_norms, kmeans_barbar=kmeans_barbar, num_init_iterations=num_init_iterations, num_kmeans_iterations=num_kmeans_iterations, tolerance=kmeans_tolerance, oversampling_factor=oversampling_factor, seed=seed, chunk_size=chunk_size_kmeans, num_threads=num_threads) if iterations_until_convergence & 1: centroids = centroids_new del centroids_new, min_distances centroids = centroids[:num_clusters] num_cells_per_cluster = num_cells_per_cluster[:num_clusters] # As an optimization, renumber clusters so that cluster `N` tends to be # near cluster `N + 1` in PC space, by sorting the centroids by PC1. # This reduces cache misses during the nearest-neighbor search. centroid_order = np.argsort(centroids[:, 0]).astype(np.uint32) centroids = np.ascontiguousarray(centroids[centroid_order]) num_cells_per_cluster = num_cells_per_cluster[centroid_order] inverse_centroid_order = np.empty_like(centroid_order) inverse_centroid_order[centroid_order] = np.arange(len(centroid_order), dtype=np.uint32) cluster_labels = inverse_centroid_order[cluster_labels] # As an optimization, sort by cluster to reduce cache misses in the # nearest-neighbor search. This also lets us skip building an explicit # inverted file index during the nearest-neighbor search. Unlike in # `neighbors()`, there is no need to pass `sorted_order` to the # k-nearest neighbors function or to create the reverse mapping, since # `self` is not being sorted, only `other`. sorted_order = np.argsort(cluster_labels).astype(np.uint32) del cluster_labels if num_threads == 1: other_Harmony_embeddings = other_Harmony_embeddings[sorted_order] original_cell_type_column = original_cell_type_column[sorted_order] cell_norms = cell_norms[sorted_order] else: other_Harmony_embeddings = parallel_subset_2d( other_Harmony_embeddings, sorted_order, num_threads) original_cell_type_column = parallel_subset_1d( original_cell_type_column, sorted_order, num_threads) cell_norms = parallel_subset_1d(cell_norms, sorted_order, num_threads) # Find the `num_neighbors` nearest neighbors in `other` of each cell in # `self`, according to `self_Harmony_embeddings` and # `other_Harmony_embeddings`. Use `threadpool_limits()` to run BLAS # single-threaded (directly in the single-threaded case, by making # everything single-threaded including BLAS; indirectly in the parallel # case, by disabling nested parallelism inside `prange()`). if num_threads == 1: neighbors = np.empty((num_self, num_neighbors), dtype=np.uint32) distances = np.empty((num_self, num_neighbors), dtype=np.float32) centroid_distances = np.empty((num_self, num_clusters_searched), dtype=np.float32) nearest_clusters = np.empty((num_self, num_clusters_searched), dtype=np.uint32) query_norms = np.empty(num_self, dtype=np.float32) else: neighbors = numa_zeros((num_self, num_neighbors), dtype=np.uint32) distances = numa_zeros((num_self, num_neighbors), dtype=np.float32) centroid_distances = numa_zeros((num_self, num_clusters_searched), dtype=np.float32) nearest_clusters = numa_zeros((num_self, num_clusters_searched), dtype=np.uint32) query_norms = numa_zeros(num_self, dtype=np.float32) with threadpool_limits(num_threads): knn_cross(Y=self_Harmony_embeddings, X=other_Harmony_embeddings, centroids=centroids, num_cells_per_cluster=num_cells_per_cluster, cell_norms=cell_norms, neighbors=neighbors, distances=distances, centroid_distances=centroid_distances, nearest_clusters=nearest_clusters, query_norms=query_norms, num_neighbors=num_neighbors, num_clusters_searched=num_clusters_searched, chunk_size_kmeans=chunk_size_kmeans, chunk_size_search=chunk_size_search, num_threads=num_threads) del centroid_distances, nearest_clusters, query_norms # Get the (two) most common cell type(s) for each cell in `self` among # its `num_neighbors` nearest neighbors in `other`. Pick the first # element in case of ties, which according to our encoding is the most # common cell type. Also get the cell-type confidence of each cell's # type(s), i.e. the frequency of the cell type among the cell's # `num_neighbors` nearest neighbors. if num_threads == 1: cell_types = np.empty(num_self, dtype=np.uint32) confidences = np.empty(num_self, dtype=np.float32) else: cell_types = numa_zeros(num_self, dtype=np.uint32) confidences = numa_zeros(num_self, dtype=np.float32) if not next_best: next_best_cell_types = np.array([], dtype=np.uint32) next_best_confidences = np.array([], dtype=np.float32) elif num_threads == 1: next_best_cell_types = np.empty(num_self, dtype=np.uint32) next_best_confidences = np.empty(num_self, dtype=np.float32) else: next_best_cell_types = numa_zeros(num_self, dtype=np.uint32) next_best_confidences = numa_zeros(num_self, dtype=np.float32) label_transfer(neighbors=neighbors, original_cell_type_column=original_cell_type_column, num_cell_types=len(most_common_cell_types), cell_types=cell_types, confidences=confidences, next_best_cell_types=next_best_cell_types, next_best_confidences=next_best_confidences, num_threads=num_threads) # Map the cell-type codes back to their labels by constructing a polars # Series from the codes, then casting it to an Enum. Also convert # cell-type confidences to Series. cell_types = pl.Series(cell_type_column, cell_types)\ .cast(pl.Enum(most_common_cell_types.to_list())) confidences = pl.Series(confidence_column, confidences) if next_best: next_best_cell_types = \ pl.Series(next_best_cell_type_column, next_best_cell_types)\ .cast(pl.Enum(most_common_cell_types.to_list())) next_best_confidences = \ pl.Series(next_best_confidence_column, next_best_confidences) columns = cell_types, confidences, next_best_cell_types, \ next_best_confidences else: columns = cell_types, confidences # Add the cell-type labels and confidences to `self.obs`. If # `QC_column` was specified, back-project from QCed cells to all cells, # filling with `null`. if QC_column is None: obs = self._obs.with_columns(columns) else: expand = lambda series: pl.when(QC_column.name)\ .then(pl.lit(series).gather(pl.col(QC_column.name).cum_sum() - pl.col(QC_column.name))) obs = self._obs.with_columns(map(expand, columns)) # Return a new SingleCell dataset containing the cell-type labels and # confidences return SingleCell(X=self._X, obs=obs, var=self._var, obsm=self._obsm, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
@staticmethod def _get_rocket_r() -> 'LinearSegmentedColormap': """ Define Seaborn's rocket_r colormap using base Matplotlib. Returns: """ signal.signal(signal.SIGINT, signal.SIG_IGN) try: import matplotlib.pyplot as plt finally: signal.signal(signal.SIGINT, signal.default_int_handler) rocket_colors = [ [0.01060815, 0.01808215, 0.10018654], [0.01428972, 0.02048237, 0.10374486], [0.01831941, 0.0229766, 0.10738511], [0.02275049, 0.02554464, 0.11108639], [0.02759119, 0.02818316, 0.11483751], [0.03285175, 0.03088792, 0.11863035], [0.03853466, 0.03365771, 0.12245873], [0.04447016, 0.03648425, 0.12631831], [0.05032105, 0.03936808, 0.13020508], [0.05611171, 0.04224835, 0.13411624], [0.0618531, 0.04504866, 0.13804929], [0.06755457, 0.04778179, 0.14200206], [0.0732236, 0.05045047, 0.14597263], [0.0788708, 0.05305461, 0.14995981], [0.08450105, 0.05559631, 0.15396203], [0.09011319, 0.05808059, 0.15797687], [0.09572396, 0.06050127, 0.16200507], [0.10132312, 0.06286782, 0.16604287], [0.10692823, 0.06517224, 0.17009175], [0.1125315, 0.06742194, 0.17414848], [0.11813947, 0.06961499, 0.17821272], [0.12375803, 0.07174938, 0.18228425], [0.12938228, 0.07383015, 0.18636053], [0.13501631, 0.07585609, 0.19044109], [0.14066867, 0.0778224, 0.19452676], [0.14633406, 0.07973393, 0.1986151], [0.15201338, 0.08159108, 0.20270523], [0.15770877, 0.08339312, 0.20679668], [0.16342174, 0.0851396, 0.21088893], [0.16915387, 0.08682996, 0.21498104], [0.17489524, 0.08848235, 0.2190294], [0.18065495, 0.09009031, 0.22303512], [0.18643324, 0.09165431, 0.22699705], [0.19223028, 0.09317479, 0.23091409], [0.19804623, 0.09465217, 0.23478512], [0.20388117, 0.09608689, 0.23860907], [0.20973515, 0.09747934, 0.24238489], [0.21560818, 0.09882993, 0.24611154], [0.22150014, 0.10013944, 0.2497868], [0.22741085, 0.10140876, 0.25340813], [0.23334047, 0.10263737, 0.25697736], [0.23928891, 0.10382562, 0.2604936], [0.24525608, 0.10497384, 0.26395596], [0.25124182, 0.10608236, 0.26736359], [0.25724602, 0.10715148, 0.27071569], [0.26326851, 0.1081815, 0.27401148], [0.26930915, 0.1091727, 0.2772502], [0.27536766, 0.11012568, 0.28043021], [0.28144375, 0.11104133, 0.2835489], [0.2875374, 0.11191896, 0.28660853], [0.29364846, 0.11275876, 0.2896085], [0.29977678, 0.11356089, 0.29254823], [0.30592213, 0.11432553, 0.29542718], [0.31208435, 0.11505284, 0.29824485], [0.31826327, 0.1157429, 0.30100076], [0.32445869, 0.11639585, 0.30369448], [0.33067031, 0.11701189, 0.30632563], [0.33689808, 0.11759095, 0.3088938], [0.34314168, 0.11813362, 0.31139721], [0.34940101, 0.11863987, 0.3138355], [0.355676, 0.11910909, 0.31620996], [0.36196644, 0.1195413, 0.31852037], [0.36827206, 0.11993653, 0.32076656], [0.37459292, 0.12029443, 0.32294825], [0.38092887, 0.12061482, 0.32506528], [0.38727975, 0.12089756, 0.3271175], [0.39364518, 0.12114272, 0.32910494], [0.40002537, 0.12134964, 0.33102734], [0.40642019, 0.12151801, 0.33288464], [0.41282936, 0.12164769, 0.33467689], [0.41925278, 0.12173833, 0.33640407], [0.42569057, 0.12178916, 0.33806605], [0.43214263, 0.12179973, 0.33966284], [0.43860848, 0.12177004, 0.34119475], [0.44508855, 0.12169883, 0.34266151], [0.45158266, 0.12158557, 0.34406324], [0.45809049, 0.12142996, 0.34540024], [0.46461238, 0.12123063, 0.34667231], [0.47114798, 0.12098721, 0.34787978], [0.47769736, 0.12069864, 0.34902273], [0.48426077, 0.12036349, 0.35010104], [0.49083761, 0.11998161, 0.35111537], [0.49742847, 0.11955087, 0.35206533], [0.50403286, 0.11907081, 0.35295152], [0.51065109, 0.11853959, 0.35377385], [0.51728314, 0.1179558, 0.35453252], [0.52392883, 0.11731817, 0.35522789], [0.53058853, 0.11662445, 0.35585982], [0.53726173, 0.11587369, 0.35642903], [0.54394898, 0.11506307, 0.35693521], [0.5506426, 0.11420757, 0.35737863], [0.55734473, 0.11330456, 0.35775059], [0.56405586, 0.11235265, 0.35804813], [0.57077365, 0.11135597, 0.35827146], [0.5774991, 0.11031233, 0.35841679], [0.58422945, 0.10922707, 0.35848469], [0.59096382, 0.10810205, 0.35847347], [0.59770215, 0.10693774, 0.35838029], [0.60444226, 0.10573912, 0.35820487], [0.61118304, 0.10450943, 0.35794557], [0.61792306, 0.10325288, 0.35760108], [0.62466162, 0.10197244, 0.35716891], [0.63139686, 0.10067417, 0.35664819], [0.63812122, 0.09938212, 0.35603757], [0.64483795, 0.0980891, 0.35533555], [0.65154562, 0.09680192, 0.35454107], [0.65824241, 0.09552918, 0.3536529], [0.66492652, 0.09428017, 0.3526697], [0.67159578, 0.09306598, 0.35159077], [0.67824099, 0.09192342, 0.3504148], [0.684863, 0.09085633, 0.34914061], [0.69146268, 0.0898675, 0.34776864], [0.69803757, 0.08897226, 0.3462986], [0.70457834, 0.0882129, 0.34473046], [0.71108138, 0.08761223, 0.3430635], [0.7175507, 0.08716212, 0.34129974], [0.72398193, 0.08688725, 0.33943958], [0.73035829, 0.0868623, 0.33748452], [0.73669146, 0.08704683, 0.33543669], [0.74297501, 0.08747196, 0.33329799], [0.74919318, 0.08820542, 0.33107204], [0.75535825, 0.08919792, 0.32876184], [0.76145589, 0.09050716, 0.32637117], [0.76748424, 0.09213602, 0.32390525], [0.77344838, 0.09405684, 0.32136808], [0.77932641, 0.09634794, 0.31876642], [0.78513609, 0.09892473, 0.31610488], [0.79085854, 0.10184672, 0.313391], [0.7965014, 0.10506637, 0.31063031], [0.80205987, 0.10858333, 0.30783], [0.80752799, 0.11239964, 0.30499738], [0.81291606, 0.11645784, 0.30213802], [0.81820481, 0.12080606, 0.29926105], [0.82341472, 0.12535343, 0.2963705], [0.82852822, 0.13014118, 0.29347474], [0.83355779, 0.13511035, 0.29057852], [0.83850183, 0.14025098, 0.2876878], [0.84335441, 0.14556683, 0.28480819], [0.84813096, 0.15099892, 0.281943], [0.85281737, 0.15657772, 0.27909826], [0.85742602, 0.1622583, 0.27627462], [0.86196552, 0.16801239, 0.27346473], [0.86641628, 0.17387796, 0.27070818], [0.87079129, 0.17982114, 0.26797378], [0.87507281, 0.18587368, 0.26529697], [0.87925878, 0.19203259, 0.26268136], [0.8833417, 0.19830556, 0.26014181], [0.88731387, 0.20469941, 0.25769539], [0.89116859, 0.21121788, 0.2553592], [0.89490337, 0.21785614, 0.25314362], [0.8985026, 0.22463251, 0.25108745], [0.90197527, 0.23152063, 0.24918223], [0.90530097, 0.23854541, 0.24748098], [0.90848638, 0.24568473, 0.24598324], [0.911533, 0.25292623, 0.24470258], [0.9144225, 0.26028902, 0.24369359], [0.91717106, 0.26773821, 0.24294137], [0.91978131, 0.27526191, 0.24245973], [0.92223947, 0.28287251, 0.24229568], [0.92456587, 0.29053388, 0.24242622], [0.92676657, 0.29823282, 0.24285536], [0.92882964, 0.30598085, 0.24362274], [0.93078135, 0.31373977, 0.24468803], [0.93262051, 0.3215093, 0.24606461], [0.93435067, 0.32928362, 0.24775328], [0.93599076, 0.33703942, 0.24972157], [0.93752831, 0.34479177, 0.25199928], [0.93899289, 0.35250734, 0.25452808], [0.94036561, 0.36020899, 0.25734661], [0.94167588, 0.36786594, 0.2603949], [0.94291042, 0.37549479, 0.26369821], [0.94408513, 0.3830811, 0.26722004], [0.94520419, 0.39062329, 0.27094924], [0.94625977, 0.39813168, 0.27489742], [0.94727016, 0.4055909, 0.27902322], [0.94823505, 0.41300424, 0.28332283], [0.94914549, 0.42038251, 0.28780969], [0.95001704, 0.42771398, 0.29244728], [0.95085121, 0.43500005, 0.29722817], [0.95165009, 0.44224144, 0.30214494], [0.9524044, 0.44944853, 0.3072105], [0.95312556, 0.45661389, 0.31239776], [0.95381595, 0.46373781, 0.31769923], [0.95447591, 0.47082238, 0.32310953], [0.95510255, 0.47787236, 0.32862553], [0.95569679, 0.48489115, 0.33421404], [0.95626788, 0.49187351, 0.33985601], [0.95681685, 0.49882008, 0.34555431], [0.9573439, 0.50573243, 0.35130912], [0.95784842, 0.51261283, 0.35711942], [0.95833051, 0.51946267, 0.36298589], [0.95879054, 0.52628305, 0.36890904], [0.95922872, 0.53307513, 0.3748895], [0.95964538, 0.53983991, 0.38092784], [0.96004345, 0.54657593, 0.3870292], [0.96042097, 0.55328624, 0.39319057], [0.96077819, 0.55997184, 0.39941173], [0.9611152, 0.5666337, 0.40569343], [0.96143273, 0.57327231, 0.41203603], [0.96173392, 0.57988594, 0.41844491], [0.96201757, 0.58647675, 0.42491751], [0.96228344, 0.59304598, 0.43145271], [0.96253168, 0.5995944, 0.43805131], [0.96276513, 0.60612062, 0.44471698], [0.96298491, 0.6126247, 0.45145074], [0.96318967, 0.61910879, 0.45824902], [0.96337949, 0.6255736, 0.46511271], [0.96355923, 0.63201624, 0.47204746], [0.96372785, 0.63843852, 0.47905028], [0.96388426, 0.64484214, 0.4861196], [0.96403203, 0.65122535, 0.4932578], [0.96417332, 0.65758729, 0.50046894], [0.9643063, 0.66393045, 0.5077467], [0.96443322, 0.67025402, 0.51509334], [0.96455845, 0.67655564, 0.52251447], [0.96467922, 0.68283846, 0.53000231], [0.96479861, 0.68910113, 0.53756026], [0.96492035, 0.69534192, 0.5451917], [0.96504223, 0.7015636, 0.5528892], [0.96516917, 0.70776351, 0.5606593], [0.96530224, 0.71394212, 0.56849894], [0.96544032, 0.72010124, 0.57640375], [0.96559206, 0.72623592, 0.58438387], [0.96575293, 0.73235058, 0.59242739], [0.96592829, 0.73844258, 0.60053991], [0.96612013, 0.74451182, 0.60871954], [0.96632832, 0.75055966, 0.61696136], [0.96656022, 0.75658231, 0.62527295], [0.96681185, 0.76258381, 0.63364277], [0.96709183, 0.76855969, 0.64207921], [0.96739773, 0.77451297, 0.65057302], [0.96773482, 0.78044149, 0.65912731], [0.96810471, 0.78634563, 0.66773889], [0.96850919, 0.79222565, 0.6764046], [0.96893132, 0.79809112, 0.68512266], [0.96935926, 0.80395415, 0.69383201], [0.9698028, 0.80981139, 0.70252255], [0.97025511, 0.81566605, 0.71120296], [0.97071849, 0.82151775, 0.71987163], [0.97120159, 0.82736371, 0.72851999], [0.97169389, 0.83320847, 0.73716071], [0.97220061, 0.83905052, 0.74578903], [0.97272597, 0.84488881, 0.75440141], [0.97327085, 0.85072354, 0.76299805], [0.97383206, 0.85655639, 0.77158353], [0.97441222, 0.86238689, 0.78015619], [0.97501782, 0.86821321, 0.78871034], [0.97564391, 0.87403763, 0.79725261], [0.97628674, 0.87986189, 0.8057883], [0.97696114, 0.88568129, 0.81430324], [0.97765722, 0.89149971, 0.82280948], [0.97837585, 0.89731727, 0.83130786], [0.97912374, 0.90313207, 0.83979337], [0.979891, 0.90894778, 0.84827858], [0.98067764, 0.91476465, 0.85676611], [0.98137749, 0.92061729, 0.86536915]] return plt.matplotlib.colors.LinearSegmentedColormap.from_list( 'rocket_r', rocket_colors[::-1])
[docs] def plot_heatmap(self, x: SingleCellColumn, y: SingleCellColumn, filename: str | Path | None = None, /, *, cells_to_plot_column: SingleCellColumn | None = 'passed_QC', normalize_rows: bool = False, normalize_columns: bool = False, ax: 'Axes' | None = None, figure_kwargs: dict[str, Any] | None = None, colormap: str | 'Colormap' | None = None, heatmap_kwargs: dict[str, Any] | None = None, label: bool = False, label_format: str | None = None, label_kwargs: dict[str, Any] | None = None, colorbar: bool = True, colorbar_kwargs: dict[str, Any] | None = None, title: str | None = None, title_kwargs: dict[str, Any] | None = None, xlabel: str | Literal[True] | None = True, xlabel_kwargs: dict[str, Any] | None = None, ylabel: str | Literal[True] | None = True, ylabel_kwargs: dict[str, Any] | None = None, despine: bool = True, savefig_kwargs: dict[str, Any] | None = None) -> None: """ Plot a heatmap of the count of each combination of two categorical columns, `x` and `y`. If `normalize_rows` or `normalize_columns` is specified, plot percentages instead of counts, so that each row or column sums to 100%. Args: x: the first column; must be String, Enum, Categorical, or integer y: the second column; must be String, Enum, Categorical, or integer filename: the file to save to. If `None`, generate the plot but do not save it, which allows it to be shown interactively or modified further before saving. cells_to_plot_column: an optional Boolean column of `obs` indicating which cells to plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to plot all cells passing QC. normalize_rows: whether to plot percentages instead of counts, so that each row sums to 100%. Mutually exclusive with `normalize_columns`. normalize_columns: whether to plot percentages instead of counts, so that each column sums to 100%. Mutually exclusive with `normalize_rows`. ax: the Matplotlib axes to save the plot onto; if `None`, create a new figure with Matpotlib's constrained layout and plot onto it figure_kwargs: a dictionary of keyword arguments to be passed to `plt.figure` when `ax` is `None`, such as: - `figsize`: a two-element sequence of the width and height of the figure in inches. The default is a complicated formula based on the number of genes and cell types being plotted, unlike Matplotlib's default of `[6.4, 4.8]`. - `layout`: the layout mechanism used by Matplotlib to avoid overlapping plot elements. Defaults to `'constrained'`, instead of Matplotlib's default of `None`. colormap: a string or Colormap object indicating the Matplotlib colormap to use in the heatmap, or `None` to use Seaborn's `'rocket_r'` colormap. heatmap_kwargs: a dictionary of keyword arguments to be passed to `ax.pcolormesh()` when generating the heatmap, such as: - `rasterized`: whether to convert the heatmap cells to a raster (bitmap) image when saving to a vector format like PDF. Defaults to `True`, instead of Matplotlib's default of `False`. - `norm`, `vmin`, and `vmax`: control how the colormap maps counts or percentages to heatmap colors - `edgecolors`: the border color of each heatmap cell; defaults to `'none'`, meaning no borders. Specifying `cmap` will raise an error, since it conflicts with the `colormap` argument. label: whether to label each cell of the heatmap with its count (or percentage, if `normalize_rows=True` or `normalize_columns=True`) label_format: a format string to apply to the label for each count or percentage. If `None`, use `'{:,}'` for counts and `'{:.2f}%'` for percentages. Can only be specified when `label=True`. label_kwargs: a dictionary of keyword arguments to be passed to `ax.text()` when adding labels to control the text properties, such as: - `color` and `size` to modify the text color/size. By default, the color is dark gray for light-colored cells, and white for dark-colored ones. - `verticalalignment` and `horizontalalignment` to control vertical and horizontal alignment. By default, unlike Matplotlib, these are both set to `'center'`. Can only be specified when `label=True`. colorbar: whether to add a colorbar colorbar_kwargs: a dictionary of keyword arguments to be passed to `plt.colorbar()`, such as: - `location`: `'left'`, `'right'`, `'top'`, or `'bottom'` - `orientation`: `'vertical'` or `'horizontal'` - `fraction`: the fraction of the axes to allocate to the colorbar. Defaults to 0.15. - `shrink`: the fraction to multiply the size of the colorbar by. Defaults to 0.5, instead of Matplotlib's default of 1. - `aspect`: the ratio of the colorbar's long to short dimensions. Defaults to 20. - `pad`: the fraction of the axes between the colorbar and the rest of the figure. Defaults to 0.01, instead of Matplotlib's default of 0.05 if vertical and 0.15 if horizontal. Can only be specified when `colorbar=True`. title: the title of the plot, or `None` to not add a title title_kwargs: a dictionary of keyword arguments to be passed to `ax.set_title()` to control text properties, such as `color` and `size`. Can only be specified when `title` is not `None`. xlabel: the x-axis label, `True` to use the name of `x` as the x-axis label, or `None` to not label the x-axis xlabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_xlabel()` to control text properties, such as `color` and `size`. Can only be specified when `xlabel` is not `None`. ylabel: the y-axis label, `True` to use the name of `y` as the y-axis label, or `None` to not label the y-axis ylabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_ylabel()` to control text properties, such as `color` and `size`. Can only be specified when `ylabel` is not `None`. despine: whether to remove the spines (borders of the plot area) from the plot; unlike the other plotting functions in this library, this also removes the left and bottom spines savefig_kwargs: a dictionary of keyword arguments to be passed to `plt.savefig()`, such as: - `dpi`: defaults to 300 instead of Matplotlib's default of 150 - `bbox_inches`: the bounding box of the portion of the figure to save; defaults to `'tight'` (crop out any blank borders) instead of Matplotlib's default of `None` (save the entire figure) - `pad_inches`: the number of inches of padding to add on each of the four sides of the figure when saving. Defaults to `'layout'` (use the padding from the constrained layout engine, when `ax` is not `None`) instead of Matplotlib's default of 0.1. - `transparent`: whether to save with a transparent background; defaults to `True` if saving to a PDF (i.e. when `filename` ends with `'.pdf'`) and `False` otherwise, instead of Matplotlib's default of always being `False`. Can only be specified when `filename` is specified. """ # Import matplotlib signal.signal(signal.SIGINT, signal.SIG_IGN) try: import matplotlib.pyplot as plt finally: signal.signal(signal.SIGINT, signal.default_int_handler) # Get `cells_to_plot_column`, if not `None` if cells_to_plot_column is not None: cells_to_plot_column = self._get_column( 'obs', cells_to_plot_column, 'cells_to_plot_column', pl.Boolean, allow_missing=isinstance(cells_to_plot_column, str) and cells_to_plot_column == 'passed_QC') # Get `x` and `y`, and check that they are String, Enum, Categorical, # or integer x = self._get_column('obs', x, 'x', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=cells_to_plot_column) y = self._get_column('obs', y, 'y', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=cells_to_plot_column) if cells_to_plot_column is not None: x = x.filter(cells_to_plot_column) y = y.filter(cells_to_plot_column) # If `filename` was specified, check that it is a string or # `pathlib.Path` and that its base directory exists; if `filename` is # `None`, make sure `savefig_kwargs` is also `None` if filename is not None: check_type(filename, 'filename', (str, Path), 'a string or pathlib.Path') directory = os.path.dirname(filename) if directory and not os.path.isdir(directory): error_message = ( f'{filename} refers to a file in the directory ' f'{directory!r}, but this directory does not exist') raise NotADirectoryError(error_message) filename = str(filename) elif savefig_kwargs is not None: error_message = 'savefig_kwargs must be None when filename is None' raise ValueError(error_message) # Check that `normalize_rows`, `normalize_columns`, `label`, # `colorbar`, and `despine` are Boolean check_type(normalize_rows, 'normalize_rows', bool, 'Boolean') check_type(normalize_columns, 'normalize_columns', bool, 'Boolean') check_type(label, 'label', bool, 'Boolean') check_type(colorbar, 'colorbar', bool, 'Boolean') check_type(despine, 'despine', bool, 'Boolean') # Check that `normalize_rows` and `normalize_columns` are mutually # exclusive if normalize_rows and normalize_columns: error_message = \ 'only one of normalize_rows and normalize_columns can be True' raise ValueError(error_message) # If `figure_kwargs` was specified, check that `ax` is `None` if figure_kwargs is not None and ax is not None: error_message = ( 'figure_kwargs must be None when ax is not None, since a new ' 'figure does not need to be generated when plotting onto an ' 'existing axis') raise ValueError(error_message) # Check that `colormap` is a string in `plt.colormaps`, a Colormap # object, or `None`; if `None`, default to Seaborn's rocket_r colormap if colormap is None: colormap = SingleCell._get_rocket_r() else: check_type(colormap, 'colormap', (str, plt.matplotlib.colors.Colormap), 'a string or matplotlib Colormap object') if isinstance(colormap, str): colormap = plt.colormaps[colormap] # If `label=False`, check that `label_format` and `label_kwargs` are # `None`. if not label: if label_format is not None: error_message = 'label_format must be None when label=False' raise ValueError(error_message) if label_kwargs is not None: error_message = 'label_kwargs must be None when label=False' raise ValueError(error_message) # If not `None`, check that `label_format` is a valid format string. # For simplicity, just check that it has curly braces and that all # braces are matched. If `label_format` is `None`, use `'{:,}'` for # counts or, if `normalize_rows=True` or `normalize_columns=True`, # `'{:.2f}%'` for percentages. if label_format is None: label_format = \ '{:.2f}%' if normalize_rows or normalize_columns else '{:,}' else: check_type(label_format, 'label_format', str, 'a string') open_braces = 0 has_braces = False for char in label_format: if char == '{': has_braces = True open_braces += 1 elif char == '}': open_braces -= 1 if open_braces < 0: error_message = \ 'label_format contains mismatched curly braces' raise ValueError(error_message) if open_braces == 0: error_message = 'label_format contains mismatched curly braces' raise ValueError(error_message) if not has_braces: error_message = 'label_format must contain curly braces' raise ValueError(error_message) # If `colorbar=False`, check that `colorbar_kwargs` is None if not colorbar and colorbar_kwargs is not None: error_message = 'colorbar_kwargs must be None when colorbar=False' raise ValueError(error_message) # Check that `title` is a string or `None`; if `None`, check that # `title_kwargs` is `None` as well. if title is not None: check_type(title, 'title', str, 'a string') elif title_kwargs is not None: error_message = 'title_kwargs must be None when title is None' raise ValueError(error_message) # Check that `xlabel` is a string, `True` (in which case set it to # `x.name`), or `None`; if `None`, check that `xlabel_kwargs` is `None` # as well. Ditto for `ylabel`. if xlabel is not None: if xlabel is True: xlabel = x.name else: check_type(xlabel, 'xlabel', str, 'a string') elif xlabel_kwargs is not None: error_message = 'xlabel_kwargs must be None when xlabel is None' raise ValueError(error_message) if ylabel is not None: if ylabel is True: ylabel = y.name else: check_type(ylabel, 'ylabel', str, 'a string') elif ylabel_kwargs is not None: error_message = 'ylabel_kwargs must be None when ylabel is None' raise ValueError(error_message) # For each of the kwargs arguments, if the argument was specified, # check that it is a dictionary and that all its keys are strings. for kwargs, kwargs_name in ((figure_kwargs, 'figure_kwargs'), (heatmap_kwargs, 'heatmap_kwargs'), (label_kwargs, 'label_kwargs'), (colorbar_kwargs, 'colorbar_kwargs'), (title_kwargs, 'title_kwargs'), (xlabel_kwargs, 'xlabel_kwargs'), (ylabel_kwargs, 'ylabel_kwargs'), (savefig_kwargs, 'savefig_kwargs')): if kwargs is not None: check_type(kwargs, kwargs_name, dict, 'a dictionary') for key in kwargs: if not isinstance(key, str): error_message = ( f'all keys of {kwargs_name} must be strings, but ' f'it contains a key of type ' f'{type(key).__name__!r}') raise TypeError(error_message) # Override the defaults for certain values of `heatmap_kwargs`; if # specified, check that `heatmap_kwargs` does not contain the `cmap` # argument default_heatmap_kwargs = dict(rasterized=True) if heatmap_kwargs is None: heatmap_kwargs = default_heatmap_kwargs else: if 'cmap' in heatmap_kwargs: error_message = ( f"'cmap' cannot be specified as a key in heatmap_kwargs; " f"specify the colormap argument instead") raise ValueError(error_message) heatmap_kwargs = heatmap_kwargs | default_heatmap_kwargs # Get the heatmap data count = pl.DataFrame((x, y))\ .group_by(pl.all(), maintain_order=True)\ .len(name='_SingleCell_count')\ .pivot(index=y.name, columns=x.name, values='_SingleCell_count')\ .fill_null(0) heatmap_data = count[:, 1:].to_numpy() # Normalize, if `normalize_rows=True` or `normalize_columns=True` if normalize_rows: heatmap_data = \ heatmap_data / heatmap_data.sum(axis=1, keepdims=True) elif normalize_columns: heatmap_data = \ heatmap_data / heatmap_data.sum(axis=0, keepdims=True) # If `ax` is `None`, create a new figure; otherwise, check that it is a # Matplotlib axis make_new_figure = ax is None try: num_rows, num_columns = heatmap_data.shape if make_new_figure: default_figure_kwargs = dict(layout='constrained') if figure_kwargs is None or 'figsize' not in figure_kwargs: if colorbar: width_ratio = max(4, 0.2 * num_columns) width = 6.4 / (1 + width_ratio) + \ 6.4 * width_ratio / (1 + width_ratio) * \ max(num_columns, 5) / 20 else: width = 6.4 * max(num_columns, 5) / 20 height = max(4.8, 1 + 3.8 * num_rows / 20) default_figure_kwargs['figsize'] = width, height figure_kwargs = default_figure_kwargs | figure_kwargs \ if figure_kwargs is not None else default_figure_kwargs plt.figure(**figure_kwargs) ax = plt.gca() else: check_type(ax, 'ax', plt.Axes, 'a Matplotlib axis') # Make the heatmap xticks = np.arange(0.5, num_columns) yticks = np.arange(0.5, num_rows) heatmap = \ ax.pcolormesh(heatmap_data, cmap=colormap, **heatmap_kwargs) ax.set_xticks(xticks, count.columns[1:], rotation=90) ax.set_yticks(yticks, count[:, 0].to_numpy()) ax.set_aspect('equal') # Add the colorbar; override the defaults for certain keys of # `colorbar_kwargs`. If normalizing rows or columns, make the # colorbar ticks percentages. if colorbar: default_colorbar_kwargs = dict(shrink=0.5, pad=0.01) colorbar_kwargs = default_colorbar_kwargs | colorbar_kwargs \ if colorbar_kwargs is not None else default_colorbar_kwargs cbar = plt.colorbar(heatmap, ax=ax, **colorbar_kwargs) cbar.outline.set_visible(False) if normalize_rows or normalize_columns: cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter( lambda x, pos: f'{100 * x:.0f}%')) # Add labels; this code is edited from `_annotate_heatmap()` at # github.com/mwaskom/seaborn/blob/master/seaborn/matrix.py if label: heatmap.update_scalarmappable() xpos, ypos = np.meshgrid(xticks, yticks) if label_kwargs is None: label_kwargs = {} label_kwargs |= dict( horizontalalignment=label_kwargs.pop( 'horizontalalignment', label_kwargs.pop('ha', 'center')), verticalalignment=label_kwargs.pop( 'verticalalignment', label_kwargs.pop('va', 'center'))) if 'c' in label_kwargs or 'color' in label_kwargs: # Use the same color for all labels for x, y, val in zip(xpos.ravel(), ypos.ravel(), heatmap_data.ravel()): ax.text(x, y, s=label_format.format(val), **label_kwargs) else: # Use either dark gray or white for the label, depending on # the cell's luminance rgb_weights = np.array([0.2126, 0.7152, 0.0722]) for x, y, color, val in zip( xpos.ravel(), ypos.ravel(), heatmap.get_facecolors(), heatmap_data.ravel()): rgb = plt.matplotlib.colors.to_rgba_array(color)[:, :3] rgb = np.where(rgb <= 0.03928, rgb / 12.92, ((rgb + 0.055) / 1.055) ** 2.4) lum = rgb.dot(rgb_weights).item() ax.text(x, y, s=label_format.format(val), c='.15' if lum > .408 else 'w', **label_kwargs) # Add the title and axis labels if xlabel is not None: if xlabel_kwargs is None: xlabel_kwargs = {} ax.set_xlabel(xlabel, **xlabel_kwargs) if ylabel is not None: if ylabel_kwargs is None: ylabel_kwargs = {} ax.set_ylabel(ylabel, **ylabel_kwargs) if title is not None: if title_kwargs is None: title_kwargs = {} ax.set_title(title, **title_kwargs) # Despine, if specified if despine: spines = ax.spines for direction in 'top', 'bottom', 'left', 'right': spines[direction].set_visible(False) # Save; override the defaults for certain keys of `savefig_kwargs` if filename is not None: default_savefig_kwargs = \ dict(dpi=300, bbox_inches='tight', pad_inches='layout', transparent=filename is not None and filename.endswith('.pdf')) savefig_kwargs = default_savefig_kwargs | savefig_kwargs \ if savefig_kwargs is not None else default_savefig_kwargs with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) plt.savefig(filename, **savefig_kwargs) if make_new_figure: plt.close() except: # If we made a new figure, make sure to close it if there's an # exception (but not if there was no error and `filename` is # `None`, in case the user wants to modify it further before # saving) if make_new_figure: plt.close() raise
@staticmethod def _process_cell_types(cell_types: str | Iterable[str] | None, excluded_cell_types: str | Iterable[str] | None, cell_type_column: pl.Series) -> pl.Series: """ Process the `cell_types` and `excluded_cell_types` arguments of various SingleCell functions. Args: cell_types: one or more cell types to include in the calling function's operation excluded_cell_types: one or more cell types to exclude from the calling function's operation cell_type_column: a column containing cell-type labels Returns: The cell-type column, with cell types not present in `cell_types` OR present in `excluded_cell_types` set to `null`. """ if cell_types is not None: if excluded_cell_types is not None: error_message = ( 'cell_types and excluded_cell_types cannot both be ' 'specified') raise ValueError(error_message) is_string = isinstance(cell_types, str) if cell_type_column.dtype.is_integer(): cell_types = to_tuple_checked(cell_types, 'cell_types', int, 'integers') else: cell_types = to_tuple_checked(cell_types, 'cell_types', str, 'strings') if cell_type_column.dtype == pl.Enum or \ cell_type_column.dtype == pl.Categorical: unique_cell_types = cell_type_column.cat.get_categories() else: unique_cell_types = \ cell_type_column.unique(maintain_order=True) for cell_type in cell_types: if cell_type not in unique_cell_types: if is_string: error_message = ( f'cell_types is {cell_type!r}, which is not a ' f'cell type in cell_type_column') raise ValueError(error_message) else: error_message = ( f'cell_types contains a cell type, {cell_type!r}, ' f'not present in cell_type_column') raise ValueError(error_message) cell_type_column = cell_type_column\ .to_frame()\ .select(pl.when(pl.first().is_in(cell_types)) .then(pl.first()))\ .to_series() elif excluded_cell_types is not None: is_string = isinstance(excluded_cell_types, str) if cell_type_column.dtype.is_integer(): excluded_cell_types = to_tuple_checked( excluded_cell_types, 'excluded_cell_types', int, 'integers') else: excluded_cell_types = to_tuple_checked( excluded_cell_types, 'excluded_cell_types', str, 'strings') if cell_type_column.dtype == pl.Enum or \ cell_type_column.dtype == pl.Categorical: unique_cell_types = cell_type_column.cat.get_categories() else: unique_cell_types = \ cell_type_column.unique(maintain_order=True) for cell_type in excluded_cell_types: if cell_type not in unique_cell_types: if is_string: error_message = ( f'excluded_cell_types is {cell_type!r}, which is ' f'not a cell type in cell_type_column') raise ValueError(error_message) else: error_message = ( f'excluded_cell_types contains a cell type, ' f'{cell_type!r}, not present in cell_type_column') raise ValueError(error_message) if len(excluded_cell_types) == len(unique_cell_types): error_message = \ 'all cell types were excluded by excluded_cell_types' raise ValueError(error_message) cell_type_column = cell_type_column\ .to_frame()\ .select(pl.when(~pl.first().is_in(excluded_cell_types)) .then(pl.first()))\ .to_series() return cell_type_column
[docs] def find_markers(self, cell_type_column: SingleCellColumn, /, *, QC_column: SingleCellColumn | None = 'passed_QC', cell_types: str | Iterable[str] | int | Iterable[int] | None = None, excluded_cell_types: str | Iterable[str] | int | Iterable[int] | None = None, min_detection_rate: int | float | np.integer | np.floating = 0.25, min_fold_change: int | float | np.integer | np.floating = 2, pareto: bool = True, all_genes: bool = False, num_threads: int | np.integer | None = None) -> \ pl.DataFrame: """ Find "marker genes" that distinguish each cell type from all other cell types. This function gives the same result regardless of whether it is run before or after normalization. Marker genes are chosen via an adaptation of the strategy of Fischer and Gillis 2021 (ncbi.nlm.nih.gov/pmc/articles/PMC8571500). For a given cell type, genes are scored based on a) their "detection rate" in that cell type (the fraction of cells of that type that have non-zero count for that gene), as well as b) the fold change in detection rate between that cell type and every other cell type. Genes must also have a detection rate of at least `min_detection_rate` (25% by default) and a minimum fold change of at least `min_fold_change` (2-fold by default) to be considered as markers. There is an inherent tradeoff between these two metrics. For instance, candidate marker genes with high enough expression to be expressed in every cell of a given type (i.e. to have a high detection rate) tend to also have at least some expression in other cell types (i.e. a low fold change in detection rate). Thus, marker genes are selected to optimally trade off between these two metrics: all genes on the Pareto front of the two metrics (i.e. genes for which there is no other gene that does better on both metrics simultaneously) are selected as marker genes. Note that Fischer and Gillis use AUROC versus log2 fold change in detection rate, instead of detection rate versus fold change in detection rate. However, detection rate is much faster to compute than AUROC, and is a very accurate proxy for AUROC: as Figure 1D in their paper shows, AUROC is almost perfectly correlated with detection rate across marker genes. Args: cell_type_column: a String, Categorical, Enum, or integer column of `obs` containing cell-type labels. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored. cell_types: one or more cell types to find markers for; by default, finds markers for all cell types in `cell_type_column`. Specifying `cell_types` is exactly equivalent to filtering the result to these cell types, but will be faster when there are many cell types and markers are only desired for a few of them. Can also be used to change the order in which cell types are reported, even if finding markers for all cell types. Mutually exclusive with `excluded_cell_types`. excluded_cell_types: one or more cell types to exclude from marker finding. Mutually exclusive with `cell_types`. min_detection_rate: the minimum detection rate required to select a gene as a marker gene; must be greater than 0 and less than or equal to 1 min_fold_change: the minimum fold change in detection rate required to select a gene as a marker gene; must be greater than 1 pareto: if `True`, include only genes on the Pareto front of detection rate and fold change as markers; if `False`, include all genes that pass the `min_detection_rate` and `min_fold_change` thresholds as markers all_genes: if `True`, include all genes in the output, not just marker genes. An additional Boolean column will be included to specify which genes are the marker genes. Note that this option does not change which marker genes are selected, only which information is returned. num_threads: the number of threads to use for marker-gene finding. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Returns: By default, a DataFrame with one row per marker gene, with columns: - `'cell_type'`: a cell-type name from `cell_type_column` - `'gene'`: a gene symbol from `var_names` - `'detection_rate'`: the gene's detection rate in that cell type - `'fold_change'`, the gene's fold change in detection rate between that cell type and all other cell types If `all_genes=True`, a DataFrame with one row per cell type-gene pair, with those four columns plus one other: - `'marker'`, a Boolean column listing whether the gene is a marker for that cell type If `all_genes=False`, marker genes within each cell type will be sorted in decreasing order of fold change. Note: This function may give an incorrect output if the count matrix contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is not checked for, due to speed considerations. In the unlikely event that your dataset contains explicit zeros, remove them by running `sc.X.eliminate_zeros()` (an in-place operation) first. Note: This function may give an incorrect output if the count matrix contains negative values: this is not checked for, due to speed considerations. """ # Check that `X` is present X = self._X if X is None: error_message = 'X is None, so marker gene finding is not possible' raise ValueError(error_message) # Check that `self` is QCed if not self._uns['QCed']: error_message = ( "uns['QCed'] is False; did you forget to run qc() before " "find_markers()? Set uns['QCed'] = True or run skip_qc() to " "bypass this check.") raise ValueError(error_message) # Get the QC column, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') # Get the cell-type column original_cell_type_column = cell_type_column cell_type_column = \ self._get_column('obs', cell_type_column, 'cell_type_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=QC_column) cell_type_column_name = cell_type_column.name # Check that `cell_types` and `excluded_cell_types` are not both # specified. If `cell_types` is specified, check it contains only cell # type names present in `cell_type_column`, then set non-matching cell # types to `null` so that they are treated as a single background cell # type. If `excluded_cell_types` is specified, do the opposite. cell_type_column = SingleCell._process_cell_types( cell_types, excluded_cell_types, cell_type_column) # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Check that `min_detection_rate` and `min_fold_change` are numeric and # have the correct ranges: 0 < min_detection_rate <= 1, # min_fold_change > 1 check_type(min_detection_rate, 'min_detection_rate', (int, float), 'a positive number less than or equal to 1') check_bounds(min_detection_rate, 'min_detection_rate', 0, 1, left_open=True) check_type(min_fold_change, 'min_fold_change', (int, float), 'a number greater than 1') check_bounds(min_fold_change, 'min_fold_change', 1, left_open=True) # Check that `all_genes` is Boolean check_type(all_genes, 'all_genes', bool, 'Boolean') # Get the indices of the cells of each cell type, ignoring cells # failing QC when `QC_column` is present in `obs` groups = (pl.LazyFrame((cell_type_column,)) .with_columns( _SingleCell_group_indices=pl.int_range(pl.len(), dtype=pl.UInt32)) if QC_column is None else pl.LazyFrame((cell_type_column, QC_column)) .with_columns( _SingleCell_group_indices=pl.int_range(pl.len(), dtype=pl.UInt32)) .filter(QC_column.name))\ .group_by(cell_type_column_name, maintain_order=True)\ .agg('_SingleCell_group_indices', _SingleCell_num_cells=pl.len())\ .sort(cell_type_column_name)\ .collect() # Check that `cell_type_column` contains at least two cell types num_cell_types = len(groups) if num_cell_types == 1: cell_type_column_description = \ SingleCell._describe_column('cell_type_column', original_cell_type_column) error_message = ( f'{cell_type_column_description} only contains one unique ' f'value') raise ValueError(error_message) # If `cell_types` is not `None`, reorder `groups` to be in the same # order as `cell_types` if cell_types is not None: groups = groups.sort(pl.first().cast(pl.Enum(cell_types)), nulls_last=True) # Get a cell-type-by-gene matrix of the number of cells of each type # with non-zero expression of each gene, i.e. the gene's detection # count in that cell type num_genes = X.shape[1] detection_count = \ np.empty((num_cell_types, num_genes), dtype=np.uint32) if isinstance(X, csr_array): group_indices = \ groups['_SingleCell_group_indices'].explode().to_numpy() group_ends = \ groups['_SingleCell_num_cells'].cum_sum().to_numpy() groupby_getnnz_csr(indices=X.indices, indptr=X.indptr, group_indices=group_indices, group_ends=group_ends, nnz=detection_count, num_threads=num_threads) else: group_map = pl.int_range(X.shape[0], dtype=pl.UInt32, eager=True)\ .to_frame('_SingleCell_group_indices')\ .join(groups .select('_SingleCell_group_indices', _SingleCell_index=pl.int_range(pl.len(), dtype=pl.Int32)) .explode('_SingleCell_group_indices'), on='_SingleCell_group_indices', how='left')\ ['_SingleCell_index'] has_missing = group_map.null_count() > 0 if has_missing: group_map = group_map.fill_null(-1) group_map = group_map.to_numpy() groupby_getnnz_csc(indices=X.indices, indptr=X.indptr, group_map=group_map, has_missing=has_missing, nnz=detection_count, num_threads=num_threads) # For each cell type, calculate the detection rate and the fold change # of the detection rate. Also, initialize the candidate set of points # on the Pareto front to those with detection rate of at least # `min_detection_rate` and fold change of at least `min_fold_change`. total_detection_count = detection_count.sum(axis=0, dtype=np.uint32) num_cells_per_cell_type = groups['_SingleCell_num_cells'].to_numpy() total_num_cells = num_cells_per_cell_type.sum() detection_rate = np.empty((num_cell_types, num_genes), dtype=np.float32) fold_change = np.empty((num_cell_types, num_genes), dtype=np.float32) is_pareto = np.empty((num_cell_types, num_genes), dtype=bool) get_detection_rate_and_fold_change_and_pareto_candidates( detection_count=detection_count, total_detection_count=total_detection_count, num_cells_per_cell_type=num_cells_per_cell_type, total_num_cells=total_num_cells, min_detection_rate=min_detection_rate, min_fold_change=min_fold_change, detection_rate=detection_rate, fold_change=fold_change, is_pareto=is_pareto, num_threads=num_threads) # If `pareto=True`, find the genes on the Pareto front of the two # metrics (tie-breaking by higher gene index); these are the marker # genes. If `pareto=False`, all genes passing the `min_detection_rate` # and `min_fold_change` thresholds are marker genes, so no additional # work needed. if pareto: pareto_front(detection_rate=detection_rate, fold_change=fold_change, is_pareto=is_pareto, num_threads=num_threads) # If `cell_types` or `excluded_cell_types` is not `None`, filter out # the `null` background cell type we introduced earlier if cell_types is not None or excluded_cell_types is not None: groups = groups[:-1] is_pareto = is_pareto[:-1] detection_rate = detection_rate[:-1] fold_change = fold_change[:-1] num_cell_types -= 1 # Return a DataFrame of the selected marker genes, or all genes if # `all_genes=True` cell_types = groups[cell_type_column_name].rename('cell_type') genes = self._var[:, 0].rename('gene') if all_genes: cell_types = pl.select(pl.lit(cell_types).repeat_by(num_genes))\ .explode('cell_type')\ .to_series() genes = pl.concat([genes] * num_cell_types) return pl.DataFrame(( cell_types, genes, pl.Series('marker', is_pareto.ravel()), pl.Series('detection_rate', detection_rate.ravel()), pl.Series('fold_change', fold_change.ravel()))) else: cell_type_indices, gene_indices = is_pareto.nonzero() return pl.DataFrame(( cell_types[cell_type_indices], genes[gene_indices], pl.Series('detection_rate', detection_rate[ cell_type_indices, gene_indices].ravel()), pl.Series('fold_change', fold_change[ cell_type_indices, gene_indices].ravel())))\ .select(pl.all().sort_by('fold_change', descending=True) .over('cell_type'))
[docs] def plot_markers(self, genes: str | Iterable[str], cell_type_column: SingleCellColumn, filename: str | Path | None = None, /, *, cells_to_plot_column: SingleCellColumn | None = 'passed_QC', color: Literal['expression', 'fold_change'] = 'expression', cell_types: str | Iterable[str] | int | Iterable[int] | None = None, excluded_cell_types: str | Iterable[str] | int | Iterable[int] | None = None, alphabetical_cell_types: bool = True, figure_kwargs: dict[str, Any] | None = None, colormap: str | 'Colormap' | None = None, infinity_color: Color = 'limegreen', NaN_color: Color = 'lightgray', colorbar: bool = True, colorbar_kwargs: dict[str, Any] | None = None, swap_axes: bool = False, scatter_kwargs: dict[str, Any] | None = None, legend_kwargs: dict[str, Any] | None = None, title: str | None = None, title_kwargs: dict[str, Any] | None = None, xlabel: str | None = None, xlabel_kwargs: dict[str, Any] | None = None, ylabel: str | None = None, ylabel_kwargs: dict[str, Any] | None = None, despine: bool = True, savefig_kwargs: dict[str, Any] | None = None, num_threads: int | np.integer | None = None) -> None: """ Make a dot plot of a set of marker genes of interest across cell types. The size of a gene's dot represents its "detection rate" in that cell type: the fraction of cells of that type where the gene has non-zero count. By default (`color='expression'`), the color of a gene's dot represents its expression. When `color='fold_change'`, the color instead represents the gene's fold change in detection rate between cells of that cell type and cells of every other cell type; this is one of two metrics (along with the detection rate) that are used to select marker genes in `find_markers()`. `color='fold_change'` gives the same result whether it is run before or after normalization, since it only depends on the pattern of non-zero entries in the data, which is unaffected by normalization. Unlike the other plotting functions, this is a figure-level rather than an axis-level function, and does not take an `axis` argument. Args: genes: a list of genes to plot: for instance, marker genes found by `find_markers()`, or marker genes from the literature cell_type_column: a String, Enum, Categorical, or integer column of `obs` containing cell-type labels. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. filename: the file to save to. If `None`, generate the plot but do not save it, which allows it to be shown interactively or modified further before saving. cells_to_plot_column: an optional Boolean column of `obs` indicating which cells to plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to plot all cells passing QC. color: whether the color of a gene's dot represents its expression (`color='expression'`, the default) or its fold change in detection rate between cells of that cell type and cells of every other cell type (`color='fold_change'`). When `color='fold_change'`, the colorbar ticks show the raw fold changes, but on a log scale, so that a fold change of 4 is twice as red as a fold change of 2, and a fold change of 0.1 is twice as blue as a fold change of 0.2. cell_types: one or more cell types to plot; by default, plots all cell types in `cell_type_column`. Can also be used to change the order in which cell types are displayed, even if plotting all cell types. Mutually exclusive with `excluded_cell_types`. excluded_cell_types: one or more cell types to exclude from the plot. Mutually exclusive with `cell_types`. alphabetical_cell_types: whether to force the cell types to be listed in alphabetical order, even when `cell_type_column` is a Categorical or Enum column where the categories are in non-alphabetical order. If `alphabetical_cell_types=False`, cell types will appear in the order specified by the Categorical or Enum column. Has no effect when `cell_type_column` is a String column, since cell types will always be plotted in alphabetical order by default. `alphabetical_cell_types=False` can only be specified when `cell_types` is `None`, since when `cell_types` is specified, it defines the order of the cell types regardless. figure_kwargs: a dictionary of keyword arguments to be passed to `plt.figure`, such as: - `figsize`: a two-element sequence of the width and height of the figure in inches. The default is a complicated formula based on the number of genes and cell types being plotted, unlike Matplotlib's default of `[6.4, 4.8]`. - `layout`: the layout mechanism used by Matplotlib to avoid overlapping plot elements. Defaults to `'constrained'`, instead of Matplotlib's default of `None`. colormap: a string or Colormap object indicating the Matplotlib colormap to use in `ax.scatter()` for representing expression values or (if `color='fold_change'`) fold changes. Defaults to `'Reds'` for expression and `'RdBu_r'` for fold changes. infinity_color: when `color='fold_change'`, the color used to plot infinite log-fold changes (i.e. where a gene is only expressed in the cell type where it is a marker). Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. NaN_color: when `color='fold_change'`, the color used to plot NaN log-fold changes (i.e. where a gene is not expressed in any cell type). Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. colorbar: whether to add a colorbar colorbar_kwargs: a dictionary of keyword arguments to be passed to `plt.colorbar()`, such as: - `location`: `'left'`, `'right'`, `'top'`, or `'bottom'` - `orientation`: `'vertical'` or `'horizontal'` - `fraction`: the fraction of the axes to allocate to the colorbar. Defaults to 0.15. - `shrink`: the fraction to multiply the size of the colorbar by. Defaults to 0.5, instead of Matplotlib's default of 1. - `aspect`: the ratio of the colorbar's long to short dimensions. Defaults to 20. - `pad`: the fraction of the axes between the colorbar and the rest of the figure. Defaults to 0.01, instead of Matplotlib's default of 0.05 if vertical and 0.15 if horizontal. Can only be specified when `colorbar=True`. swap_axes: if `True`, plot genes on the y-axis and cell types on the x-axis, instead of the other way around scatter_kwargs: a dictionary of keyword arguments to be passed to `ax.scatter()`, such as: - `rasterized`: whether to convert the scatter plot points to a raster (bitmap) image when saving to a vector format like PDF. Defaults to `False`. - `marker`: the shape to use for plotting each cell - `norm`, `vmin`, and `vmax`: control how the `colormap` maps the numbers in `color_column` to colors, if `color_column` is numeric. If neither `vmin` nor `vmax` are specified, the default behavior depends on what is being plotted. When `color='expression'` and all expression values are positive, `vmin` will be set to 0 so that the color scale includes 0. When `color='fold_change'`, `vmin` will be set to `-M` and `vmax` to `M` where `M` is the magnitude of the largest fold change, so that that the colors for positive and negative fold changes are symmetrical. - `alpha`: the transparency of each point - `linewidths` and `edgecolors`: the width and color of the borders around each marker. These are absent by default (`linewidths=0`, `edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's default. Both arguments can be either single values or sequences. - `zorder`: the order in which the cells are plotted, with higher values appearing on top of lower ones. Specifying `s`, `c`/`color`, or `cmap` will raise an error, since the size and color of each point are set automatically, and `cmap` conflicts with the `colormap` argument. legend_kwargs: a dictionary of keyword arguments to be passed to `ax.legend()` to modify the legend, such as: - `loc`, `bbox_to_anchor`, and `bbox_transform` to set its location. The legend will be placed in its own axis in the top right of the plot, and by default, `loc` is set to `'center'`. - `ncols` to set its number of columns - `prop`, `fontsize`, and `labelcolor` to set its font properties - `facecolor` and `framealpha` to set its background color and transparency - `frameon=True` or `edgecolor` to add or color its border. `frameon` defaults to `False`, instead of Matplotlib's default of `True`. - `title` to modify the legend title. Defaults to `'Detection rate'`. title: the title of the plot, or `None` to not add a title title_kwargs: a dictionary of keyword arguments to be passed to `ax.set_title()` to control text properties, such as `color` and `size`. Can only be specified when `title` is not `None`. xlabel: the x-axis label, or `None` to not label the x-axis xlabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_xlabel()` to control text properties, such as `color` and `size`. Can only be specified when `xlabel` is not `None`. ylabel: the y-axis label, or `None` to not label the y-axis ylabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_ylabel()` to control text properties, such as `color` and `size`. Can only be specified when `ylabel` is not `None`. despine: whether to remove the top and right spines (borders of the plot area) from the plot savefig_kwargs: a dictionary of keyword arguments to be passed to `plt.savefig()`, such as: - `dpi`: defaults to 300 instead of Matplotlib's default of 150 - `bbox_inches`: the bounding box of the portion of the figure to save; defaults to `'tight'` (crop out any blank borders) instead of Matplotlib's default of `None` (save the entire figure) - `pad_inches`: the number of inches of padding to add on each of the four sides of the figure when saving. Defaults to `'layout'` (use the padding from the constrained layout engine, when `ax` is not `None`), instead of Matplotlib's default of 0.1. - `transparent`: whether to save with a transparent background; defaults to `True` if saving to a PDF (i.e. when `filename` ends with `'.pdf'`) and `False` otherwise, instead of Matplotlib's default of always being `False`. Can only be specified when `filename` is specified. num_threads: the number of threads to use when tabulating each gene's detection rate and fold change. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Note: This function may give an incorrect output if the count matrix contains explicit zeros (i.e. if `(sc.X.data == 0).any()`): this is not checked for, due to speed considerations. In the unlikely event that your dataset contains explicit zeros, remove them by running `sc.X.eliminate_zeros()` (an in-place operation) first. Note: This function may give an incorrect output if the count matrix contains negative values: this is not checked for, due to speed considerations. """ # Check that `X` is present X = self._X if X is None: error_message = \ 'X is None, so marker gene plotting is not possible' raise ValueError(error_message) # Import matplotlib signal.signal(signal.SIGINT, signal.SIG_IGN) try: import matplotlib.pyplot as plt finally: signal.signal(signal.SIGINT, signal.default_int_handler) # Check that `self` is QCed if not self._uns['QCed']: error_message = ( "uns['QCed'] is False; did you forget to run qc() before " "plot_markers()? Set uns['QCed'] = True or run skip_qc() to " "bypass this check.") raise ValueError(error_message) # Get `genes` as a polars Series of the same dtype as `var_names`; # uniquify; make sure all its entries are present in `var_names` genes = to_tuple_checked(genes, 'genes', str, 'strings') genes = pl.Series(genes).unique(maintain_order=True) var_names = self._var[:, 0] if not genes.is_in(var_names).all(): if not genes.is_in(var_names).any(): error_message = \ 'none of the specified genes were found in var_names' raise ValueError(error_message) else: for gene in genes: if gene not in var_names: error_message = ( f'one of the specified genes, {gene!r}, was not ' f'found in var_names') raise ValueError(error_message) if var_names.dtype != pl.String: genes = genes.cast(var_names.dtype) # Get `cells_to_plot_column`, if not `None` if cells_to_plot_column is not None: cells_to_plot_column = self._get_column( 'obs', cells_to_plot_column, 'cells_to_plot_column', pl.Boolean, allow_missing=isinstance(cells_to_plot_column, str) and cells_to_plot_column == 'passed_QC') # Check that `color` is `'expression'` or `'fold_change'` check_type(color, 'color', str, 'a string') if color != 'expression' and color != 'fold_change': error_message = "color must be 'expression' or 'fold_change'" raise ValueError(error_message) # Get the cell-type column original_cell_type_column = cell_type_column cell_type_column = \ self._get_column('obs', cell_type_column, 'cell_type_column', (pl.String, pl.Enum, pl.Categorical, 'integer'), QC_column=cells_to_plot_column) cell_type_column_name = cell_type_column.name # Check that `cell_types` and `excluded_cell_types` are not both # specified. If `cell_types` is specified, check it contains only cell # type names present in `cell_type_column`, then set non-matching cell # types to `null` so that they are treated as a single background cell # type. If `excluded_cell_types` is specified, do the opposite. cell_type_column = SingleCell._process_cell_types( cell_types, excluded_cell_types, cell_type_column) # Check that `alphabetical_cell_types` is Boolean, and that if `False`, # `cell_types` is `None` check_type(alphabetical_cell_types, 'alphabetical_cell_types', bool, 'Boolean') if not alphabetical_cell_types and cell_types is not None: error_message = ( 'alphabetical_cell_types=False can only be specified when ' 'cell_types is None') raise ValueError(error_message) # If `filename` was specified, check that it is a string or # `pathlib.Path` and that its base directory exists; if `filename` is # `None`, make sure `savefig_kwargs` is also `None` if filename is not None: check_type(filename, 'filename', (str, Path), 'a string or pathlib.Path') directory = os.path.dirname(filename) if directory and not os.path.isdir(directory): error_message = ( f'{filename} refers to a file in the directory ' f'{directory!r}, but this directory does not exist') raise NotADirectoryError(error_message) filename = str(filename) elif savefig_kwargs is not None: error_message = 'savefig_kwargs must be None when filename is None' raise ValueError(error_message) # Check that `colormap` is a string in `plt.colormaps`, a Colormap # object, or `None`. If `None`, set to `'Reds'` for expression and # `'RdBu_r'` for fold changes. If a string, map to a Colormap object # via the `plt.colormaps` map. if colormap is None: colormap = 'RdBu_r' if color == 'fold_change' else 'Reds' else: check_type(colormap, 'colormap', (str, plt.matplotlib.colors.Colormap), 'a string or matplotlib Colormap object') if isinstance(colormap, str): colormap = plt.colormaps[colormap] # If `color='fold_change'`, check that `infinity_color` and `NaN_color` # are Matplotlib colors, and convert them to hex. Otherwise, check that # they have their default values, since they are unused. if color == 'fold_change': if not plt.matplotlib.colors.is_color_like(infinity_color): error_message = \ 'infinity_color is not a valid Matplotlib color' raise ValueError(error_message) infinity_color = plt.matplotlib.colors.to_hex(infinity_color) if not plt.matplotlib.colors.is_color_like(NaN_color): error_message = 'NaN_color is not a valid Matplotlib color' raise ValueError(error_message) NaN_color = plt.matplotlib.colors.to_hex(NaN_color) else: if not isinstance(infinity_color, str) or \ infinity_color != 'limegreen': error_message = ( "infinity_color cannot be specified unless " "color='fold_change'") raise ValueError(error_message) if not isinstance(NaN_color, str) or NaN_color != 'lightgray': error_message = \ "NaN_color cannot be specified unless color='fold_change'" raise ValueError(error_message) # Check that `colorbar` is Boolean check_type(colorbar, 'colorbar', bool, 'Boolean') # If `colorbar=False`, check that `colorbar_kwargs` is None if not colorbar and colorbar_kwargs is not None: error_message = 'colorbar_kwargs must be None when colorbar=False' raise ValueError(error_message) # Check that `swap_axes` and `despine` are Boolean check_type(swap_axes, 'swap_axes', bool, 'Boolean') check_type(despine, 'despine', bool, 'Boolean') # Check that `title` is a string or `None`; if `None`, check that # `title_kwargs` is `None` as well. Ditto for `xlabel` and `ylabel`. for arg, arg_name, arg_kwargs in ( (title, 'title', title_kwargs), (xlabel, 'xlabel', xlabel_kwargs), (ylabel, 'ylabel', ylabel_kwargs)): if arg is not None: check_type(arg, arg_name, str, 'a string') elif arg_kwargs is not None: error_message = \ f'{arg_name}_kwargs must be None when {arg_name} is None' raise ValueError(error_message) # For each of the kwargs arguments, if the argument was specified, # check that it is a dictionary and that all its keys are strings. for kwargs, kwargs_name in ((figure_kwargs, 'figure_kwargs'), (colorbar_kwargs, 'colorbar_kwargs'), (scatter_kwargs, 'scatter_kwargs'), (legend_kwargs, 'legend_kwargs'), (title_kwargs, 'title_kwargs'), (xlabel_kwargs, 'xlabel_kwargs'), (ylabel_kwargs, 'ylabel_kwargs'), (savefig_kwargs, 'savefig_kwargs')): if kwargs is not None: check_type(kwargs, kwargs_name, dict, 'a dictionary') for key in kwargs: if not isinstance(key, str): error_message = ( f'all keys of {kwargs_name} must be strings, but ' f'it contains a key of type ' f'{type(key).__name__!r}') raise TypeError(error_message) # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Get the indices of the cells of each cell type, ignoring cells # failing QC when `QC_column` is present in `obs`. Also get the number # of cells of each type. groups = (pl.LazyFrame((cell_type_column,)) .with_columns( _SingleCell_group_indices=pl.int_range(pl.len(), dtype=pl.UInt32)) if cells_to_plot_column is None else pl.LazyFrame((cell_type_column, cells_to_plot_column)) .with_columns( _SingleCell_group_indices=pl.int_range(pl.len(), dtype=pl.UInt32)) .filter(cells_to_plot_column.name))\ .group_by(cell_type_column_name, maintain_order=True)\ .agg('_SingleCell_group_indices', _SingleCell_num_cells=pl.len())\ .sort(cell_type_column_name, nulls_last=True)\ .collect() # Check that `cell_type_column` contains at least two cell types num_cell_types = len(groups) if num_cell_types == 1: cell_type_column_description = \ SingleCell._describe_column('cell_type_column', original_cell_type_column) error_message = ( f'{cell_type_column_description} only contains one unique ' f'value') raise ValueError(error_message) # If `cell_types` is not `None`, reorder `groups` to be in the same # order as `cell_types`. Otherwise, get the list of cell types from # `groups`, sorting alphabetically if `alphabetical_cell_types` and # `cell_type_column` is Enum or Categorical. if cell_types is not None: groups = groups.sort(pl.first().cast(pl.Enum(cell_types)), nulls_last=True) else: if alphabetical_cell_types and ( cell_type_column.dtype == pl.Enum or cell_type_column.dtype == pl.Categorical): groups = groups\ .cast({cell_type_column_name: pl.String})\ .sort(cell_type_column_name) cell_types = groups[cell_type_column_name] # Get a cell-type-by-gene matrix of the number of cells of each type # with non-zero expression of each gene in `genes`, i.e. the gene's # detection count in that cell type. If `color='expression'`, also # get a second cell-type-by-gene matrix of the total expression of each # gene in each cell type; we will normalize to get the mean later. num_genes = len(genes) detection_count = \ np.empty((num_cell_types, num_genes), dtype=np.uint32) if isinstance(X, csr_array): group_indices = \ groups['_SingleCell_group_indices'].explode().to_numpy() group_ends = \ groups['_SingleCell_num_cells'].cum_sum().to_numpy() # Get an array mapping each gene in `var_names` to its position in # `genes` (-1 if missing from `genes`) gene_map = var_names\ .to_frame()\ .join(genes .to_frame(var_names.name) .with_columns(_SingleCell_index=pl.int_range( pl.len(), dtype=pl.Int32)), on=var_names.name, how='left')\ .select('_SingleCell_index')\ .to_series() gene_map = gene_map.fill_null(-1).to_numpy() if color == 'fold_change': groupby_getnnz_csr_for_gene_subset( indices=X.indices, indptr=X.indptr, group_indices=group_indices, group_ends=group_ends, gene_map=gene_map, nnz=detection_count, num_threads=num_threads) else: total_expression = np.empty((num_cell_types, num_genes)) groupby_getnnz_and_total_csr_for_gene_subset( data=X.data, indices=X.indices, indptr=X.indptr, group_indices=group_indices, group_ends=group_ends, gene_map=gene_map, nnz=detection_count, total=total_expression, num_threads=num_threads) else: group_map = pl.int_range(X.shape[0], dtype=pl.UInt32, eager=True)\ .to_frame('_SingleCell_group_indices')\ .join(groups .select('_SingleCell_group_indices', _SingleCell_index=pl.int_range(pl.len(), dtype=pl.Int32)) .explode('_SingleCell_group_indices'), on='_SingleCell_group_indices', how='left')\ ['_SingleCell_index'] has_missing = group_map.null_count() > 0 if has_missing: group_map = group_map.fill_null(-1) group_map = group_map.to_numpy() # Get an array mapping each gene in `genes` to its position in # `var_names` gene_map = genes\ .to_frame(var_names.name)\ .join(var_names .to_frame() .with_columns(_SingleCell_index=pl.int_range( pl.len(), dtype=pl.UInt32)), on=var_names.name, how='left')\ .select('_SingleCell_index')\ .to_series()\ .to_numpy() if color == 'fold_change': groupby_getnnz_csc_for_gene_subset( indices=X.indices, indptr=X.indptr, group_map=group_map, gene_map=gene_map, has_missing=has_missing, nnz=detection_count, num_threads=num_threads) else: total_expression = np.empty((num_cell_types, num_genes)) groupby_getnnz_and_total_csc_for_gene_subset( data=X.data, indices=X.indices, indptr=X.indptr, group_map=group_map, gene_map=gene_map, has_missing=has_missing, nnz=detection_count, total=total_expression, num_threads=num_threads) # For each cell type, calculate the detection rate and (if # `color == 'fold_change'`) the fold change of the detection rate total_detection_count = detection_count.sum(axis=0, dtype=np.uint32) detection_rate = np.empty((num_cell_types, num_genes), dtype=np.float32) num_cells_per_cell_type = groups['_SingleCell_num_cells'].to_numpy() total_num_cells = num_cells_per_cell_type.sum() if color == 'fold_change': fold_change = np.empty((num_cell_types, num_genes), dtype=np.float32) get_detection_rate_and_fold_change( detection_count=detection_count, total_detection_count=total_detection_count, num_cells_per_cell_type=num_cells_per_cell_type, total_num_cells=total_num_cells, detection_rate=detection_rate, fold_change=fold_change, num_threads=num_threads) else: get_detection_rate( detection_count=detection_count, total_detection_count=total_detection_count, num_cells_per_cell_type=num_cells_per_cell_type, total_num_cells=total_num_cells, detection_rate=detection_rate, num_threads=num_threads) # If plotting expression, convert total expression to mean expression if color == 'expression': mean_expression = \ total_expression / num_cells_per_cell_type[:, None] # If `cell_types` or `excluded_cell_types` were specified and we # introduced a `null` background cell type as a result, remove it now if groups[cell_type_column_name][-1] is None: detection_rate = detection_rate[:-1] if color == 'fold_change': fold_change = fold_change[:-1] else: mean_expression = mean_expression[:-1] if excluded_cell_types is not None: cell_types = cell_types.drop_nulls() num_cell_types -= 1 # If `swap_axes=True`, swap cell types and genes if swap_axes: cell_types, genes = genes, cell_types num_cell_types, num_genes = num_genes, num_cell_types detection_rate = detection_rate.T if color == 'fold_change': fold_change = fold_change.T else: mean_expression = mean_expression.T # Calculate the range of the legend, and the multiplier to multiply # each point's size by max_detection_rate = detection_rate.max() interval = 0.2 if max_detection_rate > 0.5 else \ 0.1 if max_detection_rate > 0.2 else 0.05 max_detection_rate = \ np.ceil(max_detection_rate / interval) * interval legend_point_sizes = \ np.arange(interval, max_detection_rate + interval / 2, interval) point_size_multiplier = 180 / max_detection_rate try: # Make the figure, including separate portions on the left for the # legend and colorbar (if `colorbar=True`) default_figure_kwargs = dict(layout='constrained') width_ratio = max(4, 0.2 * num_genes) if figure_kwargs is None or 'figsize' not in figure_kwargs: gene_multiplier = 0.4 \ if num_genes < 5 else 0.05 * (num_genes - 5) + 0.4 \ if num_genes < 10 else 0.04 * (num_genes - 10) + 0.65 \ if num_genes < 20 else 0.03 * (num_genes - 20) + 1.05 if colorbar: width = 6.4 / (1 + width_ratio) + \ 6.4 * width_ratio / (1 + width_ratio) * \ gene_multiplier else: width = 6.4 * gene_multiplier height = max(4.8, 1 + 3.8 * num_cell_types / 20) default_figure_kwargs['figsize'] = width, height figure_kwargs = default_figure_kwargs | figure_kwargs \ if figure_kwargs is not None else default_figure_kwargs fig = plt.figure(**figure_kwargs) if colorbar: gs = fig.add_gridspec(2, 2, width_ratios=[width_ratio, 1], height_ratios=[1, 1]) else: gs = fig.add_gridspec(2, 1, height_ratios=[1, 1]) # Plot the circles; override the defaults for certain keys of # `scatter_kwargs`. If neither `vmin` nor `vmax` are specified, # set `vmin=0` when `color='expression'` and all expression values # are positive to ensure the color scale includes 0, and specify # `vmin` and `vmax` to be centered at 0 when `color='fold_change'` # so that the colors are symmetrical. ax_main = fig.add_subplot(gs[:, 0]) # the main plot spans all rows point_size = detection_rate * point_size_multiplier x, y = np.meshgrid(range(len(genes)), range(len(cell_types))) default_scatter_kwargs = \ dict(linewidths=0, edgecolors=(0, 0, 0, 0)) scatter_kwargs = default_scatter_kwargs | scatter_kwargs \ if scatter_kwargs is not None else default_scatter_kwargs if color == 'fold_change': with np.errstate(divide='ignore'): c = np.log2(fold_change) c_finite = c.copy() c_finite[~np.isfinite(c_finite)] = np.nan if 'vmin' not in scatter_kwargs and \ 'vmax' not in scatter_kwargs: vmax = np.nanmax(np.abs(c_finite)) vmin = -vmax scatter_kwargs['vmax'] = vmax scatter_kwargs['vmin'] = vmin else: c_finite = mean_expression if 'vmin' not in scatter_kwargs and \ 'vmax' not in scatter_kwargs and \ mean_expression.min() > 0: scatter_kwargs['vmin'] = 0 scatter = ax_main.scatter(x.ravel(), y.ravel(), s=point_size.ravel(), c=c_finite.ravel(), cmap=colormap, **scatter_kwargs) ax_main.set_aspect('equal') # If `color='fold_change'`, plot infinite and NaN log-fold changes # in `infinity_color` and `NaN_color`, respectively. This has to be # done separately since we are passing in actual colors, not # numbers that are being mapped to colors. if color == 'fold_change': del scatter_kwargs['vmin'] del scatter_kwargs['vmax'] infinite_indices = np.where(np.isinf(c)) if infinite_indices[0].size > 0: ax_main.scatter(infinite_indices[1], infinite_indices[0], s=point_size[infinite_indices], c=infinity_color, **scatter_kwargs) NaN_indices = np.where(np.isnan(c)) if NaN_indices[0].size > 0: ax_main.scatter(NaN_indices[1], NaN_indices[0], s=point_size[NaN_indices], c=NaN_color, **scatter_kwargs) # Set x and y limits padding = 0.6 ax_main.set_xlim((-padding, len(genes) - 1 + padding)) ax_main.set_ylim((-padding, len(cell_types) - 1 + padding)) # Invert the y-axis (must be done after setting x and y limits) ax_main.invert_yaxis() # Add x and y ticks and tick labels ax_main.set_xticks(range(len(genes)), genes, rotation=90) ax_main.set_yticks(range(len(cell_types)), cell_types) if xlabel is not None: if xlabel_kwargs is None: ax_main.set_xlabel(xlabel) else: ax_main.set_xlabel(xlabel, **xlabel_kwargs) if ylabel is not None: if ylabel_kwargs is None: ax_main.set_ylabel(ylabel) else: ax_main.set_ylabel(ylabel, **ylabel_kwargs) # Add a legend for detection rate; markers should be at intervals # of `X`% (`X`%, `2X`%, `3X`%, ...) up to the maximum detection # rate (rounded up to the nearest `X`%). Override the defaults for # certain keys of `legend_kwargs`. ax_legend = fig.add_subplot(gs[0, 1]) ax_legend.axis('off') legend_elements = [ plt.Line2D([0], [0], label=f'{100 * size:.0f}%', markersize=np.sqrt(size * point_size_multiplier), marker='o', linestyle='None', markerfacecolor='black', markeredgecolor='None') for size in legend_point_sizes] default_legend_kwargs = dict(title='Detection rate', loc='center', frameon=False) legend_kwargs = default_legend_kwargs | legend_kwargs \ if legend_kwargs is not None else default_legend_kwargs if legend_elements: ax_legend.legend(handles=legend_elements, **legend_kwargs) # Add a colorbar for expression, or if `color='fold_change'`, fold # change with labels at powers of 2. if colorbar: default_colorbar_kwargs = dict(shrink=0.5, pad=0.01) colorbar_kwargs = default_colorbar_kwargs | colorbar_kwargs \ if colorbar_kwargs is not None else \ default_colorbar_kwargs ax_colorbar = fig.add_subplot(gs[1, 1]) cbar = plt.colorbar(scatter, cax=ax_colorbar, **colorbar_kwargs) cbar.outline.set_visible(False) cbar.ax.set_box_aspect(12) cbar.ax.set_title('Fold change of\ndetection rate' if color == 'fold_change' else 'Mean\nexpression', size='medium') if color == 'fold_change': cbar.ax.yaxis.set_major_locator( plt.MaxNLocator(integer=True)) cbar.ax.yaxis.set_major_formatter(plt.FuncFormatter( lambda x, pos: f'{2 ** x:.4f}'.rstrip( '0').rstrip('.'))) # Add the title if title is not None: if title_kwargs is None: ax_main.set_title(title) else: ax_main.set_title(title, **title_kwargs) # Despine, if specified if despine: spines = ax_main.spines spines['top'].set_visible(False) spines['right'].set_visible(False) # Save, if `filename` is not `None`; override the defaults for # certain keys of `savefig_kwargs` if filename is not None: default_savefig_kwargs = \ dict(dpi=300, bbox_inches='tight', pad_inches='layout', transparent=filename is not None and filename.endswith('.pdf')) savefig_kwargs = default_savefig_kwargs | savefig_kwargs \ if savefig_kwargs is not None else default_savefig_kwargs with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) plt.savefig(filename, **savefig_kwargs) plt.close() except: # Since we made a new figure, make sure to close it if there's an # exception (but not if there was no error and `filename` is # `None`, in case the user wants to modify it further before # saving) plt.close() raise
[docs] def pacmap(self, *, QC_column: SingleCellColumn | None = 'passed_QC', PC_key: str = 'pca', neighbors_key: str = 'neighbors', distances_key: str = 'distances', embedding_key: str = 'pacmap', num_neighbors: int | np.integer = 10, num_extra_neighbors: int | np.integer = 10, num_mid_near_pairs: int | np.integer = 5, num_further_pairs: int | np.integer = 20, num_iterations: int | np.integer | tuple[int | np.integer, int | np.integer, int | np.integer] = (100, 100, 250), learning_rate: int | float | np.integer | np.floating = 1, seed: int | np.integer = 0, match_parallel: bool = False, overwrite: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Calculate a two-dimensional embedding of this SingleCell dataset suitable for plotting with `plot_embedding()`. Uses [PaCMAP](https://arxiv.org/pdf/2012.04456), a relative of UMAP that captures global structure better. This function is intended to be run after `PCA()` and `neighbors()`. By default, it uses `obsm['pca']` and `obsm['neighbors']` as the inputs to PaCMAP, and stores the output in `obsm['pacmap']` as a `len(obs)` × 2 NumPy array. It can also be run on Harmony embeddings by running `harmonize()` and then specifying `PC_key='harmony'`. Args: QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their embeddings set to `NaN`. PC_key: the key of `obsm` containing the principal components calculated with `PCA()`, to use as an input for the embedding calculation. Can also be set to the Harmony embeddings calculated by `harmonize()`, by specifying `PC_key='harmony'`. neighbors_key: the key of `obsm` containing the nearest-neighbor indices for each cell, to use as an input for the embedding calculation distances_key: the key of `obsm` containing the squared Euclidean distance to each nearest neighbor in `neighbors_key`, to use as an input for the embedding calculation embedding_key: the key of `obsm` where the embeddings will be stored num_neighbors: the number of nearest neighbors in the original high-dimensional space to consider for each point. Higher values focus on preserving the broader topological structure of local neighborhoods, potentially merging close clusters. Lower values prioritize the very fine-grained local structure, which can reveal intricate patterns but may also fragment larger clusters. num_extra_neighbors: the number of extra nearest neighbors (on top of `num_neighbors`) to search for initially, before pruning to the `num_neighbors` of these `num_neighbors + num_extra_neighbors` cells with the smallest scaled distances. For a pair of cells `i` and `j`, the scaled distance between `i` and `j` is its squared Euclidean distance, divided by `i`'s average Euclidean distance to its 3rd, 4th, and 5th nearest neighbors, divided by `j`'s average Euclidean distance to its 3rd, 4th, and 5th nearest neighbors. Must be a non-negative integer. Defaults to 10, instead of PaCMAP's original default of 50. `neighbors_key` and `distances_key` must contain at least `num_neighbors + num_extra_neighbors` nearest neighbors. num_mid_near_pairs: the number of moderately close cells (not nearest neighbors) to sample for each cell, used to attract distinct local neighborhoods together. Higher values add more "scaffolding" to preserve the large-scale global structure and the relationships between clusters. Lower values reduce this effect, allowing local structures to be placed more independently of one another. num_further_pairs: the number of distant cells to sample for each cell, used to create repulsive forces that prevent crowding and shape the final layout. Higher values increase this repulsive force, leading to a more spread-out embedding with clearer separation between clusters. Lower values reduce the force, which can result in a more compact layout where clusters may be closer or overlap. num_iterations: the number of iterations to run PaCMAP for. Can be a length-3 tuple of the number of iterations for each of the 3 stages of optimization, or a single integer of the number of iterations for the third stage (in which case the number of iterations for the first two stages will be set to 100). learning_rate: the learning rate of the Adam optimizer for PaCMAP seed: the random seed to use for PaCMAP match_parallel: if `False`, use a different order of operations for single-threaded PaCMAP. This gives a modest (~15%) boost in single-threaded performance at the cost of no longer exactly matching the embedding produced by the multithreaded version (due to differences in floating-point error arising from the different order of operations). Must be `False` unless `num_threads=1`. overwrite: if `True`, overwrite `embedding_key` if already present in `obsm`, instead of raising an error num_threads: the number of threads to use when running PaCMAP. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset with the PaCMAP embedding stored in `obsm[embedding_key]`. Note: PaCMAP's original implementation assumes generic input data, so it initializes the embedding by standardizing the input data, running PCA on it, and taking the first two PCs. Because our input data is already PCs (or harmonized PCs), we avoid redundant calculations by omitting this step and directly initializing the embedding with the first two columns of our input data, i.e. the first two PCs. """ # Check that `embedding_key` is a string check_type(embedding_key, 'embedding_key', str, 'a string') # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Check that `embedding_key` is not already a key in `obsm`, unless # `overwrite=True` if not overwrite and embedding_key in self._obsm: error_message = ( f'embedding_key {embedding_key!r} is already a key of obsm; ' f'did you already run pacmap()? Set overwrite=True to ' f'overwrite.') raise ValueError(error_message) # Get the QC column, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') # Get PCs, and check that they are float32 and C-contiguous check_type(PC_key, 'PC_key', str, 'a string') if PC_key not in self._obsm: error_message = f'PC_key {PC_key!r} is not a key of obsm' if PC_key == 'pca': error_message += ( '; did you forget to run PCA() (and possibly neighbors()) ' 'before pacmap()?') raise ValueError(error_message) PCs = self._obsm[PC_key] if PCs.dtype != np.float32: error_message = \ f'obsm[{PC_key!r}].dtype is {PCs.dtype!r}, but must be float32' raise TypeError(error_message) if not PCs.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{PC_key!r}] is not C-contiguous; make it C-contiguous ' f'with pipe_obsm_key({PC_key!r}, np.ascontiguousarray)') raise ValueError(error_message) # Get the nearest-neighbor indices and distances, and check that they # are uint32 and float32, respectively, C-contiguous, and have the same # width check_type(neighbors_key, 'neighbors_key', str, 'a string') if neighbors_key not in self._obsm: error_message = \ f'neighbors_key {neighbors_key!r} is not a key of obsm' if neighbors_key == 'neighbors': error_message += ( '; did you forget to run neighbors() before pacmap()?') raise ValueError(error_message) neighbors = self._obsm[neighbors_key] if neighbors.dtype != np.uint32: error_message = ( f'obsm[{neighbors_key!r}] must have uint32 data type, but ' f'has data type {str(neighbors.dtype)!r}') raise TypeError(error_message) if not neighbors.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{neighbors_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)') raise ValueError(error_message) check_type(distances_key, 'distances_key', str, 'a string') if distances_key not in self._obsm: error_message = \ f'distances_key {distances_key!r} is not a key of obsm' if distances_key == 'distances': error_message += ( '; did you forget to run neighbors() before pacmap()?') raise ValueError(error_message) distances = self._obsm[distances_key] if distances.dtype != np.float32: error_message = ( f'obsm[{distances_key!r}] must have float32 data type, but ' f'has data type {str(distances.dtype)!r}') raise TypeError(error_message) if not distances.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{distances_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({distances_key!r}, np.ascontiguousarray)') raise ValueError(error_message) if neighbors.shape[1] != distances.shape[1]: error_message = ( f'obsm[{neighbors_key!r}] and obsm[{distances_key!r}] have ' f'different numbers of columns ({neighbors.shape[1]:,} vs ' f'{distances.shape[1]:,}') raise ValueError(error_message) # Check that `num_extra_neighbors` is ≥ 0 check_type(num_extra_neighbors, 'num_extra_neighbors', int, 'a non-negative integer') check_bounds(num_extra_neighbors, 'num_extra_neighbors', 0) # Check that `num_iterations` is a positive integer or length-3 tuple # thereof check_type(num_iterations, 'num_iterations', (int, tuple), 'a positive integer or length-3 tuple of positive ' 'integers') if isinstance(num_iterations, tuple): if len(num_iterations) != 3: error_message = ( f'num_iterations must be a positive integer or ' f'length-3 tuple of positive integers, but has length ' f'{len(num_iterations):,}') raise ValueError(error_message) for step, step_num_iterations in enumerate(num_iterations): check_type(step_num_iterations, f'num_iterations[{step!r}]', int, 'a positive integer') check_bounds(step_num_iterations, f'num_iterations[{step!r}]', 1) num_phase_1_iterations, num_phase_2_iterations, \ num_phase_3_iterations = num_iterations else: check_bounds(num_iterations, 'num_iterations', 1) num_phase_1_iterations = num_phase_2_iterations = 100 num_phase_3_iterations = num_iterations # Check that `learning_rate` is a positive floating-point number check_type(learning_rate, 'learning_rate', (int, float), 'a positive number') check_bounds(learning_rate, 'learning_rate', 0, left_open=True) # Check that `seed` is an integer check_type(seed, 'seed', int, 'an integer') # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Check that `match_parallel` is Boolean, and `False` unless # `num_threads=1` check_type(match_parallel, 'match_parallel', bool, 'Boolean') if match_parallel and num_threads != 1: error_message = \ 'match_parallel must be False unless num_threads is 1' raise ValueError(error_message) # Subset PCs and nearest-neighbor indices to QCed cells only, if # `QC_column` is not `None` if QC_column is not None: QC_column_NumPy = QC_column.to_numpy() if num_threads == 1: PCs = PCs[QC_column_NumPy] neighbors = neighbors[QC_column_NumPy] distances = distances[QC_column_NumPy] else: indices = np.flatnonzero(QC_column_NumPy) PCs = parallel_subset_2d(PCs, indices, num_threads) neighbors = parallel_subset_2d(neighbors, indices, num_threads) distances = parallel_subset_2d(distances, indices, num_threads) # Check that there are at least 7 cells (since # `sample_mid_near_pairs() requires 6 other cells) num_cells = PCs.shape[0] if num_cells < 7: error_message = ( f'there are fewer than 7 cells, so the embedding cannot be ' f'calculated') raise ValueError(error_message) # Check that `num_neighbors` is between 1 and `num_cells - 1` check_type(num_neighbors, 'num_neighbors', int, 'a positive integer') if not 1 <= num_neighbors < num_cells: error_message = ( f'num_neighbors is {num_neighbors:,}, but must be ≥ 1 and ' f'less than the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that `num_mid_near_pairs` and `num_further_pairs` are between 1 # and `num_cells` for variable, variable_name in ( (num_mid_near_pairs, 'num_mid_near_pairs'), (num_further_pairs, 'num_further_pairs')): check_type(variable, variable_name, int, 'a positive integer') if not 1 <= variable <= num_cells: error_message = ( f'{variable_name} is {variable:,}, but must be ≥ 1 and ≤ ' f'the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that there are at least `num_neighbors + num_further_pairs + 1` # cells (since `sample_further_pairs()` requires # `num_neighbors + num_further_pairs` other cells) if num_cells < num_neighbors + num_further_pairs + 1: error_message = ( f'there are fewer than ' f'{num_neighbors + num_further_pairs + 1} (num_neighbors + ' f'num_further_pairs + 1) cells, so the embedding cannot be ' f'calculated') raise ValueError(error_message) # Define `num_total_neighbors` as `num_neighbors + num_extra_neighbors` num_total_neighbors = num_neighbors + num_extra_neighbors # Check that `num_total_neighbors` is less than `num_cells` if num_total_neighbors >= num_cells: error_message = ( f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors ' f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but ' f'must be less than the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that `neighbors` and `distances` contain at most # `num_total_neighbors` nearest neighbors if num_total_neighbors > neighbors.shape[1]: error_message = ( f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors ' f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but ' f'obsm[{neighbors_key!r}] has only {neighbors.shape[1]} ' f'columns') raise ValueError(error_message) if num_total_neighbors > distances.shape[1]: error_message = ( f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors ' f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but ' f'obsm[{distances_key!r}] has only {distances.shape[1]} ' f'columns') raise ValueError(error_message) # Run PaCMAP if num_threads == 1: embedding = np.empty((num_cells, 2), dtype=np.float32) momentum = np.empty((num_cells, 2), dtype=np.float32) velocity = np.empty((num_cells, 2), dtype=np.float32) gradients = np.empty((num_cells, 2), dtype=np.float32) average_distances = np.empty(num_cells, dtype=np.float32) neighbor_pairs = np.empty((num_cells, num_neighbors), dtype=np.uint32) mid_near_pairs = np.empty((num_cells, num_mid_near_pairs), dtype=np.uint32) further_pairs = np.empty((num_cells, num_further_pairs), dtype=np.uint32) if match_parallel: neighbor_pair_indices = \ np.empty(2 * num_cells * num_neighbors, dtype=np.uint32) neighbor_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32) mid_near_pair_indices = np.empty( 2 * num_cells * num_mid_near_pairs, dtype=np.uint32) mid_near_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32) further_pair_indices = np.empty( 2 * num_cells * num_further_pairs, dtype=np.uint32) further_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32) else: neighbor_pair_indices = np.array([], dtype=np.uint32) neighbor_pair_indptr = np.array([], dtype=np.uint32) mid_near_pair_indices = np.array([], dtype=np.uint32) mid_near_pair_indptr = np.array([], dtype=np.uint32) further_pair_indices = np.array([], dtype=np.uint32) further_pair_indptr = np.array([], dtype=np.uint32) else: embedding = numa_zeros((num_cells, 2), dtype=np.float32) momentum = numa_zeros((num_cells, 2), dtype=np.float32) velocity = numa_zeros((num_cells, 2), dtype=np.float32) gradients = numa_zeros((num_cells, 2), dtype=np.float32) average_distances = numa_zeros(num_cells, dtype=np.float32) neighbor_pairs = numa_zeros((num_cells, num_neighbors), dtype=np.uint32) mid_near_pairs = numa_zeros((num_cells, num_mid_near_pairs), dtype=np.uint32) further_pairs = numa_zeros((num_cells, num_further_pairs), dtype=np.uint32) neighbor_pair_indices = \ numa_zeros(2 * num_cells * num_neighbors, dtype=np.uint32) neighbor_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32) mid_near_pair_indices = numa_zeros( 2 * num_cells * num_mid_near_pairs, dtype=np.uint32) mid_near_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32) further_pair_indices = numa_zeros( 2 * num_cells * num_further_pairs, dtype=np.uint32) further_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32) pacmap(PCs, embedding, momentum, velocity, gradients, average_distances, neighbor_pairs, mid_near_pairs, further_pairs, neighbor_pair_indices, neighbor_pair_indptr, mid_near_pair_indices, mid_near_pair_indptr, further_pair_indices, further_pair_indptr, neighbors, distances, num_neighbors, num_extra_neighbors, num_mid_near_pairs, num_further_pairs, num_phase_1_iterations, num_phase_2_iterations, num_phase_3_iterations, learning_rate, seed, match_parallel, neighbors_key, num_threads) # If `QC_column` was specified, back-project from QCed cells to all # cells, filling with `NaN` if QC_column is not None: embedding_QCed = embedding embedding = np.full((len(self), embedding_QCed.shape[1]), np.nan, dtype=np.float32) embedding[QC_column_NumPy] = embedding_QCed return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm | {embedding_key: embedding}, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def localmap(self, *, QC_column: SingleCellColumn | None = 'passed_QC', PC_key: str = 'pca', neighbors_key: str = 'neighbors', distances_key: str = 'distances', embedding_key: str = 'localmap', num_neighbors: int | np.integer = 10, num_extra_neighbors: int | np.integer = 10, num_mid_near_pairs: int | np.integer = 5, num_further_pairs: int | np.integer = 20, num_iterations: int | np.integer | tuple[int | np.integer, int | np.integer, int | np.integer] = (100, 100, 250), learning_rate: int | float | np.integer | np.floating = 1, max_distance: int | float | np.integer | np.floating = 10, seed: int | np.integer = 0, match_parallel: bool = False, overwrite: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Calculate a two-dimensional embedding of this SingleCell dataset suitable for plotting with `plot_embedding()`. Uses [LocalMAP](https://arxiv.org/abs/2412.15426), a relative of UMAP that captures global structure better. This function is intended to be run after `PCA()` and `neighbors()`. By default, it uses `obsm['pca']` and `obsm['neighbors']` as the inputs to LocalMAP, and stores the output in `obsm['localmap']` as a `len(obs)` × 2 NumPy array. It can also be run on Harmony embeddings by running `harmonize()` and then specifying `PC_key='harmony'`. Args: QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their embeddings set to `NaN`. PC_key: the key of `obsm` containing the principal components calculated with `PCA()`, to use as an input for the embedding calculation. Can also be set to the Harmony embeddings calculated by `harmonize()`, by specifying `PC_key='harmony'`. neighbors_key: the key of `obsm` containing the nearest-neighbor indices for each cell, to use as an input for the embedding calculation distances_key: the key of `obsm` containing the squared Euclidean distance to each nearest neighbor in `neighbors_key`, to use as an input for the embedding calculation embedding_key: the key of `obsm` where the embeddings will be stored num_neighbors: the number of nearest neighbors in the original high-dimensional space to consider for each point. Higher values focus on preserving the broader topological structure of local neighborhoods, potentially merging close clusters. Lower values prioritize the very fine-grained local structure, which can reveal intricate patterns but may also fragment larger clusters. num_extra_neighbors: the number of extra nearest neighbors (on top of `num_neighbors`) to search for initially, before pruning to the `num_neighbors` of these `num_neighbors + num_extra_neighbors` cells with the smallest scaled distances. For a pair of cells `i` and `j`, the scaled distance between `i` and `j` is its squared Euclidean distance, divided by `i`'s average Euclidean distance to its 3rd, 4th, and 5th nearest neighbors, divided by `j`'s average Euclidean distance to its 3rd, 4th, and 5th nearest neighbors. Must be a non-negative integer. Defaults to 10, instead of LocalMAP's original default of 50. `neighbors_key` and `distances_key` must contain at least `num_neighbors + num_extra_neighbors` nearest neighbors. num_mid_near_pairs: the number of moderately close cells (not nearest neighbors) to sample for each cell, used to attract distinct local neighborhoods together. Higher values add more "scaffolding" to preserve the large-scale global structure and the relationships between clusters. Lower values reduce this effect, allowing local structures to be placed more independently of one another. num_further_pairs: the number of distant cells to sample for each cell, used to create repulsive forces that prevent crowding and shape the final layout. Higher values increase this repulsive force, leading to a more spread-out embedding with clearer separation between clusters. Lower values reduce the force, which can result in a more compact layout where clusters may be closer or overlap. num_iterations: the number of iterations to run LocalMAP for. Can be a length-3 tuple of the number of iterations for each of the 3 stages of optimization, or a single integer of the number of iterations for the third stage (in which case the number of iterations for the first two stages will be set to 100). learning_rate: the learning rate of the Adam optimizer for LocalMAP max_distance: the distance cutoff (in the embedding space) above which cells will not be considered as further pair candidates during the final stage of optimization, also used to define a scaling factor for the nearest-neighbor gradients during this stage. This parameter must be set near its default value of 10 to produce a good-quality embedding. seed: the random seed to use for LocalMAP match_parallel: if `False`, use a different order of operations for single-threaded LocalMAP. This gives a modest (~15%) boost in single-threaded performance at the cost of no longer exactly matching the embedding produced by the multithreaded version (due to differences in floating-point error arising from the different order of operations). Must be `False` unless `num_threads=1`. overwrite: if `True`, overwrite `embedding_key` if already present in `obsm`, instead of raising an error num_threads: the number of threads to use when running LocalMAP. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset with the LocalMAP embedding stored in `obsm[embedding_key]`. Note: LocalMAP's original implementation assumes generic input data, so it initializes the embedding by standardizing the input data, running PCA on it, and taking the first two PCs. Because our input data is already PCs (or harmonized PCs), we avoid redundant calculations by omitting this step and directly initializing the embedding with the first two columns of our input data, i.e. the first two PCs. """ # Check that `embedding_key` is a string check_type(embedding_key, 'embedding_key', str, 'a string') # Check that `overwrite` is Boolean check_type(overwrite, 'overwrite', bool, 'Boolean') # Check that `embedding_key` is not already a key in `obsm`, unless # `overwrite=True` if not overwrite and embedding_key in self._obsm: error_message = ( f'embedding_key {embedding_key!r} is already a key of obsm; ' f'did you already run localmap()? Set overwrite=True to ' f'overwrite.') raise ValueError(error_message) # Get the QC column, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') # Get PCs, and check that they are float32 and C-contiguous check_type(PC_key, 'PC_key', str, 'a string') if PC_key not in self._obsm: error_message = f'PC_key {PC_key!r} is not a key of obsm' if PC_key == 'pca': error_message += ( '; did you forget to run PCA() (and possibly neighbors()) ' 'before localmap()?') raise ValueError(error_message) PCs = self._obsm[PC_key] if PCs.dtype != np.float32: error_message = \ f'obsm[{PC_key!r}].dtype is {PCs.dtype!r}, but must be float32' raise TypeError(error_message) if not PCs.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{PC_key!r}] is not C-contiguous; make it C-contiguous ' f'with pipe_obsm_key({PC_key!r}, np.ascontiguousarray)') raise ValueError(error_message) # Get the nearest-neighbor indices and distances, and check that they # are uint32 and float32, respectively, C-contiguous, and have the same # width check_type(neighbors_key, 'neighbors_key', str, 'a string') if neighbors_key not in self._obsm: error_message = \ f'neighbors_key {neighbors_key!r} is not a key of obsm' if neighbors_key == 'neighbors': error_message += ( '; did you forget to run neighbors() before localmap()?') raise ValueError(error_message) neighbors = self._obsm[neighbors_key] if neighbors.dtype != np.uint32: error_message = ( f'obsm[{neighbors_key!r}] must have uint32 data type, but ' f'has data type {str(neighbors.dtype)!r}') raise TypeError(error_message) if not neighbors.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{neighbors_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)') raise ValueError(error_message) check_type(distances_key, 'distances_key', str, 'a string') if distances_key not in self._obsm: error_message = \ f'distances_key {distances_key!r} is not a key of obsm' if distances_key == 'distances': error_message += ( '; did you forget to run neighbors() before localmap()?') raise ValueError(error_message) distances = self._obsm[distances_key] if distances.dtype != np.float32: error_message = ( f'obsm[{distances_key!r}] must have float32 data type, but ' f'has data type {str(distances.dtype)!r}') raise TypeError(error_message) if not distances.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{distances_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({distances_key!r}, np.ascontiguousarray)') raise ValueError(error_message) if neighbors.shape[1] != distances.shape[1]: error_message = ( f'obsm[{neighbors_key!r}] and obsm[{distances_key!r}] have ' f'different numbers of columns ({neighbors.shape[1]:,} vs ' f'{distances.shape[1]:,}') raise ValueError(error_message) # Check that `num_extra_neighbors` is ≥ 0 check_type(num_extra_neighbors, 'num_extra_neighbors', int, 'a non-negative integer') check_bounds(num_extra_neighbors, 'num_extra_neighbors', 0) # Check that `num_iterations` is a positive integer or length-3 tuple # thereof check_type(num_iterations, 'num_iterations', (int, tuple), 'a positive integer or length-3 tuple of positive ' 'integers') if isinstance(num_iterations, tuple): if len(num_iterations) != 3: error_message = ( f'num_iterations must be a positive integer or ' f'length-3 tuple of positive integers, but has length ' f'{len(num_iterations):,}') raise ValueError(error_message) for step, step_num_iterations in enumerate(num_iterations): check_type(step_num_iterations, f'num_iterations[{step!r}]', int, 'a positive integer') check_bounds(step_num_iterations, f'num_iterations[{step!r}]', 1) num_phase_1_iterations, num_phase_2_iterations, \ num_phase_3_iterations = num_iterations else: check_bounds(num_iterations, 'num_iterations', 1) num_phase_1_iterations = num_phase_2_iterations = 100 num_phase_3_iterations = num_iterations # Check that `learning_rate` and `max_distance` are positive # floating-point numbers check_type(learning_rate, 'learning_rate', (int, float), 'a positive number') check_bounds(learning_rate, 'learning_rate', 0, left_open=True) check_type(max_distance, 'max_distance', (int, float), 'a positive number') check_bounds(max_distance, 'max_distance', 0, left_open=True) # Check that `seed` is an integer check_type(seed, 'seed', int, 'an integer') # Check that `num_threads` is a positive integer, -1 or `None`; if # `None`, set to `self.num_threads`, and if -1, set to `os.cpu_count()` num_threads = self._process_num_threads(num_threads) # Check that `match_parallel` is Boolean, and `False` unless # `num_threads=1` check_type(match_parallel, 'match_parallel', bool, 'Boolean') if match_parallel and num_threads != 1: error_message = \ 'match_parallel must be False unless num_threads is 1' raise ValueError(error_message) # Subset PCs and nearest-neighbor indices to QCed cells only, if # `QC_column` is not `None` if QC_column is not None: QC_column_NumPy = QC_column.to_numpy() if num_threads == 1: PCs = PCs[QC_column_NumPy] neighbors = neighbors[QC_column_NumPy] distances = distances[QC_column_NumPy] else: indices = np.flatnonzero(QC_column_NumPy) PCs = parallel_subset_2d(PCs, indices, num_threads) neighbors = parallel_subset_2d(neighbors, indices, num_threads) distances = parallel_subset_2d(distances, indices, num_threads) # Check that there are at least 7 cells (since # `sample_mid_near_pairs() requires 6 other cells) num_cells = PCs.shape[0] if num_cells < 7: error_message = ( f'there are fewer than 7 cells, so the embedding cannot be ' f'calculated') raise ValueError(error_message) # Check that `num_neighbors` is between 1 and `num_cells - 1` check_type(num_neighbors, 'num_neighbors', int, 'a positive integer') if not 1 <= num_neighbors < num_cells: error_message = ( f'num_neighbors is {num_neighbors:,}, but must be ≥ 1 and ' f'less than the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that `num_mid_near_pairs` and `num_further_pairs` are between 1 # and `num_cells` for variable, variable_name in ( (num_mid_near_pairs, 'num_mid_near_pairs'), (num_further_pairs, 'num_further_pairs')): check_type(variable, variable_name, int, 'a positive integer') if not 1 <= variable <= num_cells: error_message = ( f'{variable_name} is {variable:,}, but must be ≥ 1 and ≤ ' f'the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that there are at least `num_neighbors + num_further_pairs + 1` # cells (since `sample_further_pairs()` requires # `num_neighbors + num_further_pairs` other cells) if num_cells < num_neighbors + num_further_pairs + 1: error_message = ( f'there are fewer than ' f'{num_neighbors + num_further_pairs + 1} (num_neighbors + ' f'num_further_pairs + 1) cells, so the embedding cannot be ' f'calculated') raise ValueError(error_message) # Define `num_total_neighbors` as `num_neighbors + num_extra_neighbors` num_total_neighbors = num_neighbors + num_extra_neighbors # Check that `num_total_neighbors` is less than `num_cells` if num_total_neighbors >= num_cells: error_message = ( f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors ' f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but ' f'must be less than the number of cells ({num_cells:,})') raise ValueError(error_message) # Check that `neighbors` and `distances` contain at most # `num_total_neighbors` nearest neighbors if num_total_neighbors > neighbors.shape[1]: error_message = ( f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors ' f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but ' f'obsm[{neighbors_key!r}] has only {neighbors.shape[1]} ' f'columns') raise ValueError(error_message) if num_total_neighbors > distances.shape[1]: error_message = ( f'num_neighbors ({num_neighbors:,}) + num_extra_neighbors ' f'({num_extra_neighbors:,}) is {num_total_neighbors:,}, but ' f'obsm[{distances_key!r}] has only {distances.shape[1]} ' f'columns') raise ValueError(error_message) # Run LocalMAP if num_threads == 1: embedding = np.empty((num_cells, 2), dtype=np.float32) momentum = np.empty((num_cells, 2), dtype=np.float32) velocity = np.empty((num_cells, 2), dtype=np.float32) gradients = np.empty((num_cells, 2), dtype=np.float32) average_distances = np.empty(num_cells, dtype=np.float32) neighbor_pairs = np.empty((num_cells, num_neighbors), dtype=np.uint32) mid_near_pairs = np.empty((num_cells, num_mid_near_pairs), dtype=np.uint32) further_pairs = np.empty((num_cells, num_further_pairs), dtype=np.uint32) if match_parallel: neighbor_pair_indices = \ np.empty(2 * num_cells * num_neighbors, dtype=np.uint32) neighbor_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32) mid_near_pair_indices = np.empty( 2 * num_cells * num_mid_near_pairs, dtype=np.uint32) mid_near_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32) further_pair_indices = np.empty( 2 * num_cells * num_further_pairs, dtype=np.uint32) further_pair_indptr = np.empty(num_cells + 1, dtype=np.uint32) else: neighbor_pair_indices = np.array([], dtype=np.uint32) neighbor_pair_indptr = np.array([], dtype=np.uint32) mid_near_pair_indices = np.array([], dtype=np.uint32) mid_near_pair_indptr = np.array([], dtype=np.uint32) further_pair_indices = np.array([], dtype=np.uint32) further_pair_indptr = np.array([], dtype=np.uint32) else: embedding = numa_zeros((num_cells, 2), dtype=np.float32) momentum = numa_zeros((num_cells, 2), dtype=np.float32) velocity = numa_zeros((num_cells, 2), dtype=np.float32) gradients = numa_zeros((num_cells, 2), dtype=np.float32) average_distances = numa_zeros(num_cells, dtype=np.float32) neighbor_pairs = numa_zeros((num_cells, num_neighbors), dtype=np.uint32) mid_near_pairs = numa_zeros((num_cells, num_mid_near_pairs), dtype=np.uint32) further_pairs = numa_zeros((num_cells, num_further_pairs), dtype=np.uint32) neighbor_pair_indices = \ numa_zeros(2 * num_cells * num_neighbors, dtype=np.uint32) neighbor_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32) mid_near_pair_indices = numa_zeros( 2 * num_cells * num_mid_near_pairs, dtype=np.uint32) mid_near_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32) further_pair_indices = numa_zeros( 2 * num_cells * num_further_pairs, dtype=np.uint32) further_pair_indptr = numa_zeros(num_cells + 1, dtype=np.uint32) localmap(PCs, embedding, momentum, velocity, gradients, average_distances, neighbor_pairs, mid_near_pairs, further_pairs, neighbor_pair_indices, neighbor_pair_indptr, mid_near_pair_indices, mid_near_pair_indptr, further_pair_indices, further_pair_indptr, neighbors, distances, num_neighbors, num_extra_neighbors, num_mid_near_pairs, num_further_pairs, num_phase_1_iterations, num_phase_2_iterations, num_phase_3_iterations, learning_rate, max_distance, seed, match_parallel, neighbors_key, num_threads) # If `QC_column` was specified, back-project from QCed cells to all # cells, filling with `NaN` if QC_column is not None: embedding_QCed = embedding embedding = np.full((len(self), embedding_QCed.shape[1]), np.nan, dtype=np.float32) embedding[QC_column_NumPy] = embedding_QCed return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm | {embedding_key: embedding}, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def umap(self, *, QC_column: SingleCellColumn | None = 'passed_QC', PC_key: str = 'pca', neighbors_key: str = 'neighbors', distances_key: str = 'distances', embedding_key: str = 'umap', num_iterations: int | np.integer = 200, alpha: int | float | np.integer | np.floating = 1, gamma: int | float | np.integer | np.floating = 1, negative_sample_rate: int = 5, a: int | float | np.integer | np.floating | None = None, b: int | float | np.integer | np.floating | None = None, spread: float = 1, min_dist: float = 0.5, seed: int | np.integer = 0, overwrite: bool = False, hogwild: bool = False, num_threads: int | np.integer | None = None) -> SingleCell: """ Calculate a two-dimensional embedding of this SingleCell dataset with UMAP (Uniform Manifold Approximation and Projection), suitable for plotting with `plot_embedding()`. Use `hogwild=True` to run in parallel. Results will not be reproducible! This function is intended to be run after `PCA()` and `neighbors()`. By default, it uses `obsm['pca']`, `obsm['neighbors']`, and `obsm['distances']` as the inputs to UMAP, and stores the output in `obsm['umap']` as a `len(obs)` × 2 NumPy array. It can also be run on Harmony embeddings by running `harmonize()` and then specifying `PC_key='harmony'`. Args: QC_column: an optional Boolean column of `obs` indicating which cells passed QC. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to include all cells. Cells failing QC will be ignored and have their embeddings set to `NaN`. PC_key: the key of `obsm` containing the principal components calculated with `PCA()`, to use as an input for the embedding calculation neighbors_key: the key of `obsm` containing the nearest-neighbor indices for each cell, to use as an input for the embedding calculation distances_key: the key of `obsm` containing the squared Euclidean distance to each nearest neighbor in `neighbors_key`, to use as an input for the embedding calculation embedding_key: the key of `obsm` where the embeddings will be stored num_iterations: the number of optimization iterations. In umap-learn, this defaulted to 500 for datasets of 10,000 elements or less, and 200 for datasets larger than 10,000 elements. alpha: the initial learning rate for optimization gamma: the weight applied to negative samples during optimization negative_sample_rate: the number of negative samples per positive sample a: UMAP curve parameter; if `None`, will be fit based on the values of `spread` and `min_dist`. Either both or neither of `a` and `b` must be `None`. b: UMAP curve parameter; if `None`, will be fit based on the values of `spread` and `min_dist`. Either both or neither of `a` and `b` must be `None`. spread: the effective scale of embedded points. Only used when `a` and `b` are `None`. min_dist: the minimum distance between points in the embedding. Only used when `a` and `b` are `None`. seed: the random seed to use for UMAP overwrite: if `True`, overwrite `embedding_key` if already present in `obsm`, instead of raising an error hogwild: if `True`, go [Hogwild!](https://arxiv.org/abs/1106.5730) and optimize the embedding in parallel. Results will not be reproducible! num_threads: the number of threads to use when running `umap()`. Cannot be specified when `hogwild=False`. When `hogwild=True`, must be explicitly specified and greater than 1 unless `self.num_threads` is greater than 1, in which case it can be left unset. Set `num_threads=-1` to use all available cores, as determined by `os.cpu_count()`, or leave unset to use `self.num_threads` cores when `hogwild=True` and one core when `hogwild=False`. Does not affect the returned SingleCell dataset's `num_threads`; this will always be the same as the original dataset's `num_threads`. Returns: A new SingleCell dataset with the UMAP embedding stored in `obsm[embedding_key]`. """ # Check that `QC_column` is valid, if not `None` if QC_column is not None: QC_column = self._get_column( 'obs', QC_column, 'QC_column', pl.Boolean, allow_missing=QC_column == 'passed_QC') # Check that `PC_key` is a string, and get PCs check_type(PC_key, 'PC_key', str, 'a string') if PC_key not in self._obsm: error_message = f'PC_key {PC_key!r} is not a key of obsm' if PC_key == 'pca': error_message += ( '; did you forget to run PCA() (and possibly neighbors()) ' 'before umap()?') raise ValueError(error_message) PCs = self._obsm[PC_key] if PCs.dtype != np.float32: error_message = \ f'obsm[{PC_key!r}].dtype is {PCs.dtype!r}, but must be float32' raise TypeError(error_message) if not PCs.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{PC_key!r}] is not C-contiguous; make it C-contiguous ' f'with pipe_obsm_key({PC_key!r}, np.ascontiguousarray)') raise ValueError(error_message) # Check that `neighbors_key` is a string, and get nearest-neighbor # indices check_type(neighbors_key, 'neighbors_key', str, 'a string') if neighbors_key not in self._obsm: error_message = \ f'neighbors_key {neighbors_key!r} is not a key of obsm' if neighbors_key == 'neighbors': error_message += ( '; did you forget to run neighbors() before ' 'umap()?') raise ValueError(error_message) neighbors = self._obsm[neighbors_key] if neighbors.dtype != np.uint32: error_message = ( f'obsm[{neighbors_key!r}] must have uint32 data type, but ' f'has data type {str(neighbors.dtype)!r}') raise TypeError(error_message) if not neighbors.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{neighbors_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({neighbors_key!r}, np.ascontiguousarray)') raise ValueError(error_message) # Check that `distances_key` is a string, and get nearest-neighbor # distances check_type(distances_key, 'distances_key', str, 'a string') if distances_key not in self._obsm: error_message = \ f'distances_key {distances_key!r} is not a key of obsm' if distances_key == 'distances': error_message += ( '; did you forget to run neighbors() before ' 'umap()?') raise ValueError(error_message) distances = self._obsm[distances_key] if distances.dtype != np.float32: error_message = ( f'obsm[{distances_key!r}] must have float32 data type, but ' f'has data type {str(distances.dtype)!r}') raise TypeError(error_message) if not distances.flags['C_CONTIGUOUS']: error_message = ( f'obsm[{distances_key!r}] is not C-contiguous; make it ' f'C-contiguous with ' f'pipe_obsm_key({distances_key!r}, np.ascontiguousarray)') raise ValueError(error_message) if neighbors.shape[1] != distances.shape[1]: error_message = ( f'obsm[{neighbors_key!r}] and obsm[{distances_key!r}] have ' f'different numbers of columns ({neighbors.shape[1]:,} vs ' f'{distances.shape[1]:,})') raise ValueError(error_message) # Check that `embedding_key` is a string check_type(embedding_key, 'embedding_key', str, 'a string') # Check that `embedding_key` is not already a key in `obsm`, unless # `overwrite=True` check_type(overwrite, 'overwrite', bool, 'Boolean') if not overwrite and embedding_key in self._obsm: error_message = ( f'embedding_key {embedding_key!r} is already a key of obsm; ' f'did you already run umap()? Set overwrite=True to ' f'overwrite.') raise ValueError(error_message) # Check that `num_iterations` is a positive integer check_type(num_iterations, 'num_iterations', int, 'a positive integer') check_bounds(num_iterations, 'num_iterations', 1) # Check that `alpha` and `gamma` are positive numbers for variable, variable_name in (alpha, 'alpha'), (gamma, 'gamma'): check_type(variable, variable_name, (int, float), 'a positive number') check_bounds(variable, variable_name, 0, left_open=True) # Check that `negative_sample_rate` is a positive integer check_type(negative_sample_rate, 'negative_sample_rate', int, 'a positive integer') check_bounds(negative_sample_rate, 'negative_sample_rate', 1) # Check that either both or neither of `a` and `b` are `None` if (a is None) != (b is None): error_message = ( f'either both or neither of a and b must be None') raise ValueError(error_message) # If `a` and `b` are `None`, check `spread` and `min_dist` if a is None and b is None: check_type(spread, 'spread', (int, float), 'a positive number') check_bounds(spread, 'spread', 0, left_open=True) check_type(min_dist, 'min_dist', (int, float), 'a non-negative number') check_bounds(min_dist, 'min_dist', 0) # Check that `hogwild` is Boolean check_type(hogwild, 'hogwild', bool, 'Boolean') # Check that `num_threads` is valid. When `hogwild=False`, # `num_threads` cannot be specified and defaults to 1. When # `hogwild=True`, `num_threads` must be explicitly specified and # greater than 1 unless `self.num_threads` is greater than 1, in which # case it can be left unset. If `num_threads` is -1, use all available # cores. if not hogwild: if num_threads is not None: error_message = \ 'num_threads cannot be specified when hogwild=False' raise ValueError(error_message) num_threads = 1 elif num_threads is None: if self._num_threads <= 1: error_message = ( 'when hogwild=True and self.num_threads is 1, num_threads ' 'must be explicitly specified and greater than 1') raise ValueError(error_message) num_threads = self._num_threads else: num_threads = self._process_num_threads(num_threads) if num_threads == 1: error_message = \ 'num_threads must be greater than 1 when hogwild=True' raise ValueError(error_message) # Check that there are at least 2 cells num_cells = PCs.shape[0] if num_cells < 2: error_message = ( f'there are fewer than 2 cells, so the embedding cannot be ' f'calculated') raise ValueError(error_message) # Subset PCs and nearest-neighbor indices to QCed cells only, if # `QC_column` is not `None` if QC_column is not None: QC_column_NumPy = QC_column.to_numpy() PCs = PCs[QC_column_NumPy] neighbors = neighbors[QC_column_NumPy] distances = distances[QC_column_NumPy] # If `a` and `b` are `None`, select them based on `spread` and # `min_dist` if a is None and b is None: from scipy.optimize import curve_fit xv = np.linspace(0, spread * 3, 300) yv = np.zeros(xv.shape) yv[xv < min_dist] = 1 yv[xv >= min_dist] = \ np.exp(-(xv[xv >= min_dist] - min_dist) / spread) a, b = curve_fit( lambda x, a, b: 1.0 / (1.0 + a * x ** (2 * b)), xv, yv)[0] def fuzzy_simplicial_set(neighbors, distances, num_threads): N, K = neighbors.shape if num_threads == 1: data = np.empty(neighbors.size, dtype=np.float32) else: data = numa_zeros(neighbors.size, dtype=np.float32) indices = neighbors.ravel().view(np.int32) indptr = np.arange(0, (N + 1) * K, K, dtype=np.int32) umap_fuzzy_weights(distances, data, num_threads) result = csr_array((data, indices, indptr), shape=(N, N)) result = result + result.T - result * result.T return result graph = fuzzy_simplicial_set(neighbors, distances, num_threads) def spectral_layout(data, graph, random_state): from scipy.linalg import eigh from scipy.sparse.csgraph import connected_components from scipy.sparse.linalg import eigsh from scipy.spatial.distance import cdist n_connected_components, component_labels = \ connected_components(graph) if n_connected_components > 1: # Multi-component layout result = np.empty((graph.shape[0], 2), dtype=np.float32) if n_connected_components > 4: component_centroids = \ np.empty((n_connected_components, data.shape[1])) for label in range(n_connected_components): component_centroids[label] = \ data[component_labels == label].mean(axis=0) distance_matrix = \ cdist(component_centroids, component_centroids) affinity_matrix = np.exp(-(distance_matrix ** 2)) # Spectral embedding of the affinity matrix inv_sqrt_degree = \ affinity_matrix.sum(axis=1).ravel() ** -0.5 normalized_affinity_matrix = \ affinity_matrix * inv_sqrt_degree[:, None] * \ inv_sqrt_degree[None, :] eigenvectors = eigh(normalized_affinity_matrix)[1] # skip trivial largest eigenvector, reorder largest to # smallest, just take the next-largest two meta_embedding = eigenvectors[:, [-2, -3]] meta_embedding /= np.abs(meta_embedding).max() else: k = int(np.ceil(n_connected_components / 2)) base = np.hstack([np.eye(k), np.zeros((k, 2 - k))]) meta_embedding = \ np.vstack([base, -base])[:n_connected_components] for label in range(n_connected_components): mask = component_labels == label component_graph = graph[np.ix_(mask, mask)] distances = cdist(meta_embedding[label:label + 1], meta_embedding) data_range = distances[distances > 0].min() / 2 if component_graph.shape[0] < 4: result[component_labels == label] = \ random_state.uniform( low=-data_range, high=data_range, size=(component_graph.shape[0], 2)) + \ meta_embedding[label] else: component_embedding = spectral_layout( data=None, graph=component_graph, random_state=random_state) expansion = \ data_range / np.abs(component_embedding).max() component_embedding *= expansion result[component_labels == label] = \ component_embedding + meta_embedding[label] return result # UMAP originally took the smallest-magnitude eigenvalues # (`which='SM'`) of `I - D @ graph @ D` where # `D = diag(inv_sqrt_degree)`, but this is equivalent to taking the # most positive eigenvalues (`which='LA'`) of `D @ graph @ D`, # which converges faster. Similarly, use the trivial eigenvector as # an initial guess via `v0=sqrt_degree` to speed up convergence, # rather than the original's `v0=np.ones(L.shape[0])`. sqrt_degree = np.sqrt(graph.sum(axis=1).ravel()) inv_sqrt_degree = 1 / sqrt_degree normalized_graph = graph\ .multiply(inv_sqrt_degree[:, None])\ .multiply(inv_sqrt_degree[None, :]) with threadpool_limits(1): eigenvalues, eigenvectors = eigsh(normalized_graph, k=3, which='LA', tol=1e-4, v0=sqrt_degree) # skip trivial largest eigenvector, reorder largest to smallest return np.ascontiguousarray(eigenvectors[:, [1, 0]]) embedding = spectral_layout( PCs, graph, random_state=np.random.RandomState(seed)) # Create the graph in COO format, equivalent to # `head = graph.tocoo().row` and `tail = graph.tocoo().col` head = np.repeat(np.arange(graph.shape[0], dtype=graph.indices.dtype), repeats=np.diff(graph.indptr)) tail = graph.indices # Optimize the embedding via stochastic gradient descent, in parallel # without locks if `hogwild=True` umap_optimize( embedding=embedding, head=head, tail=tail, weights=graph.data, num_iterations=num_iterations, a=a, b=b, gamma=gamma, initial_alpha=alpha, negative_sample_rate=negative_sample_rate, seed=seed, num_threads=num_threads) # If `QC_column` was specified, back-project from QCed cells to all # cells, filling with `NaN` if QC_column is not None: embedding_QCed = embedding embedding = np.full((len(self), embedding_QCed.shape[1]), np.nan, dtype=np.float32) embedding[QC_column_NumPy] = embedding_QCed return SingleCell(X=self._X, obs=self._obs, var=self._var, obsm=self._obsm | {embedding_key: embedding}, varm=self._varm, obsp=self._obsp, varp=self._varp, uns=self._uns, num_threads=self._num_threads)
[docs] def plot_pacmap( self, color_column: SingleCellColumn | None, filename: str | Path | None = None, /, *, cells_to_plot_column: SingleCellColumn | None = 'passed_QC', ax: 'Axes' | None = None, figure_kwargs: dict[str, Any] | None = None, point_size: int | float | np.integer | np.floating | str | None = None, sort_by_frequency: bool = False, colormap: str | 'Colormap' | dict[Any, Color] = None, lightness_range: tuple[float | np.floating, float | np.floating] | None = (100 / 3, 200 / 3), chroma_range: tuple[float | np.floating, float | np.floating] | None = (50, 100), hue_range: tuple[float | np.floating, float | np.floating] | None = None, first_color: Color = '#008cb9', stride: int | np.integer = 5, default_color: Color = 'lightgray', scatter_kwargs: dict[str, Any] | None = None, label: bool = False, label_kwargs: dict[str, Any] | None = None, legend: bool = True, legend_kwargs: dict[str, Any] | None = None, colorbar: bool = True, colorbar_kwargs: dict[str, Any] | None = None, title: str | None = None, title_kwargs: dict[str, Any] | None = None, xlabel: str | None = 'Component 1', xlabel_kwargs: dict[str, Any] | None = None, ylabel: str | None = 'Component 2', ylabel_kwargs: dict[str, Any] | None = None, xlim: tuple[int | float | np.integer | np.floating, int | float | np.integer | np.floating] | None = None, ylim: tuple[int | float | np.integer | np.floating, int | float | np.integer | np.floating] | None = None, despine: bool = True, savefig_kwargs: dict[str, Any] | None = None) -> None: """ Plot a PaCMAP embedding created with `pacmap()`. Syntactic sugar for `plot_embedding('pacmap', ...)`. Args: color_column: an optional column of `obs` indicating how to color each cell in the plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Can be discrete (e.g. cell-type labels), specified as a String/Enum/Categorical column, or quantitative (e.g. the number of UMIs per cell), specified as an integer/floating-point column. Missing (`null`) cells will be plotted with the color `default_color`. Set to `None` to use `default_color` for all cells. filename: the file to save to. If `None`, generate the plot but do not save it, which allows it to be shown interactively or modified further before saving. cells_to_plot_column: an optional Boolean column of `obs` indicating which cells to plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to plot all cells passing QC. ax: the Matplotlib axes to save the plot onto; if `None`, create a new figure with Matpotlib's constrained layout and plot onto it figure_kwargs: a dictionary of keyword arguments to be passed to `plt.figure` when `ax` is `None`, such as: - `figsize`: a two-element sequence of the width and height of the figure in inches. Defaults to `[8, 6]`, 25% larger than Matplotlib's default of `[6.4, 4.8]`. - `layout`: the layout mechanism used by Matplotlib to avoid overlapping plot elements. Defaults to `'constrained'`, instead of Matplotlib's default of `None`. point_size: the size of the points for each cell; defaults to 20,000 divided by the number of cells. Can be a single number, or the name of a column of `obs` to make each point a different size. sort_by_frequency: if `True`, assign colors and sort the legend in order of decreasing frequency; if `False` (the default), use [natural sort order](en.wikipedia.org/wiki/Natural_sort_order). Cannot be `True` unless `colormap` is `None` and `color_column` is discrete; if `colormap` is not `None`, the plot order is determined by the order of the keys in `colormap`. colormap: a string or Colormap object indicating the Matplotlib colormap to use; or, if `color_column` is discrete, a dictionary mapping values in `color_column` to Matplotlib colors (cells with values of `color_column` that are not in the dictionary will be plotted in the color `default_color`). Defaults to `plt.rcParams['image.cmap']` (`'viridis'` by default) if `color_column` is continous, or the colors from a maximally perceptually distinct colormap if `color_column` is discrete (with colors assigned in decreasing order of frequency). Cannot be specified if `color_column` is `None`. lightness_range: a two-element tuple with the lightness range of colors to generate, or `None` to take the full range: `[0, 100]`. Can only be specified when `color_column` is discrete and `colormap` is `None`. chroma_range: a two-element tuple with the chroma range of colors to generate, or `None` to take the full range: `[0, 100]`. Grays have low chroma, and vivid colors have high chroma. Can only be specified when `color_column` is discrete and `colormap` is `None`. hue_range: a two-element tuple with the hue range of colors to generate, or `None` to take the full range: `[0, 360]`. Red is at 0°, green at 120°, and blue at 240°. Because it wraps around, the first element of the tuple can be greater than the second, unlike for `lightness_range` and `chroma_range`. Can only be specified when `color_column` is discrete and `colormap` is `None`. first_color: the first color of the palette. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. stride: as an optimization, consider only RGB colors where R, G, and B are all multiples of this value. Must be a small divisor of 255: 1, 3, 5, 15, or 17. Set to 1 for the best possible solution, at orders of magnitude more computational cost. default_color: the default color to plot cells in when `color_column` is `None`, or when certain cells have missing (`null`) values for `color_column`, or when `colormap` is a dictionary and some cells have values of `color_column` that are not in the dictionary. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. scatter_kwargs: a dictionary of keyword arguments to be passed to `ax.scatter()`, such as: - `rasterized`: whether to convert the scatter plot points to a raster (bitmap) image when saving to a vector format like PDF. Defaults to `True`, instead of Matplotlib's default of `False`. - `marker`: the shape to use for plotting each cell - `norm`, `vmin`, and `vmax`: control how the `colormap` maps the numbers in `color_column` to colors, if `color_column` is numeric - `alpha`: the transparency of each point - `linewidths` and `edgecolors`: the width and color of the borders around each marker. These are absent by default (`linewidths=0`, `edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's default. Both arguments can be either single values or sequences. - `zorder`: the order in which the cells are plotted, with higher values appearing on top of lower ones. Specifying `s`, `c`/`color`, or `cmap` will raise an error, since these arguments conflict with the `point_size`, `color_column`, and `colormap` arguments, respectively. label: whether to label cells with each distinct value of `color_column`. Labels will be placed at the median x and y position of the points with that color. Can only be `True` when `color_column` is discrete. When set to `True`, you may also want to set `legend=False` to avoid redundancy. label_kwargs: a dictionary of keyword arguments to be passed to `ax.text()` when adding labels to control the text properties, such as: - `color` and `size` to modify the text color/size - `verticalalignment` and `horizontalalignment` to control vertical and horizontal alignment. By default, unlike Matplotlib, these are both set to `'center'`. - `path_effects` to set properties for the border around the text. By default, set to `matplotlib.patheffects.withStroke(linewidth=3, foreground='white', alpha=0.75)` instead of Matplotlib's default of `None`, to put a semi-transparent white border around the labels for better contrast. Can only be specified when `label=True`. legend: whether to add a legend for each value in `color_column`. Ignored unless `color_column` is discrete. legend_kwargs: a dictionary of keyword arguments to be passed to `ax.legend()` to modify the legend, such as: - `loc`, `bbox_to_anchor`, and `bbox_transform` to set its location. By default, `loc` is set to `'center left'` and `bbox_to_anchor` to `(1, 0.5)` to put the legend to the right of the plot, anchored at the middle. - `ncols` to set its number of columns. By default, set to `ceil(obs[color_column].n_unique() / 24)` to have at most 24 items per column. - `prop`, `fontsize`, and `labelcolor` to set its font properties - `facecolor` and `framealpha` to set its background color and transparency - `frameon=True` or `edgecolor` to add or color its border. `frameon` defaults to `False`, instead of Matplotlib's default of `True`. - `title` to add a legend title Can only be specified when `color_column` is discrete and `legend=True`. colorbar: whether to add a colorbar. Ignored unless `color_column` is quantitative. colorbar_kwargs: a dictionary of keyword arguments to be passed to `plt.colorbar()`, such as: - `location`: `'left'`, `'right'`, `'top'`, or `'bottom'` - `orientation`: `'vertical'` or `'horizontal'` - `fraction`: the fraction of the axes to allocate to the colorbar. Defaults to 0.15. - `shrink`: the fraction to multiply the size of the colorbar by. Defaults to 0.5, instead of Matplotlib's default of 1. - `aspect`: the ratio of the colorbar's long to short dimensions. Defaults to 20. - `pad`: the fraction of the axes between the colorbar and the rest of the figure. Defaults to 0.01, instead of Matplotlib's default of 0.05 if vertical and 0.15 if horizontal. Can only be specified when `color_column` is quantitative and `colorbar=True`. title: the title of the plot, or `None` to not add a title title_kwargs: a dictionary of keyword arguments to be passed to `ax.set_title()` to control text properties, such as `color` and `size`. Can only be specified when `title` is not `None`. xlabel: the x-axis label, or `None` to not label the x-axis xlabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_xlabel()` to control text properties, such as `color` and `size`. Can only be specified when `xlabel` is not `None`. ylabel: the y-axis label, or `None` to not label the y-axis ylabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_ylabel()` to control text properties, such as `color` and `size`. Can only be specified when `ylabel` is not `None`. xlim: a length-2 tuple of the left and right x-axis limits, or `None` to set the limits based on the data ylim: a length-2 tuple of the bottom and top y-axis limits, or `None` to set the limits based on the data despine: whether to remove the top and right spines (borders of the plot area) from the plot savefig_kwargs: a dictionary of keyword arguments to be passed to `plt.savefig()`, such as: - `dpi`: defaults to 300 instead of Matplotlib's default of 150 - `bbox_inches`: the bounding box of the portion of the figure to save; defaults to `'tight'` (crop out any blank borders) instead of Matplotlib's default of `None` (save the entire figure) - `pad_inches`: the number of inches of padding to add on each of the four sides of the figure when saving. Defaults to `'layout'` (use the padding from the constrained layout engine, when `ax` is not `None`) instead of Matplotlib's default of 0.1. - `transparent`: whether to save with a transparent background; defaults to `True` if saving to a PDF (i.e. when `filename` ends with `'.pdf'`) and `False` otherwise, instead of Matplotlib's default of always being `False`. Can only be specified when `filename` is specified. """ self.plot_embedding('pacmap', color_column, filename, cells_to_plot_column=cells_to_plot_column, ax=ax, figure_kwargs=figure_kwargs, point_size=point_size, sort_by_frequency=sort_by_frequency, colormap=colormap, lightness_range=lightness_range, chroma_range=chroma_range, hue_range=hue_range, first_color=first_color, stride=stride, default_color=default_color, scatter_kwargs=scatter_kwargs, label=label, label_kwargs=label_kwargs, legend=legend, legend_kwargs=legend_kwargs, colorbar=colorbar, colorbar_kwargs=colorbar_kwargs, title=title, title_kwargs=title_kwargs, xlabel=xlabel, xlabel_kwargs=xlabel_kwargs, ylabel=ylabel, ylabel_kwargs=ylabel_kwargs, xlim=xlim, ylim=ylim, despine=despine, savefig_kwargs=savefig_kwargs)
[docs] def plot_localmap( self, color_column: SingleCellColumn | None, filename: str | Path | None = None, /, *, cells_to_plot_column: SingleCellColumn | None = 'passed_QC', ax: 'Axes' | None = None, figure_kwargs: dict[str, Any] | None = None, point_size: int | float | np.integer | np.floating | str | None = None, sort_by_frequency: bool = False, colormap: str | 'Colormap' | dict[Any, Color] = None, lightness_range: tuple[float | np.floating, float | np.floating] | None = (100 / 3, 200 / 3), chroma_range: tuple[float | np.floating, float | np.floating] | None = (50, 100), hue_range: tuple[float | np.floating, float | np.floating] | None = None, first_color: Color = '#008cb9', stride: int | np.integer = 5, default_color: Color = 'lightgray', scatter_kwargs: dict[str, Any] | None = None, label: bool = False, label_kwargs: dict[str, Any] | None = None, legend: bool = True, legend_kwargs: dict[str, Any] | None = None, colorbar: bool = True, colorbar_kwargs: dict[str, Any] | None = None, title: str | None = None, title_kwargs: dict[str, Any] | None = None, xlabel: str | None = 'Component 1', xlabel_kwargs: dict[str, Any] | None = None, ylabel: str | None = 'Component 2', ylabel_kwargs: dict[str, Any] | None = None, xlim: tuple[int | float | np.integer | np.floating, int | float | np.integer | np.floating] | None = None, ylim: tuple[int | float | np.integer | np.floating, int | float | np.integer | np.floating] | None = None, despine: bool = True, savefig_kwargs: dict[str, Any] | None = None) -> None: """ Plot a LocalMAP embedding created with `localmap()`. Syntactic sugar for `plot_embedding('localmap', ...)`. Args: color_column: an optional column of `obs` indicating how to color each cell in the plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Can be discrete (e.g. cell-type labels), specified as a String/Enum/Categorical column, or quantitative (e.g. the number of UMIs per cell), specified as an integer/floating-point column. Missing (`null`) cells will be plotted with the color `default_color`. Set to `None` to use `default_color` for all cells. filename: the file to save to. If `None`, generate the plot but do not save it, which allows it to be shown interactively or modified further before saving. cells_to_plot_column: an optional Boolean column of `obs` indicating which cells to plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to plot all cells passing QC. ax: the Matplotlib axes to save the plot onto; if `None`, create a new figure with Matpotlib's constrained layout and plot onto it figure_kwargs: a dictionary of keyword arguments to be passed to `plt.figure` when `ax` is `None`, such as: - `figsize`: a two-element sequence of the width and height of the figure in inches. Defaults to `[8, 6]`, 25% larger than Matplotlib's default of `[6.4, 4.8]`. - `layout`: the layout mechanism used by Matplotlib to avoid overlapping plot elements. Defaults to `'constrained'`, instead of Matplotlib's default of `None`. point_size: the size of the points for each cell; defaults to 20,000 divided by the number of cells. Can be a single number, or the name of a column of `obs` to make each point a different size. sort_by_frequency: if `True`, assign colors and sort the legend in order of decreasing frequency; if `False` (the default), use [natural sort order](en.wikipedia.org/wiki/Natural_sort_order). Cannot be `True` unless `colormap` is `None` and `color_column` is discrete; if `colormap` is not `None`, the plot order is determined by the order of the keys in `colormap`. colormap: a string or Colormap object indicating the Matplotlib colormap to use; or, if `color_column` is discrete, a dictionary mapping values in `color_column` to Matplotlib colors (cells with values of `color_column` that are not in the dictionary will be plotted in the color `default_color`). Defaults to `plt.rcParams['image.cmap']` (`'viridis'` by default) if `color_column` is continous, or the colors from a maximally perceptually distinct colormap if `color_column` is discrete (with colors assigned in decreasing order of frequency). Cannot be specified if `color_column` is `None`. lightness_range: a two-element tuple with the lightness range of colors to generate, or `None` to take the full range: `[0, 100]`. Can only be specified when `color_column` is discrete and `colormap` is `None`. chroma_range: a two-element tuple with the chroma range of colors to generate, or `None` to take the full range: `[0, 100]`. Grays have low chroma, and vivid colors have high chroma. Can only be specified when `color_column` is discrete and `colormap` is `None`. hue_range: a two-element tuple with the hue range of colors to generate, or `None` to take the full range: `[0, 360]`. Red is at 0°, green at 120°, and blue at 240°. Because it wraps around, the first element of the tuple can be greater than the second, unlike for `lightness_range` and `chroma_range`. Can only be specified when `color_column` is discrete and `colormap` is `None`. first_color: the first color of the palette. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. stride: as an optimization, consider only RGB colors where R, G, and B are all multiples of this value. Must be a small divisor of 255: 1, 3, 5, 15, or 17. Set to 1 for the best possible solution, at orders of magnitude more computational cost. default_color: the default color to plot cells in when `color_column` is `None`, or when certain cells have missing (`null`) values for `color_column`, or when `colormap` is a dictionary and some cells have values of `color_column` that are not in the dictionary. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. scatter_kwargs: a dictionary of keyword arguments to be passed to `ax.scatter()`, such as: - `rasterized`: whether to convert the scatter plot points to a raster (bitmap) image when saving to a vector format like PDF. Defaults to `True`, instead of Matplotlib's default of `False`. - `marker`: the shape to use for plotting each cell - `norm`, `vmin`, and `vmax`: control how the `colormap` maps the numbers in `color_column` to colors, if `color_column` is numeric - `alpha`: the transparency of each point - `linewidths` and `edgecolors`: the width and color of the borders around each marker. These are absent by default (`linewidths=0`, `edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's default. Both arguments can be either single values or sequences. - `zorder`: the order in which the cells are plotted, with higher values appearing on top of lower ones. Specifying `s`, `c`/`color`, or `cmap` will raise an error, since these arguments conflict with the `point_size`, `color_column`, and `colormap` arguments, respectively. label: whether to label cells with each distinct value of `color_column`. Labels will be placed at the median x and y position of the points with that color. Can only be `True` when `color_column` is discrete. When set to `True`, you may also want to set `legend=False` to avoid redundancy. label_kwargs: a dictionary of keyword arguments to be passed to `ax.text()` when adding labels to control the text properties, such as: - `color` and `size` to modify the text color/size - `verticalalignment` and `horizontalalignment` to control vertical and horizontal alignment. By default, unlike Matplotlib, these are both set to `'center'`. - `path_effects` to set properties for the border around the text. By default, set to `matplotlib.patheffects.withStroke(linewidth=3, foreground='white', alpha=0.75)` instead of Matplotlib's default of `None`, to put a semi-transparent white border around the labels for better contrast. Can only be specified when `label=True`. legend: whether to add a legend for each value in `color_column`. Ignored unless `color_column` is discrete. legend_kwargs: a dictionary of keyword arguments to be passed to `ax.legend()` to modify the legend, such as: - `loc`, `bbox_to_anchor`, and `bbox_transform` to set its location. By default, `loc` is set to `'center left'` and `bbox_to_anchor` to `(1, 0.5)` to put the legend to the right of the plot, anchored at the middle. - `ncols` to set its number of columns. By default, set to `ceil(obs[color_column].n_unique() / 24)` to have at most 24 items per column. - `prop`, `fontsize`, and `labelcolor` to set its font properties - `facecolor` and `framealpha` to set its background color and transparency - `frameon=True` or `edgecolor` to add or color its border. `frameon` defaults to `False`, instead of Matplotlib's default of `True`. - `title` to add a legend title Can only be specified when `color_column` is discrete and `legend=True`. colorbar: whether to add a colorbar. Ignored unless `color_column` is quantitative. colorbar_kwargs: a dictionary of keyword arguments to be passed to `plt.colorbar()`, such as: - `location`: `'left'`, `'right'`, `'top'`, or `'bottom'` - `orientation`: `'vertical'` or `'horizontal'` - `fraction`: the fraction of the axes to allocate to the colorbar. Defaults to 0.15. - `shrink`: the fraction to multiply the size of the colorbar by. Defaults to 0.5, instead of Matplotlib's default of 1. - `aspect`: the ratio of the colorbar's long to short dimensions. Defaults to 20. - `pad`: the fraction of the axes between the colorbar and the rest of the figure. Defaults to 0.01, instead of Matplotlib's default of 0.05 if vertical and 0.15 if horizontal. Can only be specified when `color_column` is quantitative and `colorbar=True`. title: the title of the plot, or `None` to not add a title title_kwargs: a dictionary of keyword arguments to be passed to `ax.set_title()` to control text properties, such as `color` and `size`. Can only be specified when `title` is not `None`. xlabel: the x-axis label, or `None` to not label the x-axis xlabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_xlabel()` to control text properties, such as `color` and `size`. Can only be specified when `xlabel` is not `None`. ylabel: the y-axis label, or `None` to not label the y-axis ylabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_ylabel()` to control text properties, such as `color` and `size`. Can only be specified when `ylabel` is not `None`. xlim: a length-2 tuple of the left and right x-axis limits, or `None` to set the limits based on the data ylim: a length-2 tuple of the bottom and top y-axis limits, or `None` to set the limits based on the data despine: whether to remove the top and right spines (borders of the plot area) from the plot savefig_kwargs: a dictionary of keyword arguments to be passed to `plt.savefig()`, such as: - `dpi`: defaults to 300 instead of Matplotlib's default of 150 - `bbox_inches`: the bounding box of the portion of the figure to save; defaults to `'tight'` (crop out any blank borders) instead of Matplotlib's default of `None` (save the entire figure) - `pad_inches`: the number of inches of padding to add on each of the four sides of the figure when saving. Defaults to `'layout'` (use the padding from the constrained layout engine, when `ax` is not `None`) instead of Matplotlib's default of 0.1. - `transparent`: whether to save with a transparent background; defaults to `True` if saving to a PDF (i.e. when `filename` ends with `'.pdf'`) and `False` otherwise, instead of Matplotlib's default of always being `False`. Can only be specified when `filename` is specified. """ self.plot_embedding('localmap', color_column, filename, cells_to_plot_column=cells_to_plot_column, ax=ax, figure_kwargs=figure_kwargs, point_size=point_size, sort_by_frequency=sort_by_frequency, colormap=colormap, lightness_range=lightness_range, chroma_range=chroma_range, hue_range=hue_range, first_color=first_color, stride=stride, default_color=default_color, scatter_kwargs=scatter_kwargs, label=label, label_kwargs=label_kwargs, legend=legend, legend_kwargs=legend_kwargs, colorbar=colorbar, colorbar_kwargs=colorbar_kwargs, title=title, title_kwargs=title_kwargs, xlabel=xlabel, xlabel_kwargs=xlabel_kwargs, ylabel=ylabel, ylabel_kwargs=ylabel_kwargs, xlim=xlim, ylim=ylim, despine=despine, savefig_kwargs=savefig_kwargs)
[docs] def plot_umap( self, color_column: SingleCellColumn | None, filename: str | Path | None = None, /, *, cells_to_plot_column: SingleCellColumn | None = 'passed_QC', ax: 'Axes' | None = None, figure_kwargs: dict[str, Any] | None = None, point_size: int | float | np.integer | np.floating | str | None = None, sort_by_frequency: bool = False, colormap: str | 'Colormap' | dict[Any, Color] = None, lightness_range: tuple[float | np.floating, float | np.floating] | None = (100 / 3, 200 / 3), chroma_range: tuple[float | np.floating, float | np.floating] | None = (50, 100), hue_range: tuple[float | np.floating, float | np.floating] | None = None, first_color: Color = '#008cb9', stride: int | np.integer = 5, default_color: Color = 'lightgray', scatter_kwargs: dict[str, Any] | None = None, label: bool = False, label_kwargs: dict[str, Any] | None = None, legend: bool = True, legend_kwargs: dict[str, Any] | None = None, colorbar: bool = True, colorbar_kwargs: dict[str, Any] | None = None, title: str | None = None, title_kwargs: dict[str, Any] | None = None, xlabel: str | None = 'Component 1', xlabel_kwargs: dict[str, Any] | None = None, ylabel: str | None = 'Component 2', ylabel_kwargs: dict[str, Any] | None = None, xlim: tuple[int | float | np.integer | np.floating, int | float | np.integer | np.floating] | None = None, ylim: tuple[int | float | np.integer | np.floating, int | float | np.integer | np.floating] | None = None, despine: bool = True, savefig_kwargs: dict[str, Any] | None = None) -> None: """ Plot a UMAP embedding created with `umap()`. Syntactic sugar for `plot_embedding('umap', ...)`. Args: color_column: an optional column of `obs` indicating how to color each cell in the plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Can be discrete (e.g. cell-type labels), specified as a String/Enum/Categorical column, or quantitative (e.g. the number of UMIs per cell), specified as an integer/floating-point column. Missing (`null`) cells will be plotted with the color `default_color`. Set to `None` to use `default_color` for all cells. filename: the file to save to. If `None`, generate the plot but do not save it, which allows it to be shown interactively or modified further before saving. cells_to_plot_column: an optional Boolean column of `obs` indicating which cells to plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to plot all cells passing QC. ax: the Matplotlib axes to save the plot onto; if `None`, create a new figure with Matpotlib's constrained layout and plot onto it figure_kwargs: a dictionary of keyword arguments to be passed to `plt.figure` when `ax` is `None`, such as: - `figsize`: a two-element sequence of the width and height of the figure in inches. Defaults to `[8, 6]`, 25% larger than Matplotlib's default of `[6.4, 4.8]`. - `layout`: the layout mechanism used by Matplotlib to avoid overlapping plot elements. Defaults to `'constrained'`, instead of Matplotlib's default of `None`. point_size: the size of the points for each cell; defaults to 20,000 divided by the number of cells. Can be a single number, or the name of a column of `obs` to make each point a different size. sort_by_frequency: if `True`, assign colors and sort the legend in order of decreasing frequency; if `False` (the default), use [natural sort order](en.wikipedia.org/wiki/Natural_sort_order). Cannot be `True` unless `colormap` is `None` and `color_column` is discrete; if `colormap` is not `None`, the plot order is determined by the order of the keys in `colormap`. colormap: a string or Colormap object indicating the Matplotlib colormap to use; or, if `color_column` is discrete, a dictionary mapping values in `color_column` to Matplotlib colors (cells with values of `color_column` that are not in the dictionary will be plotted in the color `default_color`). Defaults to `plt.rcParams['image.cmap']` (`'viridis'` by default) if `color_column` is continous, or the colors from a maximally perceptually distinct colormap if `color_column` is discrete (with colors assigned in decreasing order of frequency). Cannot be specified if `color_column` is `None`. lightness_range: a two-element tuple with the lightness range of colors to generate, or `None` to take the full range: `[0, 100]`. Can only be specified when `color_column` is discrete and `colormap` is `None`. chroma_range: a two-element tuple with the chroma range of colors to generate, or `None` to take the full range: `[0, 100]`. Grays have low chroma, and vivid colors have high chroma. Can only be specified when `color_column` is discrete and `colormap` is `None`. hue_range: a two-element tuple with the hue range of colors to generate, or `None` to take the full range: `[0, 360]`. Red is at 0°, green at 120°, and blue at 240°. Because it wraps around, the first element of the tuple can be greater than the second, unlike for `lightness_range` and `chroma_range`. Can only be specified when `color_column` is discrete and `colormap` is `None`. first_color: the first color of the palette. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. stride: as an optimization, consider only RGB colors where R, G, and B are all multiples of this value. Must be a small divisor of 255: 1, 3, 5, 15, or 17. Set to 1 for the best possible solution, at orders of magnitude more computational cost. default_color: the default color to plot cells in when `color_column` is `None`, or when certain cells have missing (`null`) values for `color_column`, or when `colormap` is a dictionary and some cells have values of `color_column` that are not in the dictionary. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. scatter_kwargs: a dictionary of keyword arguments to be passed to `ax.scatter()`, such as: - `rasterized`: whether to convert the scatter plot points to a raster (bitmap) image when saving to a vector format like PDF. Defaults to `True`, instead of Matplotlib's default of `False`. - `marker`: the shape to use for plotting each cell - `norm`, `vmin`, and `vmax`: control how the `colormap` maps the numbers in `color_column` to colors, if `color_column` is numeric - `alpha`: the transparency of each point - `linewidths` and `edgecolors`: the width and color of the borders around each marker. These are absent by default (`linewidths=0`, `edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's default. Both arguments can be either single values or sequences. - `zorder`: the order in which the cells are plotted, with higher values appearing on top of lower ones. Specifying `s`, `c`/`color`, or `cmap` will raise an error, since these arguments conflict with the `point_size`, `color_column`, and `colormap` arguments, respectively. label: whether to label cells with each distinct value of `color_column`. Labels will be placed at the median x and y position of the points with that color. Can only be `True` when `color_column` is discrete. When set to `True`, you may also want to set `legend=False` to avoid redundancy. label_kwargs: a dictionary of keyword arguments to be passed to `ax.text()` when adding labels to control the text properties, such as: - `color` and `size` to modify the text color/size - `verticalalignment` and `horizontalalignment` to control vertical and horizontal alignment. By default, unlike Matplotlib, these are both set to `'center'`. - `path_effects` to set properties for the border around the text. By default, set to `matplotlib.patheffects.withStroke(linewidth=3, foreground='white', alpha=0.75)` instead of Matplotlib's default of `None`, to put a semi-transparent white border around the labels for better contrast. Can only be specified when `label=True`. legend: whether to add a legend for each value in `color_column`. Ignored unless `color_column` is discrete. legend_kwargs: a dictionary of keyword arguments to be passed to `ax.legend()` to modify the legend, such as: - `loc`, `bbox_to_anchor`, and `bbox_transform` to set its location. By default, `loc` is set to `'center left'` and `bbox_to_anchor` to `(1, 0.5)` to put the legend to the right of the plot, anchored at the middle. - `ncols` to set its number of columns. By default, set to `ceil(obs[color_column].n_unique() / 24)` to have at most 24 items per column. - `prop`, `fontsize`, and `labelcolor` to set its font properties - `facecolor` and `framealpha` to set its background color and transparency - `frameon=True` or `edgecolor` to add or color its border. `frameon` defaults to `False`, instead of Matplotlib's default of `True`. - `title` to add a legend title Can only be specified when `color_column` is discrete and `legend=True`. colorbar: whether to add a colorbar. Ignored unless `color_column` is quantitative. colorbar_kwargs: a dictionary of keyword arguments to be passed to `plt.colorbar()`, such as: - `location`: `'left'`, `'right'`, `'top'`, or `'bottom'` - `orientation`: `'vertical'` or `'horizontal'` - `fraction`: the fraction of the axes to allocate to the colorbar. Defaults to 0.15. - `shrink`: the fraction to multiply the size of the colorbar by. Defaults to 0.5, instead of Matplotlib's default of 1. - `aspect`: the ratio of the colorbar's long to short dimensions. Defaults to 20. - `pad`: the fraction of the axes between the colorbar and the rest of the figure. Defaults to 0.01, instead of Matplotlib's default of 0.05 if vertical and 0.15 if horizontal. Can only be specified when `color_column` is quantitative and `colorbar=True`. title: the title of the plot, or `None` to not add a title title_kwargs: a dictionary of keyword arguments to be passed to `ax.set_title()` to control text properties, such as `color` and `size`. Can only be specified when `title` is not `None`. xlabel: the x-axis label, or `None` to not label the x-axis xlabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_xlabel()` to control text properties, such as `color` and `size`. Can only be specified when `xlabel` is not `None`. ylabel: the y-axis label, or `None` to not label the y-axis ylabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_ylabel()` to control text properties, such as `color` and `size`. Can only be specified when `ylabel` is not `None`. xlim: a length-2 tuple of the left and right x-axis limits, or `None` to set the limits based on the data ylim: a length-2 tuple of the bottom and top y-axis limits, or `None` to set the limits based on the data despine: whether to remove the top and right spines (borders of the plot area) from the plot savefig_kwargs: a dictionary of keyword arguments to be passed to `plt.savefig()`, such as: - `dpi`: defaults to 300 instead of Matplotlib's default of 150 - `bbox_inches`: the bounding box of the portion of the figure to save; defaults to `'tight'` (crop out any blank borders) instead of Matplotlib's default of `None` (save the entire figure) - `pad_inches`: the number of inches of padding to add on each of the four sides of the figure when saving. Defaults to `'layout'` (use the padding from the constrained layout engine, when `ax` is not `None`) instead of Matplotlib's default of 0.1. - `transparent`: whether to save with a transparent background; defaults to `True` if saving to a PDF (i.e. when `filename` ends with `'.pdf'`) and `False` otherwise, instead of Matplotlib's default of always being `False`. Can only be specified when `filename` is specified. """ self.plot_embedding('umap', color_column, filename, cells_to_plot_column=cells_to_plot_column, ax=ax, figure_kwargs=figure_kwargs, point_size=point_size, sort_by_frequency=sort_by_frequency, colormap=colormap, lightness_range=lightness_range, chroma_range=chroma_range, hue_range=hue_range, first_color=first_color, stride=stride, default_color=default_color, scatter_kwargs=scatter_kwargs, label=label, label_kwargs=label_kwargs, legend=legend, legend_kwargs=legend_kwargs, colorbar=colorbar, colorbar_kwargs=colorbar_kwargs, title=title, title_kwargs=title_kwargs, xlabel=xlabel, xlabel_kwargs=xlabel_kwargs, ylabel=ylabel, ylabel_kwargs=ylabel_kwargs, xlim=xlim, ylim=ylim, despine=despine, savefig_kwargs=savefig_kwargs)
[docs] def plot_embedding( self, embedding_key: str, color_column: SingleCellColumn | None, filename: str | Path | None = None, /, *, cells_to_plot_column: SingleCellColumn | None = 'passed_QC', ax: 'Axes' | None = None, figure_kwargs: dict[str, Any] | None = None, point_size: int | float | np.integer | np.floating | str | None = None, sort_by_frequency: bool = False, colormap: str | 'Colormap' | dict[Any, Color] = None, lightness_range: tuple[float | np.floating, float | np.floating] | None = (100 / 3, 200 / 3), chroma_range: tuple[float | np.floating, float | np.floating] | None = (50, 100), hue_range: tuple[float | np.floating, float | np.floating] | None = None, first_color: Color = '#008cb9', stride: int | np.integer = 5, default_color: Color = 'lightgray', scatter_kwargs: dict[str, Any] | None = None, label: bool = False, label_kwargs: dict[str, Any] | None = None, legend: bool = True, legend_kwargs: dict[str, Any] | None = None, colorbar: bool = True, colorbar_kwargs: dict[str, Any] | None = None, title: str | None = None, title_kwargs: dict[str, Any] | None = None, xlabel: str | None = 'Component 1', xlabel_kwargs: dict[str, Any] | None = None, ylabel: str | None = 'Component 2', ylabel_kwargs: dict[str, Any] | None = None, xlim: tuple[int | float | np.integer | np.floating, int | float | np.integer | np.floating] | None = None, ylim: tuple[int | float | np.integer | np.floating, int | float | np.integer | np.floating] | None = None, despine: bool = True, savefig_kwargs: dict[str, Any] | None = None) -> None: """ Plot the specified 2D embedding. Args: embedding_key: the key of `obsm` containing the embedding to plot color_column: an optional column of `obs` indicating how to color each cell in the plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Can be discrete (e.g. cell-type labels), specified as a String/Enum/Categorical column, or quantitative (e.g. the number of UMIs per cell), specified as an integer/floating-point column. Missing (`null`) cells will be plotted with the color `default_color`. Set to `None` to use `default_color` for all cells. filename: the file to save to. If `None`, generate the plot but do not save it, which allows it to be shown interactively or modified further before saving. cells_to_plot_column: an optional Boolean column of `obs` indicating which cells to plot. Can be a column name, a polars expression, a polars Series, a 1D NumPy array, or a function that takes in this SingleCell dataset and returns a polars Series or 1D NumPy array. Set to `None` to plot all cells passing QC. ax: the Matplotlib axes to save the plot onto; if `None`, create a new figure with Matpotlib's constrained layout and plot onto it figure_kwargs: a dictionary of keyword arguments to be passed to `plt.figure` when `ax` is `None`, such as: - `figsize`: a two-element sequence of the width and height of the figure in inches. Defaults to `[8, 6]`, 25% larger than Matplotlib's default of `[6.4, 4.8]`. - `layout`: the layout mechanism used by Matplotlib to avoid overlapping plot elements. Defaults to `'constrained'`, instead of Matplotlib's default of `None`. point_size: the size of the points for each cell; defaults to 20,000 divided by the number of cells. Can be a single number, or the name of a column of `obs` to make each point a different size. sort_by_frequency: if `True`, assign colors and sort the legend in order of decreasing frequency; if `False` (the default), use [natural sort order](en.wikipedia.org/wiki/Natural_sort_order). Cannot be `True` unless `colormap` is `None` and `color_column` is discrete; if `colormap` is not `None`, the plot order is determined by the order of the keys in `colormap`. colormap: a string or Colormap object indicating the Matplotlib colormap to use; or, if `color_column` is discrete, a dictionary mapping values in `color_column` to Matplotlib colors (cells with values of `color_column` that are not in the dictionary will be plotted in the color `default_color`). Defaults to `plt.rcParams['image.cmap']` (`'viridis'` by default) if `color_column` is continous, or the colors from a maximally perceptually distinct colormap if `color_column` is discrete (with colors assigned in decreasing order of frequency). Cannot be specified if `color_column` is `None`. lightness_range: a two-element tuple with the lightness range of colors to generate, or `None` to take the full range: `[0, 100]`. Can only be specified when `color_column` is discrete and `colormap` is `None`. chroma_range: a two-element tuple with the chroma range of colors to generate, or `None` to take the full range: `[0, 100]`. Grays have low chroma, and vivid colors have high chroma. Can only be specified when `color_column` is discrete and `colormap` is `None`. hue_range: a two-element tuple with the hue range of colors to generate, or `None` to take the full range: `[0, 360]`. Red is at 0°, green at 120°, and blue at 240°. Because it wraps around, the first element of the tuple can be greater than the second, unlike for `lightness_range` and `chroma_range`. Can only be specified when `color_column` is discrete and `colormap` is `None`. first_color: the first color of the palette. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. stride: as an optimization, consider only RGB colors where R, G, and B are all multiples of this value. Must be a small divisor of 255: 1, 3, 5, 15, or 17. Set to 1 for the best possible solution, at orders of magnitude more computational cost. default_color: the default color to plot cells in when `color_column` is `None`, or when certain cells have missing (`null`) values for `color_column`, or when `colormap` is a dictionary and some cells have values of `color_column` that are not in the dictionary. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. scatter_kwargs: a dictionary of keyword arguments to be passed to `ax.scatter()`, such as: - `rasterized`: whether to convert the scatter plot points to a raster (bitmap) image when saving to a vector format like PDF. Defaults to `True`, instead of Matplotlib's default of `False`. - `marker`: the shape to use for plotting each cell - `norm`, `vmin`, and `vmax`: control how the `colormap` maps the numbers in `color_column` to colors, if `color_column` is numeric - `alpha`: the transparency of each point - `linewidths` and `edgecolors`: the width and color of the borders around each marker. These are absent by default (`linewidths=0`, `edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's default. Both arguments can be either single values or sequences. - `zorder`: the order in which the cells are plotted, with higher values appearing on top of lower ones. Specifying `s`, `c`/`color`, or `cmap` will raise an error, since these arguments conflict with the `point_size`, `color_column`, and `colormap` arguments, respectively. label: whether to label cells with each distinct value of `color_column`. Labels will be placed at the median x and y position of the points with that color. Can only be `True` when `color_column` is discrete. When set to `True`, you may also want to set `legend=False` to avoid redundancy. label_kwargs: a dictionary of keyword arguments to be passed to `ax.text()` when adding labels to control the text properties, such as: - `color` and `size` to modify the text color/size - `verticalalignment` and `horizontalalignment` to control vertical and horizontal alignment. By default, unlike Matplotlib, these are both set to `'center'`. - `path_effects` to set properties for the border around the text. By default, set to `matplotlib.patheffects.withStroke(linewidth=3, foreground='white', alpha=0.75)` instead of Matplotlib's default of `None`, to put a semi-transparent white border around the labels for better contrast. Can only be specified when `label=True`. legend: whether to add a legend for each value in `color_column`. Ignored unless `color_column` is discrete. legend_kwargs: a dictionary of keyword arguments to be passed to `ax.legend()` to modify the legend, such as: - `loc`, `bbox_to_anchor`, and `bbox_transform` to set its location. By default, `loc` is set to `'center left'` and `bbox_to_anchor` to `(1, 0.5)` to put the legend to the right of the plot, anchored at the middle. - `ncols` to set its number of columns. By default, set to `ceil(obs[color_column].n_unique() / 24)` to have at most 24 items per column. - `prop`, `fontsize`, and `labelcolor` to set its font properties - `facecolor` and `framealpha` to set its background color and transparency - `frameon=True` or `edgecolor` to add or color its border. `frameon` defaults to `False`, instead of Matplotlib's default of `True`. - `title` to add a legend title Can only be specified when `color_column` is discrete and `legend=True`. colorbar: whether to add a colorbar. Ignored unless `color_column` is quantitative. colorbar_kwargs: a dictionary of keyword arguments to be passed to `plt.colorbar()`, such as: - `location`: `'left'`, `'right'`, `'top'`, or `'bottom'` - `orientation`: `'vertical'` or `'horizontal'` - `fraction`: the fraction of the axes to allocate to the colorbar. Defaults to 0.15. - `shrink`: the fraction to multiply the size of the colorbar by. Defaults to 0.5, instead of Matplotlib's default of 1. - `aspect`: the ratio of the colorbar's long to short dimensions. Defaults to 20. - `pad`: the fraction of the axes between the colorbar and the rest of the figure. Defaults to 0.01, instead of Matplotlib's default of 0.05 if vertical and 0.15 if horizontal. Can only be specified when `color_column` is quantitative and `colorbar=True`. title: the title of the plot, or `None` to not add a title title_kwargs: a dictionary of keyword arguments to be passed to `ax.set_title()` to control text properties, such as `color` and `size`. Can only be specified when `title` is not `None`. xlabel: the x-axis label, or `None` to not label the x-axis xlabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_xlabel()` to control text properties, such as `color` and `size`. Can only be specified when `xlabel` is not `None`. ylabel: the y-axis label, or `None` to not label the y-axis ylabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_ylabel()` to control text properties, such as `color` and `size`. Can only be specified when `ylabel` is not `None`. xlim: a length-2 tuple of the left and right x-axis limits, or `None` to set the limits based on the data ylim: a length-2 tuple of the bottom and top y-axis limits, or `None` to set the limits based on the data despine: whether to remove the top and right spines (borders of the plot area) from the plot savefig_kwargs: a dictionary of keyword arguments to be passed to `plt.savefig()`, such as: - `dpi`: defaults to 300 instead of Matplotlib's default of 150 - `bbox_inches`: the bounding box of the portion of the figure to save; defaults to `'tight'` (crop out any blank borders) instead of Matplotlib's default of `None` (save the entire figure) - `pad_inches`: the number of inches of padding to add on each of the four sides of the figure when saving. Defaults to `'layout'` (use the padding from the constrained layout engine, when `ax` is not `None`) instead of Matplotlib's default of 0.1. - `transparent`: whether to save with a transparent background; defaults to `True` if saving to a PDF (i.e. when `filename` ends with `'.pdf'`) and `False` otherwise, instead of Matplotlib's default of always being `False`. Can only be specified when `filename` is specified. """ # Import matplotlib signal.signal(signal.SIGINT, signal.SIG_IGN) try: import matplotlib.pyplot as plt finally: signal.signal(signal.SIGINT, signal.default_int_handler) # Get `cells_to_plot_column`, if not `None` original_cells_to_plot_column = cells_to_plot_column if cells_to_plot_column is not None: cells_to_plot_column = self._get_column( 'obs', cells_to_plot_column, 'cells_to_plot_column', pl.Boolean, allow_missing=isinstance(cells_to_plot_column, str) and cells_to_plot_column == 'passed_QC') # If `color_column` was specified, check that it either discrete # (String, Enum, or Categorical) or quantitative (integer or # floating-point). If discrete, check that `color_column` has at least # two distinct values. original_color_column = color_column if color_column is not None: color_column = self._get_column( 'obs', color_column, 'color_column', (pl.String, pl.Enum, pl.Categorical, 'integer', 'floating-point'), allow_null=True, QC_column=cells_to_plot_column) unique_color_column = color_column.unique().drop_nulls() dtype = color_column.dtype discrete = dtype in (pl.String, pl.Enum, pl.Categorical) if discrete and len(unique_color_column) == 1: color_column_description = \ SingleCell._describe_column('color_column', original_color_column) error_message = ( f'{color_column_description} must have at least two ' f'distinct values when its data ' f'type is {dtype.base_type()!r}') raise ValueError(error_message) # If `filename` was specified, check that it is a string or # `pathlib.Path` and that its base directory exists; if `filename` is # `None`, make sure `savefig_kwargs` is also `None` if filename is not None: check_type(filename, 'filename', (str, Path), 'a string or pathlib.Path') directory = os.path.dirname(filename) if directory and not os.path.isdir(directory): error_message = ( f'{filename} refers to a file in the directory ' f'{directory!r}, but this directory does not exist') raise NotADirectoryError(error_message) filename = str(filename) elif savefig_kwargs is not None: error_message = 'savefig_kwargs must be None when filename is None' raise ValueError(error_message) # Check that `embedding_key` is the name of a key in `obsm` check_type(embedding_key, 'embedding_key', str, 'a string') if embedding_key not in self._obsm: error_message = \ f'embedding_key {embedding_key!r} is not a key of obsm' if embedding_key in ('umap', 'pacmap', 'localmap'): error_message += ( f'; did you forget to run {embedding_key}()?') raise ValueError(error_message) # Check that the embedding `embedding_key` references is 2D. embedding = self._obsm[embedding_key] if embedding.shape[1] != 2: error_message = ( f'the embedding at obsm[{embedding_key!r}] is ' f'{embedding.shape[1]:,}-dimensional, but must be ' f'2-dimensional to be plotted') raise ValueError(error_message) # If `cells_to_plot_column` was specified, subset to these cells if cells_to_plot_column is not None: embedding = embedding[cells_to_plot_column.to_numpy()] if color_column is not None: color_column = color_column.filter(cells_to_plot_column) unique_color_column = color_column.unique().drop_nulls() # Check that the embedding does not contain NaNs if np.isnan(embedding).any(): error_message = \ f'the embedding at obsm[{embedding_key!r}] contains NaNs; ' if cells_to_plot_column is None and 'passed_QC' in self._obs: error_message += ( 'did you forget to set QC_column to None when creating the ' 'embedding, to match the fact that you set ' 'cells_to_plot_column to None here?') else: cells_to_plot_column_description = \ SingleCell._describe_column('cells_to_plot_column', original_cells_to_plot_column) error_message += ( f'does your {cells_to_plot_column_description} contain ' f'cells that were excluded by the QC_column used when ' f'creating the embedding?') raise ValueError(error_message) # For each of the kwargs arguments, if the argument was specified, # check that it is a dictionary and that all its keys are strings. for kwargs, kwargs_name in ((figure_kwargs, 'figure_kwargs'), (scatter_kwargs, 'scatter_kwargs'), (label_kwargs, 'label_kwargs'), (legend_kwargs, 'legend_kwargs'), (colorbar_kwargs, 'colorbar_kwargs'), (title_kwargs, 'title_kwargs'), (xlabel_kwargs, 'xlabel_kwargs'), (ylabel_kwargs, 'ylabel_kwargs'), (savefig_kwargs, 'savefig_kwargs')): if kwargs is not None: check_type(kwargs, kwargs_name, dict, 'a dictionary') for key in kwargs: if not isinstance(key, str): error_message = ( f'all keys of {kwargs_name} must be strings, but ' f'it contains a key of type ' f'{type(key).__name__!r}') raise TypeError(error_message) # If `figure_kwargs` was specified, check that `ax` is `None` if figure_kwargs is not None and ax is not None: error_message = ( 'figure_kwargs must be None when ax is not None, since a new ' 'figure does not need to be generated when plotting onto an ' 'existing axis') raise ValueError(error_message) # If `point_size` is `None`, default to 20,000 / num_cells; otherwise, # check that it is a positive number or the name of a numeric column of # `obs` with all-positive numbers num_cells = \ len(self) if cells_to_plot_column is None else len(embedding) if point_size is None: point_size = 20_000 / num_cells else: check_type(point_size, 'point_size', (int, float, str), 'a positive number or string') if isinstance(point_size, (int, float)): check_bounds(point_size, 'point_size', 0, left_open=True) else: if point_size not in self._obs: error_message = \ f'point_size {point_size!r} is not a column of obs' raise ValueError(error_message) point_size = self._obs[point_size] if not (point_size.dtype.is_integer() or point_size.dtype.is_float()): error_message = ( f'the point_size column, obs[{point_size!r}], must ' f'have an integer or floating-point data type, but ' f'has data type {point_size.dtype.base_type()!r}') raise TypeError(error_message) if point_size.min() <= 0: error_message = ( f'the point_size column, obs[{point_size!r}], does ' f'not have all-positive elements') raise ValueError(error_message) # If `sort_by_frequency=True`, ensure `colormap` is `None` and # `color_column` is discrete check_type(sort_by_frequency, 'sort_by_frequency', bool, 'Boolean') if sort_by_frequency: if colormap is not None: error_message = ( f'sort_by_frequency must be False when colormap is ' f'specified') raise ValueError(error_message) if color_column is None: error_message = \ 'sort_by_frequency must be False when color_column is None' raise ValueError(error_message) if not discrete: color_column_description = \ SingleCell._describe_column('color_column', original_color_column) error_message = ( f'sort_by_frequency must be False when ' f'{color_column_description} is continuous') raise ValueError(error_message) # Handle coloring based on the values of `colormap` and `color_column` if colormap is not None: # If `colormap` was specified, check that it is a string in # `plt.colormaps`, Colormap object, or dictionary where all keys # are in `color_column` and all values are valid Matplotlib colors. # Normalize the color(s) to RGBA. Make sure `color_column` is not # `None` and `lightness_range`, `chroma_range`, `hue_range`, # `first_color`, and `stride` have their default values. check_type(colormap, 'colormap', (str, plt.matplotlib.colors.Colormap, dict), 'a string, matplotlib Colormap object, or dictionary') if color_column is None: error_message = \ 'colormap must be None when color_column is None' raise ValueError(error_message) if not (isinstance(lightness_range, tuple) and len(lightness_range) == 2 and isinstance(lightness_range[0], float) and lightness_range[0] == 100 / 3 and isinstance(lightness_range[1], float) and lightness_range[1] == 200 / 3): error_message = ( f'lightness_range cannot be specified when colormap is ' f'specified') raise ValueError(error_message) if not (isinstance(chroma_range, tuple) and len(chroma_range) == 2 and isinstance(chroma_range[0], int) and chroma_range[0] == 50 and isinstance(chroma_range[1], int) and chroma_range[1] == 100): error_message = ( f'chroma_range cannot be specified when colormap is ' f'specified') raise ValueError(error_message) for arg, arg_name in ((hue_range, 'hue_range'), (first_color, 'first_color'), (stride, 'stride')): if arg is not None: error_message = \ f'{arg_name} must be None when colormap is specified' raise ValueError(error_message) if isinstance(colormap, str): colormap = plt.colormaps[colormap] elif isinstance(colormap, dict): if not discrete: color_column_description = \ SingleCell._describe_column('color_column', original_color_column) error_message = ( f'colormap cannot be a dictionary when ' f'{color_column_description} is continuous') raise ValueError(error_message) for key, value in colormap.items(): if not isinstance(key, str): error_message = ( f'all keys of colormap must be strings, but it ' f'contains a key of type {type(key).__name__!r}') raise TypeError(error_message) if key not in unique_color_column: error_message = ( f'colormap is a dictionary containing the key ' f'{key!r}, which is not one of the values in ' f'obs[{color_column!r}]') raise ValueError(error_message) if not plt.matplotlib.colors.is_color_like(value): error_message = ( f'colormap[{key!r}] is not a valid Matplotlib ' f'color') raise ValueError(error_message) colormap[key] = plt.matplotlib.colors.to_rgba(value) else: if color_column is not None and discrete: # `color_column` is discrete and `colormap` was not specified; # generate a maximally perceptually distinct colormap. Assign # colors in natural sort order, or decreasing order of # frequency if `sort_by_frequency=True`. color_order = color_column\ .value_counts(sort=True)\ .to_series()\ .drop_nulls() if sort_by_frequency else \ sorted(unique_color_column, key=lambda color_label: [ int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', color_label)]) colormap = generate_palette(num_colors=len(color_order), lightness_range=lightness_range, chroma_range=chroma_range, hue_range=hue_range, first_color=first_color, stride=stride) colormap = np.c_[colormap, np.ones(len(colormap))] # add alpha colormap = dict(zip(color_order, colormap)) else: # `color_column` is `None` or continuous, so make sure # `lightness_range`, `chroma_range`, `hue_range`, # `first_color`, and `stride` have their default values for arg, arg_name in ((lightness_range, 'lightness_range'), (chroma_range, 'chroma_range'), (hue_range, 'hue_range'), (first_color, 'first_color'), (stride, 'stride')): if arg is lightness_range: if isinstance(lightness_range, tuple) and \ len(lightness_range) == 2 and \ isinstance(lightness_range[0], float) and \ lightness_range[0] == 100 / 3 and \ isinstance(lightness_range[1], float) and \ lightness_range[1] == 200 / 3: continue elif arg is chroma_range: if isinstance(chroma_range, tuple) and \ len(chroma_range) == 2 and \ isinstance(chroma_range[0], int) and \ chroma_range[0] == 50 and \ isinstance(chroma_range[1], int) and \ chroma_range[1] == 100: continue elif arg is None: continue if color_column is None: error_message = ( f'{arg_name} must be None when color_column is ' f'None') raise ValueError(error_message) else: color_column_description = \ SingleCell._describe_column( 'color_column', original_color_column) error_message = ( f'{arg_name} must be None when ' f'{color_column_description} is continuous') raise ValueError(error_message) # Check that `default_color` is a valid Matplotlib color, and convert # it to RGBA if not plt.matplotlib.colors.is_color_like(default_color): error_message = 'default_color is not a valid Matplotlib color' raise ValueError(error_message) default_color = plt.matplotlib.colors.to_rgba(default_color) # Override the defaults for certain keys of `scatter_kwargs` default_scatter_kwargs = dict(rasterized=True, linewidths=0, edgecolors=(0, 0, 0, 0)) scatter_kwargs = default_scatter_kwargs | scatter_kwargs \ if scatter_kwargs is not None else default_scatter_kwargs # Check that `scatter_kwargs` does not contain the `s`, `c`/`color`, or # `cmap` keys if 's' in scatter_kwargs: error_message = ( "'s' cannot be specified as a key in scatter_kwargs; specify " "the point_size argument instead") raise ValueError(error_message) for key in 'c', 'color', 'cmap': if key in scatter_kwargs: error_message = ( f'{key!r} cannot be specified as a key in scatter_kwargs; ' f'specify the color_column, colormap, lightness_range, ' f'chroma_range, hue_range, first_color, stride, and/or ' f'default_color arguments instead') raise ValueError(error_message) # If `label=True`, check that `color_column` is discrete. # If `label=False`, check that `label_kwargs` is `None`. check_type(label, 'label', bool, 'Boolean') if label: if color_column is None: error_message = 'color_column cannot be None when label=True' raise ValueError(error_message) if not discrete: color_column_description = \ SingleCell._describe_column('color_column', original_color_column) error_message = ( f'{color_column_description} cannot be continuous when ' f'label=True') raise ValueError(error_message) elif label_kwargs is not None: error_message = 'label_kwargs must be None when label=False' raise ValueError(error_message) # Only add a legend if `legend=True` and `color_column` is discrete. # If not adding a legend, check that `legend_kwargs` is `None`. check_type(legend, 'legend', bool, 'Boolean') add_legend = legend and color_column is not None and discrete if not add_legend and legend_kwargs is not None: if color_column is None: error_message = \ 'legend_kwargs must be None when color_column is None' raise ValueError(error_message) else: color_column_description = SingleCell._describe_column( 'color_column', original_color_column) error_message = ( f'legend_kwargs must be None when ' f'{color_column_description} is continuous') raise ValueError(error_message) # Only add a colorbar if `colorbar=True` and `color_column` is # continuous. If not adding a colorbar, check that `colorbar_kwargs` is # `None`. check_type(colorbar, 'colorbar', bool, 'Boolean') add_colorbar = colorbar and color_column is not None and not discrete if not add_colorbar and colorbar_kwargs is not None: if color_column is None: error_message = \ 'colorbar_kwargs must be None when color_column is None' raise ValueError(error_message) else: color_column_description = SingleCell._describe_column( 'color_column', original_color_column) error_message = ( f'colorbar_kwargs must be None when ' f'{color_column_description} is discrete') raise ValueError(error_message) # Check that `title` is a string or `None`; if `None`, check that # `title_kwargs` is `None` as well. Ditto for `xlabel` and `ylabel`. for arg, arg_name, arg_kwargs in ( (title, 'title', title_kwargs), (xlabel, 'xlabel', xlabel_kwargs), (ylabel, 'ylabel', ylabel_kwargs)): if arg is not None: check_type(arg, arg_name, str, 'a string') elif arg_kwargs is not None: error_message = \ f'{arg_name}_kwargs must be None when {arg_name} is None' raise ValueError(error_message) # Check that `xlim` and `ylim` are be length-2 tuples or `None`, with # the first element less than the second for arg, arg_name in (xlim, 'xlim'), (ylim, 'ylim'): if arg is not None: check_type(arg, arg_name, tuple, 'a length-2 tuple') if len(arg) != 2: error_message = ( f'{arg_name} must be a length-2 tuple, but has length ' f'{len(arg):,}') raise ValueError(error_message) if arg[0] >= arg[1]: error_message = \ f'{arg_name}[0] must be less than {arg_name}[1]' raise ValueError(error_message) # If `color_column` is `None`, plot all cells in `default_color`. If # `colormap` is a dictionary, generate an explicit list of colors to # plot each cell in. If `colormap` is a Colormap, just pass it as the # cmap` argument. If `colormap` is missing and `color_column` is # continuous, set it to `plt.rcParams['image.cmap']` ('viridis' by # default) if color_column is None: c = default_color cmap = None elif isinstance(colormap, dict): # Fill both missing values and values missing from `colormap` with # `default_color` if color_column.dtype == pl.String: color_column = color_column.cast(pl.Categorical) categories = color_column.cat.get_categories() lookup = np.array([colormap.get(cat, default_color) for cat in categories], dtype=np.float32) c = lookup[color_column.to_physical().fill_null(0)] c[color_column.is_null()] = default_color cmap = None else: # Need to `copy()` because `set_bad()` is in-place c = color_column.to_numpy() if colormap is not None: cmap = colormap.copy() cmap.set_bad(default_color) else: # `color_column` is continuous cmap = plt.rcParams['image.cmap'] # Check that `despine` is Boolean check_type(despine, 'despine', bool, 'Boolean') # If `ax` is `None`, create a new figure; otherwise, check that it is a # Matplotlib axis make_new_figure = ax is None try: if make_new_figure: default_figure_kwargs = \ dict(figsize=(8, 6), layout='constrained') figure_kwargs = default_figure_kwargs | figure_kwargs \ if figure_kwargs is not None else default_figure_kwargs plt.figure(**figure_kwargs) ax = plt.gca() else: check_type(ax, 'ax', plt.Axes, 'a Matplotlib axis') # Make a scatter plot of the embedding with equal x-y aspect # ratios. If `color_column` is discrete (and so `colormap` is a # dictionary), plot one color at a time, in order of decreasing # frequency, so rarer cell types end up on top of more common # ones. This also lets each `ax.scatter()` call use Matplotlib's # fast path for when the color is scalar (i.e. all points being # plotted are the same color). Cells with missing (`null`) values, # or with values not in `colormap`, are plotted first (at the # bottom) in `default_color`. if color_column is not None and discrete and \ isinstance(colormap, dict): color_values = color_column.to_numpy() point_size_per_cell = isinstance(point_size, pl.Series) if point_size_per_cell: point_size = point_size.to_numpy() unknown_mask = ~np.isin(color_values, tuple(colormap)) if unknown_mask.any(): scatter = ax.scatter( embedding[unknown_mask, 0], embedding[unknown_mask, 1], s=point_size[unknown_mask] if point_size_per_cell else point_size, color=default_color, **scatter_kwargs) ordered_labels = color_column\ .value_counts(sort=True)\ .to_series()\ .drop_nulls() for color_label in ordered_labels: if color_label not in colormap: continue mask = color_values == color_label scatter = ax.scatter( embedding[mask, 0], embedding[mask, 1], s=point_size[mask] if point_size_per_cell else point_size, color=colormap[color_label], **scatter_kwargs) else: scatter = ax.scatter(embedding[:, 0], embedding[:, 1], s=point_size, c=c, cmap=cmap, **scatter_kwargs) ax.set_aspect('equal') # Add the title, axis labels and axis limits if title is not None: if title_kwargs is None: ax.set_title(title) else: ax.set_title(title, **title_kwargs) if xlabel is not None: if xlabel_kwargs is None: ax.set_xlabel(xlabel) else: ax.set_xlabel(xlabel, **xlabel_kwargs) if ylabel is not None: if ylabel_kwargs is None: ax.set_ylabel(ylabel) else: ax.set_ylabel(ylabel, **ylabel_kwargs) if xlim is not None: ax.set_xlim(*xlim) if ylim is not None: ax.set_ylim(*ylim) # Add the legend; override the defaults for certain values of # `legend_kwargs` if add_legend: default_legend_kwargs = dict( loc='center left', bbox_to_anchor=(1, 0.5), frameon=False, ncols=(len(unique_color_column) + 23) // 24) legend_kwargs = default_legend_kwargs | legend_kwargs \ if legend_kwargs is not None else default_legend_kwargs if isinstance(colormap, dict): for color_label, color in colormap.items(): ax.add_artist(plt.Line2D([], [], color=color, label=color_label, marker='o', markersize=4, linewidth=0)) plt.legend(**legend_kwargs) else: plt.legend(*scatter.legend_elements(), **legend_kwargs) # Add the colorbar; override the defaults for certain keys of # `colorbar_kwargs` if add_colorbar: default_colorbar_kwargs = dict(shrink=0.5, pad=0.01) colorbar_kwargs = default_colorbar_kwargs | colorbar_kwargs \ if colorbar_kwargs is not None else default_colorbar_kwargs cbar = plt.colorbar(scatter, ax=ax, **colorbar_kwargs) cbar.outline.set_visible(False) # Label cells; override the defaults for certain keys of # `label_kwargs` if label: from matplotlib.patheffects import withStroke if label_kwargs is None: label_kwargs = {} label_kwargs |= dict( horizontalalignment=label_kwargs.pop( 'horizontalalignment', label_kwargs.pop('ha', 'center')), verticalalignment=label_kwargs.pop( 'verticalalignment', label_kwargs.pop('va', 'center')), path_effects=[withStroke(linewidth=3, foreground='white', alpha=0.75)]) median_coordinates = pl.DataFrame({ 'x': embedding[:, 0], 'y': embedding[:, 1], 'color_column': color_column})\ .group_by('color_column')\ .agg(pl.median('x', 'y'))\ .drop_nulls()\ .sort('color_column') for color_label, median_x, median_y in \ median_coordinates.iter_rows(): ax.text(median_x, median_y, color_label, **label_kwargs) # Despine, if specified if despine: spines = ax.spines spines['top'].set_visible(False) spines['right'].set_visible(False) # Save, if `filename` is not `None`; override the defaults for # certain keys of `savefig_kwargs` if filename is not None: default_savefig_kwargs = \ dict(dpi=300, bbox_inches='tight', pad_inches='layout', transparent=filename is not None and filename.endswith('.pdf')) savefig_kwargs = default_savefig_kwargs | savefig_kwargs \ if savefig_kwargs is not None else default_savefig_kwargs with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) plt.savefig(filename, **savefig_kwargs) if make_new_figure: plt.close() except: # If we made a new figure, make sure to close it if there's an # exception (but not if there was no error and `filename` is # `None`, in case the user wants to modify it further before # saving) if make_new_figure: plt.close() raise