From 071a5a6ef62354fb06c6560c35fcee42218d2e34 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Tue, 28 May 2024 14:53:11 +0200
Subject: [PATCH] ligand scoring: refactor ligand assignment functionality

---
 modules/mol/alg/pymod/ligand_scoring_base.py  | 65 +++++++++++++++++++
 .../mol/alg/pymod/ligand_scoring_lddtpli.py   |  3 +
 .../mol/alg/pymod/ligand_scoring_scrmsd.py    |  3 +
 .../alg/tests/test_ligand_scoring_fancy.py    | 13 ++++
 4 files changed, 84 insertions(+)

diff --git a/modules/mol/alg/pymod/ligand_scoring_base.py b/modules/mol/alg/pymod/ligand_scoring_base.py
index 396c748e1..ac7c9b18c 100644
--- a/modules/mol/alg/pymod/ligand_scoring_base.py
+++ b/modules/mol/alg/pymod/ligand_scoring_base.py
@@ -67,6 +67,8 @@ class LigandScorer:
         self._coverage_matrix = None
         self._aux_data = None
 
+        self._assignment = None
+
     @property
     def states(self):
         """ Encodes states of ligand pairs
@@ -145,6 +147,60 @@ class LigandScorer:
             self._compute_scores()
         return self._aux_matrix
 
+    @property
+    def assignment(self):
+        """ Ligand assignment based on computed scores
+
+        Implements a greedy algorithm to assign target and model ligands
+        with each other. Starts from each valid ligand pair as indicated
+        by a state of 0 in :attr:`states`. Each iteration first selects
+        high coverage pairs. Given max_coverage defined as the highest
+        coverage observed in the available pairs, all pairs with coverage
+        in [max_coverage-*coverage_delta*, max_coverage] are selected.
+        The best scoring pair among those is added to the assignment
+        and the whole process is repeated until there are no ligands to
+        assign anymore.
+
+        :rtype: :class:`list`: of :class:`tuple` (trg_lig_idx, mdl_lig_idx)
+        """
+        if self._assignment is None:
+            self._assignment = list()
+            # Build working array that contains tuples for all mdl/trg ligand
+            # pairs with valid score as indicated by a state of 0:
+            # (score, coverage, trg_ligand_idx, mdl_ligand_idx)
+            tmp = list()
+            for trg_idx in range(self.score_matrix.shape[0]):
+                for mdl_idx in range(self.score_matrix.shape[1]):
+                    if self.states[trg_idx, mdl_idx] == 0:
+                        tmp.append((self.score_matrix[trg_idx, mdl_idx],
+                                    self.coverage_matrix[trg_idx, mdl_idx],
+                                    trg_idx, mdl_idx))
+
+            # sort by score, such that best scoring item is in front
+            if self._score_dir() == '+':
+                tmp.sort(reverse=True)
+            elif self._score_dir() == '-':
+                tmp.sort()
+            else:
+                raise RuntimeError("LigandScorer._score_dir must return on in "
+                                   "['+', '-']")
+
+            while len(tmp) > 0:
+                # select high coverage ligand pairs in working array
+                coverage_thresh = max([x[1] for x in tmp]) - self.coverage_delta
+                top_coverage = [x for x in tmp if x[1] >= coverage_thresh]
+
+                # working array is sorted by score => just pick first one
+                a = top_coverage[0][2] # selected trg_ligand_idx
+                b = top_coverage[0][3] # selected mdl_ligand_idx
+                self._assignment.append((a, b))
+
+                # kick out remaining pairs involving these ligands
+                tmp = [x for x in tmp if (x[2] != a and x[3] != b)]
+
+        return self._assignment
+
+
     @property
     def chain_mapper(self):
         """ Chain mapper object for the given :attr:`target`.
@@ -398,6 +454,15 @@ class LigandScorer:
         """
         raise NotImplementedError("_compute must be implemented by child class")
 
+    def _score_dir(self):
+        """ Return direction of score - defined by child class
+
+        Relevant for ligand assignment. Must return a string in ['+', '-'].
+        '+' for ascending scores, i.e. higher is better (lddt etc.)
+        '-' for descending scores, i.e. lower is better (rmsd etc.)
+        """
+        raise NotImplementedError("_score_dir must be implemented by child class")
+
 
 def _ResidueToGraph(residue, by_atom_index=False):
     """Return a NetworkX graph representation of the residue.
diff --git a/modules/mol/alg/pymod/ligand_scoring_lddtpli.py b/modules/mol/alg/pymod/ligand_scoring_lddtpli.py
index 6f4d88898..5457a06c2 100644
--- a/modules/mol/alg/pymod/ligand_scoring_lddtpli.py
+++ b/modules/mol/alg/pymod/ligand_scoring_lddtpli.py
@@ -68,6 +68,9 @@ class LDDTPLIScorer(ligand_scoring_base.LigandScorer):
 
         return (score, state, result)
 
+    def _score_dir(self):
+        return '+'
+
     def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
                                            model_ligand):
 
diff --git a/modules/mol/alg/pymod/ligand_scoring_scrmsd.py b/modules/mol/alg/pymod/ligand_scoring_scrmsd.py
index daa46bc8f..134821de9 100644
--- a/modules/mol/alg/pymod/ligand_scoring_scrmsd.py
+++ b/modules/mol/alg/pymod/ligand_scoring_scrmsd.py
@@ -90,6 +90,9 @@ class SCRMSDScorer(ligand_scoring_base.LigandScorer):
 
         return (best_rmsd, error_state, best_rmsd_result)
 
+    def _score_dir(self):
+        return '-'
+
     def _get_repr(self, target_ligand, model_ligand):
 
         key = None
diff --git a/modules/mol/alg/tests/test_ligand_scoring_fancy.py b/modules/mol/alg/tests/test_ligand_scoring_fancy.py
index a13a37be2..8e2a14343 100644
--- a/modules/mol/alg/tests/test_ligand_scoring_fancy.py
+++ b/modules/mol/alg/tests/test_ligand_scoring_fancy.py
@@ -424,6 +424,19 @@ class TestLigandScoringFancy(unittest.TestCase):
                                B_count/(A_count + B_count - TRP66_count + \
                                lig.GetAtomCount()), 5)
 
+    def test_assignment(self):
+        trg = _LoadMMCIF("1r8q.cif.gz")
+        mdl = _LoadMMCIF("P84080_model_02.cif.gz")
+        sc = ligand_scoring_scrmsd.SCRMSDScorer(mdl, trg)
+        self.assertEqual(sc.assignment, [(1, 0)])
+
+        sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg)
+        self.assertEqual(sc.assignment, [(5, 0)])
+
+
+
+
+
 
 if __name__ == "__main__":
     from ost import testutils
-- 
GitLab