""" Chain mapping aims to identify a one-to-one relationship between chains in a reference structure and a model. """ import itertools import copy import numpy as np from scipy.special import factorial from scipy.special import binom # as of Python 3.8, the math module implements # comb, i.e. n choose k import ost from ost import seq from ost import mol from ost import geom from ost.mol.alg import lddt from ost.mol.alg import qsscore def _CSel(ent, cnames): """ Returns view with specified chains Ensures that quotation marks are around chain names to not confuse OST query language with weird special characters. """ query = "cname=" + ','.join([mol.QueryQuoteName(cname) for cname in cnames]) return ent.Select(query) class MappingResult: """ Result object for the chain mapping functions in :class:`ChainMapper` Constructor is directly called within the functions, no need to construct such objects yourself. """ def __init__(self, target, model, chem_groups, chem_mapping, mapping, alns, opt_score=None): self._target = target self._model = model self._chem_groups = chem_groups self._chem_mapping = chem_mapping self._mapping = mapping self._alns = alns self._opt_score = opt_score @property def target(self): """ Target/reference structure, i.e. :attr:`ChainMapper.target` :type: :class:`ost.mol.EntityView` """ return self._target @property def model(self): """ Model structure that gets mapped onto :attr:`~target` Underwent same processing as :attr:`ChainMapper.target`, i.e. only contains peptide/nucleotide chains of sufficient size. :type: :class:`ost.mol.EntityView` """ return self._model @property def chem_groups(self): """ Groups of chemically equivalent chains in :attr:`~target` Same as :attr:`ChainMapper.chem_group` :class:`list` of :class:`list` of :class:`str` (chain names) """ return self._chem_groups @property def chem_mapping(self): """ Assigns chains in :attr:`~model` to :attr:`~chem_groups`. :class:`list` of :class:`list` of :class:`str` (chain names) """ return self._chem_mapping @property def mapping(self): """ Mapping of :attr:`~model` chains onto :attr:`~target` Exact same shape as :attr:`~chem_groups` but containing the names of the mapped chains in :attr:`~model`. May contain None for :attr:`~target` chains that are not covered. No guarantee that all chains in :attr:`~model` are mapped. :class:`list` of :class:`list` of :class:`str` (chain names) """ return self._mapping @property def alns(self): """ Alignments of mapped chains in :attr:`~target` and :attr:`~model` Each alignment is accessible with ``alns[(t_chain,m_chain)]``. First sequence is the sequence of :attr:`target` chain, second sequence the one from :attr:`~model`. The respective :class:`ost.mol.EntityView` are attached with :func:`ost.seq.ConstSequenceHandle.AttachView`. :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value: :class:`ost.seq.AlignmentHandle` """ return self._alns @property def opt_score(self): """ Placeholder property without any guarantee of being set Different scores get optimized in the various chain mapping algorithms. Some of them may set their final optimal score in that property. Consult the documentation of the respective chain mapping algorithm for more information. Won't be in the return dict of :func:`JSONSummary`. """ return self._opt_score def GetFlatMapping(self, mdl_as_key=False): """ Returns flat mapping as :class:`dict` for all mapable chains :param mdl_as_key: Default is target chain name as key and model chain name as value. This can be reversed with this flag. :returns: :class:`dict` with :class:`str` as key/value that describe one-to-one mapping """ flat_mapping = dict() for trg_chem_group, mdl_chem_group in zip(self.chem_groups, self.mapping): for a,b in zip(trg_chem_group, mdl_chem_group): if a is not None and b is not None: if mdl_as_key: flat_mapping[b] = a else: flat_mapping[a] = b return flat_mapping def JSONSummary(self): """ Returns JSON serializable summary of results """ json_dict = dict() json_dict["chem_groups"] = self.chem_groups json_dict["mapping"] = self.mapping json_dict["flat_mapping"] = self.GetFlatMapping() json_dict["alns"] = list() for aln in self.alns.values(): trg_seq = aln.GetSequence(0) mdl_seq = aln.GetSequence(1) aln_dict = {"trg_ch": trg_seq.GetName(), "trg_seq": str(trg_seq), "mdl_ch": mdl_seq.GetName(), "mdl_seq": str(mdl_seq)} json_dict["alns"].append(aln_dict) return json_dict class ReprResult: """ Result object for :func:`ChainMapper.GetRepr` Constructor is directly called within the function, no need to construct such objects yourself. :param lDDT: lDDT for this mapping. Depends on how you call :func:`ChainMapper.GetRepr` whether this is backbone only or full atom lDDT. :type lDDT: :class:`float` :param substructure: The full substructure for which we searched for a representation :type substructure: :class:`ost.mol.EntityView` :param ref_view: View pointing to the same underlying entity as *substructure* but only contains the stuff that is mapped :type ref_view: :class:`mol.EntityView` :param mdl_view: The matching counterpart in model :type mdl_view: :class:`mol.EntityView` """ def __init__(self, lDDT, substructure, ref_view, mdl_view): self._lDDT = lDDT self._substructure = substructure assert(len(ref_view.residues) == len(mdl_view.residues)) self._ref_view = ref_view self._mdl_view = mdl_view # lazily evaluated attributes self._ref_bb_pos = None self._mdl_bb_pos = None self._ref_full_bb_pos = None self._mdl_full_bb_pos = None self._transform = None self._superposed_mdl_bb_pos = None self._bb_rmsd = None self._gdt_8 = None self._gdt_4 = None self._gdt_2 = None self._gdt_1 = None self._ost_query = None self._flat_mapping = None self._inconsistent_residues = None @property def lDDT(self): """ lDDT of representation result Depends on how you call :func:`ChainMapper.GetRepr` whether this is backbone only or full atom lDDT. :type: :class:`float` """ return self._lDDT @property def substructure(self): """ The full substructure for which we searched for a representation :type: :class:`ost.mol.EntityView` """ return self._substructure @property def ref_view(self): """ View which contains the mapped subset of :attr:`substructure` :type: :class:`ost.mol.EntityView` """ return self._ref_view @property def mdl_view(self): """ The :attr:`ref_view` representation in the model :type: :class:`ost.mol.EntityView` """ return self._mdl_view @property def ref_residues(self): """ The reference residues :type: class:`mol.ResidueViewList` """ return self.ref_view.residues @property def mdl_residues(self): """ The model residues :type: :class:`mol.ResidueViewList` """ return self.mdl_view.residues @property def inconsistent_residues(self): """ A list of mapped residue whose names do not match (eg. ALA in the reference and LEU in the model). The mismatches are reported as a tuple of :class:`~ost.mol.ResidueView` (reference, model), or as an empty list if all the residue names match. :type: :class:`list` """ if self._inconsistent_residues is None: self._inconsistent_residues = self._GetInconsistentResidues( self.ref_residues, self.mdl_residues) return self._inconsistent_residues @property def ref_bb_pos(self): """ Representative backbone positions for reference residues. Thats CA positions for peptides and C3' positions for Nucleotides. :type: :class:`geom.Vec3List` """ if self._ref_bb_pos is None: self._ref_bb_pos = self._GetBBPos(self.ref_residues) return self._ref_bb_pos @property def mdl_bb_pos(self): """ Representative backbone positions for model residues. Thats CA positions for peptides and C3' positions for Nucleotides. :type: :class:`geom.Vec3List` """ if self._mdl_bb_pos is None: self._mdl_bb_pos = self._GetBBPos(self.mdl_residues) return self._mdl_bb_pos @property def ref_full_bb_pos(self): """ Representative backbone positions for reference residues. Thats N, CA and C positions for peptides and O5', C5', C4', C3', O3' positions for Nucleotides. :type: :class:`geom.Vec3List` """ if self._ref_full_bb_pos is None: self._ref_full_bb_pos = self._GetFullBBPos(self.ref_residues) return self._ref_full_bb_pos @property def mdl_full_bb_pos(self): """ Representative backbone positions for reference residues. Thats N, CA and C positions for peptides and O5', C5', C4', C3', O3' positions for Nucleotides. :type: :class:`geom.Vec3List` """ if self._mdl_full_bb_pos is None: self._mdl_full_bb_pos = self._GetFullBBPos(self.mdl_residues) return self._mdl_full_bb_pos @property def transform(self): """ Transformation to superpose mdl residues onto ref residues Superposition computed as minimal RMSD superposition on :attr:`ref_bb_pos` and :attr:`mdl_bb_pos`. If number of positions is smaller 3, the full_bb_pos equivalents are used instead. :type: :class:`ost.geom.Mat4` """ if self._transform is None: if len(self.mdl_bb_pos) < 3: self._transform = _GetTransform(self.mdl_full_bb_pos, self.ref_full_bb_pos, False) else: self._transform = _GetTransform(self.mdl_bb_pos, self.ref_bb_pos, False) return self._transform @property def superposed_mdl_bb_pos(self): """ :attr:`mdl_bb_pos` with :attr:`transform applied` :type: :class:`geom.Vec3List` """ if self._superposed_mdl_bb_pos is None: self._superposed_mdl_bb_pos = geom.Vec3List(self.mdl_bb_pos) self._superposed_mdl_bb_pos.ApplyTransform(self.transform) return self._superposed_mdl_bb_pos @property def bb_rmsd(self): """ RMSD between :attr:`ref_bb_pos` and :attr:`superposed_mdl_bb_pos` :type: :class:`float` """ if self._bb_rmsd is None: self._bb_rmsd = self.ref_bb_pos.GetRMSD(self.superposed_mdl_bb_pos) return self._bb_rmsd @property def gdt_8(self): """ GDT with one single threshold: 8.0 :type: :class:`float` """ if self._gdt_8 is None: self._gdt_8 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 8.0) return self._gdt_8 @property def gdt_4(self): """ GDT with one single threshold: 4.0 :type: :class:`float` """ if self._gdt_4 is None: self._gdt_4 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 4.0) return self._gdt_4 @property def gdt_2(self): """ GDT with one single threshold: 2.0 :type: :class:`float` """ if self._gdt_2 is None: self._gdt_2 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 2.0) return self._gdt_2 @property def gdt_1(self): """ GDT with one single threshold: 1.0 :type: :class:`float` """ if self._gdt_1 is None: self._gdt_1 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 1.0) return self._gdt_1 @property def ost_query(self): """ query for mdl residues in OpenStructure query language Repr can be selected as ``full_mdl.Select(ost_query)`` Returns invalid query if residue numbers have insertion codes. :type: :class:`str` """ if self._ost_query is None: chain_rnums = dict() for r in self.mdl_residues: chname = r.GetChain().GetName() rnum = r.GetNumber().GetNum() if chname not in chain_rnums: chain_rnums[chname] = list() chain_rnums[chname].append(str(rnum)) chain_queries = list() for k,v in chain_rnums.items(): q = f"(cname={mol.QueryQuoteName(k)} and " q += f"rnum={','.join(v)})" chain_queries.append(q) self._ost_query = " or ".join(chain_queries) return self._ost_query def JSONSummary(self): """ Returns JSON serializable summary of results """ json_dict = dict() json_dict["lDDT"] = self.lDDT json_dict["ref_residues"] = [r.GetQualifiedName() for r in \ self.ref_residues] json_dict["mdl_residues"] = [r.GetQualifiedName() for r in \ self.mdl_residues] json_dict["transform"] = list(self.transform.data) json_dict["bb_rmsd"] = self.bb_rmsd json_dict["gdt_8"] = self.gdt_8 json_dict["gdt_4"] = self.gdt_4 json_dict["gdt_2"] = self.gdt_2 json_dict["gdt_1"] = self.gdt_1 json_dict["ost_query"] = self.ost_query json_dict["flat_mapping"] = self.GetFlatChainMapping() return json_dict def GetFlatChainMapping(self, mdl_as_key=False): """ Returns flat mapping of all chains in the representation :param mdl_as_key: Default is target chain name as key and model chain name as value. This can be reversed with this flag. :returns: :class:`dict` with :class:`str` as key/value that describe one-to-one mapping """ flat_mapping = dict() for trg_res, mdl_res in zip(self.ref_residues, self.mdl_residues): if mdl_as_key: flat_mapping[mdl_res.chain.name] = trg_res.chain.name else: flat_mapping[trg_res.chain.name] = mdl_res.chain.name return flat_mapping def _GetFullBBPos(self, residues): """ Helper to extract full backbone positions """ exp_pep_atoms = ["N", "CA", "C"] exp_nuc_atoms = ["\"O5'\"", "\"C5'\"", "\"C4'\"", "\"C3'\"", "\"O3'\""] bb_pos = geom.Vec3List() for r in residues: if r.GetChemType() == mol.ChemType.NUCLEOTIDES: exp_atoms = exp_nuc_atoms elif r.GetChemType() == mol.ChemType.AMINOACIDS: exp_atoms = exp_pep_atoms else: raise RuntimeError("Something terrible happened... RUN...") for aname in exp_atoms: a = r.FindAtom(aname) if not a.IsValid(): raise RuntimeError("Something terrible happened... " "RUN...") bb_pos.append(a.GetPos()) return bb_pos def _GetBBPos(self, residues): """ Helper to extract single representative position for each residue """ bb_pos = geom.Vec3List() for r in residues: at = r.FindAtom("CA") if not at.IsValid(): at = r.FindAtom("C3'") if not at.IsValid(): raise RuntimeError("Something terrible happened... RUN...") bb_pos.append(at.GetPos()) return bb_pos def _GetInconsistentResidues(self, ref_residues, mdl_residues): """ Helper to extract a list of inconsistent residues. """ if len(ref_residues) != len(mdl_residues): raise ValueError("Something terrible happened... Reference and " "model lengths differ... RUN...") inconsistent_residues = list() for ref_residue, mdl_residue in zip(ref_residues, mdl_residues): if ref_residue.name != mdl_residue.name: inconsistent_residues.append((ref_residue, mdl_residue)) return inconsistent_residues class ChainMapper: """ Class to compute chain mappings All algorithms are performed on processed structures which fulfill criteria as given in constructor arguments (*min_pep_length*, "min_nuc_length") and only contain residues which have all required backbone atoms. for peptide residues thats N, CA, C and CB (no CB for GLY), for nucleotide residues thats O5', C5', C4', C3' and O3'. Chain mapping is a three step process: * Group chemically identical chains in *target* using pairwise alignments that are either computed with Needleman-Wunsch (NW) or simply derived from residue numbers (*resnum_alignments* flag). In case of NW, *pep_subst_mat*, *pep_gap_open* and *pep_gap_ext* and their nucleotide equivalents are relevant. Two chains are considered identical if they fulfill the thresholds given by *pep_seqid_thr*, *pep_gap_thr*, their nucleotide equivalents respectively. The grouping information is available as attributes of this class. * Map chains in an input model to these groups. Generating alignments and the similarity criteria are the same as above. You can either get the group mapping with :func:`GetChemMapping` or directly call one of the full fletched one-to-one chain mapping functions which execute that step internally. * Obtain one-to-one mapping for chains in an input model and *target* with one of the available mapping functions. Just to get an idea of complexity. If *target* and *model* are octamers, there are ``8! = 40320`` possible chain mappings. :param target: Target structure onto which models are mapped. Computations happen on a selection only containing polypeptides and polynucleotides. :type target: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` :param resnum_alignments: Use residue numbers instead of Needleman-Wunsch to compute pairwise alignments. Relevant for :attr:`~chem_groups` and related attributes. :type resnum_alignments: :class:`bool` :param pep_seqid_thr: Threshold used to decide when two chains are identical. 95 percent tolerates the few mutations crystallographers like to do. :type pep_seqid_thr: :class:`float` :param pep_gap_thr: Additional threshold to avoid gappy alignments with high seqid. By default this is disabled (set to 1.0). This threshold checks for a maximum allowed fraction of gaps in any of the two sequences after stripping terminal gaps. The reason for not just normalizing seqid by the longer sequence is that one sequence might be a perfect subsequence of the other but only cover half of it. :type pep_gap_thr: :class:`float` :param nuc_seqid_thr: Nucleotide equivalent for *pep_seqid_thr* :type nuc_seqid_thr: :class:`float` :param nuc_gap_thr: Nucleotide equivalent for *nuc_gap_thr* :type nuc_gap_thr: :class:`float` :param pep_subst_mat: Substitution matrix to align peptide sequences, irrelevant if *resnum_alignments* is True, defaults to seq.alg.BLOSUM62 :type pep_subst_mat: :class:`ost.seq.alg.SubstWeightMatrix` :param pep_gap_open: Gap open penalty to align peptide sequences, irrelevant if *resnum_alignments* is True :type pep_gap_open: :class:`int` :param pep_gap_ext: Gap extension penalty to align peptide sequences, irrelevant if *resnum_alignments* is True :type pep_gap_ext: :class:`int` :param nuc_subst_mat: Nucleotide equivalent for *pep_subst_mat*, defaults to seq.alg.NUC44 :type nuc_subst_mat: :class:`ost.seq.alg.SubstWeightMatrix` :param nuc_gap_open: Nucleotide equivalent for *pep_gap_open* :type nuc_gap_open: :class:`int` :param nuc_gap_ext: Nucleotide equivalent for *pep_gap_ext* :type nuc_gap_ext: :class:`int` :param min_pep_length: Minimal number of residues for a peptide chain to be considered in target and in models. :type min_pep_length: :class:`int` :param min_nuc_length: Minimal number of residues for a nucleotide chain to be considered in target and in models. :type min_nuc_length: :class:`int` :param n_max_naive: Max possible chain mappings that are enumerated in :func:`~GetNaivelDDTMapping` / :func:`~GetDecomposerlDDTMapping`. A :class:`RuntimeError` is raised in case of bigger complexity. :type n_max_naive: :class:`int` """ def __init__(self, target, resnum_alignments=False, pep_seqid_thr = 95., pep_gap_thr = 1.0, nuc_seqid_thr = 95., nuc_gap_thr = 1.0, pep_subst_mat = seq.alg.BLOSUM62, pep_gap_open = -11, pep_gap_ext = -1, nuc_subst_mat = seq.alg.NUC44, nuc_gap_open = -4, nuc_gap_ext = -4, min_pep_length = 10, min_nuc_length = 4, n_max_naive = 1e8): # attributes self.resnum_alignments = resnum_alignments self.pep_seqid_thr = pep_seqid_thr self.pep_gap_thr = pep_gap_thr self.nuc_seqid_thr = nuc_seqid_thr self.nuc_gap_thr = nuc_gap_thr self.min_pep_length = min_pep_length self.min_nuc_length = min_nuc_length self.n_max_naive = n_max_naive # lazy computed attributes self._chem_groups = None self._chem_group_alignments = None self._chem_group_ref_seqs = None self._chem_group_types = None # helper class to generate pairwise alignments self.aligner = _Aligner(resnum_aln = resnum_alignments, pep_subst_mat = pep_subst_mat, pep_gap_open = pep_gap_open, pep_gap_ext = pep_gap_ext, nuc_subst_mat = nuc_subst_mat, nuc_gap_open = nuc_gap_open, nuc_gap_ext = nuc_gap_ext) # target structure preprocessing self._target, self._polypep_seqs, self._polynuc_seqs = \ self.ProcessStructure(target) @property def target(self): """Target structure that only contains peptides/nucleotides Contains only residues that have the backbone representatives (CA for peptide and C3' for nucleotides) to avoid ATOMSEQ alignment inconsistencies when switching between all atom and backbone only representations. :type: :class:`ost.mol.EntityView` """ return self._target @property def polypep_seqs(self): """Sequences of peptide chains in :attr:`~target` Respective :class:`EntityView` from *target* for each sequence s are available as ``s.GetAttachedView()`` :type: :class:`ost.seq.SequenceList` """ return self._polypep_seqs @property def polynuc_seqs(self): """Sequences of nucleotide chains in :attr:`~target` Respective :class:`EntityView` from *target* for each sequence s are available as ``s.GetAttachedView()`` :type: :class:`ost.seq.SequenceList` """ return self._polynuc_seqs @property def chem_groups(self): """Groups of chemically equivalent chains in :attr:`~target` First chain in group is the one with longest sequence. :getter: Computed on first use (cached) :type: :class:`list` of :class:`list` of :class:`str` (chain names) """ if self._chem_groups is None: self._chem_groups = list() for a in self.chem_group_alignments: self._chem_groups.append([s.GetName() for s in a.sequences]) return self._chem_groups @property def chem_group_alignments(self): """MSA for each group in :attr:`~chem_groups` Sequences in MSAs exhibit same order as in :attr:`~chem_groups` and have the respective :class:`ost.mol.EntityView` from *target* attached. :getter: Computed on first use (cached) :type: :class:`ost.seq.AlignmentList` """ if self._chem_group_alignments is None: self._chem_group_alignments, self._chem_group_types = \ _GetChemGroupAlignments(self.polypep_seqs, self.polynuc_seqs, self.aligner, pep_seqid_thr=self.pep_seqid_thr, pep_gap_thr=self.pep_gap_thr, nuc_seqid_thr=self.nuc_seqid_thr, nuc_gap_thr=self.nuc_gap_thr) return self._chem_group_alignments @property def chem_group_ref_seqs(self): """Reference (longest) sequence for each group in :attr:`~chem_groups` Respective :class:`EntityView` from *target* for each sequence s are available as ``s.GetAttachedView()`` :getter: Computed on first use (cached) :type: :class:`ost.seq.SequenceList` """ if self._chem_group_ref_seqs is None: self._chem_group_ref_seqs = seq.CreateSequenceList() for a in self.chem_group_alignments: s = seq.CreateSequence(a.GetSequence(0).GetName(), a.GetSequence(0).GetGaplessString()) s.AttachView(a.GetSequence(0).GetAttachedView()) self._chem_group_ref_seqs.AddSequence(s) return self._chem_group_ref_seqs @property def chem_group_types(self): """ChemType of each group in :attr:`~chem_groups` Specifying if groups are poly-peptides/nucleotides, i.e. :class:`ost.mol.ChemType.AMINOACIDS` or :class:`ost.mol.ChemType.NUCLEOTIDES` :getter: Computed on first use (cached) :type: :class:`list` of :class:`ost.mol.ChemType` """ if self._chem_group_types is None: self._chem_group_alignments, self._chem_group_types = \ _GetChemGroupAlignments(self.polypep_seqs, self.polynuc_seqs, self.aligner, pep_seqid_thr=self.pep_seqid_thr, pep_gap_thr=self.pep_gap_thr, nuc_seqid_thr=self.nuc_seqid_thr, nuc_gap_thr=self.nuc_gap_thr) return self._chem_group_types def GetChemMapping(self, model): """Maps sequences in *model* to chem_groups of target :param model: Model from which to extract sequences, a selection that only includes peptides and nucleotides is performed and returned along other results. :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` :returns: Tuple with two lists of length `len(self.chem_groups)` and an :class:`ost.mol.EntityView` representing *model*: 1) Each element is a :class:`list` with mdl chain names that map to the chem group at that position. 2) Each element is a :class:`ost.seq.AlignmentList` aligning these mdl chain sequences to the chem group ref sequences. 3) A selection of *model* that only contains polypeptides and polynucleotides whose ATOMSEQ exactly matches the sequence info in the returned alignments. """ mdl, mdl_pep_seqs, mdl_nuc_seqs = self.ProcessStructure(model) mapping = [list() for x in self.chem_groups] alns = [seq.AlignmentList() for x in self.chem_groups] for s in mdl_pep_seqs: idx, aln = _MapSequence(self.chem_group_ref_seqs, self.chem_group_types, s, mol.ChemType.AMINOACIDS, self.aligner) if idx is not None: mapping[idx].append(s.GetName()) alns[idx].append(aln) for s in mdl_nuc_seqs: idx, aln = _MapSequence(self.chem_group_ref_seqs, self.chem_group_types, s, mol.ChemType.NUCLEOTIDES, self.aligner) if idx is not None: mapping[idx].append(s.GetName()) alns[idx].append(aln) return (mapping, alns, mdl) def GetlDDTMapping(self, model, inclusion_radius=15.0, thresholds=[0.5, 1.0, 2.0, 4.0], strategy="naive", steep_opt_rate = None, block_seed_size = 5, block_blocks_per_chem_group = 5, chem_mapping_result = None): """ Identify chain mapping by optimizing lDDT score Maps *model* chain sequences to :attr:`~chem_groups` and find mapping based on backbone only lDDT score (CA for amino acids C3' for Nucleotides). Either performs a naive search, i.e. enumerate all possible mappings or executes a greedy strategy that tries to identify a (close to) optimal mapping in an iterative way by starting from a start mapping (seed). In each iteration, the one-to-one mapping that leads to highest increase in number of conserved contacts is added with the additional requirement that this added mapping must have non-zero interface counts towards the already mapped chains. So basically we're "growing" the mapped structure by only adding connected stuff. The available strategies: * **naive**: Enumerates all possible mappings and returns best * **greedy_fast**: perform all vs. all single chain lDDTs within the respective ref/mdl chem groups. The mapping with highest number of conserved contacts is selected as seed for greedy extension * **greedy_full**: try multiple seeds for greedy extension, i.e. try all ref/mdl chain combinations within the respective chem groups and retain the mapping leading to the best lDDT. * **greedy_block**: try multiple seeds for greedy extension, i.e. try all ref/mdl chain combinations within the respective chem groups and extend them to *block_seed_size*. *block_blocks_per_chem_group* for each chem group are selected for exhaustive extension. Sets :attr:`MappingResult.opt_score` in case of no trivial one-to-one mapping. :param model: Model to map :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` :param inclusion_radius: Inclusion radius for lDDT :type inclusion_radius: :class:`float` :param thresholds: Thresholds for lDDT :type thresholds: :class:`list` of :class:`float` :param strategy: Strategy to find mapping. Must be in ["naive", "greedy_fast", "greedy_full", "greedy_block"] :type strategy: :class:`str` :param steep_opt_rate: Only relevant for greedy strategies. If set, every *steep_opt_rate* mappings, a simple optimization is executed with the goal of avoiding local minima. The optimization iteratively checks all possible swaps of mappings within their respective chem groups and accepts swaps that improve lDDT score. Iteration stops as soon as no improvement can be achieved anymore. :type steep_opt_rate: :class:`int` :param block_seed_size: Param for *greedy_block* strategy - Initial seeds are extended by that number of chains. :type block_seed_size: :class:`int` :param block_blocks_per_chem_group: Param for *greedy_block* strategy - Number of blocks per chem group that are extended in an initial search for high scoring local solutions. :type block_blocks_per_chem_group: :class:`int` :param chem_mapping_result: Pro param. The result of :func:`~GetChemMapping` where you provided *model*. If set, *model* parameter is not used. :type chem_mapping_result: :class:`tuple` :returns: A :class:`MappingResult` """ strategies = ["naive", "greedy_fast", "greedy_full", "greedy_block"] if strategy not in strategies: raise RuntimeError(f"Strategy must be in {strategies}") if chem_mapping_result is None: chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) else: chem_mapping, chem_group_alns, mdl = chem_mapping_result ref_mdl_alns = _GetRefMdlAlns(self.chem_groups, self.chem_group_alignments, chem_mapping, chem_group_alns) # check for the simplest case one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping) if one_to_one is not None: alns = dict() for ref_group, mdl_group in zip(self.chem_groups, one_to_one): for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] aln.AttachView(0, _CSel(self.target, [ref_ch])) aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, one_to_one, alns) mapping = None opt_lddt = None if strategy == "naive": mapping, opt_lddt = _lDDTNaive(self.target, mdl, inclusion_radius, thresholds, self.chem_groups, chem_mapping, ref_mdl_alns, self.n_max_naive) else: # its one of the greedy strategies - setup greedy searcher the_greed = _lDDTGreedySearcher(self.target, mdl, self.chem_groups, chem_mapping, ref_mdl_alns, inclusion_radius=inclusion_radius, thresholds=thresholds, steep_opt_rate=steep_opt_rate) if strategy == "greedy_fast": mapping = _lDDTGreedyFast(the_greed) elif strategy == "greedy_full": mapping = _lDDTGreedyFull(the_greed) elif strategy == "greedy_block": mapping = _lDDTGreedyBlock(the_greed, block_seed_size, block_blocks_per_chem_group) # cached => lDDT computation is fast here opt_lddt = the_greed.lDDT(self.chem_groups, mapping) alns = dict() for ref_group, mdl_group in zip(self.chem_groups, mapping): for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] aln.AttachView(0, _CSel(self.target, [ref_ch])) aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, mapping, alns, opt_score = opt_lddt) def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive", block_seed_size = 5, block_blocks_per_chem_group = 5, steep_opt_rate = None, chem_mapping_result = None, greedy_prune_contact_map = False): """ Identify chain mapping based on QSScore Scoring is based on CA/C3' positions which are present in all chains of a :attr:`chem_groups` as well as the *model* chains which are mapped to that respective chem group. The following strategies are available: * **naive**: Naively iterate all possible mappings and return best based on QS score. * **greedy_fast**: perform all vs. all single chain lDDTs within the respective ref/mdl chem groups. The mapping with highest number of conserved contacts is selected as seed for greedy extension. Extension is based on QS-score. * **greedy_full**: try multiple seeds for greedy extension, i.e. try all ref/mdl chain combinations within the respective chem groups and retain the mapping leading to the best QS-score. * **greedy_block**: try multiple seeds for greedy extension, i.e. try all ref/mdl chain combinations within the respective chem groups and extend them to *block_seed_size*. *block_blocks_per_chem_group* for each chem group are selected for exhaustive extension. Sets :attr:`MappingResult.opt_score` in case of no trivial one-to-one mapping. :param model: Model to map :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` :param contact_d: Max distance between two residues to be considered as contact in qs scoring :type contact_d: :class:`float` :param strategy: Strategy for sampling, must be in ["naive"] :type strategy: :class:`str` :param chem_mapping_result: Pro param. The result of :func:`~GetChemMapping` where you provided *model*. If set, *model* parameter is not used. :type chem_mapping_result: :class:`tuple` :param greedy_prune_contact_map: Relevant for all strategies that use greedy extensions. If True, only chains with at least 3 contacts (8A CB distance) towards already mapped chains in trg/mdl are considered for extension. All chains that give a potential non-zero QS-score increase are used otherwise (at least one contact within 12A). The consequence is reduced runtime and usually no real reduction in accuracy. :returns: A :class:`MappingResult` """ strategies = ["naive", "greedy_fast", "greedy_full", "greedy_block"] if strategy not in strategies: raise RuntimeError(f"strategy must be {strategies}") if chem_mapping_result is None: chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) else: chem_mapping, chem_group_alns, mdl = chem_mapping_result ref_mdl_alns = _GetRefMdlAlns(self.chem_groups, self.chem_group_alignments, chem_mapping, chem_group_alns) # check for the simplest case one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping) if one_to_one is not None: alns = dict() for ref_group, mdl_group in zip(self.chem_groups, one_to_one): for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] aln.AttachView(0, _CSel(self.target, [ref_ch])) aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, one_to_one, alns) mapping = None opt_qsscore = None if strategy == "naive": mapping, opt_qsscore = _QSScoreNaive(self.target, mdl, self.chem_groups, chem_mapping, ref_mdl_alns, contact_d, self.n_max_naive) else: # its one of the greedy strategies - setup greedy searcher the_greed = _QSScoreGreedySearcher(self.target, mdl, self.chem_groups, chem_mapping, ref_mdl_alns, contact_d = contact_d, steep_opt_rate=steep_opt_rate, greedy_prune_contact_map = greedy_prune_contact_map) if strategy == "greedy_fast": mapping = _QSScoreGreedyFast(the_greed) elif strategy == "greedy_full": mapping = _QSScoreGreedyFull(the_greed) elif strategy == "greedy_block": mapping = _QSScoreGreedyBlock(the_greed, block_seed_size, block_blocks_per_chem_group) # cached => QSScore computation is fast here opt_qsscore = the_greed.Score(mapping, check=False) alns = dict() for ref_group, mdl_group in zip(self.chem_groups, mapping): for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] aln.AttachView(0, _CSel(self.target, [ref_ch])) aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, mapping, alns, opt_score = opt_qsscore) def GetRigidMapping(self, model, strategy = "greedy_single_gdtts", single_chain_gdtts_thresh=0.4, subsampling=None, first_complete=False, iterative_superposition=False, chem_mapping_result = None): """Identify chain mapping based on rigid superposition Superposition and scoring is based on CA/C3' positions which are present in all chains of a :attr:`chem_groups` as well as the *model* chains which are mapped to that respective chem group. Transformations to superpose *model* onto :attr:`ChainMapper.target` are estimated using all possible combinations of target and model chains within the same chem groups and build the basis for further extension. There are four extension strategies: * **greedy_single_gdtts**: Iteratively add the model/target chain pair that adds the most conserved contacts based on the GDT-TS metric (Number of CA/C3' atoms within [8, 4, 2, 1] Angstrom). The mapping with highest GDT-TS score is returned. However, that mapping is not guaranteed to be complete (see *single_chain_gdtts_thresh*). * **greedy_iterative_gdtts**: Same as greedy_single_gdtts except that the transformation gets updated with each added chain pair. * **greedy_single_rmsd**: Conceptually similar to greedy_single_gdtts but the added chain pairs are the ones with lowest RMSD. The mapping with lowest overall RMSD gets returned. *single_chain_gdtts_thresh* is only applied to derive the initial transformations. After that, the minimal RMSD chain pair gets iteratively added without applying any threshold. * **greedy_iterative_rmsd**: Same as greedy_single_rmsd exept that the transformation gets updated with each added chain pair. *single_chain_gdtts_thresh* is only applied to derive the initial transformations. After that, the minimal RMSD chain pair gets iteratively added without applying any threshold. :param model: Model to map :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` :param strategy: Strategy to extend mappings from initial transforms, see description above. Must be in ["greedy_single", "greedy_iterative", "greedy_iterative_rmsd"] :type strategy: :class:`str` :param single_chain_gdtts_thresh: Minimal GDT-TS score for model/target chain pair to be added to mapping. Mapping extension for a given transform stops when no pair fulfills this threshold, potentially leading to an incomplete mapping. :type single_chain_gdtts_thresh: :class:`float` :param subsampling: If given, only use an equally distributed subset of all CA/C3' positions for superposition/scoring. :type subsampling: :class:`int` :param first_complete: Avoid full enumeration and return first found mapping that covers all model chains or all target chains. Has no effect on greedy_iterative_rmsd strategy. :type first_complete: :class:`bool` :param iterative_superposition: Whether to compute inital transformations with :func:`ost.mol.alg.IterativeSuperposeSVD` as oposed to :func:`ost.mol.alg.SuperposeSVD` :type iterative_superposition: :class:`bool` :param chem_mapping_result: Pro param. The result of :func:`~GetChemMapping` where you provided *model*. If set, *model* parameter is not used. :type chem_mapping_result: :class:`tuple` :returns: A :class:`MappingResult` """ strategies = ["greedy_single_gdtts", "greedy_iterative_gdtts", "greedy_single_rmsd", "greedy_iterative_rmsd"] if strategy not in strategies: raise RuntimeError(f"strategy must be {strategies}") if chem_mapping_result is None: chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) else: chem_mapping, chem_group_alns, mdl = chem_mapping_result ref_mdl_alns = _GetRefMdlAlns(self.chem_groups, self.chem_group_alignments, chem_mapping, chem_group_alns) # check for the simplest case one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping) if one_to_one is not None: alns = dict() for ref_group, mdl_group in zip(self.chem_groups, one_to_one): for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] aln.AttachView(0, _CSel(self.target, [ref_ch])) aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, one_to_one, alns) trg_group_pos, mdl_group_pos = _GetRefPos(self.target, mdl, self.chem_group_alignments, chem_group_alns, max_pos = subsampling) # get transforms of any mdl chain onto any trg chain in same chem group # that fulfills gdtts threshold initial_transforms = list() initial_mappings = list() for trg_pos, trg_chains, mdl_pos, mdl_chains in zip(trg_group_pos, self.chem_groups, mdl_group_pos, chem_mapping): for t_pos, t in zip(trg_pos, trg_chains): for m_pos, m in zip(mdl_pos, mdl_chains): if len(t_pos) >= 3 and len(m_pos) >= 3: transform = _GetTransform(m_pos, t_pos, iterative_superposition) t_m_pos = geom.Vec3List(m_pos) t_m_pos.ApplyTransform(transform) gdt = t_pos.GetGDTTS(t_m_pos) if gdt >= single_chain_gdtts_thresh: initial_transforms.append(transform) initial_mappings.append((t,m)) if strategy == "greedy_single_gdtts": mapping = _SingleRigidGDTTS(initial_transforms, initial_mappings, self.chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, single_chain_gdtts_thresh, iterative_superposition, first_complete, len(self.target.chains), len(mdl.chains)) elif strategy == "greedy_iterative_gdtts": mapping = _IterativeRigidGDTTS(initial_transforms, initial_mappings, self.chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, single_chain_gdtts_thresh, iterative_superposition, first_complete, len(self.target.chains), len(mdl.chains)) elif strategy == "greedy_single_rmsd": mapping = _SingleRigidRMSD(initial_transforms, initial_mappings, self.chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, iterative_superposition) elif strategy == "greedy_iterative_rmsd": mapping = _IterativeRigidRMSD(initial_transforms, initial_mappings, self.chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, iterative_superposition) # translate mapping format and return final_mapping = list() for ref_chains in self.chem_groups: mapped_mdl_chains = list() for ref_ch in ref_chains: if ref_ch in mapping: mapped_mdl_chains.append(mapping[ref_ch]) else: mapped_mdl_chains.append(None) final_mapping.append(mapped_mdl_chains) alns = dict() for ref_group, mdl_group in zip(self.chem_groups, final_mapping): for ref_ch, mdl_ch in zip(ref_group, mdl_group): if ref_ch is not None and mdl_ch is not None: aln = ref_mdl_alns[(ref_ch, mdl_ch)] aln.AttachView(0, _CSel(self.target, [ref_ch])) aln.AttachView(1, _CSel(mdl, [mdl_ch])) alns[(ref_ch, mdl_ch)] = aln return MappingResult(self.target, mdl, self.chem_groups, chem_mapping, final_mapping, alns) def GetMapping(self, model, n_max_naive = 40320): """ Convenience function to get mapping with currently preferred method If number of possible chain mappings is <= *n_max_naive*, a naive QS-score mapping is performed and optimal QS-score is guaranteed. For anything else, a QS-score mapping with the greedy_full strategy is performed (greedy_prune_contact_map = True). The default for *n_max_naive* of 40320 corresponds to an octamer (8!=40320). A structure with stoichiometry A6B2 would be 6!*2!=1440 etc. """ chem_mapping_res = self.GetChemMapping(model) if _NMappingsWithin(self.chem_groups, chem_mapping_res[0], n_max_naive): return self.GetQSScoreMapping(model, strategy="naive", chem_mapping_result=chem_mapping_res) else: return self.GetQSScoreMapping(model, strategy="greedy_full", greedy_prune_contact_map=True, chem_mapping_result=chem_mapping_res) def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0, thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False, only_interchain=False, chem_mapping_result = None, global_mapping = None): """ Identify *topn* representations of *substructure* in *model* *substructure* defines a subset of :attr:`~target` for which one wants the *topn* representations in *model*. Representations are scored and sorted by lDDT. :param substructure: A :class:`ost.mol.EntityView` which is a subset of :attr:`~target`. Should be selected with the OpenStructure query language. Example: if you're interested in residues with number 42,43 and 85 in chain A: ``substructure=mapper.target.Select("cname=A and rnum=42,43,85")`` A :class:`RuntimeError` is raised if *substructure* does not refer to the same underlying :class:`ost.mol.EntityHandle` as :attr:`~target`. :type substructure: :class:`ost.mol.EntityView` :param model: Structure in which one wants to find representations for *substructure* :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` :param topn: Max number of representations that are returned :type topn: :class:`int` :param inclusion_radius: Inclusion radius for lDDT :type inclusion_radius: :class:`float` :param thresholds: Thresholds for lDDT :type thresholds: :class:`list` of :class:`float` :param bb_only: Only consider backbone atoms in lDDT computation :type bb_only: :class:`bool` :param only_interchain: Only score interchain contacts in lDDT. Useful if you want to identify interface patches. :type only_interchain: :class:`bool` :param chem_mapping_result: Pro param. The result of :func:`~GetChemMapping` where you provided *model*. If set, *model* parameter is not used. :type chem_mapping_result: :class:`tuple` :param global_mapping: Pro param. Specify a global mapping result. This fully defines the desired representation in the model but extracts it and enriches it with all the nice attributes of :class:`ReprResult`. The target attribute in *global_mapping* must be of the same entity as self.target and the model attribute of *global_mapping* must be of the same entity as *model*. :type global_mapping: :class:`MappingResult` :returns: :class:`list` of :class:`ReprResult` """ if topn < 1: raise RuntimeError("topn must be >= 1") if global_mapping is not None: # ensure that this mapping is derived from the same structures if global_mapping.target.handle.GetHashCode() != \ self.target.handle.GetHashCode(): raise RuntimeError("global_mapping.target must be the same " "entity as self.target") if global_mapping.model.handle.GetHashCode() != \ model.handle.GetHashCode(): raise RuntimeError("global_mapping.model must be the same " "entity as model param") # check whether substructure really is a subset of self.target for r in substructure.residues: ch_name = r.GetChain().GetName() rnum = r.GetNumber() target_r = self.target.FindResidue(ch_name, rnum) if not target_r.IsValid(): raise RuntimeError(f"substructure has residue " f"{r.GetQualifiedName()} which is not in " f"self.target") if target_r.handle.GetHashCode() != r.handle.GetHashCode(): raise RuntimeError(f"substructure has residue " f"{r.GetQualifiedName()} which has an " f"equivalent in self.target but it does " f"not refer to the same underlying " f"EntityHandle") for a in r.atoms: target_a = target_r.FindAtom(a.GetName()) if not target_a.IsValid(): raise RuntimeError(f"substructure has atom " f"{a.GetQualifiedName()} which is not " f"in self.target") if a.handle.GetHashCode() != target_a.handle.GetHashCode(): raise RuntimeError(f"substructure has atom " f"{a.GetQualifiedName()} which has an " f"equivalent in self.target but it does " f"not refer to the same underlying " f"EntityHandle") # check whether it contains either CA or C3' ca = r.FindAtom("CA") c3 = r.FindAtom("C3'") # FindAtom with prime in string is tested # and works if not ca.IsValid() and not c3.IsValid(): raise RuntimeError("All residues in substructure must contain " "a backbone atom named CA or C3\'") # perform mapping and alignments on full structures if chem_mapping_result is None: chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) else: chem_mapping, chem_group_alns, mdl = chem_mapping_result ref_mdl_alns = _GetRefMdlAlns(self.chem_groups, self.chem_group_alignments, chem_mapping, chem_group_alns) # Get residue indices relative to full target chain substructure_res_indices = dict() for ch in substructure.chains: full_ch = self.target.FindChain(ch.GetName()) idx = [full_ch.GetResidueIndex(r.GetNumber()) for r in ch.residues] substructure_res_indices[ch.GetName()] = idx # strip down variables to make them specific to substructure # keep only chem_groups which are present in substructure substructure_chem_groups = list() substructure_chem_mapping = list() chnames = set([ch.GetName() for ch in substructure.chains]) for chem_group, mapping in zip(self.chem_groups, chem_mapping): substructure_chem_group = [ch for ch in chem_group if ch in chnames] if len(substructure_chem_group) > 0: substructure_chem_groups.append(substructure_chem_group) substructure_chem_mapping.append(mapping) # early stopping if no mdl chain can be mapped to substructure n_mapped_mdl_chains = sum([len(m) for m in substructure_chem_mapping]) if n_mapped_mdl_chains == 0: return list() # strip the reference sequence in alignments to only contain # sequence from substructure substructure_ref_mdl_alns = dict() mdl_views = dict() for ch in mdl.chains: mdl_views[ch.GetName()] = _CSel(mdl, [ch.GetName()]) for chem_group, mapping in zip(substructure_chem_groups, substructure_chem_mapping): for ref_ch in chem_group: for mdl_ch in mapping: full_aln = ref_mdl_alns[(ref_ch, mdl_ch)] ref_seq = full_aln.GetSequence(0) # the ref sequence is tricky... we start with a gap only # sequence and only add olcs as defined by the residue # indices that we extracted before... tmp = ['-'] * len(full_aln) for idx in substructure_res_indices[ref_ch]: idx_in_seq = ref_seq.GetPos(idx) tmp[idx_in_seq] = ref_seq[idx_in_seq] ref_seq = seq.CreateSequence(ref_ch, ''.join(tmp)) ref_seq.AttachView(_CSel(substructure, [ref_ch])) mdl_seq = full_aln.GetSequence(1) mdl_seq = seq.CreateSequence(mdl_seq.GetName(), mdl_seq.GetString()) mdl_seq.AttachView(mdl_views[mdl_ch]) aln = seq.CreateAlignment() aln.AddSequence(ref_seq) aln.AddSequence(mdl_seq) substructure_ref_mdl_alns[(ref_ch, mdl_ch)] = aln lddt_scorer = lddt.lDDTScorer(substructure, inclusion_radius = inclusion_radius, bb_only = bb_only) scored_mappings = list() if global_mapping: # construct mapping of substructure from global mapping flat_mapping = global_mapping.GetFlatMapping() mapping = list() for chem_group, chem_mapping in zip(substructure_chem_groups, substructure_chem_mapping): chem_group_mapping = list() for ch in chem_group: if ch in flat_mapping: mdl_ch = flat_mapping[ch] if mdl_ch in chem_mapping: chem_group_mapping.append(mdl_ch) else: chem_group_mapping.append(None) else: chem_group_mapping.append(None) mapping.append(chem_group_mapping) mappings = [mapping] else: mappings = list(_ChainMappings(substructure_chem_groups, substructure_chem_mapping, self.n_max_naive)) for mapping in mappings: # chain_mapping and alns as input for lDDT computation lddt_chain_mapping = dict() lddt_alns = dict() n_res_aln = 0 for ref_chem_group, mdl_chem_group in zip(substructure_chem_groups, mapping): for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group): # some mdl chains can be None if mdl_ch is not None: lddt_chain_mapping[mdl_ch] = ref_ch aln = substructure_ref_mdl_alns[(ref_ch, mdl_ch)] lddt_alns[mdl_ch] = aln tmp = [int(c[0] != '-' and c[1] != '-') for c in aln] n_res_aln += sum(tmp) # don't compute lDDT if no single residue in mdl and ref is aligned if n_res_aln == 0: continue lDDT, _ = lddt_scorer.lDDT(mdl, thresholds=thresholds, chain_mapping=lddt_chain_mapping, residue_mapping = lddt_alns, check_resnames = False, no_intrachain = only_interchain) if lDDT is None: ost.LogVerbose("No valid contacts in the reference") lDDT = 0.0 # that means, that we have not a single valid contact # in lDDT. For the code below to work, we just set it # to a terrible score => 0.0 if len(scored_mappings) == 0: scored_mappings.append((lDDT, mapping)) elif len(scored_mappings) < topn: scored_mappings.append((lDDT, mapping)) scored_mappings.sort(reverse=True, key=lambda x: x[0]) elif lDDT > scored_mappings[-1][0]: scored_mappings.append((lDDT, mapping)) scored_mappings.sort(reverse=True, key=lambda x: x[0]) scored_mappings = scored_mappings[:topn] # finalize and return results = list() for scored_mapping in scored_mappings: ref_view = substructure.handle.CreateEmptyView() mdl_view = mdl.handle.CreateEmptyView() for ref_ch_group, mdl_ch_group in zip(substructure_chem_groups, scored_mapping[1]): for ref_ch, mdl_ch in zip(ref_ch_group, mdl_ch_group): if ref_ch is not None and mdl_ch is not None: aln = substructure_ref_mdl_alns[(ref_ch, mdl_ch)] for col in aln: if col[0] != '-' and col[1] != '-': ref_view.AddResidue(col.GetResidue(0), mol.ViewAddFlag.INCLUDE_ALL) mdl_view.AddResidue(col.GetResidue(1), mol.ViewAddFlag.INCLUDE_ALL) results.append(ReprResult(scored_mapping[0], substructure, ref_view, mdl_view)) return results def GetNMappings(self, model): """ Returns number of possible mappings :param model: Model with chains that are mapped onto :attr:`chem_groups` :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` """ chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) return _NMappings(self.chem_groups, chem_mapping) def ProcessStructure(self, ent): """ Entity processing for chain mapping * Selects view containing peptide and nucleotide residues which have required backbone atoms present - for peptide residues thats N, CA, C and CB (no CB for GLY), for nucleotide residues thats O5', C5', C4', C3' and O3'. * filters view by chain lengths, see *min_pep_length* and *min_nuc_length* in constructor * Extracts atom sequences for each chain in that view * Attaches corresponding :class:`ost.mol.EntityView` to each sequence * If residue number alignments are used, strictly increasing residue numbers without insertion codes are ensured in each chain :param ent: Entity to process :type ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` :returns: Tuple with 3 elements: 1) :class:`ost.mol.EntityView` containing peptide and nucleotide residues 2) :class:`ost.seq.SequenceList` containing ATOMSEQ sequences for each polypeptide chain in returned view, sequences have :class:`ost.mol.EntityView` of according chains attached 3) same for polynucleotide chains """ view = ent.CreateEmptyView() exp_pep_atoms = ["N", "CA", "C", "CB"] exp_nuc_atoms = ["\"O5'\"", "\"C5'\"", "\"C4'\"", "\"C3'\"", "\"O3'\""] pep_query = "peptide=true and aname=" + ','.join(exp_pep_atoms) nuc_query = "nucleotide=true and aname=" + ','.join(exp_nuc_atoms) pep_sel = ent.Select(pep_query) for r in pep_sel.residues: if len(r.atoms) == 4: view.AddResidue(r.handle, mol.INCLUDE_ALL) elif r.name == "GLY" and len(r.atoms) == 3: atom_names = [a.GetName() for a in r.atoms] if sorted(atom_names) == ["C", "CA", "N"]: view.AddResidue(r.handle, mol.INCLUDE_ALL) nuc_sel = ent.Select(nuc_query) for r in nuc_sel.residues: if len(r.atoms) == 5: view.AddResidue(r.handle, mol.INCLUDE_ALL) polypep_seqs = seq.CreateSequenceList() polynuc_seqs = seq.CreateSequenceList() if len(view.residues) == 0: # no residues survived => return return (view, polypep_seqs, polynuc_seqs) for ch in view.chains: n_res = len(ch.residues) n_pep = sum([r.IsPeptideLinking() for r in ch.residues]) n_nuc = sum([r.IsNucleotideLinking() for r in ch.residues]) # guarantee that we have either pep or nuc (no mix of the two) if n_pep > 0 and n_nuc > 0: raise RuntimeError(f"Must not mix peptide and nucleotide linking " f"residues in same chain ({ch.GetName()})") if (n_pep + n_nuc) != n_res: raise RuntimeError("All residues must either be peptide_linking " "or nucleotide_linking") # filter out short chains if n_pep > 0 and n_pep < self.min_pep_length: continue if n_nuc > 0 and n_nuc < self.min_nuc_length: continue # the superfast residue number based alignment adds some # restrictions on the numbers themselves: # 1) no insertion codes 2) strictly increasing if self.resnum_alignments: # check if no insertion codes are present in residue numbers ins_codes = [r.GetNumber().GetInsCode() for r in ch.residues] if len(set(ins_codes)) != 1 or ins_codes[0] != '\0': raise RuntimeError("Residue numbers in input structures must not " "contain insertion codes") # check if residue numbers are strictly increasing nums = [r.GetNumber().GetNum() for r in ch.residues] if not all(i < j for i, j in zip(nums, nums[1:])): raise RuntimeError("Residue numbers in input structures must be " "strictly increasing for each chain") s = ''.join([r.one_letter_code for r in ch.residues]) s = seq.CreateSequence(ch.GetName(), s) s.AttachView(_CSel(view, [ch.GetName()])) if n_pep == n_res: polypep_seqs.AddSequence(s) elif n_nuc == n_res: polynuc_seqs.AddSequence(s) else: raise RuntimeError("This shouldnt happen") if len(polypep_seqs) == 0 and len(polynuc_seqs) == 0: raise RuntimeError(f"No chain fulfilled minimum length requirement " f"to be considered in chain mapping " f"({self.min_pep_length} for peptide chains, " f"{self.min_nuc_length} for nucleotide chains) " f"- mapping failed") # select for chains for which we actually extracted the sequence chain_names = [s.GetAttachedView().chains[0].name for s in polypep_seqs] chain_names += [s.GetAttachedView().chains[0].name for s in polynuc_seqs] view = _CSel(view, chain_names) return (view, polypep_seqs, polynuc_seqs) def Align(self, s1, s2, stype): """ Access to internal sequence alignment functionality Alignment parameterization is setup at ChainMapper construction :param s1: First sequence to align - must have view attached in case of resnum_alignments :type s1: :class:`ost.seq.SequenceHandle` :param s2: Second sequence to align - must have view attached in case of resnum_alignments :type s2: :class:`ost.seq.SequenceHandle` :param stype: Type of sequences to align, must be in [:class:`ost.mol.ChemType.AMINOACIDS`, :class:`ost.mol.ChemType.NUCLEOTIDES`] :returns: Pairwise alignment of s1 and s2 """ if stype not in [mol.ChemType.AMINOACIDS, mol.ChemType.NUCLEOTIDES]: raise RuntimeError("stype must be ost.mol.ChemType.AMINOACIDS or " "ost.mol.ChemType.NUCLEOTIDES") return self.aligner.Align(s1, s2, chem_type = stype) # INTERNAL HELPERS ################## class _Aligner: def __init__(self, pep_subst_mat = seq.alg.BLOSUM62, pep_gap_open = -5, pep_gap_ext = -2, nuc_subst_mat = seq.alg.NUC44, nuc_gap_open = -4, nuc_gap_ext = -4, resnum_aln = False): """ Helper class to compute alignments Sets default values for substitution matrix, gap open and gap extension penalties. They are only used in default mode (Needleman-Wunsch aln). If *resnum_aln* is True, only residue numbers of views that are attached to input sequences are considered. """ self.pep_subst_mat = pep_subst_mat self.pep_gap_open = pep_gap_open self.pep_gap_ext = pep_gap_ext self.nuc_subst_mat = nuc_subst_mat self.nuc_gap_open = nuc_gap_open self.nuc_gap_ext = nuc_gap_ext self.resnum_aln = resnum_aln def Align(self, s1, s2, chem_type=None): if self.resnum_aln: return self.ResNumAlign(s1, s2) else: if chem_type is None: raise RuntimeError("Must specify chem_type for NW alignment") return self.NWAlign(s1, s2, chem_type) def NWAlign(self, s1, s2, chem_type): """ Returns pairwise alignment using Needleman-Wunsch algorithm :param s1: First sequence to align :type s1: :class:`ost.seq.SequenceHandle` :param s2: Second sequence to align :type s2: :class:`ost.seq.SequenceHandle` :param chem_type: Must be in [:class:`ost.mol.ChemType.AMINOACIDS`, :class:`ost.mol.ChemType.NUCLEOTIDES`], determines substitution matrix and gap open/extension penalties :type chem_type: :class:`ost.mol.ChemType` :returns: Alignment with s1 as first and s2 as second sequence """ if chem_type == mol.ChemType.AMINOACIDS: return seq.alg.GlobalAlign(s1, s2, self.pep_subst_mat, gap_open=self.pep_gap_open, gap_ext=self.pep_gap_ext)[0] elif chem_type == mol.ChemType.NUCLEOTIDES: return seq.alg.GlobalAlign(s1, s2, self.nuc_subst_mat, gap_open=self.nuc_gap_open, gap_ext=self.nuc_gap_ext)[0] else: raise RuntimeError("Invalid ChemType") return aln def ResNumAlign(self, s1, s2): """ Returns pairwise alignment using residue numbers of attached views Assumes that there are no insertion codes (alignment only on numerical component) and that resnums are strictly increasing (fast min/max identification). These requirements are assured if a structure has been processed by :class:`ChainMapper`. :param s1: First sequence to align, must have :class:`ost.mol.EntityView` attached :type s1: :class:`ost.seq.SequenceHandle` :param s2: Second sequence to align, must have :class:`ost.mol.EntityView` attached :type s2: :class:`ost.seq.SequenceHandle` """ assert(s1.HasAttachedView()) assert(s2.HasAttachedView()) v1 = s1.GetAttachedView() rnums1 = [r.GetNumber().GetNum() for r in v1.residues] v2 = s2.GetAttachedView() rnums2 = [r.GetNumber().GetNum() for r in v2.residues] min_num = min(rnums1[0], rnums2[0]) max_num = max(rnums1[-1], rnums2[-1]) aln_length = max_num - min_num + 1 aln_s1 = ['-'] * aln_length for r, rnum in zip(v1.residues, rnums1): aln_s1[rnum-min_num] = r.one_letter_code aln_s2 = ['-'] * aln_length for r, rnum in zip(v2.residues, rnums2): aln_s2[rnum-min_num] = r.one_letter_code aln = seq.CreateAlignment() aln.AddSequence(seq.CreateSequence(s1.GetName(), ''.join(aln_s1))) aln.AddSequence(seq.CreateSequence(s2.GetName(), ''.join(aln_s2))) return aln def _GetAlnPropsTwo(aln): """Returns basic properties of *aln* version two... :param aln: Alignment to compute properties :type aln: :class:`seq.AlignmentHandle` :returns: Tuple with 2 elements. 1) sequence identify in range [0, 100] considering aligned columns 2) Fraction of non-gap characters in first sequence that are covered by non-gap characters in second sequence. """ assert(aln.GetCount() == 2) n_tot = sum([1 for col in aln if col[0] != '-']) n_aligned = sum([1 for col in aln if (col[0] != '-' and col[1] != '-')]) return (seq.alg.SequenceIdentity(aln), float(n_aligned)/n_tot) def _GetAlnPropsOne(aln): """Returns basic properties of *aln* version one... :param aln: Alignment to compute properties :type aln: :class:`seq.AlignmentHandle` :returns: Tuple with 3 elements. 1) sequence identify in range [0, 100] considering aligned columns 2) Fraction of gaps between first and last aligned column in s1 3) same for s2. """ assert(aln.GetCount() == 2) n_gaps_1 = str(aln.GetSequence(0)).strip('-').count('-') n_gaps_2 = str(aln.GetSequence(1)).strip('-').count('-') gap_frac_1 = float(n_gaps_1)/len(aln.GetSequence(0).GetGaplessString()) gap_frac_2 = float(n_gaps_2)/len(aln.GetSequence(1).GetGaplessString()) return (seq.alg.SequenceIdentity(aln), gap_frac_1, gap_frac_2) def _GetChemGroupAlignments(pep_seqs, nuc_seqs, aligner, pep_seqid_thr=95., pep_gap_thr=0.1, nuc_seqid_thr=95., nuc_gap_thr=0.1): """Returns alignments with groups of chemically equivalent chains :param pep_seqs: List of polypeptide sequences :type pep_seqs: :class:`seq.SequenceList` :param nuc_seqs: List of polynucleotide sequences :type nuc_seqs: :class:`seq.SequenceList` :param aligner: Helper class to generate pairwise alignments :type aligner: :class:`_Aligner` :param pep_seqid_thr: Threshold used to decide when two peptide chains are identical. 95 percent tolerates the few mutations crystallographers like to do. :type pep_seqid_thr: :class:`float` :param pep_gap_thr: Additional threshold to avoid gappy alignments with high seqid. The reason for not just normalizing seqid by the longer sequence is that one sequence might be a perfect subsequence of the other but only cover half of it. This threshold checks for a maximum allowed fraction of gaps in any of the two sequences after stripping terminal gaps. :type pep_gap_thr: :class:`float` :param nuc_seqid_thr: Nucleotide equivalent of *pep_seqid_thr* :type nuc_seqid_thr: :class:`float` :param nuc_gap_thr: Nucleotide equivalent of *nuc_gap_thr* :type nuc_gap_thr: :class:`float` :returns: Tuple with first element being an AlignmentList. Each alignment represents a group of chemically equivalent chains and the first sequence is the longest. Second element is a list of equivalent length specifying the types of the groups. List elements are in [:class:`ost.ChemType.AMINOACIDS`, :class:`ost.ChemType.NUCLEOTIDES`] """ pep_groups = _GroupSequences(pep_seqs, pep_seqid_thr, pep_gap_thr, aligner, mol.ChemType.AMINOACIDS) nuc_groups = _GroupSequences(nuc_seqs, nuc_seqid_thr, nuc_gap_thr, aligner, mol.ChemType.NUCLEOTIDES) group_types = [mol.ChemType.AMINOACIDS] * len(pep_groups) group_types += [mol.ChemType.NUCLEOTIDES] * len(nuc_groups) groups = pep_groups groups.extend(nuc_groups) return (groups, group_types) def _GroupSequences(seqs, seqid_thr, gap_thr, aligner, chem_type): """Get list of alignments representing groups of equivalent sequences :param seqid_thr: Threshold used to decide when two chains are identical. :type seqid_thr: :class:`float` :param gap_thr: Additional threshold to avoid gappy alignments with high seqid. The reason for not just normalizing seqid by the longer sequence is that one sequence might be a perfect subsequence of the other but only cover half of it. This threshold checks for a maximum allowed fraction of gaps in any of the two sequences after stripping terminal gaps. :type gap_thr: :class:`float` :param aligner: Helper class to generate pairwise alignments :type aligner: :class:`_Aligner` :param chem_type: ChemType of seqs which is passed to *aligner*, must be in [:class:`ost.mol.ChemType.AMINOACIDS`, :class:`ost.mol.ChemType.NUCLEOTIDES`] :type chem_type: :class:`ost.mol.ChemType` :returns: A list of alignments, one alignment for each group with longest sequence (reference) as first sequence. :rtype: :class:`ost.seq.AlignmentList` """ groups = list() for s_idx in range(len(seqs)): matching_group = None for g_idx in range(len(groups)): for g_s_idx in range(len(groups[g_idx])): aln = aligner.Align(seqs[s_idx], seqs[groups[g_idx][g_s_idx]], chem_type) sid, frac_i, frac_j = _GetAlnPropsOne(aln) if sid >= seqid_thr and frac_i < gap_thr and frac_j < gap_thr: matching_group = g_idx break if matching_group is not None: break if matching_group is None: groups.append([s_idx]) else: groups[matching_group].append(s_idx) # sort based on sequence length sorted_groups = list() for g in groups: if len(g) > 1: tmp = sorted([[len(seqs[i]), i] for i in g], reverse=True) sorted_groups.append([x[1] for x in tmp]) else: sorted_groups.append(g) # translate from indices back to sequences and directly generate alignments # of the groups with the longest (first) sequence as reference aln_list = seq.AlignmentList() for g in sorted_groups: if len(g) == 1: # aln with one single sequence aln_list.append(seq.CreateAlignment(seqs[g[0]])) else: # obtain pairwise aln of first sequence (reference) to all others alns = seq.AlignmentList() i = g[0] for j in g[1:]: alns.append(aligner.Align(seqs[i], seqs[j], chem_type)) # and merge aln_list.append(seq.alg.MergePairwiseAlignments(alns, seqs[i])) # transfer attached views seq_dict = {s.GetName(): s for s in seqs} for aln_idx in range(len(aln_list)): for aln_s_idx in range(aln_list[aln_idx].GetCount()): s_name = aln_list[aln_idx].GetSequence(aln_s_idx).GetName() s = seq_dict[s_name] aln_list[aln_idx].AttachView(aln_s_idx, s.GetAttachedView()) return aln_list def _MapSequence(ref_seqs, ref_types, s, s_type, aligner): """Tries top map *s* onto any of the sequences in *ref_seqs* Computes alignments of *s* to each of the reference sequences of equal type and sorts them by seqid*fraction_covered (seqid: sequence identity of aligned columns in alignment, fraction_covered: Fraction of non-gap characters in reference sequence that are covered by non-gap characters in *s*). Best scoring mapping is returned. :param ref_seqs: Reference sequences :type ref_seqs: :class:`ost.seq.SequenceList` :param ref_types: Types of reference sequences, e.g. ost.mol.ChemType.AminoAcids :type ref_types: :class:`list` of :class:`ost.mol.ChemType` :param s: Sequence to map :type s: :class:`ost.seq.SequenceHandle` :param s_type: Type of *s*, only try mapping to sequences in *ref_seqs* with equal type as defined in *ref_types* :param aligner: Helper class to generate pairwise alignments :type aligner: :class:`_Aligner` :returns: Tuple with two elements. 1) index of sequence in *ref_seqs* to which *s* can be mapped 2) Pairwise sequence alignment with sequence from *ref_seqs* as first sequence. Both elements are None if no mapping can be found. :raises: :class:`RuntimeError` if mapping is ambiguous, i.e. *s* successfully maps to more than one sequence in *ref_seqs* """ scored_alns = list() for ref_idx, ref_seq in enumerate(ref_seqs): if ref_types[ref_idx] == s_type: aln = aligner.Align(ref_seq, s, s_type) seqid, fraction_covered = _GetAlnPropsTwo(aln) score = seqid * fraction_covered scored_alns.append((score, ref_idx, aln)) if len(scored_alns) == 0: return (None, None) # no mapping possible... scored_alns = sorted(scored_alns, key=lambda x: x[0], reverse=True) return (scored_alns[0][1], scored_alns[0][2]) def _GetRefMdlAlns(ref_chem_groups, ref_chem_group_msas, mdl_chem_groups, mdl_chem_group_alns, pairs=None): """ Get all possible ref/mdl chain alignments given chem group mapping :param ref_chem_groups: :attr:`ChainMapper.chem_groups` :type ref_chem_groups: :class:`list` of :class:`list` of :class:`str` :param ref_chem_group_msas: :attr:`ChainMapper.chem_group_alignments` :type ref_chem_group_msas: :class:`ost.seq.AlignmentList` :param mdl_chem_groups: Groups of model chains that are mapped to *ref_chem_groups*. Return value of :func:`ChainMapper.GetChemMapping`. :type mdl_chem_groups: :class:`list` of :class:`list` of :class:`str` :param mdl_chem_group_alns: A pairwise sequence alignment for every chain in *mdl_chem_groups* that aligns these sequences to the respective reference sequence. Return values of :func:`ChainMapper.GetChemMapping`. :type mdl_chem_group_alns: :class:`list` of :class:`ost.seq.AlignmentList` :param pairs: Pro param - restrict return dict to specified pairs. A set of tuples in form (<trg_ch>, <mdl_ch>) :type pairs: :class:`set` :returns: A dictionary holding all possible ref/mdl chain alignments. Keys in that dictionary are tuples of the form (ref_ch, mdl_ch) and values are the respective pairwise alignments with first sequence being from ref, the second from mdl. """ # alignment of each model chain to chem_group reference sequence mdl_alns = dict() for alns in mdl_chem_group_alns: for aln in alns: mdl_chain_name = aln.GetSequence(1).GetName() mdl_alns[mdl_chain_name] = aln # generate all alignments between ref/mdl chain atomseqs that we will # ever observe ref_mdl_alns = dict() for ref_chains, mdl_chains, ref_aln in zip(ref_chem_groups, mdl_chem_groups, ref_chem_group_msas): for ref_ch in ref_chains: for mdl_ch in mdl_chains: if pairs is not None and (ref_ch, mdl_ch) not in pairs: continue # obtain alignments of mdl and ref chains towards chem # group ref sequence and merge them aln_list = seq.AlignmentList() # do ref aln s1 = ref_aln.GetSequence(0) s2 = ref_aln.GetSequence(ref_chains.index(ref_ch)) aln_list.append(seq.CreateAlignment(s1, s2)) # do mdl aln aln_list.append(mdl_alns[mdl_ch]) # merge ref_seq = seq.CreateSequence(s1.GetName(), s1.GetGaplessString()) merged_aln = seq.alg.MergePairwiseAlignments(aln_list, ref_seq) # merged_aln: # seq1: ref seq of chem group # seq2: seq of ref chain # seq3: seq of mdl chain # => we need the alignment between seq2 and seq3 s2 = merged_aln.GetSequence(1) s3 = merged_aln.GetSequence(2) # cut leading and trailing gap columns a = 0 # number of leading gap columns for idx in range(len(s2)): if s2[idx] != '-' or s3[idx] != '-': break a += 1 b = 0 # number of trailing gap columns for idx in reversed(range(len(s2))): if s2[idx] != '-' or s3[idx] != '-': break b += 1 s2 = seq.CreateSequence(s2.GetName(), s2[a: len(s2)-b]) s3 = seq.CreateSequence(s3.GetName(), s3[a: len(s3)-b]) ref_mdl_alns[(ref_ch, mdl_ch)] = seq.CreateAlignment(s2, s3) return ref_mdl_alns def _CheckOneToOneMapping(ref_chains, mdl_chains): """ Checks whether we already have a perfect one to one mapping That means each list in *ref_chains* has exactly one element and each list in *mdl_chains* has either one element (it's mapped) or is empty (ref chain has no mapped mdl chain). Returns None if no such mapping can be found. :param ref_chains: corresponds to :attr:`ChainMapper.chem_groups` :type ref_chains: :class:`list` of :class:`list` of :class:`str` :param mdl_chains: mdl chains mapped to chem groups in *ref_chains*, i.e. the return value of :func:`ChainMapper.GetChemMapping` :type mdl_chains: class:`list` of :class:`list` of :class:`str` :returns: A :class:`list` of :class:`list` if a one to one mapping is found, None otherwise """ only_one_to_one = True one_to_one = list() for ref, mdl in zip(ref_chains, mdl_chains): if len(ref) == 1 and len(mdl) == 1: one_to_one.append(mdl) elif len(ref) == 1 and len(mdl) == 0: one_to_one.append([None]) else: only_one_to_one = False break if only_one_to_one: return one_to_one else: return None class _lDDTDecomposer: def __init__(self, ref, mdl, ref_mdl_alns, inclusion_radius = 15.0, thresholds = [0.5, 1.0, 2.0, 4.0]): """ Compute backbone only lDDT scores for ref/mdl Uses the pairwise decomposable property of backbone only lDDT and implements a caching mechanism to efficiently enumerate different chain mappings. """ self.ref = ref self.mdl = mdl self.ref_mdl_alns = ref_mdl_alns self.inclusion_radius = inclusion_radius self.thresholds = thresholds # keep track of single chains and interfaces in ref self.ref_chains = list() # e.g. ['A', 'B', 'C'] self.ref_interfaces = list() # e.g. [('A', 'B'), ('A', 'C')] # holds lDDT scorer for each chain in ref # key: chain name, value: scorer self.single_chain_scorer = dict() # cache for single chain conserved contacts # key: tuple (ref_ch, mdl_ch) value: number of conserved contacts self.single_chain_cache = dict() # holds lDDT scorer for each pairwise interface in target # key: tuple (ref_ch1, ref_ch2), value: scorer self.interface_scorer = dict() # cache for interface conserved contacts # key: tuple of tuple ((ref_ch1, ref_ch2),((mdl_ch1, mdl_ch2)) # value: number of conserved contacts self.interface_cache = dict() self.n = 0 self._SetupScorer() def _SetupScorer(self): for ch in self.ref.chains: # Select everything close to that chain query = f"{self.inclusion_radius} <> " query += f"[cname={mol.QueryQuoteName(ch.GetName())}] " query += f"and cname!={mol.QueryQuoteName(ch.GetName())}" for close_ch in self.ref.Select(query).chains: k1 = (ch.GetName(), close_ch.GetName()) k2 = (close_ch.GetName(), ch.GetName()) if k1 not in self.interface_scorer and \ k2 not in self.interface_scorer: dimer_ref = _CSel(self.ref, [k1[0], k1[1]]) s = lddt.lDDTScorer(dimer_ref, bb_only=True) self.interface_scorer[k1] = s self.interface_scorer[k2] = s self.n += self.interface_scorer[k1].n_distances_ic self.ref_interfaces.append(k1) # single chain scorer are actually interface scorers to save # some distance calculations if ch.GetName() not in self.single_chain_scorer: self.single_chain_scorer[ch.GetName()] = s self.n += s.GetNChainContacts(ch.GetName(), no_interchain=True) self.ref_chains.append(ch.GetName()) if close_ch.GetName() not in self.single_chain_scorer: self.single_chain_scorer[close_ch.GetName()] = s self.n += s.GetNChainContacts(close_ch.GetName(), no_interchain=True) self.ref_chains.append(close_ch.GetName()) # add any missing single chain scorer for ch in self.ref.chains: if ch.GetName() not in self.single_chain_scorer: single_chain_ref = _CSel(self.ref, [ch.GetName()]) self.single_chain_scorer[ch.GetName()] = \ lddt.lDDTScorer(single_chain_ref, bb_only = True) self.n += self.single_chain_scorer[ch.GetName()].n_distances self.ref_chains.append(ch.GetName()) def lDDT(self, ref_chain_groups, mdl_chain_groups): flat_map = dict() for ref_chains, mdl_chains in zip(ref_chain_groups, mdl_chain_groups): for ref_ch, mdl_ch in zip(ref_chains, mdl_chains): flat_map[ref_ch] = mdl_ch return self.lDDTFromFlatMap(flat_map) def lDDTFromFlatMap(self, flat_map): conserved = 0 # do single chain scores for ref_ch in self.ref_chains: if ref_ch in flat_map and flat_map[ref_ch] is not None: conserved += self.SCCounts(ref_ch, flat_map[ref_ch]) # do interfaces for ref_ch1, ref_ch2 in self.ref_interfaces: if ref_ch1 in flat_map and ref_ch2 in flat_map: mdl_ch1 = flat_map[ref_ch1] mdl_ch2 = flat_map[ref_ch2] if mdl_ch1 is not None and mdl_ch2 is not None: conserved += self.IntCounts(ref_ch1, ref_ch2, mdl_ch1, mdl_ch2) return conserved / (len(self.thresholds) * self.n) def SCCounts(self, ref_ch, mdl_ch): if not (ref_ch, mdl_ch) in self.single_chain_cache: alns = dict() alns[mdl_ch] = self.ref_mdl_alns[(ref_ch, mdl_ch)] mdl_sel = _CSel(self.mdl, [mdl_ch]) s = self.single_chain_scorer[ref_ch] _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel, residue_mapping=alns, return_dist_test=True, no_interchain=True, chain_mapping={mdl_ch: ref_ch}, check_resnames=False) self.single_chain_cache[(ref_ch, mdl_ch)] = conserved return self.single_chain_cache[(ref_ch, mdl_ch)] def IntCounts(self, ref_ch1, ref_ch2, mdl_ch1, mdl_ch2): k1 = ((ref_ch1, ref_ch2),(mdl_ch1, mdl_ch2)) k2 = ((ref_ch2, ref_ch1),(mdl_ch2, mdl_ch1)) if k1 not in self.interface_cache and k2 not in self.interface_cache: alns = dict() alns[mdl_ch1] = self.ref_mdl_alns[(ref_ch1, mdl_ch1)] alns[mdl_ch2] = self.ref_mdl_alns[(ref_ch2, mdl_ch2)] mdl_sel = _CSel(self.mdl, [mdl_ch1, mdl_ch2]) s = self.interface_scorer[(ref_ch1, ref_ch2)] _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel, residue_mapping=alns, return_dist_test=True, no_intrachain=True, chain_mapping={mdl_ch1: ref_ch1, mdl_ch2: ref_ch2}, check_resnames=False) self.interface_cache[k1] = conserved self.interface_cache[k2] = conserved return self.interface_cache[k1] class _lDDTGreedySearcher(_lDDTDecomposer): def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups, ref_mdl_alns, inclusion_radius = 15.0, thresholds = [0.5, 1.0, 2.0, 4.0], steep_opt_rate = None): """ Greedy extension of already existing but incomplete chain mappings """ super().__init__(ref, mdl, ref_mdl_alns, inclusion_radius = inclusion_radius, thresholds = thresholds) self.steep_opt_rate = steep_opt_rate self.neighbors = {k: set() for k in self.ref_chains} for k in self.interface_scorer.keys(): self.neighbors[k[0]].add(k[1]) self.neighbors[k[1]].add(k[0]) assert(len(ref_chem_groups) == len(mdl_chem_groups)) self.ref_chem_groups = ref_chem_groups self.mdl_chem_groups = mdl_chem_groups self.ref_ch_group_mapper = dict() self.mdl_ch_group_mapper = dict() for g_idx, (ref_g, mdl_g) in enumerate(zip(ref_chem_groups, mdl_chem_groups)): for ch in ref_g: self.ref_ch_group_mapper[ch] = g_idx for ch in mdl_g: self.mdl_ch_group_mapper[ch] = g_idx # keep track of mdl chains that potentially give lDDT contributions, # i.e. they have locations within inclusion_radius + max(thresholds) self.mdl_neighbors = dict() d = self.inclusion_radius + max(self.thresholds) for ch in self.mdl.chains: ch_name = ch.GetName() self.mdl_neighbors[ch_name] = set() query = f"{d} <> [cname={mol.QueryQuoteName(ch_name)}]" query += f" and cname !={mol.QueryQuoteName(ch_name)}" for close_ch in self.mdl.Select(query).chains: self.mdl_neighbors[ch_name].add(close_ch.GetName()) def ExtendMapping(self, mapping, max_ext = None): if len(mapping) == 0: raise RuntimError("Mapping must contain a starting point") # Ref chains onto which we can map. The algorithm starts with a mapping # on ref_ch. From there we can start to expand to connected neighbors. # All neighbors that we can reach from the already mapped chains are # stored in this set which will be updated during runtime map_targets = set() for ref_ch in mapping.keys(): map_targets.update(self.neighbors[ref_ch]) # remove the already mapped chains for ref_ch in mapping.keys(): map_targets.discard(ref_ch) if len(map_targets) == 0: return mapping # nothing to extend # keep track of what model chains are not yet mapped for each chem group free_mdl_chains = list() for chem_group in self.mdl_chem_groups: tmp = [x for x in chem_group if x not in mapping.values()] free_mdl_chains.append(set(tmp)) # keep track of what ref chains got a mapping newly_mapped_ref_chains = list() something_happened = True while something_happened: something_happened=False if self.steep_opt_rate is not None: n_chains = len(newly_mapped_ref_chains) if n_chains > 0 and n_chains % self.steep_opt_rate == 0: mapping = self._SteepOpt(mapping, newly_mapped_ref_chains) if max_ext is not None and len(newly_mapped_ref_chains) >= max_ext: break max_n = 0 max_mapping = None for ref_ch in map_targets: chem_group_idx = self.ref_ch_group_mapper[ref_ch] for mdl_ch in free_mdl_chains[chem_group_idx]: # single chain score n_single = self.SCCounts(ref_ch, mdl_ch) # scores towards neighbors that are already mapped n_inter = 0 for neighbor in self.neighbors[ref_ch]: if neighbor in mapping and mapping[neighbor] in \ self.mdl_neighbors[mdl_ch]: n_inter += self.IntCounts(ref_ch, neighbor, mdl_ch, mapping[neighbor]) n = n_single + n_inter if n_inter > 0 and n > max_n: # Only accept a new solution if its actually connected # i.e. n_inter > 0. Otherwise we could just map a big # fat mdl chain sitting somewhere in Nirvana max_n = n max_mapping = (ref_ch, mdl_ch) if max_n > 0: something_happened = True # assign new found mapping mapping[max_mapping[0]] = max_mapping[1] # add all neighboring chains to map targets as they are now # reachable for neighbor in self.neighbors[max_mapping[0]]: if neighbor not in mapping: map_targets.add(neighbor) # remove the ref chain from map targets map_targets.remove(max_mapping[0]) # remove the mdl chain from free_mdl_chains - its taken... chem_group_idx = self.ref_ch_group_mapper[max_mapping[0]] free_mdl_chains[chem_group_idx].remove(max_mapping[1]) # keep track of what ref chains got a mapping newly_mapped_ref_chains.append(max_mapping[0]) return mapping def _SteepOpt(self, mapping, chains_to_optimize=None): # just optimize ALL ref chains if nothing specified if chains_to_optimize is None: chains_to_optimize = mapping.keys() # make sure that we only have ref chains which are actually mapped ref_chains = [x for x in chains_to_optimize if mapping[x] is not None] # group ref chains to be optimized into chem groups tmp = dict() for ch in ref_chains: chem_group_idx = self.ref_ch_group_mapper[ch] if chem_group_idx in tmp: tmp[chem_group_idx].append(ch) else: tmp[chem_group_idx] = [ch] chem_groups = list(tmp.values()) # try all possible mapping swaps. Swaps that improve the score are # immediately accepted and we start all over again current_lddt = self.lDDTFromFlatMap(mapping) something_happened = True while something_happened: something_happened = False for chem_group in chem_groups: if something_happened: break for ch1, ch2 in itertools.combinations(chem_group, 2): swapped_mapping = dict(mapping) swapped_mapping[ch1] = mapping[ch2] swapped_mapping[ch2] = mapping[ch1] score = self.lDDTFromFlatMap(swapped_mapping) if score > current_lddt: something_happened = True mapping = swapped_mapping current_lddt = score break return mapping def _lDDTNaive(trg, mdl, inclusion_radius, thresholds, chem_groups, chem_mapping, ref_mdl_alns, n_max_naive): """ Naively iterates all possible chain mappings and returns the best """ best_mapping = None best_lddt = -1.0 # Benchmarks on homo-oligomers indicate that full blown lDDT # computation is faster up to tetramers => 4!=24 possible mappings. # For stuff bigger than that, the decomposer approach should be used if _NMappingsWithin(chem_groups, chem_mapping, 24): # Setup scoring lddt_scorer = lddt.lDDTScorer(trg, bb_only = True) for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive): # chain_mapping and alns as input for lDDT computation lddt_chain_mapping = dict() lddt_alns = dict() for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping): for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group): # some mdl chains can be None if mdl_ch is not None: lddt_chain_mapping[mdl_ch] = ref_ch lddt_alns[mdl_ch] = ref_mdl_alns[(ref_ch, mdl_ch)] lDDT, _ = lddt_scorer.lDDT(mdl, thresholds=thresholds, chain_mapping=lddt_chain_mapping, residue_mapping = lddt_alns, check_resnames = False) if lDDT > best_lddt: best_mapping = mapping best_lddt = lDDT else: # Setup scoring lddt_scorer = _lDDTDecomposer(trg, mdl, ref_mdl_alns, inclusion_radius=inclusion_radius, thresholds = thresholds) for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive): lDDT = lddt_scorer.lDDT(chem_groups, mapping) if lDDT > best_lddt: best_mapping = mapping best_lddt = lDDT return (best_mapping, best_lddt) def _GetSeeds(ref_chem_groups, mdl_chem_groups, mapped_ref_chains = set(), mapped_mdl_chains = set()): seeds = list() for ref_chains, mdl_chains in zip(ref_chem_groups, mdl_chem_groups): for ref_ch in ref_chains: if ref_ch not in mapped_ref_chains: for mdl_ch in mdl_chains: if mdl_ch not in mapped_mdl_chains: seeds.append((ref_ch, mdl_ch)) return seeds def _lDDTGreedyFast(the_greed): something_happened = True mapping = dict() while something_happened: something_happened = False seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups, mapped_ref_chains = set(mapping.keys()), mapped_mdl_chains = set(mapping.values())) # search for best scoring starting point n_best = 0 best_seed = None for seed in seeds: n = the_greed.SCCounts(seed[0], seed[1]) if n > n_best: n_best = n best_seed = seed if n_best == 0: break # no proper seed found anymore... # add seed to mapping and start the greed mapping[best_seed[0]] = best_seed[1] mapping = the_greed.ExtendMapping(mapping) something_happened = True # translate mapping format and return final_mapping = list() for ref_chains in the_greed.ref_chem_groups: mapped_mdl_chains = list() for ref_ch in ref_chains: if ref_ch in mapping: mapped_mdl_chains.append(mapping[ref_ch]) else: mapped_mdl_chains.append(None) final_mapping.append(mapped_mdl_chains) return final_mapping def _lDDTGreedyFull(the_greed): """ Uses each reference chain as starting point for expansion """ seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups) best_overall_score = -1.0 best_overall_mapping = dict() for seed in seeds: # do initial extension mapping = the_greed.ExtendMapping({seed[0]: seed[1]}) # repeat the process until we have a full mapping something_happened = True while something_happened: something_happened = False remnant_seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups, mapped_ref_chains = set(mapping.keys()), mapped_mdl_chains = set(mapping.values())) if len(remnant_seeds) > 0: # still more mapping to be done best_score = -1.0 best_mapping = None for remnant_seed in remnant_seeds: tmp_mapping = dict(mapping) tmp_mapping[remnant_seed[0]] = remnant_seed[1] tmp_mapping = the_greed.ExtendMapping(tmp_mapping) score = the_greed.lDDTFromFlatMap(tmp_mapping) if score > best_score: best_score = score best_mapping = tmp_mapping if best_mapping is not None: something_happened = True mapping = best_mapping score = the_greed.lDDTFromFlatMap(mapping) if score > best_overall_score: best_overall_score = score best_overall_mapping = mapping mapping = best_overall_mapping # translate mapping format and return final_mapping = list() for ref_chains in the_greed.ref_chem_groups: mapped_mdl_chains = list() for ref_ch in ref_chains: if ref_ch in mapping: mapped_mdl_chains.append(mapping[ref_ch]) else: mapped_mdl_chains.append(None) final_mapping.append(mapped_mdl_chains) return final_mapping def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group): """ try multiple seeds, i.e. try all ref/mdl chain combinations within the respective chem groups and compute single chain lDDTs. The *blocks_per_chem_group* best scoring ones are extend by *seed_size* chains and the best scoring one is exhaustively extended. """ if seed_size is None or seed_size < 1: raise RuntimeError(f"seed_size must be an int >= 1 (got {seed_size})") if blocks_per_chem_group is None or blocks_per_chem_group < 1: raise RuntimeError(f"blocks_per_chem_group must be an int >= 1 " f"(got {blocks_per_chem_group})") max_ext = seed_size - 1 # -1 => start seed already has size 1 ref_chem_groups = copy.deepcopy(the_greed.ref_chem_groups) mdl_chem_groups = copy.deepcopy(the_greed.mdl_chem_groups) mapping = dict() something_happened = True while something_happened: something_happened = False starting_blocks = list() for ref_chains, mdl_chains in zip(ref_chem_groups, mdl_chem_groups): if len(mdl_chains) == 0: continue # nothing to map ref_chains_copy = list(ref_chains) for i in range(blocks_per_chem_group): if len(ref_chains_copy) == 0: break seeds = list() for ref_ch in ref_chains_copy: seeds += [(ref_ch, mdl_ch) for mdl_ch in mdl_chains] # extend starting seeds to *seed_size* and retain best scoring # block for further extension best_score = -1.0 best_mapping = None best_seed = None for s in seeds: seed = dict(mapping) seed.update({s[0]: s[1]}) seed = the_greed.ExtendMapping(seed, max_ext = max_ext) seed_lddt = the_greed.lDDTFromFlatMap(seed) if seed_lddt > best_score: best_score = seed_lddt best_mapping = seed best_seed = s if best_mapping != None: starting_blocks.append(best_mapping) if best_seed[0] in ref_chains_copy: # remove that ref chain to enforce diversity ref_chains_copy.remove(best_seed[0]) # fully expand initial starting blocks best_lddt = 0.0 best_mapping = None for seed in starting_blocks: seed = the_greed.ExtendMapping(seed) seed_lddt = the_greed.lDDTFromFlatMap(seed) if seed_lddt > best_lddt: best_lddt = seed_lddt best_mapping = seed if best_lddt == 0.0: break # no proper mapping found anymore something_happened = True mapping.update(best_mapping) for ref_ch, mdl_ch in best_mapping.items(): for group_idx in range(len(ref_chem_groups)): if ref_ch in ref_chem_groups[group_idx]: ref_chem_groups[group_idx].remove(ref_ch) if mdl_ch in mdl_chem_groups[group_idx]: mdl_chem_groups[group_idx].remove(mdl_ch) # translate mapping format and return final_mapping = list() for ref_chains in the_greed.ref_chem_groups: mapped_mdl_chains = list() for ref_ch in ref_chains: if ref_ch in mapping: mapped_mdl_chains.append(mapping[ref_ch]) else: mapped_mdl_chains.append(None) final_mapping.append(mapped_mdl_chains) return final_mapping class _QSScoreGreedySearcher(qsscore.QSScorer): def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups, ref_mdl_alns, contact_d = 12.0, steep_opt_rate = None, greedy_prune_contact_map=False): """ Greedy extension of already existing but incomplete chain mappings """ super().__init__(ref, ref_chem_groups, mdl, ref_mdl_alns, contact_d = contact_d) self.ref = ref self.mdl = mdl self.ref_mdl_alns = ref_mdl_alns self.steep_opt_rate = steep_opt_rate if greedy_prune_contact_map: self.neighbors = {k: set() for k in self.qsent1.chain_names} for p in self.qsent1.interacting_chains: if np.count_nonzero(self.qsent1.PairDist(p[0], p[1])<=8) >= 3: self.neighbors[p[0]].add(p[1]) self.neighbors[p[1]].add(p[0]) self.mdl_neighbors = {k: set() for k in self.qsent2.chain_names} for p in self.qsent2.interacting_chains: if np.count_nonzero(self.qsent2.PairDist(p[0], p[1])<=8) >= 3: self.mdl_neighbors[p[0]].add(p[1]) self.mdl_neighbors[p[1]].add(p[0]) else: self.neighbors = {k: set() for k in self.qsent1.chain_names} for p in self.qsent1.interacting_chains: self.neighbors[p[0]].add(p[1]) self.neighbors[p[1]].add(p[0]) self.mdl_neighbors = {k: set() for k in self.qsent2.chain_names} for p in self.qsent2.interacting_chains: self.mdl_neighbors[p[0]].add(p[1]) self.mdl_neighbors[p[1]].add(p[0]) assert(len(ref_chem_groups) == len(mdl_chem_groups)) self.ref_chem_groups = ref_chem_groups self.mdl_chem_groups = mdl_chem_groups self.ref_ch_group_mapper = dict() self.mdl_ch_group_mapper = dict() for g_idx, (ref_g, mdl_g) in enumerate(zip(ref_chem_groups, mdl_chem_groups)): for ch in ref_g: self.ref_ch_group_mapper[ch] = g_idx for ch in mdl_g: self.mdl_ch_group_mapper[ch] = g_idx # cache for lDDT based single chain conserved contacts # used to identify starting points for further extension by QS score # key: tuple (ref_ch, mdl_ch) value: number of conserved contacts self.single_chain_scorer = dict() self.single_chain_cache = dict() for ch in self.ref.chains: single_chain_ref = _CSel(self.ref, [ch.GetName()]) self.single_chain_scorer[ch.GetName()] = \ lddt.lDDTScorer(single_chain_ref, bb_only = True) def SCCounts(self, ref_ch, mdl_ch): if not (ref_ch, mdl_ch) in self.single_chain_cache: alns = dict() alns[mdl_ch] = self.ref_mdl_alns[(ref_ch, mdl_ch)] mdl_sel = _CSel(self.mdl, [mdl_ch]) s = self.single_chain_scorer[ref_ch] _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel, residue_mapping=alns, return_dist_test=True, no_interchain=True, chain_mapping={mdl_ch: ref_ch}, check_resnames=False) self.single_chain_cache[(ref_ch, mdl_ch)] = conserved return self.single_chain_cache[(ref_ch, mdl_ch)] def ExtendMapping(self, mapping, max_ext = None): if len(mapping) == 0: raise RuntimError("Mapping must contain a starting point") # Ref chains onto which we can map. The algorithm starts with a mapping # on ref_ch. From there we can start to expand to connected neighbors. # All neighbors that we can reach from the already mapped chains are # stored in this set which will be updated during runtime map_targets = set() for ref_ch in mapping.keys(): map_targets.update(self.neighbors[ref_ch]) # remove the already mapped chains for ref_ch in mapping.keys(): map_targets.discard(ref_ch) if len(map_targets) == 0: return mapping # nothing to extend # keep track of what model chains are not yet mapped for each chem group free_mdl_chains = list() for chem_group in self.mdl_chem_groups: tmp = [x for x in chem_group if x not in mapping.values()] free_mdl_chains.append(set(tmp)) # keep track of what ref chains got a mapping newly_mapped_ref_chains = list() something_happened = True while something_happened: something_happened=False if self.steep_opt_rate is not None: n_chains = len(newly_mapped_ref_chains) if n_chains > 0 and n_chains % self.steep_opt_rate == 0: mapping = self._SteepOpt(mapping, newly_mapped_ref_chains) if max_ext is not None and len(newly_mapped_ref_chains) >= max_ext: break score_result = self.FromFlatMapping(mapping) old_score = score_result.QS_global nominator = score_result.weighted_scores denominator = score_result.weight_sum + score_result.weight_extra_all max_diff = 0.0 max_mapping = None for ref_ch in map_targets: chem_group_idx = self.ref_ch_group_mapper[ref_ch] for mdl_ch in free_mdl_chains[chem_group_idx]: # we're not computing full QS-score here, we directly hack # into the QS-score formula to compute a diff nominator_diff = 0.0 denominator_diff = 0.0 for neighbor in self.neighbors[ref_ch]: if neighbor in mapping and mapping[neighbor] in \ self.mdl_neighbors[mdl_ch]: # it's a newly added interface if (ref_ch, mdl_ch) # are added to mapping int1 = (ref_ch, neighbor) int2 = (mdl_ch, mapping[neighbor]) a, b, c, d = self._MappedInterfaceScores(int1, int2) nominator_diff += a # weighted_scores denominator_diff += b # weight_sum denominator_diff += d # weight_extra_all # the respective interface penalties are subtracted # from denominator denominator_diff -= self._InterfacePenalty1(int1) denominator_diff -= self._InterfacePenalty2(int2) if nominator_diff > 0: # Only accept a new solution if its actually connected # i.e. nominator_diff > 0. new_nominator = nominator + nominator_diff new_denominator = denominator + denominator_diff new_score = 0.0 if new_denominator != 0.0: new_score = new_nominator/new_denominator diff = new_score - old_score if diff > max_diff: max_diff = diff max_mapping = (ref_ch, mdl_ch) if max_mapping is not None: something_happened = True # assign new found mapping mapping[max_mapping[0]] = max_mapping[1] # add all neighboring chains to map targets as they are now # reachable for neighbor in self.neighbors[max_mapping[0]]: if neighbor not in mapping: map_targets.add(neighbor) # remove the ref chain from map targets map_targets.remove(max_mapping[0]) # remove the mdl chain from free_mdl_chains - its taken... chem_group_idx = self.ref_ch_group_mapper[max_mapping[0]] free_mdl_chains[chem_group_idx].remove(max_mapping[1]) # keep track of what ref chains got a mapping newly_mapped_ref_chains.append(max_mapping[0]) return mapping def _SteepOpt(self, mapping, chains_to_optimize=None): # just optimize ALL ref chains if nothing specified if chains_to_optimize is None: chains_to_optimize = mapping.keys() # make sure that we only have ref chains which are actually mapped ref_chains = [x for x in chains_to_optimize if mapping[x] is not None] # group ref chains to be optimized into chem groups tmp = dict() for ch in ref_chains: chem_group_idx = self.ref_ch_group_mapper[ch] if chem_group_idx in tmp: tmp[chem_group_idx].append(ch) else: tmp[chem_group_idx] = [ch] chem_groups = list(tmp.values()) # try all possible mapping swaps. Swaps that improve the score are # immediately accepted and we start all over again score_result = self.FromFlatMapping(mapping) current_score = score_result.QS_global something_happened = True while something_happened: something_happened = False for chem_group in chem_groups: if something_happened: break for ch1, ch2 in itertools.combinations(chem_group, 2): swapped_mapping = dict(mapping) swapped_mapping[ch1] = mapping[ch2] swapped_mapping[ch2] = mapping[ch1] score_result = self.FromFlatMapping(swapped_mapping) if score_result.QS_global > current_score: something_happened = True mapping = swapped_mapping current_score = score_result.QS_global break return mapping def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d, n_max_naive): best_mapping = None best_score = -1.0 # qs_scorer implements caching, score calculation is thus as fast as it gets # you'll just hit a wall when the number of possible mappings becomes large qs_scorer = qsscore.QSScorer(trg, chem_groups, mdl, ref_mdl_alns) for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive): score_result = qs_scorer.Score(mapping, check=False) if score_result.QS_global > best_score: best_mapping = mapping best_score = score_result.QS_global return (best_mapping, best_score) def _QSScoreGreedyFast(the_greed): something_happened = True mapping = dict() while something_happened: something_happened = False # search for best scoring starting point, we're using lDDT here n_best = 0 best_seed = None seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups, mapped_ref_chains = set(mapping.keys()), mapped_mdl_chains = set(mapping.values())) for seed in seeds: n = the_greed.SCCounts(seed[0], seed[1]) if n > n_best: n_best = n best_seed = seed if n_best == 0: break # no proper seed found anymore... # add seed to mapping and start the greed mapping[best_seed[0]] = best_seed[1] mapping = the_greed.ExtendMapping(mapping) something_happened = True # translate mapping format and return final_mapping = list() for ref_chains in the_greed.ref_chem_groups: mapped_mdl_chains = list() for ref_ch in ref_chains: if ref_ch in mapping: mapped_mdl_chains.append(mapping[ref_ch]) else: mapped_mdl_chains.append(None) final_mapping.append(mapped_mdl_chains) return final_mapping def _QSScoreGreedyFull(the_greed): """ Uses each reference chain as starting point for expansion """ seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups) best_overall_score = -1.0 best_overall_mapping = dict() for seed in seeds: # do initial extension mapping = the_greed.ExtendMapping({seed[0]: seed[1]}) # repeat the process until we have a full mapping something_happened = True while something_happened: something_happened = False remnant_seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups, mapped_ref_chains = set(mapping.keys()), mapped_mdl_chains = set(mapping.values())) if len(remnant_seeds) > 0: # still more mapping to be done best_score = -1.0 best_mapping = None for remnant_seed in remnant_seeds: tmp_mapping = dict(mapping) tmp_mapping[remnant_seed[0]] = remnant_seed[1] tmp_mapping = the_greed.ExtendMapping(tmp_mapping) score_result = the_greed.FromFlatMapping(tmp_mapping) if score_result.QS_global > best_score: best_score = score_result.QS_global best_mapping = tmp_mapping if best_mapping is not None: something_happened = True mapping = best_mapping score_result = the_greed.FromFlatMapping(mapping) if score_result.QS_global > best_overall_score: best_overall_score = score_result.QS_global best_overall_mapping = mapping mapping = best_overall_mapping # translate mapping format and return final_mapping = list() for ref_chains in the_greed.ref_chem_groups: mapped_mdl_chains = list() for ref_ch in ref_chains: if ref_ch in mapping: mapped_mdl_chains.append(mapping[ref_ch]) else: mapped_mdl_chains.append(None) final_mapping.append(mapped_mdl_chains) return final_mapping def _QSScoreGreedyBlock(the_greed, seed_size, blocks_per_chem_group): """ try multiple seeds, i.e. try all ref/mdl chain combinations within the respective chem groups and compute single chain lDDTs. The *blocks_per_chem_group* best scoring ones are extend by *seed_size* chains and the best scoring one with respect to QS score is exhaustively extended. """ if seed_size is None or seed_size < 1: raise RuntimeError(f"seed_size must be an int >= 1 (got {seed_size})") if blocks_per_chem_group is None or blocks_per_chem_group < 1: raise RuntimeError(f"blocks_per_chem_group must be an int >= 1 " f"(got {blocks_per_chem_group})") max_ext = seed_size - 1 # -1 => start seed already has size 1 ref_chem_groups = copy.deepcopy(the_greed.ref_chem_groups) mdl_chem_groups = copy.deepcopy(the_greed.mdl_chem_groups) mapping = dict() something_happened = True while something_happened: something_happened = False starting_blocks = list() for ref_chains, mdl_chains in zip(ref_chem_groups, mdl_chem_groups): if len(mdl_chains) == 0: continue # nothing to map ref_chains_copy = list(ref_chains) for i in range(blocks_per_chem_group): if len(ref_chains_copy) == 0: break seeds = list() for ref_ch in ref_chains_copy: seeds += [(ref_ch, mdl_ch) for mdl_ch in mdl_chains] # extend starting seeds to *seed_size* and retain best scoring block # for further extension best_score = -1.0 best_mapping = None best_seed = None for s in seeds: seed = dict(mapping) seed.update({s[0]: s[1]}) seed = the_greed.ExtendMapping(seed, max_ext = max_ext) score_result = the_greed.FromFlatMapping(seed) if score_result.QS_global > best_score: best_score = score_result.QS_global best_mapping = seed best_seed = s if best_mapping != None: starting_blocks.append(best_mapping) if best_seed[0] in ref_chains_copy: # remove selected ref chain to enforce diversity ref_chains_copy.remove(best_seed[0]) # fully expand initial starting blocks best_score = -1.0 best_mapping = None for seed in starting_blocks: seed = the_greed.ExtendMapping(seed) score_result = the_greed.FromFlatMapping(seed) if score_result.QS_global > best_score: best_score = score_result.QS_global best_mapping = seed if best_mapping is not None and len(best_mapping) > len(mapping): # this even accepts extensions that lead to no increase in QS-score # at least they make sense from an lDDT perspective something_happened = True mapping.update(best_mapping) for ref_ch, mdl_ch in best_mapping.items(): for group_idx in range(len(ref_chem_groups)): if ref_ch in ref_chem_groups[group_idx]: ref_chem_groups[group_idx].remove(ref_ch) if mdl_ch in mdl_chem_groups[group_idx]: mdl_chem_groups[group_idx].remove(mdl_ch) # translate mapping format and return final_mapping = list() for ref_chains in the_greed.ref_chem_groups: mapped_mdl_chains = list() for ref_ch in ref_chains: if ref_ch in mapping: mapped_mdl_chains.append(mapping[ref_ch]) else: mapped_mdl_chains.append(None) final_mapping.append(mapped_mdl_chains) return final_mapping def _SingleRigidGDTTS(initial_transforms, initial_mappings, chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, single_chain_gdtts_thresh, iterative_superposition, first_complete, n_trg_chains, n_mdl_chains): """ Takes initial transforms and sequentially adds chain pairs with best scoring gdtts that fulfill single_chain_gdtts_thresh. The mapping from the transform that leads to best overall gdtts score is returned. Optionally, the first complete mapping, i.e. a mapping that covers all target chains or all model chains, is returned. """ best_mapping = dict() best_gdt = 0 for transform in initial_transforms: mapping = dict() mapped_mdl_chains = set() gdt = 0.0 for trg_chains, mdl_chains, trg_pos, mdl_pos, in zip(chem_groups, chem_mapping, trg_group_pos, mdl_group_pos): if len(trg_pos) == 0 or len(mdl_pos) == 0: continue # cannot compute valid gdt gdt_scores = list() t_mdl_pos = list() for m_pos in mdl_pos: t_m_pos = geom.Vec3List(m_pos) t_m_pos.ApplyTransform(transform) t_mdl_pos.append(t_m_pos) for t_pos, t in zip(trg_pos, trg_chains): for t_m_pos, m in zip(t_mdl_pos, mdl_chains): gdt = t_pos.GetGDTTS(t_m_pos) if gdt >= single_chain_gdtts_thresh: gdt_scores.append((gdt, (t,m))) n_gdt_contacts = 4 * len(trg_pos[0]) gdt_scores.sort(reverse=True) for item in gdt_scores: p = item[1] if p[0] not in mapping and p[1] not in mapped_mdl_chains: mapping[p[0]] = p[1] mapped_mdl_chains.add(p[1]) gdt += (item[0] * n_gdt_contacts) if gdt > best_gdt: best_gdt = gdt best_mapping = mapping if first_complete: n = len(mapping) if n == n_mdl_chains or n == n_trg_chains: break return best_mapping def _IterativeRigidGDTTS(initial_transforms, initial_mappings, chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, single_chain_gdtts_thresh, iterative_superposition, first_complete, n_trg_chains, n_mdl_chains): """ Takes initial transforms and sequentially adds chain pairs with best scoring gdtts that fulfill single_chain_gdtts_thresh. With each added chain pair, the transform gets updated. Thus the naming iterative. The mapping from the initial transform that leads to best overall gdtts score is returned. Optionally, the first complete mapping, i.e. a mapping that covers all target chains or all model chains, is returned. """ # to directly retrieve positions using chain names trg_pos_dict = dict() for trg_pos, trg_chains in zip(trg_group_pos, chem_groups): for t_pos, t in zip(trg_pos, trg_chains): trg_pos_dict[t] = t_pos mdl_pos_dict = dict() for mdl_pos, mdl_chains in zip(mdl_group_pos, chem_mapping): for m_pos, m in zip(mdl_pos, mdl_chains): mdl_pos_dict[m] = m_pos best_mapping = dict() best_gdt = 0 for initial_transform, initial_mapping in zip(initial_transforms, initial_mappings): mapping = {initial_mapping[0]: initial_mapping[1]} transform = geom.Mat4(initial_transform) mapped_trg_pos = geom.Vec3List(trg_pos_dict[initial_mapping[0]]) mapped_mdl_pos = geom.Vec3List(mdl_pos_dict[initial_mapping[1]]) # the following variables contain the chains which are # available for mapping trg_chain_groups = [set(group) for group in chem_groups] mdl_chain_groups = [set(group) for group in chem_mapping] # search and kick out inital mapping for group in trg_chain_groups: if initial_mapping[0] in group: group.remove(initial_mapping[0]) break for group in mdl_chain_groups: if initial_mapping[1] in group: group.remove(initial_mapping[1]) break something_happened = True while something_happened: # search for best mapping given current transform something_happened=False best_sc_mapping = None best_sc_group_idx = None best_sc_gdt = 0.0 group_idx = 0 for trg_chains, mdl_chains in zip(trg_chain_groups, mdl_chain_groups): for t in trg_chains: t_pos = trg_pos_dict[t] for m in mdl_chains: m_pos = mdl_pos_dict[m] t_m_pos = geom.Vec3List(m_pos) t_m_pos.ApplyTransform(transform) gdt = t_pos.GetGDTTS(t_m_pos) if gdt > single_chain_gdtts_thresh and gdt > best_sc_gdt: best_sc_gdt = gdt best_sc_mapping = (t,m) best_sc_group_idx = group_idx group_idx += 1 if best_sc_mapping is not None: something_happened = True mapping[best_sc_mapping[0]] = best_sc_mapping[1] mapped_trg_pos.extend(trg_pos_dict[best_sc_mapping[0]]) mapped_mdl_pos.extend(mdl_pos_dict[best_sc_mapping[1]]) trg_chain_groups[best_sc_group_idx].remove(best_sc_mapping[0]) mdl_chain_groups[best_sc_group_idx].remove(best_sc_mapping[1]) transform = _GetTransform(mapped_mdl_pos, mapped_trg_pos, iterative_superposition) # compute overall gdt for current transform (non-normalized gdt!!!) mapped_mdl_pos.ApplyTransform(transform) gdt = mapped_trg_pos.GetGDTTS(mapped_mdl_pos, norm=False) if gdt > best_gdt: best_gdt = gdt best_mapping = mapping if first_complete: n = len(mapping) if n == n_mdl_chains or n == n_trg_chains: break return best_mapping def _SingleRigidRMSD(initial_transforms, initial_mappings, chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, iterative_superposition): """ Takes initial transforms and sequentially adds chain pairs with lowest RMSD. The mapping from the transform that leads to lowest overall RMSD is returned. """ best_mapping = dict() best_ssd = float("inf") # we're actually going for summed squared distances # Since all positions have same lengths and we do a # full mapping, lowest SSD has a guarantee of also # being lowest RMSD for transform in initial_transforms: mapping = dict() mapped_mdl_chains = set() ssd = 0.0 for trg_chains, mdl_chains, trg_pos, mdl_pos, in zip(chem_groups, chem_mapping, trg_group_pos, mdl_group_pos): if len(trg_pos) == 0 or len(mdl_pos) == 0: continue # cannot compute valid rmsd ssds = list() t_mdl_pos = list() for m_pos in mdl_pos: t_m_pos = geom.Vec3List(m_pos) t_m_pos.ApplyTransform(transform) t_mdl_pos.append(t_m_pos) for t_pos, t in zip(trg_pos, trg_chains): for t_m_pos, m in zip(t_mdl_pos, mdl_chains): ssd = t_pos.GetSummedSquaredDistances(t_m_pos) ssds.append((ssd, (t,m))) ssds.sort() for item in ssds: p = item[1] if p[0] not in mapping and p[1] not in mapped_mdl_chains: mapping[p[0]] = p[1] mapped_mdl_chains.add(p[1]) ssd += item[0] if ssd < best_ssd: best_ssd = ssd best_mapping = mapping return best_mapping def _IterativeRigidRMSD(initial_transforms, initial_mappings, chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, iterative_superposition): """ Takes initial transforms and sequentially adds chain pairs with lowest RMSD. With each added chain pair, the transform gets updated. Thus the naming iterative. The mapping from the initial transform that leads to best overall RMSD score is returned. """ # to directly retrieve positions using chain names trg_pos_dict = dict() for trg_pos, trg_chains in zip(trg_group_pos, chem_groups): for t_pos, t in zip(trg_pos, trg_chains): trg_pos_dict[t] = t_pos mdl_pos_dict = dict() for mdl_pos, mdl_chains in zip(mdl_group_pos, chem_mapping): for m_pos, m in zip(mdl_pos, mdl_chains): mdl_pos_dict[m] = m_pos best_mapping = dict() best_rmsd = float("inf") for initial_transform, initial_mapping in zip(initial_transforms, initial_mappings): mapping = {initial_mapping[0]: initial_mapping[1]} transform = geom.Mat4(initial_transform) mapped_trg_pos = geom.Vec3List(trg_pos_dict[initial_mapping[0]]) mapped_mdl_pos = geom.Vec3List(mdl_pos_dict[initial_mapping[1]]) # the following variables contain the chains which are # available for mapping trg_chain_groups = [set(group) for group in chem_groups] mdl_chain_groups = [set(group) for group in chem_mapping] # search and kick out inital mapping for group in trg_chain_groups: if initial_mapping[0] in group: group.remove(initial_mapping[0]) break for group in mdl_chain_groups: if initial_mapping[1] in group: group.remove(initial_mapping[1]) break something_happened = True while something_happened: # search for best mapping given current transform something_happened=False best_sc_mapping = None best_sc_group_idx = None best_sc_rmsd = float("inf") group_idx = 0 for trg_chains, mdl_chains in zip(trg_chain_groups, mdl_chain_groups): for t in trg_chains: t_pos = trg_pos_dict[t] for m in mdl_chains: m_pos = mdl_pos_dict[m] t_m_pos = geom.Vec3List(m_pos) t_m_pos.ApplyTransform(transform) rmsd = t_pos.GetRMSD(t_m_pos) if rmsd < best_sc_rmsd: best_sc_rmsd = rmsd best_sc_mapping = (t,m) best_sc_group_idx = group_idx group_idx += 1 if best_sc_mapping is not None: something_happened = True mapping[best_sc_mapping[0]] = best_sc_mapping[1] mapped_trg_pos.extend(trg_pos_dict[best_sc_mapping[0]]) mapped_mdl_pos.extend(mdl_pos_dict[best_sc_mapping[1]]) trg_chain_groups[best_sc_group_idx].remove(best_sc_mapping[0]) mdl_chain_groups[best_sc_group_idx].remove(best_sc_mapping[1]) transform = _GetTransform(mapped_mdl_pos, mapped_trg_pos, iterative_superposition) # compute overall RMSD for current transform mapped_mdl_pos.ApplyTransform(transform) rmsd = mapped_trg_pos.GetRMSD(mapped_mdl_pos) if rmsd < best_rmsd: best_rmsd = rmsd best_mapping = mapping return best_mapping def _GetRefPos(trg, mdl, trg_msas, mdl_alns, max_pos = None): """ Extracts reference positions which are present in trg and mdl """ # select only backbone atoms, makes processing simpler later on # (just select res.atoms[0].GetPos() as ref pos) bb_trg = trg.Select("aname=\"CA\",\"C3'\"") bb_mdl = mdl.Select("aname=\"CA\",\"C3'\"") # mdl_alns are pairwise, let's construct MSAs mdl_msas = list() for aln_list in mdl_alns: if len(aln_list) > 0: tmp = aln_list[0].GetSequence(0) ref_seq = seq.CreateSequence(tmp.GetName(), tmp.GetGaplessString()) mdl_msas.append(seq.alg.MergePairwiseAlignments(aln_list, ref_seq)) else: mdl_msas.append(seq.CreateAlignment()) trg_pos = list() mdl_pos = list() for trg_msa, mdl_msa in zip(trg_msas, mdl_msas): if mdl_msa.GetCount() > 0: # make sure they have the same ref sequence (should be a given...) assert(trg_msa.GetSequence(0).GetGaplessString() == \ mdl_msa.GetSequence(0).GetGaplessString()) else: # if mdl_msa is empty, i.e. no model chain maps to the chem group # represented by trg_msa, we just continue. The result will be # empty position lists added to trg_pos and mdl_pos. pass # check which columns in MSAs are fully covered (indices relative to # first sequence) trg_indices = _GetFullyCoveredIndices(trg_msa) mdl_indices = _GetFullyCoveredIndices(mdl_msa) # get indices where both, mdl and trg, are fully covered indices = sorted(list(trg_indices.intersection(mdl_indices))) # subsample if necessary if max_pos is not None and len(indices) > max_pos: step = int(len(indices)/max_pos) indices = [indices[i] for i in range(0, len(indices), step)] # translate to column indices in the respective MSAs trg_indices = _RefIndicesToColumnIndices(trg_msa, indices) mdl_indices = _RefIndicesToColumnIndices(mdl_msa, indices) # extract positions trg_pos.append(list()) mdl_pos.append(list()) for s_idx in range(trg_msa.GetCount()): trg_pos[-1].append(_ExtractMSAPos(trg_msa, s_idx, trg_indices, bb_trg)) # first seq in mdl_msa is ref sequence in trg and does not belong to mdl for s_idx in range(1, mdl_msa.GetCount()): mdl_pos[-1].append(_ExtractMSAPos(mdl_msa, s_idx, mdl_indices, bb_mdl)) return (trg_pos, mdl_pos) def _GetFullyCoveredIndices(msa): """ Helper for _GetRefPos Returns a set containing the indices relative to first sequence in msa which are fully covered in all other sequences --AA-A-A -BBBB-BB CCCC-C-C => (0,1,3) """ indices = set() ref_idx = 0 for col in msa: if sum([1 for olc in col if olc != '-']) == col.GetRowCount(): indices.add(ref_idx) if col[0] != '-': ref_idx += 1 return indices def _RefIndicesToColumnIndices(msa, indices): """ Helper for _GetRefPos Returns a list of mapped indices. indices refer to non-gap one letter codes in the first msa sequence. The returnes mapped indices are translated to the according msa column indices """ ref_idx = 0 mapping = dict() for col_idx, col in enumerate(msa): if col[0] != '-': mapping[ref_idx] = col_idx ref_idx += 1 return [mapping[i] for i in indices] def _ExtractMSAPos(msa, s_idx, indices, view): """ Helper for _GetRefPos Returns a geom.Vec3List containing positions refering to given msa sequence. => Chain with corresponding name is mapped onto sequence and the position of the first atom of each residue specified in indices is extracted. Indices refers to column indices in msa! """ s = msa.GetSequence(s_idx) s_v = _CSel(view, [s.GetName()]) # sanity check assert(len(s.GetGaplessString()) == len(s_v.residues)) residue_idx = [s.GetResidueIndex(i) for i in indices] return geom.Vec3List([s_v.residues[i].atoms[0].pos for i in residue_idx]) def _NChemGroupMappings(ref_chains, mdl_chains): """ Number of mappings within one chem group :param ref_chains: Reference chains :type ref_chains: :class:`list` of :class:`str` :param mdl_chains: Model chains that are mapped onto *ref_chains* :type mdl_chains: :class:`list` of :class:`str` :returns: Number of possible mappings of *mdl_chains* onto *ref_chains* """ n_ref = len(ref_chains) n_mdl = len(mdl_chains) if n_ref == n_mdl: return factorial(n_ref) elif n_ref > n_mdl: n_choose_k = binom(n_ref, n_mdl) return n_choose_k * factorial(n_mdl) else: n_choose_k = binom(n_mdl, n_ref) return n_choose_k * factorial(n_ref) def _NMappings(ref_chains, mdl_chains): """ Number of mappings for a full chem mapping :param ref_chains: Chem groups of reference :type ref_chains: :class:`list` of :class:`list` of :class:`str` :param mdl_chains: Model chains that map onto those chem groups :type mdl_chains: :class:`list` of :class:`list` of :class:`str` :returns: Number of possible mappings of *mdl_chains* onto *ref_chains* """ assert(len(ref_chains) == len(mdl_chains)) n = 1 for a,b in zip(ref_chains, mdl_chains): n *= _NChemGroupMappings(a,b) return n def _NMappingsWithin(ref_chains, mdl_chains, max_mappings): """ Check whether total number of mappings is smaller than given maximum In principle the same as :func:`_NMappings` but it stops as soon as the maximum is hit. :param ref_chains: Chem groups of reference :type ref_chains: :class:`list` of :class:`list` of :class:`str` :param mdl_chains: Model chains that map onto those chem groups :type mdl_chains: :class:`list` of :class:`list` of :class:`str` :param max_mappings: Number of max allowed mappings :returns: Whether number of possible mappings of *mdl_chains* onto *ref_chains* is below or equal *max_mappings*. """ assert(len(ref_chains) == len(mdl_chains)) n = 1 for a,b in zip(ref_chains, mdl_chains): n *= _NChemGroupMappings(a,b) if n > max_mappings: return False return True def _RefSmallerGenerator(ref_chains, mdl_chains): """ Returns all possible ways to map mdl_chains onto ref_chains Specific for the case where len(ref_chains) < len(mdl_chains) """ for c in itertools.combinations(mdl_chains, len(ref_chains)): for p in itertools.permutations(c): yield list(p) def _RefLargerGenerator(ref_chains, mdl_chains): """ Returns all possible ways to map mdl_chains onto ref_chains Specific for the case where len(ref_chains) > len(mdl_chains) Ref chains without mapped mdl chain are assigned None """ n_ref = len(ref_chains) n_mdl = len(mdl_chains) for c in itertools.combinations(range(n_ref), n_mdl): for p in itertools.permutations(mdl_chains): ret_list = [None] * n_ref for idx, ch in zip(c, p): ret_list[idx] = ch yield ret_list def _RefEqualGenerator(ref_chains, mdl_chains): """ Returns all possible ways to map mdl_chains onto ref_chains Specific for the case where len(ref_chains) == len(mdl_chains) """ for p in itertools.permutations(mdl_chains): yield list(p) def _ConcatIterators(iterators): for item in itertools.product(*iterators): yield list(item) def _ChainMappings(ref_chains, mdl_chains, n_max=None): """Returns all possible ways to map *mdl_chains* onto fixed *ref_chains* :param ref_chains: List of list of chemically equivalent chains in reference :type ref_chains: :class:`list` of :class:`list` :param mdl_chains: Equally long list of list of chemically equivalent chains in model that map on those ref chains. :type mdl_chains: :class:`list` of :class:`list` :param n_max: Aborts and raises :class:`RuntimeError` if max number of mappings is above this threshold. :type n_max: :class:`int` :returns: Iterator over all possible mappings of *mdl_chains* onto fixed *ref_chains*. Potentially contains None as padding when number of model chains for a certain mapping is smaller than the according reference chains. Example: _ChainMappings([['A', 'B', 'C'], ['D', 'E']], [['x', 'y'], ['i', 'j']]) gives an iterator over: [[['x', 'y', None], ['i', 'j']], [['x', 'y', None], ['j', 'i']], [['y', 'x', None], ['i', 'j']], [['y', 'x', None], ['j', 'i']], [['x', None, 'y'], ['i', 'j']], [['x', None, 'y'], ['j', 'i']], [['y', None, 'x'], ['i', 'j']], [['y', None, 'x'], ['j', 'i']], [[None, 'x', 'y'], ['i', 'j']], [[None, 'x', 'y'], ['j', 'i']], [[None, 'y', 'x'], ['i', 'j']], [[None, 'y', 'x'], ['j', 'i']]] """ assert(len(ref_chains) == len(mdl_chains)) if n_max is not None: if not _NMappingsWithin(ref_chains, mdl_chains, n_max): raise RuntimeError(f"Too many mappings. Max allowed: {n_max}") # one iterator per mapping representing all mdl combinations relative to # reference iterators = list() for ref, mdl in zip(ref_chains, mdl_chains): if len(ref) == 0: raise RuntimeError("Expext at least one chain in ref chem group") if len(ref) == len(mdl): iterators.append(_RefEqualGenerator(ref, mdl)) elif len(ref) < len(mdl): iterators.append(_RefSmallerGenerator(ref, mdl)) else: iterators.append(_RefLargerGenerator(ref, mdl)) return _ConcatIterators(iterators) def _GetTransform(pos_one, pos_two, iterative): """ Computes minimal RMSD superposition for pos_one onto pos_two :param pos_one: Positions that should be superposed onto *pos_two* :type pos_one: :class:`geom.Vec3List` :param pos_two: Reference positions :type pos_two: :class:`geom.Vec3List` :iterative: Whether iterative superposition should be used. Iterative potentially raises, uses standard superposition as fallback. :type iterative: :class:`bool` :returns: Transformation matrix to superpose *pos_one* onto *pos_two* :rtype: :class:`geom.Mat4` """ res = None if iterative: try: res = mol.alg.IterativeSuperposeSVD(pos_one, pos_two) except: pass # triggers fallback below if res is None: res = mol.alg.SuperposeSVD(pos_one, pos_two) return res.transformation # specify public interface __all__ = ('ChainMapper', 'ReprResult', 'MappingResult')