Source code for brisc.de

from __future__ import annotations
import numpy as np
import os
import polars as pl
import signal
import warnings
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Callable
from .utils import bonferroni, check_bounds, check_dtype, check_type, \
    fdr, plural, to_tuple_checked


[docs] class DE: """ Differential expression results returned by `Pseudobulk.DE()`. """
[docs] def __init__(self, source: str | Path | None = None, /, *, table: pl.DataFrame, voom_weights: dict[str, pl.DataFrame] | None = None, voom_plot_data: dict[str, pl.DataFrame] | None = None) -> \ None: """ Initialize the DE object. Args: source: a directory containing a DE object saved with `save()`. Mutually exclusive with `table`, `voom_weights`, and `voom_plot_data`. table: a polars DataFrame containing the DE results, with columns: - cell_type: the cell type in which DE was tested - coefficient: the coefficient (or contrast) for which DE was tested - gene: the gene for which DE was tested - logFC: the log2 fold change of the gene, i.e. its effect size - SE: the standard error of the effect size - LCI: the lower 95% confidence interval of the effect size - UCI: the upper 95% confidence interval of the effect size - AveExpr: the gene's average expression in this cell type, in log CPM - p: the DE p-value - Bonferroni: the Bonferroni-corrected DE p-value - FDR: the FDR q-value for the DE Mutually exclusive with `source`. voom_weights: an optional {cell_type: DataFrame} dictionary of voom weights, where rows are genes and columns are samples. The first column of each cell type's DataFrame, 'gene', contains the gene names. Mutually exclusive with `source`. voom_plot_data: an optional {cell_type: DataFrame} dictionary of info necessary to construct a voom plot with `DE.plot_voom()`. Mutually exclusive with `source`. """ if source is not None and table is not None: error_message = 'only one of source and table can be specified' raise ValueError(error_message) if source is not None: check_type(source, 'source', (str, Path), 'a string or pathlib.Path') if voom_plot_data is not None: error_message = ( 'voom_plot_data cannot be specified when source is ' 'specified') raise ValueError(error_message) if voom_weights is not None: error_message = ( 'voom_weights cannot be specified when source is ' 'specified') raise ValueError(error_message) source = str(source) if not os.path.isdir(source): if os.path.isfile(source): error_message = \ f'{source!r} must be a directory, not a file' raise NotADirectoryError(error_message) else: error_message = \ f'DE results directory {source!r} does not exist' raise FileNotFoundError(error_message) cell_types = [line.rstrip('\n') for line in open(f'{source}/cell_types.txt')] voom_weights = {cell_type: pl.read_parquet( os.path.join(source, f'{cell_type.replace("/", "-")}.' f'voom_weights.parquet')) for cell_type in cell_types} voom_plot_data = {cell_type: pl.read_parquet( os.path.join(source, f'{cell_type.replace("/", "-")}.' f'voom_plot_data.parquet')) for cell_type in cell_types} table = pl.read_parquet(os.path.join(source, 'table.parquet')) elif table is not None: check_type(table, 'table', pl.DataFrame, 'a polars DataFrame') if voom_weights is not None: if voom_plot_data is None: error_message = ( 'voom_plot_data must be specified when voom_weights ' 'is specified') raise ValueError(error_message) check_type(voom_weights, 'voom_weights', dict, 'a dictionary') if voom_weights.keys() != voom_plot_data.keys(): error_message = ( 'voom_weights and voom_plot_data must have matching ' 'cell types (keys)') raise ValueError(error_message) for key in voom_weights: if not isinstance(key, str): error_message = ( f'all keys of voom_weights and voom_plot_data ' f'must be strings (cell types), but they contain ' f'a key of type {type(key).__name__!r}') raise TypeError(error_message) if voom_plot_data is not None: if voom_weights is None: error_message = ( 'voom_weights must be specified when voom_plot_data ' 'is specified') raise ValueError(error_message) check_type(voom_plot_data, 'voom_plot_data', dict, 'a dictionary') else: error_message = 'either source or table must be specified' raise ValueError(error_message) self.table = table self.voom_weights = voom_weights self.voom_plot_data = voom_plot_data
def __repr__(self) -> str: """ Get a string representation of this DE object. Returns: A string summarizing the object. """ num_cell_types = self.table['cell_type'].n_unique() descr = ( f'DE object with {len(self.table):,} ' f'{"entries" if len(self.table) != 1 else "entry"} across ' f'{num_cell_types:,} {plural("cell type", num_cell_types)}:\n' f'{self.table}') return descr def __eq__(self, other: DE) -> bool: """ Test for equality with another DE object. Args: other: the other DE object to test for equality with Returns: Whether the two DE objects are identical. """ if not isinstance(other, DE): error_message = ( f'the left-hand operand of `==` is a DE object, but ' f'the right-hand operand has type {type(other).__name__!r}') raise TypeError(error_message) return self.table.equals(other.table) and \ (other.voom_weights is None if self.voom_weights is None else self.voom_weights.keys() == other.voom_weights.keys() and all(self.voom_weights[cell_type].equals( other.voom_weights[cell_type]) and self.voom_plot_data[cell_type].equals( other.voom_plot_data[cell_type]) for cell_type in self.voom_weights)) @property def groups(self) -> dict[str, tuple[str, ...] | None] | None: """ The groups used by `voomByGroup` for each cell type: a dictionary mapping cell type names to group names used by voomByGroup for that cell type, or `None` if voomByGroup was not used for that cell type. If `Pseudobulk.DE()` was called with `return_voom_info=False`, `groups` will be `None` instead of a dictionary. """ return {cell_type: None if 'xy_x' in data.columns else tuple(column[5:] for column in data.columns if column[:4] == 'xy_x') for cell_type, data in self.voom_plot_data.items()} \ if self.voom_plot_data is not None else None
[docs] def save(self, directory: str | Path, /, *, overwrite: bool = False) -> \ None: """ Save a DE object to `directory` (which must not exist unless `overwrite=True`, and will be created) with the table at `table.parquet`. If the DE object contains voom info (i.e. was created with `return_voom_info=True` in `Pseudobulk.DE()`, the default), also saves each cell type's voom weights and voom plot data to f'{cell_type}_voom_weights.parquet' and f'{cell_type}_voom_plot_data.parquet', as well as a text file, cell_types.txt, containing the cell types. Args: directory: the directory to save the DE object to overwrite: if `False`, raises an error if the directory exists; if `True`, overwrites files inside it as necessary """ check_type(directory, 'directory', (str, Path), 'a string or pathlib.Path') directory = str(directory) check_type(overwrite, 'overwrite', bool, 'Boolean') if not overwrite and os.path.exists(directory): if os.path.isfile(directory): error_message = ( f'cannot save to the directory {directory!r} because it ' f'already exists as a file') raise FileExistsError(error_message) else: error_message = ( f'directory {directory!r} already exists; set ' f'overwrite=True to overwrite') raise FileExistsError(error_message) os.makedirs(directory, exist_ok=overwrite) self.table.write_parquet(os.path.join(directory, 'table.parquet')) if self.voom_weights is not None: with open(os.path.join(directory, 'cell_types.txt'), 'w') as f: print('\n'.join(self.voom_weights), file=f) for cell_type in self.voom_weights: escaped_cell_type = cell_type.replace('/', '-') self.voom_weights[cell_type].write_parquet( os.path.join(directory, f'{escaped_cell_type}.' f'voom_weights.parquet')) self.voom_plot_data[cell_type].write_parquet( os.path.join(directory, f'{escaped_cell_type}.' f'voom_plot_data.parquet'))
[docs] def get_hits(self, *, significance_column: str = 'FDR', threshold: int | float | np.integer | np.floating = 0.05, num_top_hits: int | np.integer | None = None) -> pl.DataFrame: """ Get all (or the top) differentially expressed genes. Args: significance_column: the name of a numeric column of `self.table` to determine significance from threshold: the significance threshold corresponding to `significance_column` num_top_hits: the number of top hits to report for each cell type; if `None`, report all hits Returns: The `table` attribute of this DE object, subset to (top) DE hits. """ check_type(significance_column, 'significance_column', str, 'a string') if significance_column not in self.table: error_message = ( f'significance_column ({significance_column!r}) is not a ' f'column of self.table') raise ValueError(error_message) check_dtype(self.table[significance_column], f'self.table[{significance_column!r}]', 'floating-point') check_type(threshold, 'threshold', (int, float), 'a number > 0 and ≤ 1') check_bounds(threshold, 'threshold', 0, 1, left_open=True) if num_top_hits is not None: check_type(num_top_hits, 'num_top_hits', int, 'a positive integer') check_bounds(num_top_hits, 'num_top_hits', 1) return self.table\ .filter(pl.col(significance_column) < threshold)\ .pipe(lambda df: df.group_by('cell_type', maintain_order=True) .head(num_top_hits) if num_top_hits is not None else df)
[docs] def get_num_hits(self, *, significance_column: str = 'FDR', threshold: int | float | np.integer | np.floating = 0.05) -> pl.DataFrame: """ Get the number of differentially expressed genes in each cell type. Args: significance_column: the name of a numeric column of `self.table` to determine significance from threshold: the significance threshold corresponding to `significance_column` Returns: A DataFrame with one row per cell type and two columns: 'cell_type' and 'num_hits'. """ check_type(significance_column, 'significance_column', str, 'a string') if significance_column not in self.table: error_message = ( f'significance_column ({significance_column!r}) is not a ' f'column of self.table') raise ValueError(error_message) check_dtype(self.table[significance_column], f'self.table[{significance_column!r}]', 'floating-point') check_type(threshold, 'threshold', (int, float), 'a number > 0 and ≤ 1') check_bounds(threshold, 'threshold', 0, 1, left_open=True) return self.table\ .lazy()\ .filter(pl.col(significance_column) < threshold)\ .group_by('cell_type')\ .agg(num_hits=pl.len())\ .sort('cell_type')\ .collect()
[docs] def plot_voom(self, cell_type: str, filename: str | Path | None = None, /, *, ax: 'Axes' | None = None, figure_kwargs: dict[str, Any] | None = None, point_color: Color | dict[str, Color] | None = None, point_size: int | float | np.integer | np.floating | dict[str, int | float | np.integer | np.floating] = 1, line_color: Color | dict[str, Color] | None = None, line_width: int | float | np.integer | np.floating | dict[str, int | float | np.integer | np.floating] = 1.5, scatter_kwargs: dict[str, Any] | None | dict[str, dict[str, Any] | None] = None, plot_kwargs: dict[str, Any] | None | dict[str, dict[str, Any] | None] = None, legend: bool = True, legend_kwargs: dict[str, Any] | None = None, title: str | None = None, title_kwargs: dict[str, Any] | None = None, xlabel: str | None = 'Average log2(count + 0.5)', xlabel_kwargs: dict[str, Any] | None = None, ylabel: str | None = 'sqrt(standard deviation)', ylabel_kwargs: dict[str, Any] | None = None, despine: bool = True, savefig_kwargs: dict[str, Any] | None = None) -> None: """ Generate a voom plot for a cell type that differential expression was calculated for. Voom plots consist of a scatter plot with one point per gene. They visualize how the mean expression of each gene across samples (x) relates to the gene's variation in expression across samples (y). The plot also includes a LOESS (also called LOWESS) curve, a type of non-linear curve fit, of the mean-variance (x-y) trend. Specifically, the x position of a gene's point is the average, across samples, of the base-2 logarithm of the gene's count in each sample, plus a pseudocount of 0.5: in other words, mean(log2(count + 0.5)). The y position is the square root of the standard deviation, across samples, of the gene's log counts per million after regressing out, across samples, the differential expression design matrix. When running differential expression with voomByGroup, voom is run separately within each group, so the voom plot will show a separate LOESS trendline for each group, with the points and trendlines for each group shown in distinct colors. Many arguments to this function can be either a single value or a dictionary mapping group names to values. The group names can be viewed with `self.groups[cell_type]`. Args: cell_type: the cell type to generate the voom plot for filename: the file to save to. If `None`, generate the plot but do not save it, which allows it to be shown interactively or modified further before saving. ax: the Matplotlib axes to save the plot onto; if `None`, create a new figure with Matpotlib's constrained layout and plot onto it figure_kwargs: a dictionary of keyword arguments to be passed to `plt.figure` when `ax` is `None`, such as: - `figsize`: a two-element sequence of the width and height of the figure in inches. Defaults to `[6.4, 4.8]`. - `layout`: the layout mechanism used by Matplotlib to avoid overlapping plot elements. Defaults to `'constrained'`, instead of Matplotlib's default of `None`. point_color: the color of the points in the voom plot. Can be a single color or a dictionary mapping each of the group names in `self.groups[cell_type]` to colors. When not using voomByGroup, defaults to `'#666666'` (gray). When using voomByGroup with two groups, defaults to `'#666666'` for the first group in `self.groups[cell_type]` and `'#FF6666'` (red) for the second. When using voomByGroup with more than two groups, must be specified manually. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. point_size: the size of the points in the voom plot. Can be a single number or a dictionary mapping each of the group names in `self.groups[cell_type]` to numbers. line_color: the color of the LOESS trendline. Can be a single color or a dictionary mapping each of the group names in `self.groups[cell_type]` to colors. When not using voomByGroup, defaults to `'#000000'` (black). When using voomByGroup with two groups, defaults to `'#000000'` for the first group and `'#FF0000'` (red). for the second. When using voomByGroup with more than two groups, must be specified manually. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. line_width: the width of the LOESS trendline. Can be a single number or a dictionary mapping each of the group names in `self.groups[cell_type]` to numbers. scatter_kwargs: a dictionary (or dictionary mapping each of the group names in `self.groups[cell_type]` to dictionaries) of keyword arguments to be passed to `ax.scatter()`, such as: - `rasterized`: whether to convert the scatter plot points to a raster (bitmap) image when saving to a vector format like PDF. Defaults to `True`, instead of Matplotlib's default of `False`. - `marker`: the shape to use for plotting each cell - `norm`, `vmin`, and `vmax`: control how the numbers in `color_column` are converted to colors, if `color_column` is numeric - `alpha`: the transparency of each point - `linewidths` and `edgecolors`: the width and color of the borders around each marker. These are absent by default (`linewidths=0`, `edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's default. Both arguments can be either single values or sequences. - `zorder`: the order in which the cells are plotted, with higher values appearing on top of lower ones. Specifying `s` or `c`/`color`/`norm`/`vmin`/`vmax` will raise an error, since these arguments conflict with the `point_size` and `point_color` arguments, respectively. plot_kwargs: a dictionary (or dictionary mapping each of the group names in `self.groups[cell_type]` to dictionaries) of keyword arguments to be passed to `ax.plot()` when plotting the trendlines, such as `linestyle='--'` for dashed trendlines. Specifying `color`/`c` or `linewidth` will raise an error, since these arguments conflict with the `line_color` and `line_width` arguments, respectively. legend: whether to add a legend with the colors for each group when using voomByGroup. Only `legend=False` has an effect, and it can only be specified when using voomByGroup. Without groups, there will never be a legend, so specifying `legend=False` would be redundant. legend_kwargs: a dictionary of keyword arguments to be passed to `ax.legend()` to modify the legend, such as: - `loc`, `bbox_to_anchor`, and `bbox_transform` to set its location. - `prop`, `fontsize`, and `labelcolor` to set its font properties - `facecolor` and `framealpha` to set its background color and transparency - `frameon=True` or `edgecolor` to add or color its border. `frameon` defaults to `False`, instead of Matplotlib's default of `True`. - `title` to add a legend title Can only be specified when using voomByGroup with `legend=True`. title: the title of the plot, or `None` to not add a title title_kwargs: a dictionary of keyword arguments to be passed to `ax.set_title()` to control text properties, such as `color` and `size`. Can only be specified when `title` is not `None`. xlabel: the x-axis label, or `None` to not label the x-axis xlabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_xlabel()` to control text properties, such as `color` and `size`. Can only be specified when `xlabel` is not `None`. ylabel: the y-axis label, or `None` to not label the y-axis ylabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_ylabel()` to control text properties, such as `color` and `size`. Can only be specified when `ylabel` is not `None`. despine: whether to remove the top and right spines (borders of the plot area) from the voom plot savefig_kwargs: a dictionary of keyword arguments to be passed to `plt.savefig()`, such as: - `dpi`: defaults to 300 instead of Matplotlib's default of 150 - `bbox_inches`: the bounding box of the portion of the figure to save; defaults to 'tight' (crop out any blank borders) instead of Matplotlib's default of `None` (save the entire figure) - `pad_inches`: the number of inches of padding to add on each of the four sides of the figure when saving. Defaults to 'layout' (use the padding from the constrained layout engine), instead of Matplotlib's default of 0.1. - `transparent`: whether to save with a transparent background; defaults to `True` if saving to a PDF (i.e. when `PNG=False`) and `False` if saving to a PNG, instead of Matplotlib's default of always being `False`. Can only be specified when `filename` is specified. """ # Import matplotlib signal.signal(signal.SIGINT, signal.SIG_IGN) try: import matplotlib.pyplot as plt finally: signal.signal(signal.SIGINT, signal.default_int_handler) # Check that this DE object contains `voom_plot_data`, the data # necessary to generate the voom plot from (scatter-plot points and # LOESS trendlines) if self.voom_plot_data is None: error_message = ( 'this DE object does not contain the voom_plot_data ' 'attribute, which is necessary to generate voom plots; re-run ' 'Pseudobulk.DE() with return_voom_info=True to include this ' 'attribute') raise AttributeError(error_message) # Check that `cell_type` is a cell type in this DE object check_type(cell_type, 'cell_type', str, 'a string') if cell_type not in self.voom_plot_data: error_message = \ f'cell_type {cell_type!r} is not a cell type in this DE object' raise ValueError(error_message) # Get the voom plot data for this cell type voom_plot_data = self.voom_plot_data[cell_type] # Get the voomByGroup groups for this cell type (`None` if voomByGroup # was not used) groups = None if 'xy_x' in voom_plot_data.columns else \ tuple(column[5:] for column in voom_plot_data.columns if column[:4] == 'xy_x') # If `filename` was specified, check that it is a string or # `pathlib.Path` and that its base directory exists; if `filename` is # `None`, make sure `savefig_kwargs` is also `None` if filename is not None: check_type(filename, 'filename', (str, Path), 'a string or pathlib.Path') directory = os.path.dirname(filename) if directory and not os.path.isdir(directory): error_message = ( f'{filename} refers to a file in the directory ' f'{directory!r}, but this directory does not exist') raise NotADirectoryError(error_message) filename = str(filename) elif savefig_kwargs is not None: error_message = 'savefig_kwargs must be None when filename is None' raise ValueError(error_message) # If `figure_kwargs` was specified, check that `ax` is `None` if figure_kwargs is not None and ax is not None: error_message = ( 'figure_kwargs must be None when ax is not None, since a new ' 'figure does not need to be generated when plotting onto an ' 'existing axis') raise ValueError(error_message) # Check that `point_color` and `line_color` are valid Matplotlib colors # or dictionaries thereof, and convert them to hex. Or, if `None`, set # to their default value if there are no groups or exactly two groups. point_color_is_dict = isinstance(point_color, dict) if point_color is None: if groups is None: point_color = '#666666' elif len(groups) == 2: point_color = {groups[0]: '#666666', groups[1]: '#FF6666'} point_color_is_dict = True else: error_message = ( f'point_color must be specified manually when there are ' f'three or more groups; here, there are {len(groups)!r}') raise ValueError(error_message) elif point_color_is_dict: for group, group_point_color in point_color.items(): if not plt.matplotlib.colors.is_color_like(group_point_color): error_message = ( f'point_color[{group!r}] is not a valid Matplotlib ' f'color or sequence of valid colors') raise ValueError(error_message) point_color = { group: plt.matplotlib.colors.to_hex(group_point_color) for group, group_point_color in point_color.items()} else: if not plt.matplotlib.colors.is_color_like(point_color): error_message = ( f'point_color is not a valid Matplotlib color or ' f'sequence of valid colors') raise ValueError(error_message) point_color = plt.matplotlib.colors.to_hex(point_color) line_color_is_dict = isinstance(line_color, dict) if line_color is None: if groups is None: line_color = '#000000' elif len(groups) == 2: line_color = {groups[0]: '#000000', groups[1]: '#FF0000'} line_color_is_dict = True else: error_message = ( f'line_color must be specified manually when there are ' f'three or more groups; here, there are {len(groups)!r}') raise ValueError(error_message) elif line_color_is_dict: for group, group_line_color in line_color.items(): if not plt.matplotlib.colors.is_color_like(group_line_color): error_message = ( f'line_color[{group!r}] is not a valid Matplotlib ' f'color or sequence of valid colors') raise ValueError(error_message) line_color = { group: plt.matplotlib.colors.to_hex(group_line_color) for group, group_line_color in line_color.items()} else: if not plt.matplotlib.colors.is_color_like(line_color): error_message = ( f'line_color is not a valid Matplotlib color or ' f'sequence of valid colors') raise ValueError(error_message) line_color = plt.matplotlib.colors.to_hex(line_color) # Check that `point_size` and `line_width` are positive numbers or # dicts thereof point_size_is_dict = isinstance(point_size, dict) line_width_is_dict = isinstance(line_width, dict) for number, number_name, is_dict in ( (point_size, 'point_size', point_size_is_dict), (line_width, 'line_width', line_width_is_dict)): if is_dict: for group, group_number in number.items(): check_type(group_number, f'{number_name}[{group!r}]', (int, float), 'a positive number') check_bounds(group_number, f'{number_name}[{group!r}]', 0, left_open=True) else: check_type(number, number_name, (int, float), 'a positive number') check_bounds(number, number_name, 0, left_open=True) # For each of the kwargs arguments, if the argument was specified, # check that it is a dictionary and that all its keys are strings. for kwargs, kwargs_name in ((figure_kwargs, 'figure_kwargs'), (scatter_kwargs, 'scatter_kwargs'), (plot_kwargs, 'plot_kwargs'), (legend_kwargs, 'legend_kwargs'), (xlabel_kwargs, 'xlabel_kwargs'), (ylabel_kwargs, 'ylabel_kwargs'), (title_kwargs, 'title_kwargs')): if kwargs is not None: check_type(kwargs, kwargs_name, dict, 'a dictionary') for key in kwargs: if not isinstance(key, str): error_message = ( f'all keys of {kwargs_name} must be strings, but ' f'it contains a key of type ' f'{type(key).__name__!r}') raise TypeError(error_message) # If using voomByGroup, for each of `scatter_kwargs` and `plot_kwargs`, # if the kwarg was specified, check that either all keys are group # names and in the correct order, or that no keys are group names. If # all keys are group names, check that all values are either `None` or # dictionaries with all-string keys, and make note that the kwargs is a # nested dict. scatter_kwargs_is_nested_dict = False plot_kwargs_is_nested_dict = False if groups is not None: for kwargs, kwargs_name in ((scatter_kwargs, 'scatter_kwargs'), (plot_kwargs, 'plot_kwargs')): if kwargs is not None: if tuple(kwargs) == groups: # The kwargs's keys exactly match the group names for key, value in kwargs.items(): if value is not None: check_type(value, f'{kwargs_name}[{key!r}]', dict, 'a dictionary') for inner_key in value: if not isinstance(inner_key, str): error_message = ( f'all keys of ' f'{kwargs_name}[{key!r}] must be ' f'strings, but it contains a key ' f'of type ' f'{type(inner_key).__name__!r}') raise TypeError(error_message) if kwargs is scatter_kwargs: scatter_kwargs_is_nested_dict = True else: plot_kwargs_is_nested_dict = True else: # Check that none of the kwargs's keys are group names for group in groups: if group in kwargs: if set(groups) == set(kwargs): error_message = ( f'{kwargs_name}.keys() does have the ' f'same groups as ' f'self.groups[{cell_type!r}], but ' f'they are in a different order') raise ValueError(error_message) else: error_message = ( f'some keys of {kwargs_name}.keys() ' f'are groups in ' f'self.groups[{cell_type!r}], but ' f'others are not') raise ValueError(error_message) # Override the defaults for certain keys of `scatter_kwargs` default_scatter_kwargs = dict(rasterized=True, linewidths=0, edgecolors=(0, 0, 0, 0)) if scatter_kwargs_is_nested_dict: for key, value in scatter_kwargs.items(): scatter_kwargs[key] = default_scatter_kwargs | value \ if value is not None else default_scatter_kwargs else: scatter_kwargs = default_scatter_kwargs | scatter_kwargs \ if scatter_kwargs is not None else default_scatter_kwargs # Set `plot_kwargs` to `{}` if it is `None`, or set the `None` values # of `plot_kwargs` to `{}` if `plot_kwargs` is a nested dict if plot_kwargs is None: plot_kwargs = {} elif plot_kwargs_is_nested_dict: for key, value in plot_kwargs.items(): if value is None: plot_kwargs[key] = {} # Check that `scatter_kwargs` does not contain the `s` or # `c`/`color`/`norm`/`vmin`/`vmax` keys and that `plot_kwargs` does # not contain the `c`/`color`/`norm`/`vmin`/`vmax` or `linewidth` keys, # or that their non-`None` values do not contain these keys if a nested # dict for kwargs, kwargs_name, alternate_color, is_nested_dict in ( (scatter_kwargs, 'scatter_kwargs', 'line_color', scatter_kwargs_is_nested_dict), (plot_kwargs, 'plot_kwargs', 'point_color', plot_kwargs_is_nested_dict)): bad_keys = (('linewidth', 'line_width') if kwargs is plot_kwargs else ('s', 'point_size'), ('c', alternate_color), ('color', alternate_color), ('norm', alternate_color), ('vmin', alternate_color), ('vmax', alternate_color)) if is_nested_dict: for key, value in kwargs.items(): if value is not None: for bad_key, alternate_argument in bad_keys: if bad_key in value: error_message = ( f'{bad_key!r} cannot be specified as a ' f'key in {kwargs_name}[{key}!r]; specify ' f'the {alternate_argument} argument ' f'instead') raise ValueError(error_message) elif kwargs is not None: for bad_key, alternate_argument in bad_keys: if bad_key in kwargs: error_message = ( f'{bad_key!r} cannot be specified as a key in ' f'{kwargs_name}; specify the {alternate_argument} ' f'argument instead') raise ValueError(error_message) # Check that `legend` is Boolean. If not using voomByGroup, check that # the user did not specify `legend=False`. check_type(legend, 'legend', bool, 'Boolean') if groups is None: if not legend: error_message = ( 'legend=False cannot be specified when there are no ' 'groups, since it would be redundant: without groups, ' 'there will never be a legend') raise ValueError(error_message) # Override the defaults for certain values of `legend_kwargs`; check # that it is `None` when not using a legend default_legend_kwargs = dict(frameon=False) if legend_kwargs is not None: if groups is None: error_message = ( 'legend_kwargs cannot be specified when there are no ' 'groups, since there will not be a legend') raise ValueError(error_message) if not legend: error_message = \ 'legend_kwargs cannot be specified when legend=False' raise ValueError(error_message) legend_kwargs = default_legend_kwargs | legend_kwargs else: legend_kwargs = default_legend_kwargs # If `title` was specified, check that it is a string if title is not None: check_type(title, 'title', str, 'a string') # Check that `title_kwargs` is `None` when `title` is `None` if title is None and title_kwargs is not None: error_message = 'title_kwargs cannot be specified when title=None' raise ValueError(error_message) # Check that `xlabel` is a string or `None`; if `None`, check that # `xlabel_kwargs` is `None` as well. Ditto for `ylabel`. for arg, arg_name, arg_kwargs in ( (xlabel, 'xlabel', xlabel_kwargs), (ylabel, 'ylabel', ylabel_kwargs)): if arg is not None: check_type(arg, arg_name, str, 'a string') elif arg_kwargs is not None: error_message = \ f'{arg_name}_kwargs must be None when {arg_name} is None' raise ValueError(error_message) # Check that `despine` is Boolean check_type(despine, 'despine', bool, 'Boolean') # Override the defaults for certain values of `savefig_kwargs` default_savefig_kwargs = \ dict(dpi=300, bbox_inches='tight', pad_inches='layout', transparent=filename is not None and filename.endswith('.pdf')) savefig_kwargs = default_savefig_kwargs | savefig_kwargs \ if savefig_kwargs is not None else default_savefig_kwargs # If `ax` is `None`, create a new figure with # `constrained_layout=True`; otherwise, check that it is a Matplotlib # axis make_new_figure = ax is None try: if make_new_figure: default_figure_kwargs = dict(layout='constrained') figure_kwargs = default_figure_kwargs | figure_kwargs \ if figure_kwargs is not None else default_figure_kwargs plt.figure(**figure_kwargs) ax = plt.gca() else: check_type(ax, 'ax', plt.Axes, 'a Matplotlib axis') if groups is not None: if legend: legend_patches = [] for group in groups: # Get this group's point size, point color, line color, # line width, plot kwargs, and scatter kwargs group_point_size = point_size[group] \ if point_size_is_dict else point_size group_point_color = point_color[group] \ if point_color_is_dict else point_color group_line_color = line_color[group] \ if line_color_is_dict else line_color group_line_width = line_width[group] \ if line_width_is_dict else line_width group_scatter_kwargs = scatter_kwargs[group] \ if scatter_kwargs_is_nested_dict else scatter_kwargs group_plot_kwargs = plot_kwargs[group] \ if plot_kwargs_is_nested_dict else plot_kwargs # Plot the scatter plot for this group ax.scatter(voom_plot_data[f'xy_x_{group}'].drop_nulls(), voom_plot_data[f'xy_y_{group}'].drop_nulls(), s=group_point_size, c=group_point_color, **group_scatter_kwargs) # Plot the LOESS trendline for this group ax.plot(voom_plot_data[f'line_x_{group}'].drop_nulls(), voom_plot_data[f'line_y_{group}'].drop_nulls(), c=group_line_color, linewidth=group_line_width, **group_plot_kwargs) # Create a rectangle for the legend for this group, where # the border matches the color of the trendline and the # fill matches the color of the scatter plot points if legend: legend_patches.append(plt.matplotlib.patches.Patch( facecolor=group_point_color, edgecolor=group_line_color, linewidth=group_line_width, label=group)) # Add the legend if legend: ax.legend(handles=legend_patches, **legend_kwargs) else: # Plot the scatter plot ax.scatter(voom_plot_data['xy_x'], voom_plot_data['xy_y'], s=point_size, c=point_color, **scatter_kwargs) # Plot the LOESS trendline ax.plot(voom_plot_data['line_x'], voom_plot_data['line_y'], c=line_color, linewidth=line_width, **plot_kwargs) # Add the title and axis labels if xlabel is not None: if xlabel_kwargs is None: xlabel_kwargs = {} ax.set_xlabel(xlabel, **xlabel_kwargs) if ylabel is not None: if ylabel_kwargs is None: ylabel_kwargs = {} ax.set_ylabel(ylabel, **ylabel_kwargs) if title is not False: if title_kwargs is None: title_kwargs = {} ax.set_title(title[cell_type] if isinstance(title, dict) else title if isinstance(title, str) else title(cell_type) if isinstance(title, Callable) else cell_type, **title_kwargs) # Despine, if specified if despine: spines = ax.spines spines['top'].set_visible(False) spines['right'].set_visible(False) # Save; override the defaults for certain keys of `savefig_kwargs` if filename is not None: default_savefig_kwargs = \ dict(dpi=300, bbox_inches='tight', pad_inches='layout', transparent=filename is not None and filename.endswith('.pdf')) savefig_kwargs = default_savefig_kwargs | savefig_kwargs \ if savefig_kwargs is not None else default_savefig_kwargs with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) plt.savefig(filename, **savefig_kwargs) if make_new_figure: plt.close() except: # If we made a new figure, make sure to close it if there's an # exception (but not if there was no error and `filename` is # `None`, in case the user wants to modify it further before # saving) if make_new_figure: plt.close() raise
[docs] def plot_volcano(self, cell_type: str, filename: str | Path | None = None, /, *, ax: 'Axes' | None = None, figure_kwargs: dict[str, Any] | None = None, significance_column: str = 'p', threshold: int | float | np.integer | np.floating = 0.05, genes_to_label: int | np.integer | str | Iterable[str] = 10, label_kwargs: dict[str, Any] | None = None, upregulated_size: int | float | np.integer | np.floating = 6, downregulated_size: int | float | np.integer | np.floating = 6, non_significant_size: int | float | np.integer | np.floating = 4, upregulated_color: Color = '#FC4E07', downregulated_color: Color = '#00AFBB', non_significant_color: Color = 'lightgray', upregulated_scatter_kwargs: dict[str, Any] | None = None, downregulated_scatter_kwargs: dict[str, Any] | None = None, non_significant_scatter_kwargs: dict[str, Any] | None = None, legend: bool = True, legend_kwargs: dict[str, Any] | None = None, title: str | None = None, title_kwargs: dict[str, Any] | None = None, xlabel: str | None = '$log_2(FC)$', xlabel_kwargs: dict[str, Any] | None = None, ylabel: str | None = '$-log_{10}(p)$', ylabel_kwargs: dict[str, Any] | None = None, despine: bool = True, savefig_kwargs: dict[str, Any] | None = None) -> None: """ Generate a volcano plot of the DE hits, with negative log p-values (or another `significance_column`) on the y-axis plotted against log fold changes on the x-axis. Upregulated, downregulated and non-significant genes are plotted in three different colors. Args: cell_type: the cell type to generate the volcano plot for filename: the file to save to. If `None`, generate the plot but do not save it, which allows it to be shown interactively or modified further before saving. ax: the Matplotlib axes to save the plot onto; if `None`, create a new figure with Matpotlib's constrained layout and plot onto it figure_kwargs: a dictionary of keyword arguments to be passed to `plt.figure` when `ax` is `None`, such as: - `figsize`: a two-element sequence of the width and height of the figure in inches. Defaults to Matplotlib's default of `[6.4, 4.8]`. - `layout`: the layout mechanism used by Matplotlib to avoid overlapping plot elements. Defaults to `'constrained'`, instead of Matplotlib's default of `None`. significance_column: the name of a numeric column of `self.table` to determine significance from threshold: the significance threshold corresponding to `significance_column` genes_to_label: an integer number of top DE genes to label, a name or sequence of names of genes to label, or `None` to not add labels. If an integer, only DE genes will be labeled, even if `genes_to_label` is larger than the number of DE genes. label_kwargs: a dictionary of keyword arguments to be passed to `textalloc.allocate()` when adding gene labels to control the text properties, such as: - `textcolor`/`textsize`: the text color and size - `x_scatter`/`y_scatter`: the x/y coordinates of points in the scatter plot, to repel labels away from. Defaults to all points in the plot. - `min_distance`/`max_distance`: the minimum and maximum distances from each point to its label, as a proportion of the width of the x-axis. Defaults to 0 and 0.02, instead of textalloc's defaults of 0.015 and 0.2 - `draw_lines`: whether to draw lines between each label and its corresponding point. Defaults to `False`, instead of textalloc's default of `True`. See [here](https://github.com/ckjellson/textalloc#parameters) for the full list of possible arguments. Can only be specified when `genes_to_label` is non-zero. upregulated_size: the size of each upregulated gene's point downregulated_size: the size of each downregulated gene's point non_significant_size: the size of each non-significant gene's point upregulated_color: the color of each upregulated gene's point. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. downregulated_color: the color of each downregulated gene's point. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. non_significant_color: the color of each non-significant gene's point. Can be any valid Matplotlib color, like a hex string (e.g. `'#FF0000'`), a named color (e.g. 'red'), a 3- or 4-element RGB/RGBA tuple of integers 0-255 or floats 0-1, or a single float 0-1 for grayscale. upregulated_scatter_kwargs: a dictionary of keyword arguments to be passed to `ax.scatter()` for upregulated genes, such as: - `rasterized`: whether to convert the scatter plot points to a raster (bitmap) image when saving to a vector format like PDF. Defaults to `True`, instead of Matplotlib's default of `False`. - `marker`: the shape to use for plotting each gene - `alpha`: the transparency of each point - `linewidths` and `edgecolors`: the width and color of the borders around each marker. These are absent by default (`linewidths=0`, `edgecolors=(0, 0, 0, 0)`), unlike Matplotlib's default. Both arguments can be either single values or sequences. - `zorder`: the order in which the genes are plotted, with higher values appearing on top of lower ones. Specifying `s` or `c`/`color`/`norm`/ `vmin`/`vmax` will raise an error, since these arguments conflict with the `upregulated_size` and `upregulated_color` arguments, respectively. downregulated_scatter_kwargs: a dictionary of keyword arguments to be passed to `ax.scatter()` for downregulated genes; see the documentation of the `upregulated_scatter_kwargs` argument for details non_significant_scatter_kwargs: a dictionary of keyword arguments to be passed to `ax.scatter()` for non-significant genes; see the documentation of the `upregulated_scatter_kwargs` argument for details legend: whether to add a legend showing the marker style for upregulated, downregulated, and non-significant points legend_kwargs: a dictionary of keyword arguments to be passed to `ax.legend()` to modify the legend, such as: - `loc`, `bbox_to_anchor`, and `bbox_transform` to set its location. - `prop`, `fontsize`, and `labelcolor` to set its font properties - `facecolor` and `framealpha` to set its background color and transparency - `frameon=True` or `edgecolor` to add or color its border. `frameon` defaults to `False`, instead of Matplotlib's default of `True`. - `title` to add a legend title Can only be specified when `legend=True`. title: the title of the plot, or `None` to not add a title title_kwargs: a dictionary of keyword arguments to be passed to `ax.set_title()` to control text properties, such as `color` and `size`. Can only be specified when `title` is not `None`. xlabel: the x-axis label, or `None` to not label the x-axis xlabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_xlabel()` to control text properties, such as `color` and `size`. Can only be specified when `xlabel` is not `None`. ylabel: the y-axis label, or `None` to not label the y-axis ylabel_kwargs: a dictionary of keyword arguments to be passed to `ax.set_ylabel()` to control text properties, such as `color` and `size`. Can only be specified when `ylabel` is not `None`. despine: whether to remove the top and right spines (borders of the plot area) from the volcano plot savefig_kwargs: a dictionary of keyword arguments to be passed to `plt.savefig()`, such as: - `dpi`: defaults to 300 instead of Matplotlib's default of 150 - `bbox_inches`: the bounding box of the portion of the figure to save; defaults to 'tight' (crop out any blank borders) instead of Matplotlib's default of `None` (save the entire figure) - `pad_inches`: the number of inches of padding to add on each of the four sides of the figure when saving. Defaults to 'layout' (use the padding from the constrained layout engine), instead of Matplotlib's default of 0.1. - `transparent`: whether to save with a transparent background; defaults to `True` if saving to a PDF (i.e. when `PNG=False`) and `False` if saving to a PNG, instead of Matplotlib's default of always being `False`. Can only be specified when `filename` is specified. """ # Import matplotlib signal.signal(signal.SIGINT, signal.SIG_IGN) try: import matplotlib.pyplot as plt finally: signal.signal(signal.SIGINT, signal.default_int_handler) # Check that `cell_type` is a cell type in this DE object check_type(cell_type, 'cell_type', str, 'a string') if cell_type not in self.table['cell_type']: error_message = \ f'cell_type {cell_type!r} is not a cell type in this DE object' raise ValueError(error_message) # If `filename` was specified, check that it is a string or # `pathlib.Path` and that its base directory exists; if `filename` is # `None`, make sure `savefig_kwargs` is also `None` if filename is not None: check_type(filename, 'filename', (str, Path), 'a string or pathlib.Path') directory = os.path.dirname(filename) if directory and not os.path.isdir(directory): error_message = ( f'{filename} refers to a file in the directory ' f'{directory!r}, but this directory does not exist') raise NotADirectoryError(error_message) filename = str(filename) elif savefig_kwargs is not None: error_message = 'savefig_kwargs must be None when filename is None' raise ValueError(error_message) # If `figure_kwargs` was specified, check that `ax` is `None` if figure_kwargs is not None and ax is not None: error_message = ( 'figure_kwargs must be None when ax is not None, since a new ' 'figure does not need to be generated when plotting onto an ' 'existing axis') raise ValueError(error_message) # Check that `significance_column` is the name of a floating-point # column in `self.table` check_type(significance_column, 'significance_column', str, 'a string') if significance_column not in self.table: error_message = ( f'significance_column ({significance_column!r}) is not a ' f'column of self.table') raise ValueError(error_message) check_dtype(self.table[significance_column], f'self.table[{significance_column!r}]', 'floating-point') # Check that `threshold` is greater than 0 and less than or equal to 1 check_type(threshold, 'threshold', (int, float), 'a number > 0 and ≤ 1') check_bounds(threshold, 'threshold', 0, 1, left_open=True) # Subset `self.table` to the selected cell type, and log-transform the # significance column and the threshold table = self.table.filter(cell_type=cell_type)\ .with_columns(-pl.col(significance_column).log10()) threshold = -np.log10(threshold) # Check that `genes_to_label` is an integer, a sequence of strings, or # `None`. If an integer, take that many gene names (up to the number of # DE genes). if isinstance(genes_to_label, (int, np.integer)): label = genes_to_label != 0 if label: k = min(genes_to_label, (table[significance_column] >= threshold).sum()) top_DE_genes = table.top_k(k, by=significance_column) x_to_label = top_DE_genes['logFC'] y_to_label = top_DE_genes[significance_column] genes_to_label = top_DE_genes['gene'] elif label_kwargs is not None: error_message = ( 'label_kwargs cannot be specified when genes_to_label=0, ' 'since no genes are being labeled') raise ValueError(error_message) else: label = True if genes_to_label is not None: genes_to_label = \ to_tuple_checked(genes_to_label, 'genes_to_label', str, 'strings') genes_to_label = pl.DataFrame({'gene': genes_to_label})\ .join(table.select('gene', 'logFC', significance_column), how='left', on='gene') num_missing = genes_to_label['logFC'].null_count() if num_missing == len(genes_to_label): error_message = ( "none of the specified genes were found in " "table['gene']") raise ValueError(error_message) elif num_missing > 0: gene = genes_to_label\ .filter(pl.col.logFC.is_null())['gene'][0] error_message = ( f"one of the specified genes, {gene!r}, was not found " f"in table['gene']") raise ValueError(error_message) x_to_label = genes_to_label['logFC'] y_to_label = genes_to_label[significance_column] genes_to_label = genes_to_label['gene'] if label: import textalloc # Check that `upregulated_size`, `downregulated_size`, and # `non_significant_size` are positive numbers for size, size_name in (upregulated_size, 'upregulated_size'), \ (downregulated_size, 'downregulated_size'), \ (non_significant_size, 'non_significant_size'): check_type(size, size_name, (int, float), 'a positive number') check_bounds(size, size_name, 0, left_open=True) # Check that `upregulated_color`, `downregulated_color`, and # `non_significant_color` are valid Matplotlib colors, and convert them # to hex if not plt.matplotlib.colors.is_color_like(upregulated_color): error_message = 'upregulated_color is not a valid Matplotlib color' raise ValueError(error_message) upregulated_color = plt.matplotlib.colors.to_hex(upregulated_color) if not plt.matplotlib.colors.is_color_like(downregulated_color): error_message = \ 'downregulated_color is not a valid Matplotlib color' raise ValueError(error_message) downregulated_color = plt.matplotlib.colors.to_hex(downregulated_color) if not plt.matplotlib.colors.is_color_like(non_significant_color): error_message = \ 'non_significant_color is not a valid Matplotlib color' raise ValueError(error_message) non_significant_color = \ plt.matplotlib.colors.to_hex(non_significant_color) # Check that the three `scatter_kwargs` arguments do not contain # the `s` or `c`/`color`/`cmap`/`norm`/`vmin`/`vmax` keys for kwargs, kwargs_prefix in ( (upregulated_scatter_kwargs, 'upregulated'), (downregulated_scatter_kwargs, 'downregulated'), (non_significant_scatter_kwargs, 'non_significant')): if kwargs is None: continue if 's' in kwargs: error_message = ( f"'s' cannot be specified as a key in " f"{kwargs_prefix}_scatter_kwargs; specify the " f"{kwargs_prefix}_size argument instead") raise ValueError(error_message) for key in 'c', 'color', 'cmap', 'norm', 'vmin', 'vmax': if key in kwargs: error_message = ( f'{key!r} cannot be specified as a key in ' f'scatter_kwargs; specify the {kwargs_prefix}_color ' f'argument instead') raise ValueError(error_message) # Override the defaults for certain values of the three # `scatter_kwargs` arguments default_scatter_kwargs = dict(rasterized=True, linewidths=0, edgecolors=(0, 0, 0, 0)) upregulated_scatter_kwargs = \ default_scatter_kwargs | upregulated_scatter_kwargs \ if upregulated_scatter_kwargs is not None else \ default_scatter_kwargs downregulated_scatter_kwargs = \ default_scatter_kwargs | downregulated_scatter_kwargs \ if downregulated_scatter_kwargs is not None else \ default_scatter_kwargs non_significant_scatter_kwargs = \ default_scatter_kwargs | non_significant_scatter_kwargs \ if non_significant_scatter_kwargs is not None else \ default_scatter_kwargs # Check that `title` is a string or `None`; if `None`, check that # `title_kwargs` is `None` as well. Ditto for `xlabel` and `ylabel`. for arg, arg_name, arg_kwargs in ( (title, 'title', title_kwargs), (xlabel, 'xlabel', xlabel_kwargs), (ylabel, 'ylabel', ylabel_kwargs)): if arg is not None: check_type(arg, arg_name, str, 'a string') elif arg_kwargs is not None: error_message = \ f'{arg_name}_kwargs must be None when {arg_name} is None' raise ValueError(error_message) # For each of the kwargs arguments, if the argument was specified, # check that it is a dictionary and that all its keys are strings. for kwargs, kwargs_name in ((figure_kwargs, 'figure_kwargs'), (label_kwargs, 'label_kwargs'), (upregulated_scatter_kwargs, 'upregulated_scatter_kwargs'), (downregulated_scatter_kwargs, 'downregulated_scatter_kwargs'), (non_significant_scatter_kwargs, 'non_significant_scatter_kwargs'), (legend_kwargs, 'legend_kwargs'), (title_kwargs, 'title_kwargs'), (xlabel_kwargs, 'xlabel_kwargs'), (ylabel_kwargs, 'ylabel_kwargs'), (savefig_kwargs, 'savefig_kwargs')): if kwargs is not None: check_type(kwargs, kwargs_name, dict, 'a dictionary') for key in kwargs: if not isinstance(key, str): error_message = ( f'all keys of {kwargs_name} must be strings, but ' f'it contains a key of type ' f'{type(key).__name__!r}') raise TypeError(error_message) # Check that `legend` and `despine` are Boolean check_type(legend, 'legend', bool, 'Boolean') check_type(despine, 'despine', bool, 'Boolean') # If `ax` is `None`, create a new figure; otherwise, check that it is a # Matplotlib axis make_new_figure = ax is None try: if make_new_figure: default_figure_kwargs = dict(layout='constrained') figure_kwargs = default_figure_kwargs | figure_kwargs \ if figure_kwargs is not None else default_figure_kwargs plt.figure(**figure_kwargs) ax = plt.gca() else: check_type(ax, 'ax', plt.Axes, 'a Matplotlib axis') # Make the volcano plot ax.scatter(*table .select('logFC', significance_column) .filter(pl.col(significance_column) >= threshold, pl.col.logFC > 0) .to_numpy() .T, s=upregulated_size, c=upregulated_color, label='Upregulated', **upregulated_scatter_kwargs) ax.scatter(*table .select('logFC', significance_column) .filter(pl.col(significance_column) >= threshold, pl.col.logFC < 0) .to_numpy() .T, s=downregulated_size, c=downregulated_color, label='Downregulated', **downregulated_scatter_kwargs) ax.scatter(*table .select('logFC', significance_column) .filter(pl.col(significance_column) < threshold) .to_numpy() .T, s=non_significant_size, c=non_significant_color, label='Non-significant', **non_significant_scatter_kwargs) ax.set_ylim(bottom=0) # Add labels, using textalloc to avoid overlap if label: default_label_kwargs = dict( ax=ax, x=x_to_label, y=y_to_label, text_list=genes_to_label, x_scatter=table['logFC'].to_numpy(), y_scatter=table[significance_column].to_numpy(), min_distance=0, max_distance=0.02, draw_lines=False) label_kwargs = default_label_kwargs | label_kwargs \ if label_kwargs is not None else default_label_kwargs textalloc.allocate(**label_kwargs) # Add the legend; override the defaults for certain values of # `legend_kwargs` if legend: default_legend_kwargs = dict(frameon=False) legend_kwargs = default_legend_kwargs | legend_kwargs \ if legend_kwargs is not None else default_legend_kwargs ax.legend(**legend_kwargs) # Add the title and axis labels if xlabel is not None: if xlabel_kwargs is None: xlabel_kwargs = {} ax.set_xlabel(xlabel, **xlabel_kwargs) if ylabel is not None: if ylabel_kwargs is None: ylabel_kwargs = {} ax.set_ylabel(ylabel, **ylabel_kwargs) if title is not None: if title_kwargs is None: title_kwargs = {} ax.set_title(title, **title_kwargs) # Despine, if specified if despine: spines = ax.spines spines['top'].set_visible(False) spines['right'].set_visible(False) # Save; override the defaults for certain keys of `savefig_kwargs` if filename is not None: default_savefig_kwargs = \ dict(dpi=300, bbox_inches='tight', pad_inches='layout', transparent=filename is not None and filename.endswith('.pdf')) savefig_kwargs = default_savefig_kwargs | savefig_kwargs \ if savefig_kwargs is not None else default_savefig_kwargs with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) plt.savefig(filename, **savefig_kwargs) if make_new_figure: plt.close() except: # If we made a new figure, make sure to close it if there's an # exception (but not if there was no error and `filename` is # `None`, in case the user wants to modify it further before # saving) if make_new_figure: plt.close() raise