import abc
from typing import Literal
from drugforge.data.schema.ligand import Ligand
from pydantic import BaseModel, ConfigDict, Field
[docs]
class StateExpanderBase(abc.ABC, BaseModel):
expander_type: Literal["StateExpanderBase"] = Field(
"StateExpanderBase", description="The type of expander."
)
@abc.abstractmethod
def _expand(self, ligands: list[Ligand], unique: bool = False) -> list[Ligand]: ...
def expand(self, ligands: list[Ligand], unique: bool = True) -> list[Ligand]:
expanded_ligands = self._expand(ligands=ligands)
if unique:
return list(set(expanded_ligands))
else:
return expanded_ligands
@abc.abstractmethod
def _provenance(self) -> dict[str, str]:
"""Return the software used to perform the state expansion in the workflow."""
...
[docs]
def provenance(self) -> dict[str, str]:
"""
Get the provenance of the software and settings used to expand the molecule state.
Returns
-------
A dict of the expander and the software used to do the expansion.
"""
data = {"expander": self.model_dump()}
data.update(self._provenance())
return data
[docs]
class StateExpansion(BaseModel):
parent: Ligand = Field(..., description="The parent ligand")
children: list[Ligand] = Field(
..., description="The children ligands resulting from expansion"
)
expansion: Literal["stereo", "charge"] = Field(
...,
description="The type of state expansion, this will be used "
"to group the expansions.",
)
model_config = ConfigDict(frozen=True)
@property
def n_expanded_states(self) -> int:
return len(self.children)
[docs]
class StateExpansionSet(BaseModel):
expansions: list[StateExpansion] = Field(..., description="The set of expansions")
unassigned: list[Ligand] = Field(
..., description="Ligands that could not be assigned a parent"
)
model_config = ConfigDict(frozen=True)
@classmethod
def from_ligands(cls, ligands: list[Ligand]) -> "StateExpansionSet":
is_expansion = [
ligand for ligand in ligands if ligand.expansion_tag is not None
]
expansions = []
# keep track of children that have been assigned a parent
assigned = set()
for ligand in ligands:
inchikey = ligand.fixed_inchikey
children = [
child
for child in is_expansion
if child.expansion_tag.parent_fixed_inchikey == inchikey
]
if len(children) > 0:
# work out the type of expansion, make sure only one type links the children and parents
expansion_type = [
child.expansion_tag.provenance["expander"]["expander_type"].lower()
for child in children
]
if len(set(expansion_type)) > 1:
raise RuntimeError(
f"Multiple expansion methods link the parent {ligand.smiles} to the child molecules {[child.smiles for child in children]} this should not happen."
)
# set the type to one of the two defined types
expansion_method = (
"stereo"
if expansion_type[0].lower().find("stereo") == 0
else "charge"
)
expansion = StateExpansion(
parent=ligand, children=children, expansion=expansion_method
)
expansions.append(expansion)
assigned.update(children)
assigned.add(ligand)
# check for unassigned ligands
unassigned = [ligand for ligand in ligands if ligand not in assigned]
return StateExpansionSet(expansions=expansions, unassigned=unassigned)
def get_stereo_expansions(self) -> list[StateExpansion]:
return [
expansion
for expansion in self.expansions
if expansion.expansion == "stereo"
]
def get_charge_expansions(self) -> list[StateExpansion]:
return [
expansion
for expansion in self.expansions
if expansion.expansion == "charge"
]