From 94618420e88851c49a919f53f42bb1ca6287f8f1 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 12 May 2023 00:09:45 +0200
Subject: [PATCH] add TM-score to Scorer object

---
 modules/mol/alg/pymod/scoring.py | 61 +++++++++++++++++++++++++++++++-
 1 file changed, 60 insertions(+), 1 deletion(-)

diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 79fd4a29f..1a11304c8 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -13,7 +13,9 @@ from ost.mol.alg import dockq
 from ost.mol.alg.lddt import lDDTScorer
 from ost.mol.alg.qsscore import QSScorer
 from ost.mol.alg import Molck, MolckSettings
+from ost import bindings
 from ost.bindings import cadscore
+from ost.bindings import tmtools
 import numpy as np
 
 class lDDTBSScorer:
@@ -125,10 +127,15 @@ class Scorer:
                            *target*. Dictionary with target chain names as key
                            and model chain names as value.
     :type custom_mapping: :class:`dict`
+    :param usalign_exec: Explicit path to USalign executable used to compute
+                         TM-score. If not given, TM-score will be computed
+                         with OpenStructure internal copy of USalign code.
+    :type usalign_exec: :class:`str`
     """
     def __init__(self, model, target, resnum_alignments=False,
                  molck_settings = None, naive_chain_mapping_thresh=12,
-                 cad_score_exec = None, custom_mapping=None):
+                 cad_score_exec = None, custom_mapping=None,
+                 usalign_exec = None):
 
         if isinstance(model, mol.EntityView):
             model = mol.CreateEntityFromView(model, False)
@@ -195,6 +202,7 @@ class Scorer:
         self.resnum_alignments = resnum_alignments
         self.naive_chain_mapping_thresh = naive_chain_mapping_thresh
         self.cad_score_exec = cad_score_exec
+        self.usalign_exec = usalign_exec
 
         # lazily evaluated attributes
         self._stereochecked_model = None
@@ -258,6 +266,9 @@ class Scorer:
         self._patch_qs = None
         self._patch_dockq = None
 
+        self._tm_score = None
+        self._usalign_mapping = None
+
         if custom_mapping is not None:
             self._set_custom_mapping(custom_mapping)
 
@@ -961,6 +972,35 @@ class Scorer:
             self._compute_patchdockq_scores()
         return self._patch_dockq
 
+    @property
+    def tm_score(self):
+        """ TM-score computed with USalign
+
+        USalign executable can be specified with usalign_exec kwarg at Scorer
+        construction, an OpenStructure internal copy of the USalign code is
+        used otherwise.
+
+        :type: :class:`float`
+        """
+        if self._tm_score is None:
+            self._compute_tmscore()
+        return self._tm_score
+
+    @property
+    def usalign_mapping(self):
+        """ Mapping computed with USalign
+
+        Dictionary with target chain names as key and model chain names as
+        values. No guarantee that all chains are mapped. USalign executable
+        can be specified with usalign_exec kwarg at Scorer construction, an
+        OpenStructure internal copy of the USalign code is used otherwise.
+
+        :type: :class:`dict`
+        """
+        if self._usalign_mapping is None:
+            self._compute_tmscore()
+        return self._usalign_mapping
+
     def _aln_helper(self, target, model):
         # perform required alignments - cannot take the alignments from the
         # mapping results as we potentially remove stuff there as compared
@@ -1644,3 +1684,22 @@ class Scorer:
         self._mapping = chain_mapping.MappingResult(chain_mapper.target, mdl,
                                                     chain_mapper.chem_groups,
                                                     final_mapping, alns)
+
+    def _compute_tmscore(self):
+        res = None
+        if self.usalign_exec is not None:
+            if not os.path.exists(self.usalign_exec):
+                raise RuntimeError(f"USalign exec ({self.usalign_exec}) "
+                                   f"not found")
+            if not os.access(self.usalign_exec, os.X_OK):
+                raise RuntimeError(f"USalign exec ({self.usalign_exec}) "
+                                   f"is not executable")
+            res = tmtools.USAlign(self.model, self.target,
+                                  usalign = self.usalign_exec)
+        else:
+            res = bindings.WrappedMMAlign(self.model, self.target)
+
+        self._tm_score = res.tm_score
+        self._usalign_mapping = dict()
+        for a,b in zip(res.ent1_mapped_chains, res.ent2_mapped_chains):
+            self._usalign_mapping[b] = a
-- 
GitLab