diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index cd2fc76aa739ad56a6d680fbcc4644e03ddf6da5..649522d99f41c1a22223de46ec12be2fcd4565a6 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -781,22 +781,48 @@ def _AlnToFastaStr(aln): s2 = aln.GetSequence(1) return f">reference:{s1.name}\n{str(s1)}\n>model:{s2.name}\n{str(s2)}" -def _GetInconsistentResidues(alns): +def _GetInconsistentResidues(alns, ref, mdl): lst = list() for aln in alns: + + ref_ch = ref.FindChain(aln.GetSequence(0).GetName()) + mdl_ch = mdl.FindChain(aln.GetSequence(1).GetName()) + + if not ref_ch.IsValid(): + raise RuntimeError("ref lacks requested chain in _GetInconsistentResidues") + + if not mdl_ch.IsValid(): + raise RuntimeError("mdl lacks requested chain in _GetInconsistentResidues") + + if len(aln.GetSequence(0).GetGaplessString()) != ref_ch.GetResidueCount(): + raise RuntimeError("aln/chain mismatch in _GetAlignedResidues") + if len(aln.GetSequence(1).GetGaplessString()) != mdl_ch.GetResidueCount(): + raise RuntimeError("aln/chain mismatch in _GetAlignedResidues") + + ref_res = ref_ch.residues + mdl_res = mdl_ch.residues + + ref_res_idx = 0 + mdl_res_idx = 0 + for col in aln: - r1 = col.GetResidue(0) - r2 = col.GetResidue(1) - if r1.IsValid() and r2.IsValid() and r1.GetName() != r2.GetName(): - ch_1 = r1.GetChain().name - num_1 = r1.number.num - ins_code_1 = r1.number.ins_code.strip("\u0000") - id_1 = f"{ch_1}.{num_1}.{ins_code_1}" - ch_2 = r2.GetChain().name - num_2 = r2.number.num - ins_code_2 = r2.number.ins_code.strip("\u0000") - id_2 = f"{ch_2}.{num_2}.{ins_code_2}" - lst.append(f"{id_1}-{id_2}") + if col[0] != '-' and col[1] != '-': + r1 = ref_res[ref_res_idx] + r2 = mdl_res[mdl_res_idx] + if r1.IsValid() and r2.IsValid() and r1.GetName() != r2.GetName(): + ch_1 = r1.GetChain().name + num_1 = r1.number.num + ins_code_1 = r1.number.ins_code.strip("\u0000") + id_1 = f"{ch_1}.{num_1}.{ins_code_1}" + ch_2 = r2.GetChain().name + num_2 = r2.number.num + ins_code_2 = r2.number.ins_code.strip("\u0000") + id_2 = f"{ch_2}.{num_2}.{ins_code_2}" + lst.append(f"{id_1}-{id_2}") + if col[0] != '-': + ref_res_idx += 1 + if col[1] != '-': + mdl_res_idx += 1 return lst def _LocalScoresToJSONDict(score_dict): @@ -842,29 +868,52 @@ def _PatchScoresToJSONList(interface_dict, score_dict): json_list.append(_RoundOrNone(item)) return json_list -def _GetAlignedResidues(aln): +def _GetAlignedResidues(aln, ref, mdl): aligned_residues = list() for a in aln: mdl_lst = list() ref_lst = list() + + ref_ch = ref.FindChain(a.GetSequence(0).GetName()) + mdl_ch = mdl.FindChain(a.GetSequence(1).GetName()) + + if not ref_ch.IsValid(): + raise RuntimeError("ref lacks requested chain in _GetAlignedResidues") + + if not mdl_ch.IsValid(): + raise RuntimeError("mdl lacks requested chain in _GetAlignedResidues") + + if len(a.GetSequence(0).GetGaplessString()) != ref_ch.GetResidueCount(): + raise RuntimeError("aln/chain mismatch in _GetAlignedResidues") + if len(a.GetSequence(1).GetGaplessString()) != mdl_ch.GetResidueCount(): + raise RuntimeError("aln/chain mismatch in _GetAlignedResidues") + + ref_res = ref_ch.residues + mdl_res = mdl_ch.residues + + ref_res_idx = 0 + mdl_res_idx = 0 + for c in a: - mdl_r = c.GetResidue(1) - ref_r = c.GetResidue(0) - if mdl_r.IsValid(): + if c[1] != '-': + mdl_r = mdl_res[mdl_res_idx] olc = mdl_r.one_letter_code num = mdl_r.GetNumber().num ins_code = mdl_r.GetNumber().ins_code.strip("\u0000") mdl_lst.append({"olc": olc, "num": f"{num}.{ins_code}"}) + mdl_res_idx += 1 else: mdl_lst.append(None) - if ref_r.IsValid(): + if c[0] != '-': + ref_r = ref_res[ref_res_idx] olc = ref_r.one_letter_code num = ref_r.GetNumber().num ins_code = ref_r.GetNumber().ins_code.strip("\u0000") ref_lst.append({"olc": olc, "num": f"{num}.{ins_code}"}) + ref_res_idx += 1 else: ref_lst.append(None) @@ -901,7 +950,7 @@ def _Process(model, reference, args, model_format, reference_format): mdl_map_pep_seqid_thr = args.chem_map_seqid_thresh, mdl_map_nuc_seqid_thr = args.chem_map_seqid_thresh) - ir = _GetInconsistentResidues(scorer.aln) + ir = _GetInconsistentResidues(scorer.aln, scorer.target, scorer.model) if len(ir) > 0 and args.enforce_consistency: raise RuntimeError(f"Inconsistent residues observed: {' '.join(ir)}") @@ -916,13 +965,16 @@ def _Process(model, reference, args, model_format, reference_format): out["inconsistent_residues"] = ir if args.dump_aligned_residues: - out["aligned_residues"] = _GetAlignedResidues(scorer.aln) + out["aligned_residues"] = _GetAlignedResidues(scorer.aln, scorer.target, + scorer.model) if args.dump_pepnuc_alns: out["pepnuc_aln"] = [_AlnToFastaStr(aln) for aln in scorer.pepnuc_aln] if args.dump_pepnuc_aligned_residues: - out["pepnuc_aligned_residues"] = _GetAlignedResidues(scorer.pepnuc_aln) + out["pepnuc_aligned_residues"] = _GetAlignedResidues(scorer.pepnuc_aln, + scorer.pepnuc_target, + scorer.pepnuc_model) if args.lddt: out["lddt"] = _RoundOrNone(scorer.lddt) diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index 9347d31618512aeb8630832767c22c30ca07eedf..04d2787456ce730517b1f652f4fddd1af42740c3 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -113,8 +113,7 @@ class MappingResult: Each alignment is accessible with ``alns[(t_chain,m_chain)]``. First sequence is the sequence of :attr:`target` chain, second sequence the - one from :attr:`~model`. The respective :class:`ost.mol.EntityView` are - attached with :func:`ost.seq.ConstSequenceHandle.AttachView`. + one from :attr:`~model`. :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value: :class:`ost.seq.AlignmentHandle` @@ -588,6 +587,12 @@ class ChainMapper: self.n_max_naive = n_max_naive self.mdl_map_pep_seqid_thr = mdl_map_pep_seqid_thr self.mdl_map_nuc_seqid_thr = mdl_map_nuc_seqid_thr + self.pep_subst_mat = pep_subst_mat + self.pep_gap_open = pep_gap_open + self.pep_gap_ext = pep_gap_ext + self.nuc_subst_mat = nuc_subst_mat + self.nuc_gap_open = nuc_gap_open + self.nuc_gap_ext = nuc_gap_ext # lazy computed attributes self._chem_groups = None @@ -595,15 +600,6 @@ class ChainMapper: self._chem_group_ref_seqs = None self._chem_group_types = None - # helper class to generate pairwise alignments - self.aligner = _Aligner(resnum_aln = resnum_alignments, - pep_subst_mat = pep_subst_mat, - pep_gap_open = pep_gap_open, - pep_gap_ext = pep_gap_ext, - nuc_subst_mat = nuc_subst_mat, - nuc_gap_open = nuc_gap_open, - nuc_gap_ext = nuc_gap_ext) - # target structure preprocessing self._target, self._polypep_seqs, self._polynuc_seqs = \ self.ProcessStructure(target) @@ -669,14 +665,7 @@ class ChainMapper: :type: :class:`ost.seq.AlignmentList` """ if self._chem_group_alignments is None: - self._chem_group_alignments, self._chem_group_types = \ - _GetChemGroupAlignments(self.polypep_seqs, self.polynuc_seqs, - self.aligner, - pep_seqid_thr=self.pep_seqid_thr, - min_pep_length=self.min_pep_length, - nuc_seqid_thr=self.nuc_seqid_thr, - min_nuc_length=self.min_nuc_length) - + self._SetChemGroupAlignments() return self._chem_group_alignments @property @@ -694,7 +683,6 @@ class ChainMapper: for a in self.chem_group_alignments: s = seq.CreateSequence(a.GetSequence(0).GetName(), a.GetSequence(0).GetGaplessString()) - s.AttachView(a.GetSequence(0).GetAttachedView()) self._chem_group_ref_seqs.AddSequence(s) return self._chem_group_ref_seqs @@ -710,14 +698,7 @@ class ChainMapper: :type: :class:`list` of :class:`ost.mol.ChemType` """ if self._chem_group_types is None: - self._chem_group_alignments, self._chem_group_types = \ - _GetChemGroupAlignments(self.polypep_seqs, self.polynuc_seqs, - self.aligner, - pep_seqid_thr=self.pep_seqid_thr, - min_pep_length=self.min_pep_length, - nuc_seqid_thr=self.nuc_seqid_thr, - min_nuc_length=self.min_nuc_length) - + self._SetChemGroupAlignments() return self._chem_group_types def GetChemMapping(self, model): @@ -744,12 +725,11 @@ class ChainMapper: alns = [seq.AlignmentList() for x in self.chem_groups] for s in mdl_pep_seqs: - idx, aln = _MapSequence(self.chem_group_ref_seqs, - self.chem_group_types, - s, mol.ChemType.AMINOACIDS, - self.aligner, - seq_id_thr = self.mdl_map_pep_seqid_thr, - min_aln_length = self.min_pep_length) + idx, aln = self._MapSequence(self.chem_group_ref_seqs, + self.chem_group_types, + s, mol.ChemType.AMINOACIDS, mdl, + seq_id_thr = self.mdl_map_pep_seqid_thr, + min_aln_length = self.min_pep_length) if idx is None: unmapped_mdl_chains.append(s.GetName()) else: @@ -757,12 +737,11 @@ class ChainMapper: alns[idx].append(aln) for s in mdl_nuc_seqs: - idx, aln = _MapSequence(self.chem_group_ref_seqs, - self.chem_group_types, - s, mol.ChemType.NUCLEOTIDES, - self.aligner, - seq_id_thr = self.mdl_map_nuc_seqid_thr, - min_aln_length = self.min_nuc_length) + idx, aln = self._MapSequence(self.chem_group_ref_seqs, + self.chem_group_types, + s, mol.ChemType.NUCLEOTIDES, mdl, + seq_id_thr = self.mdl_map_nuc_seqid_thr, + min_aln_length = self.min_nuc_length) if idx is None: unmapped_mdl_chains.append(s.GetName()) else: @@ -878,8 +857,6 @@ class ChainMapper: for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] - aln.AttachView(0, _CSel(self.target, [ref_ch])) - aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, unmapped_mdl_chains, one_to_one, alns) @@ -921,8 +898,6 @@ class ChainMapper: for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] - aln.AttachView(0, _CSel(self.target, [ref_ch])) - aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, @@ -1018,8 +993,6 @@ class ChainMapper: for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] - aln.AttachView(0, _CSel(self.target, [ref_ch])) - aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, unmapped_mdl_chains, one_to_one, alns) @@ -1062,8 +1035,6 @@ class ChainMapper: for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] - aln.AttachView(0, _CSel(self.target, [ref_ch])) - aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, @@ -1138,8 +1109,6 @@ class ChainMapper: for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] - aln.AttachView(0, _CSel(self.target, [ref_ch])) - aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, unmapped_mdl_chains, one_to_one, alns) @@ -1210,8 +1179,6 @@ class ChainMapper: for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] - aln.AttachView(0, _CSel(self.target, [ref_ch])) - aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, @@ -1600,7 +1567,6 @@ class ChainMapper: s = ''.join([r.one_letter_code for r in ch.residues]) s = seq.CreateSequence(ch.GetName(), s) - s.AttachView(_CSel(view, [ch.GetName()])) if n_pep == n_res: polypep_seqs.AddSequence(s) elif n_nuc == n_res: @@ -1616,16 +1582,17 @@ class ChainMapper: f"- mapping failed") # select for chains for which we actually extracted the sequence - chain_names = [s.GetAttachedView().chains[0].name for s in polypep_seqs] - chain_names += [s.GetAttachedView().chains[0].name for s in polynuc_seqs] + chain_names = [s.name for s in polypep_seqs] + chain_names += [s.name for s in polynuc_seqs] view = _CSel(view, chain_names) return (view, polypep_seqs, polynuc_seqs) - def Align(self, s1, s2, stype): + def NWAlign(self, s1, s2, stype): """ Access to internal sequence alignment functionality - Alignment parameterization is setup at ChainMapper construction + Performs Needleman-Wunsch alignment with parameterization + setup at ChainMapper construction :param s1: First sequence to align - must have view attached in case of resnum_alignments @@ -1638,59 +1605,11 @@ class ChainMapper: :class:`ost.mol.ChemType.NUCLEOTIDES`] :returns: Pairwise alignment of s1 and s2 """ - if stype not in [mol.ChemType.AMINOACIDS, mol.ChemType.NUCLEOTIDES]: - raise RuntimeError("stype must be ost.mol.ChemType.AMINOACIDS or " - "ost.mol.ChemType.NUCLEOTIDES") - return self.aligner.Align(s1, s2, chem_type = stype) - - -# INTERNAL HELPERS -################## -class _Aligner: - def __init__(self, pep_subst_mat = seq.alg.BLOSUM62, pep_gap_open = -5, - pep_gap_ext = -2, nuc_subst_mat = seq.alg.NUC44, - nuc_gap_open = -4, nuc_gap_ext = -4, resnum_aln = False): - """ Helper class to compute alignments - - Sets default values for substitution matrix, gap open and gap extension - penalties. They are only used in default mode (Needleman-Wunsch aln). - If *resnum_aln* is True, only residue numbers of views that are attached - to input sequences are considered. - """ - self.pep_subst_mat = pep_subst_mat - self.pep_gap_open = pep_gap_open - self.pep_gap_ext = pep_gap_ext - self.nuc_subst_mat = nuc_subst_mat - self.nuc_gap_open = nuc_gap_open - self.nuc_gap_ext = nuc_gap_ext - self.resnum_aln = resnum_aln - - def Align(self, s1, s2, chem_type=None): - if self.resnum_aln: - return self.ResNumAlign(s1, s2) - else: - if chem_type is None: - raise RuntimeError("Must specify chem_type for NW alignment") - return self.NWAlign(s1, s2, chem_type) - - def NWAlign(self, s1, s2, chem_type): - """ Returns pairwise alignment using Needleman-Wunsch algorithm - - :param s1: First sequence to align - :type s1: :class:`ost.seq.SequenceHandle` - :param s2: Second sequence to align - :type s2: :class:`ost.seq.SequenceHandle` - :param chem_type: Must be in [:class:`ost.mol.ChemType.AMINOACIDS`, - :class:`ost.mol.ChemType.NUCLEOTIDES`], determines - substitution matrix and gap open/extension penalties - :type chem_type: :class:`ost.mol.ChemType` - :returns: Alignment with s1 as first and s2 as second sequence - """ - if chem_type == mol.ChemType.AMINOACIDS: + if stype == mol.ChemType.AMINOACIDS: return seq.alg.SemiGlobalAlign(s1, s2, self.pep_subst_mat, gap_open=self.pep_gap_open, gap_ext=self.pep_gap_ext)[0] - elif chem_type == mol.ChemType.NUCLEOTIDES: + elif stype == mol.ChemType.NUCLEOTIDES: return seq.alg.SemiGlobalAlign(s1, s2, self.nuc_subst_mat, gap_open=self.nuc_gap_open, gap_ext=self.nuc_gap_ext)[0] @@ -1698,45 +1617,214 @@ class _Aligner: raise RuntimeError("Invalid ChemType") return aln - def ResNumAlign(self, s1, s2): - """ Returns pairwise alignment using residue numbers of attached views - - Assumes that there are no insertion codes (alignment only on numerical - component) and that resnums are strictly increasing (fast min/max - identification). These requirements are assured if a structure has been - processed by :class:`ChainMapper`. + def ResNumAlign(self, s1, s2, s1_ent, s2_ent): + """ Access to internal sequence alignment functionality + + Performs residue number based alignment. Residue numbers are extracted + from *s1_ent*/*s2_ent*. - :param s1: First sequence to align, must have :class:`ost.mol.EntityView` - attached + :param s1: First sequence to align :type s1: :class:`ost.seq.SequenceHandle` - :param s2: Second sequence to align, must have :class:`ost.mol.EntityView` - attached + :param s2: Second sequence to align :type s2: :class:`ost.seq.SequenceHandle` + :param s1_ent: Structure as source for residue numbers in *s1*. Must + contain a chain named after sequence name in *s1*. This + chain must have the exact same number of residues as + characters in s1. + :type s1_ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param s2_ent: Same for *s2*. + :type s2_ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + + :returns: Pairwise alignment of s1 and s2 """ - assert(s1.HasAttachedView()) - assert(s2.HasAttachedView()) - v1 = s1.GetAttachedView() - rnums1 = [r.GetNumber().GetNum() for r in v1.residues] - v2 = s2.GetAttachedView() - rnums2 = [r.GetNumber().GetNum() for r in v2.residues] + ch1 = s1_ent.FindChain(s1.name) + ch2 = s2_ent.FindChain(s2.name) + + if not ch1.IsValid(): + raise RuntimeError("s1_ent lacks requested chain in ResNumAlign") + + if not ch2.IsValid(): + raise RuntimeError("s2_ent lacks requested chain in ResNumAlign") + + if len(s1) != ch1.GetResidueCount(): + raise RuntimeError("Sequence/structure mismatch in ResNumAlign") + + if len(s2) != ch2.GetResidueCount(): + raise RuntimeError("Sequence/structure mismatch in ResNumAlign") + + rnums1 = [r.GetNumber().GetNum() for r in ch1.residues] + rnums2 = [r.GetNumber().GetNum() for r in ch2.residues] min_num = min(rnums1[0], rnums2[0]) max_num = max(rnums1[-1], rnums2[-1]) aln_length = max_num - min_num + 1 aln_s1 = ['-'] * aln_length - for r, rnum in zip(v1.residues, rnums1): - aln_s1[rnum-min_num] = r.one_letter_code + for olc, rnum in zip(s1, rnums1): + aln_s1[rnum-min_num] = olc aln_s2 = ['-'] * aln_length - for r, rnum in zip(v2.residues, rnums2): - aln_s2[rnum-min_num] = r.one_letter_code + for olc, rnum in zip(s2, rnums2): + aln_s2[rnum-min_num] = olc aln = seq.CreateAlignment() aln.AddSequence(seq.CreateSequence(s1.GetName(), ''.join(aln_s1))) aln.AddSequence(seq.CreateSequence(s2.GetName(), ''.join(aln_s2))) return aln + def _SetChemGroupAlignments(self): + """Sets self._chem_group_alignments and self._chem_group_types + """ + pep_groups = self._GroupSequences(self.polypep_seqs, self.pep_seqid_thr, + self.min_pep_length, + mol.ChemType.AMINOACIDS) + nuc_groups = self._GroupSequences(self.polynuc_seqs, self.nuc_seqid_thr, + self.min_nuc_length, + mol.ChemType.NUCLEOTIDES) + group_types = [mol.ChemType.AMINOACIDS] * len(pep_groups) + group_types += [mol.ChemType.NUCLEOTIDES] * len(nuc_groups) + groups = pep_groups + groups.extend(nuc_groups) + self._chem_group_alignments = groups + self._chem_group_types = group_types + + def _GroupSequences(self, seqs, seqid_thr, min_length, chem_type): + """Get list of alignments representing groups of equivalent sequences + + :param seqid_thr: Threshold used to decide when two chains are identical. + :type seqid_thr: :class:`float` + :param gap_thr: Additional threshold to avoid gappy alignments with high + seqid. Number of aligned columns must be at least this + number. + :type gap_thr: :class:`int` + :param aligner: Helper class to generate pairwise alignments + :type aligner: :class:`_Aligner` + :param chem_type: ChemType of seqs, must be in + [:class:`ost.mol.ChemType.AMINOACIDS`, + :class:`ost.mol.ChemType.NUCLEOTIDES`] + :type chem_type: :class:`ost.mol.ChemType` + :returns: A list of alignments, one alignment for each group + with longest sequence (reference) as first sequence. + :rtype: :class:`ost.seq.AlignmentList` + """ + groups = list() + for s_idx in range(len(seqs)): + matching_group = None + for g_idx in range(len(groups)): + for g_s_idx in range(len(groups[g_idx])): + if self.resnum_alignments: + aln = self.ResNumAlign(seqs[s_idx], + seqs[groups[g_idx][g_s_idx]], + self.target, self.target) + else: + aln = self.NWAlign(seqs[s_idx], + seqs[groups[g_idx][g_s_idx]], + chem_type) + sid, n_aligned = _GetAlnPropsOne(aln) + if sid >= seqid_thr and n_aligned >= min_length: + matching_group = g_idx + break + if matching_group is not None: + break + + if matching_group is None: + groups.append([s_idx]) + else: + groups[matching_group].append(s_idx) + + # sort based on sequence length + sorted_groups = list() + for g in groups: + if len(g) > 1: + tmp = sorted([[len(seqs[i]), i] for i in g], reverse=True) + sorted_groups.append([x[1] for x in tmp]) + else: + sorted_groups.append(g) + + # translate from indices back to sequences and directly generate alignments + # of the groups with the longest (first) sequence as reference + aln_list = seq.AlignmentList() + for g in sorted_groups: + if len(g) == 1: + # aln with one single sequence + aln_list.append(seq.CreateAlignment(seqs[g[0]])) + else: + # obtain pairwise aln of first sequence (reference) to all others + alns = seq.AlignmentList() + i = g[0] + for j in g[1:]: + if self.resnum_alignments: + aln = self.ResNumAlign(seqs[i], seqs[j], + self.target, self.target) + else: + aln = self.NWAlign(seqs[i], seqs[j], chem_type) + alns.append(aln) + # and merge + aln_list.append(seq.alg.MergePairwiseAlignments(alns, seqs[i])) + + return aln_list + + def _MapSequence(self, ref_seqs, ref_types, s, s_type, s_ent, + seq_id_thr=0.0, min_aln_length=0): + """Tries top map *s* onto any of the sequences in *ref_seqs* + + Computes alignments of *s* to each of the reference sequences of equal type + and sorts them by seqid*fraction_covered (seqid: sequence identity of + aligned columns in alignment, fraction_covered: Fraction of non-gap + characters in reference sequence that are covered by non-gap characters in + *s*). Best scoring mapping is returned. Optionally, *seq_id*/ + *min_aln_length* thresholds can be enabled to avoid non-sensical mappings. + However, *min_aln_length* only has an effect if *seq_id_thr* > 0!!! + + :param ref_seqs: Reference sequences + :type ref_seqs: :class:`ost.seq.SequenceList` + :param ref_types: Types of reference sequences, e.g. + ost.mol.ChemType.AminoAcids + :type ref_types: :class:`list` of :class:`ost.mol.ChemType` + :param s: Sequence to map + :type s: :class:`ost.seq.SequenceHandle` + :param s_type: Type of *s*, only try mapping to sequences in *ref_seqs* + with equal type as defined in *ref_types* + :param s_ent: Entity which represents *s*. Only relevant in case of + residue number alignments. + :type s_ent: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView` + :param seq_id_thr: Minimum sequence identity to be considered as match + :type seq_id_thr: :class:`float` + :param min_aln_length: Minimum number of aligned columns to be considered + as match. Only has an effect if *seq_id_thr* > 0! + :type min_aln_length: :class:`int` + :returns: Tuple with two elements. 1) index of sequence in *ref_seqs* to + which *s* can be mapped 2) Pairwise sequence alignment with + sequence from *ref_seqs* as first sequence. Both elements are + None if no mapping can be found or if thresholds are not + fulfilled for any alignment. + :raises: :class:`RuntimeError` if mapping is ambiguous, i.e. *s* + successfully maps to more than one sequence in *ref_seqs* + """ + scored_alns = list() + for ref_idx, ref_seq in enumerate(ref_seqs): + if ref_types[ref_idx] == s_type: + if self.resnum_alignments: + aln = self.ResNumAlign(ref_seq, s, self.target, s_ent) + else: + aln = self.NWAlign(ref_seq, s, s_type) + seqid, n_tot, n_aligned = _GetAlnPropsTwo(aln) + if seq_id_thr > 0: + if seqid >= seq_id_thr and n_aligned >= min_aln_length: + fraction_covered = float(n_aligned)/n_tot + score = seqid * fraction_covered + scored_alns.append((score, ref_idx, aln)) + else: + fraction_covered = float(n_aligned)/n_tot + score = seqid * fraction_covered + scored_alns.append((score, ref_idx, aln)) + + if len(scored_alns) == 0: + return (None, None) # no mapping possible... + + scored_alns = sorted(scored_alns, key=lambda x: x[0], reverse=True) + return (scored_alns[0][1], scored_alns[0][2]) + def _GetAlnPropsTwo(aln): """Returns basic properties of *aln* version two... @@ -1765,176 +1853,6 @@ def _GetAlnPropsOne(aln): n_aligned = sum([1 for col in aln if (col[0] != '-' and col[1] != '-')]) return (seqid, n_aligned) -def _GetChemGroupAlignments(pep_seqs, nuc_seqs, aligner, pep_seqid_thr=95., - min_pep_length=6, nuc_seqid_thr=95., - min_nuc_length=4): - """Returns alignments with groups of chemically equivalent chains - - :param pep_seqs: List of polypeptide sequences - :type pep_seqs: :class:`seq.SequenceList` - :param nuc_seqs: List of polynucleotide sequences - :type nuc_seqs: :class:`seq.SequenceList` - :param aligner: Helper class to generate pairwise alignments - :type aligner: :class:`_Aligner` - :param pep_seqid_thr: Threshold used to decide when two peptide chains are - identical. 95 percent tolerates the few mutations - crystallographers like to do. - :type pep_seqid_thr: :class:`float` - :param min_pep_length: Additional threshold to avoid gappy alignments with high - seqid. Number of aligned columns must be at least this - number. - :type min_pep_length: :class:`int` - :param nuc_seqid_thr: Nucleotide equivalent of *pep_seqid_thr* - :type nuc_seqid_thr: :class:`float` - :param min_nuc_length: Nucleotide equivalent of *min_pep_length* - :type min_nuc_length: :class:`int` - :returns: Tuple with first element being an AlignmentList. Each alignment - represents a group of chemically equivalent chains and the first - sequence is the longest. Second element is a list of equivalent - length specifying the types of the groups. List elements are in - [:class:`ost.ChemType.AMINOACIDS`, - :class:`ost.ChemType.NUCLEOTIDES`] - """ - pep_groups = _GroupSequences(pep_seqs, pep_seqid_thr, min_pep_length, aligner, - mol.ChemType.AMINOACIDS) - nuc_groups = _GroupSequences(nuc_seqs, nuc_seqid_thr, min_nuc_length, aligner, - mol.ChemType.NUCLEOTIDES) - group_types = [mol.ChemType.AMINOACIDS] * len(pep_groups) - group_types += [mol.ChemType.NUCLEOTIDES] * len(nuc_groups) - groups = pep_groups - groups.extend(nuc_groups) - return (groups, group_types) - -def _GroupSequences(seqs, seqid_thr, min_length, aligner, chem_type): - """Get list of alignments representing groups of equivalent sequences - - :param seqid_thr: Threshold used to decide when two chains are identical. - :type seqid_thr: :class:`float` - :param gap_thr: Additional threshold to avoid gappy alignments with high - seqid. Number of aligned columns must be at least this - number. - :type gap_thr: :class:`int` - :param aligner: Helper class to generate pairwise alignments - :type aligner: :class:`_Aligner` - :param chem_type: ChemType of seqs which is passed to *aligner*, must be in - [:class:`ost.mol.ChemType.AMINOACIDS`, - :class:`ost.mol.ChemType.NUCLEOTIDES`] - :type chem_type: :class:`ost.mol.ChemType` - :returns: A list of alignments, one alignment for each group - with longest sequence (reference) as first sequence. - :rtype: :class:`ost.seq.AlignmentList` - """ - groups = list() - for s_idx in range(len(seqs)): - matching_group = None - for g_idx in range(len(groups)): - for g_s_idx in range(len(groups[g_idx])): - aln = aligner.Align(seqs[s_idx], seqs[groups[g_idx][g_s_idx]], - chem_type) - sid, n_aligned = _GetAlnPropsOne(aln) - if sid >= seqid_thr and n_aligned >= min_length: - matching_group = g_idx - break - if matching_group is not None: - break - - if matching_group is None: - groups.append([s_idx]) - else: - groups[matching_group].append(s_idx) - - # sort based on sequence length - sorted_groups = list() - for g in groups: - if len(g) > 1: - tmp = sorted([[len(seqs[i]), i] for i in g], reverse=True) - sorted_groups.append([x[1] for x in tmp]) - else: - sorted_groups.append(g) - - # translate from indices back to sequences and directly generate alignments - # of the groups with the longest (first) sequence as reference - aln_list = seq.AlignmentList() - for g in sorted_groups: - if len(g) == 1: - # aln with one single sequence - aln_list.append(seq.CreateAlignment(seqs[g[0]])) - else: - # obtain pairwise aln of first sequence (reference) to all others - alns = seq.AlignmentList() - i = g[0] - for j in g[1:]: - alns.append(aligner.Align(seqs[i], seqs[j], chem_type)) - # and merge - aln_list.append(seq.alg.MergePairwiseAlignments(alns, seqs[i])) - - # transfer attached views - seq_dict = {s.GetName(): s for s in seqs} - for aln_idx in range(len(aln_list)): - for aln_s_idx in range(aln_list[aln_idx].GetCount()): - s_name = aln_list[aln_idx].GetSequence(aln_s_idx).GetName() - s = seq_dict[s_name] - aln_list[aln_idx].AttachView(aln_s_idx, s.GetAttachedView()) - - return aln_list - -def _MapSequence(ref_seqs, ref_types, s, s_type, aligner, - seq_id_thr=0.0, min_aln_length=0): - """Tries top map *s* onto any of the sequences in *ref_seqs* - - Computes alignments of *s* to each of the reference sequences of equal type - and sorts them by seqid*fraction_covered (seqid: sequence identity of - aligned columns in alignment, fraction_covered: Fraction of non-gap - characters in reference sequence that are covered by non-gap characters in - *s*). Best scoring mapping is returned. Optionally, *seq_id*/ - *min_aln_length* thresholds can be enabled to avoid non-sensical mappings. - However, *min_aln_length* only has an effect if *seq_id_thr* > 0!!! - - :param ref_seqs: Reference sequences - :type ref_seqs: :class:`ost.seq.SequenceList` - :param ref_types: Types of reference sequences, e.g. - ost.mol.ChemType.AminoAcids - :type ref_types: :class:`list` of :class:`ost.mol.ChemType` - :param s: Sequence to map - :type s: :class:`ost.seq.SequenceHandle` - :param s_type: Type of *s*, only try mapping to sequences in *ref_seqs* - with equal type as defined in *ref_types* - :param aligner: Helper class to generate pairwise alignments - :type aligner: :class:`_Aligner` - :param seq_id_thr: Minimum sequence identity to be considered as match - :type seq_id_thr: :class:`float` - :param min_aln_length: Minimum number of aligned columns to be considered - as match. Only has an effect if *seq_id_thr* > 0! - :type min_aln_length: :class:`int` - :returns: Tuple with two elements. 1) index of sequence in *ref_seqs* to - which *s* can be mapped 2) Pairwise sequence alignment with - sequence from *ref_seqs* as first sequence. Both elements are - None if no mapping can be found or if thresholds are not - fulfilled for any alignment. - :raises: :class:`RuntimeError` if mapping is ambiguous, i.e. *s* - successfully maps to more than one sequence in *ref_seqs* - """ - scored_alns = list() - for ref_idx, ref_seq in enumerate(ref_seqs): - if ref_types[ref_idx] == s_type: - aln = aligner.Align(ref_seq, s, s_type) - seqid, n_tot, n_aligned = _GetAlnPropsTwo(aln) - if seq_id_thr > 0: - if seqid >= seq_id_thr and n_aligned >= min_aln_length: - fraction_covered = float(n_aligned)/n_tot - score = seqid * fraction_covered - scored_alns.append((score, ref_idx, aln)) - else: - fraction_covered = float(n_aligned)/n_tot - score = seqid * fraction_covered - scored_alns.append((score, ref_idx, aln)) - - if len(scored_alns) == 0: - return (None, None) # no mapping possible... - - scored_alns = sorted(scored_alns, key=lambda x: x[0], reverse=True) - return (scored_alns[0][1], scored_alns[0][2]) - def _GetRefMdlAlns(ref_chem_groups, ref_chem_group_msas, mdl_chem_groups, mdl_chem_group_alns, pairs=None): """ Get all possible ref/mdl chain alignments given chem group mapping diff --git a/modules/mol/alg/pymod/contact_score.py b/modules/mol/alg/pymod/contact_score.py index 845f620e0da0032deccc5184ee8a3bac482f12c4..0af2d467ac17093a908aa6f60ee564e28178fda1 100644 --- a/modules/mol/alg/pymod/contact_score.py +++ b/modules/mol/alg/pymod/contact_score.py @@ -384,7 +384,7 @@ class ContactScorerResultIPS: :type: :class:`int` """ - return self._n_trg_contacts + return self._n_trg_int_res @property def n_mdl_int_res(self): @@ -738,19 +738,29 @@ class ContactScorer: trg_int_r = (trg_ch2, trg_ch1) mdl_int_r = (mdl_ch2, mdl_ch1) + trg_contacts = None if trg_int in self.cent1.contacts: - n_trg = len(self.cent1.contacts[trg_int]) + trg_contacts = self.cent1.contacts[trg_int] elif trg_int_r in self.cent1.contacts: - n_trg = len(self.cent1.contacts[trg_int_r]) - else: + trg_contacts = self.cent1.contacts[trg_int_r] + + if trg_contacts is None: n_trg = 0 + else: + n_trg = len(set([x[0] for x in trg_contacts])) + n_trg += len(set([x[1] for x in trg_contacts])) + mdl_contacts = None if mdl_int in self.cent2.contacts: - n_mdl = len(self.cent2.contacts[mdl_int]) + mdl_contacts = self.cent2.contacts[mdl_int] elif mdl_int_r in self.cent2.contacts: - n_mdl = len(self.cent2.contacts[mdl_int_r]) - else: + mdl_contacts = self.cent2.contacts[mdl_int_r] + + if mdl_contacts is None: n_mdl = 0 + else: + n_mdl = len(set([x[0] for x in mdl_contacts])) + n_mdl += len(set([x[1] for x in mdl_contacts])) _, _, n_union, n_intersection = self._MappedInterfaceScores(trg_int, mdl_int) return ContactScorerResultIPS(n_trg, n_mdl, n_union, n_intersection) diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py index cea7db31802b65a744035675b8344556f054c363..f1d164a8eddc501ed9e368cdbb4ea784b0034bb5 100644 --- a/modules/mol/alg/pymod/scoring.py +++ b/modules/mol/alg/pymod/scoring.py @@ -22,6 +22,48 @@ from ost.bindings import cadscore from ost.bindings import tmtools import numpy as np +def _GetAlignedResidues(aln, s1_ent, s2_ent): + """ Yields aligned residues + + :param aln: The alignment with 2 sequences defining a residue-by-residue + relationship. + :type aln: :class:`ost.seq.AlignmentHandle` + :param s1_ent: Structure representing first sequence in *aln*. + One chain must be named after the first sequence and the + number of residues must match the number of non-gap + characters. + :type s1_ent: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView` + :param s2_ent: Same for second sequence in *aln*. + :type s2_ent: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView` + """ + s1_ch = s1_ent.FindChain(aln.GetSequence(0).GetName()) + s2_ch = s2_ent.FindChain(aln.GetSequence(1).GetName()) + + if not s1_ch.IsValid(): + raise RuntimeError("s1_ent lacks required chain in _GetAlignedResidues") + + if not s2_ch.IsValid(): + raise RuntimeError("s2_ent lacks required chain in _GetAlignedResidues") + + if len(aln.GetSequence(0).GetGaplessString()) != s1_ch.GetResidueCount(): + raise RuntimeError("aln/chain mismatch in _GetAlignedResidues") + if len(aln.GetSequence(1).GetGaplessString()) != s2_ch.GetResidueCount(): + raise RuntimeError("aln/chain mismatch in _GetAlignedResidues") + + s1_res = s1_ch.residues + s2_res = s2_ch.residues + + s1_res_idx = 0 + s2_res_idx = 0 + + for col in aln: + if col[0] != '-' and col[1] != '-': + yield (s1_res[s1_res_idx], s2_res[s2_res_idx]) + if col[0] != '-': + s1_res_idx += 1 + if col[1] != '-': + s2_res_idx += 1 + class lDDTBSScorer: """Scorer specific for a reference/model pair @@ -237,6 +279,10 @@ class Scorer: self._target_orig = target self._model_orig = model + # lazily computed versions of target_orig and model_orig + self._pepnuc_target = None + self._pepnuc_model = None + if isinstance(self._model_orig, mol.EntityView): self._model = mol.CreateEntityFromView(self._model_orig, False) else: @@ -493,6 +539,19 @@ class Scorer: """ return self._model_orig + @property + def pepnuc_model(self): + """ A selection of :attr:`~model_orig` + + Only contains peptide and nucleotide residues + + :type: :class:`ost.mol.EntityView` + """ + if self._pepnuc_model is None: + query = "peptide=true or nucleotide=true" + self._pepnuc_model = self.model_orig.Select(query) + return self._pepnuc_model + @property def target(self): """ Target with Molck cleanup @@ -509,6 +568,19 @@ class Scorer: """ return self._target_orig + @property + def pepnuc_target(self): + """ A selection of :attr:`~target_orig` + + Only contains peptide and nucleotide residues + + :type: :class:`ost.mol.EntityView` + """ + if self._pepnuc_target is None: + query = "peptide=true or nucleotide=true" + self._pepnuc_target = self.target_orig.Select(query) + return self._pepnuc_target + @property def aln(self): """ Alignments of :attr:`~model`/:attr:`~target` chains @@ -2044,14 +2116,12 @@ class Scorer: cname = ch.GetName() s = ''.join([r.one_letter_code for r in ch.residues]) s = seq.CreateSequence(ch.GetName(), s) - s.AttachView(target.Select(f"cname={mol.QueryQuoteName(cname)}")) trg_seqs[ch.GetName()] = s mdl_seqs = dict() for ch in model.chains: cname = ch.GetName() s = ''.join([r.one_letter_code for r in ch.residues]) s = seq.CreateSequence(cname, s) - s.AttachView(model.Select(f"cname={mol.QueryQuoteName(cname)}")) mdl_seqs[ch.GetName()] = s alns = list() @@ -2068,25 +2138,73 @@ class Scorer: else: raise RuntimeError("Chain name inconsistency... ask " "Gabriel") - alns.append(self.chain_mapper.Align(trg_seqs[trg_ch], + if self.resnum_alignments: + aln = self.chain_mapper.ResNumAlign(trg_seqs[trg_ch], + mdl_seqs[mdl_ch], + target, model) + else: + aln = self.chain_mapper.NWAlign(trg_seqs[trg_ch], mdl_seqs[mdl_ch], - stype)) - alns[-1].AttachView(0, trg_seqs[trg_ch].GetAttachedView()) - alns[-1].AttachView(1, mdl_seqs[mdl_ch].GetAttachedView()) - return alns + stype) - def _compute_pepnuc_aln(self): - query = "peptide=true or nucleotide=true" - pep_nuc_target = self.target_orig.Select(query) - pep_nuc_model = self.model_orig.Select(query) - self._pepnuc_aln = self._aln_helper(pep_nuc_target, pep_nuc_model) + alns.append(aln) + return alns def _compute_aln(self): self._aln = self._aln_helper(self.target, self.model) def _compute_stereochecked_aln(self): - self._stereochecked_aln = self._aln_helper(self.stereochecked_target, - self.stereochecked_model) + # lets not redo the alignment and derive it from self.aln + alns = list() + for a in self.aln: + trg_s = a.GetSequence(0) + mdl_s = a.GetSequence(1) + trg_ch = self.target.FindChain(trg_s.name) + mdl_ch = self.model.FindChain(mdl_s.name) + + sc_trg_olc = ['-'] * len(trg_s) + sc_mdl_olc = ['-'] * len(mdl_s) + + sc_trg_ch = self.stereochecked_target.FindChain(trg_s.name) + if sc_trg_ch.IsValid(): + # there is the theoretical possibility that the full chain + # has been removed in stereochemistry checks... + trg_residues = trg_ch.residues + res_idx = 0 + for olc_idx, olc in enumerate(trg_s): + if olc != '-': + r = trg_residues[res_idx] + sc_r = sc_trg_ch.FindResidue(r.GetNumber()) + if sc_r.IsValid(): + sc_trg_olc[olc_idx] = sc_r.one_letter_code + res_idx += 1 + + sc_mdl_ch = self.stereochecked_model.FindChain(mdl_s.name) + if sc_mdl_ch.IsValid(): + # there is the theoretical possibility that the full chain + # has been removed in stereochemistry checks... + mdl_residues = mdl_ch.residues + res_idx = 0 + for olc_idx, olc in enumerate(mdl_s): + if olc != '-': + r = mdl_residues[res_idx] + sc_r = sc_mdl_ch.FindResidue(r.GetNumber()) + if sc_r.IsValid(): + sc_mdl_olc[olc_idx] = sc_r.one_letter_code + res_idx += 1 + + sc_trg_s = seq.CreateSequence(trg_s.name, ''.join(sc_trg_olc)) + sc_mdl_s = seq.CreateSequence(mdl_s.name, ''.join(sc_mdl_olc)) + new_a = seq.CreateAlignment() + new_a.AddSequence(sc_trg_s) + new_a.AddSequence(sc_mdl_s) + alns.append(new_a) + + self._stereochecked_aln = alns + + def _compute_pepnuc_aln(self): + self._pepnuc_aln = self._aln_helper(self.pepnuc_target, + self.pepnuc_model) def _compute_lddt(self): LogScript("Computing all-atom LDDT") @@ -2167,10 +2285,12 @@ class Scorer: local_lddt = dict() aa_local_lddt = dict() for r in self.model.residues: + cname = r.GetChain().GetName() if cname not in local_lddt: local_lddt[cname] = dict() aa_local_lddt[cname] = dict() + rnum = r.GetNumber() if rnum not in aa_local_lddt[cname]: aa_local_lddt[cname][rnum] = dict() @@ -2179,8 +2299,8 @@ class Scorer: score = round(r.GetFloatProp("lddt"), 3) local_lddt[cname][rnum] = score - trg_r = None - mdl_r = None + trg_r = None # represents stereochecked trg res + mdl_r = None # represents stereochecked mdl res for a in r.atoms: if a.HasProp("lddt"): @@ -2195,18 +2315,24 @@ class Scorer: # stereochecks but is there in stereochecked # target => 0.0 if trg_r is None: + # let's first see if we find that target residue + # in the non-stereochecked target + tmp = None if cname in flat_mapping: - for col in alns[cname]: - if col[0] != '-' and col[1] != '-': - if col.GetResidue(1).number == r.number: - trg_r = col.GetResidue(0) - break - if trg_r is not None: - trg_cname = trg_r.GetChain().GetName() - trg_rnum = trg_r.GetNumber() - tmp = self.stereochecked_target.FindResidue(trg_cname, - trg_rnum) + for x, y in _GetAlignedResidues(alns[cname], + self.target, + self.model): + if y.number == r.number: + tmp = x + break + if tmp is not None: + # we have it in the non-stereochecked target! + tmp_cname = tmp.GetChain().GetName() + tmp_rnum = tmp.GetNumber() + tmp = self.stereochecked_target.FindResidue(tmp_cname, + tmp_rnum) if tmp.IsValid(): + # And it's there in the stereochecked target too! trg_r = tmp if mdl_r is None: @@ -2214,12 +2340,23 @@ class Scorer: if tmp.IsValid(): mdl_r = tmp - if trg_r is not None and not trg_r.FindAtom(a.GetName()).IsValid(): - # opt 1 + if trg_r is None: + # opt 1 - the whole target residue is not there + # this is actually an impossibility, as we have + # a score for the full mdl residue set + aa_local_lddt[cname][rnum][a.GetName()] = None + elif trg_r is not None and not trg_r.FindAtom(a.GetName()).IsValid(): + # opt 1 - the target residue is there but not the atom aa_local_lddt[cname][rnum][a.GetName()] = None + elif trg_r is not None and trg_r.FindAtom(a.GetName()).IsValid() and \ + mdl_r is None: + # opt 2 - trg atom is there but full model residue is removed + # this is actuall an impossibility, as we have + # a score for the full mdl residue set + aa_local_lddt[cname][rnum][a.GetName()] = 0.0 elif trg_r is not None and trg_r.FindAtom(a.GetName()).IsValid() and \ mdl_r is not None and not mdl_r.FindAtom(a.GetName()).IsValid(): - # opt 2 + # opt 2 - trg atom is there but model atom is removed aa_local_lddt[cname][rnum][a.GetName()] = 0.0 else: # unknown issue @@ -2237,20 +2374,25 @@ class Scorer: # opt 1: removed by stereochecks => assign 0.0 # opt 2: removed by stereochecks AND not covered by ref # => assign None + # fetch trg residue from non-stereochecked aln trg_r = None if cname in flat_mapping: - for col in alns[cname]: - if col[0] != '-' and col[1] != '-': - if col.GetResidue(1).number == r.number: - trg_r = col.GetResidue(0) - break - if trg_r is not None: - trg_cname = trg_r.GetChain().GetName() - trg_rnum = trg_r.GetNumber() - tmp = self.stereochecked_target.FindResidue(trg_cname, - trg_rnum) + tmp = None + for x, y in _GetAlignedResidues(alns[cname], + self.target, + self.model): + if y.number == r.number: + tmp = x + break + if tmp is not None: + # we have it in the non-stereochecked target! + tmp_cname = tmp.GetChain().GetName() + tmp_rnum = tmp.GetNumber() + tmp = self.stereochecked_target.FindResidue(tmp_cname, + tmp_rnum) if tmp.IsValid(): + # And it's there in the stereochecked target too! trg_r = tmp if trg_r is None: @@ -2617,30 +2759,30 @@ class Scorer: for trg_ch, mdl_ch in self.mapping.GetFlatMapping().items(): processed_trg_chains.add(trg_ch) aln = self.mapping.alns[(trg_ch, mdl_ch)] - for col in aln: - if col[0] != '-' and col[1] != '-': - trg_res = col.GetResidue(0) - mdl_res = col.GetResidue(1) - trg_at = trg_res.FindAtom("CA") - mdl_at = mdl_res.FindAtom("CA") - if not trg_at.IsValid(): - trg_at = trg_res.FindAtom("C3'") - if not mdl_at.IsValid(): - mdl_at = mdl_res.FindAtom("C3'") - self._mapped_target_pos.append(trg_at.GetPos()) - self._mapped_model_pos.append(mdl_at.GetPos()) - elif col[0] != '-': - self._n_target_not_mapped += 1 + n_mapped = 0 + for trg_res, mdl_res in _GetAlignedResidues(aln, + self.mapping.target, + self.mapping.model): + trg_at = trg_res.FindAtom("CA") + mdl_at = mdl_res.FindAtom("CA") + if not trg_at.IsValid(): + trg_at = trg_res.FindAtom("C3'") + if not mdl_at.IsValid(): + mdl_at = mdl_res.FindAtom("C3'") + self._mapped_target_pos.append(trg_at.GetPos()) + self._mapped_model_pos.append(mdl_at.GetPos()) + n_mapped += 1 + self._n_target_not_mapped += (len(aln.GetSequence(0).GetGaplessString())-n_mapped) # count number of trg residues from non-mapped chains for ch in self.mapping.target.chains: if ch.GetName() not in processed_trg_chains: - self._n_target_not_mapped += len(ch.residues) + self._n_target_not_mapped += ch.GetResidueCount() def _extract_mapped_pos_full_bb(self): self._mapped_target_pos_full_bb = geom.Vec3List() self._mapped_model_pos_full_bb = geom.Vec3List() exp_pep_atoms = ["N", "CA", "C"] - exp_nuc_atoms = ["\"O5'\"", "\"C5'\"", "\"C4'\"", "\"C3'\"", "\"O3'\""] + exp_nuc_atoms = ["O5'", "C5'", "C4'", "C3'", "O3'"] trg_pep_chains = [s.GetName() for s in self.chain_mapper.polypep_seqs] trg_nuc_chains = [s.GetName() for s in self.chain_mapper.polynuc_seqs] for trg_ch, mdl_ch in self.mapping.GetFlatMapping().items(): @@ -2653,19 +2795,18 @@ class Scorer: else: # this should be guaranteed by the chain mapper raise RuntimeError("Unexpected error - contact OST developer") - for col in aln: - if col[0] != '-' and col[1] != '-': - trg_res = col.GetResidue(0) - mdl_res = col.GetResidue(1) - for aname in exp_atoms: - trg_at = trg_res.FindAtom(aname) - mdl_at = mdl_res.FindAtom(aname) - if not (trg_at.IsValid() and mdl_at.IsValid()): - # this should be guaranteed by the chain mapper - raise RuntimeError("Unexpected error - contact OST " - "developer") - self._mapped_target_pos_full_bb.append(trg_at.GetPos()) - self._mapped_model_pos_full_bb.append(mdl_at.GetPos()) + for trg_res, mdl_res in _GetAlignedResidues(aln, + self.mapping.target, + self.mapping.model): + for aname in exp_atoms: + trg_at = trg_res.FindAtom(aname) + mdl_at = mdl_res.FindAtom(aname) + if not (trg_at.IsValid() and mdl_at.IsValid()): + # this should be guaranteed by the chain mapper + raise RuntimeError("Unexpected error - contact OST " + "developer") + self._mapped_target_pos_full_bb.append(trg_at.GetPos()) + self._mapped_model_pos_full_bb.append(mdl_at.GetPos()) def _extract_rigid_mapped_pos(self): @@ -2676,34 +2817,35 @@ class Scorer: for trg_ch, mdl_ch in self.rigid_mapping.GetFlatMapping().items(): processed_trg_chains.add(trg_ch) aln = self.rigid_mapping.alns[(trg_ch, mdl_ch)] - for col in aln: - if col[0] != '-' and col[1] != '-': - trg_res = col.GetResidue(0) - mdl_res = col.GetResidue(1) - trg_at = trg_res.FindAtom("CA") - mdl_at = mdl_res.FindAtom("CA") - if not trg_at.IsValid(): - trg_at = trg_res.FindAtom("C3'") - if not mdl_at.IsValid(): - mdl_at = mdl_res.FindAtom("C3'") - self._rigid_mapped_target_pos.append(trg_at.GetPos()) - self._rigid_mapped_model_pos.append(mdl_at.GetPos()) - elif col[0] != '-': - self._rigid_n_target_not_mapped += 1 + n_mapped = 0 + for trg_res, mdl_res in _GetAlignedResidues(aln, + self.rigid_mapping.target, + self.rigid_mapping.model): + trg_at = trg_res.FindAtom("CA") + mdl_at = mdl_res.FindAtom("CA") + if not trg_at.IsValid(): + trg_at = trg_res.FindAtom("C3'") + if not mdl_at.IsValid(): + mdl_at = mdl_res.FindAtom("C3'") + self._rigid_mapped_target_pos.append(trg_at.GetPos()) + self._rigid_mapped_model_pos.append(mdl_at.GetPos()) + n_mapped += 1 + + self._rigid_n_target_not_mapped += (len(aln.GetSequence(0).GetGaplessString())-n_mapped) # count number of trg residues from non-mapped chains for ch in self.rigid_mapping.target.chains: if ch.GetName() not in processed_trg_chains: - self._rigid_n_target_not_mapped += len(ch.residues) + self._rigid_n_target_not_mapped += ch.GetResidueCount() def _extract_rigid_mapped_pos_full_bb(self): self._rigid_mapped_target_pos_full_bb = geom.Vec3List() self._rigid_mapped_model_pos_full_bb = geom.Vec3List() exp_pep_atoms = ["N", "CA", "C"] - exp_nuc_atoms = ["\"O5'\"", "\"C5'\"", "\"C4'\"", "\"C3'\"", "\"O3'\""] + exp_nuc_atoms = ["O5'", "C5'", "C4'", "C3'", "O3'"] trg_pep_chains = [s.GetName() for s in self.chain_mapper.polypep_seqs] trg_nuc_chains = [s.GetName() for s in self.chain_mapper.polynuc_seqs] for trg_ch, mdl_ch in self.rigid_mapping.GetFlatMapping().items(): - aln = self.mapping.alns[(trg_ch, mdl_ch)] + aln = self.rigid_mapping.alns[(trg_ch, mdl_ch)] trg_ch = aln.GetSequence(0).GetName() if trg_ch in trg_pep_chains: exp_atoms = exp_pep_atoms @@ -2712,19 +2854,18 @@ class Scorer: else: # this should be guaranteed by the chain mapper raise RuntimeError("Unexpected error - contact OST developer") - for col in aln: - if col[0] != '-' and col[1] != '-': - trg_res = col.GetResidue(0) - mdl_res = col.GetResidue(1) - for aname in exp_atoms: - trg_at = trg_res.FindAtom(aname) - mdl_at = mdl_res.FindAtom(aname) - if not (trg_at.IsValid() and mdl_at.IsValid()): - # this should be guaranteed by the chain mapper - raise RuntimeError("Unexpected error - contact OST " - "developer") - self._rigid_mapped_target_pos_full_bb.append(trg_at.GetPos()) - self._rigid_mapped_model_pos_full_bb.append(mdl_at.GetPos()) + for trg_res, mdl_res in _GetAlignedResidues(aln, + self.rigid_mapping.target, + self.rigid_mapping.model): + for aname in exp_atoms: + trg_at = trg_res.FindAtom(aname) + mdl_at = mdl_res.FindAtom(aname) + if not (trg_at.IsValid() and mdl_at.IsValid()): + # this should be guaranteed by the chain mapper + raise RuntimeError("Unexpected error - contact OST " + "developer") + self._rigid_mapped_target_pos_full_bb.append(trg_at.GetPos()) + self._rigid_mapped_model_pos_full_bb.append(mdl_at.GetPos()) def _compute_cad_score(self): if not self.resnum_alignments: @@ -2813,12 +2954,12 @@ class Scorer: a, b, c, d = stereochemistry.StereoCheck(self.target, stereo_data = data, stereo_link_data = l_data) + self._stereochecked_target = a self._target_clashes = b self._target_bad_bonds = c self._target_bad_angles = d - def _get_interface_patches(self, mdl_ch, mdl_rnum): """ Select interface patches representative for specified residue @@ -2897,34 +3038,38 @@ class Scorer: # transfer mdl residues to trg flat_mapping = self.mapping.GetFlatMapping(mdl_as_key=True) full_trg_coverage = True - trg_patch_one = self.target.CreateEmptyView() + trg_patch_one = self.mapping.target.CreateEmptyView() for r in mdl_patch_one.residues: trg_r = None mdl_cname = r.GetChain().GetName() if mdl_cname in flat_mapping: aln = self.mapping.alns[(flat_mapping[mdl_cname], mdl_cname)] - for col in aln: - if col[0] != '-' and col[1] != '-': - if col.GetResidue(1).GetNumber() == r.GetNumber(): - trg_r = col.GetResidue(0) - break + for x,y in _GetAlignedResidues(aln, + self.mapping.target, + self.mapping.model): + if y.GetNumber() == r.GetNumber(): + trg_r = x + break + if trg_r is not None: trg_patch_one.AddResidue(trg_r.handle, mol.ViewAddFlag.INCLUDE_ALL) else: full_trg_coverage = False - trg_patch_two = self.target.CreateEmptyView() + trg_patch_two = self.mapping.target.CreateEmptyView() for r in mdl_patch_two.residues: trg_r = None mdl_cname = r.GetChain().GetName() if mdl_cname in flat_mapping: aln = self.mapping.alns[(flat_mapping[mdl_cname], mdl_cname)] - for col in aln: - if col[0] != '-' and col[1] != '-': - if col.GetResidue(1).GetNumber() == r.GetNumber(): - trg_r = col.GetResidue(0) - break + for x,y in _GetAlignedResidues(aln, + self.mapping.target, + self.mapping.model): + if y.GetNumber() == r.GetNumber(): + trg_r = x + break + if trg_r is not None: trg_patch_two.AddResidue(trg_r.handle, mol.ViewAddFlag.INCLUDE_ALL) diff --git a/modules/mol/alg/tests/CMakeLists.txt b/modules/mol/alg/tests/CMakeLists.txt index e21f0746812554f0b3a6e33145efbd7eb3d9e254..1c030fec3041690dfaa756a8931d9d97f7e21eba 100644 --- a/modules/mol/alg/tests/CMakeLists.txt +++ b/modules/mol/alg/tests/CMakeLists.txt @@ -21,7 +21,8 @@ if (COMPOUND_LIB) list(APPEND OST_MOL_ALG_UNIT_TESTS test_qsscoring.py test_nonstandard.py test_chain_mapping.py - test_ligand_scoring.py) + test_ligand_scoring.py + test_scoring.py) endif() ost_unittest(MODULE mol_alg SOURCES "${OST_MOL_ALG_UNIT_TESTS}" LINK ost_io) diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py index d7ec18c40cb9aea95f9c34ba0f217dc62f7b1c62..3183ed7c810efd380c11de8ed6e95872191adad2 100644 --- a/modules/mol/alg/tests/test_chain_mapping.py +++ b/modules/mol/alg/tests/test_chain_mapping.py @@ -63,15 +63,6 @@ class TestChainMapper(unittest.TestCase): self.assertEqual(str(mapper.polynuc_seqs[0]), str(nuc_s_one)) self.assertEqual(str(mapper.polynuc_seqs[1]), str(nuc_s_two)) - for s in mapper.polypep_seqs: - self.assertTrue(s.HasAttachedView()) - for s in mapper.polynuc_seqs: - self.assertTrue(s.HasAttachedView()) - self.assertTrue(_CompareViews(mapper.polypep_seqs[0].GetAttachedView(), pep_view_one)) - self.assertTrue(_CompareViews(mapper.polypep_seqs[1].GetAttachedView(), pep_view_two)) - self.assertTrue(_CompareViews(mapper.polynuc_seqs[0].GetAttachedView(), nuc_view_one)) - self.assertTrue(_CompareViews(mapper.polynuc_seqs[1].GetAttachedView(), nuc_view_two)) - # peptide sequences should be in the same group, the nucleotides not self.assertEqual(len(mapper.chem_group_alignments), 3) self.assertEqual(len(mapper.chem_groups), 3) @@ -89,11 +80,6 @@ class TestChainMapper(unittest.TestCase): self.assertEqual(str(mapper.chem_group_ref_seqs[0]), str(pep_s_one)) self.assertEqual(str(mapper.chem_group_ref_seqs[1]), str(nuc_s_one)) self.assertEqual(str(mapper.chem_group_ref_seqs[2]), str(nuc_s_two)) - for s in mapper.chem_group_ref_seqs: - self.assertTrue(s.HasAttachedView()) - self.assertTrue(_CompareViews(mapper.chem_group_ref_seqs[0].GetAttachedView(), pep_view_one)) - self.assertTrue(_CompareViews(mapper.chem_group_ref_seqs[1].GetAttachedView(), nuc_view_one)) - self.assertTrue(_CompareViews(mapper.chem_group_ref_seqs[2].GetAttachedView(), nuc_view_two)) # check chem_group_alignments attribute self.assertEqual(len(mapper.chem_group_alignments), 3) @@ -108,14 +94,6 @@ class TestChainMapper(unittest.TestCase): self.assertEqual(s0.GetGaplessString(), str(nuc_s_one)) s0 = mapper.chem_group_alignments[2].GetSequence(0) self.assertEqual(s0.GetGaplessString(), str(nuc_s_two)) - self.assertTrue(mapper.chem_group_alignments[0].GetSequence(0).HasAttachedView()) - self.assertTrue(mapper.chem_group_alignments[0].GetSequence(1).HasAttachedView()) - self.assertTrue(mapper.chem_group_alignments[1].GetSequence(0).HasAttachedView()) - self.assertTrue(mapper.chem_group_alignments[2].GetSequence(0).HasAttachedView()) - self.assertTrue(_CompareViews(mapper.chem_group_alignments[0].GetSequence(0).GetAttachedView(), pep_view_one)) - self.assertTrue(_CompareViews(mapper.chem_group_alignments[0].GetSequence(1).GetAttachedView(), pep_view_two)) - self.assertTrue(_CompareViews(mapper.chem_group_alignments[1].GetSequence(0).GetAttachedView(), nuc_view_one)) - self.assertTrue(_CompareViews(mapper.chem_group_alignments[2].GetSequence(0).GetAttachedView(), nuc_view_two)) # ensure that error is triggered if there are insertion codes # and resnum_alignments are enabled @@ -218,19 +196,6 @@ class TestChainMapper(unittest.TestCase): self.assertEqual(alns[2][0].GetSequence(0).GetGaplessString(), str(ref_nuc_s_two)) self.assertEqual(alns[2][0].GetSequence(1).GetGaplessString(), str(mdl_nuc_s_one)) - self.assertTrue(alns[0][0].GetSequence(0).HasAttachedView()) - self.assertTrue(alns[0][0].GetSequence(1).HasAttachedView()) - self.assertTrue(alns[0][1].GetSequence(0).HasAttachedView()) - self.assertTrue(alns[0][1].GetSequence(1).HasAttachedView()) - self.assertTrue(alns[2][0].GetSequence(0).HasAttachedView()) - self.assertTrue(alns[2][0].GetSequence(1).HasAttachedView()) - self.assertTrue(_CompareViews(alns[0][0].GetSequence(0).GetAttachedView(),ref_pep_view_one)) - self.assertTrue(_CompareViews(alns[0][0].GetSequence(1).GetAttachedView(),mdl_pep_view_one)) - self.assertTrue(_CompareViews(alns[0][1].GetSequence(0).GetAttachedView(),ref_pep_view_one)) - self.assertTrue(_CompareViews(alns[0][1].GetSequence(1).GetAttachedView(),mdl_pep_view_two)) - self.assertTrue(_CompareViews(alns[2][0].GetSequence(0).GetAttachedView(),ref_nuc_view_two)) - self.assertTrue(_CompareViews(alns[2][0].GetSequence(1).GetAttachedView(),mdl_nuc_view_one)) - # test for unmapped mdl chains, i.e. chains in the mdl that are not # present in ref mdl = _LoadFile("mdl_different_chain_mdl.pdb") @@ -355,7 +320,7 @@ class TestChainMapper(unittest.TestCase): self.assertTrue(mdl_s is not None) self.assertEqual(ref_s_type, mdl_s_type) - aln = mapper.Align(ref_s, mdl_s, ref_s_type) + aln = mapper.NWAlign(ref_s, mdl_s, ref_s_type) self.assertEqual(ref_aln.GetSequence(0).GetName(), aln.GetSequence(0).GetName()) self.assertEqual(ref_aln.GetSequence(1).GetName(), @@ -415,6 +380,18 @@ class TestChainMapper(unittest.TestCase): for ref_ch, mdl_ch in repr_mapping.items(): self.assertEqual(mdl_ch, flat_mapping[ref_ch]) + def test_resnum_aln_mapping(self): + ref = _LoadFile("3l1p.1.pdb") + mdl = _LoadFile("3l1p.1_model.pdb") + mapper = ChainMapper(ref, resnum_alignments=True) + + # Again, no in depth testing. Its only about checking if residue number + # alignments run through + + # lDDT based chain mappings + naive_lddt_res = mapper.GetlDDTMapping(mdl, strategy="naive") + self.assertEqual(naive_lddt_res.mapping, [['X', 'Y'],[None],['Z']]) + def test_misc(self): # check for triggered error when no chain fulfills length threshold diff --git a/modules/mol/alg/tests/test_scoring.py b/modules/mol/alg/tests/test_scoring.py new file mode 100644 index 0000000000000000000000000000000000000000..162808afce53669e169df279c1a05e4efcd28ef1 --- /dev/null +++ b/modules/mol/alg/tests/test_scoring.py @@ -0,0 +1,242 @@ +import unittest, os, sys +import ost +from ost import io, mol, geom +# check if we can import: fails if numpy or scipy not available +try: + from ost.mol.alg.scoring import * +except ImportError: + print("Failed to import scoring.py. Happens when numpy or scipy "\ + "missing. Ignoring test_scoring.py tests.") + sys.exit(0) + +def _LoadFile(file_name): + """Helper to avoid repeating input path over and over.""" + return io.LoadPDB(os.path.join('testfiles', file_name)) + +class TestScorer(unittest.TestCase): + + # compare to hardcoded values - no in depth testing + # this should be sufficient to flag issues when changes are introduced + + def test_scorer_lddt(self): + + mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + trg = _LoadFile("1eud_ref.pdb") + + scorer = Scorer(mdl, trg) + + # check global lDDT values + self.assertAlmostEqual(scorer.lddt, 0.539, 3) + self.assertAlmostEqual(scorer.bb_lddt, 0.622, 3) + self.assertAlmostEqual(scorer.ilddt, 0.282, 3) + + for ch in scorer.model.chains: + self.assertTrue(ch.name in scorer.local_lddt) + self.assertEqual(len(scorer.local_lddt[ch.name]), ch.GetResidueCount()) + self.assertTrue(ch.name in scorer.bb_local_lddt) + self.assertEqual(len(scorer.bb_local_lddt[ch.name]), ch.GetResidueCount()) + + for ch in scorer.model.chains: + self.assertTrue(ch.name in scorer.aa_local_lddt) + self.assertEqual(len(scorer.aa_local_lddt[ch.name]), ch.GetResidueCount()) + + # check some random per-residue/per-atom scores + self.assertEqual(scorer.local_lddt["B"][mol.ResNum(42)], 0.659, 3) + self.assertEqual(scorer.local_lddt["A"][mol.ResNum(142)], 0.849, 3) + self.assertEqual(scorer.bb_local_lddt["B"][mol.ResNum(42)], 0.782, 3) + self.assertEqual(scorer.bb_local_lddt["A"][mol.ResNum(142)], 0.910, 3) + self.assertEqual(scorer.aa_local_lddt["B"][mol.ResNum(42)]["CA"], 0.718, 3) + self.assertEqual(scorer.aa_local_lddt["A"][mol.ResNum(142)]["CB"], 0.837, 3) + + # test stereochemistry checks related behaviour + + # stereochemistry issue in mdl sidechain + bad_mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + ed = bad_mdl.EditXCS() + at = bad_mdl.FindResidue("B", mol.ResNum(42)).FindAtom("CD") + pos = at.GetPos() + new_pos = geom.Vec3(pos[0], pos[1], pos[2] + 1.0) + ed.SetAtomPos(at, new_pos) + scorer = Scorer(bad_mdl, trg) + # original score without stereochemistry issues: 0.659 + # now it should be much worse but above zero + penalized_score = scorer.local_lddt["B"][mol.ResNum(42)] + self.assertTrue(penalized_score < 0.5 and penalized_score > 0.1) + + # let's make it really bad, i.e. involve a backbone atom + at = bad_mdl.FindResidue("B", mol.ResNum(42)).FindAtom("CA") + pos = at.GetPos() + new_pos = geom.Vec3(pos[0], pos[1], pos[2] + 1.0) + ed.SetAtomPos(at, new_pos) + scorer = Scorer(bad_mdl, trg) + # original score without stereochemistry issues: 0.659 + # now it should be 0.0 + penalized_score = scorer.local_lddt["B"][mol.ResNum(42)] + self.assertEqual(penalized_score, 0.0) + + # let's fiddle around in trg stereochemistry + bad_trg = _LoadFile("1eud_ref.pdb") + ed = bad_trg.EditXCS() + # thats the atom that should map to B.42 in mdl + at = bad_trg.FindResidue("B", mol.ResNum(5)).FindAtom("CD") + pos = at.GetPos() + new_pos = geom.Vec3(pos[0], pos[1], pos[2] + 1.0) + ed.SetAtomPos(at, new_pos) + scorer = Scorer(mdl, bad_trg) + # there is no reference info anymore on the whole sidechain + # The scores of the sidechain atoms should thus be None + self.assertTrue(scorer.aa_local_lddt["B"][mol.ResNum(42)]["CD"] is None) + # but not the sidechain atom scores from backbone atoms! + self.assertFalse(scorer.aa_local_lddt["B"][mol.ResNum(42)]["CA"] is None) + # also the full per-residue score is still a valid number + self.assertFalse(scorer.local_lddt["B"][mol.ResNum(42)] is None) + + bad_trg = _LoadFile("1eud_ref.pdb") + ed = bad_trg.EditXCS() + at = bad_trg.FindResidue("B", mol.ResNum(5)).FindAtom("CA") + pos = at.GetPos() + new_pos = geom.Vec3(pos[0], pos[1], pos[2] + 1.0) + + ed.SetAtomPos(at, new_pos) + scorer = Scorer(mdl, bad_trg) + + # all scores should be None now for this residue + self.assertTrue(scorer.aa_local_lddt["B"][mol.ResNum(42)]["CD"] is None) + self.assertTrue(scorer.aa_local_lddt["B"][mol.ResNum(42)]["CA"] is None) + self.assertTrue(scorer.local_lddt["B"][mol.ResNum(42)] is None) + + + def test_scorer_qsscore(self): + + mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + trg = _LoadFile("1eud_ref.pdb") + + scorer = Scorer(mdl, trg) + + # check qs-score related values + self.assertAlmostEqual(scorer.qs_global, 0.321, 3) + self.assertAlmostEqual(scorer.qs_best, 0.932, 3) + self.assertEqual(len(scorer.qs_interfaces), 1) + self.assertEqual(scorer.qs_interfaces[0], ("A", "B", "A", "B")) + + # should be equal global scores since we're only dealing with + # a single interface + self.assertAlmostEqual(scorer.per_interface_qs_global[0], 0.321, 3) + self.assertAlmostEqual(scorer.per_interface_qs_best[0], 0.932, 3) + + def test_scorer_rigid_scores(self): + mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + trg = _LoadFile("1eud_ref.pdb") + scorer = Scorer(mdl, trg) + self.assertAlmostEqual(scorer.gdtts, 0.616, 3) + self.assertAlmostEqual(scorer.gdtha, 0.473, 3) + self.assertAlmostEqual(scorer.rmsd, 2.944, 3) + + def test_scorer_contacts(self): + mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + trg = _LoadFile("1eud_ref.pdb") + scorer = Scorer(mdl, trg) + + self.assertEqual(len(scorer.model_contacts), 48) + self.assertEqual(len(scorer.native_contacts), 140) + + self.assertAlmostEqual(scorer.ics, 0.415, 3) + self.assertAlmostEqual(scorer.ics_precision, 0.812, 3) + self.assertAlmostEqual(scorer.ics_recall, 0.279, 3) + self.assertAlmostEqual(scorer.ips, 0.342, 3) + self.assertAlmostEqual(scorer.ips_precision, 0.891, 3) + self.assertAlmostEqual(scorer.ips_recall, 0.357, 3) + + + # per interface scores should be equal since we're only dealing with one + # interface + self.assertEqual(len(scorer.per_interface_ics), 1) + self.assertAlmostEqual(scorer.per_interface_ics[0], 0.415, 3) + self.assertAlmostEqual(scorer.per_interface_ics_precision[0], 0.812, 3) + self.assertAlmostEqual(scorer.per_interface_ics_recall[0], 0.279, 3) + + self.assertEqual(len(scorer.per_interface_ips), 1) + self.assertAlmostEqual(scorer.per_interface_ips[0], 0.342, 3) + self.assertAlmostEqual(scorer.per_interface_ips_precision[0], 0.891, 3) + self.assertAlmostEqual(scorer.per_interface_ips_recall[0], 0.357, 3) + + def test_scorer_dockq(self): + mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + trg = _LoadFile("1eud_ref.pdb") + scorer = Scorer(mdl, trg) + self.assertEqual(scorer.dockq_interfaces, [("A", "B", "A", "B")]) + self.assertAlmostEqual(scorer.dockq_scores[0], 0.559, 3) + self.assertAlmostEqual(scorer.fnat[0], 0.279, 3) + self.assertAlmostEqual(scorer.fnonnat[0], 0.188, 2) + self.assertAlmostEqual(scorer.irmsd[0], 0.988, 3) + self.assertAlmostEqual(scorer.lrmsd[0], 5.533, 3) + self.assertEqual(scorer.nnat[0], 140) + self.assertEqual(scorer.nmdl[0], 48) + + # ave and wave values are the same as scorer.dockq_scores[0] + self.assertAlmostEqual(scorer.dockq_wave, scorer.dockq_scores[0], 7) + self.assertAlmostEqual(scorer.dockq_ave, scorer.dockq_scores[0], 7) + + def test_scorer_patch_scores(self): + mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + trg = _LoadFile("1eud_ref.pdb") + scorer = Scorer(mdl, trg) + + patch_qs = scorer.patch_qs + patch_dockq = scorer.patch_dockq + + # check some random values + self.assertAlmostEqual(patch_qs["B"][5], 0.925, 3) + self.assertAlmostEqual(patch_qs["B"][17], 0.966, 3) + self.assertAlmostEqual(patch_qs["A"][4], 0.973, 3) + + self.assertAlmostEqual(patch_dockq["B"][5], 0.858, 3) + self.assertAlmostEqual(patch_dockq["B"][17], 0.965, 3) + self.assertAlmostEqual(patch_dockq["A"][4], 0.973, 3) + + def test_scorer_trimmed_contacts(self): + mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + trg = _LoadFile("1eud_mdl_partial-dimer.pdb") + + scorer = Scorer(mdl, trg) + + # mdl and trg are the same + self.assertEqual(scorer.ics, 1.0) + self.assertEqual(scorer.ips, 1.0) + + # let's cut a critical interface loop in the target + trg = trg.Select("cname=A or (cname=B and rnum!=157:168)") + scorer = Scorer(mdl, trg) + + # mdl is now longer than trg which lowers ICS/IPS + self.assertTrue(scorer.ics < 0.75) + self.assertTrue(scorer.ips < 0.75) + + # but if we use the trimmed versions, it should go up to 1.0 + # again + self.assertEqual(scorer.ics_trimmed, 1.0) + self.assertEqual(scorer.ips_trimmed, 1.0) + + # lets see if the trimmed model has the right + # residues missing + for r in scorer.model.residues: + cname = r.GetChain().GetName() + rnum = r.GetNumber() + trimmed_r = scorer.trimmed_model.FindResidue(cname, rnum) + if cname == "B" and (rnum.num >= 157 and rnum.num <= 168): + self.assertFalse(trimmed_r.IsValid()) + else: + self.assertTrue(trimmed_r.IsValid()) + + def test_scorer_tmscore(self): + mdl = _LoadFile("1eud_mdl_partial-dimer.pdb") + trg = _LoadFile("1eud_ref.pdb") + scorer = Scorer(mdl, trg) + self.assertAlmostEqual(scorer.tm_score, 0.711, 3) + +if __name__ == "__main__": + from ost import testutils + if testutils.DefaultCompoundLibIsSet(): + testutils.RunTests() + else: + print('No compound lib available. Ignoring test_scoring.py tests.')