import random
import numpy as np
import pandas as pd
import torch
from dgllife.utils import CanonicalAtomFeaturizer, SMILESToBigraph
from drugforge.data.backend.openeye import oechem
from drugforge.data.schema.complex import Complex
from drugforge.data.schema.ligand import Ligand
from torch.utils.data import Dataset
[docs]
class DockedDataset(Dataset):
"""
Class for loading docking results into a dataset to be used for graph
learning.
"""
[docs]
def __init__(self, compounds={}, structures=[], random_iter=False):
"""
Constructor for DockedDataset object.
Parameters
----------
compounds : dict[(str, str), list[int]]
Dict mapping a compound tuple (xtal_id, compound_id) to a list of indices in
structures that are poses for that id pair
structures : list[dict]
List of pose dicts, containing at minimum tensors for atomic number, atomic
positions, and a ligand idx. Indices in this list should match the indices
in the lists in compounds.
random_iter : bool, default=False
Iterate through the dataset randomly each time
"""
super().__init__()
self.compounds = compounds
self.structures = structures
self.random_iter = random_iter
[docs]
@classmethod
def from_complexes(
cls, complexes: list[Complex], exp_dict=None, ignore_h=True, random_iter=False
):
"""
Build from a list of Complex objects.
Parameters
----------
complexes : list[Complex]
List of Complex schema objects to build into a DockedDataset object
exp_dict : dict[str, dict[str, int | float]], optional
Dict mapping compound_id to an experimental results dict. The dict for a
compound will be added to the pose representation of each Complex containing
a ligand witht that compound_id
ignore_h : bool, default=True
Whether to remove hydrogens from the loaded structure
random_iter : bool, default=False
Iterate through the dataset randomly each time
Returns
-------
DockedDataset
"""
if exp_dict is None:
exp_dict = {}
# Helper function to grab all relevant
def get_complex_id(c):
# First build target id from target_name and all identifiers
target_name = c.target.target_name
target_ids = {k: v for k, v in c.target.ids.model_dump() if v}
target_id = []
if target_name:
target_id += [target_name]
if len(target_ids):
target_id += [target_ids]
# Build ligand_id from compound_name and all identifiers
compound_name = c.ligand.compound_name
compound_ids = {k: v for k, v in c.ligand.ids.model_dump() if v}
compound_id = []
if compound_name:
compound_id += [compound_name]
if len(compound_ids):
compound_id += [compound_ids]
return tuple(target_id), tuple(compound_id)
compound_idxs = {}
structures = []
# Can't use enumerate in case we skip some
comp_counter = 0
for comp in complexes:
try:
comp_exp_dict = comp.ligand.experimental_data.experimental_data
except AttributeError:
comp_exp_dict = {}
comp_exp_dict |= exp_dict.get(comp.ligand.compound_name, {})
compound = (comp.target.target_name, comp.ligand.compound_name)
pose = cls._complex_to_pose(
comp, compound=compound, exp_dict=comp_exp_dict, ignore_h=ignore_h
)
if pose is None:
continue
structures.append(pose)
try:
compound_idxs[compound].append(comp_counter)
except KeyError:
compound_idxs[compound] = [comp_counter]
comp_counter += 1
return cls(compound_idxs, structures, random_iter=random_iter)
@staticmethod
def _complex_to_pose(comp, compound=None, exp_dict=None, ignore_h=True):
"""
Helper function to convert a Complex to a pose.
"""
if exp_dict is None:
exp_dict = {}
# First get target atom positions, atomic numbers, and B factors
target_mol = comp.target.to_oemol()
target_coords = target_mol.GetCoords()
target_pos = []
target_z = []
target_b = []
for atom in target_mol.GetAtoms():
target_pos.append(target_coords[atom.GetIdx()])
target_z.append(atom.GetAtomicNum())
target_b.append(oechem.OEAtomGetResidue(atom).GetBFactor())
# Get ligand atom positions, atomic numbers, and B factors
ligand_mol = comp.ligand.to_oemol()
ligand_coords = ligand_mol.GetCoords()
ligand_pos = []
ligand_z = []
ligand_b = []
for atom in ligand_mol.GetAtoms():
ligand_pos.append(ligand_coords[atom.GetIdx()])
ligand_z.append(atom.GetAtomicNum())
ligand_b.append(oechem.OEAtomGetResidue(atom).GetBFactor())
# Combine the two
all_pos = torch.tensor(target_pos + ligand_pos).float()
all_z = torch.tensor(target_z + ligand_z)
all_b = torch.tensor(target_b + ligand_b)
all_lig = torch.tensor(
[False] * target_mol.NumAtoms() + [True] * ligand_mol.NumAtoms()
)
# Add some extra stuff for use in e3nn models
all_one_hot = torch.nn.functional.one_hot(all_z - 1, 100).float()
# Subset to remove Hs if desired
if ignore_h:
h_idx = all_z == 1
all_pos = all_pos[~h_idx]
all_z = all_z[~h_idx]
all_b = all_b[~h_idx]
all_lig = all_lig[~h_idx]
all_one_hot = all_one_hot[~h_idx]
pose = {
"pos": all_pos,
"z": all_z,
"lig": all_lig,
"x": all_one_hot,
"b": all_b,
"ligand": comp.ligand,
}
if compound:
pose["compound"] = compound
return pose | exp_dict
[docs]
@classmethod
def from_files(
cls,
str_fns,
compounds,
ignore_h=True,
extra_dict=None,
num_workers=1,
random_iter=False,
):
"""
Parameters
----------
str_fns : list[str]
List of paths for the PDB files. Should correspond 1:1 with the
names in compounds
compounds : list[tuple[str]]
List of (crystal structure, ligand compound id)
ignore_h : bool, default=True
Whether to remove hydrogens from the loaded structure
extra_dict : dict[str, dict], optional
Extra information to add to each structure. Keys should be
compounds, and dicts can be anything as long as they don't have the
keys ["z", "pos", "lig", "compound"]
num_workers : int, default=1
Number of cores to use to load structures
random_iter : bool, default=False
Iterate through the dataset randomly each time
"""
if extra_dict is None:
extra_dict = {}
mp_args = [(fn, compound) for fn, compound in zip(str_fns, compounds)]
def mp_func(fn, compound):
return Complex.from_pdb(
pdb_file=fn,
target_kwargs={"target_name": compound[0]},
ligand_kwargs={"compound_name": compound[1]},
)
if num_workers > 1:
import multiprocessing as mp
n_procs = min(num_workers, mp.cpu_count(), len(mp_args))
with mp.Pool(n_procs) as pool:
all_complexes = pool.starmap(mp_func, mp_args)
else:
all_complexes = [mp_func(*args) for args in mp_args]
return cls.from_complexes(
all_complexes,
exp_dict=extra_dict,
ignore_h=ignore_h,
random_iter=random_iter,
)
def __len__(self):
return len(self.structures)
def __getitem__(self, idx):
"""
Parameters
----------
idx : int, tuple, list[tuple/int], tensor[tuple/int]
Index into dataset. Can either be a numerical index into the
structures or a tuple of (crystal structure, ligand compound id),
or a list/torch.tensor/numpy.ndarray of either of those types
Returns
-------
list[tuple]
List of tuples (crystal_structure, compound_id) for found structures
list[dict]
List of dictionaries with keys
- `z`: atomic numbers
- `pos`: position matrix
- `lig`: ligand identifier
- `compound`: tuple of (crystal_structure, compound_id)
"""
import torch
# Extract idx from inside the tensor object
if torch.is_tensor(idx):
try:
idx = idx.item()
except ValueError:
idx = idx.tolist()
# Figure out the type of the index, and keep note of whether a list was
# passed in or not
if isinstance(idx, int):
return_list = False
idx_type = int
idx = [idx]
elif isinstance(idx, slice):
return_list = True
idx_type = int
start, stop, step = idx.indices(len(self))
idx = list(range(start, stop, step))
else:
return_list = True
if isinstance(idx[0], int):
idx_type = int
else:
idx_type = tuple
if (
isinstance(idx, tuple)
and (len(idx) == 2)
and isinstance(idx[0], str)
and isinstance(idx[1], str)
):
idx = [idx]
else:
idx = [tuple(i) for i in idx]
# If idx is integral, assume it is indexing the structures list,
# otherwise assume it's giving structure name
if idx_type is int:
str_idx_list = idx
else:
# Need to find the structures that correspond to this compound(s)
str_idx_list = [i for c in idx for i in self.compounds[c]]
str_list = [self.structures[i] for i in str_idx_list]
compounds = [s["compound"] for s in str_list]
if return_list:
return list(zip(compounds, str_list))
else:
return (compounds[0], str_list[0])
def __iter__(self):
if self.random_iter:
rand_idx = random.sample(range(len(self.structures)), len(self.structures))
for i in rand_idx:
s = self.structures[i]
yield (s["compound"], s)
else:
for s in self.structures:
yield (s["compound"], s)
[docs]
class SplitDockedDataset(DockedDataset):
"""
Same layout as DockedDataset, but each entry is a dict that has entries for
"complex", "protein", and "ligand", which store the corresponding representations.
"""
@staticmethod
def _complex_to_pose(comp, compound=None, exp_dict=None, ignore_h=True):
"""
Helper function to convert a Complex to a pose.
"""
# First use already written code to do the actual parsing
pose = DockedDataset._complex_to_pose(
comp=comp, compound=compound, exp_dict=exp_dict, ignore_h=ignore_h
)
# Get just the actual structural data
complex_pose = {k: pose.pop(k) for k in ["pos", "z", "lig", "x", "b"]}
# Extract the protein atoms
lig_idx = complex_pose["lig"]
prot_pose = {k: complex_pose[k][~lig_idx] for k in ["pos", "z", "x", "b"]}
# Convert ligand to a DGL graph
# Function for encoding SMILES to a graph
smiles_to_g = SMILESToBigraph(
add_self_loop=True,
node_featurizer=CanonicalAtomFeaturizer(),
edge_featurizer=None,
)
g = smiles_to_g(comp.ligand.smiles)
if g is None:
print(f"{compound} ligand couldn't be converted to graph", flush=True)
return None
lig_pose = {"g": g}
pose["complex"] = complex_pose
pose["protein"] = prot_pose
pose["ligand"] = lig_pose
return pose
[docs]
class GroupedDockedDataset(Dataset):
"""
Version of DockedDataset where data is grouped by compound_id, so all poses
for a given compound can be accessed at a time.
"""
[docs]
def __init__(
self,
compound_ids: list[str] = [],
structures: dict[str, dict] = {},
random_iter=False,
):
"""
Constructor for GroupedDockedDataset object.
Parameters
----------
compound_ids : list[str]
List of compound ids. Each entry in this list must have a corresponding
entry in structures
structures : dict[str, dict]
Dict mapping compound_id to a pose dict
random_iter : bool, default=False
Iterate through the dataset randomly each time
"""
import numpy as np
super().__init__()
self.compound_ids = np.asarray(compound_ids)
self.structures = structures
self.random_iter = random_iter
[docs]
@classmethod
def from_complexes(
cls, complexes: list[Complex], exp_dict={}, ignore_h=True, random_iter=False
):
"""
Build from a list of Complex objects.
Parameters
----------
complexes : list[Complex]
List of Complex schema objects to build into a DockedDataset object
exp_dict : dict[str, dict[str, int | float]], optional
Dict mapping compound_id to an experimental results dict. The dict for a
compound will be added to the pose representation of each Complex containing
a ligand witht that compound_id
ignore_h : bool, default=True
Whether to remove hydrogens from the loaded structure
random_iter : bool, default=False
Iterate through the dataset randomly each time
Returns
-------
GroupedDockedDataset
"""
from drugforge.docking.analysis import calculate_rmsd_openeye
compound_ids = []
structures = {}
for i, comp in enumerate(complexes):
# compound = get_complex_id(comp)
compound = (comp.target.target_name, comp.ligand.compound_name)
# Build pose dict
try:
comp_exp_dict = comp.ligand.experimental_data.experimental_data
except AttributeError:
comp_exp_dict = {}
comp_exp_dict |= exp_dict.get(comp.ligand.compound_name, {})
pose = DockedDataset._complex_to_pose(
comp, compound=compound, exp_dict=comp_exp_dict, ignore_h=ignore_h
)
# Calculate RMSD to ref if available
if "xtal_ligand" in pose:
pose["ref_rmsd"] = calculate_rmsd_openeye(
Ligand(**pose["xtal_ligand"]).to_oemol(), pose["ligand"].to_oemol()
)
try:
structures[comp.ligand.compound_name]["poses"].append(pose)
except KeyError:
# Take compound-level data from first pose
exp_data = {
k: v
for k, v in pose.items()
if (not isinstance(v, torch.Tensor)) and (k != "ref_rmsd")
}
structures[comp.ligand.compound_name] = {"poses": [pose]} | exp_data
compound_ids.append(comp.ligand.compound_name)
# Calculate which pose is closest to experiment
for compound_id, data in structures.items():
if "xtal_ligand" not in data:
continue
# Get all RMSDs
pose_rmsds = np.asarray([pose["ref_rmsd"] for pose in data["poses"]])
# Label of all zeros, except the one with the best pose (lowest ref RMSD)
best_lab = np.zeros(len(data["poses"]))
best_lab[np.argmin(pose_rmsds)] = 1
data["best_pose_label"] = best_lab
# Normalize to probability, take inverse first so lower RMSDs are better
data["rmsd_probs"] = (1 / pose_rmsds) / (1 / pose_rmsds).sum()
return cls(
compound_ids=compound_ids, structures=structures, random_iter=random_iter
)
[docs]
@classmethod
def from_files(
cls,
str_fns,
compounds,
ignore_h=True,
extra_dict=None,
num_workers=1,
random_iter=False,
):
"""
Parameters
----------
str_fns : list[str]
List of paths for the PDB files. Should correspond 1:1 with the
names in compounds
compounds : list[tuple[str]]
List of (crystal structure, ligand compound id)
ignore_h : bool, default=True
Whether to remove hydrogens from the loaded structure
extra_dict : dict[str, dict], optional
Extra information to add to each structure. Keys should be
compounds, and dicts can be anything as long as they don't have the
keys ["z", "pos", "lig", "compound"]
num_workers : int, default=1
Number of cores to use to load structures
random_iter : bool, default=False
Iterate through the dataset randomly each time
"""
if extra_dict is None:
extra_dict = {}
mp_args = [(fn, compound) for fn, compound in zip(str_fns, compounds)]
def mp_func(fn, compound):
return Complex.from_pdb(
pdb_file=fn,
target_kwargs={"target_name": compound[0]},
ligand_kwargs={"compound_name": compound[1]},
)
if num_workers > 1:
import multiprocessing as mp
n_procs = min(num_workers, mp.cpu_count(), len(mp_args))
with mp.Pool(n_procs) as pool:
all_complexes = pool.starmap(mp_func, mp_args)
else:
all_complexes = [mp_func(*args) for args in mp_args]
return cls.from_complexes(
all_complexes,
exp_dict=extra_dict,
ignore_h=ignore_h,
random_iter=random_iter,
)
def __len__(self):
return len(self.compound_ids)
def __getitem__(self, idx):
"""
Parameters
----------
idx : int, str, list[str/int], tensor[str/int]
Index into dataset. Can either be a numerical index into the
compound ids or the compound id itself, or a
list/torch.tensor/numpy.ndarray of either of those types.
Returns
-------
List[str]
List of compound_id for found groups
List[List[Dict]]
List of groups (lists) of dict representation of structures
"""
import torch
# Extract idx from inside the tensor object
if torch.is_tensor(idx):
try:
idx = idx.item()
except ValueError:
idx = idx.tolist()
# Figure out the type of the index, and keep note of whether a list was
# passed in or not
if (isinstance(idx, int)) or (isinstance(idx, str)):
return_list = False
idx_type = type(idx)
idx = [idx]
elif (isinstance(idx[0], int)) or (isinstance(idx[0], str)):
return_list = True
idx_type = type(idx[0])
elif isinstance(idx, slice):
return_list = True
idx_type = int
start, stop, step = idx.indices(len(self))
idx = list(range(start, stop, step))
else:
try:
err_type = type(idx[0])
except TypeError:
err_type = type(idx)
raise TypeError(f"Unknown indexing type {err_type}")
# If idx is integral, assume it is indexing the structures list,
# otherwise assume it's giving structure name
if idx_type is int:
compound_id_list = self.compound_ids[idx]
else:
compound_id_list = idx
str_list = [self.structures[compound_id] for compound_id in compound_id_list]
if return_list:
return list(zip(compound_id_list, str_list))
else:
return (compound_id_list[0], str_list[0])
def __iter__(self):
if self.random_iter:
rand_idx = random.sample(
range(len(self.compound_ids)), len(self.compound_ids)
)
for i in rand_idx:
compound_id = self.compound_ids[i]
yield compound_id, self.structures[compound_id]
else:
for compound_id in self.compound_ids:
yield compound_id, self.structures[compound_id]
[docs]
class GraphDataset(Dataset):
"""
Class for loading SMILES as graphs.
"""
[docs]
def __init__(self, compounds={}, structures=[], random_iter=False):
super().__init__()
self.compounds = compounds
self.structures = structures
self.random_iter = random_iter
[docs]
@classmethod
def from_ligands(
cls,
ligands: list[Ligand],
exp_dict: dict = {},
node_featurizer=None,
edge_featurizer=None,
random_iter=False,
):
"""
Parameters
----------
ligands : list[Ligands]
List of Ligand schema objects to build into a GraphDataset object
exp_dict : dict[str, dict[str, int | float]], optional
Dict mapping compound_id to an experimental results dict. The dict for a
compound will be added to the pose representation of each Complex containing
a ligand witht that compound_id
node_featurizer : BaseAtomFeaturizer, optional
Featurizer for node data
edge_featurizer : BaseBondFeaturizer, optional
Featurizer for edges
random_iter : bool, default=False
Iterate through the dataset randomly each time
"""
# Function for encoding SMILES to a graph
smiles_to_g = SMILESToBigraph(
add_self_loop=True,
node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
)
compounds = {}
structures = []
for i, lig in enumerate(ligands):
compound_id = lig.compound_name
smiles = lig.smiles
# Need a tuple to match DockedDataset, but the graph objects aren't
# attached to a protein structure at all
compound = ("NA", compound_id)
# Generate DGL graph
g = smiles_to_g(smiles)
# Gather experimental data
try:
lig_exp_dict = lig.experimental_data.experimental_data
if lig.experimental_data.date_created:
lig_exp_dict |= {"date_created": lig.experimental_data.date_created}
except AttributeError:
lig_exp_dict = {}
lig_exp_dict |= exp_dict.get(compound_id, {})
# Add data
try:
compounds[compound].append(i)
except KeyError:
compounds[compound] = [i]
structures.append(
{
"smiles": smiles,
"g": g,
"compound": compound,
}
| lig_exp_dict
)
return cls(compounds, structures, random_iter=random_iter)
[docs]
@classmethod
def from_exp_compounds(
cls,
exp_compounds,
exp_dict: dict = {},
node_featurizer=None,
edge_featurizer=None,
random_iter=False,
):
"""
Parameters
----------
exp_compounds : List[schema.ExperimentalCompoundData]
List of compounds
exp_dict : dict[str, dict[str, int | float]], optional
Dict mapping compound_id to an experimental results dict. The dict for a
compound will be added to the pose representation of each Complex containing
a ligand witht that compound_id
node_featurizer : BaseAtomFeaturizer, optional
Featurizer for node data
edge_featurizer : BaseBondFeaturizer, optional
Featurizer for edges
random_iter : bool, default=False
Iterate through the dataset randomly each time
"""
# Function for encoding SMILES to a graph
smiles_to_g = SMILESToBigraph(
add_self_loop=True,
node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
)
compounds = {}
structures = []
for i, exp_compound in enumerate(exp_compounds):
compound_id = exp_compound.compound_id
smiles = exp_compound.smiles
# Need a tuple to match DockedDataset, but the graph objects aren't
# attached to a protein structure at all
compound = ("NA", compound_id)
# Generate DGL graph
g = smiles_to_g(smiles)
# Gather experimental data
lig_exp_dict = exp_compound.experimental_data.copy()
lig_exp_dict |= exp_dict.get(compound_id, {})
if exp_compound.date_created:
lig_exp_dict |= {"date_created": exp_compound.date_created}
# Add data
try:
compounds[compound].append(i)
except KeyError:
compounds[compound] = [i]
structures.append(
{
"smiles": smiles,
"g": g,
"compound": compound,
}
| exp_compound.experimental_data
| exp_dict.get(compound_id, {})
| {"date_created": exp_compound.date_created}
)
return cls(compounds, structures, random_iter=random_iter)
def __len__(self):
return len(self.structures)
def __getitem__(self, idx):
"""
Parameters
----------
idx : int, tuple, list[tuple/int], tensor[tuple/int]
Index into dataset. Can either be a numerical index into the
structures or a tuple of (crystal structure, ligand compound id),
or a list/torch.tensor/numpy.ndarray of either of those types
Returns
-------
list[tuple]
List of tuples (crystal_structure, compound_id) for found structures
list[dict]
List of dictionaries with keys
- `g`: DGLGraph
- `compound`: tuple of (crystal_structure, compound_id)
"""
import torch
# Extract idx from inside the tensor object
if torch.is_tensor(idx):
try:
idx = idx.item()
except ValueError:
idx = idx.tolist()
# Figure out the type of the index, and keep note of whether a list was
# passed in or not
if isinstance(idx, int):
return_list = False
idx_type = int
idx = [idx]
elif isinstance(idx, slice):
return_list = True
idx_type = int
start, stop, step = idx.indices(len(self))
idx = list(range(start, stop, step))
else:
return_list = True
if isinstance(idx[0], bool):
idx_type = bool
if len(idx) != len(self.structures):
raise IndexError("Index length must match number of structures.")
elif isinstance(idx[0], int):
idx_type = int
else:
idx_type = tuple
if (
isinstance(idx, tuple)
and (len(idx) == 2)
and isinstance(idx[0], str)
and isinstance(idx[1], str)
):
idx = [idx]
else:
idx = [tuple(i) for i in idx]
# If idx is integral, assume it is indexing the structures list,
# otherwise assume it's giving structure name
if idx_type is int:
str_idx_list = idx
elif idx_type is bool:
str_idx_list = [i for i in range(len(self.structures)) if idx[i]]
else:
# Need to find the structures that correspond to this compound(s)
str_idx_list = [i for c in idx for i in self.compounds[c]]
str_list = [self.structures[i] for i in str_idx_list]
compounds = [s["compound"] for s in str_list]
if return_list:
return list(zip(compounds, str_list))
else:
return (compounds[0], str_list[0])
def __iter__(self):
if self.random_iter:
rand_idx = random.sample(range(len(self.structures)), len(self.structures))
for i in rand_idx:
s = self.structures[i]
yield (s["compound"], s)
else:
for s in self.structures:
yield (s["compound"], s)
[docs]
def dataset_to_dataframe(dataset):
all_data = []
for k, v in dataset:
# add all string castable data in v to a dict
data_dict = {}
for key, value in v.items():
try:
value = str(value)
data_dict[key] = value
except: # noqa: E722
pass
# add compound tuple to dict
data_dict["xtal_id"] = k[0]
data_dict["compound_id"] = k[1]
all_data.append(data_dict)
return pd.DataFrame(all_data)
[docs]
def dataset_to_csv(dataset, filename):
dataset_to_dataframe(dataset).to_csv(filename, index=False)
return filename