drugforge.ml.loss.PoseCrossEntropyLoss
- class drugforge.ml.loss.PoseCrossEntropyLoss(*args: Any, **kwargs: Any)[source]
Bases:
CrossEntropyLoss- __init__()[source]
Class for calculating a cross entropy loss for per-pose delta G predictions in kT units compared to labels for pose closest to experimental structure.
Methods
__init__()Class for calculating a cross entropy loss for per-pose delta G predictions in kT units compared to labels for pose closest to experimental structure.
forward(pred, pose_preds, target, in_range, ...)Calculate cross-entropy loss for per-pose delta G predictions.
- forward(pred, pose_preds, target, in_range, uncertainty)[source]
Calculate cross-entropy loss for per-pose delta G predictions. These predictions are assumed to be in implicit kT units, as that is the standard in mtenn.
- Parameters:
pred (torch.Tensor) – Model prediction
pose_preds (torch.Tensor) – Predictions for each pose
target (torch.Tensor) – Prediction target
in_range (torch.Tensor) – target’s presence in the dynamic range of the assay. Give a value of < 0 for target below lower bound, > 0 for target above upper bound, and 0 or None for inside range
uncertainty (torch.Tensor) – Uncertainty in target measurements
- Returns:
Calculated loss
- Return type:
torch.Tensor