diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index 6121d05103fce84047d159a4d485b8a07d1e5a57..136e952a1ee7376f92ecab3eee6ca7fe878b7e2d 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -17,7 +17,6 @@ from ost import PushVerbosityLevel from ost.mol.alg import (qsscoring, Molck, MolckSettings, lDDTSettings, lDDTScorer, CheckStructure) from ost.conop import CompoundLib -from ost.seq.alg.renumber import Renumber class _DefaultStereochemicalParamAction(argparse.Action): @@ -504,27 +503,21 @@ def _Main(): label="lddt") if opts.verbosity > 3: lddt_settings.PrintParameters() - # Perform single chain scoring - # Get chains from mapped alignments - lddt_scorers = list() - for aln in qs_scorer.alignments: + + oligo_lddt_scorer = qsscoring.OligoLDDTScorer( + qs_scorer.qs_ent_1.ent, + qs_scorer.qs_ent_2.ent, + qs_scorer.alignments, + qs_scorer.calpha_only, + lddt_settings) + for lddt_scorer in oligo_lddt_scorer.sc_lddt_scorers: # Get chains and renumber according to alignment (for lDDT) - ch_ref = aln.GetSequence(0).GetName() - reference = Renumber(aln.GetSequence(0)).CreateFullView() - ch_mdl = aln.GetSequence(1).GetName() - model = Renumber(aln.GetSequence(1)).CreateFullView() - ost.LogInfo(("Computing lDDT between model chain %s and " - "reference chain %s") % (ch_mdl, ch_ref)) - lddt_scorer = lDDTScorer( - references=[reference], - model=model, - settings=lddt_settings) try: lddt_results["single_chain_lddt"].append({ "status": "SUCCESS", "error": "", - "model_chain": ch_mdl, - "reference_chain": ch_ref, + "model_chain": lddt_scorer.model.chains[0].GetName(), + "reference_chain": lddt_scorer.references[0].chains[0].GetName(), "global_score": lddt_scorer.global_score, "conserved_contacts": lddt_scorer.conserved_contacts, "total_contacts": lddt_scorer.total_contacts}) @@ -533,24 +526,17 @@ def _Main(): lddt_results["single_chain_lddt"].append({ "status": "FAILURE", "error": str(ex), - "model_chain": ch_mdl, - "reference_chain": ch_ref, + "model_chain": lddt_scorer.model.chains[0].GetName(), + "reference_chain": lddt_scorer.references[0].chains[0].GetName(), "global_score": 0.0, "conserved_contacts": 0.0, "total_contacts": 0.0}) - lddt_scorers.append(lddt_scorer) # perform oligo lddt scoring try: - oligo_lddt_scorer = qsscoring.OligoLDDTScorer( - qs_scorer.qs_ent_1.ent, - qs_scorer.qs_ent_2.ent, - qs_scorer.alignments, - qs_scorer.calpha_only, - lddt_settings) lddt_results["oligo_lddt"] = { "status": "SUCCESS", "error": "", - "global_score": oligo_lddt_scorer.lddt} + "global_score": oligo_lddt_scorer.oligo_lddt} except Exception as ex: ost.LogError('Oligo lDDT failed:', str(ex)) lddt_results["oligo_lddt"] = { @@ -558,11 +544,10 @@ def _Main(): "error": str(ex), "global_score": 0.0} try: - weighted_lddt = _AveragelDDT(lddt_scorers) lddt_results["weighted_lddt"] = { "status": "SUCCESS", "error": "", - "global_score": weighted_lddt} + "global_score": oligo_lddt_scorer.weighted_lddt} except Exception as ex: lddt_results["weighted_lddt"] = { "status": "FAILURE", diff --git a/modules/mol/alg/pymod/qsscoring.py b/modules/mol/alg/pymod/qsscoring.py index 094d8b3aa701e6554d9a732afa91f2b7d883c7fc..c32a456432f27ef03ebc632a848a4d6112368002 100644 --- a/modules/mol/alg/pymod/qsscoring.py +++ b/modules/mol/alg/pymod/qsscoring.py @@ -20,6 +20,7 @@ from ost import mol, geom, conop, seq, settings from ost import LogError, LogWarning, LogScript, LogInfo, LogVerbose, LogDebug from ost.bindings.clustalw import ClustalW from ost.mol.alg import lDDTScorer +from ost.seq.alg.renumber import Renumber import numpy as np from scipy.misc import factorial from scipy.special import binom @@ -835,16 +836,17 @@ class OligoLDDTScorer(object): # get single chain reference and model self.ref = ref self.mdl = mdl - self.calpha_only = calpha_only self.alignments = alignments + self.calpha_only = calpha_only self.settings = settings - self._lddt = None + self._sc_lddt = None + self._oligo_lddt = None + self._weighted_lddt = None self._lddt_ref = None self._lddt_mdl = None - self.scorer = lDDTScorer( - references=[self.lddt_ref.Select("")], - model=self.lddt_mdl.Select(""), - settings=self.settings) + self._oligo_lddt_scorer = None + self._sc_lddt_scorers = None + self._report = dict() @property def lddt_ref(self): @@ -865,16 +867,69 @@ class OligoLDDTScorer(object): return self._lddt_mdl @property - def lddt(self): + def oligo_lddt(self): """Fills cached lddt_score, lddt_mdl and lddt_ref.""" - if self._lddt is None: + if self._oligo_lddt is None: LogInfo('Computing oligomeric lDDT score') LogInfo('Reference %s has: %s chains' % (self.ref.GetName(), self.ref.chain_count)) LogInfo('Model %s has: %s chains' % (self.mdl.GetName(), self.mdl.chain_count)) # score them (mdl and ref changed) and keep results - self._lddt = self.scorer.global_score - return self._lddt + self._oligo_lddt = self.oligo_lddt_scorer.global_score + return self._oligo_lddt + + @property + def oligo_lddt_scorer(self): + if self._oligo_lddt_scorer is None: + self._oligo_lddt_scorer = lDDTScorer( + references=[self.lddt_ref.Select("")], + model=self.lddt_mdl.Select(""), + settings=self.settings) + return self._oligo_lddt_scorer + + @property + def sc_lddt_scorers(self): + if self._sc_lddt_scorers is None: + for aln in self.alignments: + self._sc_lddt_scorers = list() + # Get chains and renumber according to alignment (for lDDT) + ch_ref = aln.GetSequence(0).GetName() + reference = Renumber(aln.GetSequence(0)).CreateFullView() + ch_mdl = aln.GetSequence(1).GetName() + model = Renumber(aln.GetSequence(1)).CreateFullView() + LogInfo(("Computing lDDT between model chain %s and " + "reference chain %s") % (ch_mdl, ch_ref)) + lddt_scorer = lDDTScorer( + references=[reference], + model=model, + settings=self.settings) + self._sc_lddt_scorers.append(lddt_scorer) + return self._sc_lddt_scorers + + @property + def sc_lddt(self): + if self._sc_lddt is None: + self._sc_lddt = list() + for lddt_scorer in self.sc_lddt_scorers: + try: + self._sc_lddt.append(lddt_scorer.global_score) + except Exception as ex: + LogError('Single chain lDDT failed:', str(ex)) + self._sc_lddt.append(0.0) + return self._sc_lddt + + @property + def weighted_lddt(self): + if self._weighted_lddt is None: + scores = [s.global_score for s in self.sc_lddt_scorers] + weights = [s.total_contacts for s in self.sc_lddt_scorers] + nominator = sum([s * w for s, w in zip(scores, weights)]) + denominator = sum(weights) + if denominator > 0: + self._weighted_lddt = nominator / float(denominator) + else: + self._weighted_lddt = 0.0 + return self._weighted_lddt ############################################################################### diff --git a/modules/mol/alg/pymod/wrap_mol_alg.cc b/modules/mol/alg/pymod/wrap_mol_alg.cc index efa15a5ea5c7302050d5612c8f7efed024c7c02b..d8f8b46890e7154108f9f55173e7cce94d6adbc5 100644 --- a/modules/mol/alg/pymod/wrap_mol_alg.cc +++ b/modules/mol/alg/pymod/wrap_mol_alg.cc @@ -315,6 +315,16 @@ list get_local_scores_wrapper(mol::alg::lDDTScorer& scorer) { return local_scores_list; } +list get_references_wrapper(mol::alg::lDDTScorer& scorer) { + std::vector<mol::EntityView> references = scorer.GetReferences(); + list local_references_list; + for (std::vector<mol::EntityView>::const_iterator sit = references.begin(); sit != references.end(); ++sit) { + local_references_list.append(*sit); + } + return local_references_list; +} + + void print_lddt_per_residue_stats_wrapper(list& scores, bool structural_checks, int cutoffs_size){ int scores_length = boost::python::extract<int>(scores.attr("__len__")()); std::vector<mol::alg::lDDTLocalScore> scores_vector(scores_length); @@ -430,6 +440,8 @@ BOOST_PYTHON_MODULE(_ost_mol_alg) .add_property("total_contacts", &mol::alg::lDDTScorer::GetNumTotalContacts) .def("PrintPerResidueStats", &mol::alg::lDDTScorer::PrintPerResidueStats) .add_property("local_scores", &get_local_scores_wrapper) + .def_readonly("model", &mol::alg::lDDTScorer::model_view) + .add_property("references", &get_references_wrapper) .add_property("is_valid", &mol::alg::lDDTScorer::IsValid); class_<mol::alg::StereoChemicalProps>("StereoChemicalProps", diff --git a/modules/mol/alg/src/local_dist_diff_test.cc b/modules/mol/alg/src/local_dist_diff_test.cc index 676c955e10651a7b7bc70165b9da24d1c133e095..321f27e28a60a485c407021df3274d8db4ecd319 100644 --- a/modules/mol/alg/src/local_dist_diff_test.cc +++ b/modules/mol/alg/src/local_dist_diff_test.cc @@ -605,6 +605,10 @@ void lDDTScorer::PrintPerResidueStats(){ settings.cutoffs.size()); } +std::vector<EntityView> lDDTScorer::GetReferences(){ + return references_view; +} + void lDDTScorer::_PrepareReferences(std::vector<EntityHandle>& references){ for (unsigned int i = 0; i < references.size(); i++) { if (settings.sel != ""){ diff --git a/modules/mol/alg/src/local_dist_diff_test.hh b/modules/mol/alg/src/local_dist_diff_test.hh index e155f749cacad95353d7362164b71c5b245b3317..c48627203fc247b56ed8526d634ddd8e6bc1ae49 100644 --- a/modules/mol/alg/src/local_dist_diff_test.hh +++ b/modules/mol/alg/src/local_dist_diff_test.hh @@ -112,6 +112,7 @@ class lDDTScorer std::vector<lDDTLocalScore> GetLocalScores(); int GetNumConservedContacts(); // number of conserved distances in the model int GetNumTotalContacts(); // the number of total distances in the reference structure + std::vector<EntityView> GetReferences(); void PrintPerResidueStats(); bool IsValid();