From 55fd227a56b77c61f26bf2da5ec5dad6d451ef74 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 22 Feb 2023 13:48:56 +0100
Subject: [PATCH] chain mapping: Introduce new rigid mapping strategy:
 greedy_single_rmsd

---
 modules/mol/alg/pymod/chain_mapping.py      | 76 ++++++++++++++++++---
 modules/mol/alg/tests/test_chain_mapping.py |  3 +
 2 files changed, 71 insertions(+), 8 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index aa314e7df..2318f6a2a 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -1014,7 +1014,7 @@ class ChainMapper:
         are estimated using all possible combinations of target and model chains
         within the same chem groups and build the basis for further extension.
 
-        There are three extension strategies:
+        There are four extension strategies:
 
         * **greedy_single_gdtts**: Iteratively add the model/target chain pair
           that adds the most conserved contacts based on the GDT-TS metric
@@ -1022,12 +1022,19 @@ class ChainMapper:
           with highest GDT-TS score is returned. However, that mapping is not
           guaranteed to be complete (see *single_chain_gdtts_thresh*).
 
-        * **greedy_iterative_gdtts**: Same as single except that the
-          transformation gets updated with each added chain pair.
+        * **greedy_iterative_gdtts**: Same as greedy_single_gdtts except that
+          the transformation gets updated with each added chain pair.
 
-        * **greedy_iterative_rmsd**: Same as iterative, i.e. the transformation
-          gets updated with each added chain pair. However,
-          **single_chain_gdtts_thresh** is only applied to derive the initial
+        * **greedy_single_rmsd**: Conceptually similar to greedy_single_gdtts
+          but the added chain pairs are the ones with lowest RMSD.
+          The mapping with lowest overall RMSD gets returned.
+          *single_chain_gdtts_thresh* is only applied to derive the initial
+          transformations. After that, the minimal RMSD chain pair gets
+          iteratively added without applying any threshold.
+
+        * **greedy_iterative_rmsd**: Same as greedy_single_rmsd exept that
+          the transformation gets updated with each added chain pair.
+          *single_chain_gdtts_thresh* is only applied to derive the initial
           transformations. After that, the minimal RMSD chain pair gets
           iteratively added without applying any threshold.
 
@@ -1067,7 +1074,7 @@ class ChainMapper:
         """
 
         strategies = ["greedy_single_gdtts", "greedy_iterative_gdtts",
-                      "greedy_iterative_rmsd"]
+                      "greedy_single_rmsd", "greedy_iterative_rmsd"]
         if strategy not in strategies:
             raise RuntimeError(f"strategy must be {strategies}")
 
@@ -1138,6 +1145,13 @@ class ChainMapper:
                                            len(self.target.chains),
                                            len(mdl.chains))
 
+        elif strategy == "greedy_single_rmsd":
+            mapping = _SingleRigidRMSD(initial_transforms, initial_mappings,
+                                       self.chem_groups, chem_mapping,
+                                       trg_group_pos, mdl_group_pos,
+                                       iterative_superposition)
+
+
         elif strategy == "greedy_iterative_rmsd":
             mapping = _IterativeRigidRMSD(initial_transforms, initial_mappings,
                                           self.chem_groups, chem_mapping,
@@ -2973,7 +2987,7 @@ def _IterativeRigidGDTTS(initial_transforms, initial_mappings, chem_groups,
     for mdl_pos, mdl_chains in zip(mdl_group_pos, chem_mapping):
         for m_pos, m in zip(mdl_pos, mdl_chains):
             mdl_pos_dict[m] = m_pos
-        
+
     best_mapping = dict()
     best_gdt = 0
     for initial_transform, initial_mapping in zip(initial_transforms,
@@ -3045,6 +3059,52 @@ def _IterativeRigidGDTTS(initial_transforms, initial_mappings, chem_groups,
 
     return best_mapping
 
+def _SingleRigidRMSD(initial_transforms, initial_mappings, chem_groups,
+                     chem_mapping, trg_group_pos, mdl_group_pos,
+                     iterative_superposition):
+    """
+    Takes initial transforms and sequentially adds chain pairs with lowest RMSD.
+    The mapping from the transform that leads to lowest overall RMSD is
+    returned.
+    """
+    best_mapping = dict()
+    best_ssd = float("inf") # we're actually going for summed squared distances
+                            # Since all positions have same lengths and we do a
+                            # full mapping, lowest SSD has a guarantee of also
+                            # being lowest RMSD
+    for transform in initial_transforms:
+        mapping = dict()
+        mapped_mdl_chains = set()
+        ssd = 0.0
+        for trg_chains, mdl_chains, trg_pos, mdl_pos, in zip(chem_groups,
+                                                             chem_mapping,
+                                                             trg_group_pos,
+                                                             mdl_group_pos):
+            if len(trg_pos) == 0 or len(mdl_pos) == 0:
+                continue # cannot compute valid rmsd
+            ssds = list()
+            t_mdl_pos = list()
+            for m_pos in mdl_pos:
+                t_m_pos = geom.Vec3List(m_pos)
+                t_m_pos.ApplyTransform(transform)
+                t_mdl_pos.append(t_m_pos)
+            for t_pos, t in zip(trg_pos, trg_chains):
+                for t_m_pos, m in zip(t_mdl_pos, mdl_chains):
+                    ssd = t_pos.GetSummedSquaredDistances(t_m_pos)
+                    ssds.append((ssd, (t,m)))
+            ssds.sort()
+            for item in ssds:
+                p = item[1]
+                if p[0] not in mapping and p[1] not in mapped_mdl_chains:
+                    mapping[p[0]] = p[1]
+                    mapped_mdl_chains.add(p[1])
+                    ssd += item[0]
+
+        if ssd < best_ssd:
+            best_ssd = ssd
+            best_mapping = mapping
+
+    return best_mapping
 
 def _IterativeRigidRMSD(initial_transforms, initial_mappings, chem_groups,
                         chem_mapping, trg_group_pos, mdl_group_pos,
diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py
index d012f1e65..36a633b09 100644
--- a/modules/mol/alg/tests/test_chain_mapping.py
+++ b/modules/mol/alg/tests/test_chain_mapping.py
@@ -283,6 +283,9 @@ class TestChainMapper(unittest.TestCase):
     greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_gdtts")
     self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
 
+    greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_single_rmsd")
+    self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
+
     greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_rmsd")
     self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
 
-- 
GitLab