From fa58ebaefe6bcfe9a7a99642a76a4fc810a08845 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 19 Jul 2023 08:52:15 +0200
Subject: [PATCH] Chain Mapping: QSscore mapping code cleanup and speedups for
 large structures

---
 modules/mol/alg/pymod/chain_mapping.py | 172 +++++++++++++++----------
 1 file changed, 101 insertions(+), 71 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index c4a017f3d..d9d3c96d9 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -19,6 +19,8 @@ from ost import geom
 from ost.mol.alg import lddt
 from ost.mol.alg import qsscore
 
+import time
+
 def _CSel(ent, cnames):
     """ Returns view with specified chains
 
@@ -2710,14 +2712,7 @@ class _QSScoreGreedySearcher(qsscore.QSScorer):
         for ref_ch in mapping.keys():
             map_targets.discard(ref_ch)
 
-        # same for model
-        mdl_map_targets = set()
-        for mdl_ch in mapping.values():
-            mdl_map_targets.update(self.mdl_neighbors[mdl_ch])
-        for mdl_ch in mapping.values():
-            mdl_map_targets.discard(mdl_ch)
-
-        if len(map_targets) == 0 or len(mdl_map_targets) == 0:
+        if len(map_targets) == 0:
             return mapping # nothing to extend
 
         # keep track of what model chains are not yet mapped for each chem group
@@ -2741,19 +2736,49 @@ class _QSScoreGreedySearcher(qsscore.QSScorer):
             if max_ext is not None and len(newly_mapped_ref_chains) >= max_ext:
                 break
 
+            score_result = self.FromFlatMapping(mapping)
+            old_score = score_result.QS_global
+            nominator = score_result.weighted_scores
+            denominator = score_result.weight_sum + score_result.weight_extra_all
+
+            max_diff = 0.0
             max_mapping = None
-            best_score = self.FromFlatMapping(mapping).QS_global
             for ref_ch in map_targets:
                 chem_group_idx = self.ref_ch_group_mapper[ref_ch]
                 for mdl_ch in free_mdl_chains[chem_group_idx]:
-                    if mdl_ch in mdl_map_targets:
-                        new_mapping = dict(mapping)
-                        new_mapping[ref_ch] = mdl_ch
-                        new_score = self.FromFlatMapping(new_mapping).QS_global
-                        if new_score > best_score:
-                            best_score = new_score
+                    # we're not computing full QS-score here, we directly hack
+                    # into the QS-score formula to compute a diff
+                    nominator_diff = 0.0
+                    denominator_diff = 0.0
+                    for neighbor in self.neighbors[ref_ch]:
+                        if neighbor in mapping and mapping[neighbor] in \
+                        self.mdl_neighbors[mdl_ch]:
+                            # it's a newly added interface if (ref_ch, mdl_ch)
+                            # are added to mapping
+                            int1 = (ref_ch, neighbor)
+                            int2 = (mdl_ch, mapping[neighbor])
+                            a, b, c, d = self._MappedInterfaceScores(int1, int2)
+                            nominator_diff += a # weighted_scores
+                            denominator_diff += b # weight_sum
+                            denominator_diff += d # weight_extra_all
+                            # the respective interface penalties are subtracted
+                            # from denominator
+                            denominator_diff -= self._InterfacePenalty1(int1)
+                            denominator_diff -= self._InterfacePenalty2(int2)
+
+                    if nominator_diff > 0:
+                        # Only accept a new solution if its actually connected
+                        # i.e. nominator_diff > 0.
+                        new_nominator = nominator + nominator_diff
+                        new_denominator = denominator + denominator_diff
+                        new_score = 0.0
+                        if new_denominator != 0.0:
+                            new_score = new_nominator/new_denominator
+                        diff = new_score - old_score
+                        if diff > max_diff:
+                            max_diff = diff
                             max_mapping = (ref_ch, mdl_ch)
-
+     
             if max_mapping is not None:
                 something_happened = True
                 # assign new found mapping
@@ -2765,13 +2790,8 @@ class _QSScoreGreedySearcher(qsscore.QSScorer):
                     if neighbor not in mapping:
                         map_targets.add(neighbor)
 
-                for neighbor in self.mdl_neighbors[max_mapping[1]]:
-                    if neighbor not in mapping.values():
-                        mdl_map_targets.add(neighbor)
-
-                # remove chains from map targets
+                # remove the ref chain from map targets
                 map_targets.remove(max_mapping[0])
-                mdl_map_targets.remove(max_mapping[1])
 
                 # remove the mdl chain from free_mdl_chains - its taken...
                 chem_group_idx = self.ref_ch_group_mapper[max_mapping[0]]
@@ -2839,28 +2859,37 @@ def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d,
     return (best_mapping, best_score)
 
 
+def _GetSeeds(ref_chem_groups, mdl_chem_groups, mapped_ref_chains = set(),
+              mapped_mdl_chains = set()):
+    seeds = list()
+    for ref_chains, mdl_chains in zip(ref_chem_groups,
+                                      mdl_chem_groups):
+        for ref_ch in ref_chains:
+            if ref_ch not in mapped_ref_chains:
+                for mdl_ch in mdl_chains:
+                    if mdl_ch not in mapped_mdl_chains:
+                        seeds.append((ref_ch, mdl_ch))
+    return seeds
+
+
 def _QSScoreGreedyFast(the_greed):
 
     something_happened = True
     mapping = dict()
-
     while something_happened:
         something_happened = False
         # search for best scoring starting point, we're using lDDT here
         n_best = 0
         best_seed = None
-        mapped_ref_chains = set(mapping.keys())
-        mapped_mdl_chains = set(mapping.values())
-        for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups,
-                                          the_greed.mdl_chem_groups):
-            for ref_ch in ref_chains:
-                if ref_ch not in mapped_ref_chains:
-                    for mdl_ch in mdl_chains:
-                        if mdl_ch not in mapped_mdl_chains:
-                            n = the_greed.SCCounts(ref_ch, mdl_ch)
-                            if n > n_best:
-                                n_best = n
-                                best_seed = (ref_ch, mdl_ch)
+        seeds = _GetSeeds(the_greed.ref_chem_groups,
+                          the_greed.mdl_chem_groups,
+                          mapped_ref_chains = set(mapping.keys()),
+                          mapped_mdl_chains = set(mapping.values()))
+        for seed in seeds:
+            n = the_greed.SCCounts(seed[0], seed[1])
+            if n > n_best:
+                n_best = n
+                best_seed = seed
         if n_best == 0:
             break # no proper seed found anymore...
         # add seed to mapping and start the greed
@@ -2868,7 +2897,6 @@ def _QSScoreGreedyFast(the_greed):
         mapping = the_greed.ExtendMapping(mapping)
         something_happened = True
 
-
     # translate mapping format and return
     final_mapping = list()
     for ref_chains in the_greed.ref_chem_groups:
@@ -2893,43 +2921,45 @@ def _QSScoreGreedyFull(the_greed, n_mdl_chains):
     if n_mdl_chains is not None and n_mdl_chains < 1:
         raise RuntimeError("n_mdl_chains must be None or >= 1")
 
-    something_happened = True
-    mapping = dict()
+    seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups)
+    best_overall_score = -1.0
+    best_overall_mapping = dict()
 
-    while something_happened:
-        something_happened = False
-        # Try all possible starting points and keep the one giving the best QS score
-        best_score = -1.0
-        best_mapping = None
-        mapped_ref_chains = set(mapping.keys())
-        mapped_mdl_chains = set(mapping.values())
-        for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups,
-                                          the_greed.mdl_chem_groups):
-            for ref_ch in ref_chains:
-                if ref_ch not in mapped_ref_chains:
-                    seeds = list()
-                    for mdl_ch in mdl_chains:
-                        if mdl_ch not in mapped_mdl_chains:
-                            seeds.append((ref_ch, mdl_ch))
-                    if n_mdl_chains is not None and n_mdl_chains < len(seeds):
-                        counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds]
-                        tmp = [(a,b) for a,b in zip(counts, seeds)]
-                        tmp.sort(reverse=True)
-                        seeds = [item[1] for item in tmp[:n_mdl_chains]]
-                    for seed in seeds:
-                        tmp_mapping = dict(mapping)
-                        tmp_mapping[seed[0]] = seed[1]
-                        tmp_mapping = the_greed.ExtendMapping(tmp_mapping)
-                        score_result = the_greed.FromFlatMapping(tmp_mapping)
-                        if score_result.QS_global > best_score:
-                            best_score = score_result.QS_global
-                            best_mapping = tmp_mapping
+    for seed in seeds:
 
-        if best_mapping is not None and len(best_mapping) > len(mapping):
-            # this even accepts extensions that lead to no increase in QS-score
-            # at least they make sense from an lDDT perspective
-            something_happened = True
-            mapping = best_mapping
+        # do initial extension
+        mapping = the_greed.ExtendMapping({seed[0]: seed[1]})
+
+        # repeat the process until we have a full mapping
+        something_happened = True
+        while something_happened:
+            something_happened = False
+            remnant_seeds = _GetSeeds(the_greed.ref_chem_groups,
+                                      the_greed.mdl_chem_groups,
+                                      mapped_ref_chains = set(mapping.keys()),
+                                      mapped_mdl_chains = set(mapping.values()))
+            if len(remnant_seeds) > 0:
+                # still more mapping to be done
+                best_score = -1.0
+                best_mapping = None
+                for remnant_seed in remnant_seeds:
+                    tmp_mapping = dict(mapping)
+                    tmp_mapping[remnant_seed[0]] = remnant_seed[1]
+                    tmp_mapping = the_greed.ExtendMapping(tmp_mapping)
+                    score_result = the_greed.FromFlatMapping(tmp_mapping)
+                    if score_result.QS_global > best_score:
+                        best_score = score_result.QS_global
+                        best_mapping = tmp_mapping
+                if best_mapping is not None:
+                    something_happened = True
+                    mapping = best_mapping
+
+        score_result = the_greed.FromFlatMapping(mapping)
+        if score_result.QS_global > best_overall_score:
+            best_overall_score = score_result.QS_global
+            best_overall_mapping = mapping
+
+    mapping = best_overall_mapping
 
     # translate mapping format and return
     final_mapping = list()
-- 
GitLab