From a2492881e8f8c73ce667d7cd70f7ff7b9323bfa1 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Thu, 25 May 2023 11:17:28 +0200
Subject: [PATCH] Add convenience function for chain mapping

Calls the currently preferred chain mapping method that should
deal with most chain mapping cases in a reasonable runtime.
---
 modules/mol/alg/pymod/chain_mapping.py | 19 ++++++++++++++++
 modules/mol/alg/pymod/scoring.py       | 30 ++++----------------------
 2 files changed, 23 insertions(+), 26 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index 692788bb8..896f77257 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -1228,6 +1228,25 @@ class ChainMapper:
         return MappingResult(self.target, mdl, self.chem_groups, chem_mapping,
                              final_mapping, alns)
 
+    def GetMapping(self, model):
+        """ Convenience function to get mapping with currently preferred method
+
+        If number of chains in model and target are <= 12, a naive QS-score
+        mapping is performed. For anything else, a QS-score mapping with the
+        greedy_block strategy is performed (steep_opt_rate = 3,
+        block_seed_size = 5, block_blocks_per_chem_group = 6).
+        """
+        n_trg_chains = len(self.target.chains)
+        res = self.GetChemMapping(model)
+        n_mdl_chains = len(res[2].chains)
+        if n_trg_chains <= 12 and n_mdl_chains <= 12:
+            return self.GetQSScoreMapping(model, strategy="naive",
+                                          chem_mapping_result=res)
+        else:
+            return self.GetQSScoreMapping(model, strategy="greedy_block",
+                                          steep_opt_rate=3, block_seed_size=5,
+                                          block_blocks_per_chem_group=6,
+                                          chem_mapping_result=res)
 
     def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0,
                 thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False,
diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index f9e7ec864..7b6b802b3 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -111,12 +111,6 @@ class Scorer:
                            colored to True in
                            :class:`ost.mol.alg.MolckSettings` constructor.
     :type molck_settings: :class:`ost.mol.alg.MolckSettings`
-    :param naive_chain_mapping_thresh: Chain mappings for targets/models up to
-                                       that number of chains will be fully
-                                       enumerated to optimize for QS-score.
-                                       Everything above is treated with a
-                                       heuristic.
-    :type naive_chain_mapping_thresh: :class:`int` 
     :param cad_score_exec: Explicit path to voronota-cadscore executable from
                            voronota installation from 
                            https://github.com/kliment-olechnovic/voronota. If
@@ -136,9 +130,9 @@ class Scorer:
     :type lddt_no_stereochecks: :class:`bool`
     """
     def __init__(self, model, target, resnum_alignments=False,
-                 molck_settings = None, naive_chain_mapping_thresh=12,
-                 cad_score_exec = None, custom_mapping=None,
-                 usalign_exec = None, lddt_no_stereochecks=False):
+                 molck_settings = None, cad_score_exec = None,
+                 custom_mapping=None, usalign_exec = None,
+                 lddt_no_stereochecks=False):
 
         if isinstance(model, mol.EntityView):
             model = mol.CreateEntityFromView(model, False)
@@ -203,7 +197,6 @@ class Scorer:
                                        "resnum_alignments are enabled")
 
         self.resnum_alignments = resnum_alignments
-        self.naive_chain_mapping_thresh = naive_chain_mapping_thresh
         self.cad_score_exec = cad_score_exec
         self.usalign_exec = usalign_exec
         self.lddt_no_stereochecks = lddt_no_stereochecks
@@ -426,22 +419,7 @@ class Scorer:
         :type: :class:`ost.mol.alg.chain_mapping.MappingResult` 
         """
         if self._mapping is None:
-            n_trg_chains = len(self.chain_mapper.target.chains)
-            res = self.chain_mapper.GetChemMapping(self.model)
-            n_mdl_chains = len(res[2].chains)
-            thresh = self.naive_chain_mapping_thresh
-            if n_trg_chains <= thresh and n_mdl_chains <= thresh:
-                m = self.chain_mapper.GetQSScoreMapping(self.model,
-                                                        strategy="naive",
-                                                        chem_mapping_result=res)
-            else:
-                m = self.chain_mapper.GetQSScoreMapping(self.model,
-                                                        strategy="greedy_block",
-                                                        steep_opt_rate=3,
-                                                        block_seed_size=5,
-                                                        block_blocks_per_chem_group=6,
-                                                        chem_mapping_result=res)
-            self._mapping = m
+            self._mapping = self.chain_mapper.GetMapping(self.model)
         return self._mapping
 
     @property
-- 
GitLab