diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures index 58eeb6eb3da64daa581b0ce701ed58b454322a80..42128cb1b849bfc3245c71217d62979d46c5cf43 100644 --- a/actions/ost-compare-ligand-structures +++ b/actions/ost-compare-ligand-structures @@ -561,16 +561,10 @@ def _Process(model, model_ligands, reference, reference_ligands, args): lddt_pli["reference_ligand"] = reference_ligands_map[ lddt_pli.pop("target_ligand").hash_code] lddt_pli["model_ligand"] = model_key - transform_data = lddt_pli["transform"].data - lddt_pli["transform"] = [transform_data[i:i + 4] - for i in range(0, len(transform_data), - 4)] lddt_pli["bs_ref_res"] = [_QualifiedResidueNotation(r) for r in lddt_pli["bs_ref_res"]] - lddt_pli["bs_ref_res_mapped"] = [_QualifiedResidueNotation(r) for r in - lddt_pli["bs_ref_res_mapped"]] - lddt_pli["bs_mdl_res_mapped"] = [_QualifiedResidueNotation(r) for r in - lddt_pli["bs_mdl_res_mapped"]] + lddt_pli["bs_mdl_res"] = [_QualifiedResidueNotation(r) for r in + lddt_pli["bs_mdl_res"]] lddt_pli["inconsistent_residues"] = ["%s-%s" %( _QualifiedResidueNotation(x), _QualifiedResidueNotation(y)) for x,y in lddt_pli[ "inconsistent_residues"]] @@ -610,7 +604,7 @@ def _Main(): args = _ParseArgs() ost.PushVerbosityLevel(args.verbosity) if args.verbosity < 4: - sys.tracebacklimit = 0 + sys.tracebacklimit = 100 _CheckCompoundLib() try: # Load structures diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index adc0a8901971657d39f8eb3f692f82cf615efaa8..cc9d960867ad4bcaf6335fd6d9a8ad9b6c1b857c 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -430,7 +430,8 @@ class lDDTScorer: chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, - check_resnames=True, add_mdl_contacts=False): + check_resnames=True, add_mdl_contacts=False, + process_model_out=None, interaction_data=None): """Computes lDDT of *model* - globally and per-residue :param model: Model to be scored - models are preferably scored upon @@ -520,6 +521,10 @@ class lDDTScorer: be added if the respective atom pair is not resolved in the target. :type add_mdl_contacts: :class:`bool` + :param process_model_out: Pro param - don't use + :type process_model_out: :class:`tuple` + :param interaction_data: Pro param - don't use + :type interaction_data: :class:`tuple` :returns: global and per-residue lDDT scores as a tuple - first element is global lDDT score (None if *target* has no @@ -550,38 +555,56 @@ class lDDTScorer: # data objects defining model data - see _ProcessModel for rough # description - pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ - res_indices, symmetries = self._ProcessModel(model, chain_mapping, - residue_mapping = residue_mapping, - thresholds = thresholds, - check_resnames = check_resnames) + if process_model_out is None: + pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ + res_indices, ref_res_indices, symmetries = \ + self._ProcessModel(model, chain_mapping, + residue_mapping = residue_mapping, + thresholds = thresholds, + check_resnames = check_resnames) + else: + pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ + res_indices, symmetries = process_model_out if no_interchain and no_intrachain: raise RuntimeError("no_interchain and no_intrachain flags are " "mutually exclusive") - if no_interchain: - sym_ref_indices = self.sym_ref_indices_sc - sym_ref_distances = self.sym_ref_distances_sc - ref_indices = self.ref_indices_sc - ref_distances = self.ref_distances_sc - elif no_intrachain: - sym_ref_indices = self.sym_ref_indices_ic - sym_ref_distances = self.sym_ref_distances_ic - ref_indices = self.ref_indices_ic - ref_distances = self.ref_distances_ic + + sym_ref_indices = None + sym_ref_distances = None + ref_indices = None + ref_distances = None + + if interaction_data is None: + if no_interchain: + sym_ref_indices = self.sym_ref_indices_sc + sym_ref_distances = self.sym_ref_distances_sc + ref_indices = self.ref_indices_sc + ref_distances = self.ref_distances_sc + elif no_intrachain: + sym_ref_indices = self.sym_ref_indices_ic + sym_ref_distances = self.sym_ref_distances_ic + ref_indices = self.ref_indices_ic + ref_distances = self.ref_distances_ic + else: + sym_ref_indices = self.sym_ref_indices + sym_ref_distances = self.sym_ref_distances + ref_indices = self.ref_indices + ref_distances = self.ref_distances + + if add_mdl_contacts: + ref_indices, ref_distances = \ + self._AddMdlContacts(model, res_atom_indices, res_atom_hashes, + ref_indices, ref_distances, + no_interchain, no_intrachain) + # recompute symmetry related indices/distances + sym_ref_indices, sym_ref_distances = \ + lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms, + ref_indices, ref_distances) else: - sym_ref_indices = self.sym_ref_indices - sym_ref_distances = self.sym_ref_distances - ref_indices = self.ref_indices - ref_distances = self.ref_distances - - if add_mdl_contacts: - ref_indices, ref_distances, \ - sym_ref_indices, sym_ref_distances = \ - self._AddMdlContacts(model, res_atom_indices, res_atom_hashes, - ref_indices, ref_distances, - no_interchain, no_intrachain) + sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \ + interaction_data self._ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances) @@ -689,6 +712,9 @@ class lDDTScorer: # indices of the scored residues res_indices = list() + # respective residue indices in reference + ref_res_indices = list() + # Will contain one element per symmetry group symmetries = list() @@ -724,6 +750,7 @@ class lDDTScorer: res_atom_indices.append(list()) res_atom_hashes.append(list()) res_indices.append(current_model_res_idx) + ref_res_indices.append(r_idx) for a_idx, a in enumerate(atoms): if a.IsValid(): p = a.GetPos() @@ -748,7 +775,7 @@ class lDDTScorer: symmetries.append(sym_indices) return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, - res_indices, symmetries) + res_indices, ref_res_indices, symmetries) def _GetExtraModelChainPenalty(self, model, chain_mapping): @@ -1006,12 +1033,7 @@ class lDDTScorer: np.sqrt(tmp, out=tmp) # distances against all relevant atoms ref_distances[i] = np.append(ref_distances[i], tmp) - # recompute symmetry related indices/distances - sym_ref_indices, sym_ref_distances = \ - lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms, - ref_indices, ref_distances) - - return (ref_indices, ref_distances, sym_ref_indices, sym_ref_distances) + return (ref_indices, ref_distances) diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index ce98690ddbf60f5d7511bee9dfb0492e5e5b2232..c5637ca928cb968c28fe49bed0266991d11df21a 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -4,10 +4,13 @@ import numpy as np import numpy.ma as np_ma import networkx +from ost import io from ost import mol from ost import geom +from ost import seq from ost import LogError, LogWarning, LogScript, LogInfo, LogVerbose, LogDebug from ost.mol.alg import chain_mapping +from ost.mol.alg import lddt class LigandScorer: @@ -282,7 +285,8 @@ class LigandScorer: lddt_pli_radius=6.0, lddt_lp_radius=10.0, model_bs_radius=20, binding_sites_topn=100000, global_chain_mapping=False, rmsd_assignment=False, n_max_naive=12, max_symmetries=1e5, - custom_mapping=None, unassigned=False, full_bs_search=False): + custom_mapping=None, unassigned=False, full_bs_search=False, + add_mdl_contacts=False): if isinstance(model, mol.EntityView): self.model = mol.CreateEntityFromView(model, False) @@ -337,6 +341,7 @@ class LigandScorer: self.unassigned = unassigned self.coverage_delta = coverage_delta self.full_bs_search = full_bs_search + self.add_mdl_contacts = add_mdl_contacts # scoring matrices self._rmsd_matrix = None @@ -363,6 +368,7 @@ class LigandScorer: # for localized GetRepr searches self._chem_mapping = None self._chem_group_alns = None + self._ref_mdl_alns = None self._chain_mapping_mdl = None self._get_repr_input = dict() @@ -596,56 +602,6 @@ class LigandScorer: new_editor.UpdateICS() return extracted_ligands - @staticmethod - def _build_binding_site_entity(ligand, residues, extra_residues=[]): - """ Build an entity with all the binding site residues in chain A - and the ligand in chain _. Residues are renumbered consecutively from - 1. The ligand is assigned residue number 1 and residue name LIG. - Residues in extra_residues not in `residues` in the model are added - at the end of chain A. - - :param ligand: the Residue Handle of the ligand - :type ligand: :class:`~ost.mol.ResidueHandle` - :param residues: a list of binding site residues - :type residues: :class:`list` of :class:`~ost.mol.ResidueHandle` - :param extra_residues: an optional list with addition binding site - residues. Residues in this list which are not - in `residues` will be added at the end of chain - A. This allows for instance adding unmapped - residues missing from the model into the - reference binding site. - :type extra_residues: :class:`list` of :class:`~ost.mol.ResidueHandle` - :rtype: :class:`~ost.mol.EntityHandle` - """ - bs_ent = mol.CreateEntity() - ed = bs_ent.EditXCS() - bs_chain = ed.InsertChain("A") - seen_res_qn = [] - for resnum, old_res in enumerate(residues, 1): - seen_res_qn.append(old_res.qualified_name) - new_res = ed.AppendResidue(bs_chain, old_res.handle, - deep=True) - ed.SetResidueNumber(new_res, mol.ResNum(resnum)) - - # Add extra residues at the end. - for extra_res in extra_residues: - if extra_res.qualified_name not in seen_res_qn: - resnum += 1 - seen_res_qn.append(extra_res.qualified_name) - new_res = ed.AppendResidue(bs_chain, - extra_res.handle, - deep=True) - ed.SetResidueNumber(new_res, mol.ResNum(resnum)) - # Add the ligand in chain _ - ligand_chain = ed.InsertChain("_") - ligand_res = ed.AppendResidue(ligand_chain, ligand.handle, - deep=True) - ed.RenameResidue(ligand_res, "LIG") - ed.SetResidueNumber(ligand_res, mol.ResNum(1)) - ed.UpdateICS() - - return bs_ent - def _compute_scores(self): """ Compute the RMSD and lDDT-PLI scores for every possible target-model @@ -712,7 +668,7 @@ class LigandScorer: # Ligand assignment makes assumptions here, and is likely # to not work properly if this differs. There is no reason # it would ever do, so let's just check it - raise Exception("Ligand scoring bug: discrepency between " + raise Exception("Ligand scoring bug: discrepancy between " "RMSD and lDDT-PLI definition.") if rmsd_result is not None: # Now we assume both rmsd_result and lddt_pli_result are defined @@ -765,80 +721,380 @@ class LigandScorer: return best_rmsd_result - def _compute_lddtpli(self, symmetries, target_ligand, model_ligand): - ref_bs = self.get_target_binding_site(target_ligand) - best_lddt_result = None - for r_i, r in enumerate(self.get_repr(target_ligand, model_ligand)): - ref_bs_ent = self._build_binding_site_entity( - target_ligand, r.ref_residues, - r.substructure.residues) - ref_bs_ent_ligand = ref_bs_ent.FindResidue("_", 1) # by definition - - custom_compounds = { - ref_bs_ent_ligand.name: - mol.alg.lddt.CustomCompound.FromResidue( - ref_bs_ent_ligand)} - lddt_scorer = mol.alg.lddt.lDDTScorer( - ref_bs_ent, - custom_compounds=custom_compounds, - inclusion_radius=self.lddt_pli_radius) - - lddt_tot = 4 * sum([len(x) for x in lddt_scorer.ref_indices_ic]) - if lddt_tot == 0: - # it's a space ship in the reference! - self._unassigned_target_ligands_reason[ - target_ligand] = ("no_contact", - "No lDDT-PLI contacts in the" - " reference structure") - #continue - mdl_bs_ent = self._build_binding_site_entity( - model_ligand, r.mdl_residues, []) - mdl_bs_ent_ligand = mdl_bs_ent.FindResidue("_", 1) # by definition - # Now for each symmetry, loop and rename atoms according - # to ref. - mdl_editor = mdl_bs_ent.EditXCS() - for i, (trg_sym, mdl_sym) in enumerate(symmetries): - for mdl_anum, trg_anum in zip(mdl_sym, trg_sym): - # Rename model atoms according to symmetry - trg_atom = ref_bs_ent_ligand.atoms[trg_anum] - mdl_atom = mdl_bs_ent_ligand.atoms[mdl_anum] - mdl_editor.RenameAtom(mdl_atom, trg_atom.name) - mdl_editor.UpdateICS() - - global_lddt, local_lddt = lddt_scorer.lDDT( - mdl_bs_ent, chain_mapping={"A": "A", "_": "_"}, - no_intrachain=True, - check_resnames=self.check_resnames) - - its_awesome = (best_lddt_result is None) or \ - (global_lddt > best_lddt_result["lddt_pli"]) - - # additionally consider rmsd as tiebreaker - if (not its_awesome) and (global_lddt == best_lddt_result["lddt_pli"]): - rmsd_cache_key = (target_ligand.handle.hash_code, - model_ligand.handle.hash_code, r_i) - rmsd = self._rmsd_cache[rmsd_cache_key] - if rmsd < best_lddt_result["rmsd"]: - its_awesome = True - - if its_awesome: - rmsd_cache_key = (target_ligand.handle.hash_code, - model_ligand.handle.hash_code, r_i) - best_lddt_result = {"lddt_pli": global_lddt, - "lddt_lp": r.lDDT, - "lddt_pli_n_contacts": lddt_tot, - "rmsd": self._rmsd_cache[rmsd_cache_key], - "bs_ref_res": r.substructure.residues, - "bs_ref_res_mapped": r.ref_residues, - "bs_mdl_res_mapped": r.mdl_residues, - "bb_rmsd": r.bb_rmsd, - "target_ligand": target_ligand, - "model_ligand": model_ligand, - "chain_mapping": r.GetFlatChainMapping(), - "transform": r.transform, - "inconsistent_residues": r.inconsistent_residues} - - return best_lddt_result + def _compute_lddtpli(self, symmetries, target_ligand, model_ligand, + thresholds = [0.5, 1.0, 2.0, 4.0]): + + # identify residues with contacts to ligands + trg = self.chain_mapper.target + mdl = self.chain_mapping_mdl + + trg_residues = set() + for at in target_ligand.atoms: + close_atoms = trg.FindWithin(at.GetPos(), self.lddt_pli_radius) + for close_at in close_atoms: + trg_residues.add(close_at.GetResidue()) + + mdl_residues = set() + for at in model_ligand.atoms: + close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radius) + for close_at in close_atoms: + mdl_residues.add(close_at.GetResidue()) + + ##################### + # setup lDDT scorer # + ##################### + + # max dist for peptide/nucleotide atom towards ligand for which non-zero + # contribution is possible + max_r = self.lddt_pli_radius + max(thresholds) + + trg_chains = set() + for at in target_ligand.atoms: + close_atoms = trg.FindWithin(at.GetPos(), max_r) + for close_at in close_atoms: + trg_chains.add(close_at.GetChain().GetName()) + + if len(trg_chains) == 0: + # It's a spaceship! + return {"lddt_pli": 0.0, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "bs_ref_res": trg_residues, + "bs_mdl_res": mdl_residues, + "inconsisntent_residues": list()} + + chem_groups = list() + for g in self.chain_mapper.chem_groups: + chem_groups.append([x for x in g if x in trg_chains]) + + query = "cname=" + query += ','.join([mol.QueryQuoteName(x) for x in trg_chains]) + trg_bs = mol.CreateEntityFromView(trg.Select(query), True) + trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT) + trg_ligand_chain = None + for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]: + try: + # I'm pretty sure, one of these chain names is not there yet + trg_ligand_chain = trg_editor.InsertChain(cname) + break + except: + pass + if trg_ligand_chain is None: + raise RuntimeError("Fuck this, I'm out...") + + trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain, target_ligand, + deep=True) + compound_name = trg_ligand_res.name + compound = lddt.CustomCompound.FromResidue(trg_ligand_res) + custom_compounds = {compound_name: compound} + + scorer = lddt.lDDTScorer(trg_bs, + custom_compounds = custom_compounds, + inclusion_radius = self.lddt_pli_radius) + + ############### + # setup model # + ############### + for r in mdl.residues: + r.SetIntProp("bs", 0) + for at in model_ligand.atoms: + close_atoms = mdl.FindWithin(at.GetPos(), max_r) + for close_at in close_atoms: + close_at.GetResidue().SetIntProp("bs", 1) + + + mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True) + mdl_chains = set([ch.name for ch in mdl_bs.chains]) + + if len(mdl_chains) == 0: + # It's a spaceship! + return {"lddt_pli": 0.0, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "bs_ref_res": trg_residues, + "bs_mdl_res": mdl_residues, + "inconsistent_residues": list()} + + mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT) + mdl_ligand_chain = None + for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]: + try: + # I'm pretty sure, one of these chain names is not there yet + mdl_ligand_chain = mdl_editor.InsertChain(cname) + break + except: + pass + if mdl_ligand_chain is None: + raise RuntimeError("Fuck this, I'm out...") + mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain, model_ligand, + deep=True) + + #################### + # Setup alignments # + #################### + chem_mapping = list() + for m in self.chem_mapping: + chem_mapping.append([x for x in m if x in mdl_chains]) + + # ref_mdl_alns refers to full chain mapper trg and mdl structures + # => need to adapt mdl sequence that only contain residues in contact + # with ligand + cut_ref_mdl_alns = dict() + for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping): + for ref_ch in ref_chem_group: + for mdl_ch in mdl_chem_group: + aln = self.ref_mdl_alns[(ref_ch, mdl_ch)] + mdl_bs_chain = mdl_bs.FindChain(mdl_ch) + aln.AttachView(1, mdl.Select("cname=" + mol.QueryQuoteName(mdl_ch))) + cut_mdl_seq = ['-'] * aln.GetLength() + for i, col in enumerate(aln): + r = col.GetResidue(1) + if r.IsValid(): + bs_r = mdl_bs_chain.FindResidue(r.GetNumber()) + if bs_r.IsValid(): + cut_mdl_seq[i] = col[1] + cut_aln = seq.CreateAlignment() + cut_aln.AddSequence(aln.GetSequence(0)) + cut_aln.AddSequence(seq.CreateSequence(mdl_ch, ''.join(cut_mdl_seq))) + cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln + + ############################################################### + # compute lDDT for all possible chain mappings and symmetries # + ############################################################### + best_score = -1.0 + best_result = None + + # cache for model contacts towards non mapped trg chains + # key: tuple in form (mdl_ch, trg_ch) + # value: yet another dict with + # key: ligand_atom_hash + # value: n contacts towards respective trg chain that can be mapped + non_mapped_cache = dict() + + # cache as helper to compute non_mapped_cache + # key: ligand_atom_hash + # value: list of mdl atom handles that are within self.lddt_pli_radius + close_atom_cache = dict() + + for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping): + + lddt_chain_mapping = dict() + lddt_alns = dict() + for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping): + for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group): + # some mdl chains can be None + if mdl_ch is not None: + lddt_chain_mapping[mdl_ch] = ref_ch + lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)] + + # add ligand to lddt_chain_mapping/lddt_alns + lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name + ligand_aln = seq.CreateAlignment() + trg_s = seq.CreateSequence(trg_ligand_chain.name, + trg_ligand_res.GetOneLetterCode()) + mdl_s = seq.CreateSequence(mdl_ligand_chain.name, + mdl_ligand_res.GetOneLetterCode()) + ligand_aln.AddSequence(trg_s) + ligand_aln.AddSequence(mdl_s) + lddt_alns[mdl_ligand_chain.name] = ligand_aln + + if self.add_mdl_contacts: + + # estimate a penalty for unsatisfied model contacts from chains + # that are not in the local trg binding site, but can be mapped in + # the target. + # We're using the trg chain with the closest geometric center that + # can be mapped to the mdl chain according the chem mapping. + # An alternative would be to search for the target chain with + # the minimal number of additional contacts. + # There is not good solution for this problem... + unmapped_chains = list() + for mdl_ch in mdl_chains: + if mdl_ch not in lddt_chain_mapping: + # check which chain in trg is closest + chem_group_idx = None + for i, m in self.chem_mapping: + if mdl_ch in m: + chem_group_idx = i + break + if chem_group_idx is None: + raise RuntimeError("This should never happen... " + "ask Gabriel...") + mdl_center = mdl.FindChain(mdl_ch).geometric_center + closest_trg_ch = None + closest_trg_ch_dist = None + for trg_ch in self.chem_groups[chem_group_idx]: + if trg_ch not in lddt_mapping.values(): + c = self.target.FindChain(trg_ch).geometric_center + d = geom.Distance(mdl_center, c) + if closest_trg_ch_dist is None or d < closest_trg_ch_dist: + closest_trg_ch_dist = d + closest_trg_ch = trg_ch + if closest_trg_ch is not None: + unmapped_chains.append((mdl_ch, closest_trg_ch)) + + for i, (trg_sym, mdl_sym) in enumerate(symmetries): + # remove assert after proper testing - testing assumption made during development + assert(sorted(trg_sym) == list(range(len(trg_ligand_res.atoms)))) + for a in mdl_ligand_res.atoms: + mdl_editor.RenameAtom(a, "asdf") + for mdl_anum, trg_anum in zip(mdl_sym, trg_sym): + # Rename model atoms according to symmetry + trg_atom = trg_ligand_res.atoms[trg_anum] + mdl_atom = mdl_ligand_res.atoms[mdl_anum] + mdl_editor.RenameAtom(mdl_atom, trg_atom.name) + + pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ + res_indices, ref_res_indices, lddt_symmetries = \ + scorer._ProcessModel(mdl_bs, lddt_chain_mapping, + residue_mapping = lddt_alns, + thresholds = thresholds, + check_resnames = self.check_resnames) + ref_indices, ref_distances = \ + scorer._AddMdlContacts(mdl_bs, res_atom_indices, res_atom_hashes, + scorer.ref_indices_ic, scorer.ref_distances_ic, + False, True) + + # distance hacking... remove any interchain distance except the ones + # with the ligand + ligand_start_idx = scorer.chain_start_indices[-1] + for at_idx in range(ligand_start_idx): + mask = ref_indices[at_idx] >= ligand_start_idx + ref_indices[at_idx] = ref_indices[at_idx][mask] + ref_distances[at_idx] = ref_distances[at_idx][mask] + + # compute lddt symmetry related indices/distances + sym_ref_indices, sym_ref_distances = \ + lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms, + ref_indices, ref_distances) + + scorer._ResolveSymmetries(pos, thresholds, lddt_symmetries, + sym_ref_indices, sym_ref_distances) + + # only compute lDDT on ligand residue + n_exp = sum([len(ref_indices[i]) for i in range(ligand_start_idx, scorer.n_atoms)]) + conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1], thresholds, + ref_indices, ref_distances), axis=0) + + # collect number of expected contacts which can be mapped + if len(unmapped_chains) > 0: + for ch_tuple in unmapped_chains: + if ch_tuple not in non_mapped_cache: + + # identify each atom in given mdl chain from mdl_bs + # which can be mapped to a trg atom in given trg + # chain + mappable_atoms = set() + aln = self.ref_mdl_alns[(ch_tuple[1], ch_tuple[0])] + mdl_bs_chain = mdl_bs.FindChain(ch_tuple[0]) + aln.AttachView(0, trg.Select("cname=" + mol.QueryQuoteName(ch_tuple[1]))) + aln.AttachView(1, mdl.Select("cname=" + mol.QueryQuoteName(ch_tuple[0]))) + for i, col in enumerate(aln): + r = col.GetResidue(1) + if r.IsValid(): + bs_r = mdl_bs_chain.FindResidue(r.GetNumber()) + if bs_r.IsValid(): + trg_r = col.GetResidue(0) + if trg_r.IsValid(): + for bs_a in bs_r.atoms: + trg_a = trg_r.FindAtom(bs_a.GetName()) + if trg_a.IsValid(): + mappable_atoms.add(bs_a.hash_code) + + # for each ligand atom, we count the number of mappable + # atoms + counts = dict() + for lig_a in mdl_ligand_res.atoms: + close_atoms = None + if lig_a.hash_code not in close_atom_cache: + tmp = mdl_bs.FindWithin(lig_a.GetPos(), self.lddt_pli_radius) + lig_hash = mdl_ligand_res.hash_code + close_atoms = [x for x in tmp if x.GetResidue().GetHashCode() != lig_hash] + close_atom_cache[lig_a.hash_code] = close_atoms + else: + close_atoms = close_atom_cache[lig_a.hash_code] + + N = 0 + for close_a in close_atoms: + if close_a.hash_code in mappable_atoms: + N += 1 + + counts[lig_a.hash_code] = N + + # fill cache + non_mapped_cache[ch_tuple] = counts + + # add number of mdl contacts which can be mapped to target + # as non-fulfilled contacts + counts = non_mapped_cache[ch_tuple] + for i in mdl_sym: + n_exp += counts[mdl_ligand_res.atoms[i].hash_code] + + score = np.mean(conserved/n_exp) + + if score > best_score: + best_score = score + # do not yet add actual bs_ref_res_mapped and bs_mdl_res_mapped + # do this at the very end... + best_result = {"lddt_pli": score} + + else: + ref_indices = scorer.ref_indices_ic + ref_distances = scorer.ref_distances_ic + + # distance hacking... remove any interchain distance except the ones + # with the ligand + ligand_start_idx = scorer.chain_start_indices[-1] + for at_idx in range(ligand_start_idx): + mask = ref_indices[at_idx] >= ligand_start_idx + ref_indices[at_idx] = ref_indices[at_idx][mask] + ref_distances[at_idx] = ref_distances[at_idx][mask] + + # compute lddt symmetry related indices/distances + sym_ref_indices, sym_ref_distances = \ + lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms, + ref_indices, ref_distances) + + for i, (trg_sym, mdl_sym) in enumerate(symmetries): + # remove assert after proper testing - testing assumption made during development + assert(sorted(trg_sym) == list(range(len(trg_ligand_res.atoms)))) + for a in mdl_ligand_res.atoms: + mdl_editor.RenameAtom(a, "asdf") + for mdl_anum, trg_anum in zip(mdl_sym, trg_sym): + # Rename model atoms according to symmetry + trg_atom = trg_ligand_res.atoms[trg_anum] + mdl_atom = mdl_ligand_res.atoms[mdl_anum] + mdl_editor.RenameAtom(mdl_atom, trg_atom.name) + + pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ + res_indices, ref_res_indices, lddt_symmetries = \ + scorer._ProcessModel(mdl_bs, lddt_chain_mapping, + residue_mapping = lddt_alns, + thresholds = thresholds, + check_resnames = self.check_resnames) + + scorer._ResolveSymmetries(pos, thresholds, lddt_symmetries, + sym_ref_indices, sym_ref_distances) + + # only compute lDDT on ligand residue + n_exp = sum([len(ref_indices[i]) for i in range(ligand_start_idx, scorer.n_atoms)]) + conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1], thresholds, + ref_indices, ref_distances), axis=0) + + score = np.mean(conserved/n_exp) + + if score > best_score: + best_score = score + best_result = {"lddt_pli": score} + + # fill misc info to result object + best_result["target_ligand"] = target_ligand + best_result["model_ligand"] = model_ligand + best_result["bs_ref_res"] = trg_residues + best_result["bs_mdl_res"] = mdl_residues + best_result["inconsistent_residues"] = list() + + return best_result @staticmethod @@ -1766,6 +2022,16 @@ class LigandScorer: self._chain_mapping_mdl = \ self.chain_mapper.GetChemMapping(self.model) return self._chem_group_alns + + @property + def ref_mdl_alns(self): + if self._ref_mdl_alns is None: + self._ref_mdl_alns = \ + chain_mapping._GetRefMdlAlns(self.chain_mapper.chem_groups, + self.chain_mapper.chem_group_alignments, + self.chem_mapping, + self.chem_group_alns) + return self._ref_mdl_alns @property def chain_mapping_mdl(self):