From 6ecb119a4dd0e561c81339de407d32d97721de95 Mon Sep 17 00:00:00 2001
From: Xavier Robin <xavalias-github@xavier.robin.name>
Date: Thu, 25 May 2023 16:21:56 +0200
Subject: [PATCH] feat: SCHWED-5540 enable global chain mapping optionally

A global chain mapping can now be enabled with --global-chain-mapping as
a complement to the existing local chain mapping. This can be evaluated
for the scoring paper.
---
 actions/ost-compare-ligand-structures        |  9 ++++++++
 modules/mol/alg/pymod/ligand_scoring.py      | 22 ++++++++++++++++----
 modules/mol/alg/tests/test_ligand_scoring.py | 20 ++++++++++++++++++
 3 files changed, 47 insertions(+), 4 deletions(-)

diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures
index 74faddae2..5ab31eb07 100644
--- a/actions/ost-compare-ligand-structures
+++ b/actions/ost-compare-ligand-structures
@@ -168,6 +168,14 @@ def _ParseArgs():
         action="store_true",
         help=("Allow incomplete target ligands."))
 
+    parser.add_argument(
+        "-gcm",
+        "--global-chain-mapping",
+        dest="global_chain_mapping",
+        default=False,
+        action="store_true",
+        help=("Use a global chain mapping."))
+
     parser.add_argument(
         "--lddt-pli",
         dest="lddt_pli",
@@ -303,6 +311,7 @@ def _Process(model, model_ligands, reference, reference_ligands, args):
         check_resnames=args.enforce_consistency,
         rename_ligand_chain=True,
         substructure_match=args.substructure_match,
+        global_chain_mapping=args.global_chain_mapping,
         radius=args.radius,
         lddt_pli_radius=args.lddt_pli_radius,
         lddt_lp_radius=args.lddt_lp_radius
diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py
index 991c7d27c..992bb6ea6 100644
--- a/modules/mol/alg/pymod/ligand_scoring.py
+++ b/modules/mol/alg/pymod/ligand_scoring.py
@@ -174,14 +174,21 @@ class LigandScorer:
     :type lddt_lp_radius: :class:`float`
     :param binding_sites_topn: maximum number of target binding site
                                representations to assess, per target ligand.
+                               Ignored if `global_chain_mapping` is True.
     :type binding_sites_topn: :class:`int`
+    :param global_chain_mapping: set to True to use a global chain mapping for
+                                 the polymer (protein, nucleotide) chains.
+                                 Defaults to False, in which case only local
+                                 chain mappings are allowed (where different
+                                 ligand may be scored against different chain
+                                 mappings).
     """
     def __init__(self, model, target, model_ligands=None, target_ligands=None,
                  resnum_alignments=False, check_resnames=True,
                  rename_ligand_chain=False,
                  chain_mapper=None, substructure_match=False,
                  radius=4.0, lddt_pli_radius=6.0, lddt_lp_radius=10.0,
-                 binding_sites_topn=100000):
+                 binding_sites_topn=100000, global_chain_mapping=False):
 
         if isinstance(model, mol.EntityView):
             self.model = mol.CreateEntityFromView(model, False)
@@ -226,6 +233,7 @@ class LigandScorer:
         self.lddt_pli_radius = lddt_pli_radius
         self.lddt_lp_radius = lddt_lp_radius
         self.binding_sites_topn = binding_sites_topn
+        self.global_chain_mapping = global_chain_mapping
 
         # scoring matrices
         self._rmsd_matrix = None
@@ -415,9 +423,15 @@ class LigandScorer:
                         ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL)
 
             # Find the representations
-            self._binding_sites[ligand.hash_code] = self.chain_mapper.GetRepr(
-                ref_bs, self.model, inclusion_radius=self.lddt_lp_radius,
-                topn=self.binding_sites_topn)
+            if self.global_chain_mapping:
+                mapping_res = self.chain_mapper.GetMapping(self.model)
+                self._binding_sites[ligand.hash_code] = self.chain_mapper.GetRepr(
+                    ref_bs, self.model, inclusion_radius=self.lddt_lp_radius,
+                    global_mapping = mapping_res)
+            else:
+                self._binding_sites[ligand.hash_code] = self.chain_mapper.GetRepr(
+                    ref_bs, self.model, inclusion_radius=self.lddt_lp_radius,
+                    topn=self.binding_sites_topn)
         return self._binding_sites[ligand.hash_code]
 
     @staticmethod
diff --git a/modules/mol/alg/tests/test_ligand_scoring.py b/modules/mol/alg/tests/test_ligand_scoring.py
index a1de22761..c2b2a762f 100644
--- a/modules/mol/alg/tests/test_ligand_scoring.py
+++ b/modules/mol/alg/tests/test_ligand_scoring.py
@@ -338,6 +338,26 @@ class TestLigandScoring(unittest.TestCase):
         self.assertEqual(sc.lddt_pli_details["F"][mol.ResNum(1)]["target_ligand"].qualified_name, 'K.G3D1')
         self.assertEqual(sc.lddt_pli_details["F"][mol.ResNum(1)]["model_ligand"].qualified_name, 'F.G3D1')
 
+    def test_global_chain_mapping(self):
+        """Test that the global and local chain mappings works.
+
+        For RMSD, A: A results in a better chain mapping. However, C: A is a
+        better global chain mapping from an lDDT perspective (and lDDT-PLI).
+        """
+        trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
+        mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True)
+        mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
+
+        # Local by default
+        sc = LigandScorer(mdl, trg, [mdl_lig], None)
+        assert sc.rmsd_details["00001_L_2"][1]["chain_mapping"] == {'A': 'A'}
+        assert sc.lddt_pli_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'}
+
+        # Global
+        sc = LigandScorer(mdl, trg, [mdl_lig], None, global_chain_mapping=True)
+        assert sc.rmsd_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'}
+        assert sc.lddt_pli_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'}
+
 
 if __name__ == "__main__":
     from ost import testutils
-- 
GitLab