Source code for drugforge.ml.loss

import numpy as np
import torch
from torch.nn import CrossEntropyLoss as TorchCrossEntropyLoss
from torch.nn import GaussianNLLLoss as TorchGaussianNLLLoss
from torch.nn import L1Loss as TorchL1Loss
from torch.nn import MSELoss as TorchMSELoss
from torch.nn import SmoothL1Loss as TorchSmoothL1Loss


[docs] class L1Loss(TorchL1Loss):
[docs] def __init__(self, loss_type=None): """ Class for calculating L1 (MAE) loss, with various options Parameters ---------- loss_type : str, optional Which type of loss to use: * None: vanilla MAE * "step": step MAE loss * "uncertainty": MAE loss with added uncertainty """ # No reduction so we can apply whatever adjustment to each sample super().__init__(reduction="none") if loss_type is not None: loss_type = loss_type.lower() if loss_type == "step": self.loss_function = self.step_loss elif loss_type == "uncertainty": self.loss_function = self.uncertainty_loss else: raise ValueError(f'Unknown loss_type "{loss_type}"') else: self.loss_function = super().forward self.loss_type = loss_type
[docs] def forward(self, pred, pose_preds, target, in_range, uncertainty): """ Dispatch method for calculating loss. All arguments should be passed regardless of actual loss function to keep an identical signature for this class. Data is passed to `self.loss_function`. """ if self.loss_type is None: # Just need to calculate mean to get MAE return self.loss_function(pred, target).mean() elif self.loss_type == "step": # Call step_loss return self.loss_function(pred, target, in_range) elif self.loss_type == "uncertainty": # Call uncertainty_loss return self.loss_function(pred, target, uncertainty)
[docs] def step_loss(self, pred, target, in_range=None): """ Step loss calculation. For `in_range` < 0, loss is returned as 0 if `pred` < `target`, otherwise MAE is calculated as normal. For `in_range` > 0, loss is returned as 0 if `pred` > `target`, otherwise MAE is calculated as normal. For `in_range` == 0, MAE is calculated as normal. Parameters ---------- pred : torch.Tensor Model prediction target : torch.Tensor Prediction target in_range : torch.Tensor, optional `target`'s presence in the dynamic range of the assay. Give a value of < 0 for `target` below lower bound, > 0 for `target` above upper bound, and 0 or None for inside range Returns ------- torch.Tensor Calculated loss """ # Calculate loss loss = super().forward(pred, target) # Calculate mask: # 1.0 - If pred or data is semiquant and prediction is inside the # assay range # 0.0 - If data is semiquant and prediction is outside the assay range # r < 0 -> measurement is below thresh, want to count if pred > target # r > 0 -> measurement is above thresh, want to count if pred < target mask = torch.tensor( [ 1.0 if ((r == 0) or (r is None)) else ((r < 0) == (t < i)) for i, t, r in zip( np.ravel(pred.detach().cpu()), np.ravel(target.detach().cpu()), np.ravel( in_range.detach().cpu() if in_range is not None else [None] * len(pred.flatten()) ), ) ] ) mask = mask.to(pred.device) # Need to add the max in the denominator in case there are no values that we # want to calculate loss for loss = (loss * mask).sum() / max(torch.sum(mask), 1) return loss
[docs] class MSELoss(TorchMSELoss):
[docs] def __init__(self, loss_type=None): """ Class for calculating MSE loss, with various options Parameters ---------- loss_type : str, optional Which type of loss to use: * None: vanilla MSE * "step": step MSE loss * "uncertainty": MSE loss with added uncertainty """ # No reduction so we can apply whatever adjustment to each sample super().__init__(reduction="none") if loss_type is not None: loss_type = loss_type.lower() if loss_type == "step": self.loss_function = self.step_loss elif loss_type == "uncertainty": self.loss_function = self.uncertainty_loss else: raise ValueError(f'Unknown loss_type "{loss_type}"') else: self.loss_function = super().forward self.loss_type = loss_type
[docs] def forward(self, pred, pose_preds, target, in_range, uncertainty): """ Dispatch method for calculating loss. All arguments should be passed regardless of actual loss function to keep an identical signature for this class. Data is passed to `self.loss_function`. """ if self.loss_type is None: # Just need to calculate mean to get MSE return self.loss_function(pred, target).mean() elif self.loss_type == "step": # Call step_loss return self.loss_function(pred, target, in_range) elif self.loss_type == "uncertainty": # Call uncertainty_loss return self.loss_function(pred, target, uncertainty)
[docs] def step_loss(self, pred, target, in_range=None): """ Step loss calculation. For `in_range` < 0, loss is returned as 0 if `pred` < `target`, otherwise MSE is calculated as normal. For `in_range` > 0, loss is returned as 0 if `pred` > `target`, otherwise MSE is calculated as normal. For `in_range` == 0, MSE is calculated as normal. Parameters ---------- pred : torch.Tensor Model prediction target : torch.Tensor Prediction target in_range : torch.Tensor, optional `target`'s presence in the dynamic range of the assay. Give a value of < 0 for `target` below lower bound, > 0 for `target` above upper bound, and 0 or None for inside range Returns ------- torch.Tensor Calculated loss """ # Calculate loss loss = super().forward(pred, target) # Calculate mask: # 1.0 - If pred or data is semiquant and prediction is inside the # assay range # 0.0 - If data is semiquant and prediction is outside the assay range # r < 0 -> measurement is below thresh, want to count if pred > target # r > 0 -> measurement is above thresh, want to count if pred < target mask = torch.tensor( [ 1.0 if ((r == 0) or (r is None)) else ((r < 0) == (t < i)) for i, t, r in zip( np.ravel(pred.detach().cpu()), np.ravel(target.detach().cpu()), np.ravel( in_range.detach().cpu() if in_range is not None else [None] * len(pred.flatten()) ), ) ] ) mask = mask.to(pred.device) # Need to add the max in the denominator in case there are no values that we # want to calculate loss for loss = (loss * mask).sum() / max(torch.sum(mask), 1) return loss
[docs] class SmoothL1Loss(TorchSmoothL1Loss):
[docs] def __init__(self, loss_type=None): """ Class for calculating smooth L1 loss, with various options Parameters ---------- loss_type : str, optional Which type of loss to use: * None: vanilla smooth L1 * "step": step smooth L1 loss * "uncertainty": smooth L1 loss with added uncertainty """ # No reduction so we can apply whatever adjustment to each sample super().__init__(reduction="none") if loss_type is not None: loss_type = loss_type.lower() if loss_type == "step": self.loss_function = self.step_loss elif loss_type == "uncertainty": self.loss_function = self.uncertainty_loss else: raise ValueError(f'Unknown loss_type "{loss_type}"') else: self.loss_function = super().forward self.loss_type = loss_type
[docs] def forward(self, pred, pose_preds, target, in_range, uncertainty): """ Dispatch method for calculating loss. All arguments should be passed regardless of actual loss function to keep an identical signature for this class. Data is passed to `self.loss_function`. """ if self.loss_type is None: # Just need to calculate mean return self.loss_function(pred, target).mean() elif self.loss_type == "step": # Call step_loss return self.loss_function(pred, target, in_range) elif self.loss_type == "uncertainty": # Call uncertainty_loss return self.loss_function(pred, target, uncertainty)
[docs] def step_loss(self, pred, target, in_range=None): """ Step loss calculation. For `in_range` < 0, loss is returned as 0 if `pred` < `target`, otherwise loss is calculated as normal. For `in_range` > 0, loss is returned as 0 if `pred` > `target`, otherwise loss is calculated as normal. For `in_range` == 0, loss is calculated as normal. Parameters ---------- pred : torch.Tensor Model prediction target : torch.Tensor Prediction target in_range : torch.Tensor, optional `target`'s presence in the dynamic range of the assay. Give a value of < 0 for `target` below lower bound, > 0 for `target` above upper bound, and 0 or None for inside range Returns ------- torch.Tensor Calculated loss """ # Calculate loss loss = super().forward(pred, target) # Calculate mask: # 1.0 - If pred or data is semiquant and prediction is inside the # assay range # 0.0 - If data is semiquant and prediction is outside the assay range # r < 0 -> measurement is below thresh, want to count if pred > target # r > 0 -> measurement is above thresh, want to count if pred < target mask = torch.tensor( [ 1.0 if ((r == 0) or (r is None)) else ((r < 0) == (t < i)) for i, t, r in zip( np.ravel(pred.detach().cpu()), np.ravel(target.detach().cpu()), np.ravel( in_range.detach().cpu() if in_range is not None else [None] * len(pred.flatten()) ), ) ] ) mask = mask.to(pred.device) # Need to add the max in the denominator in case there are no values that we # want to calculate loss for loss = (loss * mask).sum() / max(torch.sum(mask), 1) return loss
[docs] class GaussianNLLLoss(TorchGaussianNLLLoss):
[docs] def __init__(self, include_semiquant=True, fill_value=None): """ Class for calculating Gaussian NLL loss, with various options. Parameters ---------- include_semiquant : bool, default=True Whether to include semi-quantitative samples in the loss fill_value : float, optional If provided, use this value as the uncertainty for all semiquant predictions """ # No reduction so we can apply masking if desired super().__init__(reduction="none") self.include_semiquant = include_semiquant self.fill_value = fill_value
[docs] def forward(self, pred, pose_preds, target, in_range, uncertainty): """ Loss calculation Parameters ---------- pred : torch.Tensor Model prediction pose_preds : torch.Tensor Predictions for each pose target : torch.Tensor Prediction target in_range : torch.Tensor `target`'s presence in the dynamic range of the assay. Give a value of < 0 for `target` below lower bound, > 0 for `target` above upper bound, and 0 or None for inside range uncertainty : torch.Tensor Uncertainty in `target` measurements Returns ------- """ # Clone to avoid modifying the original uncertainty measurements uncertainty_clone = uncertainty.clone() # Fill in semiquant values if self.include_semiquant and (self.fill_value is not None): idx = [r != 0 for r in in_range] uncertainty_clone[idx] = self.fill_value # Calculate loss (need to square uncertainty to convert to variance) loss = super().forward(pred, target, uncertainty_clone**2) # Mask out losses for all semiquant measurements if not self.include_semiquant: mask = torch.tensor( [r == 0 for r in in_range], dtype=loss.dtype, device=loss.device ) loss *= mask return loss.sum()
[docs] class RangeLoss(torch.nn.Module):
[docs] def __init__(self, lower_lim, upper_lim): """ Class for calculating a loss to penalize predictions outside of the given range. Current implementation uses a squared difference penalty. Parameters ---------- lower_lim : float Bottom limit of acceptable range upper_lim : float Upper limit of acceptable range """ super().__init__() self.lower_lim = lower_lim self.upper_lim = upper_lim
[docs] def forward(self, pred, pose_preds, target, in_range, uncertainty): """ No loss for predictions within self range, otherwise calculate squared distance to closest bound. Parameters ---------- Parameters ---------- pred : torch.Tensor Model prediction target : torch.Tensor Prediction target in_range : torch.Tensor `target`'s presence in the dynamic range of the assay. Give a value of < 0 for `target` below lower bound, > 0 for `target` above upper bound, and 0 or None for inside range uncertainty : torch.Tensor Uncertainty in `target` measurements Returns ------- torch.Tensor Calculated loss """ if pred < self.lower_lim: return torch.pow(pred - self.lower_lim, 2) elif pred > self.upper_lim: return torch.pow(pred - self.upper_lim, 2) else: return pred * 0
[docs] class PoseCrossEntropyLoss(TorchCrossEntropyLoss):
[docs] def __init__(self): """ Class for calculating a cross entropy loss for per-pose delta G predictions in kT units compared to labels for pose closest to experimental structure. """ super().__init__()
[docs] def forward(self, pred, pose_preds, target, in_range, uncertainty): """ Calculate cross-entropy loss for per-pose delta G predictions. These predictions are assumed to be in implicit kT units, as that is the standard in mtenn. Parameters ---------- pred : torch.Tensor Model prediction pose_preds : torch.Tensor Predictions for each pose target : torch.Tensor Prediction target in_range : torch.Tensor `target`'s presence in the dynamic range of the assay. Give a value of < 0 for `target` below lower bound, > 0 for `target` above upper bound, and 0 or None for inside range uncertainty : torch.Tensor Uncertainty in `target` measurements Returns ------- torch.Tensor Calculated loss """ if not isinstance(pose_preds, torch.Tensor): pose_free_energies = torch.cat(pose_preds).flatten() else: pose_free_energies = pose_preds.flatten() return super().forward( -pose_free_energies, target.flatten().to( device=pose_free_energies.device, dtype=pose_free_energies.dtype ), )