From d6debb7c888fc7ede2d5dcc208664734fb814b9a Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Tue, 24 Oct 2023 17:06:54 +0200
Subject: [PATCH] Contact scores - Enable per-interface IPS scoring

requested by Andriy. Computes all IPS metrics on a per-interface basis
---
 actions/ost-compare-structures         |  14 ++-
 modules/doc/actions.rst                |   9 +-
 modules/mol/alg/pymod/contact_score.py | 123 ++++++++++++++++++++-----
 modules/mol/alg/pymod/scoring.py       |  64 +++++++++++++
 4 files changed, 183 insertions(+), 27 deletions(-)

diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures
index 0692c610e..15527423a 100644
--- a/actions/ost-compare-structures
+++ b/actions/ost-compare-structures
@@ -386,7 +386,14 @@ def _ParseArgs():
               "also interface residues in the model. "
               "The IPS score (Interface Patch Similarity) available as key "
               "\"ips\" is the Jaccard coefficient between interface residues "
-              "in reference and model."))
+              "in reference and model. "
+              "All these measures are also available on a per-interface basis "
+              "for each interface in the reference structure that are defined "
+              "as chain pairs with at least one contact (available as key "
+              " \"contact_reference_interfaces\"). The respective metrics are "
+              "available as keys \"per_interface_ips_precision\", "
+              "\"per_interface_ips_recall\" and \"per_interface_ips\"."))
+
 
     parser.add_argument(
         "--rigid-scores",
@@ -658,12 +665,12 @@ def _Process(model, reference, args):
     if args.ics or args.ips:
         out["reference_contacts"] = scorer.native_contacts
         out["model_contacts"] = scorer.model_contacts
+        out["contact_reference_interfaces"] = scorer.contact_target_interfaces
 
     if args.ics:
         out["ics_precision"] = scorer.ics_precision
         out["ics_recall"] = scorer.ics_recall
         out["ics"] = scorer.ics
-        out["contact_reference_interfaces"] = scorer.contact_target_interfaces
         out["per_interface_ics_precision"] = scorer.per_interface_ics_precision
         out["per_interface_ics_recall"] = scorer.per_interface_ics_recall
         out["per_interface_ics"] = scorer.per_interface_ics
@@ -672,6 +679,9 @@ def _Process(model, reference, args):
         out["ips_precision"] = scorer.ips_precision
         out["ips_recall"] = scorer.ips_recall
         out["ips"] = scorer.ips
+        out["per_interface_ips_precision"] = scorer.per_interface_ips_precision
+        out["per_interface_ips_recall"] = scorer.per_interface_ips_recall
+        out["per_interface_ips"] = scorer.per_interface_ips
 
     if args.dockq:
         out["dockq_reference_interfaces"] = scorer.dockq_target_interfaces
diff --git a/modules/doc/actions.rst b/modules/doc/actions.rst
index eb315c341..bd5785bfd 100644
--- a/modules/doc/actions.rst
+++ b/modules/doc/actions.rst
@@ -293,7 +293,14 @@ Details on the usage (output of ``ost compare-structures --help``):
                           also interface residues in the model. The IPS score
                           (Interface Patch Similarity) available as key "ips" is
                           the Jaccard coefficient between interface residues in
-                          reference and model.
+                          reference and model. All these measures are also
+                          available on a per-interface basis for each interface
+                          in the reference structure that are defined as chain
+                          pairs with at least one contact (available as key
+                          "contact_reference_interfaces"). The respective
+                          metrics are available as keys
+                          "per_interface_ips_precision",
+                          "per_interface_ips_recall" and "per_interface_ips".
     --rigid-scores        Computes rigid superposition based scores. They're
                           based on a Kabsch superposition of all mapped CA
                           positions (C3' for nucleotides). Makes the following
diff --git a/modules/mol/alg/pymod/contact_score.py b/modules/mol/alg/pymod/contact_score.py
index ef77a183d..b5ddeb10b 100644
--- a/modules/mol/alg/pymod/contact_score.py
+++ b/modules/mol/alg/pymod/contact_score.py
@@ -405,21 +405,25 @@ class ContactScorer:
         self._alns = alns
 
         # cache for mapped interface scores
-        # relevant to compute ICS
         # key: tuple of tuple ((qsent1_ch1, qsent1_ch2),
         #                     ((qsent2_ch1, qsent2_ch2))
-        # value: tuple with two numbers required for computation of ICS
-        #        1: n_union
-        #        2: n_intersection
-        self._mapped_cache_ics = dict()
-
-        # cache for mapped scores
-        # relevant to compute IPS
+        # value: tuple with four numbers required for computation of
+        #        per-interface scores.
+        #        The first two are relevant for ICS, the others for per
+        #        interface IPS.
+        #        1: n_union_contacts
+        #        2: n_intersection_contacts
+        #        3: n_union_interface_residues
+        #        4: n_intersection_interface_residues
+        self._mapped_cache_interface = dict()
+
+        # cache for mapped single chain scores
+        # for interface residues of single chains
         # key: tuple: (qsent1_ch, qsent2_ch)
         # value: tuple with two numbers required for computation of IPS
         #        1: n_union
         #        2: n_intersection
-        self._mapped_cache_ips = dict()
+        self._mapped_cache_sc = dict()
 
     @staticmethod
     def FromMappingResult(mapping_result, contact_mode="aa", contact_d = 5.0):
@@ -565,7 +569,7 @@ class ContactScorer:
         else:
             n_mdl = 0
 
-        n_union, n_intersection = self._MappedInterfaceScores(trg_int, mdl_int)
+        n_union, n_intersection, _, _ = self._MappedInterfaceScores(trg_int, mdl_int)
         return ContactScorerResultICS(n_trg, n_mdl, n_union, n_intersection)
 
     def ICSFromFlatMapping(self, flat_mapping):
@@ -585,7 +589,7 @@ class ContactScorer:
         for int1 in self.cent1.interacting_chains:
             if int1[0] in flat_mapping and int1[1] in flat_mapping:
                 int2 = (flat_mapping[int1[0]], flat_mapping[int1[1]])
-                a, b = self._MappedInterfaceScores(int1, int2)
+                a, b, _, _ = self._MappedInterfaceScores(int1, int2)
                 n_union += a
                 n_intersection += b
                 processed_cent2_interfaces.add((min(int2), max(int2)))
@@ -596,7 +600,7 @@ class ContactScorer:
             if int2 not in processed_cent2_interfaces:
                 if int2[0] in r_flat_mapping and int2[1] in r_flat_mapping:
                     int1 = (r_flat_mapping[int2[0]], r_flat_mapping[int2[1]])
-                    a, b = self._MappedInterfaceScores(int1, int2)
+                    a, b, _, _ = self._MappedInterfaceScores(int1, int2)
                     n_union += a
                     n_intersection += b
 
@@ -640,6 +644,59 @@ class ContactScorer:
 
         return self.IPSFromFlatMapping(flat_mapping)
 
+    def ScoreIPSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2):
+        """ Computes IPS scores only considering one interface
+
+        This only works for interfaces that are computed in :func:`Score`, i.e.
+        interfaces for which the alignments are set up correctly.
+
+        :param trg_ch1: Name of first interface chain in target
+        :type trg_ch1: :class:`str`
+        :param trg_ch2: Name of second interface chain in target
+        :type trg_ch2: :class:`str`
+        :param mdl_ch1: Name of first interface chain in model
+        :type mdl_ch1: :class:`str`
+        :param mdl_ch2: Name of second interface chain in model
+        :type mdl_ch2: :class:`str`
+        :returns: Result object of type :class:`ContactScorerResultIPS`
+        :raises: :class:`RuntimeError` if no aln for trg_ch1/mdl_ch1 or
+                 trg_ch2/mdl_ch2 is available.
+        """
+        if (trg_ch1, mdl_ch1) not in self.alns:
+            raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
+                               f"mdl_ch1 ({mdl_ch1}) available. Did you "
+                               f"construct the QSScorer object from a "
+                               f"MappingResult and are trg_ch1 and mdl_ch1 "
+                               f"mapped to each other?")
+        if (trg_ch2, mdl_ch2) not in self.alns:
+            raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and "
+                               f"mdl_ch1 ({mdl_ch1}) available. Did you "
+                               f"construct the QSScorer object from a "
+                               f"MappingResult and are trg_ch1 and mdl_ch1 "
+                               f"mapped to each other?")
+        trg_int = (trg_ch1, trg_ch2)
+        mdl_int = (mdl_ch1, mdl_ch2)
+        trg_int_r = (trg_ch2, trg_ch1)
+        mdl_int_r = (mdl_ch2, mdl_ch1)
+
+        if trg_int in self.cent1.contacts:
+            n_trg = len(self.cent1.contacts[trg_int])
+        elif trg_int_r in self.cent1.contacts:
+            n_trg = len(self.cent1.contacts[trg_int_r])
+        else:
+            n_trg = 0
+
+        if mdl_int in self.cent2.contacts:
+            n_mdl = len(self.cent2.contacts[mdl_int])
+        elif mdl_int_r in self.cent2.contacts:
+            n_mdl = len(self.cent2.contacts[mdl_int_r])
+        else:
+            n_mdl = 0
+
+        _, _, n_union, n_intersection = self._MappedInterfaceScores(trg_int, mdl_int)
+        return ContactScorerResultIPS(n_trg, n_mdl, n_union, n_intersection)
+
+
     def IPSFromFlatMapping(self, flat_mapping):
         """ Same as :func:`ScoreIPS` but with flat mapping
 
@@ -673,15 +730,15 @@ class ContactScorer:
 
     def _MappedInterfaceScores(self, int1, int2):
         key_one = (int1, int2)
-        if key_one in self._mapped_cache_ics:
-            return self._mapped_cache_ics[key_one]
+        if key_one in self._mapped_cache_interface:
+            return self._mapped_cache_interface[key_one]
         key_two = ((int1[1], int1[0]), (int2[1], int2[0]))
-        if key_two in self._mapped_cache_ics:
-            return self._mapped_cache_ics[key_two]
+        if key_two in self._mapped_cache_interface:
+            return self._mapped_cache_interface[key_two]
 
-        n_union, n_intersection = self._InterfaceScores(int1, int2)
-        self._mapped_cache_ics[key_one] = (n_union, n_intersection)
-        return (n_union, n_intersection)
+        a, b, c, d = self._InterfaceScores(int1, int2)
+        self._mapped_cache_interface[key_one] = (a, b, c, d)
+        return (a, b, c, d)
 
     def _InterfaceScores(self, int1, int2):
         if int1 in self.cent1.contacts:
@@ -715,14 +772,32 @@ class ContactScorer:
             mapped_c = (ch1_aln.GetPos(1, c[0]), ch2_aln.GetPos(1, c[1]))
             mapped_mdl_contacts.add(mapped_c)
 
-        return (len(mapped_ref_contacts.union(mapped_mdl_contacts)),
-                len(mapped_ref_contacts.intersection(mapped_mdl_contacts)))
+        contact_union = len(mapped_ref_contacts.union(mapped_mdl_contacts))
+        contact_intersection = len(mapped_ref_contacts.intersection(mapped_mdl_contacts))
+
+        # above, we computed the union and intersection on actual
+        # contacts. Here, we do the same on interface residues
+
+        # process interface residues of chain one in interface
+        tmp_ref = set([x[0] for x in mapped_ref_contacts])
+        tmp_mdl = set([x[0] for x in mapped_mdl_contacts])
+        intres_union = len(tmp_ref.union(tmp_mdl))
+        intres_intersection = len(tmp_ref.intersection(tmp_mdl))
+
+        # process interface residues of chain two in interface
+        tmp_ref = set([x[1] for x in mapped_ref_contacts])
+        tmp_mdl = set([x[1] for x in mapped_mdl_contacts])
+        intres_union += len(tmp_ref.union(tmp_mdl))
+        intres_intersection += len(tmp_ref.intersection(tmp_mdl))
+
+        return (contact_union, contact_intersection,
+                intres_union, intres_intersection)
 
     def _MappedSCScores(self, ref_ch, mdl_ch):
-        if (ref_ch, mdl_ch) in self._mapped_cache_ips:
-            return self._mapped_cache_ips[(ref_ch, mdl_ch)]
+        if (ref_ch, mdl_ch) in self._mapped_cache_sc:
+            return self._mapped_cache_sc[(ref_ch, mdl_ch)]
         n_union, n_intersection = self._SCScores(ref_ch, mdl_ch)
-        self._mapped_cache_ips[(ref_ch, mdl_ch)] = (n_union, n_intersection)
+        self._mapped_cache_sc[(ref_ch, mdl_ch)] = (n_union, n_intersection)
         return (n_union, n_intersection)
 
     def _SCScores(self, ch1, ch2):
diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 50a1e7faa..b570e5080 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -257,6 +257,9 @@ class Scorer:
         self._ips_precision = None
         self._ips_recall = None
         self._ips = None
+        self._per_interface_ics_precision = None
+        self._per_interface_ics_recall = None
+        self._per_interface_ics = None
 
         self._dockq_target_interfaces = None
         self._dockq_interfaces = None
@@ -829,6 +832,47 @@ class Scorer:
             self._compute_ips_scores()
         return self._ips
 
+    @property
+    def per_interface_ips_precision(self):
+        """ Per-interface IPS precision
+
+        :attr:`~ips_precision` for each interface in
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`list` of :class:`float`
+        """
+        if self._per_interface_ips_precision is None:
+            self._compute_ips_scores()
+        return self._per_interface_ips_precision
+
+
+    @property
+    def per_interface_ips_recall(self):
+        """ Per-interface IPS recall
+
+        :attr:`~ips_recall` for each interface in
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`list` of :class:`float`
+        """
+        if self._per_interface_ics_recall is None:
+            self._compute_ips_scores()
+        return self._per_interface_ips_recall
+
+    @property
+    def per_interface_ips(self):
+        """ Per-interface IPS (Interface Patch Similarity) score
+
+        :attr:`~ips` for each interface in 
+        :attr:`~contact_target_interfaces`
+
+        :type: :class:`list` of :class:`float`
+        """
+
+        if self._per_interface_ips is None:
+            self._compute_ips_scores()
+        return self._per_interface_ips
+
     @property
     def dockq_target_interfaces(self):
         """ Interfaces in :attr:`target` that are relevant for DockQ
@@ -1413,6 +1457,26 @@ class Scorer:
         self._ips_recall = contact_scorer_res.recall
         self._ips = contact_scorer_res.ips
 
+        self._per_interface_ips_precision = list()
+        self._per_interface_ips_recall = list()
+        self._per_interface_ips = list()
+        flat_mapping = self.mapping.GetFlatMapping()
+        for trg_int in self.contact_target_interfaces:
+            trg_ch1 = trg_int[0]
+            trg_ch2 = trg_int[1]
+            if trg_ch1 in flat_mapping and trg_ch2 in flat_mapping:
+                mdl_ch1 = flat_mapping[trg_ch1]
+                mdl_ch2 = flat_mapping[trg_ch2]
+                res = self.contact_scorer.ScoreIPSInterface(trg_ch1, trg_ch2,
+                                                            mdl_ch1, mdl_ch2)
+                self._per_interface_ips_precision.append(res.precision)
+                self._per_interface_ips_recall.append(res.recall)
+                self._per_interface_ips.append(res.ips)
+            else:
+                self._per_interface_ips_precision.append(None)
+                self._per_interface_ips_recall.append(None)
+                self._per_interface_ips.append(None)
+
     def _compute_dockq_scores(self):
         # lists with values in contact_target_interfaces
         self._dockq_scores = list()
-- 
GitLab