From 77efa561590181329ede13d847a7c000492c2e6c Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Tue, 21 Feb 2023 15:50:20 +0100
Subject: [PATCH] scoring: Adapt Scorer to reduced restrictions on rnums in
 chain mapping

---
 modules/mol/alg/pymod/scoring.py            | 44 +++++++++++++--------
 modules/mol/alg/tests/test_chain_mapping.py |  6 ++-
 2 files changed, 32 insertions(+), 18 deletions(-)

diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 70f2dd5b2..d42390eba 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -144,27 +144,39 @@ class Scorer:
         for ch in self._model.chains:
             if ch.GetName().strip() == "":
                 raise RuntimeError("Model chains must have valid chain names")
-
-        # catch models with residue numbers that are not strictly increasing
-        # (requirement of ChainMapper)
-        for ch in self._model.chains:
-            nums = [r.GetNumber().GetNum() for r in ch.residues]
-            if not all(i < j for i, j in zip(nums, nums[1:])):
-                raise RuntimeError("Residue numbers in each model chain must "
-                                   "be strictly increasing")
-
+        
         # catch targets which have empty chain names
         for ch in self._target.chains:
             if ch.GetName().strip() == "":
                 raise RuntimeError("Target chains must have valid chain names")
 
-        # catch targets with residue numbers that are not strictly increasing
-        # (requirement of ChainMapper)
-        for ch in self._target.chains:
-            nums = [r.GetNumber().GetNum() for r in ch.residues]
-            if not all(i < j for i, j in zip(nums, nums[1:])):
-                raise RuntimeError("Residue numbers in each target chain must "
-                                   "be strictly increasing")
+        if resnum_alignments:
+            # In case of resnum_alignments, we have some requirements on 
+            # residue numbers in the chain mapping: 1) no ins codes 2) strictly
+            # increasing residue numbers.
+            for ch in self._model.chains:
+                ins_codes = [r.GetNumber().GetInsCode() for r in ch.residues]
+                if len(set(ins_codes)) != 1 or ins_codes[0] != '\0':
+                    raise RuntimeError("Residue numbers in each model chain "
+                                       "must not contain insertion codes if "
+                                       "resnum_alignments are enabled")
+                nums = [r.GetNumber().GetNum() for r in ch.residues]
+                if not all(i < j for i, j in zip(nums, nums[1:])):
+                    raise RuntimeError("Residue numbers in each model chain "
+                                       "must be strictly increasing if "
+                                       "resnum_alignments are enabled")
+
+            for ch in self._target.chains:
+                ins_codes = [r.GetNumber().GetInsCode() for r in ch.residues]
+                if len(set(ins_codes)) != 1 or ins_codes[0] != '\0':
+                    raise RuntimeError("Residue numbers in each target chain "
+                                       "must not contain insertion codes if "
+                                       "resnum_alignments are enabled")
+                nums = [r.GetNumber().GetNum() for r in ch.residues]
+                if not all(i < j for i, j in zip(nums, nums[1:])):
+                    raise RuntimeError("Residue numbers in each target chain "
+                                       "must be strictly increasing if "
+                                       "resnum_alignments are enabled")
 
         if molck_settings is None:
             molck_settings = MolckSettings(rm_unk_atoms=True,
diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py
index 106f849a3..d012f1e65 100644
--- a/modules/mol/alg/tests/test_chain_mapping.py
+++ b/modules/mol/alg/tests/test_chain_mapping.py
@@ -118,18 +118,20 @@ class TestChainMapper(unittest.TestCase):
     self.assertTrue(_CompareViews(mapper.chem_group_alignments[2].GetSequence(0).GetAttachedView(), nuc_view_two))
 
     # ensure that error is triggered if there are insertion codes
+    # and resnum_alignments are enabled
     tmp_ent = ent.Copy()
     ed = tmp_ent.EditXCS()
     r = tmp_ent.residues[0]
     ed.SetResidueNumber(r, mol.ResNum(r.GetNumber().GetNum(), 'A'))
-    self.assertRaises(Exception, ChainMapper, tmp_ent)
+    self.assertRaises(Exception, ChainMapper, tmp_ent, resnum_alignments=True)
 
     # ensure that error is triggered if resnums are not strictly increasing
+    # and resnum_alignments are enabled
     tmp_ent = ent.Copy()
     ed = tmp_ent.EditXCS()
     r = tmp_ent.residues[0]
     ed.SetResidueNumber(r, mol.ResNum(r.GetNumber().GetNum() + 42))
-    self.assertRaises(Exception, ChainMapper, tmp_ent)
+    self.assertRaises(Exception, ChainMapper, tmp_ent, resnum_alignments=True)
 
     # chain B has a missing Valine... set pep_gap_thr to 0.0 should give an
     # additional chem group
-- 
GitLab