From c3684ce5293cf584f2dd8e599752ea1027aeab96 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Mon, 22 Apr 2024 09:29:47 +0200
Subject: [PATCH] ligand scoring: speedups (triggers a unit test, to be fixed
 in later commits)

LigandScorer uses ChainMapper.GetRepr to match binding sites between target
and model. This commit limits the search space for model binding sites to
locations where there actually is a ligand.
---
 modules/mol/alg/pymod/ligand_scoring.py | 100 +++++++++++++++++++++---
 1 file changed, 88 insertions(+), 12 deletions(-)

diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py
index 07ebe9a15..7ba49504a 100644
--- a/modules/mol/alg/pymod/ligand_scoring.py
+++ b/modules/mol/alg/pymod/ligand_scoring.py
@@ -333,6 +333,12 @@ class LigandScorer:
         self._binding_sites = {}
         self.__model_mapping = None
 
+        # lazily precomputed variables to speedup GetRepr chain mapping calls
+        self._chem_mapping = None
+        self._chem_group_alns = None
+        self._chain_mapping_mdl = None
+        self._get_repr_input = dict()
+
         # Bookkeeping of unassigned ligands
         self._unassigned_target_ligands = None
         self._unassigned_model_ligands = None
@@ -551,23 +557,32 @@ class LigandScorer:
                     raise RuntimeError("Failed to add proximity residues to "
                                        "the reference binding site entity")
 
-                # Find the representations
-                if self.global_chain_mapping:
-                    self._binding_sites[ligand.hash_code] = self.chain_mapper.GetRepr(
-                        ref_bs, self.model, inclusion_radius=self.lddt_lp_radius,
-                        global_mapping = self._model_mapping)
-                else:
-                    self._binding_sites[ligand.hash_code] = self.chain_mapper.GetRepr(
-                        ref_bs, self.model, inclusion_radius=self.lddt_lp_radius,
-                        topn=self.binding_sites_topn)
-
+                reprs = list()
+                for model_ligand in self.model_ligands:
+                    # Find the representations
+                    if self.global_chain_mapping:
+                        reprs.extend(self.chain_mapper.GetRepr(ref_bs, self.model,
+                                                               inclusion_radius=self.lddt_lp_radius,
+                                                               chem_mapping_result = self.get_get_repr_input(model_ligand),
+                                                               global_mapping = self._model_mapping))
+                    else:
+                        reprs.extend(self.chain_mapper.GetRepr(ref_bs, self.model,
+                                                               inclusion_radius=self.lddt_lp_radius,
+                                                               topn=self.binding_sites_topn,
+                                                               chem_mapping_result = self.get_get_repr_input(model_ligand)))
+
+                ################################################
+                # TODO: sort by lDDT and ensure unique results #
+                ################################################
+                
                 # Flag empty representation
+                self._binding_sites[ligand.hash_code] = reprs
                 if not self._binding_sites[ligand.hash_code]:
                     self._unassigned_target_ligands_reason[ligand] = (
                         "model_representation",
                         "No representation of the reference binding site was "
                         "found in the model")
-
+                
             else:  # if ref_residues_hashes
                 # Flag missing binding site
                 self._unassigned_target_ligands_reason[ligand] = ("binding_site",
@@ -1697,6 +1712,68 @@ class LigandScorer:
             iso = "full graph isomorphism"
         return ("identity", "Ligand was not found in the model (by %s)" % iso)
 
+    @property
+    def chem_mapping(self):
+        if self._chem_mapping is None:
+            self._chem_mapping, self._chem_group_alns, \
+            self._chain_mapping_mdl = \
+            self.chain_mapper.GetChemMapping(self.model)
+        return self._chem_mapping
+
+    @property
+    def chem_group_alns(self):
+        if self._chem_group_alns is None:   
+            self._chem_mapping, self._chem_group_alns, \
+            self._chain_mapping_mdl = \
+            self.chain_mapper.GetChemMapping(self.model)
+        return self._chem_group_alns
+  
+    @property
+    def chain_mapping_mdl(self):
+        if self._chain_mapping_mdl is None:   
+            self._chem_mapping, self._chem_group_alns, \
+            self._chain_mapping_mdl = \
+            self.chain_mapper.GetChemMapping(self.model)
+        return self._chain_mapping_mdl
+
+    def get_get_repr_input(self, mdl_ligand):
+        if mdl_ligand.handle.hash_code not in self._get_repr_input:
+
+            # figure out what chains in the model are in contact with the ligand
+            # that may give a non-zero contribution to lDDT in
+            # chain_mapper.GetRepr
+            radius = self.lddt_lp_radius + 4.0
+            chains = set()
+            for at in mdl_ligand.atoms:
+                close_atoms = self.chain_mapping_mdl.FindWithin(at.GetPos(),
+                                                                radius)
+                for close_at in close_atoms:
+                    chains.add(close_at.GetChain().GetName())
+
+            if len(chains) > 0:
+
+                # the chain mapping model which only contains close chains
+                query = "cname="
+                query += ','.join([mol.QueryQuoteName(x) for x in chains])
+                mdl = self.chain_mapping_mdl.Select(query)
+
+                # chem mapping which is reduced to the respective chains
+                chem_mapping = list()
+                for m in self.chem_mapping:
+                    chem_mapping.append([x for x in m if x in chains]) 
+
+                self._get_repr_input[mdl_ligand.handle.hash_code] = \
+                (mdl, chem_mapping)
+
+            else:
+                self._get_repr_input[mdl_ligand.handle.hash_code] = \
+                (self.chain_mapping_mdl.CreateEmptyView(),
+                 [list() for _ in self.chem_mapping])
+
+        return (self._get_repr_input[mdl_ligand.hash_code][1],
+                self.chem_group_alns,
+                self._get_repr_input[mdl_ligand.hash_code][0])
+
 
 def _ResidueToGraph(residue, by_atom_index=False):
     """Return a NetworkX graph representation of the residue.
@@ -1898,7 +1975,6 @@ def _ComputeSymmetries(model_ligand, target_ligand, substructure_match=False,
 
     return symmetries
 
-
 class NoSymmetryError(ValueError):
     """Exception raised when no symmetry can be found.
     """
-- 
GitLab