From b7163962dfbd0f8939bcd8c2ee187d2294e4beb0 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 24 May 2023 14:15:43 +0200
Subject: [PATCH] flag to disable lddt stereochemistry checks in
 ost.mol.alg.Scorer

---
 modules/mol/alg/pymod/scoring.py | 109 ++++++++++++++++++++-----------
 1 file changed, 71 insertions(+), 38 deletions(-)

diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 1a11304c8..f9e7ec864 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -131,11 +131,14 @@ class Scorer:
                          TM-score. If not given, TM-score will be computed
                          with OpenStructure internal copy of USalign code.
     :type usalign_exec: :class:`str`
+    :param lddt_no_stereochecks: Whether to compute lDDT without stereochemistry
+                                checks
+    :type lddt_no_stereochecks: :class:`bool`
     """
     def __init__(self, model, target, resnum_alignments=False,
                  molck_settings = None, naive_chain_mapping_thresh=12,
                  cad_score_exec = None, custom_mapping=None,
-                 usalign_exec = None):
+                 usalign_exec = None, lddt_no_stereochecks=False):
 
         if isinstance(model, mol.EntityView):
             model = mol.CreateEntityFromView(model, False)
@@ -203,6 +206,7 @@ class Scorer:
         self.naive_chain_mapping_thresh = naive_chain_mapping_thresh
         self.cad_score_exec = cad_score_exec
         self.usalign_exec = usalign_exec
+        self.lddt_no_stereochecks = lddt_no_stereochecks
 
         # lazily evaluated attributes
         self._stereochecked_model = None
@@ -474,7 +478,10 @@ class Scorer:
         :type: :class:`ost.mol.alg.lddt.lDDTScorer`
         """
         if self._lddt_scorer is None:
-            self._lddt_scorer = lDDTScorer(self.stereochecked_target)
+            if self.lddt_no_stereochecks:
+                self._lddt_scorer = lDDTScorer(self.target)
+            else:
+                self._lddt_scorer = lDDTScorer(self.stereochecked_target)
         return self._lddt_scorer
 
     @property
@@ -1062,48 +1069,74 @@ class Scorer:
             mdl_seq = aln.GetSequence(1)
             alns[mdl_seq.name] = aln
 
-        lddt_chain_mapping = dict()
-        for mdl_ch, trg_ch in flat_mapping.items():
-            if mdl_ch in stereochecked_alns:
-                lddt_chain_mapping[mdl_ch] = trg_ch
-
-        lddt_score = self.lddt_scorer.lDDT(self.stereochecked_model,
-                                           chain_mapping = lddt_chain_mapping,
-                                           residue_mapping = stereochecked_alns,
-                                           check_resnames=False,
-                                           local_lddt_prop="lddt")[0]
-        local_lddt = dict()
-        for r in self.model.residues:
-            cname = r.GetChain().GetName()
-            if cname not in local_lddt:
-                local_lddt[cname] = dict()
-            if r.HasProp("lddt"):
-                score = round(r.GetFloatProp("lddt"), 3)
-                local_lddt[cname][r.GetNumber()] = score
-            else:
-                # rsc => residue stereo checked...
-                mdl_res = self.stereochecked_model.FindResidue(cname, r.GetNumber())
-                if mdl_res.IsValid():
+        # score variables to be set
+        lddt_score = None
+        local_lddt = None
+
+        if self.lddt_no_stereochecks:
+            lddt_chain_mapping = dict()
+            for mdl_ch, trg_ch in flat_mapping.items():
+                if mdl_ch in alns:
+                    lddt_chain_mapping[mdl_ch] = trg_ch
+            lddt_score = self.lddt_scorer.lDDT(self.model,
+                                               chain_mapping = lddt_chain_mapping,
+                                               residue_mapping = alns,
+                                               check_resnames=False,
+                                               local_lddt_prop="lddt")[0]
+            local_lddt = dict()
+            for r in self.model.residues:
+                cname = r.GetChain().GetName()
+                if cname not in local_lddt:
+                    local_lddt[cname] = dict()
+                if r.HasProp("lddt"):
+                    score = round(r.GetFloatProp("lddt"), 3)
+                    local_lddt[cname][r.GetNumber()] = score
+                else:
                     # not covered by trg or skipped in chain mapping procedure
                     # the latter happens if its part of a super short chain
                     local_lddt[cname][r.GetNumber()] = None
+
+        else:
+            lddt_chain_mapping = dict()
+            for mdl_ch, trg_ch in flat_mapping.items():
+                if mdl_ch in stereochecked_alns:
+                    lddt_chain_mapping[mdl_ch] = trg_ch
+            lddt_score = self.lddt_scorer.lDDT(self.stereochecked_model,
+                                               chain_mapping = lddt_chain_mapping,
+                                               residue_mapping = stereochecked_alns,
+                                               check_resnames=False,
+                                               local_lddt_prop="lddt")[0]
+            local_lddt = dict()
+            for r in self.model.residues:
+                cname = r.GetChain().GetName()
+                if cname not in local_lddt:
+                    local_lddt[cname] = dict()
+                if r.HasProp("lddt"):
+                    score = round(r.GetFloatProp("lddt"), 3)
+                    local_lddt[cname][r.GetNumber()] = score
                 else:
-                    # opt 1: removed by stereochecks => assign 0.0
-                    # opt 2: removed by stereochecks AND not covered by ref
-                    #        => assign None
-
-                    # fetch trg residue from non-stereochecked aln
-                    trg_r = None
-                    if cname in flat_mapping:
-                        for col in alns[cname]:
-                            if col[0] != '-' and col[1] != '-':
-                                if col.GetResidue(1).number == r.number:
-                                    trg_r = col.GetResidue(0)
-                                    break
-                    if trg_r is None:
+                    # rsc => residue stereo checked...
+                    mdl_res = self.stereochecked_model.FindResidue(cname, r.GetNumber())
+                    if mdl_res.IsValid():
+                        # not covered by trg or skipped in chain mapping procedure
+                        # the latter happens if its part of a super short chain
                         local_lddt[cname][r.GetNumber()] = None
                     else:
-                        local_lddt[cname][r.GetNumber()] = 0.0
+                        # opt 1: removed by stereochecks => assign 0.0
+                        # opt 2: removed by stereochecks AND not covered by ref
+                        #        => assign None
+                        # fetch trg residue from non-stereochecked aln
+                        trg_r = None
+                        if cname in flat_mapping:
+                            for col in alns[cname]:
+                                if col[0] != '-' and col[1] != '-':
+                                    if col.GetResidue(1).number == r.number:
+                                        trg_r = col.GetResidue(0)
+                                        break
+                        if trg_r is None:
+                            local_lddt[cname][r.GetNumber()] = None
+                        else:
+                            local_lddt[cname][r.GetNumber()] = 0.0
 
         self._lddt = lddt_score
         self._local_lddt = local_lddt
-- 
GitLab