From 97a859ada4145fda33527cb9963585335d0ba32d Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Tue, 28 Feb 2023 14:52:43 +0100
Subject: [PATCH] scoring: make alignments available as attributes of Scorer
 object

---
 modules/mol/alg/pymod/scoring.py | 104 ++++++++++++++++++++-----------
 1 file changed, 66 insertions(+), 38 deletions(-)

diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index d42390eba..645a02a01 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -206,6 +206,8 @@ class Scorer:
         self._mapping = None
         self._model_interface_residues = None
         self._target_interface_residues = None
+        self._aln = None
+        self._stereochecked_aln = None
 
         # lazily constructed scorer objects
         self._lddt_scorer = None
@@ -266,6 +268,31 @@ class Scorer:
         """
         return self._target
 
+    @property
+    def aln(self):
+        """ Alignments of :attr:`model`/:attr:`target` chains
+
+        Alignments for each pair of chains mapped in :attr:`mapping`.
+        First sequence is target sequence, second sequence the model sequence.
+
+        :type: :class:`list` of :class:`ost.seq.AlignmentHandle`
+        """
+        if self._aln is None:
+            self._compute_aln()
+        return self._aln
+
+    @property
+    def stereochecked_aln(self):
+        """ Stereochecked equivalent of :attr:`aln`
+
+        The alignments may differ, as stereochecks potentially remove residues
+
+        :type: :class:``
+        """
+        if self._aln is None:
+            self._compute_stereochecked_aln()
+        return self._aln
+
     @property
     def stereochecked_model(self):
         """ View of :attr:`~model` that has stereochemistry checks applied
@@ -882,35 +909,31 @@ class Scorer:
             self._compute_patchdockq_scores()
         return self._patch_dockq
 
-    def _compute_lddt(self):
-        # lDDT requires a flat mapping with mdl_ch as key and trg_ch as value
-        flat_mapping = self.mapping.GetFlatMapping(mdl_as_key=True)
-
+    def _aln_helper(self, target, model):
         # perform required alignments - cannot take the alignments from the
-        # mapping results as we potentially remove stuff in the stereocheck
-        # process.
+        # mapping results as we potentially remove stuff there as compared
+        # to self.model and self.target
         trg_seqs = dict()
-        for ch in self.stereochecked_target.chains:
+        for ch in target.chains:
             cname = ch.GetName()
             s = ''.join([r.one_letter_code for r in ch.residues])
             s = seq.CreateSequence(ch.GetName(), s)
-            s.AttachView(self.stereochecked_target.Select(f"cname={cname}"))
+            s.AttachView(target.Select(f"cname={cname}"))
             trg_seqs[ch.GetName()] = s
         mdl_seqs = dict()
-        for ch in self.stereochecked_model.chains:
+        for ch in model.chains:
             cname = ch.GetName()
             s = ''.join([r.one_letter_code for r in ch.residues])
             s = seq.CreateSequence(cname, s)
-            s.AttachView(self.stereochecked_model.Select(f"cname={cname}"))
+            s.AttachView(model.Select(f"cname={cname}"))
             mdl_seqs[ch.GetName()] = s
 
+        alns = list()
         trg_pep_chains = [s.GetName() for s in self.chain_mapper.polypep_seqs]
         trg_nuc_chains = [s.GetName() for s in self.chain_mapper.polynuc_seqs]
         trg_pep_chains = set(trg_pep_chains)
         trg_nuc_chains = set(trg_nuc_chains)
-        lddt_alns = dict()
-        lddt_chain_mapping = dict()
-        for mdl_ch, trg_ch in flat_mapping.items():
+        for mdl_ch, trg_ch in self.mapping.GetFlatMapping().items():
             if mdl_ch in mdl_seqs and trg_ch in trg_seqs:
                 if trg_ch in trg_pep_chains:
                     stype = mol.ChemType.AMINOACIDS
@@ -919,9 +942,32 @@ class Scorer:
                 else:
                     raise RuntimeError("Chain name inconsistency... ask "
                                        "Gabriel")
-                lddt_alns[mdl_ch] = self.chain_mapper.Align(trg_seqs[trg_ch],
-                                                            mdl_seqs[mdl_ch],
-                                                            stype)
+                alns.append(self.chain_mapper.Align(trg_seqs[trg_ch],
+                                                    mdl_seqs[mdl_ch],
+                                                    stype))
+                alns[-1].AttachView(0, trg_seqs[trg_ch].GetAttachedView())
+                alns[-1].AttachView(1, mdl_seqs[mdl_ch].GetAttachedView())
+        return alns
+
+    def _compute_aln(self):
+        self._aln = self._aln_helper(self.target, self.model)
+
+    def _compute_stereochecked_aln(self):
+        self._stereochecked_aln = self._aln_helper(self.stereochecked_target,
+                                                   self.stereochecked_model)
+
+    def _compute_lddt(self):
+        # lDDT requires a flat mapping with mdl_ch as key and trg_ch as value
+        flat_mapping = self.mapping.GetFlatMapping(mdl_as_key=True)
+
+        lddt_alns = dict()
+        for aln in self.stereochecked_aln:
+            mdl_seq = aln.GetSequence(1)
+            lddt_alns[mdl_seq.name] = aln
+
+        lddt_chain_mapping = dict()
+        for mdl_ch, trg_ch in flat_mapping.items():
+            if mdl_ch in lddt_alns:
                 lddt_chain_mapping[mdl_ch] = trg_ch
 
         lddt_score = self.lddt_scorer.lDDT(self.stereochecked_model,
@@ -996,30 +1042,12 @@ class Scorer:
         flat_mapping = self.mapping.GetFlatMapping()
         pep_seqs = set([s.GetName() for s in self.chain_mapper.polypep_seqs])
 
-        # perform required alignments - cannot take the alignments from the
-        # mapping results as we potentially remove stuff there as compared
-        # to self.model and self.target
-        trg_seqs = dict()
-        for ch in self.target.chains:
-            cname = ch.GetName()
-            s = ''.join([r.one_letter_code for r in ch.residues])
-            s = seq.CreateSequence(ch.GetName(), s)
-            s.AttachView(self.target.Select(f"cname={cname}"))
-            trg_seqs[ch.GetName()] = s
-        mdl_seqs = dict()
-        for ch in self.model.chains:
-            cname = ch.GetName()
-            s = ''.join([r.one_letter_code for r in ch.residues])
-            s = seq.CreateSequence(cname, s)
-            s.AttachView(self.model.Select(f"cname={cname}"))
-            mdl_seqs[ch.GetName()] = s
-
         dockq_alns = dict()
-        for trg_ch, mdl_ch in flat_mapping.items():
+        for aln in self.aln:
+            trg_ch = aln.GetSequence(0).name
             if trg_ch in pep_seqs:
-                dockq_alns[(trg_ch, mdl_ch)] = \
-                self.chain_mapper.Align(trg_seqs[trg_ch], mdl_seqs[mdl_ch],
-                                        mol.ChemType.AMINOACIDS)
+                mdl_ch = aln.GetSequence(1).name
+                dockq_alns[(trg_ch, mdl_ch)] = aln
 
         for trg_int in self.qs_scorer.qsent1.interacting_chains:
             trg_ch1 = trg_int[0]
-- 
GitLab