diff --git a/modules/mol/alg/pymod/qsscoring.py b/modules/mol/alg/pymod/qsscoring.py index aa55203a930c5ee20372df794a38a2bfc2a9af55..937c40910ec1a57d8c4301e47ede0f6aedb815e6 100644 --- a/modules/mol/alg/pymod/qsscoring.py +++ b/modules/mol/alg/pymod/qsscoring.py @@ -449,12 +449,19 @@ class QSscorer: self._clustalw_bin = settings.Locate(('clustalw', 'clustalw2')) return self._clustalw_bin - def GetOligoLDDTScorer(self, settings): + def GetOligoLDDTScorer(self, settings, penalize_extra_chains=True): """ :return: :class:`OligoLDDTScorer` object, setup for this QS scoring problem. """ - return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, - self.alignments, self.calpha_only, settings) + # TODO: DOCUMENT PARAMS + if penalize_extra_chains: + return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, + self.alignments, self.calpha_only, settings, + True, self.chem_mapping) + else: + return OligoLDDTScorer(self.qs_ent_1.ent, self.qs_ent_2.ent, + self.alignments, self.calpha_only, settings, False) + ############################################################################## # Class internal helpers (anything that doesnt easily work without this class) @@ -841,18 +848,40 @@ class OligoLDDTScorer(object): """A simple class to calculate oligomeric lDDT score.""" # TODO: DOCUMENT - - def __init__(self, ref, mdl, alignments, calpha_only, settings): - if mdl.chain_count > ref.chain_count: - LogWarning('MODEL contains more chains than REFERENCE, ' - 'lDDT is not considering them') + # -> make sure to mention assumption on sequence naming in alignments + + # TODO: one could also allow computation of both penalized and unpenalized + # in same object -> must regenerate lddt_ref / lddt_mdl though + + def __init__(self, ref, mdl, alignments, calpha_only, settings, + penalize_extra_chains=False, chem_mapping=None): + # sanity checks + if chem_mapping is None and penalize_extra_chains: + raise RuntimeError("Must provide chem_mapping when requesting penalty " + "for extra chains!") + if not penalize_extra_chains: + # warn for unmapped model chains + unmapped_mdl_chains = self._GetUnmappedMdlChains(mdl, alignments) + if unmapped_mdl_chains: + LogWarning('MODEL contains chains unmapped to REFERENCE, ' + 'lDDT is not considering MODEL chains %s' \ + % str(list(unmapped_mdl_chains))) + # warn for unmapped reference chains + ref_chains = set(ch.name for ch in ref.chains) + mapped_ref_chains = set(aln.GetSequence(0).GetName() for aln in alignments) + unmapped_ref_chains = (ref_chains - mapped_ref_chains) + if unmapped_ref_chains: + LogWarning('REFERENCE contains chains unmapped to MODEL, ' + 'lDDT is not considering REFERENCE chains %s' \ + % str(list(unmapped_ref_chains))) # prepare fields self.ref = ref self.mdl = mdl self.alignments = alignments self.calpha_only = calpha_only self.settings = settings - self._old_number_label = "old_num" + self.penalize_extra_chains = penalize_extra_chains + self.chem_mapping = chem_mapping self._sc_lddt = None self._oligo_lddt = None self._weighted_lddt = None @@ -860,6 +889,8 @@ class OligoLDDTScorer(object): self._lddt_mdl = None self._oligo_lddt_scorer = None self._mapped_lddt_scorers = None + self._ref_scorers = None + self._model_penalty = None @property def lddt_ref(self): @@ -880,8 +911,18 @@ class OligoLDDTScorer(object): LogInfo('Reference %s has: %s chains' % (self.ref.GetName(), self.ref.chain_count)) LogInfo('Model %s has: %s chains' % (self.mdl.GetName(), self.mdl.chain_count)) - # score them (mdl and ref changed) and keep results - self._oligo_lddt = self.oligo_lddt_scorer.global_score + # score with or w/o extra-chain penalty + if self.penalize_extra_chains: + denominator = self.oligo_lddt_scorer.total_contacts + denominator += self._GetModelPenalty() + if denominator > 0: + oligo_lddt = self.oligo_lddt_scorer.conserved_contacts \ + / float(denominator) + else: + oligo_lddt = 0.0 + else: + oligo_lddt = self.oligo_lddt_scorer.global_score + self._oligo_lddt = oligo_lddt return self._oligo_lddt @property @@ -925,7 +966,12 @@ class OligoLDDTScorer(object): scores = [s.global_score for s in self.sc_lddt_scorers] weights = [s.total_contacts for s in self.sc_lddt_scorers] nominator = sum([s * w for s, w in zip(scores, weights)]) - denominator = sum(weights) + if self.penalize_extra_chains: + ref_scorers = self._GetRefScorers() + denominator = sum(s.total_contacts for s in ref_scorers.values()) + denominator += self._GetModelPenalty() + else: + denominator = sum(weights) if denominator > 0: self._weighted_lddt = nominator / float(denominator) else: @@ -938,10 +984,80 @@ class OligoLDDTScorer(object): def _PrepareOligoEntities(self): # simple wrapper to avoid code duplication - self._lddt_ref, self._lddt_mdl = _MergeAlignedChains(self.alignments, - self.ref, - self.mdl, - self.calpha_only) + self._lddt_ref, self._lddt_mdl = _MergeAlignedChains( + self.alignments, self.ref, self.mdl, self.calpha_only, + self.penalize_extra_chains) + + def _GetUnmappedMdlChains(self, mdl, alignments): + # TODO: maybe move out of class since it doesn't need self + # assume model is second sequence in alignment and is named by chain + mdl_chains = set(ch.name for ch in mdl.chains) + mapped_mdl_chains = set(aln.GetSequence(1).GetName() for aln in alignments) + return (mdl_chains - mapped_mdl_chains) + + def _GetRefScorers(self): + # single chain lddt scorers for each reference chain (key = chain name) + if self._ref_scorers is None: + # collect from mapped_lddt_scorers + ref_scorers = dict() + for mapped_lddt_scorer in self.mapped_lddt_scorers: + ref_ch_name = mapped_lddt_scorer.reference_chain_name + ref_scorers[ref_ch_name] = mapped_lddt_scorer.lddt_scorer + # add new ones where needed + for ch in self.ref.chains: + if ch.name not in ref_scorers: + if self.calpha_only: + ref_chain = ch.Select('aname=CA') + else: + ref_chain = ch.Select('') + ref_scorers[ch.name] = lDDTScorer( + references=[ref_chain], + model=ref_chain, + settings=self.settings) + # store in cache + self._ref_scorers = ref_scorers + # fetch from cache + return self._ref_scorers + + def _GetModelPenalty(self): + # extra value to add to total number of distances for extra model chains + # -> estimated from chem-mapped reference chains + if self._model_penalty is None: + # sanity check + if self.chem_mapping is None: + raise RuntimeError("Must provide chem_mapping when requesting penalty " + "for extra model chains!") + # get cached ref_scorers + ref_scorers = self._GetRefScorers() + # get unmapped model chains + unmapped_mdl_chains = self._GetUnmappedMdlChains(self.mdl, self.alignments) + # map extra chains to ref. chains + model_penalty = 0 + for ch_name_mdl in sorted(unmapped_mdl_chains): + # get penalty for chain + cur_penalty = None + for cm_ref, cm_mdl in self.chem_mapping.iteritems(): + if ch_name_mdl in cm_mdl: + # penalize by an average of the chem. mapped ref. chains + cur_penalty = 0 + for ch_name_ref in cm_ref: + # assumes that total_contacts is cached (for speed) + cur_penalty += ref_scorers[ch_name_ref].total_contacts + cur_penalty /= float(len(cm_ref)) + break + # report penalty + if cur_penalty is None: + LogWarning('Extra MODEL chain %s could not be chemically mapped to ' + 'any chain in REFERENCE, lDDT cannot consider it!' \ + % ch_name_mdl) + else: + LogScript('Extra MODEL chain %s added to lDDT score by considering ' + 'chemically mapped chains in REFERENCE.' % ch_name_mdl) + model_penalty += cur_penalty + # store in cache + self._model_penalty = model_penalty + # fetch from cache + return self._model_penalty class MappedLDDTScorer(object): @@ -2521,7 +2637,7 @@ def _AddResidue(edi, res, rnum, chain, calpha_only): for atom in res.atoms: edi.InsertAtom(new_res, atom.name, atom.pos) -def _MergeAlignedChains(alns, ent_1, ent_2, calpha_only): +def _MergeAlignedChains(alns, ent_1, ent_2, calpha_only, penalize_extra_chains): """ Create two new entities (based on the alignments attached views) where all residues have same numbering (when they're aligned) and they are all pushed to @@ -2541,6 +2657,9 @@ def _MergeAlignedChains(alns, ent_1, ent_2, calpha_only): :type ent_2: :class:`~ost.mol.EntityHandle` :param calpha_only: If True, we only include CA atoms instead of all. :type calpha_only: :class:`bool` + :param penalize_extra_chains: If True, extra chains are added to model and + reference. Otherwise, only mapped ones. + :type penalize_extra_chains: :class:`bool` :return: Tuple of two single chain entities (from *ent_1* and from *ent_2*) :rtype: :class:`tuple` of :class:`~ost.mol.EntityHandle` @@ -2569,19 +2688,20 @@ def _MergeAlignedChains(alns, ent_1, ent_2, calpha_only): res_2 = col.GetResidue(1) if res_2.IsValid(): _AddResidue(ed_2, res_2, rnum, new_chain_2, calpha_only) - # extra chains - for chain in ent_1.chains: - if chain.name in chain_done_1: - continue - for res in chain.residues: - rnum += 1 - _AddResidue(ed_1, res, rnum, new_chain_1, calpha_only) - for chain in ent_2.chains: - if chain.name in chain_done_2: - continue - for res in chain.residues: - rnum += 1 - _AddResidue(ed_2, res, rnum, new_chain_2, calpha_only) + # extra chains? + if penalize_extra_chains: + for chain in ent_1.chains: + if chain.name in chain_done_1: + continue + for res in chain.residues: + rnum += 1 + _AddResidue(ed_1, res, rnum, new_chain_1, calpha_only) + for chain in ent_2.chains: + if chain.name in chain_done_2: + continue + for res in chain.residues: + rnum += 1 + _AddResidue(ed_2, res, rnum, new_chain_2, calpha_only) # get entity names ent_ren_1.SetName(aln.GetSequence(0).GetAttachedView().GetName()) ent_ren_2.SetName(aln.GetSequence(1).GetAttachedView().GetName()) diff --git a/modules/mol/alg/tests/test_qsscoring.py b/modules/mol/alg/tests/test_qsscoring.py index 840e0077f9606326908ac0c149278862b9699c67..c241eea49dea794a7f620014561a482019b3507a 100644 --- a/modules/mol/alg/tests/test_qsscoring.py +++ b/modules/mol/alg/tests/test_qsscoring.py @@ -377,23 +377,30 @@ class TestQSscore(unittest.TestCase): # TEST EXTRA SCORES def test_lDDT(self): - # lDDT is not symmetrical and does not account for overprediction! + # check for penalized and unpenalized oligo lDDT ref = _LoadFile('4br6.1.pdb').Select('cname=A,B') mdl = _LoadFile('4br6.1.pdb') lddt_settings = lDDTSettings() qs_scorer = QSscorer(ref, mdl) - lddt_oligo_scorer = qs_scorer.GetOligoLDDTScorer(lddt_settings) + lddt_oligo_scorer = qs_scorer.GetOligoLDDTScorer(lddt_settings, False) self.assertAlmostEqual(qs_scorer.global_score, 0.171, 2) self.assertAlmostEqual(qs_scorer.best_score, 1.00, 2) self.assertAlmostEqual(lddt_oligo_scorer.oligo_lddt, 1.00, 2) + # with penalty we account for extra model chains + lddt_oligo_scorer_pen = qs_scorer.GetOligoLDDTScorer(lddt_settings, True) + self.assertAlmostEqual(lddt_oligo_scorer_pen.oligo_lddt, 0.5213, 2) # flip them (use QSscoreEntity to go faster) qs_scorer2 = QSscorer(qs_scorer.qs_ent_2, qs_scorer.qs_ent_1, res_num_alignment=True) - lddt_oligo_scorer2 = qs_scorer2.GetOligoLDDTScorer(lddt_settings) + lddt_oligo_scorer2 = qs_scorer2.GetOligoLDDTScorer(lddt_settings, False) self.assertAlmostEqual(qs_scorer2.global_score, 0.171, 2) self.assertAlmostEqual(qs_scorer2.best_score, 1.00, 2) - self.assertAlmostEqual(lddt_oligo_scorer2.oligo_lddt, 0.4496, 2) + # without penalty we don't see extra chains + self.assertAlmostEqual(lddt_oligo_scorer2.oligo_lddt, 1.00, 2) + # with penalty we account for extra reference chains + lddt_oligo_scorer2_pen = qs_scorer2.GetOligoLDDTScorer(lddt_settings, True) + self.assertAlmostEqual(lddt_oligo_scorer2_pen.oligo_lddt, 0.4496, 2) # check properties self.assertFalse(qs_scorer.calpha_only) self.assertEqual(qs_scorer.chem_mapping, {('B', 'A'): ('B', 'C', 'D', 'A')})