scallops.visualize.distribution.comparative_effect_scatter

scallops.visualize.distribution.comparative_effect_scatter(df_x, x_label, df_y, y_label, effect_size_column, highlight_mode='query', highlight_query=5, regression_params=None, legend_cols=4, match_axes=True, transform=None, colors=('purple', 'orange', 'blue'), top_right_spines=False, xeqy=None, axes_fontsize=20, ax=None, grouping='group', randomseed=42, **kwargs)

Plots and compares the effect sizes of two treatments with advanced highlighting and regression.

This function creates a scatter plot of effect sizes from two dataframes (e.g., two different treatments). It provides two main modes for highlighting points of interest:

  1. query: Highlights points based on independent criteria applied to each dataframe, such as the top/bottom N values or a pandas query string.

  2. regression: Fits a linear regression to the data and highlights points based on their relationship to the trend. Outliers are defined as points outside a prediction interval (PI), while other notable points (e.g., those at the extremes of the trend) are also highlighted.

A key feature is the consolidated regression_params dictionary, which provides a clean, extensible interface for controlling the regression analysis and its visual representation.

Parameters:
  • df_x (DataFrame) – DataFrame for the first treatment (x-axis).

  • x_label (str) – Label for the x-axis.

  • df_y (DataFrame) – DataFrame for the second treatment (y-axis).

  • y_label (str) – Label for the y-axis.

  • effect_size_column (str) – The name of the column representing the effect size.

  • highlight_mode (str) – Method for highlighting. Either ‘query’ or ‘regression’.

  • highlight_query (str | int | dict) –

    Defines what to highlight. - If highlight_mode=’query’: An int for top/bottom N points or a str query. - If highlight_mode=’regression’: A dict, e.g., {‘pi’: 0.95, ‘n_closest’: 10}. ‘pi’ sets the prediction

    interval, and ‘n_closest’ sets the number of points to highlight within the interval.

  • regression_params (dict | None) –

    A dictionary to control regression fitting and plotting. If None, no regression is performed. Expected keys: - ‘method’ (str): ‘ols’ (Ordinary Least Squares, default) or ‘ransac’ (for robust regression). - ‘residual_threshold’ (float): For RANSAC, the maximum residual for a point to be

    considered an inlier. A larger value makes the fit less strict.

    • ’ci_style’ (str): ‘fill’ or ‘line’ to draw the prediction interval.

    • ’line_kws’ (dict): Keyword arguments for styling the regression line (e.g., {‘color’: ‘red’}).

    • ’ci_kws’ (dict): Keyword arguments for styling the PI (e.g., {‘alpha’: 0.1}).

  • legend_cols (int) – Number of columns in the legend.

  • match_axes (bool) – If True, sets x and y axis limits to be the same.

  • transform (Callable | None) – An optional function to apply to the effect size column.

  • colors (Sequence[str]) – A list of three colors for highlighting: [Both, X-only, Y-only].

  • top_right_spines (bool) – If False, hides the top and right plot borders.

  • xeqy (dict | None) – A dictionary of keywords for plotting a y=x line (e.g., {‘color’: ‘grey’}).

  • axes_fontsize (int) – Font size for the x and y axis labels.

  • ax (Axes | None) – An existing Matplotlib Axes object to plot on. If None, one is created.

  • grouping (str) – The column name that contains the labels for the points (e.g., “Gene”).

  • randomseed (int) – The random seed for ransac regression.

  • kwargs – Additional keyword arguments passed to adjust_text for label placement.

Returns:

A tuple containing the Matplotlib Figure and Axes objects.

Example:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Creating sample data
df_x = pd.DataFrame(
    {
        "Gene": [f"Gene{i}" for i in range(1, 11)],
        "Effect_Size": np.random.randn(10),
    }
)
df_y = pd.DataFrame(
    {
        "Gene": [f"Gene{i}" for i in range(1, 11)],
        "Effect_Size": df_x["Effect_Size"] + np.random.randn(10) * 0.5,
    }
)

# Define regression parameters
regression_config = {
    "method": "ransac",
    "residual_threshold": df_y.effect_size.std(),
    "ci_style": "line",
    "line_kws": {"color": "black", "linestyle": "--", "linewidth": 1.5},
}

# Plotting
fig, ax = comparative_effect_scatter(
    df_x=df_x,
    x_label="Treatment A",
    df_y=df_y,
    y_label="Treatment B",
    effect_size_column="Effect_Size",
    grouping="Gene",
    highlight_mode="regression",
    highlight_query={"pi": 0.95, "n_closest": 4},
    regression_params=regression_config,
    match_axes=True,
    xeqy={"color": "grey", "linestyle": ":"},
)
plt.show()
Return type:

tuple[Figure | None, Axes, RANSACRegressor | None]