diff --git a/modules/mol/alg/pymod/CMakeLists.txt b/modules/mol/alg/pymod/CMakeLists.txt index 8506d7be3ad9a95631ec61a5db2371e942aff879..2cae612018c804361e2e9af94da9d45c83e46379 100644 --- a/modules/mol/alg/pymod/CMakeLists.txt +++ b/modules/mol/alg/pymod/CMakeLists.txt @@ -32,6 +32,8 @@ set(OST_MOL_ALG_PYMOD_MODULES ligand_scoring.py dockq.py contact_score.py + ligand_scoring_base.py + ligand_scoring_scrmsd.py ) if (NOT ENABLE_STATIC) diff --git a/modules/mol/alg/pymod/ligand_scoring_base.py b/modules/mol/alg/pymod/ligand_scoring_base.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1f5af050d13f227bb98b8210b858e63336a8ed --- /dev/null +++ b/modules/mol/alg/pymod/ligand_scoring_base.py @@ -0,0 +1,545 @@ +import numpy as np +import networkx + +from ost import mol +from ost import LogWarning, LogScript, LogVerbose, LogDebug +from ost.mol.alg import chain_mapping + +class LigandScorer: + + def __init__(self, model, target, model_ligands=None, target_ligands=None, + resnum_alignments=False, rename_ligand_chain=False, + substructure_match=False, coverage_delta=0.2, + max_symmetries=1e5): + + if isinstance(model, mol.EntityView): + self.model = mol.CreateEntityFromView(model, False) + elif isinstance(model, mol.EntityHandle): + self.model = model.Copy() + else: + raise RuntimeError("model must be of type EntityView/EntityHandle") + + if isinstance(target, mol.EntityView): + self.target = mol.CreateEntityFromView(target, False) + elif isinstance(target, mol.EntityHandle): + self.target = target.Copy() + else: + raise RuntimeError("target must be of type EntityView/EntityHandle") + + # Extract ligands from target + if target_ligands is None: + self.target_ligands = self._extract_ligands(self.target) + else: + self.target_ligands = self._prepare_ligands(self.target, target, + target_ligands, + rename_ligand_chain) + if len(self.target_ligands) == 0: + LogWarning("No ligands in the target") + + # Extract ligands from model + if model_ligands is None: + self.model_ligands = self._extract_ligands(self.model) + else: + self.model_ligands = self._prepare_ligands(self.model, model, + model_ligands, + rename_ligand_chain) + if len(self.model_ligands) == 0: + LogWarning("No ligands in the model") + if len(self.target_ligands) == 0: + raise ValueError("No ligand in the model and in the target") + + self.resnum_alignments = resnum_alignments + self.rename_ligand_chain = rename_ligand_chain + self.substructure_match = substructure_match + self.coverage_delta = coverage_delta + self.max_symmetries = max_symmetries + + # lazily computed attributes + self._chain_mapper = None + + # keep track of error states + # simple integers instead of enums - documentation of property describes + # encoding + self._error_states = None + + # score matrices + self._score_matrix = None + self._coverage_matrix = None + self._aux_data = None + + @property + def error_states(self): + """ Encodes error states of ligand pairs + + Not only critical things, but also things like: a pair of ligands + simply doesn't match. Target ligands are in rows, model ligands in + columns. States are encoded as integers <= 9. Larger numbers encode + errors for child classes. + + * -1: Unknown Error - cannot be matched + * 0: Ligand pair has valid symmetries - can be matched. + * 1: Ligand pair has no valid symmetry - cannot be matched. + * 2: Ligand pair has too many symmetries - cannot be matched. + You might be able to get a match by increasing *max_symmetries*. + * 3: Ligand pair has no isomorphic symmetries - cannot be matched. + Target ligand is subgraph of model ligand. This error only occurs + if *substructure_match* is False. These cases will likely become + 0 if this flag is enabled. + * 4: Disconnected graph error - cannot be matched. + Either target ligand or model ligand has disconnected graph. + + :rtype: :class:`~numpy.ndarray` + """ + if self._error_states is None: + self._compute_scores() + return self._error_states + + @property + def score_matrix(self): + """ Get the matrix of scores. + + Target ligands are in rows, model ligands in columns. + + NaN values indicate that no value could be computed (i.e. different + ligands). In other words: values are only valid if respective location + :attr:`~error_states` is 0. + + :rtype: :class:`~numpy.ndarray` + """ + if self._score_matrix is None: + self._compute_scores() + return self._score_matrix + + @property + def coverage_matrix(self): + """ Get the matrix of model ligand atom coverage in the target. + + Target ligands are in rows, model ligands in columns. + + NaN values indicate that no value could be computed (i.e. different + ligands). In other words: values are only valid if respective location + :attr:`~error_states` is 0. If `substructure_match=False`, only full + match isomorphisms are considered, and therefore only values of 1.0 + can be observed. + + :rtype: :class:`~numpy.ndarray` + """ + if self._coverage_matrix is None: + self._compute_scores() + return self._coverage_matrix + + @property + def aux_data(self): + """ Get the matrix of scorer specific auxiliary data. + + Target ligands are in rows, model ligands in columns. + + Auxiliary data consists of arbitrary data dicts which allow a child + class to provide additional information for a scored ligand pair. + empty dictionaries indicate that no value could be computed + (i.e. different ligands). In other words: values are only valid if + respective location :attr:`~error_states` is 0. + + :rtype: :class:`~numpy.ndarray` + """ + if self._aux_matrix is None: + self._compute_scores() + return self._aux_matrix + + @property + def chain_mapper(self): + """ Chain mapper object for the given :attr:`target`. + + Can be used by child classes if needed, constructed with + *resnum_alignments* flag + + :type: :class:`ost.mol.alg.chain_mapping.ChainMapper` + """ + if self._chain_mapper is None: + self._chain_mapper = \ + chain_mapping.ChainMapper(self.target, + n_max_naive=1e9, + resnum_alignments=self.resnum_alignments) + return self._chain_mapper + + @staticmethod + def _extract_ligands(entity): + """Extract ligands from entity. Return a list of residues. + + Assumes that ligands have the :attr:`~ost.mol.ResidueHandle.is_ligand` + flag set. This is typically the case for entities loaded from mmCIF + (tested with mmCIF files from the PDB and SWISS-MODEL). + Legacy PDB files must contain `HET` headers (which is usually the + case for files downloaded from the PDB but not elsewhere). + + This function performs basic checks to ensure that the residues in this + chain are not forming polymer bonds (ie peptide/nucleotide ligands) and + will raise a RuntimeError if this assumption is broken. + + :param entity: the entity to extract ligands from + :type entity: :class:`~ost.mol.EntityHandle` + :rtype: :class:`list` of :class:`~ost.mol.ResidueHandle` + + """ + extracted_ligands = [] + for residue in entity.residues: + if residue.is_ligand: + if mol.InSequence(residue, residue.next): + raise RuntimeError("Residue %s connected in polymer sequen" + "ce %s" % (residue.qualified_name)) + extracted_ligands.append(residue) + LogVerbose("Detected residue %s as ligand" % residue) + return extracted_ligands + + @staticmethod + def _prepare_ligands(new_entity, old_entity, ligands, rename_chain): + """Prepare the ligands given into a list of ResidueHandles which are + part of the copied entity, suitable for the model_ligands and + target_ligands properties. + + This function takes a list of ligands as (Entity|Residue)(Handle|View). + Entities can contain multiple ligands, which will be considered as + separate ligands. + + Ligands which are part of the entity are simply fetched in the new + copied entity. Otherwise, they are copied over to the copied entity. + """ + extracted_ligands = [] + + next_chain_num = 1 + new_editor = None + + def _copy_residue(residue, rename_chain): + """ Copy the residue into the new chain. + Return the new residue handle.""" + nonlocal next_chain_num, new_editor + + # Instantiate the editor + if new_editor is None: + new_editor = new_entity.EditXCS() + + new_chain = new_entity.FindChain(residue.chain.name) + if not new_chain.IsValid(): + new_chain = new_editor.InsertChain(residue.chain.name) + else: + # Does a residue with the same name already exist? + already_exists = new_chain.FindResidue(residue.number).IsValid() + if already_exists: + if rename_chain: + chain_ext = 2 # Extend the chain name by this + while True: + new_chain_name = residue.chain.name + "_" + str(chain_ext) + new_chain = new_entity.FindChain(new_chain_name) + if new_chain.IsValid(): + chain_ext += 1 + continue + else: + new_chain = new_editor.InsertChain(new_chain_name) + break + LogScript("Moved ligand residue %s to new chain %s" % ( + residue.qualified_name, new_chain.name)) + else: + msg = "A residue number %s already exists in chain %s" % ( + residue.number, residue.chain.name) + raise RuntimeError(msg) + + # Add the residue with its original residue number + new_res = new_editor.AppendResidue(new_chain, residue.name, residue.number) + # Add atoms + for old_atom in residue.atoms: + new_editor.InsertAtom(new_res, old_atom.name, old_atom.pos, + element=old_atom.element, occupancy=old_atom.occupancy, + b_factor=old_atom.b_factor, is_hetatm=old_atom.is_hetatom) + # Add bonds + for old_atom in residue.atoms: + for old_bond in old_atom.bonds: + new_first = new_res.FindAtom(old_bond.first.name) + new_second = new_res.FindAtom(old_bond.second.name) + new_editor.Connect(new_first, new_second) + return new_res + + def _process_ligand_residue(res, rename_chain): + """Copy or fetch the residue. Return the residue handle.""" + new_res = None + if res.entity.handle == old_entity.handle: + # Residue is part of the old_entity handle. + # However, it may not be in the copied one, for instance it may have been a view + # We try to grab it first, otherwise we copy it + new_res = new_entity.FindResidue(res.chain.name, res.number) + if new_res and new_res.valid: + LogVerbose("Ligand residue %s already in entity" % res.handle.qualified_name) + else: + # Residue is not part of the entity, need to copy it first + new_res = _copy_residue(res, rename_chain) + LogVerbose("Copied ligand residue %s" % res.handle.qualified_name) + new_res.SetIsLigand(True) + return new_res + + for ligand in ligands: + if isinstance(ligand, mol.EntityHandle) or isinstance(ligand, mol.EntityView): + for residue in ligand.residues: + new_residue = _process_ligand_residue(residue, rename_chain) + extracted_ligands.append(new_residue) + elif isinstance(ligand, mol.ResidueHandle) or isinstance(ligand, mol.ResidueView): + new_residue = _process_ligand_residue(ligand, rename_chain) + extracted_ligands.append(new_residue) + else: + raise RuntimeError("Ligands should be given as Entity or Residue") + + if new_editor is not None: + new_editor.UpdateICS() + return extracted_ligands + + def _compute_scores(self): + """ + Compute score for every possible target-model ligand pair and store the + result in internal matrices. + """ + ############################## + # Create the result matrices # + ############################## + shape = (len(self.target_ligands), len(self.model_ligands)) + self._score_matrix = np.full(shape, np.nan, dtype=np.float32) + self._coverage_matrix = np.full(shape, np.nan, dtype=np.float32) + self._error_states = np.full(shape, -1, dtype=np.int32) + self._aux_data = np.empty(shape, dtype=dict) + + for target_id, target_ligand in enumerate(self.target_ligands): + LogVerbose("Analyzing target ligand %s" % target_ligand) + for model_id, model_ligand in enumerate(self.model_ligands): + LogVerbose("Compare to model ligand %s" % model_ligand) + + ######################################################### + # Compute symmetries for given target/model ligand pair # + ######################################################### + try: + symmetries = ComputeSymmetries( + model_ligand, target_ligand, + substructure_match=self.substructure_match, + by_atom_index=True, + max_symmetries=self.max_symmetries) + LogVerbose("Ligands %s and %s symmetry match" % ( + str(model_ligand), str(target_ligand))) + except NoSymmetryError: + # Ligands are different - skip + LogVerbose("No symmetry between %s and %s" % ( + str(model_ligand), str(target_ligand))) + self._error_states[target_id, model_id] = 1 + continue + except TooManySymmetriesError: + # Ligands are too symmetrical - skip + LogVerbose("Too many symmetries between %s and %s" % ( + str(model_ligand), str(target_ligand))) + self._error_states[target_id, model_id] = 2 + continue + except NoIsomorphicSymmetryError: + # Ligands are different - skip + LogVerbose("No isomorphic symmetry between %s and %s" % ( + str(model_ligand), str(target_ligand))) + self._error_states[target_id, model_id] = 3 + continue + except DisconnectedGraphError: + LogVerbose("Disconnected graph observed for %s and %s" % ( + str(model_ligand), str(target_ligand))) + self._error_states[target_id, model_id] = 4 + continue + + ##################################################### + # Compute score by calling the child class _compute # + ##################################################### + score, error_state, aux = self._compute(symmetries, target_ligand, + model_ligand) + + ############ + # Finalize # + ############ + if error_state != 0: + # non-zero error states up to 4 are reserved for base class + if error_state <= 9: + raise RuntimeError("Child returned reserved err. state") + + self._error_states[target_id, model_id] = error_state + if error_state == 0: + # it's a valid score! + self._score_matrix[target_id, model_id] = score + cvg = len(symmetries[0][0]) / len(model_ligand.atoms) + self._coverage_matrix[target_id, model_id] = cvg + self._aux_data[target_id, model_id] = aux + + def _compute(self, symmetries, target_ligand, model_ligand): + """ Compute score for specified ligand pair - defined by child class + + Raises :class:`NotImplementedError` if not implemented by child class. + + :param symmetries: Defines symmetries between *target_ligand* and + *model_ligand*. Return value of + :func:`ComputeSymmetries` + :type symmetries: :class:`list` of :class:`tuple` with two elements + each: 1) :class:`list` of atom indices in + *target_ligand* 2) :class:`list` of respective atom + indices in *model_ligand* + :param target_ligand: The target ligand + :type target_ligand: :class:`ost.mol.ResidueHandle` or + :class:`ost.mol.ResidueView` + :param model_ligand: The model ligand + :type model_ligand: :class:`ost.mol.ResidueHandle` or + :class:`ost.mol.ResidueView` + + :returns: A :class:`tuple` with three elements: 1) a score + (:class:`float`) 2) error state (:class:`int`). + 3) auxiliary data for this ligand pair (:class:`dict`). + If error state is 0, the score and auxiliary data will be + added to :attr:`~score_matrix` and :attr:`~aux_data` as well + as the respective value in :attr:`~coverage_matrix`. + Child specific non-zero states MUST be >= 10. + """ + raise NotImplementedError("_compute must be implemented by child class") + + +def _ResidueToGraph(residue, by_atom_index=False): + """Return a NetworkX graph representation of the residue. + + :param residue: the residue from which to derive the graph + :type residue: :class:`ost.mol.ResidueHandle` or + :class:`ost.mol.ResidueView` + :param by_atom_index: Set this parameter to True if you need the nodes to + be labeled by atom index (within the residue). + Otherwise, if False, the nodes will be labeled by + atom names. + :type by_atom_index: :class:`bool` + :rtype: :class:`~networkx.classes.graph.Graph` + + Nodes are labeled with the Atom's uppercase :attr:`~ost.mol.AtomHandle.element`. + """ + nxg = networkx.Graph() + + for atom in residue.atoms: + nxg.add_node(atom.name, element=atom.element.upper()) + + # This will list all edges twice - once for every atom of the pair. + # But as of NetworkX 3.0 adding the same edge twice has no effect, so we're good. + nxg.add_edges_from([( + b.first.name, + b.second.name) for a in residue.atoms for b in a.GetBondList()]) + + if by_atom_index: + nxg = networkx.relabel_nodes(nxg, + {a: b for a, b in zip( + [a.name for a in residue.atoms], + range(len(residue.atoms)))}, + True) + return nxg + +def ComputeSymmetries(model_ligand, target_ligand, substructure_match=False, + by_atom_index=False, return_symmetries=True, + max_symmetries=1e6): + """Return a list of symmetries (isomorphisms) of the model onto the target + residues. + + :param model_ligand: The model ligand + :type model_ligand: :class:`ost.mol.ResidueHandle` or + :class:`ost.mol.ResidueView` + :param target_ligand: The target ligand + :type target_ligand: :class:`ost.mol.ResidueHandle` or + :class:`ost.mol.ResidueView` + :param substructure_match: Set this to True to allow partial ligands + in the reference. + :type substructure_match: :class:`bool` + :param by_atom_index: Set this parameter to True if you need the symmetries + to refer to atom index (within the residue). + Otherwise, if False, the symmetries refer to atom + names. + :type by_atom_index: :class:`bool` + :type return_symmetries: If Truthy, return the mappings, otherwise simply + return True if a mapping is found (and raise if + no mapping is found). This is useful to quickly + find out if a mapping exist without the expensive + step to find all the mappings. + :type return_symmetries: :class:`bool` + :param max_symmetries: If more than that many isomorphisms exist, raise + a :class:`TooManySymmetriesError`. This can only be assessed by + generating at least that many isomorphisms and can take some time. + :type max_symmetries: :class:`int` + :raises: :class:`NoSymmetryError` when no symmetry can be found; + :class:`NoIsomorphicSymmetryError` in case of isomorphic + subgraph but *substructure_match* is False. + :class:`TooManySymmetriesError` when more than `max_symmetries` + isomorphisms are found. + """ + + # Get the Graphs of the ligands + model_graph = _ResidueToGraph(model_ligand, by_atom_index=by_atom_index) + target_graph = _ResidueToGraph(target_ligand, by_atom_index=by_atom_index) + + if not networkx.is_connected(model_graph): + raise DisconnectedGraphError("Disconnected graph for model ligand %s" % model_ligand) + if not networkx.is_connected(target_graph): + raise DisconnectedGraphError("Disconnected graph for target ligand %s" % target_ligand) + + # Note the argument order (model, target) which differs from spyrmsd. + # This is because a subgraph of model is isomorphic to target - but not the opposite + # as we only consider partial ligands in the reference. + # Make sure to generate the symmetries correctly in the end + gm = networkx.algorithms.isomorphism.GraphMatcher( + model_graph, target_graph, node_match=lambda x, y: + x["element"] == y["element"]) + if gm.is_isomorphic(): + if not return_symmetries: + return True + symmetries = [] + for i, isomorphism in enumerate(gm.isomorphisms_iter()): + if i >= max_symmetries: + raise TooManySymmetriesError( + "Too many symmetries between %s and %s" % ( + str(model_ligand), str(target_ligand))) + symmetries.append((list(isomorphism.values()), list(isomorphism.keys()))) + assert len(symmetries) > 0 + LogDebug("Found %s isomorphic mappings (symmetries)" % len(symmetries)) + elif gm.subgraph_is_isomorphic() and substructure_match: + if not return_symmetries: + return True + symmetries = [] + for i, isomorphism in enumerate(gm.subgraph_isomorphisms_iter()): + if i >= max_symmetries: + raise TooManySymmetriesError( + "Too many symmetries between %s and %s" % ( + str(model_ligand), str(target_ligand))) + symmetries.append((list(isomorphism.values()), list(isomorphism.keys()))) + assert len(symmetries) > 0 + # Assert that all the atoms in the target are part of the substructure + assert len(symmetries[0][0]) == len(target_ligand.atoms) + LogDebug("Found %s subgraph isomorphisms (symmetries)" % len(symmetries)) + elif gm.subgraph_is_isomorphic(): + LogDebug("Found subgraph isomorphisms (symmetries), but" + " ignoring because substructure_match=False") + raise NoIsomorphicSymmetryError("No symmetry between %s and %s" % ( + str(model_ligand), str(target_ligand))) + else: + LogDebug("Found no isomorphic mappings (symmetries)") + raise NoSymmetryError("No symmetry between %s and %s" % ( + str(model_ligand), str(target_ligand))) + + return symmetries + +class NoSymmetryError(ValueError): + """ Exception raised when no symmetry can be found. + """ + pass + +class NoIsomorphicSymmetryError(ValueError): + """ Exception raised when no isomorphic symmetry can be found + + There would be isomorphic subgraphs for which symmetries can + be found, but substructure_match is disabled + """ + pass + +class TooManySymmetriesError(ValueError): + """ Exception raised when too many symmetries are found. + """ + pass + +class DisconnectedGraphError(Exception): + """ Exception raised when the ligand graph is disconnected. + """ + pass diff --git a/modules/mol/alg/pymod/ligand_scoring_scrmsd.py b/modules/mol/alg/pymod/ligand_scoring_scrmsd.py new file mode 100644 index 0000000000000000000000000000000000000000..0370f10d7371488a9c431a3905824a9737484ee9 --- /dev/null +++ b/modules/mol/alg/pymod/ligand_scoring_scrmsd.py @@ -0,0 +1,329 @@ +import numpy as np + +from ost import LogWarning +from ost import geom +from ost import mol + +from ost.mol.alg import ligand_scoring_base + +class SCRMSDScorer(ligand_scoring_base.LigandScorer): + + def __init__(self, model, target, model_ligands=None, target_ligands=None, + resnum_alignments=False, rename_ligand_chain=False, + substructure_match=False, coverage_delta=0.2, + max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0, + model_bs_radius=25, binding_sites_topn=100000, + full_bs_search=False): + + + super().__init__(model, target, model_ligands = model_ligands, + target_ligands = target_ligands, + resnum_alignments = resnum_alignments, + rename_ligand_chain = rename_ligand_chain, + substructure_match = substructure_match, + coverage_delta = coverage_delta, + max_symmetries = 1e5) + + self.bs_radius = bs_radius + self.lddt_lp_radius = lddt_lp_radius + self.model_bs_radius = model_bs_radius + self.binding_sites_topn = binding_sites_topn + self.full_bs_search = full_bs_search + + # Residues that are in contact with a ligand => binding site + # defined as all residues with at least one atom within self.radius + # key: ligand.handle.hash_code, value: EntityView of whatever + # entity ligand belongs to + self._binding_sites = dict() + + # cache for GetRepr chain mapping calls + self._repr = dict() + + # lazily precomputed variables to speedup GetRepr chain mapping calls + # for localized GetRepr searches + self.__chem_mapping = None + self.__chem_group_alns = None + self.__ref_mdl_alns = None + self.__chain_mapping_mdl = None + self._get_repr_input = dict() + + def _compute(self, symmetries, target_ligand, model_ligand): + + # set default to invalid scores + best_rmsd_result = {"rmsd": None, + "lddt_lp": None, + "bs_ref_res": list(), + "bs_ref_res_mapped": list(), + "bs_mdl_res_mapped": list(), + "bb_rmsd": None, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "chain_mapping": dict(), + "transform": geom.Mat4(), + "inconsistent_residues": list()} + + for r in self._get_repr(target_ligand, model_ligand): + rmsd = _SCRMSD_symmetries(symmetries, model_ligand, + target_ligand, transformation=r.transform) + + if best_rmsd_result["rmsd"] is None or rmsd < best_rmsd_result["rmsd"]: + best_rmsd_result = {"rmsd": rmsd, + "lddt_lp": r.lDDT, + "bs_ref_res": r.substructure.residues, + "bs_ref_res_mapped": r.ref_residues, + "bs_mdl_res_mapped": r.mdl_residues, + "bb_rmsd": r.bb_rmsd, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "chain_mapping": r.GetFlatChainMapping(), + "transform": r.transform, + "inconsistent_residues": r.inconsistent_residues} + + # set default to error + best_rmsd = np.nan + error_state = 10 + + if best_rmsd_result["rmsd"] is not None: + # but here we save the day + best_rmsd = best_rmsd_result["rmsd"] + error_state = 0 + + return (best_rmsd, error_state, best_rmsd_result) + + def _get_repr(self, target_ligand, model_ligand): + + key = None + if self.full_bs_search: + # all possible binding sites, independent from actual model ligand + key = (target_ligand.handle.hash_code, 0) + else: + key = (target_ligand.handle.hash_code, model_ligand.handle.hash_code) + + if key not in self._repr: + ref_bs = self._get_target_binding_site(target_ligand) + if self.full_bs_search: + reprs = self.chain_mapper.GetRepr( + ref_bs, self.model, inclusion_radius=self.lddt_lp_radius, + topn=self.binding_sites_topn) + else: + reprs = self.chain_mapper.GetRepr(ref_bs, self.model, + inclusion_radius=self.lddt_lp_radius, + topn=self.binding_sites_topn, + chem_mapping_result = self._get_get_repr_input(model_ligand)) + self._repr[key] = reprs + if len(reprs) == 0: + # whatever is in there already has precedence + if target_ligand not in self._unassigned_target_ligands_reason: + self._unassigned_target_ligands_reason[target_ligand] = ( + "model_representation", + "No representation of the reference binding site was " + "found in the model") + + return self._repr[key] + + def _get_target_binding_site(self, target_ligand): + + if target_ligand.handle.hash_code not in self._binding_sites: + + # create view of reference binding site + ref_residues_hashes = set() # helper to keep track of added residues + ignored_residue_hashes = {target_ligand.hash_code} + for ligand_at in target_ligand.atoms: + close_atoms = self.target.FindWithin(ligand_at.GetPos(), self.bs_radius) + for close_at in close_atoms: + # Skip any residue not in the chain mapping target + ref_res = close_at.GetResidue() + h = ref_res.handle.GetHashCode() + if h not in ref_residues_hashes and \ + h not in ignored_residue_hashes: + if self.chain_mapper.target.ViewForHandle(ref_res).IsValid(): + h = ref_res.handle.GetHashCode() + ref_residues_hashes.add(h) + elif ref_res.is_ligand: + LogWarning("Ignoring ligand %s in binding site of %s" % ( + ref_res.qualified_name, target_ligand.qualified_name)) + ignored_residue_hashes.add(h) + elif ref_res.chem_type == mol.ChemType.WATERS: + pass # That's ok, no need to warn + else: + LogWarning("Ignoring residue %s in binding site of %s" % ( + ref_res.qualified_name, target_ligand.qualified_name)) + ignored_residue_hashes.add(h) + + ref_bs = self.target.CreateEmptyView() + if ref_residues_hashes: + # reason for doing that separately is to guarantee same ordering of + # residues as in underlying entity. (Reorder by ResNum seems only + # available on ChainHandles) + for ch in self.target.chains: + for r in ch.residues: + if r.handle.GetHashCode() in ref_residues_hashes: + ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL) + if len(ref_bs.residues) == 0: + raise RuntimeError("Failed to add proximity residues to " + "the reference binding site entity") + else: + # Flag missing binding site + self._unassigned_target_ligands_reason[target_ligand] = ("binding_site", + "No residue in proximity of the target ligand") + + self._binding_sites[target_ligand.handle.hash_code] = ref_bs + + return self._binding_sites[target_ligand.handle.hash_code] + + @property + def _chem_mapping(self): + if self.__chem_mapping is None: + self.__chem_mapping, self.__chem_group_alns, \ + self.__chain_mapping_mdl = \ + self.chain_mapper.GetChemMapping(self.model) + return self.__chem_mapping + + @property + def _chem_group_alns(self): + if self.__chem_group_alns is None: + self.__chem_mapping, self.__chem_group_alns, \ + self.__chain_mapping_mdl = \ + self.chain_mapper.GetChemMapping(self.model) + return self.__chem_group_alns + + @property + def _ref_mdl_alns(self): + if self.__ref_mdl_alns is None: + self.__ref_mdl_alns = \ + chain_mapping._GetRefMdlAlns(self.chain_mapper.chem_groups, + self.chain_mapper.chem_group_alignments, + self._chem_mapping, + self._chem_group_alns) + return self.__ref_mdl_alns + + @property + def _chain_mapping_mdl(self): + if self.__chain_mapping_mdl is None: + self.__chem_mapping, self.__chem_group_alns, \ + self.__chain_mapping_mdl = \ + self.chain_mapper.GetChemMapping(self.model) + return self.__chain_mapping_mdl + + def _get_get_repr_input(self, mdl_ligand): + if mdl_ligand.handle.hash_code not in self._get_repr_input: + + # figure out what chains in the model are in contact with the ligand + # that may give a non-zero contribution to lDDT in + # chain_mapper.GetRepr + radius = self.model_bs_radius + chains = set() + for at in mdl_ligand.atoms: + close_atoms = self._chain_mapping_mdl.FindWithin(at.GetPos(), + radius) + for close_at in close_atoms: + chains.add(close_at.GetChain().GetName()) + + if len(chains) > 0: + + # the chain mapping model which only contains close chains + query = "cname=" + query += ','.join([mol.QueryQuoteName(x) for x in chains]) + mdl = self._chain_mapping_mdl.Select(query) + + # chem mapping which is reduced to the respective chains + chem_mapping = list() + for m in self._chem_mapping: + chem_mapping.append([x for x in m if x in chains]) + + self._get_repr_input[mdl_ligand.handle.hash_code] = \ + (mdl, chem_mapping) + + else: + self._get_repr_input[mdl_ligand.handle.hash_code] = \ + (self._chain_mapping_mdl.CreateEmptyView(), + [list() for _ in self._chem_mapping]) + + return (self._get_repr_input[mdl_ligand.hash_code][1], + self._chem_group_alns, + self._get_repr_input[mdl_ligand.hash_code][0]) + + +def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(), + substructure_match=False, max_symmetries=1e6): + """Calculate symmetry-corrected RMSD. + + Binding site superposition must be computed separately and passed as + `transformation`. + + :param model_ligand: The model ligand + :type model_ligand: :class:`ost.mol.ResidueHandle` or + :class:`ost.mol.ResidueView` + :param target_ligand: The target ligand + :type target_ligand: :class:`ost.mol.ResidueHandle` or + :class:`ost.mol.ResidueView` + :param transformation: Optional transformation to apply on each atom + position of model_ligand. + :type transformation: :class:`ost.geom.Mat4` + :param substructure_match: Set this to True to allow partial target + ligand. + :type substructure_match: :class:`bool` + :param max_symmetries: If more than that many isomorphisms exist, raise + a :class:`TooManySymmetriesError`. This can only be assessed by + generating at least that many isomorphisms and can take some time. + :type max_symmetries: :class:`int` + :rtype: :class:`float` + :raises: :class:`NoSymmetryError` when no symmetry can be found, + :class:`DisconnectedGraphError` when ligand graph is disconnected, + :class:`TooManySymmetriesError` when more than `max_symmetries` + isomorphisms are found. + """ + + symmetries = ligand_scoring_base.ComputeSymmetries(model_ligand, + target_ligand, + substructure_match=substructure_match, + by_atom_index=True, + max_symmetries=max_symmetries) + return _SCRMSD_symmetries(symmetries, model_ligand, target_ligand, + transformation) + + +def _SCRMSD_symmetries(symmetries, model_ligand, target_ligand, + transformation): + """Compute SCRMSD with pre-computed symmetries. Internal. """ + + # setup numpy positions for model ligand and apply transformation + mdl_ligand_pos = np.ones((model_ligand.GetAtomCount(), 4)) + for a_idx, a in enumerate(model_ligand.atoms): + p = a.GetPos() + mdl_ligand_pos[a_idx, 0] = p[0] + mdl_ligand_pos[a_idx, 1] = p[1] + mdl_ligand_pos[a_idx, 2] = p[2] + np_transformation = np.zeros((4,4)) + for i in range(4): + for j in range(4): + np_transformation[i,j] = transformation[i,j] + mdl_ligand_pos = mdl_ligand_pos.dot(np_transformation.T)[:,:3] + + # setup numpy positions for target ligand + trg_ligand_pos = np.zeros((target_ligand.GetAtomCount(), 3)) + for a_idx, a in enumerate(target_ligand.atoms): + p = a.GetPos() + trg_ligand_pos[a_idx, 0] = p[0] + trg_ligand_pos[a_idx, 1] = p[1] + trg_ligand_pos[a_idx, 2] = p[2] + + # position matrices to iterate symmetries + # there is a guarantee that + # target_ligand.GetAtomCount() <= model_ligand.GetAtomCount() + # and that each target ligand atom is part of every symmetry + # => target_ligand.GetAtomCount() is size of both position matrices + rmsd_mdl_pos = np.zeros((target_ligand.GetAtomCount(), 3)) + rmsd_trg_pos = np.zeros((target_ligand.GetAtomCount(), 3)) + + # iterate symmetries and find the one with lowest RMSD + best_rmsd = np.inf + for i, (trg_sym, mdl_sym) in enumerate(symmetries): + for idx, (mdl_anum, trg_anum) in enumerate(zip(mdl_sym, trg_sym)): + rmsd_mdl_pos[idx,:] = mdl_ligand_pos[mdl_anum, :] + rmsd_trg_pos[idx,:] = trg_ligand_pos[trg_anum, :] + rmsd = np.sqrt(((rmsd_mdl_pos - rmsd_trg_pos)**2).sum(-1).mean()) + if rmsd < best_rmsd: + best_rmsd = rmsd + + return best_rmsd diff --git a/modules/mol/alg/tests/CMakeLists.txt b/modules/mol/alg/tests/CMakeLists.txt index e1a1aaf827cc2e84a9f9d4ec174258ad41159f77..7c054701cca0cb7f9829d3ccddd90942ab4394e9 100644 --- a/modules/mol/alg/tests/CMakeLists.txt +++ b/modules/mol/alg/tests/CMakeLists.txt @@ -20,7 +20,8 @@ if (COMPOUND_LIB) list(APPEND OST_MOL_ALG_UNIT_TESTS test_qsscoring.py test_nonstandard.py test_chain_mapping.py - test_ligand_scoring.py) + test_ligand_scoring.py + test_ligand_scoring_fancy.py) endif() ost_unittest(MODULE mol_alg SOURCES "${OST_MOL_ALG_UNIT_TESTS}" LINK ost_io) diff --git a/modules/mol/alg/tests/test_ligand_scoring_fancy.py b/modules/mol/alg/tests/test_ligand_scoring_fancy.py new file mode 100644 index 0000000000000000000000000000000000000000..5dca71f51509d84355e36970390d5629634ef1de --- /dev/null +++ b/modules/mol/alg/tests/test_ligand_scoring_fancy.py @@ -0,0 +1,309 @@ +import unittest, os, sys +from functools import lru_cache + +import numpy as np + +import ost +from ost import io, mol, geom +# check if we can import: fails if numpy or scipy not available +try: + from ost.mol.alg.ligand_scoring_base import * + from ost.mol.alg import ligand_scoring_base + from ost.mol.alg import ligand_scoring_scrmsd +except ImportError: + print("Failed to import ligand_scoring.py. Happens when numpy, scipy or " + "networkx is missing. Ignoring test_ligand_scoring.py tests.") + sys.exit(0) + + +def _GetTestfilePath(filename): + """Get the path to the test file given filename""" + return os.path.join('testfiles', filename) + + +@lru_cache(maxsize=None) +def _LoadMMCIF(filename): + path = _GetTestfilePath(filename) + ent = io.LoadMMCIF(path) + return ent + + +@lru_cache(maxsize=None) +def _LoadPDB(filename): + path = _GetTestfilePath(filename) + ent = io.LoadPDB(path) + return ent + + +@lru_cache(maxsize=None) +def _LoadEntity(filename): + path = _GetTestfilePath(filename) + ent = io.LoadEntity(path) + return ent + + +class TestLigandScoringFancy(unittest.TestCase): + + def setUp(self): + # Silence expected warnings about ignoring of ligands in binding site + ost.PushVerbosityLevel(ost.LogLevel.Error) + + def tearDown(self): + ost.PopVerbosityLevel() + + def test_extract_ligands_mmCIF(self): + """Test that we can extract ligands from mmCIF files. + """ + trg = _LoadMMCIF("1r8q.cif.gz") + mdl = _LoadMMCIF("P84080_model_02.cif.gz") + + sc = LigandScorer(mdl, trg, None, None) + + self.assertEqual(len(sc.target_ligands), 7) + self.assertEqual(len(sc.model_ligands), 1) + self.assertEqual(len([r for r in sc.target.residues if r.is_ligand]), 7) + self.assertEqual(len([r for r in sc.model.residues if r.is_ligand]), 1) + + def test_extract_ligands_PDB(self): + """Test that we can extract ligands from PDB files containing HET records. + """ + trg = _LoadPDB("1R8Q.pdb") + mdl = _LoadMMCIF("P84080_model_02.cif.gz") + + sc = LigandScorer(mdl, trg, None, None) + + self.assertEqual(len(sc.target_ligands), 7) + self.assertEqual(len(sc.model_ligands), 1) + self.assertEqual(len([r for r in sc.target.residues if r.is_ligand]), 7) + self.assertEqual(len([r for r in sc.model.residues if r.is_ligand]), 1) + + def test_init_given_ligands(self): + """Test that we can instantiate the scorer with ligands contained in + the target and model entity and given in a list. + """ + trg = _LoadMMCIF("1r8q.cif.gz") + mdl = _LoadMMCIF("P84080_model_02.cif.gz") + + # Pass entity views + trg_lig = [trg.Select("rname=MG"), trg.Select("rname=G3D")] + mdl_lig = [mdl.Select("rname=G3D")] + sc = LigandScorer(mdl, trg, mdl_lig, trg_lig) + + self.assertEqual(len(sc.target_ligands), 4) + self.assertEqual(len(sc.model_ligands), 1) + # IsLigand flag should still be set even on not selected ligands + self.assertEqual(len([r for r in sc.target.residues if r.is_ligand]), 7) + self.assertEqual(len([r for r in sc.model.residues if r.is_ligand]), 1) + + # Ensure the residues are not copied + self.assertEqual(len(sc.target.Select("rname=MG").residues), 2) + self.assertEqual(len(sc.target.Select("rname=G3D").residues), 2) + self.assertEqual(len(sc.model.Select("rname=G3D").residues), 1) + + # Pass residue handles + trg_lig = [trg.FindResidue("F", 1), trg.FindResidue("H", 1)] + mdl_lig = [mdl.FindResidue("L_2", 1)] + sc = LigandScorer(mdl, trg, mdl_lig, trg_lig) + + self.assertEqual(len(sc.target_ligands), 2) + self.assertEqual(len(sc.model_ligands), 1) + + # Ensure the residues are not copied + self.assertEqual(len(sc.target.Select("rname=ZN").residues), 1) + self.assertEqual(len(sc.target.Select("rname=G3D").residues), 2) + self.assertEqual(len(sc.model.Select("rname=G3D").residues), 1) + + def test_init_sdf_ligands(self): + """Test that we can instantiate the scorer with ligands from separate SDF files. + + In order to setup the ligand SDF files, the following code was used: + for prefix in [os.path.join('testfiles', x) for x in ["1r8q", "P84080_model_02"]]: + trg = io.LoadMMCIF("%s.cif.gz" % prefix) + trg_prot = trg.Select("protein=True") + io.SavePDB(trg_prot, "%s_protein.pdb.gz" % prefix) + lig_num = 0 + for chain in trg.chains: + if chain.chain_type == mol.ChainType.CHAINTYPE_NON_POLY: + lig_sel = trg.Select("cname=%s" % chain.name) + lig_ent = mol.CreateEntityFromView(lig_sel, False) + io.SaveEntity(lig_ent, "%s_ligand_%d.sdf" % (prefix, lig_num)) + lig_num += 1 + """ + mdl = _LoadPDB("P84080_model_02_nolig.pdb") + mdl_ligs = [_LoadEntity("P84080_model_02_ligand_0.sdf")] + trg = _LoadPDB("1r8q_protein.pdb.gz") + trg_ligs = [_LoadEntity("1r8q_ligand_%d.sdf" % i) for i in range(7)] + + # Pass entities + sc = LigandScorer(mdl, trg, mdl_ligs, trg_ligs) + + self.assertEqual(len(sc.target_ligands), 7) + self.assertEqual(len(sc.model_ligands), 1) + # Ensure we set the is_ligand flag + self.assertEqual(len([r for r in sc.target.residues if r.is_ligand]), 7) + self.assertEqual(len([r for r in sc.model.residues if r.is_ligand]), 1) + + # Pass residues + mdl_ligs_res = [mdl_ligs[0].residues[0]] + trg_ligs_res = [res for ent in trg_ligs for res in ent.residues] + + sc = LigandScorer(mdl, trg, mdl_ligs_res, trg_ligs_res) + + self.assertEqual(len(sc.target_ligands), 7) + self.assertEqual(len(sc.model_ligands), 1) + + def test_init_reject_duplicate_ligands(self): + """Test that we reject input if multiple ligands with the same chain + name/residue number are given. + """ + mdl = _LoadPDB("P84080_model_02_nolig.pdb") + mdl_ligs = [_LoadEntity("P84080_model_02_ligand_0.sdf")] + trg = _LoadPDB("1r8q_protein.pdb.gz") + trg_ligs = [_LoadEntity("1r8q_ligand_%d.sdf" % i) for i in range(7)] + + # Reject identical model ligands + with self.assertRaises(RuntimeError): + sc = LigandScorer(mdl, trg, [mdl_ligs[0], mdl_ligs[0]], trg_ligs) + + # Reject identical target ligands + lig0 = trg_ligs[0].Copy() + lig1 = trg_ligs[1].Copy() + ed1 = lig1.EditXCS() + ed1.RenameChain(lig1.chains[0], lig0.chains[0].name) + ed1.SetResidueNumber(lig1.residues[0], lig0.residues[0].number) + with self.assertRaises(RuntimeError): + sc = LigandScorer(mdl, trg, mdl_ligs, [lig0, lig1]) + + def test__ResidueToGraph(self): + """Test that _ResidueToGraph works as expected + """ + mdl_lig = _LoadEntity("P84080_model_02_ligand_0.sdf") + + graph = ligand_scoring_base._ResidueToGraph(mdl_lig.residues[0]) + self.assertEqual(len(graph.edges), 34) + self.assertEqual(len(graph.nodes), 32) + # Check an arbitrary node + self.assertEqual([a for a in graph.adj["14"].keys()], ["13", "29"]) + + graph = ligand_scoring_base._ResidueToGraph(mdl_lig.residues[0], by_atom_index=True) + self.assertEqual(len(graph.edges), 34) + self.assertEqual(len(graph.nodes), 32) + # Check an arbitrary node + self.assertEqual([a for a in graph.adj[13].keys()], [12, 28]) + + def test__ComputeSymmetries(self): + """Test that _ComputeSymmetries works. + """ + trg = _LoadMMCIF("1r8q.cif.gz") + mdl = _LoadMMCIF("P84080_model_02.cif.gz") + + trg_mg1 = trg.FindResidue("E", 1) + trg_g3d1 = trg.FindResidue("F", 1) + trg_afb1 = trg.FindResidue("G", 1) + trg_g3d2 = trg.FindResidue("J", 1) + mdl_g3d = mdl.FindResidue("L_2", 1) + + sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d, trg_g3d1) + self.assertEqual(len(sym), 72) + + sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d, trg_g3d1, by_atom_index=True) + self.assertEqual(len(sym), 72) + + # Test that we can match ions read from SDF + sdf_lig = _LoadEntity("1r8q_ligand_0.sdf") + sym = ligand_scoring_base.ComputeSymmetries(trg_mg1, sdf_lig.residues[0], by_atom_index=True) + self.assertEqual(len(sym), 1) + + # Test that it works with views and only consider atoms in the view + # Skip PA, PB and O[1-3]A and O[1-3]B in target and model + # We assume atom index are fixed and won't change + trg_g3d1_sub_ent = trg_g3d1.Select("aindex>6019") + trg_g3d1_sub = trg_g3d1_sub_ent.residues[0] + mdl_g3d_sub_ent = mdl_g3d.Select("aindex>1447") + mdl_g3d_sub = mdl_g3d_sub_ent.residues[0] + + sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d_sub, trg_g3d1_sub) + self.assertEqual(len(sym), 6) + + sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d_sub, trg_g3d1_sub, by_atom_index=True) + self.assertEqual(len(sym), 6) + + # Substructure matches + sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d, trg_g3d1_sub, substructure_match=True) + self.assertEqual(len(sym), 6) + + # Missing atoms only allowed in target, not in model + with self.assertRaises(NoSymmetryError): + ligand_scoring_base.ComputeSymmetries(mdl_g3d_sub, trg_g3d1, substructure_match=True) + + def test_SCRMSD(self): + """Test that SCRMSD works. + """ + trg = _LoadMMCIF("1r8q.cif.gz") + mdl = _LoadMMCIF("P84080_model_02.cif.gz") + + trg_mg1 = trg.FindResidue("E", 1) + trg_g3d1 = trg.FindResidue("F", 1) + trg_afb1 = trg.FindResidue("G", 1) + trg_g3d2 = trg.FindResidue("J", 1) + mdl_g3d = mdl.FindResidue("L_2", 1) + + rmsd = ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d1) + self.assertAlmostEqual(rmsd, 2.21341e-06, 10) + rmsd = ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d2) + self.assertAlmostEqual(rmsd, 61.21325, 4) + + # Ensure we raise a NoSymmetryError if the ligand is wrong + with self.assertRaises(NoSymmetryError): + ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_mg1) + with self.assertRaises(NoSymmetryError): + ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_afb1) + + # Assert that transform works + trans = geom.Mat4(-0.999256, 0.00788487, -0.0377333, -15.4397, + 0.0380652, 0.0473315, -0.998154, 29.9477, + -0.00608426, -0.998848, -0.0475963, 28.8251, + 0, 0, 0, 1) + rmsd = ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d2, transformation=trans) + self.assertAlmostEqual(rmsd, 0.293972, 5) + + # Assert that substructure matches work + trg_g3d1_sub = trg_g3d1.Select("aindex>6019").residues[0] # Skip PA, PB and O[1-3]A and O[1-3]B. + # mdl_g3d_sub = mdl_g3d.Select("aindex>1447").residues[0] # Skip PA, PB and O[1-3]A and O[1-3]B. + with self.assertRaises(NoIsomorphicSymmetryError): + ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d1_sub) # no full match + + # But partial match is OK + rmsd = ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d1_sub, substructure_match=True) + self.assertAlmostEqual(rmsd, 2.2376232209353475e-06, 8) + + # Ensure it doesn't work the other way around - ie incomplete model is invalid + with self.assertRaises(NoSymmetryError): + ligand_scoring_scrmsd.SCRMSD(trg_g3d1_sub, mdl_g3d) # no full match + + def test_compute_rmsd_scores(self): + """Test that _compute_scores works. + """ + trg = _LoadMMCIF("1r8q.cif.gz") + mdl = _LoadMMCIF("P84080_model_02.cif.gz") + mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf")) + sc = ligand_scoring_scrmsd.SCRMSDScorer(mdl, trg, [mdl_lig], None) + + # Note: expect warning about Binding site of H.ZN1 not mapped to the model + self.assertEqual(sc.score_matrix.shape, (7, 1)) + np.testing.assert_almost_equal(sc.score_matrix, np.array( + [[np.nan], + [0.04244993], + [np.nan], + [np.nan], + [np.nan], + [0.29399303], + [np.nan]]), decimal=5) + + +if __name__ == "__main__": + from ost import testutils + if testutils.DefaultCompoundLibIsSet(): + testutils.RunTests() + else: + print('No compound lib available. Ignoring test_ligand_scoring.py tests.') \ No newline at end of file