drugforge.ml.loss.PoseCrossEntropyLoss

class drugforge.ml.loss.PoseCrossEntropyLoss(*args: Any, **kwargs: Any)[source]

Bases: CrossEntropyLoss

__init__()[source]

Class for calculating a cross entropy loss for per-pose delta G predictions in kT units compared to labels for pose closest to experimental structure.

Methods

__init__()

Class for calculating a cross entropy loss for per-pose delta G predictions in kT units compared to labels for pose closest to experimental structure.

forward(pred, pose_preds, target, in_range, ...)

Calculate cross-entropy loss for per-pose delta G predictions.

forward(pred, pose_preds, target, in_range, uncertainty)[source]

Calculate cross-entropy loss for per-pose delta G predictions. These predictions are assumed to be in implicit kT units, as that is the standard in mtenn.

Parameters:
  • pred (torch.Tensor) – Model prediction

  • pose_preds (torch.Tensor) – Predictions for each pose

  • target (torch.Tensor) – Prediction target

  • in_range (torch.Tensor) – target’s presence in the dynamic range of the assay. Give a value of < 0 for target below lower bound, > 0 for target above upper bound, and 0 or None for inside range

  • uncertainty (torch.Tensor) – Uncertainty in target measurements

Returns:

Calculated loss

Return type:

torch.Tensor