diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index 8a3f3b0cc1b347ee509be1c66efddc2644a77842..7aaaf691414c365299a9d22dbd0f03a0fac59a89 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -67,10 +67,12 @@ import ost from ost.io import (LoadPDB, LoadMMCIF, SavePDB, MMCifInfoBioUnit, MMCifInfo, MMCifInfoTransOp, ReadStereoChemicalPropsFile, profiles) from ost import PushVerbosityLevel -from ost.mol.alg import (qsscoring, lddt, Molck, MolckSettings, lDDTSettings, - CheckStructure) +from ost.mol.alg import (qsscoring, Molck, MolckSettings, lDDTSettings, + CheckStructure, ResidueNamesMatch) from ost.conop import (CompoundLib, SetDefaultLib, GetDefaultLib, RuleBasedProcessor) +from ost.seq.alg.renumber import Renumber + def _GetDefaultShareFilePath(filename): """Look for filename in working directory and OST shared data path. @@ -518,19 +520,10 @@ def _RevertChainNames(ent): def _CheckConsistency(alignments, log_error): is_cons = True for alignment in alignments: - for col in alignment: - r1 = col.GetResidue(0) - r2 = col.GetResidue(1) - if r1.IsValid() and r2.IsValid(): - if r1.GetName() != r2.GetName(): - is_cons = False - msg = f"Name mismatch for model residue {r2}: in the " - msg += f"reference structure(s) is {r1.GetName()} ({r1}), " - msg += f"in the model {r2.GetName()}" - if log_error: - ost.LogError(msg) - else: - ost.LogWarning(msg) + ref_chain = Renumber(alignment.GetSequence(0)).CreateFullView() + mdl_chain = Renumber(alignment.GetSequence(1)).CreateFullView() + new_is_cons = ResidueNamesMatch(mdl_chain, ref_chain, log_error) + is_cons = is_cons and new_is_cons return is_cons @@ -844,24 +837,25 @@ def _Main(): "best_score": 0.0} # Calculate lDDT if opts.lddt: - - lddt_settings = lDDTSettings(radius=opts.inclusion_radius, - sequence_separation=opts.sequence_separation, label="lddt") - ost.LogInfo("-" * 80) ost.LogInfo("Computing lDDT scores") - ost.LogInfo("lDDT settings: ") - ost.LogInfo(str(lddt_settings).rstrip()) - ost.LogInfo("===") - lddt_results = { "single_chain_lddt": list() } - - if skip_score: - for aln in qs_scorer.alignments: - reference_chain = aln.GetSequence(0).GetName() - model_chain = aln.GetSequence(1).GetName() + lddt_settings = lDDTSettings( + radius=opts.inclusion_radius, + sequence_separation=opts.sequence_separation, + label="lddt") + ost.LogInfo("lDDT settings: ") + ost.LogInfo(str(lddt_settings).rstrip()) + ost.LogInfo("===") + oligo_lddt_scorer = qs_scorer.GetOligoLDDTScorer(lddt_settings) + for mapped_lddt_scorer in oligo_lddt_scorer.mapped_lddt_scorers: + # Get data + lddt_scorer = mapped_lddt_scorer.lddt_scorer + model_chain = mapped_lddt_scorer.model_chain_name + reference_chain = mapped_lddt_scorer.reference_chain_name + if skip_score: ost.LogInfo( " --> Skipping single chain lDDT because " "consistency check failed") @@ -873,97 +867,107 @@ def _Main(): "global_score": 0.0, "conserved_contacts": 0.0, "total_contacts": 0.0}) - + else: + try: + ost.LogInfo((" --> Computing lDDT between model " + "chain %s and reference chain %s") % ( + model_chain, + reference_chain)) + ost.LogInfo("Global LDDT score: %.4f" % + lddt_scorer.global_score) + ost.LogInfo( + "(%i conserved distances out of %i checked, over " + "%i thresholds)" % (lddt_scorer.conserved_contacts, + lddt_scorer.total_contacts, + len(lddt_settings.cutoffs))) + sc_lddt_scores = { + "status": "SUCCESS", + "error": "", + "model_chain": model_chain, + "reference_chain": reference_chain, + "global_score": lddt_scorer.global_score, + "conserved_contacts": + lddt_scorer.conserved_contacts, + "total_contacts": lddt_scorer.total_contacts} + if opts.save_per_residue_scores: + per_residue_sc = \ + mapped_lddt_scorer.GetPerResidueScores() + ost.LogInfo("Per residue local lDDT (reference):") + ost.LogInfo("Chain\tResidue Number\tResidue Name" + "\tlDDT\tConserved Contacts\tTotal " + "Contacts") + for prs_scores in per_residue_sc: + ost.LogInfo("%s\t%i\t%s\t%.4f\t%i\t%i" % ( + reference_chain, + prs_scores["residue_number"], + prs_scores["residue_name"], + prs_scores["lddt"], + prs_scores["conserved_contacts"], + prs_scores["total_contacts"])) + sc_lddt_scores["per_residue_scores"] = \ + per_residue_sc + lddt_results["single_chain_lddt"].append( + sc_lddt_scores) + except Exception as ex: + ost.LogError('Single chain lDDT failed:', str(ex)) + lddt_results["single_chain_lddt"].append({ + "status": "FAILURE", + "error": str(ex), + "model_chain": model_chain, + "reference_chain": reference_chain, + "global_score": 0.0, + "conserved_contacts": 0.0, + "total_contacts": 0.0}) + # perform oligo lddt scoring + if skip_score: ost.LogInfo( " --> Skipping oligomeric lDDT because consistency " "check failed") lddt_results["oligo_lddt"] = { "status": "FAILURE", "error": "Consistency check failed.", - "global_score": 0.0, - "conserved_contacts": 0, - "total_contacts": 0} - + "global_score": 0.0} else: - lddt_scorer = qs_scorer.GetOligoLDDTScorer(lddt_settings, - penalize_extra_chains=True) - - # do single chain lDDT - scores = lddt_scorer.sc_lddt - per_res = lddt_scorer.sc_lddt_per_res - tot = lddt_scorer.sc_lddt_tot - cons = lddt_scorer.sc_lddt_cons - alns = qs_scorer.alignments - - for idx in range(len(scores)): - ref_chain = alns[idx].GetSequence(0).GetName() - model_chain = alns[idx].GetSequence(1).GetName() - - ost.LogInfo((" --> Computing lDDT between model " - "chain %s and reference chain %s") % ( - model_chain, - ref_chain)) - ost.LogInfo("Global LDDT score: %.4f" %scores[idx]) - ost.LogInfo("(%i conserved distances out of %i checked, over " - "%i thresholds)" % (cons[idx], tot[idx], 4)) - - sc_lddt_scores = { + try: + ost.LogInfo(' --> Computing oligomeric lDDT score') + lddt_results["oligo_lddt"] = { "status": "SUCCESS", "error": "", - "model_chain": model_chain, - "reference_chain": ref_chain, - "global_score": scores[idx], - "conserved_contacts": cons[idx], - "total_contacts": tot[idx]} - - if opts.save_per_residue_scores: - tmp = [x for x in per_res if x["chain"]==ref_chain] - ost.LogInfo("Per residue local lDDT (reference):") - ost.LogInfo("Chain\tResidue Number\tResidue Name" - "\tlDDT\tConserved Contacts\tTotal " - "Contacts") - for x in tmp: - ost.LogInfo("%s\t%i\t%s\t%.4f\t%i\t%i" % ( - x["chain"], - x["residue_number"], - x["residue_name"], - x["lddt"], - x["conserved_contacts"], - x["total_contacts"])) - sc_lddt_scores["per_residue_scores"] = tmp - - lddt_results["single_chain_lddt"].append( - sc_lddt_scores) - - # perform oligo lddt scoring - ost.LogInfo(' --> Computing oligomeric lDDT score') - lddt_results["oligo_lddt"] = { - "status": "SUCCESS", - "error": "", - "global_score": lddt_scorer.oligo_lddt, - "conserved_contacts": lddt_scorer.oligo_lddt_cons, - "total_contacts": lddt_scorer.oligo_lddt_tot} + "global_score": oligo_lddt_scorer.oligo_lddt} + ost.LogInfo( + "Oligo lDDT score: %.4f" % + oligo_lddt_scorer.oligo_lddt) + except Exception as ex: + ost.LogError('Oligo lDDT failed:', str(ex)) + lddt_results["oligo_lddt"] = { + "status": "FAILURE", + "error": str(ex), + "global_score": 0.0} + if skip_score: ost.LogInfo( - "Oligo lDDT score: %.4f" % lddt_scorer.oligo_lddt) - - if opts.save_per_residue_scores: - tmp = lddt_scorer.oligo_lddt_per_res - ost.LogInfo("Per residue local oligo lDDT (reference):") - ost.LogInfo("Chain\tResidue Number\tResidue Name" - "\tlDDT\tConserved Contacts\tTotal " - "Contacts") - for x in tmp: - ost.LogInfo("%s\t%i\t%s\t%.4f\t%i\t%i" % ( - x["chain"], - x["residue_number"], - x["residue_name"], - x["lddt"], - x["conserved_contacts"], - x["total_contacts"])) - lddt_results["oligo_lddt"]["per_residue_scores"] = tmp - + " --> Skipping weighted lDDT because consistency " + "check failed") + lddt_results["weighted_lddt"] = { + "status": "FAILURE", + "error": "Consistency check failed.", + "global_score": 0.0} + else: + try: + ost.LogInfo(' --> Computing weighted lDDT score') + lddt_results["weighted_lddt"] = { + "status": "SUCCESS", + "error": "", + "global_score": oligo_lddt_scorer.weighted_lddt} + ost.LogInfo( + "Weighted lDDT score: %.4f" % + oligo_lddt_scorer.weighted_lddt) + except Exception as ex: + ost.LogError('Weighted lDDT failed:', str(ex)) + lddt_results["weighted_lddt"] = { + "status": "FAILURE", + "error": str(ex), + "global_score": 0.0} reference_results["lddt"] = lddt_results - model_results[reference_name] = reference_results if opts.dump_structures: ost.LogInfo("-" * 80) diff --git a/modules/mol/alg/pymod/qsscoring.py b/modules/mol/alg/pymod/qsscoring.py index e5ab8d1d76c2a30ac614f81cbead719276ad875d..b47aa4e76e1544ba28caa18482253833b57106c9 100644 --- a/modules/mol/alg/pymod/qsscoring.py +++ b/modules/mol/alg/pymod/qsscoring.py @@ -19,7 +19,7 @@ by `Bertoni et al. <https://dx.doi.org/10.1038/s41598-017-09654-8>`_. from ost import mol, geom, conop, seq, settings, PushVerbosityLevel from ost import LogError, LogWarning, LogScript, LogInfo, LogVerbose, LogDebug from ost.bindings.clustalw import ClustalW -from ost.mol.alg import lddt +from ost.mol.alg import lDDTScorer from ost.seq.alg.renumber import Renumber import numpy as np from scipy.special import factorial @@ -558,9 +558,14 @@ class QSscorer: :param settings: Passed to :class:`OligoLDDTScorer` constructor. :param penalize_extra_chains: Passed to :class:`OligoLDDTScorer` constructor. """ - return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, - self.alignments, self.calpha_only, settings, - penalize_extra_chains = penalize_extra_chains) + if penalize_extra_chains: + return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, + self.alignments, self.calpha_only, settings, + True, self.chem_mapping) + else: + return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, + self.alignments, self.calpha_only, settings, False) + ############################################################################## # Class internal helpers (anything that doesnt easily work without this class) @@ -954,7 +959,7 @@ class OligoLDDTScorer(object): By construction, lDDT scores are not symmetric and hence it matters which structure is the reference (:attr:`ref`) and which one is the model (:attr:`mdl`). Extra residues in the model are generally not considered. - Extra chains in the model can be considered by setting the + Extra chains in both model and reference can be considered by setting the :attr:`penalize_extra_chains` flag to True. :param ref: Sets :attr:`ref` @@ -963,15 +968,15 @@ class OligoLDDTScorer(object): :param calpha_only: Sets :attr:`calpha_only` :param settings: Sets :attr:`settings` :param penalize_extra_chains: Sets :attr:`penalize_extra_chains` + :param chem_mapping: Sets :attr:`chem_mapping`. Must be given if + *penalize_extra_chains* is True. .. attribute:: ref mdl Full reference/model entity to be scored. The entity must contain all chains - mapped in :attr:`alignments`. Additional chains in ref automatically impact - the lDDT score as the according contacts are not conserved. - However, punishing for extra chains in mdl must be explicitely activated by - setting :attr:`penalize_extra_chains` to True. + mapped in :attr:`alignments` and may also contain additional ones which are + considered if :attr:`penalize_extra_chains` is True. :type: :class:`~ost.mol.EntityHandle` @@ -998,9 +1003,27 @@ class OligoLDDTScorer(object): .. attribute:: penalize_extra_chains - If True, extra chains in :attr:`mdl` will penalize the lDDT scores. + If True, extra chains in both :attr:`ref` and :attr:`mdl` will penalize the + lDDT scores. :type: :class:`bool` + + .. attribute:: chem_mapping + + Inter-complex mapping of chemical groups as defined in + :attr:`QSscorer.chem_mapping`. Used to find "chem-mapped" chains in + :attr:`ref` for unmapped chains in :attr:`mdl` when penalizing scores. + Each unmapped model chain can add extra reference-contacts according to the + average total contacts of each single "chem-mapped" reference chain. If + there is no "chem-mapped" reference chain, a warning is shown and the model + chain is ignored. + + + Only relevant if :attr:`penalize_extra_chains` is True. + + :type: :class:`dict` with key = :class:`tuple` of chain names in + :attr:`ref` and value = :class:`tuple` of chain names in + :attr:`mdl`. """ # NOTE: one could also allow computation of both penalized and unpenalized @@ -1008,6 +1031,10 @@ class OligoLDDTScorer(object): def __init__(self, ref, mdl, alignments, calpha_only, settings, penalize_extra_chains=False, chem_mapping=None): + # sanity checks + if chem_mapping is None and penalize_extra_chains: + raise RuntimeError("Must provide chem_mapping when requesting penalty " + "for extra chains!") if not penalize_extra_chains: # warn for unmapped model chains unmapped_mdl_chains = self._GetUnmappedMdlChains(mdl, alignments) @@ -1021,7 +1048,7 @@ class OligoLDDTScorer(object): unmapped_ref_chains = (ref_chains - mapped_ref_chains) if unmapped_ref_chains: LogWarning('REFERENCE contains chains unmapped to MODEL, ' - 'lDDT penalizes these non-satisfied contacts %s' \ + 'lDDT is not considering REFERENCE chains %s' \ % str(list(unmapped_ref_chains))) # prepare fields self.ref = ref @@ -1030,25 +1057,27 @@ class OligoLDDTScorer(object): self.calpha_only = calpha_only self.settings = settings self.penalize_extra_chains = penalize_extra_chains - self._lddt_scorer = None - self._oligo_lddt = None - self._oligo_lddt_tot = None - self._oligo_lddt_cons = None - self._oligo_lddt_per_res = None + self.chem_mapping = chem_mapping self._sc_lddt = None - self._sc_lddt_tot = None - self._sc_lddt_cons = None - self._sc_lddt_per_res = None + self._oligo_lddt = None self._weighted_lddt = None - self._chain_mapping = None + self._lddt_ref = None + self._lddt_mdl = None + self._oligo_lddt_scorer = None + self._mapped_lddt_scorers = None + self._ref_scorers = None + self._model_penalty = None @property def oligo_lddt(self): """Oligomeric lDDT score. - lDDT using the full complex as reference/model structure. If - :attr:`penalize_extra_chains` is True, the contacts from additional - non-mapped model chains are added to the reference's total. + The score is computed as conserved contacts divided by the total contacts + in the reference using the :attr:`oligo_lddt_scorer`, which uses the full + complex as reference/model structure. If :attr:`penalize_extra_chains` is + True, the reference/model complexes contain all chains (otherwise only the + mapped ones) and additional contacts are added to the reference's total + contacts for unmapped model chains according to the :attr:`chem_mapping`. The main difference with :attr:`weighted_lddt` is that the lDDT scorer "sees" the full complex here (incl. inter-chain contacts), while the @@ -1058,236 +1087,159 @@ class OligoLDDTScorer(object): :type: :class:`float` """ if self._oligo_lddt is None: - lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, a, b, c = \ - self.lddt_scorer.lDDT(self.mdl, - thresholds = self.settings.cutoffs, - chain_mapping = self.chain_mapping, - no_interchain=False, - penalize_extra_chains=self.penalize_extra_chains, - residue_mapping=self.alignments, - local_lddt_prop="oligo_lddt", - local_contact_prop="oligo_contact", - return_dist_test=True) - self._oligo_lddt = lDDT - self._oligo_lddt_tot = lDDT_tot - self._oligo_lddt_cons = lDDT_cons + LogInfo('Reference %s has: %s chains' \ + % (self.ref.GetName(), self.ref.chain_count)) + LogInfo('Model %s has: %s chains' \ + % (self.mdl.GetName(), self.mdl.chain_count)) + + # score with or w/o extra-chain penalty + if self.penalize_extra_chains: + denominator = self.oligo_lddt_scorer.total_contacts + denominator += self._GetModelPenalty() + if denominator > 0: + oligo_lddt = self.oligo_lddt_scorer.conserved_contacts \ + / float(denominator) + else: + oligo_lddt = 0.0 + else: + oligo_lddt = self.oligo_lddt_scorer.global_score + self._oligo_lddt = oligo_lddt return self._oligo_lddt @property - def oligo_lddt_tot(self): - """Number of total contacts used for oligo_lddt + def weighted_lddt(self): + """Weighted average of single chain lDDT scores. + + The score is computed as a weighted average of single chain lDDT scores + (see :attr:`sc_lddt_scorers`) using the total contacts of each single + reference chain as weights. If :attr:`penalize_extra_chains` is True, + unmapped chains are added with a 0 score and total contacts taken from + the actual reference chains or (for unmapped model chains) using the + :attr:`chem_mapping`. - Potentially includes penalty contacts from non-mapped model chains + See :attr:`oligo_lddt` for a comparison of the two scores. :getter: Computed on first use (cached) - :type: :class:`int` + :type: :class:`float` """ - if self._oligo_lddt_tot is None: - yolo = self.oligo_lddt - assert(self._oligo_lddt_tot is not None) - return self._oligo_lddt_tot + if self._weighted_lddt is None: + scores = [s.global_score for s in self.sc_lddt_scorers] + weights = [s.total_contacts for s in self.sc_lddt_scorers] + nominator = sum([s * w for s, w in zip(scores, weights)]) + if self.penalize_extra_chains: + ref_scorers = self._GetRefScorers() + denominator = sum(s.total_contacts for s in list(ref_scorers.values())) + denominator += self._GetModelPenalty() + else: + denominator = sum(weights) + if denominator > 0: + self._weighted_lddt = nominator / float(denominator) + else: + self._weighted_lddt = 0.0 + return self._weighted_lddt @property - def oligo_lddt_cons(self): - """Number of conserved contacts used for oligo_lddt + def lddt_ref(self): + """The reference entity used for oligomeric lDDT scoring + (:attr:`oligo_lddt` / :attr:`oligo_lddt_scorer`). + + Since the lDDT computation requires a single chain with mapped residue + numbering, all chains of :attr:`ref` are appended into a single chain X with + unique residue numbers according to the column-index in the alignment. The + alignments are in the same order as they appear in :attr:`alignments`. + Additional residues are appended at the end of the chain with unique residue + numbers. Unmapped chains are only added if :attr:`penalize_extra_chains` is + True. Only CA atoms are considered if :attr:`calpha_only` is True. :getter: Computed on first use (cached) - :type: :class:`int` + :type: :class:`~ost.mol.EntityHandle` """ - if self._oligo_lddt_cons is None: - yolo = self.oligo_lddt - assert(self._oligo_lddt_cons is not None) - return self._oligo_lddt_cons - + if self._lddt_ref is None: + self._PrepareOligoEntities() + return self._lddt_ref + @property - def oligo_lddt_per_res(self): - """Per residue scores based on oligo_lddt + def lddt_mdl(self): + """The model entity used for oligomeric lDDT scoring + (:attr:`oligo_lddt` / :attr:`oligo_lddt_scorer`). - Each scored residue gets a dict with keys: - ["residue_number", "residue_name", "chain", "lddt", "conserved_contacts", - "total_contacts"] - The first three uniquely identify the residue and refer to the residue in - self.ref. + Like :attr:`lddt_ref`, this is a single chain X containing all chains of + :attr:`mdl`. The residue numbers match the ones in :attr:`lddt_ref` where + aligned and have unique numbers for additional residues. :getter: Computed on first use (cached) - :type: :class:`list` of :class:`dict` + :type: :class:`~ost.mol.EntityHandle` """ - if self._oligo_lddt_per_res is None: - yolo = self.oligo_lddt # trigger oligo_lddt computation to assign scores - # and contacts as generic properties on residues - self._oligo_lddt_per_res = self._GetPerResidueScores(self.alignments, - "oligo_lddt", - "oligo_contact_cons", - "oligo_contact_exp") - return self._oligo_lddt_per_res + if self._lddt_mdl is None: + self._PrepareOligoEntities() + return self._lddt_mdl @property - def sc_lddt(self): - """List of global lDDT score for each chain mapping in self.alignments. + def oligo_lddt_scorer(self): + """lDDT Scorer object for :attr:`lddt_ref` and :attr:`lddt_mdl`. :getter: Computed on first use (cached) - :type: :class:`list` of :class:`float` + :type: :class:`~ost.mol.alg.lDDTScorer` """ - if self._sc_lddt is None: - yolo = self.weighted_lddt # sc_lddt is computed as a side product - assert(self._sc_lddt is not None) - assert(self._sc_lddt_tot is not None) - assert(self._sc_lddt_cons is not None) - return self._sc_lddt + if self._oligo_lddt_scorer is None: + self._oligo_lddt_scorer = lDDTScorer( + references=[self.lddt_ref.Select("")], + model=self.lddt_mdl.Select(""), + settings=self.settings) + return self._oligo_lddt_scorer @property - def sc_lddt_tot(self): - """Number of total contacts for each chain mapping in self.alignments + def mapped_lddt_scorers(self): + """List of scorer objects for each chain mapped in :attr:`alignments`. :getter: Computed on first use (cached) - :type: :class:`list` of :class:`int` + :type: :class:`list` of :class:`MappedLDDTScorer` """ - if self._sc_lddt_tot is None: - yolo = self.sc_lddt # sc_lddt_tot is computed as a sideproduct - assert(self._sc_lddt_tot is not None) - return self._sc_lddt_tot - - @property - def sc_lddt_cons(self): - """Number of conserved contacts for each chain mapping in self.alignments - - :getter: Computed on first use (cached) - :type: :class:`list` of :class:`int` - """ - if self._sc_lddt_cons is None: - yolo = self.sc_lddt # sc_lddt_tot is computed as a sideproduct - assert(self._sc_lddt_cons is not None) - return self._sc_lddt_cons + if self._mapped_lddt_scorers is None: + self._mapped_lddt_scorers = list() + for aln in self.alignments: + mapped_lddt_scorer = MappedLDDTScorer(aln, self.calpha_only, + self.settings) + self._mapped_lddt_scorers.append(mapped_lddt_scorer) + return self._mapped_lddt_scorers @property - def sc_lddt_per_res(self): - """Per residue scores based on sc_lddt + def sc_lddt_scorers(self): + """List of lDDT scorer objects extracted from :attr:`mapped_lddt_scorers`. - Each scored residue gets a dict with keys: - ["residue_number", "residue_name", "chain", "lddt", "conserved_contacts", - "total_contacts"] - The first three uniquely identify the residue and refer to the residue in - self.ref. - - :getter: Computed on first use (cached) - :type: :class:`list` of :class:`dict` + :type: :class:`list` of :class:`~ost.mol.alg.lDDTScorer` """ - if self._sc_lddt_per_res is None: - yolo = self.sc_lddt # trigger sc_lddt computation to assign scores - # and contacts as generic properties on residues - self._sc_lddt_per_res = self._GetPerResidueScores(self.alignments, - "sc_lddt", - "sc_contact_cons", - "sc_contact_exp") - return self._sc_lddt_per_res + return [mls.lddt_scorer for mls in self.mapped_lddt_scorers] @property - def weighted_lddt(self): - """Weighted average of single chain lDDT scores. + def sc_lddt(self): + """List of global scores extracted from :attr:`sc_lddt_scorers`. - The score is computed as a weighted average of single chain lDDT scores. - In principle thats oligo_lddt without inter-chain contacts. - (see :attr:`sc_lddt_scorers`). Chains in ref which are not mapped - penalize the overall score as their contacts are not conserved. - Chains in mdl which are not mapped only penalize the score if - :attr:`penalize_extra_chains` is True. + If scoring for a mapped chain fails, an error is displayed and a score of 0 + is assigned. :getter: Computed on first use (cached) - :type: :class:`float` + :type: :class:`list` of :class:`float` """ - if self._weighted_lddt is None: - lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, per_res_exp, \ - per_res_conserved = \ - self.lddt_scorer.lDDT(self.mdl, - thresholds = self.settings.cutoffs, - chain_mapping = self.chain_mapping, - no_interchain=True, - penalize_extra_chains=self.penalize_extra_chains, - residue_mapping=self.alignments, - local_lddt_prop="sc_lddt", - local_contact_prop="sc_contact", - return_dist_test=True) - self._weighted_lddt = lDDT - - # we directly use the results from above to also compute the - # single chain lDDTs manually + if self._sc_lddt is None: self._sc_lddt = list() - self._sc_lddt_tot = list() - self._sc_lddt_cons = list() - chain_res_indices = dict() - residues = self.mdl.residues - for i, r_idx in enumerate(res_indices): - r = residues[r_idx] - ch = r.GetChain().GetName() - if ch not in chain_res_indices: - chain_res_indices[ch] = list() - chain_res_indices[ch].append(i) - - n_thresholds = len(self.settings.cutoffs) - for aln in self.alignments: - ch = aln.GetSequence(0).GetName() - cons = int(np.sum(per_res_conserved.take(chain_res_indices[ch], axis=0))) - tot = self.lddt_scorer.GetNChainContacts(ch, no_interchain=True) - tot*=n_thresholds - if tot > 0: - self._sc_lddt.append(float(cons)/tot) - else: - self._sc_lddt.append(0) - self._sc_lddt_tot.append(tot) - self._sc_lddt_cons.append(cons) - return self._weighted_lddt - - @property - def lddt_scorer(self): - if self._lddt_scorer is None: - if not conop.GetDefaultLib(): - raise RuntimeError("OligolDDT computation requires a compound library!") - r = self.settings.radius - seq_sep = self.settings.sequence_separation - self._lddt_scorer = lddt.lDDTScorer(self.ref, - inclusion_radius = r, - sequence_separation = seq_sep) - return self._lddt_scorer - - @property - def chain_mapping(self): - if self._chain_mapping is None: - # chain mapping as required by lddt_scorer - # key: model chain, value: reference chain - self._chain_mapping = dict() - for aln in self.alignments: - ref_seq = aln.GetSequence(0) - mdl_seq = aln.GetSequence(1) - self._chain_mapping[mdl_seq.GetName()] = ref_seq.GetName() - return self._chain_mapping + for lddt_scorer in self.sc_lddt_scorers: + try: + self._sc_lddt.append(lddt_scorer.global_score) + except Exception as ex: + LogError('Single chain lDDT failed:', str(ex)) + self._sc_lddt.append(0.0) + return self._sc_lddt ############################################################################## # Class internal helpers ############################################################################## - @staticmethod - def _GetPerResidueScores(alignments, lddt_prop, cons_prop, tot_prop): - per_residue_sc = list() - for aln in alignments: - reference_chain = aln.GetSequence(0).GetName() - for col in aln: - ref_res = col.GetResidue(0) - mdl_res = col.GetResidue(1) - if ref_res.IsValid() and mdl_res.IsValid(): - if mdl_res.HasProp(lddt_prop) and mdl_res.HasProp(cons_prop) and \ - mdl_res.HasProp(tot_prop): - num = ref_res.GetNumber().GetNum() - name = ref_res.GetName() - score = mdl_res.GetFloatProp(lddt_prop) - cons = mdl_res.GetIntProp(cons_prop) - tot = mdl_res.GetIntProp(tot_prop) - per_residue_sc.append({"residue_number": num, - "residue_name": name, - "chain": reference_chain, - "lddt": score, - "conserved_contacts": cons, - "total_contacts": tot}) - return per_residue_sc + def _PrepareOligoEntities(self): + # simple wrapper to avoid code duplication + self._lddt_ref, self._lddt_mdl = _MergeAlignedChains( + self.alignments, self.ref, self.mdl, self.calpha_only, + self.penalize_extra_chains) @staticmethod def _GetUnmappedMdlChains(mdl, alignments): @@ -1296,6 +1248,224 @@ class OligoLDDTScorer(object): mapped_mdl_chains = set(aln.GetSequence(1).GetName() for aln in alignments) return (mdl_chains - mapped_mdl_chains) + def _GetRefScorers(self): + # single chain lddt scorers for each reference chain (key = chain name) + if self._ref_scorers is None: + # collect from mapped_lddt_scorers + ref_scorers = dict() + for mapped_lddt_scorer in self.mapped_lddt_scorers: + ref_ch_name = mapped_lddt_scorer.reference_chain_name + ref_scorers[ref_ch_name] = mapped_lddt_scorer.lddt_scorer + # add new ones where needed + for ch in self.ref.chains: + if ch.name not in ref_scorers: + if self.calpha_only: + ref_chain = ch.Select('aname=CA') + else: + ref_chain = ch.Select('') + ref_scorers[ch.name] = lDDTScorer( + references=[ref_chain], + model=ref_chain, + settings=self.settings) + # store in cache + self._ref_scorers = ref_scorers + # fetch from cache + return self._ref_scorers + + def _GetModelPenalty(self): + # extra value to add to total number of distances for extra model chains + # -> estimated from chem-mapped reference chains + if self._model_penalty is None: + # sanity check + if self.chem_mapping is None: + raise RuntimeError("Must provide chem_mapping when requesting penalty " + "for extra model chains!") + # get cached ref_scorers + ref_scorers = self._GetRefScorers() + # get unmapped model chains + unmapped_mdl_chains = self._GetUnmappedMdlChains(self.mdl, self.alignments) + # map extra chains to ref. chains + model_penalty = 0 + for ch_name_mdl in sorted(unmapped_mdl_chains): + # get penalty for chain + cur_penalty = None + for cm_ref, cm_mdl in self.chem_mapping.items(): + if ch_name_mdl in cm_mdl: + # penalize by an average of the chem. mapped ref. chains + cur_penalty = 0 + for ch_name_ref in cm_ref: + # assumes that total_contacts is cached (for speed) + cur_penalty += ref_scorers[ch_name_ref].total_contacts + cur_penalty /= float(len(cm_ref)) + break + # report penalty + if cur_penalty is None: + LogWarning('Extra MODEL chain %s could not be chemically mapped to ' + 'any chain in REFERENCE, lDDT cannot consider it!' \ + % ch_name_mdl) + else: + LogScript('Extra MODEL chain %s added to lDDT score by considering ' + 'chemically mapped chains in REFERENCE.' % ch_name_mdl) + model_penalty += cur_penalty + # store in cache + self._model_penalty = model_penalty + # fetch from cache + return self._model_penalty + + +class MappedLDDTScorer(object): + """A simple class to calculate a single-chain lDDT score on a given chain to + chain mapping as extracted from :class:`OligoLDDTScorer`. + + :param alignment: Sets :attr:`alignment` + :param calpha_only: Sets :attr:`calpha_only` + :param settings: Sets :attr:`settings` + + .. attribute:: alignment + + Alignment with two sequences named according to the mapped chains and with + views attached to both sequences (e.g. one of the items of + :attr:`QSscorer.alignments`). + + The first sequence is assumed to be the reference and the second one the + model. Since the lDDT score is not symmetric (extra residues in model are + ignored), the order is important. + + :type: :class:`~ost.seq.AlignmentHandle` + + .. attribute:: calpha_only + + If True, restricts lDDT score to CA only. + + :type: :class:`bool` + + .. attribute:: settings + + Settings to use for lDDT scoring. + + :type: :class:`~ost.mol.alg.lDDTSettings` + + .. attribute:: lddt_scorer + + lDDT Scorer object for the given chains. + + :type: :class:`~ost.mol.alg.lDDTScorer` + + .. attribute:: reference_chain_name + + Chain name of the reference. + + :type: :class:`str` + + .. attribute:: model_chain_name + + Chain name of the model. + + :type: :class:`str` + """ + def __init__(self, alignment, calpha_only, settings): + # prepare fields + self.alignment = alignment + self.calpha_only = calpha_only + self.settings = settings + self.lddt_scorer = None # set in _InitScorer + self.reference_chain_name = alignment.sequences[0].name + self.model_chain_name = alignment.sequences[1].name + self._old_number_label = "old_num" + self._extended_alignment = None # set in _InitScorer + # initialize lDDT scorer + self._InitScorer() + + def GetPerResidueScores(self): + """ + :return: Scores for each residue + :rtype: :class:`list` of :class:`dict` with one item for each residue + existing in model and reference: + + - "residue_number": Residue number in reference chain + - "residue_name": Residue name in reference chain + - "lddt": local lDDT + - "conserved_contacts": number of conserved contacts + - "total_contacts": total number of contacts + """ + scores = list() + assigned_residues = list() + # Make sure the score is calculated + self.lddt_scorer.global_score + for col in self._extended_alignment: + if col[0] != "-" and col.GetResidue(3).IsValid(): + ref_res = col.GetResidue(0) + mdl_res = col.GetResidue(1) + ref_res_renum = col.GetResidue(2) + mdl_res_renum = col.GetResidue(3) + if ref_res.one_letter_code != ref_res_renum.one_letter_code: + raise RuntimeError("Reference residue name mapping inconsistent: %s != %s" % + (ref_res.one_letter_code, + ref_res_renum.one_letter_code)) + if mdl_res.one_letter_code != mdl_res_renum.one_letter_code: + raise RuntimeError("Model residue name mapping inconsistent: %s != %s" % + (mdl_res.one_letter_code, + mdl_res_renum.one_letter_code)) + if ref_res.GetNumber().num != ref_res_renum.GetIntProp(self._old_number_label): + raise RuntimeError("Reference residue number mapping inconsistent: %s != %s" % + (ref_res.GetNumber().num, + ref_res_renum.GetIntProp(self._old_number_label))) + if mdl_res.GetNumber().num != mdl_res_renum.GetIntProp(self._old_number_label): + raise RuntimeError("Model residue number mapping inconsistent: %s != %s" % + (mdl_res.GetNumber().num, + mdl_res_renum.GetIntProp(self._old_number_label))) + if ref_res.qualified_name in assigned_residues: + raise RuntimeError("Duplicated residue in reference: " % + (ref_res.qualified_name)) + else: + assigned_residues.append(ref_res.qualified_name) + # check if property there (may be missing for CA-only) + if mdl_res_renum.HasProp(self.settings.label): + scores.append({ + "residue_number": ref_res.GetNumber().num, + "residue_name": ref_res.name, + "lddt": mdl_res_renum.GetFloatProp(self.settings.label), + "conserved_contacts": mdl_res_renum.GetFloatProp(self.settings.label + "_conserved"), + "total_contacts": mdl_res_renum.GetFloatProp(self.settings.label + "_total")}) + return scores + + ############################################################################## + # Class internal helpers (anything that doesnt easily work without this class) + ############################################################################## + + def _InitScorer(self): + # Use copy of alignment (extended by 2 extra sequences for renumbering) + aln = self.alignment.Copy() + # Get chains and renumber according to alignment (for lDDT) + reference = Renumber( + aln.GetSequence(0), + old_number_label=self._old_number_label).CreateFullView() + refseq = seq.CreateSequence( + "reference_renumbered", + aln.GetSequence(0).GetString()) + refseq.AttachView(reference) + aln.AddSequence(refseq) + model = Renumber( + aln.GetSequence(1), + old_number_label=self._old_number_label).CreateFullView() + modelseq = seq.CreateSequence( + "model_renumbered", + aln.GetSequence(1).GetString()) + modelseq.AttachView(model) + aln.AddSequence(modelseq) + # Filter to CA-only if desired (done after AttachView to not mess it up) + if self.calpha_only: + self.lddt_scorer = lDDTScorer( + references=[reference.Select('aname=CA')], + model=model.Select('aname=CA'), + settings=self.settings) + else: + self.lddt_scorer = lDDTScorer( + references=[reference], + model=model, + settings=self.settings) + # Store alignment for later + self._extended_alignment = aln ############################################################################### # HELPERS @@ -2720,6 +2890,101 @@ def _GetQsSuperposition(alns): res = mol.alg.SuperposeSVD(view_1, view_2, apply_transform=False) return res + +def _AddResidue(edi, res, rnum, chain, calpha_only): + """ + Add residue *res* with res. num. *run* to given *chain* using editor *edi*. + Either all atoms added or (if *calpha_only*) only CA. + """ + if calpha_only: + ca_atom = res.FindAtom('CA') + if ca_atom.IsValid(): + new_res = edi.AppendResidue(chain, res.name, rnum) + edi.InsertAtom(new_res, ca_atom.name, ca_atom.pos) + else: + new_res = edi.AppendResidue(chain, res.name, rnum) + for atom in res.atoms: + edi.InsertAtom(new_res, atom.name, atom.pos) + +def _MergeAlignedChains(alns, ent_1, ent_2, calpha_only, penalize_extra_chains): + """ + Create two new entities (based on the alignments attached views) where all + residues have same numbering (when they're aligned) and they are all pushed to + a single chain X. Also append extra chains contained in *ent_1* and *ent_2* + but not contained in *alns*. + + Used for :attr:`QSscorer.lddt_ref` and :attr:`QSscorer.lddt_mdl` + + :param alns: List of alignments with attached views (first sequence: *ent_1*, + second: *ent_2*). Residue number in single chain is column index + of current alignment + sum of lengths of all previous alignments + (order of alignments as in input list). + :type alns: See :attr:`QSscorer.alignments` + :param ent_1: First entity to process. + :type ent_1: :class:`~ost.mol.EntityHandle` + :param ent_2: Second entity to process. + :type ent_2: :class:`~ost.mol.EntityHandle` + :param calpha_only: If True, we only include CA atoms instead of all. + :type calpha_only: :class:`bool` + :param penalize_extra_chains: If True, extra chains are added to model and + reference. Otherwise, only mapped ones. + :type penalize_extra_chains: :class:`bool` + + :return: Tuple of two single chain entities (from *ent_1* and from *ent_2*) + :rtype: :class:`tuple` of :class:`~ost.mol.EntityHandle` + """ + # first new entity + ent_ren_1 = mol.CreateEntity() + ed_1 = ent_ren_1.EditXCS() + new_chain_1 = ed_1.InsertChain('X') + # second one + ent_ren_2 = mol.CreateEntity() + ed_2 = ent_ren_2.EditXCS() + new_chain_2 = ed_2.InsertChain('X') + # the alignment already contains sorted chains + rnum = 0 + chain_done_1 = set() + chain_done_2 = set() + for aln in alns: + chain_done_1.add(aln.GetSequence(0).name) + chain_done_2.add(aln.GetSequence(1).name) + for col in aln: + rnum += 1 + # add valid residues to single chain entities + res_1 = col.GetResidue(0) + if res_1.IsValid(): + _AddResidue(ed_1, res_1, rnum, new_chain_1, calpha_only) + res_2 = col.GetResidue(1) + if res_2.IsValid(): + _AddResidue(ed_2, res_2, rnum, new_chain_2, calpha_only) + # extra chains? + if penalize_extra_chains: + for chain in ent_1.chains: + if chain.name in chain_done_1: + continue + for res in chain.residues: + rnum += 1 + _AddResidue(ed_1, res, rnum, new_chain_1, calpha_only) + for chain in ent_2.chains: + if chain.name in chain_done_2: + continue + for res in chain.residues: + rnum += 1 + _AddResidue(ed_2, res, rnum, new_chain_2, calpha_only) + # get entity names + ent_ren_1.SetName(aln.GetSequence(0).GetAttachedView().GetName()) + ent_ren_2.SetName(aln.GetSequence(1).GetAttachedView().GetName()) + # connect atoms + if not conop.GetDefaultLib(): + raise RuntimeError("QSscore computation requires a compound library!") + pr = conop.RuleBasedProcessor(conop.GetDefaultLib()) + pr.Process(ent_ren_1) + ed_1.UpdateICS() + pr.Process(ent_ren_2) + ed_2.UpdateICS() + return ent_ren_1, ent_ren_2 + + # specify public interface __all__ = ('QSscoreError', 'QSscorer', 'QSscoreEntity', 'FilterContacts', - 'GetContacts', 'OligoLDDTScorer') + 'GetContacts', 'OligoLDDTScorer', 'MappedLDDTScorer') diff --git a/modules/mol/alg/pymod/wrap_mol_alg.cc b/modules/mol/alg/pymod/wrap_mol_alg.cc index 8d57c57dfc06b634d82705c4a879085c89edc3a9..e1b80398615764e98b67b912ae0e75758b24c136 100644 --- a/modules/mol/alg/pymod/wrap_mol_alg.cc +++ b/modules/mol/alg/pymod/wrap_mol_alg.cc @@ -172,9 +172,6 @@ object lDDTSettingsInitWrapper(tuple args, dict kwargs){ label); } -/* -lDDTScorer is commented out to not collide with the new lDDTScorer class -that lives in Python object lDDTScorerInitWrapper(tuple args, dict kwargs){ object self = args[0]; @@ -220,7 +217,7 @@ object lDDTScorerInitWrapper(tuple args, dict kwargs){ model, settings); } -*/ + void clean_lddt_references_wrapper(const list& reference_list) { @@ -268,10 +265,6 @@ list get_lddt_per_residue_stats_wrapper(mol::EntityView& model, return local_scores_list; } -/* -lDDTScorer is commented out to not collide with the new lDDTScorer class -that lives in Python - list get_local_scores_wrapper(mol::alg::lDDTScorer& scorer) { std::vector<mol::alg::lDDTLocalScore> scores = scorer.GetLocalScores(); list local_scores_list; @@ -289,7 +282,7 @@ list get_references_wrapper(mol::alg::lDDTScorer& scorer) { } return local_references_list; } -*/ + void print_lddt_per_residue_stats_wrapper(list& scores, bool structural_checks, int cutoffs_size){ int scores_length = boost::python::extract<int>(scores.attr("__len__")()); @@ -402,9 +395,6 @@ BOOST_PYTHON_MODULE(_ost_mol_alg) .def_readwrite("conserved_dist", &mol::alg::lDDTLocalScore::conserved_dist) .def_readwrite("total_dist", &mol::alg::lDDTLocalScore::total_dist); - /* - lDDTScorer is commented out to not collide with the new lDDTScorer class - that lives in Python class_<mol::alg::lDDTScorer>("lDDTScorer", no_init) .def("__init__", raw_function(lDDTScorerInitWrapper)) .def(init<std::vector<mol::EntityView>&, mol::EntityView&, mol::alg::lDDTSettings&>()) @@ -416,7 +406,6 @@ BOOST_PYTHON_MODULE(_ost_mol_alg) .def_readonly("model", &mol::alg::lDDTScorer::model_view) .add_property("references", &get_references_wrapper) .add_property("is_valid", &mol::alg::lDDTScorer::IsValid); - */ class_<mol::alg::StereoChemicalProps>("StereoChemicalProps", init<mol::alg::StereoChemicalParams&, diff --git a/modules/mol/alg/tests/test_qsscoring.py b/modules/mol/alg/tests/test_qsscoring.py index 1810ad054ba1a0924475f4dcd1fefe3abb80c9f7..31daffe565ed5ac3d3cb8cf59ecb76dee7e62534 100644 --- a/modules/mol/alg/tests/test_qsscoring.py +++ b/modules/mol/alg/tests/test_qsscoring.py @@ -424,12 +424,13 @@ class TestQSscore(unittest.TestCase): lddt_oligo_scorer2 = qs_scorer2.GetOligoLDDTScorer(lddt_settings, False) self.assertAlmostEqual(qs_scorer2.global_score, 0.171, 2) self.assertAlmostEqual(qs_scorer2.best_score, 1.00, 2) - self.assertAlmostEqual(lddt_oligo_scorer2.weighted_lddt, 0.4996, 2) + self.assertAlmostEqual(lddt_oligo_scorer2.weighted_lddt, 1.00, 2) self.assertEqual(len(lddt_oligo_scorer2.sc_lddt), 2) self.assertAlmostEqual(lddt_oligo_scorer2.sc_lddt[0], 1.00, 2) self.assertAlmostEqual(lddt_oligo_scorer2.sc_lddt[1], 1.00, 2) - self.assertAlmostEqual(lddt_oligo_scorer2.oligo_lddt, 0.4496, 2) - # penalty only affects additional model chains, scores are thus the same + # without penalty we don't see extra chains + self.assertAlmostEqual(lddt_oligo_scorer2.oligo_lddt, 1.00, 2) + # with penalty we account for extra reference chains lddt_oligo_scorer2_pen = qs_scorer2.GetOligoLDDTScorer(lddt_settings, True) self.assertAlmostEqual(lddt_oligo_scorer2_pen.oligo_lddt, 0.4496, 2) self.assertAlmostEqual(lddt_oligo_scorer2_pen.weighted_lddt, 0.50, 2)