From 3c41cef62e09ec2f35c524e9ae7cfcf669c9f969 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Mon, 23 Jan 2023 17:57:26 +0100
Subject: [PATCH] include number of model chains in chain mapping heuristic

So far, the Scorer object only relied on the number of chains in
the target to decide whether to enumerate the full chain mapping
solution space or apply a heuristic. The number of chains in the
model is equally important.
---
 modules/mol/alg/pymod/chain_mapping.py | 51 +++++++++++++++++++++-----
 modules/mol/alg/pymod/scoring.py       | 20 ++++++----
 2 files changed, 55 insertions(+), 16 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index db28156de..29acb6654 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -705,7 +705,8 @@ class ChainMapper:
     def GetlDDTMapping(self, model, inclusion_radius=15.0,
                        thresholds=[0.5, 1.0, 2.0, 4.0], strategy="naive",
                        steep_opt_rate = None, full_n_mdl_chains = None,
-                       block_seed_size = 5, block_blocks_per_chem_group = 5):
+                       block_seed_size = 5, block_blocks_per_chem_group = 5,
+                       chem_mapping_result = None):
         """ Identify chain mapping by optimizing lDDT score
 
         Maps *model* chain sequences to :attr:`~chem_groups` and find mapping
@@ -771,6 +772,11 @@ class ChainMapper:
                                             are extended in an initial search
                                             for high scoring local solutions.
         :type block_blocks_per_chem_group: :class:`int`
+        :param chem_mapping_result: Pro param. The result of
+                                    :func:`~GetChemMapping` where you provided
+                                    *model*. If set, *model* parameter is not
+                                    used.
+        :type chem_mapping_result: :class:`tuple`
         :returns: A :class:`MappingResult`
         """
 
@@ -778,7 +784,10 @@ class ChainMapper:
         if strategy not in strategies:
             raise RuntimeError(f"Strategy must be in {strategies}")
 
-        chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        if chem_mapping_result is None:
+            chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        else:
+            chem_mapping, chem_group_alns, mdl = chem_mapping_result
 
         ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
                                        self.chem_group_alignments,
@@ -836,7 +845,7 @@ class ChainMapper:
     def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive",
                           full_n_mdl_chains = None, block_seed_size = 5,
                           block_blocks_per_chem_group = 5,
-                          steep_opt_rate = None):
+                          steep_opt_rate = None, chem_mapping_result = None):
         """ Identify chain mapping based on QSScore
 
         Scoring is based on CA/C3' positions which are present in all chains of
@@ -874,6 +883,11 @@ class ChainMapper:
         :type contact_d: :class:`float` 
         :param strategy: Strategy for sampling, must be in ["naive"]
         :type strategy: :class:`str`
+        :param chem_mapping_result: Pro param. The result of
+                                    :func:`~GetChemMapping` where you provided
+                                    *model*. If set, *model* parameter is not
+                                    used.
+        :type chem_mapping_result: :class:`tuple`
         :returns: A :class:`MappingResult`
         """
 
@@ -881,7 +895,10 @@ class ChainMapper:
         if strategy not in strategies:
             raise RuntimeError(f"strategy must be {strategies}")
 
-        chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        if chem_mapping_result is None:
+            chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        else:
+            chem_mapping, chem_group_alns, mdl = chem_mapping_result
         ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
                                        self.chem_group_alignments,
                                        chem_mapping,
@@ -934,7 +951,8 @@ class ChainMapper:
 
     def GetRigidMapping(self, model, strategy = "greedy_single_gdtts",
                         single_chain_gdtts_thresh=0.4, subsampling=None,
-                        first_complete=False, iterative_superposition=False):
+                        first_complete=False, iterative_superposition=False,
+                        chem_mapping_result = None):
         """Identify chain mapping based on rigid superposition
 
         Superposition and scoring is based on CA/C3' positions which are present
@@ -989,6 +1007,11 @@ class ChainMapper:
                                         as oposed to
                                         :func:`ost.mol.alg.SuperposeSVD`
         :type iterative_superposition: :class:`bool`
+        :param chem_mapping_result: Pro param. The result of
+                                    :func:`~GetChemMapping` where you provided
+                                    *model*. If set, *model* parameter is not
+                                    used.
+        :type chem_mapping_result: :class:`tuple`
         :returns: A :class:`MappingResult`
         """
 
@@ -997,7 +1020,10 @@ class ChainMapper:
         if strategy not in strategies:
             raise RuntimeError(f"strategy must be {strategies}")
 
-        chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        if chem_mapping_result is None:
+            chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        else:
+            chem_mapping, chem_group_alns, mdl = chem_mapping_result
         ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
                                        self.chem_group_alignments,
                                        chem_mapping,
@@ -1093,7 +1119,7 @@ class ChainMapper:
 
     def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0,
                 thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False,
-                only_interchain=False):
+                only_interchain=False, chem_mapping_result = None):
         """ Identify *topn* representations of *substructure* in *model*
 
         *substructure* defines a subset of :attr:`~target` for which one
@@ -1124,7 +1150,11 @@ class ChainMapper:
         :param only_interchain: Only score interchain contacts in lDDT. Useful
                                 if you want to identify interface patches.
         :type only_interchain: :class:`bool`
-
+        :param chem_mapping_result: Pro param. The result of
+                                    :func:`~GetChemMapping` where you provided
+                                    *model*. If set, *model* parameter is not
+                                    used.
+        :type chem_mapping_result: :class:`tuple`
         :returns: :class:`list` of :class:`ReprResult`
         """
 
@@ -1168,7 +1198,10 @@ class ChainMapper:
                                    "a backbone atom named CA or C3\'")
 
         # perform mapping and alignments on full structures
-        chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        if chem_mapping_result is None:
+            chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        else:
+            chem_mapping, chem_group_alns, mdl = chem_mapping_result
         ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
                                        self.chem_group_alignments,
                                        chem_mapping,
diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 6d5b345ec..b6635c240 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -109,10 +109,11 @@ class Scorer:
                            colored to True in
                            :class:`ost.mol.alg.MolckSettings` constructor.
     :type molck_settings: :class:`ost.mol.alg.MolckSettings`
-    :param naive_chain_mapping_thresh: Chain mappings for targets up to that
-                                       number of chains will be fully enumerated
-                                       to optimize for QS-score. Everything
-                                       above is treated with a heuristic.
+    :param naive_chain_mapping_thresh: Chain mappings for targets/models up to
+                                       that number of chains will be fully
+                                       enumerated to optimize for QS-score.
+                                       Everything above is treated with a
+                                       heuristic.
     :type naive_chain_mapping_thresh: :class:`int` 
     :param cad_score_exec: Explicit path to voronota-cadscore executable from
                            voronota installation from 
@@ -356,15 +357,20 @@ class Scorer:
         """
         if self._mapping is None:
             n_trg_chains = len(self.chain_mapper.target.chains)
-            if n_trg_chains <= self.naive_chain_mapping_thresh:
+            res = self.chain_mapper.GetChemMapping(self.model)
+            n_mdl_chains = len(res[2].chains)
+            thresh = self.naive_chain_mapping_thresh
+            if n_trg_chains <= thresh and n_mdl_chains <= thresh:
                 m = self.chain_mapper.GetQSScoreMapping(self.model,
-                                                        strategy="naive")
+                                                        strategy="naive",
+                                                        chem_mapping_result=res)
             else:
                 m = self.chain_mapper.GetQSScoreMapping(self.model,
                                                         strategy="greedy_block",
                                                         steep_opt_rate=3,
                                                         block_seed_size=5,
-                                                        block_blocks_per_chem_group=6)
+                                                        block_blocks_per_chem_group=6,
+                                                        chem_mapping_result=res)
             self._mapping = m
         return self._mapping
 
-- 
GitLab