From 7d956e81ca5248dac52be9cad28a64661550fb2b Mon Sep 17 00:00:00 2001
From: Xavier Robin <xavier.robin@unibas.ch>
Date: Thu, 19 Jan 2023 18:16:35 +0100
Subject: [PATCH] feat: SCHWED-5783 add check_resname

---
 modules/mol/alg/pymod/ligand_scoring.py      | 13 +++++++++----
 modules/mol/alg/tests/test_ligand_scoring.py | 14 ++++++++++++++
 2 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py
index c44ec74c0..75738e34f 100644
--- a/modules/mol/alg/pymod/ligand_scoring.py
+++ b/modules/mol/alg/pymod/ligand_scoring.py
@@ -87,6 +87,9 @@ class LigandScorer:
                               based on residue numbers. This can be assumed in
                               benchmarking setups such as CAMEO/CASP.
     :type resnum_alignments: :class:`bool`
+    :param check_resnames:  On by default. Enforces residue name matches
+                            between mapped model and target residues.
+    :type check_resnames: :class:`bool`
     :param chain_mapper: a chain mapper initialized for the target structure.
                          If None (default), a chain mapper will be initialized
                          lazily as required.
@@ -103,9 +106,9 @@ class LigandScorer:
     :param lddt_bs_radius: :class:`float`
     """
     def __init__(self, model, target, model_ligands=None, target_ligands=None,
-                 resnum_alignments=False, chain_mapper=None,
-                 substructure_match=False, radius=4.0, lddt_pli_radius=6.0,
-                 lddt_bs_radius=10.0):
+                 resnum_alignments=False, check_resnames=True,
+                 chain_mapper=None, substructure_match=False,
+                 radius=4.0, lddt_pli_radius=6.0, lddt_bs_radius=10.0):
 
         if isinstance(model, mol.EntityView):
             self.model = mol.CreateEntityFromView(model, False)
@@ -135,6 +138,7 @@ class LigandScorer:
 
         self._chain_mapper = chain_mapper
         self.resnum_alignments = resnum_alignments
+        self.check_resnames = check_resnames
         self.substructure_match = substructure_match
         self.radius = radius
         self.lddt_pli_radius = lddt_pli_radius
@@ -453,7 +457,8 @@ class LigandScorer:
                             n_cont, n_cons = lddt_scorer.lDDT(
                                 mdl_bs_ent, chain_mapping={"A": "A", "_": "_"},
                                 no_intrachain=True,
-                                return_dist_test=True)
+                                return_dist_test=True,
+                                check_resnames = self.check_resnames)
 
                         # Save results?
                         best_lddt = self._lddt_pli_matrix[target_i, model_i]["lddt_pli"]
diff --git a/modules/mol/alg/tests/test_ligand_scoring.py b/modules/mol/alg/tests/test_ligand_scoring.py
index 8e3eb94a7..a09d350fb 100644
--- a/modules/mol/alg/tests/test_ligand_scoring.py
+++ b/modules/mol/alg/tests/test_ligand_scoring.py
@@ -263,6 +263,20 @@ class TestLigandScoring(unittest.TestCase):
         assert sc._lddt_pli_matrix[1, 0]["bs_num_overlap_res"] == 15
         assert sc._lddt_pli_matrix[5, 0]["bs_num_overlap_res"] == 15
 
+    def test_check_resnames(self):
+        """Test check_resname argument works
+        """
+        # 4C0A has mismatching sequence and fails with check_resnames=True
+        trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
+        trg_4c0a, _ = io.LoadMMCIF(os.path.join('testfiles', "4c0a.cif.gz"), seqres=True)
+
+        with self.assertRaises(RuntimeError):
+            sc = LigandScorer(trg, trg_4c0a, None, None, check_resnames=True)
+            sc._compute_scores()
+
+        sc = LigandScorer(trg, trg_4c0a, None, None, check_resnames=False)
+        sc._compute_scores()
+
 
 if __name__ == "__main__":
     from ost import testutils
-- 
GitLab