From 2f86b59bf1c3e32f859c266f26482876a9bd8526 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Thu, 1 Sep 2022 12:59:56 +0200
Subject: [PATCH] lDDTBSScorer now makes use of
 chain_mapping.ChainMapper.GetRepr function

---
 modules/mol/alg/pymod/scoring.py   | 124 ++++++-----------------------
 modules/mol/alg/tests/test_lddt.py |   9 +--
 2 files changed, 29 insertions(+), 104 deletions(-)

diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index bb222b6de..a01269266 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -2,64 +2,33 @@ from ost import mol
 from ost import seq
 from ost import io
 from ost.mol.alg import lddt
-from ost.mol.alg import qsscoring
-
+from ost.mol.alg import chain_mapping
 
 class lDDTBSScorer:
     """Scorer specific for a reference/model pair
 
-    Computes lDDT score on residues that constitute a binding site and can deal
-    with oligos using a chain mapping derived from
-    :class:`ost.mol.alg.qsscoring.QSscorer.chain_mapping`
-
-    There are two options to initialize :class:`lDDTBSScorer`
-
-    * provide a *reference* and *model* structure, will be used to internally
-      setup a :class:`ost.mol.alg.qsscoring.QSscorer` from which a chain mapping
-      is derived.
-    * provide a :class:`ost.mol.alg.qsscoring.QSscorer` directly to make use of a
-      potentially cached chain mapping.
-
-    In both cases, the actually evaluated structures are derived from
-    :class:`ost.mol.alg.qsscoring.QSscorer.qs_ent_1` (reference) and
-    :class:`ost.mol.alg.qsscoring.QSscorer.qs_ent_2` (model). That means they
-    are cleaned as described in
-    :class:`ost.mol.alg.qsscoring.QSscoreEntity.ent`.
-
+    Finds best possible binding site representation of reference in model given
+    lDDT score. Uses :class:`ost.mol.alg.chain_mapping.ChainMapper` to deal with
+    chain mapping.
 
     :param reference: Reference structure
     :type reference: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
     :param model: Model structure
     :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
-    :param qs_scorer: QSscorer object where you potentially already have a 
-                      chain mapping cached - *model* and *reference* will be
-                      neglected if this is provided.
-    :type qs_scorer: :class:`ost.mol.alg.qsscoring.QSscorer`
-    :param residue_number_alignment: Passed to QSscorer constructor if it needs
-                                     to be initialized with *reference* and
-                                     *model*.
+    :param residue_number_alignment: Passed to ChainMapper constructor
     :type residue_number_alignment: :class:`bool`
-    :raises: :class:`RuntimeError` if you don't provide *qs_scorer* or
-             *reference* and *model*,
-             :class:`ost.mol.alg.qsscoring.QSscoreError` if QSscorer
-             constructor raises.
     """
-    def __init__(self, reference=None, model=None,
-                 qs_scorer=None, residue_number_alignment=False):
-        if qs_scorer is not None:
-            self.qs_scorer = qs_scorer
-        elif model is not None and reference is not None:
-            self.qs_scorer = qsscoring.QSscorer(reference, model,
-                                                residue_number_alignment)
-        else:
-            raise RuntimeError("Must either provide qs_scorer or reference and "
-                               "model")
-        self.ref = self.qs_scorer.qs_ent_1.ent.Select("peptide=true")
-        self.mdl = self.qs_scorer.qs_ent_2.ent.Select("peptide=true")
-
-    def ScoreBS(self, ligand, radius = 4.0, lddt_radius=10.0,
-                return_mapping=False):
-        """Computes binding site lDDT score given *ligand*
+    def __init__(self, reference, model,
+                 residue_number_alignment=False):
+        self.chain_mapper = chain_mapping.ChainMapper(reference,
+            resnum_alignments=residue_number_alignment)
+        self.ref = self.chain_mapper.target
+        self.mdl = model
+
+    def ScoreBS(self, ligand, radius = 4.0, lddt_radius=10.0):
+        """Computes binding site lDDT score given *ligand*. Best possible
+        binding site representation is selected by lDDT but other scores such as
+        CA based RMSD and GDT are computed too and returned.
 
         :param ligand: Defines the scored binding site, i.e. provides positions
                        to perform proximity search
@@ -70,13 +39,9 @@ class lDDTBSScorer:
         :param lddt_radius: Passed as *inclusion_radius* to
                             :class:`ost.mol.alg.lddt.lDDTScorer`
         :type lddt_radius: :class:`float`
-        :param return_mapping: If true, returns binding site mapping information
-                               in addition to the raw lDDTBS score, i.e. returns
-                               a tuple with 1: lDDTBS score 2: list of qualified
-                               residue names in reference 3: same for model
-        :type return_mapping: :class:`bool`
-
-        :returns: lDDTBS score or tuple if *return_mapping* is True
+        :returns: Object of type :class:`ost.mol.alg.chain_mapping.ReprResult`
+                  containing all atom lDDT score and mapping information.
+                  None if no representation could be found.
         """
 
         # create view of reference binding site
@@ -98,49 +63,10 @@ class lDDTBSScorer:
                 if r.handle.GetHashCode() in ref_residues_hashes:
                     ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL)
 
-        # create view of model binding site using residue mapping from qs_scorer
-        # build up ref to mdl alignments for each chain on the go (alns)
-        mdl_bs = self.mdl.CreateEmptyView()
-        alns = dict()
-        rmapping = self.qs_scorer.mapped_residues
-        chain_mapping = self.qs_scorer.chain_mapping
-        for ref_chain in ref_bs.chains:
-            ref_cname = ref_chain.GetName()
-            ref_olcs = list()
-            mdl_olcs = list()
-            for ref_r in ref_chain.residues:
-                ref_rnum = ref_r.GetNumber().GetNum()
-                ref_olcs.append(ref_r.one_letter_code)
-                mdl_res_found = False
-                if ref_cname in rmapping and ref_rnum in rmapping[ref_cname]:
-                    mdl_cname = chain_mapping[ref_cname]
-                    mdl_rnum = rmapping[ref_cname][ref_rnum]
-                    mdl_r = self.mdl.FindResidue(mdl_cname, mol.ResNum(mdl_rnum))
-                    if mdl_r.IsValid():
-                        mdl_res_found = True
-                        mdl_bs.AddResidue(mdl_r, mol.ViewAddFlag.INCLUDE_ALL)
-                        mdl_olcs.append(mdl_r.one_letter_code)
-                if not mdl_res_found:
-                    mdl_olcs.append('-')
-            if list(set(mdl_olcs)) != ['-']:
-                mdl_cname = chain_mapping[ref_cname]
-                a = seq.CreateAlignment()
-                a.AddSequence(seq.CreateSequence(ref_cname, ''.join(ref_olcs)))
-                a.AddSequence(seq.CreateSequence(mdl_cname, ''.join(mdl_olcs)))
-                alns[mdl_cname] = a
-
-        scorer = lddt.lDDTScorer(ref_bs, inclusion_radius = lddt_radius)
-        # lddt wants model chains mapped on target chain => invert
-        inv_chain_mapping = {v:k for k,v in chain_mapping.items()}
-        # additionally, lddt only wants chains in that mapping that
-        # actually exist in the provided structures
-        lddt_chain_mapping = {k: inv_chain_mapping[k] for k in alns.keys()}
-        score, _ = scorer.lDDT(mdl_bs, chain_mapping = lddt_chain_mapping,
-                               residue_mapping = alns)
-
-        if return_mapping:
-            trg_residues = [str(r) for r in ref_bs.residues]
-            mdl_residues = [str(r) for r in mdl_bs.residues]
-            return (score, trg_residues, mdl_residues)
+        # gogogo
+        bs_repr = self.chain_mapper.GetRepr(ref_bs, self.mdl,
+                                            inclusion_radius = lddt_radius)
+        if len(bs_repr) >= 1:
+            return bs_repr[0]
         else:
-            return score
+            return None
diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py
index b90249a68..f429b60eb 100644
--- a/modules/mol/alg/tests/test_lddt.py
+++ b/modules/mol/alg/tests/test_lddt.py
@@ -225,9 +225,8 @@ class TestlDDTBS(unittest.TestCase):
         ref = _LoadFile("lddtbs_ref_1r8q.1.pdb")
 
         lddtbs_scorer = lDDTBSScorer(reference=ref, model=mdl)
-        score, ref_residues, mdl_residues = \
-        lddtbs_scorer.ScoreBS(ref.Select("rname=AFB"), radius = 5.0,
-                              lddt_radius = 12.0, return_mapping=True)
+        bs_repr = lddtbs_scorer.ScoreBS(ref.Select("rname=AFB"), radius = 5.0,
+                                        lddt_radius = 12.0)
 
         # select residues manually from reference
         for at in ref.Select("rname=AFB").atoms:
@@ -238,7 +237,7 @@ class TestlDDTBS(unittest.TestCase):
         ref_bs = ref.Select("grasdf:0=1")
         ref_bs = ref_bs.Select("peptide=true")
         ref_bs_names = [r.GetQualifiedName() for r in ref_bs.residues]
-        self.assertEqual(sorted(ref_bs_names), sorted(ref_residues))
+        self.assertEqual(sorted(ref_bs_names), sorted(bs_repr.ref_residues))
 
 
         # everything below basically computes lDDTBS manually and
@@ -271,7 +270,7 @@ class TestlDDTBS(unittest.TestCase):
 
         # compute and compare
         lddt_scorer = lDDTScorer(sc_ref_bs, inclusion_radius=12.0)
-        self.assertAlmostEqual(score, lddt_scorer.lDDT(sc_mdl_bs)[0])
+        self.assertAlmostEqual(bs_repr.lDDT, lddt_scorer.lDDT(sc_mdl_bs)[0])
 
 
 if __name__ == "__main__":
-- 
GitLab