scallops.visualize.distribution.comparative_effect_scatter
- scallops.visualize.distribution.comparative_effect_scatter(df_x, x_label, df_y, y_label, effect_size_column, highlight_mode='query', highlight_query=5, regression_params=None, legend_cols=4, match_axes=True, transform=None, colors=('purple', 'orange', 'blue'), top_right_spines=False, xeqy=None, axes_fontsize=20, ax=None, grouping='group', randomseed=42, **kwargs)
Plots and compares the effect sizes of two treatments with advanced highlighting and regression.
This function creates a scatter plot of effect sizes from two dataframes (e.g., two different treatments). It provides two main modes for highlighting points of interest:
query: Highlights points based on independent criteria applied to each dataframe, such as the top/bottom N values or a pandas query string.
regression: Fits a linear regression to the data and highlights points based on their relationship to the trend. Outliers are defined as points outside a prediction interval (PI), while other notable points (e.g., those at the extremes of the trend) are also highlighted.
A key feature is the consolidated regression_params dictionary, which provides a clean, extensible interface for controlling the regression analysis and its visual representation.
- Parameters:
df_x (DataFrame) – DataFrame for the first treatment (x-axis).
x_label (str) – Label for the x-axis.
df_y (DataFrame) – DataFrame for the second treatment (y-axis).
y_label (str) – Label for the y-axis.
effect_size_column (str) – The name of the column representing the effect size.
highlight_mode (str) – Method for highlighting. Either ‘query’ or ‘regression’.
highlight_query (str | int | dict) –
Defines what to highlight. - If highlight_mode=’query’: An int for top/bottom N points or a str query. - If highlight_mode=’regression’: A dict, e.g., {‘pi’: 0.95, ‘n_closest’: 10}. ‘pi’ sets the prediction
interval, and ‘n_closest’ sets the number of points to highlight within the interval.
regression_params (dict | None) –
A dictionary to control regression fitting and plotting. If None, no regression is performed. Expected keys: - ‘method’ (str): ‘ols’ (Ordinary Least Squares, default) or ‘ransac’ (for robust regression). - ‘residual_threshold’ (float): For RANSAC, the maximum residual for a point to be
considered an inlier. A larger value makes the fit less strict.
’ci_style’ (str): ‘fill’ or ‘line’ to draw the prediction interval.
’line_kws’ (dict): Keyword arguments for styling the regression line (e.g., {‘color’: ‘red’}).
’ci_kws’ (dict): Keyword arguments for styling the PI (e.g., {‘alpha’: 0.1}).
legend_cols (int) – Number of columns in the legend.
match_axes (bool) – If True, sets x and y axis limits to be the same.
transform (Callable | None) – An optional function to apply to the effect size column.
colors (Sequence[str]) – A list of three colors for highlighting: [Both, X-only, Y-only].
top_right_spines (bool) – If False, hides the top and right plot borders.
xeqy (dict | None) – A dictionary of keywords for plotting a y=x line (e.g., {‘color’: ‘grey’}).
axes_fontsize (int) – Font size for the x and y axis labels.
ax (Axes | None) – An existing Matplotlib Axes object to plot on. If None, one is created.
grouping (str) – The column name that contains the labels for the points (e.g., “Gene”).
randomseed (int) – The random seed for ransac regression.
kwargs – Additional keyword arguments passed to adjust_text for label placement.
- Returns:
A tuple containing the Matplotlib Figure and Axes objects.
- Example:
import pandas as pd import numpy as np import matplotlib.pyplot as plt # Creating sample data df_x = pd.DataFrame( { "Gene": [f"Gene{i}" for i in range(1, 11)], "Effect_Size": np.random.randn(10), } ) df_y = pd.DataFrame( { "Gene": [f"Gene{i}" for i in range(1, 11)], "Effect_Size": df_x["Effect_Size"] + np.random.randn(10) * 0.5, } ) # Define regression parameters regression_config = { "method": "ransac", "residual_threshold": df_y.effect_size.std(), "ci_style": "line", "line_kws": {"color": "black", "linestyle": "--", "linewidth": 1.5}, } # Plotting fig, ax = comparative_effect_scatter( df_x=df_x, x_label="Treatment A", df_y=df_y, y_label="Treatment B", effect_size_column="Effect_Size", grouping="Gene", highlight_mode="regression", highlight_query={"pi": 0.95, "n_closest": 4}, regression_params=regression_config, match_axes=True, xeqy={"color": "grey", "linestyle": ":"}, ) plt.show()
- Return type:
tuple[Figure | None, Axes, RANSACRegressor | None]