diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index 595c18b2e320726f97e1c994f049cb69f7497916..8453cf6871fcd931ee3ed60fc7fff383f3150daa 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -342,24 +342,46 @@ def _ParseArgs(): "present in the reference but not in the model.")) parser.add_argument( - "--contact-scores", - dest="contact_scores", + "--ics", + dest="ics", default=False, action="store_true", - help=("Computes interface contact based scores. A contact between two " - "residues of different chains is defined as having at least one " - "heavy atom within 5A. Contacts in reference structure are " - "available as key \"reference_contacts\". Each contact specifies " - "the interacting residues in format " - "\"<cname>.<rnum>.<ins_code>\". Model contacts are available as " - "key \"model_contacts\". The precision which is available as key " - "\"contact_precision\" reports the fraction of model contacts " - "that are also present in the reference. The recall which is " - "available as key \"contact_recall\" reports the fraction of " - "reference contacts that are correctly reproduced in the model. " + help=("Computes interface contact similarity (ICS) related scores. " + "A contact between two residues of different chains is defined " + "as having at least one heavy atom within 5A. Contacts in " + "reference structure are available as key " + "\"reference_contacts\". Each contact specifies the interacting " + "residues in format \"<cname>.<rnum>.<ins_code>\". Model " + "contacts are available as key \"model_contacts\". The precision " + "which is available as key \"ics_precision\" reports the " + "fraction of model contacts that are also present in the " + "reference. The recall which is available as key \"ics_recall\" " + "reports the fraction of reference contacts that are correctly " + "reproduced in the model. " "The ICS score (Interface Contact Similarity) available as key " "\"ics\" combines precision and recall using the F1-measure.")) + parser.add_argument( + "--ips", + dest="ips", + default=False, + action="store_true", + help=("Computes interface patch similarity (IPS) related scores. " + "They focus on interface residues. They are defined as having " + "at least one contact to a residue from any other chain. " + "In short: if they show up in the contact lists used to compute " + "ICS. If ips is enabled, these contacts get reported too and are " + "available as keys \"reference_contacts\" and \"model_contacts\"." + "The precision which is available as key \"ips_precision\" " + "reports the fraction of model interface residues, that are also " + "interface residues in the reference. " + "The recall which is available as key \"ips_recall\" " + "reports the fraction of reference interface residues that are " + "also interface residues in the model. " + "The IPS score (Interface Patch Similarity) available as key " + "\"ips\" is the Jaccard coefficient between interface residues " + "in reference and model.")) + parser.add_argument( "--rigid-scores", dest="rigid_scores", @@ -627,13 +649,20 @@ def _Process(model, reference, args): out["per_interface_qs_global"] = scorer.per_interface_qs_global out["per_interface_qs_best"] = scorer.per_interface_qs_best - if args.contact_scores: + if args.ics or args.ips: out["reference_contacts"] = scorer.native_contacts out["model_contacts"] = scorer.model_contacts - out["contact_precision"] = scorer.contact_precision - out["contact_recall"] = scorer.contact_recall + + if args.ics: + out["ics_precision"] = scorer.ics_precision + out["ics_recall"] = scorer.ics_recall out["ics"] = scorer.ics + if args.ips: + out["ips_precision"] = scorer.ips_precision + out["ips_recall"] = scorer.ips_recall + out["ips"] = scorer.ips + if args.dockq: out["dockq_reference_interfaces"] = scorer.dockq_target_interfaces out["dockq_interfaces"] = scorer.dockq_interfaces @@ -646,10 +675,6 @@ def _Process(model, reference, args): out["dockq_ave_full"] = scorer.dockq_ave_full out["dockq_wave_full"] = scorer.dockq_wave_full - if args.contact_scores: - out["reference_contacts"] = scorer.native_contacts - out["model_contacts"] = scorer.model_contacts - if args.rigid_scores: out["oligo_gdtts"] = scorer.gdtts out["oligo_gdtha"] = scorer.gdtha diff --git a/modules/doc/actions.rst b/modules/doc/actions.rst index 01b02f0c8aa6b9eb4c3c80b5d919324b8947fcd9..3faf8c62faea9ab5999cf6283739c6cc87b8cea4 100644 --- a/modules/doc/actions.rst +++ b/modules/doc/actions.rst @@ -255,22 +255,38 @@ Details on the usage (output of ``ost compare-structures --help``): and "dockq_wave_full" add zeros in the average computation for each interface that is only present in the reference but not in the model. - --contact-scores Computes interface contact based scores. A contact - between two residues of different chains is defined as - having at least one heavy atom within 5A. Contacts in - reference structure are available as key - "reference_contacts". Each contact specifies the - interacting residues in format + --ics Computes interface contact similarity (ICS) related + scores. A contact between two residues of different + chains is defined as having at least one heavy atom + within 5A. Contacts in reference structure are + available as key "reference_contacts". Each contact + specifies the interacting residues in format "<cname>.<rnum>.<ins_code>". Model contacts are available as key "model_contacts". The precision which - is available as key "contact_precision" reports the + is available as key "ics_precision" reports the fraction of model contacts that are also present in the reference. The recall which is available as key - "contact_recall" reports the fraction of reference + "ics_recall" reports the fraction of reference contacts that are correctly reproduced in the model. The ICS score (Interface Contact Similarity) available as key "ics" combines precision and recall using the F1-measure. + --ips Computes interface patch similarity (IPS) related + scores. They focus on interface residues. They are + defined as having at least one contact to a residue + from any other chain. In short: if they show up in the + contact lists used to compute ICS. If ips is enabled, + these contacts get reported too and are available as + keys "reference_contacts" and "model_contacts".The + precision which is available as key "ips_precision" + reports the fraction of model interface residues, that + are also interface residues in the reference. The + recall which is available as key "ips_recall" reports + the fraction of reference interface residues that are + also interface residues in the model. The IPS score + (Interface Patch Similarity) available as key "ips" is + the Jaccard coefficient between interface residues in + reference and model. --rigid-scores Computes rigid superposition based scores. They're based on a Kabsch superposition of all mapped CA positions (C3' for nucleotides). Makes the following diff --git a/modules/mol/alg/pymod/contact_score.py b/modules/mol/alg/pymod/contact_score.py index 500e9043f48759cf63ad5136295990884dab991d..256ff4d8faf390442843cc3a722cf605f238e907 100644 --- a/modules/mol/alg/pymod/contact_score.py +++ b/modules/mol/alg/pymod/contact_score.py @@ -55,6 +55,8 @@ class ContactEntity: self._sequence = dict() self._contacts = None self._hr_contacts = None + self._interface_residues = None + self._hr_interface_residues = None @property def view(self): @@ -135,6 +137,30 @@ class ContactEntity: if self._hr_contacts is None: self._SetupContacts() return self._hr_contacts + + @property + def interface_residues(self): + """ Interface residues + + Residues in each chain that are in contact with any other chain. + Organized as :class:`dict` with key cname and values the respective + residue indices in a :class:`set`. + """ + if self._interface_residues is None: + self._SetupInterfaceResidues() + return self._interface_residues + + @property + def hr_interface_residues(self): + """ Human readable interface residues + + Human readable version of :attr:`interface_residues`. :class:`list` of + strings specifying the interface residues in format: + <cname>.<rnum>.<ins_code> + """ + if self._interface_residues is None: + self._SetupHRInterfaceResidues() + return self._hr_interface_residues def GetChain(self, chain_name): """ Get chain by name @@ -199,9 +225,28 @@ class ContactEntity: self._hr_contacts.append((hr1.strip("\u0000"), hr2.strip("\u0000"))) -class ContactScorerResult: + def _SetupInterfaceResidues(self): + self._interface_residues = dict() + for k,v in self.contacts.items(): + if k[0] not in self._interface_residues: + self._interface_residues[k[0]] = set() + if k[1] not in self._interface_residues: + self._interface_residues[k[1]] = set() + for item in v: + self._interface_residues[k[0]].add(item[0]) + self._interface_residues[k[1]].add(item[1]) + + def _SetupHRInterfaceResidues(self): + interface_residues = set() + for item in self.hr_contacts: + interface_residues.add(item[0]) + interface_residues.add(item[1]) + self._hr_interface_residues = list(interface_residues) + + +class ContactScorerResultICS: """ - Holds data relevant to compute contact scores + Holds data relevant to compute ics """ def __init__(self, n_trg_contacts, n_mdl_contacts, n_union, n_intersection): self._n_trg_contacts = n_trg_contacts @@ -257,6 +302,67 @@ class ContactScorerResult: r = self.recall return 2*p*r/(p+r) +class ContactScorerResultIPS: + """ + Holds data relevant to compute ips + """ + def __init__(self, n_trg_int_res, n_mdl_int_res, n_union, n_intersection): + self._n_trg_int_res = n_trg_int_res + self._n_mdl_int_res = n_mdl_int_res + self._n_union = n_union + self._n_intersection = n_intersection + + @property + def n_trg_int_res(self): + """ Number of interface residues in target + + :type: :class:`int` + """ + return self._n_trg_contacts + + @property + def n_mdl_int_res(self): + """ Number of interface residues in model + + :type: :class:`int` + """ + return self._n_mdl_int_res + + @property + def precision(self): + """ Precision of model interface residues + + The fraction of model interface residues that are also interface + residues in target + + :type: :class:`int` + """ + return self._n_intersection / self._n_mdl_int_res + + @property + def recall(self): + """ Recall of model interface residues + + The fraction of target interface residues that are also interface + residues in model + + :type: :class:`int` + """ + return self._n_intersection / self._n_trg_int_res + + @property + def ips(self): + """ The Interface Patch Similarity score (IPS) + + Jaccard coefficient of interface residues in model/target. + Technically thats :attr:`intersection`/:attr:`union` + + :type: :class:`float` + """ + if(self._n_union > 0): + return self._n_intersection/self._n_union + return 0.0 + class ContactScorer: """ Helper object to compute Contact scores @@ -286,13 +392,21 @@ class ContactScorer: self._alns = alns # cache for mapped interface scores + # relevant to compute ICS # key: tuple of tuple ((qsent1_ch1, qsent1_ch2), # ((qsent2_ch1, qsent2_ch2)) - # value: tuple with two numbers required for computation of common - # contact based scores + # value: tuple with two numbers required for computation of ICS + # 1: n_union + # 2: n_intersection + self._mapped_cache_ics = dict() + + # cache for mapped scores + # relevant to compute IPS + # key: tuple: (qsent1_ch, qsent2_ch) + # value: tuple with two numbers required for computation of IPS # 1: n_union # 2: n_intersection - self._mapped_cache = dict() + self._mapped_cache_ips = dict() @staticmethod def FromMappingResult(mapping_result, contact_mode="aa", contact_d = 5.0): @@ -352,8 +466,8 @@ class ContactScorer: """ return self._alns - def Score(self, mapping, check=True): - """ Computes contact scores given chain mapping + def ScoreICS(self, mapping, check=True): + """ Computes ICS given chain mapping Again, the preferred way is to get *mapping* is from an object of type :class:`ost.mol.alg.chain_mapping.MappingResult`. @@ -364,7 +478,7 @@ class ContactScorer: :param check: Perform input checks, can be disabled for speed purposes if you know what you're doing. :type check: :class:`bool` - :returns: Result object of type :class:`ContactScorerResult` + :returns: Result object of type :class:`ContactScorerResultICS` """ if check: @@ -387,10 +501,10 @@ class ContactScorer: for a, b in zip(self.chem_groups, mapping): flat_mapping.update({x: y for x, y in zip(a, b) if y is not None}) - return self.FromFlatMapping(flat_mapping) + return self.ICSFromFlatMapping(flat_mapping) def ScoreInterface(self, trg_ch1, trg_ch2, mdl_ch1, mdl_ch2): - """ Computes contact scores only considering one interface + """ Computes ICS scores only considering one interface This only works for interfaces that are computed in :func:`Score`, i.e. interfaces for which the alignments are set up correctly. @@ -403,7 +517,7 @@ class ContactScorer: :type mdl_ch1: :class:`str` :param mdl_ch2: Name of second interface chain in model :type mdl_ch2: :class:`str` - :returns: Result object of type :class:`ContactScorerResult` + :returns: Result object of type :class:`ContactScorerResultICS` :raises: :class:`RuntimeError` if no aln for trg_ch1/mdl_ch1 or trg_ch2/mdl_ch2 is available. """ @@ -439,15 +553,15 @@ class ContactScorer: n_mdl = 0 n_union, n_intersection = self._MappedInterfaceScores(trg_int, mdl_int) - return ContactScorerResult(n_trg, n_mdl, n_union, n_intersection) + return ContactScorerResultICS(n_trg, n_mdl, n_union, n_intersection) - def FromFlatMapping(self, flat_mapping): - """ Same as :func:`Score` but with flat mapping + def ICSFromFlatMapping(self, flat_mapping): + """ Same as :func:`ScoreICS` but with flat mapping :param flat_mapping: Dictionary with target chain names as keys and the mapped model chain names as value :type flat_mapping: :class:`dict` with :class:`str` as key and value - :returns: Result object of type :class:`ContactScorerResult` + :returns: Result object of type :class:`ContactScorerResultICS` """ n_trg = sum([len(x) for x in self.cent1.contacts.values()]) n_mdl = sum([len(x) for x in self.cent2.contacts.values()]) @@ -473,19 +587,87 @@ class ContactScorer: n_union += a n_intersection += b - return ContactScorerResult(n_trg, n_mdl, - n_union, n_intersection) + return ContactScorerResultICS(n_trg, n_mdl, + n_union, n_intersection) + + def ScoreIPS(self, mapping, check=True): + """ Computes IPS given chain mapping + + Again, the preferred way is to get *mapping* is from an object + of type :class:`ost.mol.alg.chain_mapping.MappingResult`. + + :param mapping: see + :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping` + :type mapping: :class:`list` of :class:`list` of :class:`str` + :param check: Perform input checks, can be disabled for speed purposes + if you know what you're doing. + :type check: :class:`bool` + :returns: Result object of type :class:`ContactScorerResultIPS` + """ + + if check: + # ensure that dimensionality of mapping matches self.chem_groups + if len(self.chem_groups) != len(mapping): + raise RuntimeError("Dimensions of self.chem_groups and mapping " + "must match") + for a,b in zip(self.chem_groups, mapping): + if len(a) != len(b): + raise RuntimeError("Dimensions of self.chem_groups and " + "mapping must match") + # ensure that chain names in mapping are all present in cent2 + for name in itertools.chain.from_iterable(mapping): + if name is not None and name not in self.cent2.chain_names: + raise RuntimeError(f"Each chain in mapping must be present " + f"in self.cent2. No match for " + f"\"{name}\"") + + flat_mapping = dict() + for a, b in zip(self.chem_groups, mapping): + flat_mapping.update({x: y for x, y in zip(a, b) if y is not None}) + + return self.IPSFromFlatMapping(flat_mapping) + + def IPSFromFlatMapping(self, flat_mapping): + """ Same as :func:`ScoreIPS` but with flat mapping + + :param flat_mapping: Dictionary with target chain names as keys and + the mapped model chain names as value + :type flat_mapping: :class:`dict` with :class:`str` as key and value + :returns: Result object of type :class:`ContactScorerResultIPS` + """ + n_trg = sum([len(x) for x in self.cent1.interface_residues.values()]) + n_mdl = sum([len(x) for x in self.cent2.interface_residues.values()]) + n_union = 0 + n_intersection = 0 + + processed_cent2_chains = set() + for trg_ch in self.cent1.chain_names: + if trg_ch in flat_mapping: + a, b = self._MappedSCScores(trg_ch, flat_mapping[trg_ch]) + n_union += a + n_intersection += b + processed_cent2_chains.add(flat_mapping[trg_ch]) + else: + n_union += len(self.cent1.interface_residues[trg_ch]) + + for mdl_ch in self._cent2.chain_names: + if mdl_ch not in processed_cent2_chains: + n_union += len(self.cent2.interface_residues[mdl_ch]) + + return ContactScorerResultIPS(n_trg, n_mdl, + n_union, n_intersection) + def _MappedInterfaceScores(self, int1, int2): key_one = (int1, int2) - if key_one in self._mapped_cache: - return self._mapped_cache[key_one] + if key_one in self._mapped_cache_ics: + return self._mapped_cache_ics[key_one] key_two = ((int1[1], int1[0]), (int2[1], int2[0])) - if key_two in self._mapped_cache: - return self._mapped_cache[key_two] + if key_two in self._mapped_cache_ics: + return self._mapped_cache_ics[key_two] n_union, n_intersection = self._InterfaceScores(int1, int2) - self._mapped_cache[key_one] = (n_union, n_intersection) + self._mapped_cache_ics[key_one] = (n_union, n_intersection) return (n_union, n_intersection) def _InterfaceScores(self, int1, int2): @@ -523,5 +705,25 @@ class ContactScorer: return (len(mapped_ref_contacts.union(mapped_mdl_contacts)), len(mapped_ref_contacts.intersection(mapped_mdl_contacts))) + def _MappedSCScores(self, ref_ch, mdl_ch): + if (ref_ch, mdl_ch) in self._mapped_cache_ips: + return self._mapped_cache_ips[(ref_ch, mdl_ch)] + n_union, n_intersection = self._SCScores(ref_ch, mdl_ch) + self._mapped_cache_ips[(ref_ch, mdl_ch)] = (n_union, n_intersection) + return (n_union, n_intersection) + + def _SCScores(self, ch1, ch2): + ref_int_res = self.cent1.interface_residues[ch1] + mdl_int_res = self.cent2.interface_residues[ch2] + aln = self.alns[(ch1, ch2)] + mapped_ref_int_res = set() + mapped_mdl_int_res = set() + for r_idx in ref_int_res: + mapped_ref_int_res.add(aln.GetPos(0, r_idx)) + for r_idx in mdl_int_res: + mapped_mdl_int_res.add(aln.GetPos(1, r_idx)) + return(len(mapped_ref_int_res.union(mapped_mdl_int_res)), + len(mapped_ref_int_res.intersection(mapped_mdl_int_res))) + # specify public interface -__all__ = ('ContactEntity', 'ContactScorerResult', 'ContactScorer') +__all__ = ('ContactEntity', 'ContactScorerResultICS', 'ContactScorerResultIPS', 'ContactScorer') diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py index 2d58dfff665e8929c844c206e4c2baac26f09fa5..e73f9d7e9b91e02aa36be054411ef812f457c9e4 100644 --- a/modules/mol/alg/pymod/scoring.py +++ b/modules/mol/alg/pymod/scoring.py @@ -248,9 +248,12 @@ class Scorer: self._contact_model_interfaces = None self._native_contacts = None self._model_contacts = None - self._contact_precision = None - self._contact_recall = None + self._ics_precision = None + self._ics_recall = None self._ics = None + self._ips_precision = None + self._ips_recall = None + self._ips = None self._dockq_target_interfaces = None self._dockq_interfaces = None @@ -714,38 +717,73 @@ class Scorer: return self._contact_model_interfaces @property - def contact_precision(self): + def ics_precision(self): """ Fraction of model contacts that are also present in target :type: :class:`float` """ - if self._contact_precision is None: - self._compute_contact_scores() - return self._contact_precision + if self._ics_precision is None: + self._compute_ics_scores() + return self._ics_precision @property - def contact_recall(self): + def ics_recall(self): """ Fraction of target contacts that are correctly reproduced in model :type: :class:`float` """ - if self._contact_recall is None: - self._compute_contact_scores() - return self._contact_recall + if self._ics_recall is None: + self._compute_ics_scores() + return self._ics_recall @property def ics(self): """ ICS (Interface Contact Similarity) score - Combination of :attr:`~contact_precision` and :attr:`~contact_recall` + Combination of :attr:`~ics_precision` and :attr:`~ics_recall` using the F1-measure :type: :class:`float` """ if self._ics is None: - self._compute_contact_scores() + self._compute_ics_scores() return self._ics + @property + def ips_precision(self): + """ Fraction of model interface residues that are also interface + residues in target + + :type: :class:`float` + """ + if self._ips_precision is None: + self._compute_ips_scores() + return self._ips_precision + + @property + def ips_recall(self): + """ Fraction of target interface residues that are also interface + residues in model + + :type: :class:`float` + """ + if self._ips_recall is None: + self._compute_ips_scores() + return self._ips_recall + + @property + def ips(self): + """ IPS (Interface Patch Similarity) score + + Jaccard coefficient of interface residues in target and their mapped + counterparts in model + + :type: :class:`float` + """ + if self._ips is None: + self._compute_ips_scores() + return self._ips + @property def dockq_target_interfaces(self): """ Interfaces in :attr:`target` that are relevant for DockQ @@ -1299,14 +1337,19 @@ class Scorer: self._per_interface_qs_best.append(qs_res.QS_best) self._per_interface_qs_global.append(qs_res.QS_global) - def _compute_contact_scores(self): - contact_scorer_res = self.contact_scorer.Score(self.mapping.mapping) - self._contact_precision = contact_scorer_res.precision - self._contact_recall = contact_scorer_res.recall + def _compute_ics_scores(self): + contact_scorer_res = self.contact_scorer.ScoreICS(self.mapping.mapping) + self._ics_precision = contact_scorer_res.precision + self._ics_recall = contact_scorer_res.recall self._ics = contact_scorer_res.ics - def _compute_dockq_scores(self): + def _compute_ips_scores(self): + contact_scorer_res = self.contact_scorer.ScoreIPS(self.mapping.mapping) + self._ips_precision = contact_scorer_res.precision + self._ips_recall = contact_scorer_res.recall + self._ips = contact_scorer_res.ips + def _compute_dockq_scores(self): # lists with values in contact_target_interfaces self._dockq_scores = list() self._fnat = list() diff --git a/modules/mol/alg/tests/test_contact_score.py b/modules/mol/alg/tests/test_contact_score.py index 940f14eaf50898b3bd06835c932178e438a6b4ad..1a3a53e529d4ec4d090f82d74417cbde426ddae4 100644 --- a/modules/mol/alg/tests/test_contact_score.py +++ b/modules/mol/alg/tests/test_contact_score.py @@ -56,11 +56,16 @@ class TestContactScore(unittest.TestCase): mapper = ChainMapper(target) res = mapper.GetRigidMapping(model, strategy="greedy_iterative_rmsd") contact_scorer = ContactScorer.FromMappingResult(res) - score_result = contact_scorer.Score(res.mapping) + score_result = contact_scorer.ScoreICS(res.mapping) self.assertAlmostEqual(score_result.precision, 0.583, places=2) self.assertAlmostEqual(score_result.recall, 0.288, places=2) self.assertAlmostEqual(score_result.ics, 0.386, places=2) + score_result = contact_scorer.ScoreIPS(res.mapping) + self.assertAlmostEqual(score_result.precision, 0.779, places=2) + self.assertAlmostEqual(score_result.recall, 0.493, places=2) + self.assertAlmostEqual(score_result.ips, 0.432, places=2) + if __name__ == "__main__": from ost import testutils if testutils.DefaultCompoundLibIsSet():