From 96805edcc81c4eb62b6ac3ad484903a0c0d32856 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Tue, 19 Apr 2022 09:46:34 +0200 Subject: [PATCH] refactor Oligo lDDT scoring - BEHAVIOUR CHANGES before: Scoring was only affected by mapped chains. Penalties for additional chains in reference/model could be added by enabling "penalize_extra_chains" flag. now: Non-mapped chains from the reference penalize the lDDT score. Penalties for additional chains in model can be added by enabling "penalize_extra_chains" flag. --- actions/ost-compare-structures | 220 ++++---- modules/mol/alg/pymod/qsscoring.py | 687 ++++++++---------------- modules/mol/alg/tests/test_qsscoring.py | 7 +- 3 files changed, 322 insertions(+), 592 deletions(-) diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index 7aaaf6914..8a3f3b0cc 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -67,12 +67,10 @@ import ost from ost.io import (LoadPDB, LoadMMCIF, SavePDB, MMCifInfoBioUnit, MMCifInfo, MMCifInfoTransOp, ReadStereoChemicalPropsFile, profiles) from ost import PushVerbosityLevel -from ost.mol.alg import (qsscoring, Molck, MolckSettings, lDDTSettings, - CheckStructure, ResidueNamesMatch) +from ost.mol.alg import (qsscoring, lddt, Molck, MolckSettings, lDDTSettings, + CheckStructure) from ost.conop import (CompoundLib, SetDefaultLib, GetDefaultLib, RuleBasedProcessor) -from ost.seq.alg.renumber import Renumber - def _GetDefaultShareFilePath(filename): """Look for filename in working directory and OST shared data path. @@ -520,10 +518,19 @@ def _RevertChainNames(ent): def _CheckConsistency(alignments, log_error): is_cons = True for alignment in alignments: - ref_chain = Renumber(alignment.GetSequence(0)).CreateFullView() - mdl_chain = Renumber(alignment.GetSequence(1)).CreateFullView() - new_is_cons = ResidueNamesMatch(mdl_chain, ref_chain, log_error) - is_cons = is_cons and new_is_cons + for col in alignment: + r1 = col.GetResidue(0) + r2 = col.GetResidue(1) + if r1.IsValid() and r2.IsValid(): + if r1.GetName() != r2.GetName(): + is_cons = False + msg = f"Name mismatch for model residue {r2}: in the " + msg += f"reference structure(s) is {r1.GetName()} ({r1}), " + msg += f"in the model {r2.GetName()}" + if log_error: + ost.LogError(msg) + else: + ost.LogWarning(msg) return is_cons @@ -837,25 +844,24 @@ def _Main(): "best_score": 0.0} # Calculate lDDT if opts.lddt: + + lddt_settings = lDDTSettings(radius=opts.inclusion_radius, + sequence_separation=opts.sequence_separation, label="lddt") + ost.LogInfo("-" * 80) ost.LogInfo("Computing lDDT scores") - lddt_results = { - "single_chain_lddt": list() - } - lddt_settings = lDDTSettings( - radius=opts.inclusion_radius, - sequence_separation=opts.sequence_separation, - label="lddt") ost.LogInfo("lDDT settings: ") ost.LogInfo(str(lddt_settings).rstrip()) ost.LogInfo("===") - oligo_lddt_scorer = qs_scorer.GetOligoLDDTScorer(lddt_settings) - for mapped_lddt_scorer in oligo_lddt_scorer.mapped_lddt_scorers: - # Get data - lddt_scorer = mapped_lddt_scorer.lddt_scorer - model_chain = mapped_lddt_scorer.model_chain_name - reference_chain = mapped_lddt_scorer.reference_chain_name - if skip_score: + + lddt_results = { + "single_chain_lddt": list() + } + + if skip_score: + for aln in qs_scorer.alignments: + reference_chain = aln.GetSequence(0).GetName() + model_chain = aln.GetSequence(1).GetName() ost.LogInfo( " --> Skipping single chain lDDT because " "consistency check failed") @@ -867,107 +873,97 @@ def _Main(): "global_score": 0.0, "conserved_contacts": 0.0, "total_contacts": 0.0}) - else: - try: - ost.LogInfo((" --> Computing lDDT between model " - "chain %s and reference chain %s") % ( - model_chain, - reference_chain)) - ost.LogInfo("Global LDDT score: %.4f" % - lddt_scorer.global_score) - ost.LogInfo( - "(%i conserved distances out of %i checked, over " - "%i thresholds)" % (lddt_scorer.conserved_contacts, - lddt_scorer.total_contacts, - len(lddt_settings.cutoffs))) - sc_lddt_scores = { - "status": "SUCCESS", - "error": "", - "model_chain": model_chain, - "reference_chain": reference_chain, - "global_score": lddt_scorer.global_score, - "conserved_contacts": - lddt_scorer.conserved_contacts, - "total_contacts": lddt_scorer.total_contacts} - if opts.save_per_residue_scores: - per_residue_sc = \ - mapped_lddt_scorer.GetPerResidueScores() - ost.LogInfo("Per residue local lDDT (reference):") - ost.LogInfo("Chain\tResidue Number\tResidue Name" - "\tlDDT\tConserved Contacts\tTotal " - "Contacts") - for prs_scores in per_residue_sc: - ost.LogInfo("%s\t%i\t%s\t%.4f\t%i\t%i" % ( - reference_chain, - prs_scores["residue_number"], - prs_scores["residue_name"], - prs_scores["lddt"], - prs_scores["conserved_contacts"], - prs_scores["total_contacts"])) - sc_lddt_scores["per_residue_scores"] = \ - per_residue_sc - lddt_results["single_chain_lddt"].append( - sc_lddt_scores) - except Exception as ex: - ost.LogError('Single chain lDDT failed:', str(ex)) - lddt_results["single_chain_lddt"].append({ - "status": "FAILURE", - "error": str(ex), - "model_chain": model_chain, - "reference_chain": reference_chain, - "global_score": 0.0, - "conserved_contacts": 0.0, - "total_contacts": 0.0}) - # perform oligo lddt scoring - if skip_score: + ost.LogInfo( " --> Skipping oligomeric lDDT because consistency " "check failed") lddt_results["oligo_lddt"] = { "status": "FAILURE", "error": "Consistency check failed.", - "global_score": 0.0} + "global_score": 0.0, + "conserved_contacts": 0, + "total_contacts": 0} + else: - try: - ost.LogInfo(' --> Computing oligomeric lDDT score') - lddt_results["oligo_lddt"] = { + lddt_scorer = qs_scorer.GetOligoLDDTScorer(lddt_settings, + penalize_extra_chains=True) + + # do single chain lDDT + scores = lddt_scorer.sc_lddt + per_res = lddt_scorer.sc_lddt_per_res + tot = lddt_scorer.sc_lddt_tot + cons = lddt_scorer.sc_lddt_cons + alns = qs_scorer.alignments + + for idx in range(len(scores)): + ref_chain = alns[idx].GetSequence(0).GetName() + model_chain = alns[idx].GetSequence(1).GetName() + + ost.LogInfo((" --> Computing lDDT between model " + "chain %s and reference chain %s") % ( + model_chain, + ref_chain)) + ost.LogInfo("Global LDDT score: %.4f" %scores[idx]) + ost.LogInfo("(%i conserved distances out of %i checked, over " + "%i thresholds)" % (cons[idx], tot[idx], 4)) + + sc_lddt_scores = { "status": "SUCCESS", "error": "", - "global_score": oligo_lddt_scorer.oligo_lddt} - ost.LogInfo( - "Oligo lDDT score: %.4f" % - oligo_lddt_scorer.oligo_lddt) - except Exception as ex: - ost.LogError('Oligo lDDT failed:', str(ex)) - lddt_results["oligo_lddt"] = { - "status": "FAILURE", - "error": str(ex), - "global_score": 0.0} - if skip_score: + "model_chain": model_chain, + "reference_chain": ref_chain, + "global_score": scores[idx], + "conserved_contacts": cons[idx], + "total_contacts": tot[idx]} + + if opts.save_per_residue_scores: + tmp = [x for x in per_res if x["chain"]==ref_chain] + ost.LogInfo("Per residue local lDDT (reference):") + ost.LogInfo("Chain\tResidue Number\tResidue Name" + "\tlDDT\tConserved Contacts\tTotal " + "Contacts") + for x in tmp: + ost.LogInfo("%s\t%i\t%s\t%.4f\t%i\t%i" % ( + x["chain"], + x["residue_number"], + x["residue_name"], + x["lddt"], + x["conserved_contacts"], + x["total_contacts"])) + sc_lddt_scores["per_residue_scores"] = tmp + + lddt_results["single_chain_lddt"].append( + sc_lddt_scores) + + # perform oligo lddt scoring + ost.LogInfo(' --> Computing oligomeric lDDT score') + lddt_results["oligo_lddt"] = { + "status": "SUCCESS", + "error": "", + "global_score": lddt_scorer.oligo_lddt, + "conserved_contacts": lddt_scorer.oligo_lddt_cons, + "total_contacts": lddt_scorer.oligo_lddt_tot} ost.LogInfo( - " --> Skipping weighted lDDT because consistency " - "check failed") - lddt_results["weighted_lddt"] = { - "status": "FAILURE", - "error": "Consistency check failed.", - "global_score": 0.0} - else: - try: - ost.LogInfo(' --> Computing weighted lDDT score') - lddt_results["weighted_lddt"] = { - "status": "SUCCESS", - "error": "", - "global_score": oligo_lddt_scorer.weighted_lddt} - ost.LogInfo( - "Weighted lDDT score: %.4f" % - oligo_lddt_scorer.weighted_lddt) - except Exception as ex: - ost.LogError('Weighted lDDT failed:', str(ex)) - lddt_results["weighted_lddt"] = { - "status": "FAILURE", - "error": str(ex), - "global_score": 0.0} + "Oligo lDDT score: %.4f" % lddt_scorer.oligo_lddt) + + if opts.save_per_residue_scores: + tmp = lddt_scorer.oligo_lddt_per_res + ost.LogInfo("Per residue local oligo lDDT (reference):") + ost.LogInfo("Chain\tResidue Number\tResidue Name" + "\tlDDT\tConserved Contacts\tTotal " + "Contacts") + for x in tmp: + ost.LogInfo("%s\t%i\t%s\t%.4f\t%i\t%i" % ( + x["chain"], + x["residue_number"], + x["residue_name"], + x["lddt"], + x["conserved_contacts"], + x["total_contacts"])) + lddt_results["oligo_lddt"]["per_residue_scores"] = tmp + reference_results["lddt"] = lddt_results + model_results[reference_name] = reference_results if opts.dump_structures: ost.LogInfo("-" * 80) diff --git a/modules/mol/alg/pymod/qsscoring.py b/modules/mol/alg/pymod/qsscoring.py index b47aa4e76..38727ea1c 100644 --- a/modules/mol/alg/pymod/qsscoring.py +++ b/modules/mol/alg/pymod/qsscoring.py @@ -19,7 +19,7 @@ by `Bertoni et al. <https://dx.doi.org/10.1038/s41598-017-09654-8>`_. from ost import mol, geom, conop, seq, settings, PushVerbosityLevel from ost import LogError, LogWarning, LogScript, LogInfo, LogVerbose, LogDebug from ost.bindings.clustalw import ClustalW -from ost.mol.alg import lDDTScorer +from ost.mol.alg import lddt from ost.seq.alg.renumber import Renumber import numpy as np from scipy.special import factorial @@ -558,14 +558,9 @@ class QSscorer: :param settings: Passed to :class:`OligoLDDTScorer` constructor. :param penalize_extra_chains: Passed to :class:`OligoLDDTScorer` constructor. """ - if penalize_extra_chains: - return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, - self.alignments, self.calpha_only, settings, - True, self.chem_mapping) - else: - return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, - self.alignments, self.calpha_only, settings, False) - + return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, + self.alignments, self.calpha_only, settings, + penalize_extra_chains = penalize_extra_chains) ############################################################################## # Class internal helpers (anything that doesnt easily work without this class) @@ -959,7 +954,7 @@ class OligoLDDTScorer(object): By construction, lDDT scores are not symmetric and hence it matters which structure is the reference (:attr:`ref`) and which one is the model (:attr:`mdl`). Extra residues in the model are generally not considered. - Extra chains in both model and reference can be considered by setting the + Extra chains in the model can be considered by setting the :attr:`penalize_extra_chains` flag to True. :param ref: Sets :attr:`ref` @@ -968,15 +963,15 @@ class OligoLDDTScorer(object): :param calpha_only: Sets :attr:`calpha_only` :param settings: Sets :attr:`settings` :param penalize_extra_chains: Sets :attr:`penalize_extra_chains` - :param chem_mapping: Sets :attr:`chem_mapping`. Must be given if - *penalize_extra_chains* is True. .. attribute:: ref mdl Full reference/model entity to be scored. The entity must contain all chains - mapped in :attr:`alignments` and may also contain additional ones which are - considered if :attr:`penalize_extra_chains` is True. + mapped in :attr:`alignments`. Additional chains in ref automatically impact + the lDDT score as the according contacts are not conserved. + However, punishing for extra chains in mdl must be explicitely activated by + setting :attr:`penalize_extra_chains` to True. :type: :class:`~ost.mol.EntityHandle` @@ -1003,27 +998,9 @@ class OligoLDDTScorer(object): .. attribute:: penalize_extra_chains - If True, extra chains in both :attr:`ref` and :attr:`mdl` will penalize the - lDDT scores. + If True, extra chains in :attr:`mdl` will penalize the lDDT scores. :type: :class:`bool` - - .. attribute:: chem_mapping - - Inter-complex mapping of chemical groups as defined in - :attr:`QSscorer.chem_mapping`. Used to find "chem-mapped" chains in - :attr:`ref` for unmapped chains in :attr:`mdl` when penalizing scores. - Each unmapped model chain can add extra reference-contacts according to the - average total contacts of each single "chem-mapped" reference chain. If - there is no "chem-mapped" reference chain, a warning is shown and the model - chain is ignored. - - - Only relevant if :attr:`penalize_extra_chains` is True. - - :type: :class:`dict` with key = :class:`tuple` of chain names in - :attr:`ref` and value = :class:`tuple` of chain names in - :attr:`mdl`. """ # NOTE: one could also allow computation of both penalized and unpenalized @@ -1031,10 +1008,6 @@ class OligoLDDTScorer(object): def __init__(self, ref, mdl, alignments, calpha_only, settings, penalize_extra_chains=False, chem_mapping=None): - # sanity checks - if chem_mapping is None and penalize_extra_chains: - raise RuntimeError("Must provide chem_mapping when requesting penalty " - "for extra chains!") if not penalize_extra_chains: # warn for unmapped model chains unmapped_mdl_chains = self._GetUnmappedMdlChains(mdl, alignments) @@ -1048,7 +1021,7 @@ class OligoLDDTScorer(object): unmapped_ref_chains = (ref_chains - mapped_ref_chains) if unmapped_ref_chains: LogWarning('REFERENCE contains chains unmapped to MODEL, ' - 'lDDT is not considering REFERENCE chains %s' \ + 'lDDT penalizes these non-satisfied contacts %s' \ % str(list(unmapped_ref_chains))) # prepare fields self.ref = ref @@ -1057,27 +1030,25 @@ class OligoLDDTScorer(object): self.calpha_only = calpha_only self.settings = settings self.penalize_extra_chains = penalize_extra_chains - self.chem_mapping = chem_mapping - self._sc_lddt = None + self._lddt_scorer = None self._oligo_lddt = None + self._oligo_lddt_tot = None + self._oligo_lddt_cons = None + self._oligo_lddt_per_res = None + self._sc_lddt = None + self._sc_lddt_tot = None + self._sc_lddt_cons = None + self._sc_lddt_per_res = None self._weighted_lddt = None - self._lddt_ref = None - self._lddt_mdl = None - self._oligo_lddt_scorer = None - self._mapped_lddt_scorers = None - self._ref_scorers = None - self._model_penalty = None + self._chain_mapping = None @property def oligo_lddt(self): """Oligomeric lDDT score. - The score is computed as conserved contacts divided by the total contacts - in the reference using the :attr:`oligo_lddt_scorer`, which uses the full - complex as reference/model structure. If :attr:`penalize_extra_chains` is - True, the reference/model complexes contain all chains (otherwise only the - mapped ones) and additional contacts are added to the reference's total - contacts for unmapped model chains according to the :attr:`chem_mapping`. + lDDT using the full complex as reference/model structure. If + :attr:`penalize_extra_chains` is True, the contacts from additional + non-mapped model chains are added to the reference's total. The main difference with :attr:`weighted_lddt` is that the lDDT scorer "sees" the full complex here (incl. inter-chain contacts), while the @@ -1087,159 +1058,236 @@ class OligoLDDTScorer(object): :type: :class:`float` """ if self._oligo_lddt is None: - LogInfo('Reference %s has: %s chains' \ - % (self.ref.GetName(), self.ref.chain_count)) - LogInfo('Model %s has: %s chains' \ - % (self.mdl.GetName(), self.mdl.chain_count)) - - # score with or w/o extra-chain penalty - if self.penalize_extra_chains: - denominator = self.oligo_lddt_scorer.total_contacts - denominator += self._GetModelPenalty() - if denominator > 0: - oligo_lddt = self.oligo_lddt_scorer.conserved_contacts \ - / float(denominator) - else: - oligo_lddt = 0.0 - else: - oligo_lddt = self.oligo_lddt_scorer.global_score - self._oligo_lddt = oligo_lddt + lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, a, b, c = \ + self.lddt_scorer.lDDT(self.mdl, + thresholds = self.settings.cutoffs, + chain_mapping = self.chain_mapping, + no_interchain=False, + penalize_extra_chains=self.penalize_extra_chains, + residue_mapping=self.alignments, + local_lddt_prop="oligo_lddt", + local_contact_prop="oligo_contact", + return_dist_test=True) + self._oligo_lddt = lDDT + self._oligo_lddt_tot = lDDT_tot + self._oligo_lddt_cons = lDDT_cons return self._oligo_lddt @property - def weighted_lddt(self): - """Weighted average of single chain lDDT scores. - - The score is computed as a weighted average of single chain lDDT scores - (see :attr:`sc_lddt_scorers`) using the total contacts of each single - reference chain as weights. If :attr:`penalize_extra_chains` is True, - unmapped chains are added with a 0 score and total contacts taken from - the actual reference chains or (for unmapped model chains) using the - :attr:`chem_mapping`. + def oligo_lddt_tot(self): + """Number of total contacts used for oligo_lddt - See :attr:`oligo_lddt` for a comparison of the two scores. + Potentially includes penalty contacts from non-mapped model chains :getter: Computed on first use (cached) - :type: :class:`float` + :type: :class:`int` """ - if self._weighted_lddt is None: - scores = [s.global_score for s in self.sc_lddt_scorers] - weights = [s.total_contacts for s in self.sc_lddt_scorers] - nominator = sum([s * w for s, w in zip(scores, weights)]) - if self.penalize_extra_chains: - ref_scorers = self._GetRefScorers() - denominator = sum(s.total_contacts for s in list(ref_scorers.values())) - denominator += self._GetModelPenalty() - else: - denominator = sum(weights) - if denominator > 0: - self._weighted_lddt = nominator / float(denominator) - else: - self._weighted_lddt = 0.0 - return self._weighted_lddt + if self._oligo_lddt_tot is None: + yolo = self.oligo_lddt + assert(self._oligo_lddt_tot is not None) + return self._oligo_lddt_tot @property - def lddt_ref(self): - """The reference entity used for oligomeric lDDT scoring - (:attr:`oligo_lddt` / :attr:`oligo_lddt_scorer`). - - Since the lDDT computation requires a single chain with mapped residue - numbering, all chains of :attr:`ref` are appended into a single chain X with - unique residue numbers according to the column-index in the alignment. The - alignments are in the same order as they appear in :attr:`alignments`. - Additional residues are appended at the end of the chain with unique residue - numbers. Unmapped chains are only added if :attr:`penalize_extra_chains` is - True. Only CA atoms are considered if :attr:`calpha_only` is True. + def oligo_lddt_cons(self): + """Number of conserved contacts used for oligo_lddt :getter: Computed on first use (cached) - :type: :class:`~ost.mol.EntityHandle` + :type: :class:`int` """ - if self._lddt_ref is None: - self._PrepareOligoEntities() - return self._lddt_ref - + if self._oligo_lddt_cons is None: + yolo = self.oligo_lddt + assert(self._oligo_lddt_cons is not None) + return self._oligo_lddt_cons + @property - def lddt_mdl(self): - """The model entity used for oligomeric lDDT scoring - (:attr:`oligo_lddt` / :attr:`oligo_lddt_scorer`). + def oligo_lddt_per_res(self): + """Per residue scores based on oligo_lddt - Like :attr:`lddt_ref`, this is a single chain X containing all chains of - :attr:`mdl`. The residue numbers match the ones in :attr:`lddt_ref` where - aligned and have unique numbers for additional residues. + Each scored residue gets a dict with keys: + ["residue_number", "residue_name", "chain", "lddt", "conserved_contacts", + "total_contacts"] + The first three uniquely identify the residue and refer to the residue in + self.ref. :getter: Computed on first use (cached) - :type: :class:`~ost.mol.EntityHandle` + :type: :class:`list` of :class:`dict` """ - if self._lddt_mdl is None: - self._PrepareOligoEntities() - return self._lddt_mdl + if self._oligo_lddt_per_res is None: + yolo = self.oligo_lddt # trigger oligo_lddt computation to assign scores + # and contacts as generic properties on residues + self._oligo_lddt_per_res = self._GetPerResidueScores(self.alignments, + "oligo_lddt", + "oligo_contact_cons", + "oligo_contact_exp") + return self._oligo_lddt_per_res @property - def oligo_lddt_scorer(self): - """lDDT Scorer object for :attr:`lddt_ref` and :attr:`lddt_mdl`. + def sc_lddt(self): + """List of global lDDT score for each chain mapping in self.alignments. :getter: Computed on first use (cached) - :type: :class:`~ost.mol.alg.lDDTScorer` + :type: :class:`list` of :class:`float` """ - if self._oligo_lddt_scorer is None: - self._oligo_lddt_scorer = lDDTScorer( - references=[self.lddt_ref.Select("")], - model=self.lddt_mdl.Select(""), - settings=self.settings) - return self._oligo_lddt_scorer + if self._sc_lddt is None: + yolo = self.weighted_lddt # sc_lddt is computed as a side product + assert(self._sc_lddt is not None) + assert(self._sc_lddt_tot is not None) + assert(self._sc_lddt_cons is not None) + return self._sc_lddt @property - def mapped_lddt_scorers(self): - """List of scorer objects for each chain mapped in :attr:`alignments`. + def sc_lddt_tot(self): + """Number of total contacts for each chain mapping in self.alignments :getter: Computed on first use (cached) - :type: :class:`list` of :class:`MappedLDDTScorer` + :type: :class:`list` of :class:`int` """ - if self._mapped_lddt_scorers is None: - self._mapped_lddt_scorers = list() - for aln in self.alignments: - mapped_lddt_scorer = MappedLDDTScorer(aln, self.calpha_only, - self.settings) - self._mapped_lddt_scorers.append(mapped_lddt_scorer) - return self._mapped_lddt_scorers + if self._sc_lddt_tot is None: + yolo = self.sc_lddt # sc_lddt_tot is computed as a sideproduct + assert(self._sc_lddt_tot is not None) + return self._sc_lddt_tot + + @property + def sc_lddt_cons(self): + """Number of conserved contacts for each chain mapping in self.alignments + + :getter: Computed on first use (cached) + :type: :class:`list` of :class:`int` + """ + if self._sc_lddt_cons is None: + yolo = self.sc_lddt # sc_lddt_tot is computed as a sideproduct + assert(self._sc_lddt_cons is not None) + return self._sc_lddt_cons @property - def sc_lddt_scorers(self): - """List of lDDT scorer objects extracted from :attr:`mapped_lddt_scorers`. + def sc_lddt_per_res(self): + """Per residue scores based on sc_lddt - :type: :class:`list` of :class:`~ost.mol.alg.lDDTScorer` + Each scored residue gets a dict with keys: + ["residue_number", "residue_name", "chain", "lddt", "conserved_contacts", + "total_contacts"] + The first three uniquely identify the residue and refer to the residue in + self.ref. + + :getter: Computed on first use (cached) + :type: :class:`list` of :class:`dict` """ - return [mls.lddt_scorer for mls in self.mapped_lddt_scorers] + if self._sc_lddt_per_res is None: + yolo = self.sc_lddt # trigger sc_lddt computation to assign scores + # and contacts as generic properties on residues + self._sc_lddt_per_res = self._GetPerResidueScores(self.alignments, + "sc_lddt", + "sc_contact_cons", + "sc_contact_exp") + return self._sc_lddt_per_res @property - def sc_lddt(self): - """List of global scores extracted from :attr:`sc_lddt_scorers`. + def weighted_lddt(self): + """Weighted average of single chain lDDT scores. - If scoring for a mapped chain fails, an error is displayed and a score of 0 - is assigned. + The score is computed as a weighted average of single chain lDDT scores. + In principle thats oligo_lddt without inter-chain contacts. + (see :attr:`sc_lddt_scorers`). Chains in ref which are not mapped + penalize the overall score as their contacts are not conserved. + Chains in mdl which are not mapped only penalize the score if + :attr:`penalize_extra_chains` is True. :getter: Computed on first use (cached) - :type: :class:`list` of :class:`float` + :type: :class:`float` """ - if self._sc_lddt is None: + if self._weighted_lddt is None: + lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, per_res_exp, \ + per_res_conserved = \ + self.lddt_scorer.lDDT(self.mdl, + thresholds = self.settings.cutoffs, + chain_mapping = self.chain_mapping, + no_interchain=True, + penalize_extra_chains=self.penalize_extra_chains, + residue_mapping=self.alignments, + local_lddt_prop="sc_lddt", + local_contact_prop="sc_contact", + return_dist_test=True) + self._weighted_lddt = lDDT + + # we directly use the results from above to also compute the + # single chain lDDTs manually self._sc_lddt = list() - for lddt_scorer in self.sc_lddt_scorers: - try: - self._sc_lddt.append(lddt_scorer.global_score) - except Exception as ex: - LogError('Single chain lDDT failed:', str(ex)) - self._sc_lddt.append(0.0) - return self._sc_lddt + self._sc_lddt_tot = list() + self._sc_lddt_cons = list() + chain_res_indices = dict() + residues = self.mdl.residues + for i, r_idx in enumerate(res_indices): + r = residues[r_idx] + ch = r.GetChain().GetName() + if ch not in chain_res_indices: + chain_res_indices[ch] = list() + chain_res_indices[ch].append(i) + + n_thresholds = len(self.settings.cutoffs) + for aln in self.alignments: + ch = aln.GetSequence(0).GetName() + cons = int(np.sum(per_res_conserved.take(chain_res_indices[ch], axis=0))) + tot = self.lddt_scorer.GetNChainContacts(ch, no_interchain=True) + tot*=n_thresholds + if tot > 0: + self._sc_lddt.append(float(cons)/tot) + else: + self._sc_lddt.append(0) + self._sc_lddt_tot.append(tot) + self._sc_lddt_cons.append(cons) + return self._weighted_lddt + + @property + def lddt_scorer(self): + if self._lddt_scorer is None: + if not conop.GetDefaultLib(): + raise RuntimeError("OligolDDT computation requires a compound library!") + r = self.settings.radius + seq_sep = self.settings.sequence_separation + self._lddt_scorer = lddt.lDDTScorer(self.ref, conop.GetDefaultLib(), + inclusion_radius = r, + sequence_separation = seq_sep) + return self._lddt_scorer + + @property + def chain_mapping(self): + if self._chain_mapping is None: + # chain mapping as required by lddt_scorer + # key: model chain, value: reference chain + self._chain_mapping = dict() + for aln in self.alignments: + ref_seq = aln.GetSequence(0) + mdl_seq = aln.GetSequence(1) + self._chain_mapping[mdl_seq.GetName()] = ref_seq.GetName() + return self._chain_mapping ############################################################################## # Class internal helpers ############################################################################## - def _PrepareOligoEntities(self): - # simple wrapper to avoid code duplication - self._lddt_ref, self._lddt_mdl = _MergeAlignedChains( - self.alignments, self.ref, self.mdl, self.calpha_only, - self.penalize_extra_chains) + @staticmethod + def _GetPerResidueScores(alignments, lddt_prop, cons_prop, tot_prop): + per_residue_sc = list() + for aln in alignments: + reference_chain = aln.GetSequence(0).GetName() + for col in aln: + ref_res = col.GetResidue(0) + mdl_res = col.GetResidue(1) + if ref_res.IsValid() and mdl_res.IsValid(): + if mdl_res.HasProp(lddt_prop) and mdl_res.HasProp(cons_prop) and \ + mdl_res.HasProp(tot_prop): + num = ref_res.GetNumber().GetNum() + name = ref_res.GetName() + score = mdl_res.GetFloatProp(lddt_prop) + cons = mdl_res.GetIntProp(cons_prop) + tot = mdl_res.GetIntProp(tot_prop) + per_residue_sc.append({"residue_number": num, + "residue_name": name, + "chain": reference_chain, + "lddt": score, + "conserved_contacts": cons, + "total_contacts": tot}) + return per_residue_sc @staticmethod def _GetUnmappedMdlChains(mdl, alignments): @@ -1248,224 +1296,6 @@ class OligoLDDTScorer(object): mapped_mdl_chains = set(aln.GetSequence(1).GetName() for aln in alignments) return (mdl_chains - mapped_mdl_chains) - def _GetRefScorers(self): - # single chain lddt scorers for each reference chain (key = chain name) - if self._ref_scorers is None: - # collect from mapped_lddt_scorers - ref_scorers = dict() - for mapped_lddt_scorer in self.mapped_lddt_scorers: - ref_ch_name = mapped_lddt_scorer.reference_chain_name - ref_scorers[ref_ch_name] = mapped_lddt_scorer.lddt_scorer - # add new ones where needed - for ch in self.ref.chains: - if ch.name not in ref_scorers: - if self.calpha_only: - ref_chain = ch.Select('aname=CA') - else: - ref_chain = ch.Select('') - ref_scorers[ch.name] = lDDTScorer( - references=[ref_chain], - model=ref_chain, - settings=self.settings) - # store in cache - self._ref_scorers = ref_scorers - # fetch from cache - return self._ref_scorers - - def _GetModelPenalty(self): - # extra value to add to total number of distances for extra model chains - # -> estimated from chem-mapped reference chains - if self._model_penalty is None: - # sanity check - if self.chem_mapping is None: - raise RuntimeError("Must provide chem_mapping when requesting penalty " - "for extra model chains!") - # get cached ref_scorers - ref_scorers = self._GetRefScorers() - # get unmapped model chains - unmapped_mdl_chains = self._GetUnmappedMdlChains(self.mdl, self.alignments) - # map extra chains to ref. chains - model_penalty = 0 - for ch_name_mdl in sorted(unmapped_mdl_chains): - # get penalty for chain - cur_penalty = None - for cm_ref, cm_mdl in self.chem_mapping.items(): - if ch_name_mdl in cm_mdl: - # penalize by an average of the chem. mapped ref. chains - cur_penalty = 0 - for ch_name_ref in cm_ref: - # assumes that total_contacts is cached (for speed) - cur_penalty += ref_scorers[ch_name_ref].total_contacts - cur_penalty /= float(len(cm_ref)) - break - # report penalty - if cur_penalty is None: - LogWarning('Extra MODEL chain %s could not be chemically mapped to ' - 'any chain in REFERENCE, lDDT cannot consider it!' \ - % ch_name_mdl) - else: - LogScript('Extra MODEL chain %s added to lDDT score by considering ' - 'chemically mapped chains in REFERENCE.' % ch_name_mdl) - model_penalty += cur_penalty - # store in cache - self._model_penalty = model_penalty - # fetch from cache - return self._model_penalty - - -class MappedLDDTScorer(object): - """A simple class to calculate a single-chain lDDT score on a given chain to - chain mapping as extracted from :class:`OligoLDDTScorer`. - - :param alignment: Sets :attr:`alignment` - :param calpha_only: Sets :attr:`calpha_only` - :param settings: Sets :attr:`settings` - - .. attribute:: alignment - - Alignment with two sequences named according to the mapped chains and with - views attached to both sequences (e.g. one of the items of - :attr:`QSscorer.alignments`). - - The first sequence is assumed to be the reference and the second one the - model. Since the lDDT score is not symmetric (extra residues in model are - ignored), the order is important. - - :type: :class:`~ost.seq.AlignmentHandle` - - .. attribute:: calpha_only - - If True, restricts lDDT score to CA only. - - :type: :class:`bool` - - .. attribute:: settings - - Settings to use for lDDT scoring. - - :type: :class:`~ost.mol.alg.lDDTSettings` - - .. attribute:: lddt_scorer - - lDDT Scorer object for the given chains. - - :type: :class:`~ost.mol.alg.lDDTScorer` - - .. attribute:: reference_chain_name - - Chain name of the reference. - - :type: :class:`str` - - .. attribute:: model_chain_name - - Chain name of the model. - - :type: :class:`str` - """ - def __init__(self, alignment, calpha_only, settings): - # prepare fields - self.alignment = alignment - self.calpha_only = calpha_only - self.settings = settings - self.lddt_scorer = None # set in _InitScorer - self.reference_chain_name = alignment.sequences[0].name - self.model_chain_name = alignment.sequences[1].name - self._old_number_label = "old_num" - self._extended_alignment = None # set in _InitScorer - # initialize lDDT scorer - self._InitScorer() - - def GetPerResidueScores(self): - """ - :return: Scores for each residue - :rtype: :class:`list` of :class:`dict` with one item for each residue - existing in model and reference: - - - "residue_number": Residue number in reference chain - - "residue_name": Residue name in reference chain - - "lddt": local lDDT - - "conserved_contacts": number of conserved contacts - - "total_contacts": total number of contacts - """ - scores = list() - assigned_residues = list() - # Make sure the score is calculated - self.lddt_scorer.global_score - for col in self._extended_alignment: - if col[0] != "-" and col.GetResidue(3).IsValid(): - ref_res = col.GetResidue(0) - mdl_res = col.GetResidue(1) - ref_res_renum = col.GetResidue(2) - mdl_res_renum = col.GetResidue(3) - if ref_res.one_letter_code != ref_res_renum.one_letter_code: - raise RuntimeError("Reference residue name mapping inconsistent: %s != %s" % - (ref_res.one_letter_code, - ref_res_renum.one_letter_code)) - if mdl_res.one_letter_code != mdl_res_renum.one_letter_code: - raise RuntimeError("Model residue name mapping inconsistent: %s != %s" % - (mdl_res.one_letter_code, - mdl_res_renum.one_letter_code)) - if ref_res.GetNumber().num != ref_res_renum.GetIntProp(self._old_number_label): - raise RuntimeError("Reference residue number mapping inconsistent: %s != %s" % - (ref_res.GetNumber().num, - ref_res_renum.GetIntProp(self._old_number_label))) - if mdl_res.GetNumber().num != mdl_res_renum.GetIntProp(self._old_number_label): - raise RuntimeError("Model residue number mapping inconsistent: %s != %s" % - (mdl_res.GetNumber().num, - mdl_res_renum.GetIntProp(self._old_number_label))) - if ref_res.qualified_name in assigned_residues: - raise RuntimeError("Duplicated residue in reference: " % - (ref_res.qualified_name)) - else: - assigned_residues.append(ref_res.qualified_name) - # check if property there (may be missing for CA-only) - if mdl_res_renum.HasProp(self.settings.label): - scores.append({ - "residue_number": ref_res.GetNumber().num, - "residue_name": ref_res.name, - "lddt": mdl_res_renum.GetFloatProp(self.settings.label), - "conserved_contacts": mdl_res_renum.GetFloatProp(self.settings.label + "_conserved"), - "total_contacts": mdl_res_renum.GetFloatProp(self.settings.label + "_total")}) - return scores - - ############################################################################## - # Class internal helpers (anything that doesnt easily work without this class) - ############################################################################## - - def _InitScorer(self): - # Use copy of alignment (extended by 2 extra sequences for renumbering) - aln = self.alignment.Copy() - # Get chains and renumber according to alignment (for lDDT) - reference = Renumber( - aln.GetSequence(0), - old_number_label=self._old_number_label).CreateFullView() - refseq = seq.CreateSequence( - "reference_renumbered", - aln.GetSequence(0).GetString()) - refseq.AttachView(reference) - aln.AddSequence(refseq) - model = Renumber( - aln.GetSequence(1), - old_number_label=self._old_number_label).CreateFullView() - modelseq = seq.CreateSequence( - "model_renumbered", - aln.GetSequence(1).GetString()) - modelseq.AttachView(model) - aln.AddSequence(modelseq) - # Filter to CA-only if desired (done after AttachView to not mess it up) - if self.calpha_only: - self.lddt_scorer = lDDTScorer( - references=[reference.Select('aname=CA')], - model=model.Select('aname=CA'), - settings=self.settings) - else: - self.lddt_scorer = lDDTScorer( - references=[reference], - model=model, - settings=self.settings) - # Store alignment for later - self._extended_alignment = aln ############################################################################### # HELPERS @@ -2890,101 +2720,6 @@ def _GetQsSuperposition(alns): res = mol.alg.SuperposeSVD(view_1, view_2, apply_transform=False) return res - -def _AddResidue(edi, res, rnum, chain, calpha_only): - """ - Add residue *res* with res. num. *run* to given *chain* using editor *edi*. - Either all atoms added or (if *calpha_only*) only CA. - """ - if calpha_only: - ca_atom = res.FindAtom('CA') - if ca_atom.IsValid(): - new_res = edi.AppendResidue(chain, res.name, rnum) - edi.InsertAtom(new_res, ca_atom.name, ca_atom.pos) - else: - new_res = edi.AppendResidue(chain, res.name, rnum) - for atom in res.atoms: - edi.InsertAtom(new_res, atom.name, atom.pos) - -def _MergeAlignedChains(alns, ent_1, ent_2, calpha_only, penalize_extra_chains): - """ - Create two new entities (based on the alignments attached views) where all - residues have same numbering (when they're aligned) and they are all pushed to - a single chain X. Also append extra chains contained in *ent_1* and *ent_2* - but not contained in *alns*. - - Used for :attr:`QSscorer.lddt_ref` and :attr:`QSscorer.lddt_mdl` - - :param alns: List of alignments with attached views (first sequence: *ent_1*, - second: *ent_2*). Residue number in single chain is column index - of current alignment + sum of lengths of all previous alignments - (order of alignments as in input list). - :type alns: See :attr:`QSscorer.alignments` - :param ent_1: First entity to process. - :type ent_1: :class:`~ost.mol.EntityHandle` - :param ent_2: Second entity to process. - :type ent_2: :class:`~ost.mol.EntityHandle` - :param calpha_only: If True, we only include CA atoms instead of all. - :type calpha_only: :class:`bool` - :param penalize_extra_chains: If True, extra chains are added to model and - reference. Otherwise, only mapped ones. - :type penalize_extra_chains: :class:`bool` - - :return: Tuple of two single chain entities (from *ent_1* and from *ent_2*) - :rtype: :class:`tuple` of :class:`~ost.mol.EntityHandle` - """ - # first new entity - ent_ren_1 = mol.CreateEntity() - ed_1 = ent_ren_1.EditXCS() - new_chain_1 = ed_1.InsertChain('X') - # second one - ent_ren_2 = mol.CreateEntity() - ed_2 = ent_ren_2.EditXCS() - new_chain_2 = ed_2.InsertChain('X') - # the alignment already contains sorted chains - rnum = 0 - chain_done_1 = set() - chain_done_2 = set() - for aln in alns: - chain_done_1.add(aln.GetSequence(0).name) - chain_done_2.add(aln.GetSequence(1).name) - for col in aln: - rnum += 1 - # add valid residues to single chain entities - res_1 = col.GetResidue(0) - if res_1.IsValid(): - _AddResidue(ed_1, res_1, rnum, new_chain_1, calpha_only) - res_2 = col.GetResidue(1) - if res_2.IsValid(): - _AddResidue(ed_2, res_2, rnum, new_chain_2, calpha_only) - # extra chains? - if penalize_extra_chains: - for chain in ent_1.chains: - if chain.name in chain_done_1: - continue - for res in chain.residues: - rnum += 1 - _AddResidue(ed_1, res, rnum, new_chain_1, calpha_only) - for chain in ent_2.chains: - if chain.name in chain_done_2: - continue - for res in chain.residues: - rnum += 1 - _AddResidue(ed_2, res, rnum, new_chain_2, calpha_only) - # get entity names - ent_ren_1.SetName(aln.GetSequence(0).GetAttachedView().GetName()) - ent_ren_2.SetName(aln.GetSequence(1).GetAttachedView().GetName()) - # connect atoms - if not conop.GetDefaultLib(): - raise RuntimeError("QSscore computation requires a compound library!") - pr = conop.RuleBasedProcessor(conop.GetDefaultLib()) - pr.Process(ent_ren_1) - ed_1.UpdateICS() - pr.Process(ent_ren_2) - ed_2.UpdateICS() - return ent_ren_1, ent_ren_2 - - # specify public interface __all__ = ('QSscoreError', 'QSscorer', 'QSscoreEntity', 'FilterContacts', - 'GetContacts', 'OligoLDDTScorer', 'MappedLDDTScorer') + 'GetContacts', 'OligoLDDTScorer') diff --git a/modules/mol/alg/tests/test_qsscoring.py b/modules/mol/alg/tests/test_qsscoring.py index 31daffe56..1810ad054 100644 --- a/modules/mol/alg/tests/test_qsscoring.py +++ b/modules/mol/alg/tests/test_qsscoring.py @@ -424,13 +424,12 @@ class TestQSscore(unittest.TestCase): lddt_oligo_scorer2 = qs_scorer2.GetOligoLDDTScorer(lddt_settings, False) self.assertAlmostEqual(qs_scorer2.global_score, 0.171, 2) self.assertAlmostEqual(qs_scorer2.best_score, 1.00, 2) - self.assertAlmostEqual(lddt_oligo_scorer2.weighted_lddt, 1.00, 2) + self.assertAlmostEqual(lddt_oligo_scorer2.weighted_lddt, 0.4996, 2) self.assertEqual(len(lddt_oligo_scorer2.sc_lddt), 2) self.assertAlmostEqual(lddt_oligo_scorer2.sc_lddt[0], 1.00, 2) self.assertAlmostEqual(lddt_oligo_scorer2.sc_lddt[1], 1.00, 2) - # without penalty we don't see extra chains - self.assertAlmostEqual(lddt_oligo_scorer2.oligo_lddt, 1.00, 2) - # with penalty we account for extra reference chains + self.assertAlmostEqual(lddt_oligo_scorer2.oligo_lddt, 0.4496, 2) + # penalty only affects additional model chains, scores are thus the same lddt_oligo_scorer2_pen = qs_scorer2.GetOligoLDDTScorer(lddt_settings, True) self.assertAlmostEqual(lddt_oligo_scorer2_pen.oligo_lddt, 0.4496, 2) self.assertAlmostEqual(lddt_oligo_scorer2_pen.weighted_lddt, 0.50, 2) -- GitLab