From 8f53cd735e0167575e4344b62a8cc3ec9ca07255 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Mon, 8 Aug 2022 16:33:12 +0200
Subject: [PATCH] lDDT: refactor block strategy in greedy chain mapping

---
 modules/mol/alg/pymod/chain_mapping.py | 106 ++++++++++++-------------
 1 file changed, 53 insertions(+), 53 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index a375aaca5..43d7ac143 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -401,8 +401,8 @@ class ChainMapper:
     def GetGreedylDDTMapping(self, model, inclusion_radius=15.0,
                              thresholds=[0.5, 1.0, 2.0, 4.0],
                              seed_strategy="fast", steep_opt_rate = None,
-                             full_n_mdl_chains = None, block_seed_size = None,
-                             block_n_mdl_chains = None):
+                             full_n_mdl_chains = None, block_seed_size = 5,
+                             block_blocks_per_chem_group = 5):
         """Heuristic to lower the complexity of naive iteration
 
         Maps *model* chain sequences to :attr:`~chem_groups` and extends these
@@ -425,10 +425,10 @@ class ChainMapper:
           per ref chain to the *full_n_mdl_chains* best scoring ones.
 
         * block: try multiple seeds, i.e. try all ref/mdl chain combinations
-          within the respective chem groups but only extend these seeds by
-          *block_seed_size* chains. The highest scoring block for every ref
-          chain is extended exhaustively to identify the best scoring initial
-          block.
+          within the respective chem groups and compute single chain lDDTs.
+          The *block_blocks_per_chem_group* best scoring ones are extend by
+          *block_seed_size* chains and the best scoring one is exhaustively
+          extended.
 
         :param model: Model to map
         :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
@@ -445,19 +445,19 @@ class ChainMapper:
                                within their respective chem groups and accepts
                                swaps that improve lDDT score. Iteration stops as
                                soon as no improvement can be achieved anymore.
-        :type stepp_opt_rate: :class:`int`
+        :type steep_opt_rate: :class:`int`
         :param full_n_mdl_chains: Param for *full* seed strategy - Max number of
                                   mdl chains that are tried per ref chain. The
                                   default (None) tries all of them.
         :type full_n_mdl_chains: :class:`int`
         :param block_seed_size: Param for *block* seed strategy - Initial seeds
-                                are extended by that number of chains. The
-                                default (None) performs full extensions and you
-                                get equivalent behaviour as in *full* strategy.
+                                are extended by that number of chains.
         :type block_seed_size: :class:`int`
-        :param block_n_mdl_chains: Equivalent of *full_n_mdl_chains* but for
-                                   *block* seed strategy.
-        :type block_n_mdl_chains: :class:`int`
+        :param block_blocks_per_chem_group: Param for *block* seed strategy -
+                                            Number of blocks per chem group that
+                                            are extended in an initial search
+                                            for high scoring local solutions.
+        :type block_blocks_per_chem_group: :class:`int`
         :returns: A :class:`list` of :class:`list` that reflects
                   :attr:`~chem_groups` but is filled with the respective model
                   chains. Target chains without mapped model chains are set to
@@ -497,7 +497,8 @@ class ChainMapper:
         elif seed_strategy == "full":
             return _FullGreedy(the_greed, full_n_mdl_chains)
         elif seed_strategy == "block":
-            return _BlockGreedy(the_greed, block_seed_size, block_n_mdl_chains)
+            return _BlockGreedy(the_greed, block_seed_size, block_blocks_per_chem_group)
+
 
     def GetRigidMapping(self, model, single_chain_gdtts_thresh=0.4,
                         subsampling=None, first_complete=False,
@@ -1529,25 +1530,21 @@ def _FullGreedy(the_greed, n_mdl_chains):
     return final_mapping
 
 
-def _BlockGreedy(the_greed, seed_size, n_mdl_chains):
-    """ Uses each reference chain as starting point for expansion
-
-    Tries to map all mdl chains (optionally up to *n_mdl_chains* best ones)
-    to these references but initially does not perform full expansion but only
-    up to *seed_size*. The best scoring block for each reference is then used
-    for full expansion.
+def _BlockGreedy(the_greed, seed_size, blocks_per_chem_group):
+    """ try multiple seeds, i.e. try all ref/mdl chain combinations within the
+    respective chem groups and compute single chain lDDTs. The
+    *blocks_per_chem_group* best scoring ones are extend by *seed_size* chains
+    and the best scoring one is exhaustively extended.
     """
 
-    if seed_size is not None and seed_size < 1:
-        raise RuntimeError("seed_size must be None or >= 1")
+    if seed_size is None or seed_size < 1:
+        raise RuntimeError(f"seed_size must be an int >= 1 (got {seed_size})")
 
-    if n_mdl_chains is not None and n_mdl_chains < 1:
-        raise RuntimeError("n_mdl_chains must be None or >= 1")
+    if blocks_per_chem_group is None or blocks_per_chem_group < 1:
+        raise RuntimeError(f"blocks_per_chem_group must be an int >= 1 "
+                           f"(got {blocks_per_chem_group})")
 
-    max_ext = None
-    if seed_size is not None:
-        # max_ext = seed_size-1 => start seed already has size 1
-        max_ext = seed_size - 1
+    max_ext = seed_size - 1 #  -1 => start seed already has size 1
 
     ref_chem_groups = copy.deepcopy(the_greed.ref_chem_groups)
     mdl_chem_groups = copy.deepcopy(the_greed.mdl_chem_groups)
@@ -1557,36 +1554,39 @@ def _BlockGreedy(the_greed, seed_size, n_mdl_chains):
     something_happened = True
     while something_happened:
         something_happened = False
-
-        # one block per ref chain, i.e. a mapping that is extended by seed_size
-        starting_blocks = dict()
+        starting_blocks = list()
         for ref_chains, mdl_chains in zip(ref_chem_groups, mdl_chem_groups):
             if len(mdl_chains) == 0:
                 continue # nothing to map
-            for ref_ch in ref_chains:
-                best_lddt = 0.0
-                best_mapping = None
-                seeds = [(ref_ch, mdl_ch) for mdl_ch in mdl_chains]
-                if n_mdl_chains is not None and n_mdl_chains < len(seeds):
-                    counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds]
-                    tmp = [(a,b) for a,b in zip(counts, seeds)]
-                    tmp.sort(reverse=True)
-                    seeds = [item[1] for item in tmp[:n_mdl_chains]]
-                for s in seeds:
-                    seed = dict(mapping)
-                    seed.update({s[0]: s[1]})
-                    
-                    seed = the_greed.ExtendMapping(seed, max_ext = max_ext)
-                    seed_lddt = the_greed.lDDTFromFlatMap(seed)
-                    if seed_lddt > best_lddt:
-                        best_lddt = seed_lddt
-                        best_mapping = seed
-                if best_mapping is not None:
-                    starting_blocks[ref_ch] = best_mapping
 
+            # Identify starting seeds for *blocks_per_chem_group* blocks
+            seeds = list()
+            for ref_ch in ref_chains:
+                seeds += [(ref_ch, mdl_ch) for mdl_ch in mdl_chains]
+            counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds]
+            tmp = [(a,b) for a,b in zip(counts, seeds)]
+            tmp.sort(reverse=True)
+            seeds = [item[1] for item in tmp[:blocks_per_chem_group]]
+
+            # extend starting seeds to *seed_size* and retain best scoring block
+            # for further extension
+            best_lddt = 0.0
+            best_mapping = None
+            for s in seeds:
+                seed = dict(mapping)
+                seed.update({s[0]: s[1]})  
+                seed = the_greed.ExtendMapping(seed, max_ext = max_ext)
+                seed_lddt = the_greed.lDDTFromFlatMap(seed)
+                if seed_lddt > best_lddt:
+                    best_lddt = seed_lddt
+                    best_mapping = seed
+            if best_mapping != None:
+                starting_blocks.append(best_mapping)
+
+        # fully expand initial starting blocks
         best_lddt = 0.0
         best_mapping = None
-        for ref_ch, seed in starting_blocks.items():
+        for seed in starting_blocks:
             seed = the_greed.ExtendMapping(seed)
             seed_lddt = the_greed.lDDTFromFlatMap(seed)
             if seed_lddt > best_lddt:
-- 
GitLab