From 55082fcb8693c7f8be045d8b5ac42da7b5752d5d Mon Sep 17 00:00:00 2001
From: Rafal Gumienny <r.gumienny@unibas.ch>
Date: Wed, 18 Apr 2018 14:52:08 +0200
Subject: [PATCH] feat: SCHWED-3121 Output local lDDT scores

---
 actions/ost-compare-structures     | 45 +++++++++++++++++-----
 modules/mol/alg/pymod/qsscoring.py | 60 +++++++++++++++++++++++++++++-
 modules/seq/alg/pymod/renumber.py  | 14 ++++---
 3 files changed, 102 insertions(+), 17 deletions(-)

diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures
index 0c1dc991a..6aa18d7ba 100644
--- a/actions/ost-compare-structures
+++ b/actions/ost-compare-structures
@@ -293,6 +293,13 @@ def _ParseArgs():
         default=False,
         action="store_true",
         help=("Residue name consistency checks."))
+    parser.add_argument(
+        "-spr",
+        "--save-per-residue-scores",
+        dest="save_per_residue_scores",
+        default=False,
+        action="store_true",
+        help=(""))
     #
     # Molck parameters
     #
@@ -609,7 +616,7 @@ def _Main():
                     qs_scorer.alignments,
                     qs_scorer.calpha_only,
                     lddt_settings)
-                for lddt_scorer in oligo_lddt_scorer.sc_lddt_scorers:
+                for scorer_index, lddt_scorer in enumerate(oligo_lddt_scorer.sc_lddt_scorers):
                     # Get chains and renumber according to alignment (for lDDT)
                     try:
                         model_chain = lddt_scorer.model.chains[0].GetName()
@@ -619,15 +626,6 @@ def _Main():
                                      "chain %s and reference chain %s") % (
                                          model_chain,
                                          reference_chain))
-                        lddt_results["single_chain_lddt"].append({
-                            "status": "SUCCESS",
-                            "error": "",
-                            "model_chain": model_chain,
-                            "reference_chain": reference_chain,
-                            "global_score": lddt_scorer.global_score,
-                            "conserved_contacts":
-                                lddt_scorer.conserved_contacts,
-                            "total_contacts": lddt_scorer.total_contacts})
                         ost.LogInfo("Global LDDT score: %.4f" %
                                     lddt_scorer.global_score)
                         ost.LogInfo(
@@ -635,6 +633,33 @@ def _Main():
                             "%i thresholds)" % (lddt_scorer.conserved_contacts,
                                                 lddt_scorer.total_contacts,
                                                 len(lddt_settings.cutoffs)))
+                        sc_lddt_scores = {
+                            "status": "SUCCESS",
+                            "error": "",
+                            "model_chain": model_chain,
+                            "reference_chain": reference_chain,
+                            "global_score": lddt_scorer.global_score,
+                            "conserved_contacts":
+                                lddt_scorer.conserved_contacts,
+                            "total_contacts": lddt_scorer.total_contacts}
+                        if opts.save_per_residue_scores:
+                            per_residue_sc = oligo_lddt_scorer.GetPerResidueScores(
+                                scorer_index)
+                            ost.LogInfo("Per residue local lDDT (reference):")
+                            ost.LogInfo("Chain\tResidue Number\tResidue Name"
+                                         "\tlDDT\tConserved Contacts\tTotal "
+                                         "Contacts")
+                            for prs_scores in per_residue_sc:
+                                ost.LogInfo("%s\t%i\t%s\t%.4f\t%i\t%i" % (
+                                    reference_chain,
+                                    prs_scores["residue_number"],
+                                    prs_scores["residue_name"],
+                                    prs_scores["lddt"],
+                                    prs_scores["conserved_contacts"],
+                                    prs_scores["total_contacts"]))
+                            sc_lddt_scores["per_residue_scores"] = per_residue_sc
+                        lddt_results["single_chain_lddt"].append(
+                            sc_lddt_scores)
                     except Exception as ex:
                         ost.LogError('Single chain lDDT failed:', str(ex))
                         lddt_results["single_chain_lddt"].append({
diff --git a/modules/mol/alg/pymod/qsscoring.py b/modules/mol/alg/pymod/qsscoring.py
index 16e4357c8..f33b21707 100644
--- a/modules/mol/alg/pymod/qsscoring.py
+++ b/modules/mol/alg/pymod/qsscoring.py
@@ -839,6 +839,7 @@ class OligoLDDTScorer(object):
     self.alignments = alignments
     self.calpha_only = calpha_only
     self.settings = settings
+    self._old_number_label = "old_num"
     self._sc_lddt = None
     self._oligo_lddt = None
     self._weighted_lddt = None
@@ -892,15 +893,70 @@ class OligoLDDTScorer(object):
       self._sc_lddt_scorers = list()
       for aln in self.alignments:
         # Get chains and renumber according to alignment (for lDDT)
-        reference = Renumber(aln.GetSequence(0)).CreateFullView()
-        model = Renumber(aln.GetSequence(1)).CreateFullView()
+        reference = Renumber(
+          aln.GetSequence(0),
+          old_number_label=self._old_number_label).CreateFullView()
+        refseq = seq.CreateSequence(
+          "reference_renumbered",
+          aln.GetSequence(0).GetString())
+        refseq.AttachView(reference)
+        aln.AddSequence(refseq)
+        model = Renumber(
+          aln.GetSequence(1),
+          old_number_label=self._old_number_label).CreateFullView()
+        modelseq = seq.CreateSequence(
+          "model_renumbered",
+          aln.GetSequence(1).GetString())
+        modelseq.AttachView(model)
+        aln.AddSequence(modelseq)
         lddt_scorer = lDDTScorer(
           references=[reference],
           model=model,
           settings=self.settings)
+        lddt_scorer.alignment = aln  # a bit of a hack
         self._sc_lddt_scorers.append(lddt_scorer)
     return self._sc_lddt_scorers
 
+  def GetPerResidueScores(self, scorer_index):
+    scores = list()
+    assigned_residues = list()
+    # Make sure the score is calculated
+    self.sc_lddt_scorers[scorer_index].global_score
+    for col in self.sc_lddt_scorers[scorer_index].alignment:
+      if col[0] != "-" and col.GetResidue(3).IsValid():
+        ref_res = col.GetResidue(0)
+        mdl_res = col.GetResidue(1)
+        ref_res_renum = col.GetResidue(2)
+        mdl_res_renum = col.GetResidue(3)
+        if ref_res.one_letter_code != ref_res_renum.one_letter_code:
+          raise RuntimeError("Reference residue name mapping inconsistent: %s != %s" %
+                             (ref_res.one_letter_code,
+                              ref_res_renum.one_letter_code))
+        if mdl_res.one_letter_code != mdl_res_renum.one_letter_code:
+          raise RuntimeError("Model residue name mapping inconsistent: %s != %s" %
+                             (mdl_res.one_letter_code,
+                              mdl_res_renum.one_letter_code))
+        if ref_res.GetNumber().num != ref_res_renum.GetIntProp(self._old_number_label):
+          raise RuntimeError("Reference residue number mapping inconsistent: %s != %s" %
+                             (ref_res.GetNumber().num,
+                              ref_res_renum.GetIntProp(self._old_number_label)))
+        if mdl_res.GetNumber().num != mdl_res_renum.GetIntProp(self._old_number_label):
+          raise RuntimeError("Model residue number mapping inconsistent: %s != %s" %
+                             (mdl_res.GetNumber().num,
+                              mdl_res_renum.GetIntProp(self._old_number_label)))
+        if ref_res.qualified_name in assigned_residues:
+          raise RuntimeError("Duplicated residue in reference: " %
+                             (ref_res.qualified_name))
+        else:
+          assigned_residues.append(ref_res.qualified_name)
+        scores.append({
+          "residue_number": ref_res.GetNumber().num,
+          "residue_name": ref_res.name,
+          "lddt": mdl_res_renum.GetFloatProp(self.settings.label),
+          "conserved_contacts": mdl_res_renum.GetFloatProp(self.settings.label + "_conserved"),
+          "total_contacts": mdl_res_renum.GetFloatProp(self.settings.label + "_total")})
+    return scores
+
   @property
   def sc_lddt(self):
     if self._sc_lddt is None:
diff --git a/modules/seq/alg/pymod/renumber.py b/modules/seq/alg/pymod/renumber.py
index 9f6dd02d5..434732dc8 100644
--- a/modules/seq/alg/pymod/renumber.py
+++ b/modules/seq/alg/pymod/renumber.py
@@ -1,6 +1,6 @@
 from ost import seq, mol
 
-def _RenumberSeq(seq_handle):
+def _RenumberSeq(seq_handle, old_number_label=None):
   if not seq_handle.HasAttachedView():
     raise RuntimeError("Sequence Handle has no attached view")
   ev = seq_handle.attached_view.CreateEmptyView()
@@ -11,11 +11,13 @@ def _RenumberSeq(seq_handle):
       if r.IsValid():
           ev.AddResidue(r, mol.INCLUDE_ALL)
           new_numbers.append(pos+1)
+          if old_number_label is not None:
+            r.SetIntProp(old_number_label, r.number.GetNum())
       else:
         raise RuntimeError('Error: renumbering failed at position %s' % pos)
   return ev, new_numbers
 
-def _RenumberAln(aln, seq_index):
+def _RenumberAln(aln, seq_index, old_number_label=None):
   if not aln.sequences[seq_index].HasAttachedView():
     raise RuntimeError("Sequence Handle has no attached view")
   counter=0
@@ -34,11 +36,13 @@ def _RenumberAln(aln, seq_index):
                            % (counter))
       ev.AddResidue(r, mol.INCLUDE_ALL)
       new_numbers.append(counter+1)
+      if old_number_label is not None:
+        r.SetIntProp(old_number_label, r.number.GetNum())
     counter += 1
   return ev, new_numbers
 
 
-def Renumber(seq_handle, sequence_number_with_attached_view=1):
+def Renumber(seq_handle, sequence_number_with_attached_view=1, old_number_label=None):
   """
   Function to renumber an entity according to an alignment between the model
   sequence and the full-length target sequence. The aligned model sequence or
@@ -70,9 +74,9 @@ def Renumber(seq_handle, sequence_number_with_attached_view=1):
   """
   if isinstance(seq_handle, seq.SequenceHandle) \
      or isinstance(seq_handle, seq.ConstSequenceHandle):
-    ev, new_numbers = _RenumberSeq(seq_handle)
+    ev, new_numbers = _RenumberSeq(seq_handle, old_number_label)
   elif isinstance(seq_handle, seq.AlignmentHandle):
-    ev, new_numbers = _RenumberAln(seq_handle, sequence_number_with_attached_view)
+    ev, new_numbers = _RenumberAln(seq_handle, sequence_number_with_attached_view, old_number_label)
   else:
     raise RuntimeError("Unknown input type " + str(type(seq_handle)))
 
-- 
GitLab