From d6debb7c888fc7ede2d5dcc208664734fb814b9a Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Tue, 24 Oct 2023 17:06:54 +0200 Subject: [PATCH] Contact scores - Enable per-interface IPS scoring requested by Andriy. Computes all IPS metrics on a per-interface basis --- actions/ost-compare-structures | 14 ++- modules/doc/actions.rst | 9 +- modules/mol/alg/pymod/contact_score.py | 123 ++++++++++++++++++++----- modules/mol/alg/pymod/scoring.py | 64 +++++++++++++ 4 files changed, 183 insertions(+), 27 deletions(-) diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index 0692c610e..15527423a 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -386,7 +386,14 @@ def _ParseArgs(): "also interface residues in the model. " "The IPS score (Interface Patch Similarity) available as key " "\"ips\" is the Jaccard coefficient between interface residues " - "in reference and model.")) + "in reference and model. " + "All these measures are also available on a per-interface basis " + "for each interface in the reference structure that are defined " + "as chain pairs with at least one contact (available as key " + " \"contact_reference_interfaces\"). The respective metrics are " + "available as keys \"per_interface_ips_precision\", " + "\"per_interface_ips_recall\" and \"per_interface_ips\".")) + parser.add_argument( "--rigid-scores", @@ -658,12 +665,12 @@ def _Process(model, reference, args): if args.ics or args.ips: out["reference_contacts"] = scorer.native_contacts out["model_contacts"] = scorer.model_contacts + out["contact_reference_interfaces"] = scorer.contact_target_interfaces if args.ics: out["ics_precision"] = scorer.ics_precision out["ics_recall"] = scorer.ics_recall out["ics"] = scorer.ics - out["contact_reference_interfaces"] = scorer.contact_target_interfaces out["per_interface_ics_precision"] = scorer.per_interface_ics_precision out["per_interface_ics_recall"] = scorer.per_interface_ics_recall out["per_interface_ics"] = scorer.per_interface_ics @@ -672,6 +679,9 @@ def _Process(model, reference, args): out["ips_precision"] = scorer.ips_precision out["ips_recall"] = scorer.ips_recall out["ips"] = scorer.ips + out["per_interface_ips_precision"] = scorer.per_interface_ips_precision + out["per_interface_ips_recall"] = scorer.per_interface_ips_recall + out["per_interface_ips"] = scorer.per_interface_ips if args.dockq: out["dockq_reference_interfaces"] = scorer.dockq_target_interfaces diff --git a/modules/doc/actions.rst b/modules/doc/actions.rst index eb315c341..bd5785bfd 100644 --- a/modules/doc/actions.rst +++ b/modules/doc/actions.rst @@ -293,7 +293,14 @@ Details on the usage (output of ``ost compare-structures --help``): also interface residues in the model. The IPS score (Interface Patch Similarity) available as key "ips" is the Jaccard coefficient between interface residues in - reference and model. + reference and model. All these measures are also + available on a per-interface basis for each interface + in the reference structure that are defined as chain + pairs with at least one contact (available as key + "contact_reference_interfaces"). The respective + metrics are available as keys + "per_interface_ips_precision", + "per_interface_ips_recall" and "per_interface_ips". --rigid-scores Computes rigid superposition based scores. They're based on a Kabsch superposition of all mapped CA positions (C3' for nucleotides). Makes the following diff --git a/modules/mol/alg/pymod/contact_score.py b/modules/mol/alg/pymod/contact_score.py index ef77a183d..b5ddeb10b 100644 --- a/modules/mol/alg/pymod/contact_score.py +++ b/modules/mol/alg/pymod/contact_score.py @@ -405,21 +405,25 @@ class ContactScorer: self._alns = alns # cache for mapped interface scores - # relevant to compute ICS # key: tuple of tuple ((qsent1_ch1, qsent1_ch2), # ((qsent2_ch1, qsent2_ch2)) - # value: tuple with two numbers required for computation of ICS - # 1: n_union - # 2: n_intersection - self._mapped_cache_ics = dict() - - # cache for mapped scores - # relevant to compute IPS + # value: tuple with four numbers required for computation of + # per-interface scores. + # The first two are relevant for ICS, the others for per + # interface IPS. + # 1: n_union_contacts + # 2: n_intersection_contacts + # 3: n_union_interface_residues + # 4: n_intersection_interface_residues + self._mapped_cache_interface = dict() + + # cache for mapped single chain scores + # for interface residues of single chains # key: tuple: (qsent1_ch, qsent2_ch) # value: tuple with two numbers required for computation of IPS # 1: n_union # 2: n_intersection - self._mapped_cache_ips = dict() + self._mapped_cache_sc = dict() @staticmethod def FromMappingResult(mapping_result, contact_mode="aa", contact_d = 5.0): @@ -565,7 +569,7 @@ class ContactScorer: else: n_mdl = 0 - n_union, n_intersection = self._MappedInterfaceScores(trg_int, mdl_int) + n_union, n_intersection, _, _ = self._MappedInterfaceScores(trg_int, mdl_int) return ContactScorerResultICS(n_trg, n_mdl, n_union, n_intersection) def ICSFromFlatMapping(self, flat_mapping): @@ -585,7 +589,7 @@ class ContactScorer: for int1 in self.cent1.interacting_chains: if int1[0] in flat_mapping and int1[1] in flat_mapping: int2 = (flat_mapping[int1[0]], flat_mapping[int1[1]]) - a, b = self._MappedInterfaceScores(int1, int2) + a, b, _, _ = self._MappedInterfaceScores(int1, int2) n_union += a n_intersection += b processed_cent2_interfaces.add((min(int2), max(int2))) @@ -596,7 +600,7 @@ class ContactScorer: if int2 not in processed_cent2_interfaces: if int2[0] in r_flat_mapping and int2[1] in r_flat_mapping: int1 = (r_flat_mapping[int2[0]], r_flat_mapping[int2[1]]) - a, b = self._MappedInterfaceScores(int1, int2) + a, b, _, _ = self._MappedInterfaceScores(int1, int2) n_union += a n_intersection += b @@ -640,6 +644,59 @@ class ContactScorer: return self.IPSFromFlatMapping(flat_mapping) + def ScoreIPSInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2): + """ Computes IPS scores only considering one interface + + This only works for interfaces that are computed in :func:`Score`, i.e. + interfaces for which the alignments are set up correctly. + + :param trg_ch1: Name of first interface chain in target + :type trg_ch1: :class:`str` + :param trg_ch2: Name of second interface chain in target + :type trg_ch2: :class:`str` + :param mdl_ch1: Name of first interface chain in model + :type mdl_ch1: :class:`str` + :param mdl_ch2: Name of second interface chain in model + :type mdl_ch2: :class:`str` + :returns: Result object of type :class:`ContactScorerResultIPS` + :raises: :class:`RuntimeError` if no aln for trg_ch1/mdl_ch1 or + trg_ch2/mdl_ch2 is available. + """ + if (trg_ch1, mdl_ch1) not in self.alns: + raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and " + f"mdl_ch1 ({mdl_ch1}) available. Did you " + f"construct the QSScorer object from a " + f"MappingResult and are trg_ch1 and mdl_ch1 " + f"mapped to each other?") + if (trg_ch2, mdl_ch2) not in self.alns: + raise RuntimeError(f"No aln between trg_ch1 ({trg_ch1}) and " + f"mdl_ch1 ({mdl_ch1}) available. Did you " + f"construct the QSScorer object from a " + f"MappingResult and are trg_ch1 and mdl_ch1 " + f"mapped to each other?") + trg_int = (trg_ch1, trg_ch2) + mdl_int = (mdl_ch1, mdl_ch2) + trg_int_r = (trg_ch2, trg_ch1) + mdl_int_r = (mdl_ch2, mdl_ch1) + + if trg_int in self.cent1.contacts: + n_trg = len(self.cent1.contacts[trg_int]) + elif trg_int_r in self.cent1.contacts: + n_trg = len(self.cent1.contacts[trg_int_r]) + else: + n_trg = 0 + + if mdl_int in self.cent2.contacts: + n_mdl = len(self.cent2.contacts[mdl_int]) + elif mdl_int_r in self.cent2.contacts: + n_mdl = len(self.cent2.contacts[mdl_int_r]) + else: + n_mdl = 0 + + _, _, n_union, n_intersection = self._MappedInterfaceScores(trg_int, mdl_int) + return ContactScorerResultIPS(n_trg, n_mdl, n_union, n_intersection) + + def IPSFromFlatMapping(self, flat_mapping): """ Same as :func:`ScoreIPS` but with flat mapping @@ -673,15 +730,15 @@ class ContactScorer: def _MappedInterfaceScores(self, int1, int2): key_one = (int1, int2) - if key_one in self._mapped_cache_ics: - return self._mapped_cache_ics[key_one] + if key_one in self._mapped_cache_interface: + return self._mapped_cache_interface[key_one] key_two = ((int1[1], int1[0]), (int2[1], int2[0])) - if key_two in self._mapped_cache_ics: - return self._mapped_cache_ics[key_two] + if key_two in self._mapped_cache_interface: + return self._mapped_cache_interface[key_two] - n_union, n_intersection = self._InterfaceScores(int1, int2) - self._mapped_cache_ics[key_one] = (n_union, n_intersection) - return (n_union, n_intersection) + a, b, c, d = self._InterfaceScores(int1, int2) + self._mapped_cache_interface[key_one] = (a, b, c, d) + return (a, b, c, d) def _InterfaceScores(self, int1, int2): if int1 in self.cent1.contacts: @@ -715,14 +772,32 @@ class ContactScorer: mapped_c = (ch1_aln.GetPos(1, c[0]), ch2_aln.GetPos(1, c[1])) mapped_mdl_contacts.add(mapped_c) - return (len(mapped_ref_contacts.union(mapped_mdl_contacts)), - len(mapped_ref_contacts.intersection(mapped_mdl_contacts))) + contact_union = len(mapped_ref_contacts.union(mapped_mdl_contacts)) + contact_intersection = len(mapped_ref_contacts.intersection(mapped_mdl_contacts)) + + # above, we computed the union and intersection on actual + # contacts. Here, we do the same on interface residues + + # process interface residues of chain one in interface + tmp_ref = set([x[0] for x in mapped_ref_contacts]) + tmp_mdl = set([x[0] for x in mapped_mdl_contacts]) + intres_union = len(tmp_ref.union(tmp_mdl)) + intres_intersection = len(tmp_ref.intersection(tmp_mdl)) + + # process interface residues of chain two in interface + tmp_ref = set([x[1] for x in mapped_ref_contacts]) + tmp_mdl = set([x[1] for x in mapped_mdl_contacts]) + intres_union += len(tmp_ref.union(tmp_mdl)) + intres_intersection += len(tmp_ref.intersection(tmp_mdl)) + + return (contact_union, contact_intersection, + intres_union, intres_intersection) def _MappedSCScores(self, ref_ch, mdl_ch): - if (ref_ch, mdl_ch) in self._mapped_cache_ips: - return self._mapped_cache_ips[(ref_ch, mdl_ch)] + if (ref_ch, mdl_ch) in self._mapped_cache_sc: + return self._mapped_cache_sc[(ref_ch, mdl_ch)] n_union, n_intersection = self._SCScores(ref_ch, mdl_ch) - self._mapped_cache_ips[(ref_ch, mdl_ch)] = (n_union, n_intersection) + self._mapped_cache_sc[(ref_ch, mdl_ch)] = (n_union, n_intersection) return (n_union, n_intersection) def _SCScores(self, ch1, ch2): diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py index 50a1e7faa..b570e5080 100644 --- a/modules/mol/alg/pymod/scoring.py +++ b/modules/mol/alg/pymod/scoring.py @@ -257,6 +257,9 @@ class Scorer: self._ips_precision = None self._ips_recall = None self._ips = None + self._per_interface_ics_precision = None + self._per_interface_ics_recall = None + self._per_interface_ics = None self._dockq_target_interfaces = None self._dockq_interfaces = None @@ -829,6 +832,47 @@ class Scorer: self._compute_ips_scores() return self._ips + @property + def per_interface_ips_precision(self): + """ Per-interface IPS precision + + :attr:`~ips_precision` for each interface in + :attr:`~contact_target_interfaces` + + :type: :class:`list` of :class:`float` + """ + if self._per_interface_ips_precision is None: + self._compute_ips_scores() + return self._per_interface_ips_precision + + + @property + def per_interface_ips_recall(self): + """ Per-interface IPS recall + + :attr:`~ips_recall` for each interface in + :attr:`~contact_target_interfaces` + + :type: :class:`list` of :class:`float` + """ + if self._per_interface_ics_recall is None: + self._compute_ips_scores() + return self._per_interface_ips_recall + + @property + def per_interface_ips(self): + """ Per-interface IPS (Interface Patch Similarity) score + + :attr:`~ips` for each interface in + :attr:`~contact_target_interfaces` + + :type: :class:`list` of :class:`float` + """ + + if self._per_interface_ips is None: + self._compute_ips_scores() + return self._per_interface_ips + @property def dockq_target_interfaces(self): """ Interfaces in :attr:`target` that are relevant for DockQ @@ -1413,6 +1457,26 @@ class Scorer: self._ips_recall = contact_scorer_res.recall self._ips = contact_scorer_res.ips + self._per_interface_ips_precision = list() + self._per_interface_ips_recall = list() + self._per_interface_ips = list() + flat_mapping = self.mapping.GetFlatMapping() + for trg_int in self.contact_target_interfaces: + trg_ch1 = trg_int[0] + trg_ch2 = trg_int[1] + if trg_ch1 in flat_mapping and trg_ch2 in flat_mapping: + mdl_ch1 = flat_mapping[trg_ch1] + mdl_ch2 = flat_mapping[trg_ch2] + res = self.contact_scorer.ScoreIPSInterface(trg_ch1, trg_ch2, + mdl_ch1, mdl_ch2) + self._per_interface_ips_precision.append(res.precision) + self._per_interface_ips_recall.append(res.recall) + self._per_interface_ips.append(res.ips) + else: + self._per_interface_ips_precision.append(None) + self._per_interface_ips_recall.append(None) + self._per_interface_ips.append(None) + def _compute_dockq_scores(self): # lists with values in contact_target_interfaces self._dockq_scores = list() -- GitLab