From 9c000e33d96cb2b2b4eb41b2210ff73f6565fe8a Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 31 May 2023 13:36:58 +0200
Subject: [PATCH] Control exhaustive enumeration in chain mapping from compare
 actions

---
 actions/ost-compare-ligand-structures   | 11 +++++++++++
 actions/ost-compare-structures          | 13 ++++++++++++-
 modules/mol/alg/pymod/ligand_scoring.py | 12 ++++++++++--
 modules/mol/alg/pymod/scoring.py        | 12 ++++++++++--
 4 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures
index 2af7694d8..b51a6c454 100644
--- a/actions/ost-compare-ligand-structures
+++ b/actions/ost-compare-ligand-structures
@@ -226,6 +226,16 @@ def _ParseArgs():
         default=3,
         help="Set verbosity level. Defaults to 3 (INFO).")
 
+    parser.add_argument(
+        "--n-max-naive",
+        dest="n_max_naive",
+        required=False,
+        default=12,
+        type=int,
+        help=("If number of chains in model and reference are below or equal "
+              "that number, the global chain mapping will naively enumerate "
+              "all possible mappings. A heuristic is used otherwise."))
+
     return parser.parse_args()
 
 
@@ -334,6 +344,7 @@ def _Process(model, model_ligands, reference, reference_ligands, args):
         radius=args.radius,
         lddt_pli_radius=args.lddt_pli_radius,
         lddt_lp_radius=args.lddt_lp_radius,
+        n_max_naive=args.n_max_naive
     )
 
     out = dict()
diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures
index cc145d048..48d33839e 100644
--- a/actions/ost-compare-structures
+++ b/actions/ost-compare-structures
@@ -384,6 +384,16 @@ def _ParseArgs():
         action="store_true",
         help=("Disable stereochecks for lDDT computation"))
 
+    parser.add_argument(
+        "--n-max-naive",
+        dest="n_max_naive",
+        required=False,
+        default=12,
+        type=int,
+        help=("If number of chains in model and reference are below or equal "
+              "that number, the chain mapping will naively enumerate all "
+              "possible mappings. A heuristic is used otherwise."))
+
     return parser.parse_args()
 
 def _Rename(ent):
@@ -541,7 +551,8 @@ def _Process(model, reference, args):
                             cad_score_exec = args.cad_exec,
                             custom_mapping = mapping,
                             usalign_exec = args.usalign_exec,
-                            lddt_no_stereochecks = args.lddt_no_stereochecks)
+                            lddt_no_stereochecks = args.lddt_no_stereochecks,
+                            n_max_naive = args.n_max_naive)
 
     ir = _GetInconsistentResidues(scorer.aln)
     if len(ir) > 0 and args.enforce_consistency:
diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py
index 3d9467bf6..ed1e95522 100644
--- a/modules/mol/alg/pymod/ligand_scoring.py
+++ b/modules/mol/alg/pymod/ligand_scoring.py
@@ -209,6 +209,12 @@ class LigandScorer:
                             (False) is to use a combination of lDDT-PLI and
                             RMSD for the assignment.
     :type rmsd_assignment: :class:`bool`
+    :param n_max_naive: Parameter for global chain mapping. If *model* and
+                        *target* have less or equal that number of chains,
+                        the full
+                        mapping solution space is enumerated to find the
+                        the optimum. A heuristic is used otherwise.
+    :type n_max_naive: :class:`int`
     """
     def __init__(self, model, target, model_ligands=None, target_ligands=None,
                  resnum_alignments=False, check_resnames=True,
@@ -216,7 +222,7 @@ class LigandScorer:
                  chain_mapper=None, substructure_match=False,
                  radius=4.0, lddt_pli_radius=6.0, lddt_lp_radius=10.0,
                  binding_sites_topn=100000, global_chain_mapping=False,
-                 rmsd_assignment=False):
+                 rmsd_assignment=False, n_max_naive=12):
 
         if isinstance(model, mol.EntityView):
             self.model = mol.CreateEntityFromView(model, False)
@@ -263,6 +269,7 @@ class LigandScorer:
         self.binding_sites_topn = binding_sites_topn
         self.global_chain_mapping = global_chain_mapping
         self.rmsd_assignment = rmsd_assignment
+        self.n_max_naive = n_max_naive
 
         # scoring matrices
         self._rmsd_matrix = None
@@ -296,7 +303,8 @@ class LigandScorer:
     def _model_mapping(self):
         """Get the global chain mapping for the model."""
         if self.__model_mapping is None:
-            self.__model_mapping = self.chain_mapper.GetMapping(self.model)
+            self.__model_mapping = self.chain_mapper.GetMapping(self.model,
+                                                                n_max_naive=self.n_max_naive)
         return self.__model_mapping
 
     @staticmethod
diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 7b6b802b3..26633216a 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -128,11 +128,16 @@ class Scorer:
     :param lddt_no_stereochecks: Whether to compute lDDT without stereochemistry
                                 checks
     :type lddt_no_stereochecks: :class:`bool`
+    :param n_max_naive: Parameter for chain mapping. If *model* and *target*
+                        have less or equal that number of chains, the full
+                        mapping solution space is enumerated to find the
+                        the optimum. A heuristic is used otherwise.
+    :type n_max_naive: :class:`int`
     """
     def __init__(self, model, target, resnum_alignments=False,
                  molck_settings = None, cad_score_exec = None,
                  custom_mapping=None, usalign_exec = None,
-                 lddt_no_stereochecks=False):
+                 lddt_no_stereochecks=False, n_max_naive=12):
 
         if isinstance(model, mol.EntityView):
             model = mol.CreateEntityFromView(model, False)
@@ -200,6 +205,7 @@ class Scorer:
         self.cad_score_exec = cad_score_exec
         self.usalign_exec = usalign_exec
         self.lddt_no_stereochecks = lddt_no_stereochecks
+        self.n_max_naive = n_max_naive
 
         # lazily evaluated attributes
         self._stereochecked_model = None
@@ -419,7 +425,9 @@ class Scorer:
         :type: :class:`ost.mol.alg.chain_mapping.MappingResult` 
         """
         if self._mapping is None:
-            self._mapping = self.chain_mapper.GetMapping(self.model)
+            self._mapping = \
+            self.chain_mapper.GetMapping(self.model,
+                                         n_max_naive = self.n_max_naive)
         return self._mapping
 
     @property
-- 
GitLab