Source code for drugforge.ml.viz

import pandas
import seaborn as sns


[docs] def plot_split_losses( pred_tracker_dict, out_fn=None, splits=["train", "val", "test"], loss_label="Loss", legend_title="label", label_trans=None, for_fig=False, **kwargs, ): """ Plot overall losses per split by training epoch. Parameters ---------- pred_tracker_dict : dict[str, TrainingPredictionTracker] Dict mapping labels to pred trackers out_fn : Path, optional Path to save plot to splits : list[str], default=["train", "val", "test"] Which splits to actually plot loss_label : str, default="Loss" What to label the y-axis of the plot legend_title : str, default="label" Column name for the dict keys, which will be used as the Legend title by default label_trans : callable, optional Function that should take a string as input and return a dict mapping str -> str. This function will be applied to each label, and each key in the output will be added as a column to the DataFrame with its corresponding value as the entry for that row in DF for_fig : bool, default=False Plotting for a figure rather than just visualization. Will take some liberties with capitalization to make labels look a bit more professional kwargs : dict Anything else to pass directly to relplot """ # Build overall DF all_dfs = [] for lab, pred_tracker in pred_tracker_dict.items(): df = pred_tracker.to_plot_df(agg_compounds=True, agg_losses=True) if for_fig: df[legend_title] = lab.title() else: df[legend_title] = lab # Apply label transform and add any new columns if callable(label_trans): new_cols = label_trans(lab) for k, v in new_cols.items(): df[k] = v all_dfs.append(df) all_dfs = pandas.concat(all_dfs, ignore_index=True) # Subset all_dfs = all_dfs.loc[all_dfs["split"].isin(splits), :] # Capitalize for figure if for_fig: all_dfs["Split"] = [s.title() for s in all_dfs["split"]] splits_fig = [s.title() for s in splits] if ("hue" not in kwargs) and ("style" not in kwargs): # Figure out styles if len(pred_tracker_dict) > 1: # More than one different experiment, so use color for experiment and style # for split hue = legend_title hue_order = [lab.title() for lab in pred_tracker_dict.keys()] if len(splits) > 1: if for_fig: style = "Split" style_order = splits_fig else: style = "split" style_order = splits else: style = None style_order = None else: if for_fig: hue = "Split" hue_order = splits_fig else: hue = "split" hue_order = splits style = None style_order = None else: # Pull from kwargs hue = kwargs.pop("hue", None) hue_order = kwargs.pop("hue_order", None) style = kwargs.pop("style", None) style_order = kwargs.pop("style_order", None) # Other various kwargs aspect = kwargs.pop("aspect", 1.5) # Make plot # fig = plt.figure(figsize=(7, 5)) fg = sns.relplot( all_dfs, x="epoch", y="loss", hue=hue, style=style, hue_order=hue_order, style_order=style_order, kind="line", aspect=aspect, **kwargs, ) # Set axes fg.set_axis_labels("Training Epoch", loss_label) if out_fn: fg.savefig(out_fn, bbox_inches="tight", dpi=200) return fg
[docs] def plot_model_preds_scatter( pred_tracker_dict, stats_dict, out_fn=None, split="test", use_epoch=-1, label_trans=None, plot_stats=True, table_stats=False, **kwargs, ): """ Plot a scatterplot of experimental vs predicted values. Parameters ---------- pred_tracker_dict : dict[str, TrainingPredictionTracker] Dict mapping labels to pred trackers stats_dict : dict Dict mapping lab -> pred stats (generated by pred_tracker.calculate_pred_statistics) out_fn : Path, optional Path to save plot to split : str, default="test" Which split to plot use_epoch : int, default=-1 Which epoch of training to take predictions from. Set to -1 to use final epoch label_trans : callable, optional Function that should take a string as input and return a dict mapping str -> str. This function will be applied to each label, and each key in the output will be added as a column to the DataFrame with its corresponding value as the entry for that row in DF kwargs : dict Anything else to pass directly to relplot Returns ------- """ # Build overall DF all_dfs = [] for lab, pred_tracker in pred_tracker_dict.items(): df = pred_tracker.to_plot_df(agg_losses=True) df["label"] = lab # Apply label transform and add any new columns if callable(label_trans): new_cols = label_trans(lab) for k, v in new_cols.items(): df[k] = v all_dfs.append(df) all_dfs = pandas.concat(all_dfs, ignore_index=True) # Subset by split and epoch epoch_idx = [] for lab, g in all_dfs.groupby("label"): if use_epoch < 0: cur_use_epoch = g["epoch"].max() else: cur_use_epoch = use_epoch epoch_idx.extend(g.index[g["epoch"] == cur_use_epoch]) all_dfs = all_dfs.iloc[epoch_idx, :] split_idx = all_dfs["split"] == split all_dfs = all_dfs.loc[split_idx, :] # Set so the legend looks nicer legend_text_mapper = { -1: "Below Assay Range", 0: "In Assay Range", 1: "Above Assay Range", } all_dfs["Assay Range"] = list(map(legend_text_mapper.get, all_dfs["in_range"])) # If any facet_kws are passed in kwargs, update the defaults facet_kws = {"sharex": False, "sharey": False} | kwargs.pop("facet_kws", {}) col = kwargs.pop("col", "label") col_order = kwargs.pop("col_order", list(all_dfs[col].unique())) if table_stats and plot_stats: col_order += ["blank"] # plt.rc("font", size=18) fg = sns.relplot( data=all_dfs, x="target", y="pred", col=col, col_order=col_order, style="Assay Range", markers={ "Below Assay Range": "<", "In Assay Range": "o", "Above Assay Range": ">", }, style_order=["Below Assay Range", "In Assay Range", "Above Assay Range"], facet_kws=facet_kws, **kwargs, ) # Figure title fg.figure.subplots_adjust(top=0.8) fg.figure.suptitle("Test Set Predictions", fontweight="bold") # Axes bounds min_val = -0.5 max_val = all_dfs.loc[:, ["target", "pred"]].values.flatten().max() + 0.5 # Axis labels for ax in fg.axes[:, 0]: ax.set_ylabel(r"Predicted $\mathrm{pIC}_{50}$") for ax in fg.axes[-1, :]: ax.set_xlabel(r"Experimental $\mathrm{pIC}_{50}$") sns.move_legend(fg, loc="upper center", bbox_to_anchor=(0.5, 0), ncols=3) if table_stats and plot_stats: stats_table_text = [ [""], ["MAE"], ["RMSE"], ["Spearman's $\\rho$"], ["Kendall's $\\tau$"], ] for lab, ax in fg.axes_dict.items(): if lab == "blank": ax.set_title("") continue # Set title ax.set_title(lab, fontweight="bold") # Plot y=x line ax.plot( [min_val, max_val], [min_val, max_val], color="black", ls="--", ) # Shade 0.5 pIC50 and 1 pIC50 regions ax.fill_between( [min_val, max_val], [min_val - 0.5, max_val - 0.5], [min_val + 0.5, max_val + 0.5], color="gray", alpha=0.2, ) ax.fill_between( [min_val, max_val], [min_val - 1, max_val - 1], [min_val + 1, max_val + 1], color="gray", alpha=0.2, ) # Stats labels if table_stats and plot_stats: stats_table_text[0].append(lab) for i, stat in enumerate(["mae", "rmse", "sp_r", "tau"]): stats_str = ( f"{stats_dict[lab]['test'][stat]['value']:0.2f}" f"$_{{{stats_dict[lab]['test'][stat]['95ci_low']:0.2f}}}" f"^{{{stats_dict[lab]['test'][stat]['95ci_high']:0.2f}}}$" ) stats_table_text[i + 1].append(stats_str) elif plot_stats: stats_text = [] for stat, stat_label in zip( ["mae", "rmse", "sp_r", "tau"], ["MAE", "RMSE", "Spearman's $\\rho$", "Kendall's $\\tau$"], ): stats_str = ( f"{stat_label}: " f"{stats_dict[lab]['test'][stat]['value']:0.2f}" f"$_{{{stats_dict[lab]['test'][stat]['95ci_low']:0.2f}}}" f"^{{{stats_dict[lab]['test'][stat]['95ci_high']:0.2f}}}$" ) stats_text.append(stats_str) ax.text( 0.7, 0, "\n".join(stats_text), transform=ax.transAxes, va="bottom", linespacing=0.8, # fontsize=14, ) # Make it a square ax.set_aspect("equal", "box") ax.set_xlim((min_val, max_val)) ax.set_ylim((min_val, max_val)) if table_stats and plot_stats: ax = fg.axes.flatten()[-1] ax.set_axis_off() ax.table( cellText=stats_table_text, cellLoc="center", loc="center", edges="open" ) if out_fn: fg.savefig(out_fn, bbox_inches="tight", dpi=200) return fg