From 2170c86235a90339409aec48ea58474903defa9d Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 2 Nov 2022 13:50:21 +0100
Subject: [PATCH] enable CAD score in Scorer object

---
 modules/mol/alg/pymod/scoring.py | 143 ++++++++++++++++++++++++++-----
 1 file changed, 123 insertions(+), 20 deletions(-)

diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 11cdb0693..400fb51b2 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -1,3 +1,4 @@
+import os
 from ost import mol
 from ost import seq
 from ost import io
@@ -12,6 +13,7 @@ from ost.mol.alg.qsscore import QSScorer
 from ost.io import ReadStereoChemicalPropsFile
 from ost.mol.alg import CheckStructure, Molck, MolckSettings
 from ost.bindings import dockq
+from ost.bindings import cadscore
 import numpy as np
 
 class lDDTBSScorer:
@@ -112,9 +114,21 @@ class Scorer:
                                        to optimize for QS-score. Everything
                                        above is treated with a heuristic.
     :type naive_chain_mapping_thresh: :class:`int` 
+    :param dockq_exec: Explicit path to DockQ.py script from DockQ installation
+                       from https://github.com/bjornwallner/DockQ. If not given,
+                       DockQ.py must be in PATH if any of the DockQ related
+                       attributes is requested.
+    :type dockq_exec: :class:`str`
+    :param cad_score_exec: Explicit path to voronota-cadscore executable from
+                           voronota installation from 
+                           https://github.com/kliment-olechnovic/voronota. If
+                           not given, voronota-cadscore must be in PATH if any
+                           of the CAD score related attributes is requested.
+    :type cad_score_exec: :class:`str`
     """
     def __init__(self, model, target, resnum_alignments=False,
-                 molck_settings = None, naive_chain_mapping_thresh=12):
+                 molck_settings = None, naive_chain_mapping_thresh=12,
+                 dockq_exec = None, cad_score_exec = None):
 
         if isinstance(model, mol.EntityView):
             self._model = mol.CreateEntityFromView(model, False)
@@ -169,6 +183,8 @@ class Scorer:
         Molck(self._target, conop.GetDefaultLib(), molck_settings)
         self.resnum_alignments = resnum_alignments
         self.naive_chain_mapping_thresh = naive_chain_mapping_thresh
+        self.dockq_exec = dockq_exec
+        self.cad_score_exec = cad_score_exec
 
         # lazily evaluated attributes
         self._stereochecked_model = None
@@ -186,8 +202,8 @@ class Scorer:
         self._lddt = None
         self._local_lddt = None
 
-        self._QS_global = None
-        self._QS_best = None
+        self._qs_global = None
+        self._qs_best = None
 
         self._dockq_interfaces = None
         self._dockq_native_contacts = None
@@ -208,6 +224,9 @@ class Scorer:
         self._gdtha = None
         self._rmsd = None
 
+        self._cad_score = None
+        self._local_cad_score = None
+
     @property
     def model(self):
         """ Model with Molck cleanup
@@ -370,26 +389,26 @@ class Scorer:
         chain mapping procedure (happens for super short chains), the respective
         score is set to None. In case of oligomers, :attr:`~mapping` is used.
 
-        :type: :class:`float`
+        :type: :class:`dict`
         """
         if self._local_lddt is None:
             self._compute_lddt()
         return self._local_lddt
 
     @property
-    def QS_global(self):
+    def qs_global(self):
         """  Global QS-score
 
         Computed based on :attr:`model` using :attr:`mapping`
 
         :type: :class:`float`
         """
-        if self._QS_global is None:
+        if self._qs_global is None:
             self._compute_qs()
-        return self._QS_global
+        return self._qs_global
 
     @property
-    def QS_best(self):
+    def qs_best(self):
         """  Global QS-score - only computed on aligned residues
 
         Computed based on :attr:`model` using :attr:`mapping`. The QS-score
@@ -399,9 +418,9 @@ class Scorer:
 
         :type: :class:`float`
         """
-        if self._QS_best is None:
+        if self._qs_best is None:
             self._compute_qs()
-        return self._QS_best
+        return self._qs_best
 
     @property
     def dockq_interfaces(self):
@@ -628,6 +647,32 @@ class Scorer:
             self._rmsd = \
             self.mapped_target_pos.GetRMSD(self.transformed_mapped_model_pos)
         return self._rmsd
+
+    @property
+    def cad_score(self):
+        """ The global CAD atom-atom (AA) score
+
+        Computed based on :attr:`~model`. In case of oligomers, :attr:`~mapping`
+        is used.
+
+        :type: :class:`float`
+        """
+        if self._cad_score is None:
+            self._compute_cad_score()
+        return self._cad_score
+
+    @property
+    def local_cad_score(self):
+        """ The per-residue CAD atom-atom (AA) scores
+
+        Computed based on :attr:`~model`. In case of oligomers, :attr:`~mapping`
+        is used.
+
+        :type: :class:`dict`
+        """
+        if self._local_cad_score is None:
+            self._compute_cad_score()
+        return self._local_cad_score
     
     def _compute_lddt(self):
         # lDDT requires a flat mapping with mdl_ch as key and trg_ch as value
@@ -682,27 +727,49 @@ class Scorer:
                 local_lddt[cname][r.GetNumber().GetNum()] = score
             else:
                 # rsc => residue stereo checked...
-                rsc = self.stereochecked_model.FindResidue(cname, r.GetNumber())
-                if not rsc.IsValid():
-                    # has been removed by stereochecks => assign 0.0
-                    local_lddt[cname][r.GetNumber().GetNum()] = 0.0
-                else:
+                mdl_res = self.stereochecked_model.FindResidue(cname, r.GetNumber())
+                if mdl_res.IsValid():
                     # not covered by trg or skipped in chain mapping procedure
                     # the latter happens if its part of a super short chain
                     local_lddt[cname][r.GetNumber().GetNum()] = None
+                else:
+                    # opt 1: removed by stereochecks => assign 0.0
+                    # opt 2: removed by stereochecks AND not covered by ref
+                    #        => assign None
+
+                    # fetch trg residue from non-stereochecked aln
+                    aln = self.mapping.alns[(flat_mapping[cname], cname)]
+                    trg_r = None
+                    for col in aln:
+                        if col[0] != '-' and col[1] != '-':
+                            if col.GetResidue(1).GetNumber() == r.GetNumber():
+                                trg_r = col.GetResidue(0)
+                                break
+                    if trg_r is None:
+                        local_lddt[cname][r.GetNumber().GetNum()] = None
+                    else:
+                        local_lddt[cname][r.GetNumber().GetNum()] = 0.0
+
         self._lddt = lddt_score
         self._local_lddt = local_lddt
 
     def _compute_qs(self):
         qs_score_result = self.qs_scorer.Score(self.mapping.mapping)
-        self._QS_global = qs_score_result.QS_global
-        self._QS_best = qs_score_result.QS_best
+        self._qs_global = qs_score_result.QS_global
+        self._qs_best = qs_score_result.QS_best
 
     def _compute_dockq(self):
+        if not self.resnum_alignments:
+            raise RuntimeError("DockQ computations rely on residue numbers "
+                               "that are consistent between target and model "
+                               "chains, i.e. only work if resnum_alignments "
+                               "is True at Scorer construction.")
         try:
-            dockq_exec = settings.Locate("DockQ.py")
-        except:
-            raise RuntimeError("DockQ.py must be in PATH for DockQ scoring")
+            dockq_exec = settings.Locate("DockQ.py",
+                                         explicit_file_name=self.dockq_exec)
+        except Exception as e:
+            raise RuntimeError("DockQ.py must be in PATH for DockQ "
+                               "scoring") from e
 
         flat_mapping = self.mapping.GetFlatMapping()
         # list of [trg_ch1, trg_ch2, mdl_ch1, mdl_ch2]
@@ -799,6 +866,42 @@ class Scorer:
             if ch.GetName() not in processed_trg_chains:
                 self._n_target_not_mapped += len(ch.residues)
 
+    def _compute_cad_score(self):
+        if not self.resnum_alignments:
+            raise RuntimeError("CAD score computations rely on residue numbers "
+                               "that are consistent between target and model "
+                               "chains, i.e. only work if resnum_alignments "
+                               "is True at Scorer construction.")
+        try:
+            cad_score_exec = \
+            settings.Locate("voronota-cadscore",
+                            explicit_file_name=self.cad_score_exec)
+        except Exception as e:
+            raise RuntimeError("voronota-cadscore must be in PATH for CAD "
+                               "score scoring") from e
+        cad_bin_dir = os.path.dirname(cad_score_exec)
+        m = self.mapping.GetFlatMapping(mdl_as_key=True)
+        cad_result = cadscore.CADScore(self.model, self.target,
+                                       mode = "voronota",
+                                       label="localcad",
+                                       old_regime=False,
+                                       cad_bin_path=cad_bin_dir,
+                                       chain_mapping=m)
+
+        local_cad = dict()
+        for r in self.model.residues:
+            cname = r.GetChain().GetName()
+            if cname not in local_cad:
+                local_cad[cname] = dict()
+            if r.HasProp("localcad"):
+                score = round(r.GetFloatProp("localcad"), 3)
+                local_cad[cname][r.GetNumber().GetNum()] = score
+            else:
+                local_cad[cname][r.GetNumber().GetNum()] = None
+
+        self._cad_score = cad_result.globalAA
+        self._local_cad_score = local_cad
+
     def _get_repr_view(self, ent):
         """ Returns view with representative atoms => CB, CA for GLY
     
-- 
GitLab