"""Untilities for generating plots."""
from typing import Iterable, List, Optional, Tuple, Union
from carabiner import cast
from carabiner.mpl import add_legend, colorblind_palette, grid
import pandas as pd
from pandas import DataFrame
import matplotlib.pyplot as plt
from matplotlib import axes, figure
from matplotlib.container import ErrorbarContainer
from matplotlib.collections import PathCollection
from matplotlib.colors import LogNorm
import numpy as np
from scipy import stats
PlotTuple = Tuple[axes.Axes, figure.Figure]
_DARK_GREY = "dimgrey"
_LIGHT_GREY = "lightgrey"
_MARKER_SIZE = 3.
_PANEL_SIZE = 2.5
def _plot_errbars(
ax: axes.Axes,
data: DataFrame,
x: str,
y: str,
ci: float = .95,
**kwargs
) -> Tuple[axes.Axes, ErrorbarContainer]:
grouped = data.groupby([x])[[y]]
this_mean = grouped.mean().reset_index()
this_sd = grouped.sem().reset_index()
ci = stats.norm.interval(ci)[-1]
y_is_all_nan = np.all(np.isnan(this_sd[y]))
errbar = ax.errorbar(
this_mean[x],
this_mean[y],
yerr=(this_sd[y] * ci) if not y_is_all_nan else 0.,
fmt='o-',
**kwargs
)
return ax, errbar
def _plot_scatter_open_circles(
ax: axes.Axes,
data: DataFrame,
x: str,
y: str,
color: str,
**kwargs
) -> Tuple[axes.Axes, PathCollection]:
scatter = ax.scatter(
x,
y,
data=data,
edgecolors=color,
facecolors='none',
**kwargs
)
return ax, scatter
def _plot_mean_and_scatter(
ax: axes.Axes,
data: DataFrame,
x: str,
y: str,
color: str,
ci: float = .95,
s: float = _MARKER_SIZE,
zorder: int = 0,
**kwargs
) -> Tuple[axes.Axes, PathCollection, ErrorbarContainer]:
ax, errbars = _plot_errbars(
ax,
data,
x,
y,
ci=ci,
color=color,
markerfacecolor=color,
markersize=s,
zorder=zorder,
**kwargs
)
ax, scatter = _plot_scatter_open_circles(
ax,
data,
x,
y,
color=color,
s=s * 3.,
zorder=zorder,
label='_default'
)
return ax, scatter, errbars
[docs]
def plot_dose_response(
data: DataFrame,
x: str,
y: str,
file_prefix: str,
color: Optional[str] = None,
color_control: Optional[str] = None,
facet: Optional[str] = None,
files: Optional[str] = None,
hlines: Optional[Iterable[float]] = None,
panel_size: float = _PANEL_SIZE,
format: str = 'pdf',
sharey: bool = False,
sharex: bool = False,
x_log: bool = False,
y_log: bool = False
) -> List[str]:
"""Plot dose response curves, optionally splitting data across files, facets and colors.
This is a flexible function for data exploration and presentation. Uses a color-blind friendly palette.
Parameters
----------
data : pandas.DataFrame
Input data in columnar format.
x : str
Column to use as x-axis.
y: str
Column to use as y-axis.
file_prefix : str
Prefix to use in output filenames.
color : str, optional
If provided, use this column to split data into separate colored lines.
color_control : str, optional
If provided, plot this value from the color column as a dark grey.
facet : str, optional
If provided, split plots into separate facets (panels) based on this column.
files : str, optional
If provided, split plots into separate files based on this column.
hlines : list of float, optional
Plot horizontal guidelines at these y-intercepts. Default: [0.].
panel_size : float, optional
Size of a single panel (facet) in inches. Default: 3.0.
format : str, optional
File format to save plots. Default: "pdf".
sharex : bool, optional
Whether to have shared x-axis ranges. Default: False.
sharey : bool, optional
Whether to have shared y-axis ranges. Default: False.
x_log : bool, optional
Whether to make x-axis log scale. Default: False.
y_log : bool, optional
Whether to make y-axis log scale. Default: False.
Returns
-------
list
Filenames in which plots were saved.
"""
if format.lower() not in ('pdf', 'png'):
raise NotImplementedError(f"Saving as {format=} not implemented")
DEFAULT = '__default__'
filenames = []
capsize, markersize = 2., _MARKER_SIZE
data[DEFAULT] = DEFAULT
color = color or DEFAULT
facet = facet or DEFAULT
files = files or DEFAULT
hlines = hlines or []
for dim in (color, facet, files):
data = data.query(f'{dim} != "blank"')
n_facets = data[facet].unique().size
n_facet_rows = int(np.ceil(np.sqrt(n_facets)))
n_facet_cols = int(np.ceil(n_facets / n_facet_rows))
assert (n_facet_cols * n_facet_rows) >= n_facets, f"{n_facets} != {n_facet_rows} x {n_facet_cols}"
for file_name, file_data in (
data.sort_values(by=x).groupby(files)
):
if files == DEFAULT:
filename = f'{file_prefix}.{format}'
else:
filename = f'{file_prefix}_{files}={file_name}.{format}'
fig, _ = grid(
nrow=n_facet_rows,
ncol=n_facet_cols,
aspect_ratio=1.1,
sharey=sharey,
sharex=sharex,
)
for ax, (facet_name, facet_data) in zip(fig.axes, file_data.groupby(facet)):
ax.set(
title=facet_name if facet_name != DEFAULT else '',
xlabel=x,
ylabel=y,
xscale='log' if x_log else "linear",
yscale='log' if y_log else "linear",
)
for i in hlines:
ax.axhline(i, c=_LIGHT_GREY, zorder=-5)
if (
color_control is not None and
color_control in facet_data[color].values
):
control_data = facet_data.query(f'{color} == "{color_control}"')
_plot_mean_and_scatter(
ax,
control_data,
x,
y,
color=_DARK_GREY,
s=markersize,
capsize=capsize,
zorder=5,
)
for i, (color_name, color_data) in enumerate(facet_data.groupby(color)):
if color_name != color_control:
_plot_mean_and_scatter(
ax,
color_data,
x,
y,
color=f"C{i}" if color_name != DEFAULT else "C1",
s=markersize,
capsize=capsize,
label=color_name if color_name != DEFAULT else '',
)
add_legend(ax)
fig.savefig(
filename,
dpi=300 if format.lower() == 'png' else 'figure',
bbox_inches='tight',
)
filenames.append(filename)
return filenames
[docs]
def plot_mean_sd(
data: DataFrame,
x: Union[str, Iterable[str]],
y: str,
panel_size: float = _PANEL_SIZE
) -> PlotTuple:
x = cast(x, to=list)
y_pos_mean, y_pos_sd = (
'calc_' + y + suffix for suffix in ('_pos_mean', '_pos_sd')
)
y_neg_mean, y_neg_sd = (
'calc_' + y + suffix for suffix in ('_neg_mean', '_neg_sd')
)
data_to_plot = data[x + [y_pos_mean, y_pos_sd, y_neg_mean, y_neg_sd]].copy()
data_to_plot['grouping'] = data_to_plot[x[0]].str.cat(data_to_plot[x[1:]], sep=':')
data_to_plot = data_to_plot.sort_values('grouping')
fig, ax = grid(
panel_size=panel_size / .7,
aspect_ratio=.7,
)
for m, s, label in zip((y_pos_mean, y_neg_mean), (y_pos_sd, y_neg_sd), ("Positive", "Negative")):
y1 = data_to_plot[m] - data_to_plot[s]
y2 = data_to_plot[m] + data_to_plot[s]
ax.plot(
'grouping',
m,
data=data_to_plot,
label=label,
)
ax.fill_between(
'grouping',
y1,
y2,
data=data_to_plot,
alpha=.7,
label=label,
)
ax.tick_params(
axis='x',
labelrotation=90,
labelsize='small',
)
ax.set(
xlabel='Batch',
ylabel=y,
)
add_legend(ax)
return fig, ax
[docs]
def plot_zprime(
data: DataFrame,
x: Union[str, Iterable[str]],
y: str,
panel_size: float = _PANEL_SIZE * 1.8
) -> PlotTuple:
x = cast(x, to=list)
fig, ax = grid(
panel_size=panel_size,
aspect_ratio=.8,
)
markersize = _MARKER_SIZE * 2.
y_ = ('calc_' + y + '_zprime')
data_to_plot = data[x + [y_]].copy()
x = [
x_ for x_ in x if (not x_.endswith('wavelength') or y in x_)
]
data_to_plot['grouping'] = (
data_to_plot[x[0]]
.str.cat(
data_to_plot[x[1:]],
sep=':',
)
)
data_to_plot = data_to_plot.sort_values('grouping')
ax.plot(
'grouping', y_,
data=data_to_plot,
)
ax.scatter(
'grouping', y_,
data=data_to_plot,
s=markersize,
)
for n in np.linspace(0., 1., num=3):
ax.axhline(n, color='lightgray')
ax.tick_params(
axis='x',
labelrotation=90,
labelsize='small',
)
ax.set(
xlabel='Batch',
ylabel=y + '_zprime',
)
return fig, ax
[docs]
def plot_heatmap(
data: DataFrame,
x: Union[str, Iterable[str]],
y: str,
panel_size: float = _PANEL_SIZE
) -> PlotTuple:
x = cast(x, to=list)
plates = data[x].drop_duplicates()
n_plates = plates.shape[0]
n_cols = int(np.ceil(np.sqrt(n_plates)))
n_rows = int(np.ceil(n_plates / n_cols))
n_panels = n_cols * n_rows
if not n_panels >= n_plates:
## This should never happen!
raise ValueError(
f"ERROR: Number of heatmap panels ({n_panels}) is less than the number of plates ({n_plates})!")
fig, axes = grid(
nrow=n_rows,
ncol=n_cols,
)
for ax in fig.axes[n_plates:]:
ax.set_visible(False)
for ax, (plate_name, plate_data) in zip(fig.axes, data.groupby(x)):
this_plate_data = pd.pivot_table(
plate_data,
index='row_id',
columns='column_id',
values=y,
)
axes_image = ax.imshow(
this_plate_data,
cmap='cividis',
)
cbar = fig.colorbar(
axes_image,
ax=ax,
shrink=.5,
orientation='vertical',
)
cbar.ax.tick_params(labelsize='xx-small')
ax.set_yticks(
np.arange(this_plate_data.index.values.size),
labels=this_plate_data.index.values,
size='x-small',
)
x_locs = [
x for x in range(this_plate_data.columns.values.size) if x % 2 == 0
]
ax.set_xticks(
x_locs,
labels=[str(loc + 1) for loc in x_locs],
size='x-small',
)
ax.set_title(
':'.join(plate_name),
fontdict={'fontsize': 8.},
)
return fig, axes
[docs]
def plot_histogram(
data: DataFrame,
x: str,
control_col: str,
negative: str,
positive: str,
panel_size: float = _PANEL_SIZE
) -> PlotTuple:
read_type = '_'.join(x.split('_')[1:3]) + '_wavelength'
data = data.query(f"{read_type} != ''")
if not read_type in data.columns:
raise KeyError(f"The column {read_type} is missing from the input data.")
n_read_types = data[read_type].unique().size
data = data[[control_col, read_type, x]].dropna().copy()
data = data[~np.isinf(data[x])].copy()
n_cols = 4
fig, axes = grid(
nrow=n_read_types,
ncol=n_cols,
squeeze=False,
)
for row, (wv, wv_df) in zip(axes, data.groupby(read_type)):
for i, control in enumerate((positive, negative, None)):
if control is not None:
q = f'{control_col} == "{control}"'
axes = row[0], row[2]
title = 'Controls'
else:
q = f'{control_col} not in ["{negative}", "{positive}"]'
axes = row[1], row[3]
title = 'Experiment'
this_data = wv_df.query(q)
n_bins = 5 + this_data.shape[0] // 20
bins_log = np.geomspace(
this_data[x].min(),
this_data[x].max(),
num=n_bins,
)
for ax, b in zip(axes, (n_bins, bins_log)):
ax.hist(
x,
data=this_data,
density=False,
histtype='stepfilled',
bins=b,
zorder=3 - i,
)
if x.endswith('_norm'):
for m in (0., 1.):
ax.axvline(
m,
color='lightgray',
zorder=0,
)
ax.set(
xlabel='_'.join(x.split('_')[1:]) + f': {wv}',
ylabel='Frequency',
title=title,
)
ax.set_xscale('log')
if title == "Controls":
add_legend(ax)
return fig, ax
[docs]
def plot_replicates(
data: DataFrame,
x: str,
grouping: Union[str, Iterable[str]],
control_col: str,
negative: str,
positive: str,
panel_size: float = _PANEL_SIZE
) -> PlotTuple:
grouping = cast(grouping, to=list)
read_type = '_'.join(x.split('_')[1:3]) + '_wavelength'
if not read_type in data.columns:
raise KeyError(f"The column {read_type} is missing from the input data.")
n_read_types = data[read_type].unique().size
do_log = True #n_gt_zero > .5
n_cols = 2 if do_log else 1
n_rows = n_read_types
fig, axes = grid(
nrow=n_rows,
ncol=n_cols,
squeeze=False,
)
for row, (wv, wv_df) in zip(axes, data.groupby(read_type)):
this_title = '_'.join(x.split('_')[1:]) + f': {wv}'
n_gt_zero = np.mean(wv_df[x] > 0)
sub_df = (
wv_df
.query('replicate < 3')
.assign(replicate=lambda x: 'rep_' + x['replicate'].astype(str))
)
df_wide = pd.pivot_table(
sub_df,
index=grouping,
columns='replicate',
values=x,
)
pearson_corr = np.corrcoef(
df_wide.values,
rowvar=False,
)[0, -1]
for ax in row:
for i, control in enumerate((positive, negative, None)):
if control is not None:
q = f'{control_col} == "{control}"'
else:
q = f'{control_col} not in ["{negative}", "{positive}"]'
this_data = df_wide.query(q)
ax.scatter(
'rep_1', 'rep_2',
data=this_data,
s=_MARKER_SIZE + (6. if control is not None else 0.),
zorder=3 - i,
)
ax.plot(
ax.get_xlim(), ax.get_xlim(),
color='lightgray',
zorder=0,
)
ax.set(
# aspect='equal',
xlabel='Replicate 1',
ylabel='Replicate 2',
)
ax.set_title(
this_title + f'\nr = {pearson_corr:.2f}',
fontdict={'fontsize': 8.},
)
if do_log:
df_gt_zero = df_wide.query('rep_1 > 0 and rep_2 > 0')
pearson_corr = np.corrcoef(
np.log(df_gt_zero),
rowvar=False,
)[0, -1]
ax.set(
xscale='log',
yscale='log',
)
ax.set_title(
this_title + f'\nr = {pearson_corr:.2f} | hidden points: {100. * (1. - n_gt_zero):.1f} %',
fontdict={'fontsize': 8.},
)
return fig, axes
[docs]
def plot_scatter(
data: DataFrame,
measurement_col: str,
x: str,
y: str,
color: Optional[str] = None,
log_color: bool = False,
hlines: Optional[Iterable[float]] = None,
vlines: Optional[Iterable[float]] = None,
x_log: bool = False,
y_log: bool = False,
panel_size: float = _PANEL_SIZE,
**kwargs
) -> PlotTuple:
hlines = hlines or []
vlines = vlines or []
data = data.copy()
fig, ax = grid(
panel_size=panel_size,
**kwargs
)
x_col, y_col, c_col = (
('calc_' + measurement_col + dimension)
if dimension is not None else dimension
for dimension in (x, y, color)
)
sc = ax.scatter(
x_col,
y_col,
c=data[c_col] if c_col is not None else c_col,
cmap="cividis" if c_col is not None else None,
norm="log" if log_color else None,
data=data,
s=_MARKER_SIZE,
zorder=3,
)
if c_col is not None:
cbar = fig.colorbar(sc)
cbar.set_label(c_col.split("_")[-1])
for y_ in hlines:
ax.axhline(
y_,
color=_LIGHT_GREY,
zorder=0,
)
for x_ in vlines:
ax.axvline(
x_,
color=_LIGHT_GREY,
zorder=0,
)
ax.set(
xscale="log" if x_log else "linear",
yscale="log" if y_log else "linear",
xlabel=x.split("_")[-1],
ylabel=y.split("_")[-1],
title=measurement_col,
)
return fig, ax