From f8e75963d776ff41e52502658329bd13ad70cb6d Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 13 Sep 2024 17:20:18 +0200
Subject: [PATCH] ics/ips: make per-interface scores for trimmed variant
 available

---
 actions/ost-compare-structures   |  36 +++++++--
 modules/mol/alg/pymod/scoring.py | 135 ++++++++++++++++++++++++++++++-
 2 files changed, 161 insertions(+), 10 deletions(-)

diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures
index 57ed24b90..b38bb4b7e 100644
--- a/actions/ost-compare-structures
+++ b/actions/ost-compare-structures
@@ -444,10 +444,20 @@ def _ParseArgs():
               "counterpart are removed. As a consequence, model contacts for "
               "which we have no experimental evidence do not affect the score. "
               "The effect of these added model contacts without mapping to "
-              "target are decreased precision and thus lower ics. Recall is "
+              "target would be decreased precision and thus lower ics. Recall is "
               "not affected. Enabling this flag adds the following keys: "
               "\"ics_trimmed\", \"ics_precision_trimmed\", "
-              "\"ics_recall_trimmed\" and \"model_contacts_trimmed\""))
+              "\"ics_recall_trimmed\", \"model_contacts_trimmed\". "
+              "The reference contacts and reference interfaces are the same "
+              "as for ics and available as keys: \"reference_contacts\", "
+              "\"contact_reference_interfaces\". "
+              "All these measures are also available on a per-interface basis "
+              "for each interface in the reference structure that are defined "
+              "as chain pairs with at least one contact (available as key "
+              " \"contact_reference_interfaces\"). The respective metrics are "
+              "available as keys \"per_interface_ics_precision_trimmed\", "
+              "\"per_interface_ics_recall_trimmed\" and "
+              "\"per_interface_ics_trimmed\"."))
 
     parser.add_argument(
         "--ips",
@@ -906,8 +916,13 @@ def _Process(model, reference, args, model_format, reference_format):
         [_RoundOrNone(x) for x in scorer.per_interface_qs_best]
 
     if args.ics or args.ips:
-        out["reference_contacts"] = scorer.native_contacts
         out["model_contacts"] = scorer.model_contacts
+
+    if args.ics_trimmed or args.ips_trimmed:
+        out["model_contacts_trimmed"] = scorer.trimmed_model_contacts
+
+    if args.ics or args.ips or args.ics_trimmed or args.ips_trimmed:
+        out["reference_contacts"] = scorer.native_contacts
         out["contact_reference_interfaces"] = scorer.contact_target_interfaces
 
     if args.ics:
@@ -932,18 +947,27 @@ def _Process(model, reference, args, model_format, reference_format):
         out["per_interface_ips"] = \
         [_RoundOrNone(x) for x in scorer.per_interface_ips]
 
-    if args.ics_trimmed or args.ips_trimmed:
-        out["model_contacts_trimmed"] = scorer.trimmed_model_contacts
-
     if args.ics_trimmed:
         out["ics_trimmed"] = _RoundOrNone(scorer.ics_trimmed)
         out["ics_precision_trimmed"] = _RoundOrNone(scorer.ics_precision_trimmed)
         out["ics_recall_trimmed"] = _RoundOrNone(scorer.ics_recall_trimmed)
+        out["per_interface_ics_precision_trimmed"] = \
+        [_RoundOrNone(x) for x in scorer.per_interface_ics_precision_trimmed]
+        out["per_interface_ics_recall_trimmed"] = \
+        [_RoundOrNone(x) for x in scorer.per_interface_ics_recall_trimmed]
+        out["per_interface_ics_trimmed"] = \
+        [_RoundOrNone(x) for x in scorer.per_interface_ics_trimmed]
 
     if args.ips_trimmed:
         out["ips_trimmed"] = _RoundOrNone(scorer.ips_trimmed)
         out["ips_precision_trimmed"] = _RoundOrNone(scorer.ips_precision_trimmed)
         out["ips_recall_trimmed"] = _RoundOrNone(scorer.ips_recall_trimmed)
+        out["per_interface_ips_precision_trimmed"] = \
+        [_RoundOrNone(x) for x in scorer.per_interface_ips_precision_trimmed]
+        out["per_interface_ips_recall_trimmed"] = \
+        [_RoundOrNone(x) for x in scorer.per_interface_ips_recall_trimmed]
+        out["per_interface_ips_trimmed"] = \
+        [_RoundOrNone(x) for x in scorer.per_interface_ips_trimmed]
 
     if args.dockq:
         out["dockq_reference_interfaces"] = scorer.dockq_target_interfaces
diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index d1ec0705d..710223b68 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -374,9 +374,9 @@ class Scorer:
         self._ips_precision = None
         self._ips_recall = None
         self._ips = None
-        self._per_interface_ics_precision = None
-        self._per_interface_ics_recall = None
-        self._per_interface_ics = None
+        self._per_interface_ips_precision = None
+        self._per_interface_ips_recall = None
+        self._per_interface_ips = None
 
         # subset of contact scores that operate on trimmed model
         # i.e. no contacts from model residues that are not present in
@@ -384,9 +384,15 @@ class Scorer:
         self._ics_trimmed = None
         self._ics_precision_trimmed = None
         self._ics_recall_trimmed = None
+        self._per_interface_ics_precision_trimmed = None
+        self._per_interface_ics_recall_trimmed = None
+        self._per_interface_ics_trimmed = None
         self._ips_trimmed = None
         self._ips_precision_trimmed = None
         self._ips_recall_trimmed = None
+        self._per_interface_ips_precision_trimmed = None
+        self._per_interface_ips_recall_trimmed = None
+        self._per_interface_ips_trimmed = None
 
         self._dockq_target_interfaces = None
         self._dockq_interfaces = None
@@ -1096,7 +1102,6 @@ class Scorer:
         if self._per_interface_ics is None:
             self._compute_ics_scores()
         return self._per_interface_ics
-    
 
     @property
     def ips_precision(self):
@@ -1175,6 +1180,47 @@ class Scorer:
             self._compute_ics_scores_trimmed()
         return self._ics_recall_trimmed
 
+    @property
+    def per_interface_ics_precision_trimmed(self):
+        """ Same as :attr:`per_interface_ics_precision` but with :attr:`trimmed_model`
+
+        :attr:`~ics_precision_trimmed` for each interface in
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`list` of :class:`float`
+        """
+        if self._per_interface_ics_precision_trimmed is None:
+            self._compute_ics_scores_trimmed()
+        return self._per_interface_ics_precision_trimmed
+
+
+    @property
+    def per_interface_ics_recall_trimmed(self):
+        """ Same as :attr:`per_interface_ics_recall` but with :attr:`trimmed_model`
+
+        :attr:`~ics_recall_trimmed` for each interface in
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`list` of :class:`float`
+        """
+        if self._per_interface_ics_recall_trimmed is None:
+            self._compute_ics_scores_trimmed()
+        return self._per_interface_ics_recall_trimmed
+
+    @property
+    def per_interface_ics_trimmed(self):
+        """ Same as :attr:`per_interface_ics` but with :attr:`trimmed_model`
+
+        :attr:`~ics` for each interface in 
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`float`
+        """
+
+        if self._per_interface_ics_trimmed is None:
+            self._compute_ics_scores_trimmed()
+        return self._per_interface_ics_trimmed
+
     @property
     def ips_trimmed(self):
         """ Same as :attribute:`ips` but with trimmed model
@@ -1257,6 +1303,47 @@ class Scorer:
             self._compute_ips_scores()
         return self._per_interface_ips
 
+    @property
+    def per_interface_ips_precision_trimmed(self):
+        """ Same as :attr:`per_interface_ips_precision` but with :attr:`trimmed_model`
+
+        :attr:`~ips_precision_trimmed` for each interface in
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`list` of :class:`float`
+        """
+        if self._per_interface_ips_precision_trimmed is None:
+            self._compute_ips_scores_trimmed()
+        return self._per_interface_ips_precision_trimmed
+
+
+    @property
+    def per_interface_ips_recall_trimmed(self):
+        """ Same as :attr:`per_interface_ips_recall` but with :attr:`trimmed_model`
+
+        :attr:`~ics_recall_trimmed` for each interface in
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`list` of :class:`float`
+        """
+        if self._per_interface_ips_recall_trimmed is None:
+            self._compute_ips_scores_trimmed()
+        return self._per_interface_ips_recall_trimmed
+
+    @property
+    def per_interface_ips_trimmed(self):
+        """ Same as :attr:`per_interface_ips` but with :attr:`trimmed_model`
+
+        :attr:`~ics` for each interface in 
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`float`
+        """
+
+        if self._per_interface_ips_trimmed is None:
+            self._compute_ips_scores_trimmed()
+        return self._per_interface_ips_trimmed
+
     @property
     def dockq_target_interfaces(self):
         """ Interfaces in :attr:`target` that are relevant for DockQ
@@ -2266,6 +2353,26 @@ class Scorer:
         self._ics_precision_trimmed = contact_scorer_res.precision
         self._ics_recall_trimmed = contact_scorer_res.recall
 
+        self._per_interface_ics_precision_trimmed = list()
+        self._per_interface_ics_recall_trimmed = list()
+        self._per_interface_ics_trimmed = list()
+        flat_mapping = self.mapping.GetFlatMapping()
+        for trg_int in self.contact_target_interfaces:
+            trg_ch1 = trg_int[0]
+            trg_ch2 = trg_int[1]
+            if trg_ch1 in flat_mapping and trg_ch2 in flat_mapping:
+                mdl_ch1 = flat_mapping[trg_ch1]
+                mdl_ch2 = flat_mapping[trg_ch2]
+                res = self.trimmed_contact_scorer.ScoreICSInterface(trg_ch1, trg_ch2,
+                                                                    mdl_ch1, mdl_ch2)
+                self._per_interface_ics_precision_trimmed.append(res.precision)
+                self._per_interface_ics_recall_trimmed.append(res.recall)
+                self._per_interface_ics_trimmed.append(res.ics)
+            else:
+                self._per_interface_ics_precision_trimmed.append(None)
+                self._per_interface_ics_recall_trimmed.append(None)
+                self._per_interface_ics_trimmed.append(None)
+
     def _compute_ips_scores(self):
         LogScript("Computing IPS scores")
         contact_scorer_res = self.contact_scorer.ScoreIPS(self.mapping.mapping)
@@ -2304,6 +2411,26 @@ class Scorer:
         self._ips_recall_trimmed = contact_scorer_res.recall
         self._ips_trimmed = contact_scorer_res.ips
 
+        self._per_interface_ips_precision_trimmed = list()
+        self._per_interface_ips_recall_trimmed = list()
+        self._per_interface_ips_trimmed = list()
+        flat_mapping = self.mapping.GetFlatMapping()
+        for trg_int in self.contact_target_interfaces:
+            trg_ch1 = trg_int[0]
+            trg_ch2 = trg_int[1]
+            if trg_ch1 in flat_mapping and trg_ch2 in flat_mapping:
+                mdl_ch1 = flat_mapping[trg_ch1]
+                mdl_ch2 = flat_mapping[trg_ch2]
+                res = self.trimmed_contact_scorer.ScoreIPSInterface(trg_ch1, trg_ch2,
+                                                                    mdl_ch1, mdl_ch2)
+                self._per_interface_ips_precision_trimmed.append(res.precision)
+                self._per_interface_ips_recall_trimmed.append(res.recall)
+                self._per_interface_ips_trimmed.append(res.ips)
+            else:
+                self._per_interface_ips_precision_trimmed.append(None)
+                self._per_interface_ips_recall_trimmed.append(None)
+                self._per_interface_ips_trimmed.append(None)
+
     def _compute_dockq_scores(self):
         LogScript("Computing DockQ")
 
-- 
GitLab