drugforge.ml.viz.plot_model_preds_scatter

drugforge.ml.viz.plot_model_preds_scatter(pred_tracker_dict, stats_dict, out_fn=None, split='test', use_epoch=-1, label_trans=None, plot_stats=True, table_stats=False, **kwargs)[source]

Plot a scatterplot of experimental vs predicted values.

Parameters:
  • pred_tracker_dict (dict[str, TrainingPredictionTracker]) – Dict mapping labels to pred trackers

  • stats_dict (dict) – Dict mapping lab -> pred stats (generated by pred_tracker.calculate_pred_statistics)

  • out_fn (Path, optional) – Path to save plot to

  • split (str, default=”test”) – Which split to plot

  • use_epoch (int, default=-1) – Which epoch of training to take predictions from. Set to -1 to use final epoch

  • label_trans (callable, optional) – Function that should take a string as input and return a dict mapping str -> str. This function will be applied to each label, and each key in the output will be added as a column to the DataFrame with its corresponding value as the entry for that row in DF

  • kwargs (dict) – Anything else to pass directly to relplot