From d4d24587ce7ef6e016fb5bc80e6bf122eefaf935 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Fri, 17 May 2024 18:24:39 +0200 Subject: [PATCH] lddt-pli: first implementation of added mdl contacts with heavy caching Can be considered backup commit and still contains debug output --- modules/mol/alg/pymod/ligand_scoring.py | 488 +++++++++++++++++------- 1 file changed, 348 insertions(+), 140 deletions(-) diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index 538920534..cd014db30 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -11,6 +11,7 @@ from ost import seq from ost import LogError, LogWarning, LogScript, LogInfo, LogVerbose, LogDebug from ost.mol.alg import chain_mapping from ost.mol.alg import lddt +import time class LigandScorer: @@ -286,7 +287,7 @@ class LigandScorer: binding_sites_topn=100000, global_chain_mapping=False, rmsd_assignment=False, n_max_naive=12, max_symmetries=1e5, custom_mapping=None, unassigned=False, full_bs_search=False, - add_mdl_contacts=False, + add_mdl_contacts=True, lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0]): if isinstance(model, mol.EntityView): @@ -739,18 +740,30 @@ class LigandScorer: model_ligand) + def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand, model_ligand): + t0 = time.time() - ########################################################## - # Get stuff from model/target from lazily computed cache # - ########################################################## + ############################### + # Get stuff from model/target # + ############################### trg_residues, trg_bs, trg_chains, trg_ligand_chain, \ trg_ligand_res, scorer, chem_groups = \ self._lddt_pli_get_trg_data(target_ligand) + # distance hacking... remove any interchain distance except the ones + # with the ligand + ref_indices = scorer.ref_indices_ic + ref_distances = scorer.ref_distances_ic + ligand_start_idx = scorer.chain_start_indices[-1] + for at_idx in range(ligand_start_idx): + mask = ref_indices[at_idx] >= ligand_start_idx + ref_indices[at_idx] = ref_indices[at_idx][mask] + ref_distances[at_idx] = ref_distances[at_idx][mask] + mdl_residues, mdl_bs, mdl_chains, mdl_editor, mdl_ligand_chain,\ mdl_ligand_res, chem_mapping = self._lddt_pli_get_mdl_data(model_ligand) @@ -772,14 +785,149 @@ class LigandScorer: # with ligand cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups, chem_mapping, - mdl_bs) - - ############################################################### - # compute lDDT for all possible chain mappings and symmetries # - ############################################################### - - best_score = -1.0 - best_result = None + mdl_bs, trg_bs) + + ######################## + # Setup model contacts # + ######################## + + # THATS A GLOBAL THING AND CAN BE LAZILY COMPUTED!!! + # store for each ref_ch,mdl_ch pair all mdl atoms that can be + # mapped. Don't store mappable atoms as hashes but rather as tuple + # (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that + # the full model and the mdl_bs are different entity handles without + # corresponding hashes. + mappable_atoms = dict() + for (ref_cname, mdl_cname), aln in self.ref_mdl_alns.items(): + mappable_atoms[(ref_cname, mdl_cname)] = set() + ref_ch = self.chain_mapper.target.Select(f"cname={ref_cname}") + mdl_ch = self.chain_mapping_mdl.Select(f"cname={mdl_cname}") + aln.AttachView(0, ref_ch) + aln.AttachView(1, mdl_ch) + for col in aln: + ref_r = col.GetResidue(0) + mdl_r = col.GetResidue(1) + if ref_r.IsValid() and mdl_r.IsValid(): + for mdl_a in mdl_r.atoms: + if ref_r.FindAtom(mdl_a.name).IsValid(): + at_key = (mdl_r.GetNumber(), mdl_a.name) + mappable_atoms[(ref_cname, mdl_cname)].add(at_key) + + # get each chain mapping that we ever observe in scoring + chain_mappings = list(chain_mapping._ChainMappings(chem_groups, chem_mapping)) + + # for each mdl ligand atom, we collect all trg ligand atoms that are + # ever mapped onto it + ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms] + for (trg_sym, mdl_sym) in symmetries: + for trg_i, mdl_i in zip(trg_sym, mdl_sym): + ligand_atom_mappings[mdl_i].add(trg_i) + + mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3)) + for a_idx, a in enumerate(mdl_ligand_res.atoms): + p = a.GetPos() + mdl_ligand_pos[a_idx, 0] = p[0] + mdl_ligand_pos[a_idx, 1] = p[1] + mdl_ligand_pos[a_idx, 2] = p[2] + + trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3)) + for a_idx, a in enumerate(trg_ligand_res.atoms): + p = a.GetPos() + trg_ligand_pos[a_idx, 0] = p[0] + trg_ligand_pos[a_idx, 1] = p[1] + trg_ligand_pos[a_idx, 2] = p[2] + + mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms] + + symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)), dtype=np.int64) + + scoring_cache = list() + penalty_cache = list() + + for mapping in chain_mappings: + + # flat mapping with mdl chain names as key + flat_mapping = dict() + for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping): + for a,b in zip(trg_chem_group, mdl_chem_group): + if a is not None and b is not None: + flat_mapping[b] = a + + # for each mdl bs atom (as atom hash), the trg bs atoms (as index in scorer) + # some caching could help here => same mdl_ch/ref_ch combination could occur + # in several mappings... + bs_atom_mapping = dict() + yolo_mapping = dict() + for mdl_cname, ref_cname in flat_mapping.items(): + aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)] + ref_ch = trg_bs.Select(f"cname={ref_cname}") + mdl_ch = mdl_bs.Select(f"cname={mdl_cname}") + aln.AttachView(0, ref_ch) + aln.AttachView(1, mdl_ch) + for col in aln: + ref_r = col.GetResidue(0) + mdl_r = col.GetResidue(1) + if ref_r.IsValid() and mdl_r.IsValid(): + for mdl_a in mdl_r.atoms: + ref_a = ref_r.FindAtom(mdl_a.GetName()) + if ref_a.IsValid(): + ref_h = ref_a.handle.hash_code + if ref_h in scorer.atom_indices: + mdl_h = mdl_a.handle.hash_code + bs_atom_mapping[mdl_h] = scorer.atom_indices[ref_h] + yolo_mapping[mdl_h] = ref_a + + cache = dict() + n_penalties = list() + + for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms): + n_penalty = 0 + trg_bs_indices = list() + close_a = mdl_bs.FindWithin(mdl_a.GetPos(), self.lddt_pli_radius) + for a in close_a: + mdl_a_hash_code = a.hash_code + if mdl_a_hash_code in bs_atom_mapping: + trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code]) + elif mdl_a_hash_code not in mdl_lig_hashes: + at_key = (a.GetResidue().GetNumber(), a.name) + cname = a.GetChain().name + cname_key = (flat_mapping[cname], cname) + if at_key in mappable_atoms[cname_key]: + # Its a contact in the model but not part of trg_bs. + # It can still be mapped using the global + # mdl_ch/ref_ch alignment + # dist in ref > self.lddt_pli_radius + max_thresh + # => guaranteed to be non-fulfilled contact + n_penalty += 1 + + n_penalties.append(n_penalty) + + trg_bs_indices = np.asarray(sorted(trg_bs_indices)) + for trg_a_idx in ligand_atom_mappings[mdl_a_idx]: + mask = np.isin(trg_bs_indices, ref_indices[ligand_start_idx + trg_a_idx], + assume_unique=True, invert=True) + added_indices = np.asarray([], dtype=np.int64) + added_distances = np.asarray([], dtype=np.float64) + if np.sum(mask) > 0: + # compute ref distances on reference positions + added_indices = trg_bs_indices[mask] + tmp = scorer.positions.take(added_indices, axis=0) + np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :], out=tmp) + np.square(tmp, out=tmp) + tmp = tmp.sum(axis=1) + np.sqrt(tmp, out=tmp) # distances against all relevant atoms + added_distances = tmp + + # extract the distances towards bs atoms that are symmetric + sym_mask = np.isin(added_indices, symmetric_atoms, + assume_unique=True) + + cache[(mdl_a_idx, trg_a_idx)] = (added_indices, added_distances, + added_indices[sym_mask], + added_distances[sym_mask]) + + scoring_cache.append(cache) + penalty_cache.append(n_penalties) # cache for model contacts towards non mapped trg chains # key: tuple in form (mdl_ch, trg_ch) @@ -793,7 +941,27 @@ class LigandScorer: # value: list of mdl atom handles that are within self.lddt_pli_radius close_atom_cache = dict() - for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping): + ############################################################### + # compute lDDT for all possible chain mappings and symmetries # + ############################################################### + + best_score = -1.0 + best_result = None + + # dummy alignment for ligand chains which is needed as input later on + ligand_aln = seq.CreateAlignment() + trg_s = seq.CreateSequence(trg_ligand_chain.name, + trg_ligand_res.GetOneLetterCode()) + mdl_s = seq.CreateSequence(mdl_ligand_chain.name, + mdl_ligand_res.GetOneLetterCode()) + ligand_aln.AddSequence(trg_s) + ligand_aln.AddSequence(mdl_s) + ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms)) + + sym_idx_collector = [None] * scorer.n_atoms + sym_dist_collector = [None] * scorer.n_atoms + + for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache, penalty_cache): lddt_chain_mapping = dict() lddt_alns = dict() @@ -806,15 +974,16 @@ class LigandScorer: # add ligand to lddt_chain_mapping/lddt_alns lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name - ligand_aln = seq.CreateAlignment() - trg_s = seq.CreateSequence(trg_ligand_chain.name, - trg_ligand_res.GetOneLetterCode()) - mdl_s = seq.CreateSequence(mdl_ligand_chain.name, - mdl_ligand_res.GetOneLetterCode()) - ligand_aln.AddSequence(trg_s) - ligand_aln.AddSequence(mdl_s) lddt_alns[mdl_ligand_chain.name] = ligand_aln + # already process model, positions will be manually hacked for each + # symmetry - small overhead for variables that are thrown away here + pos, _, _, _, _, _, lddt_symmetries = \ + scorer._ProcessModel(mdl_bs, lddt_chain_mapping, + residue_mapping = lddt_alns, + thresholds = self.lddt_pli_thresholds, + check_resnames = self.check_resnames) + # estimate a penalty for unsatisfied model contacts from chains # that are not in the local trg binding site, but can be mapped in # the target. @@ -848,58 +1017,55 @@ class LigandScorer: if closest_ch is not None: unmapped_chains.append((mdl_ch, closest_ch)) - for i, (trg_sym, mdl_sym) in enumerate(symmetries): - # remove assert after proper testing - testing assumption made - # during development - assert(sorted(trg_sym)==list(range(len(trg_ligand_res.atoms)))) - for a in mdl_ligand_res.atoms: - mdl_editor.RenameAtom(a, "asdf") - for mdl_anum, trg_anum in zip(mdl_sym, trg_sym): - # Rename model atoms according to symmetry - trg_atom = trg_ligand_res.atoms[trg_anum] - mdl_atom = mdl_ligand_res.atoms[mdl_anum] - mdl_editor.RenameAtom(mdl_atom, trg_atom.name) - - pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ - res_indices, ref_res_indices, lddt_symmetries = \ - scorer._ProcessModel(mdl_bs, lddt_chain_mapping, - residue_mapping = lddt_alns, - thresholds = self.lddt_pli_thresholds, - check_resnames = self.check_resnames) - ref_indices, ref_distances = \ - scorer._AddMdlContacts(mdl_bs, res_atom_indices, - res_atom_hashes, - scorer.ref_indices_ic, - scorer.ref_distances_ic, - False, True) - - # distance hacking... remove any interchain distance except the - # ones with the ligand - ligand_start_idx = scorer.chain_start_indices[-1] - for at_idx in range(ligand_start_idx): - mask = ref_indices[at_idx] >= ligand_start_idx - ref_indices[at_idx] = ref_indices[at_idx][mask] - ref_distances[at_idx] = ref_distances[at_idx][mask] - - # compute lddt symmetry related indices/distances - sym_ref_indices, sym_ref_distances = \ - lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, - scorer.symmetric_atoms, - ref_indices, ref_distances) - + for (trg_sym, mdl_sym) in symmetries: + # update positions + + t0_sym = time.time() + + for mdl_i, trg_i in zip(mdl_sym, trg_sym): + pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :] + + # start new ref_indices/ref_distances from original values + funky_ref_indices = [np.copy(a) for a in ref_indices] + funky_ref_distances = [np.copy(a) for a in ref_distances] + + # The only distances from the binding site towards the ligand + # we care about are the ones from the symmetric atoms. + # We collect them while updating distances from added mdl + # contacts + for idx in symmetric_atoms: + sym_idx_collector[idx] = list() + sym_dist_collector[idx] = list() + + # and add the ones from added mdl contacts + added_penalty = 0 + for mdl_i, trg_i in zip(mdl_sym, trg_sym): + added_penalty += p_cache[mdl_i] + cache = s_cache[mdl_i, trg_i] + full_trg_i = ligand_start_idx + trg_i + funky_ref_indices[full_trg_i] = np.append(funky_ref_indices[full_trg_i], cache[0]) + funky_ref_distances[full_trg_i] = np.append(funky_ref_distances[full_trg_i], cache[1]) + for idx, d in zip(cache[2], cache[3]): + sym_idx_collector[idx].append(full_trg_i) + sym_dist_collector[idx].append(d) + + for idx in symmetric_atoms: + funky_ref_indices[idx] = np.append(funky_ref_indices[idx], + np.asarray(sym_idx_collector[idx], dtype=np.int64)) + funky_ref_distances[idx] = np.append(funky_ref_distances[idx], + np.asarray(sym_dist_collector[idx], dtype=np.float64)) + + # we can pass funky_ref_indices/funky_ref_distances as + # sym_ref_indices/sym_ref_distances in + # scorer._ResolveSymmetries as we only have distances of the bs + # to the ligand and ligand atoms are "non-symmetric" scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds, lddt_symmetries, - sym_ref_indices, - sym_ref_distances) + funky_ref_indices, + funky_ref_distances) - # only compute lDDT on ligand residue - n_exp = \ - sum([len(ref_indices[i]) for i in range(ligand_start_idx, - scorer.n_atoms)]) - conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1], - self.lddt_pli_thresholds, - ref_indices,ref_distances), - axis=0) + n_exp = sum([len(funky_ref_indices[i]) for i in ligand_at_indices]) + n_exp += added_penalty # collect number of expected contacts which can be mapped if len(unmapped_chains) > 0: @@ -910,15 +1076,20 @@ class LigandScorer: mdl_bs, mdl_ligand_res, mdl_sym) - + + conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices, + self.lddt_pli_thresholds, + funky_ref_indices, + funky_ref_distances), axis=0) + print(conserved) + print(n_exp, added_penalty) score = np.mean(conserved/n_exp) if score > best_score: best_score = score - # do not yet add actual bs_ref_res_mapped and bs_mdl_res_mapped - # do this at the very end... best_result = {"lddt_pli": score} + print("sym time:", time.time() - t0_sym) # fill misc info to result object best_result["target_ligand"] = target_ligand @@ -927,21 +1098,44 @@ class LigandScorer: best_result["bs_mdl_res"] = mdl_residues best_result["inconsistent_residues"] = list() + print("full time", time.time() - t0) + return best_result def _compute_lddt_pli_classic(self, symmetries, target_ligand, model_ligand): + ############################### + # Get stuff from model/target # + ############################### - ########################################################## - # Get stuff from model/target from lazily computed cache # - ########################################################## + t0 = time.time() trg_residues, trg_bs, trg_chains, trg_ligand_chain, \ trg_ligand_res, scorer, chem_groups = \ self._lddt_pli_get_trg_data(target_ligand) + # distance hacking... remove any interchain distance except the ones + # with the ligand + ref_indices = scorer.ref_indices_ic + ref_distances = scorer.ref_distances_ic + ligand_start_idx = scorer.chain_start_indices[-1] + for at_idx in range(ligand_start_idx): + mask = ref_indices[at_idx] >= ligand_start_idx + ref_indices[at_idx] = ref_indices[at_idx][mask] + ref_distances[at_idx] = ref_distances[at_idx][mask] + + # no matter what mapping/symmetries, the number of expected + # contacts stays the same + ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms)) + n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices]) + + # compute lddt symmetry related indices/distances + sym_ref_indices, sym_ref_distances = \ + lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms, + ref_indices, ref_distances) + mdl_residues, mdl_bs, mdl_chains, mdl_editor, mdl_ligand_chain,\ mdl_ligand_res, chem_mapping = self._lddt_pli_get_mdl_data(model_ligand) @@ -961,8 +1155,9 @@ class LigandScorer: # ref_mdl_alns refers to full chain mapper trg and mdl structures # => need to adapt mdl sequence that only contain residues in contact # with ligand - cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups, chem_mapping, - mdl_bs) + cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups, + chem_mapping, + mdl_bs, trg_bs) ############################################################### # compute lDDT for all possible chain mappings and symmetries # @@ -971,6 +1166,22 @@ class LigandScorer: best_score = -1.0 best_result = None + # dummy alignment for ligand chains which is needed as input later on + ligand_aln = seq.CreateAlignment() + trg_s = seq.CreateSequence(trg_ligand_chain.name, + trg_ligand_res.GetOneLetterCode()) + mdl_s = seq.CreateSequence(mdl_ligand_chain.name, + mdl_ligand_res.GetOneLetterCode()) + ligand_aln.AddSequence(trg_s) + ligand_aln.AddSequence(mdl_s) + + mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3)) + for a_idx, a in enumerate(model_ligand.atoms): + p = a.GetPos() + mdl_ligand_pos[a_idx, 0] = p[0] + mdl_ligand_pos[a_idx, 1] = p[1] + mdl_ligand_pos[a_idx, 2] = p[2] + for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping): lddt_chain_mapping = dict() @@ -984,62 +1195,34 @@ class LigandScorer: # add ligand to lddt_chain_mapping/lddt_alns lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name - ligand_aln = seq.CreateAlignment() - trg_s = seq.CreateSequence(trg_ligand_chain.name, - trg_ligand_res.GetOneLetterCode()) - mdl_s = seq.CreateSequence(mdl_ligand_chain.name, - mdl_ligand_res.GetOneLetterCode()) - ligand_aln.AddSequence(trg_s) - ligand_aln.AddSequence(mdl_s) lddt_alns[mdl_ligand_chain.name] = ligand_aln - ref_indices = scorer.ref_indices_ic - ref_distances = scorer.ref_distances_ic - - # distance hacking... remove any interchain distance except the ones - # with the ligand - ligand_start_idx = scorer.chain_start_indices[-1] - for at_idx in range(ligand_start_idx): - mask = ref_indices[at_idx] >= ligand_start_idx - ref_indices[at_idx] = ref_indices[at_idx][mask] - ref_distances[at_idx] = ref_distances[at_idx][mask] - - # compute lddt symmetry related indices/distances - sym_ref_indices, sym_ref_distances = \ - lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms, - ref_indices, ref_distances) - - for i, (trg_sym, mdl_sym) in enumerate(symmetries): - # remove assert after proper testing - testing assumption made during development - assert(sorted(trg_sym) == list(range(len(trg_ligand_res.atoms)))) - for a in mdl_ligand_res.atoms: - mdl_editor.RenameAtom(a, "asdf") - for mdl_anum, trg_anum in zip(mdl_sym, trg_sym): - # Rename model atoms according to symmetry - trg_atom = trg_ligand_res.atoms[trg_anum] - mdl_atom = mdl_ligand_res.atoms[mdl_anum] - mdl_editor.RenameAtom(mdl_atom, trg_atom.name) - - pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ - res_indices, ref_res_indices, lddt_symmetries = \ - scorer._ProcessModel(mdl_bs, lddt_chain_mapping, - residue_mapping = lddt_alns, - thresholds = self.lddt_pli_thresholds, - check_resnames = self.check_resnames) - - scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds, lddt_symmetries, - sym_ref_indices, sym_ref_distances) - - # only compute lDDT on ligand residue - n_exp = sum([len(ref_indices[i]) for i in range(ligand_start_idx, scorer.n_atoms)]) - conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1], self.lddt_pli_thresholds, - ref_indices, ref_distances), axis=0) - + # already process model, positions will be manually hacked for each + # symmetry - small overhead of variables that are thrown away here + pos, _, _, _, _, _, lddt_symmetries = \ + scorer._ProcessModel(mdl_bs, lddt_chain_mapping, + residue_mapping = lddt_alns, + thresholds = self.lddt_pli_thresholds, + check_resnames = self.check_resnames) + + for (trg_sym, mdl_sym) in symmetries: + t0_sym = time.time() + for mdl_i, trg_i in zip(mdl_sym, trg_sym): + pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :] + scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds, + lddt_symmetries, + sym_ref_indices, + sym_ref_distances) + conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices, + self.lddt_pli_thresholds, + ref_indices, + ref_distances), axis=0) score = np.mean(conserved/n_exp) if score > best_score: best_score = score best_result = {"lddt_pli": score} + print("sym time:", time.time() - t0_sym) # fill misc info to result object best_result["target_ligand"] = target_ligand @@ -1048,6 +1231,8 @@ class LigandScorer: best_result["bs_mdl_res"] = mdl_residues best_result["inconsistent_residues"] = list() + print("full time:", time.time() - t0) + return best_result def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains, @@ -1175,23 +1360,23 @@ class LigandScorer: trg = self.chain_mapper.target + max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds) + trg_residues = set() for at in target_ligand.atoms: - close_atoms = trg.FindWithin(at.GetPos(), self.lddt_pli_radius) + close_atoms = trg.FindWithin(at.GetPos(), max_r) for close_at in close_atoms: trg_residues.add(close_at.GetResidue()) - max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds) + for r in trg.residues: + r.SetIntProp("bs", 0) - trg_chains = set() - for at in target_ligand.atoms: - close_atoms = trg.FindWithin(at.GetPos(), max_r) - for close_at in close_atoms: - trg_chains.add(close_at.GetChain().GetName()) + for r in trg_residues: + r.SetIntProp("bs", 1) + + trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True) + trg_chains = set([ch.name for ch in trg_bs.chains]) - query = "cname=" - query += ','.join([mol.QueryQuoteName(x) for x in trg_chains]) - trg_bs = mol.CreateEntityFromView(trg.Select(query), True) trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT) trg_ligand_chain = None for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]: @@ -1229,29 +1414,52 @@ class LigandScorer: return self._lddt_pli_target_data[target_ligand] - def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs): + + def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs, + ref_bs): cut_ref_mdl_alns = dict() for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping): for ref_ch in ref_chem_group: + + ref_bs_chain = ref_bs.FindChain(ref_ch) + query = "cname=" + mol.QueryQuoteName(ref_ch) + ref_view = self.chain_mapper.target.Select(query) + for mdl_ch in mdl_chem_group: aln = self.ref_mdl_alns[(ref_ch, mdl_ch)] + + aln.AttachView(0, ref_view) + mdl_bs_chain = mdl_bs.FindChain(mdl_ch) query = "cname=" + mol.QueryQuoteName(mdl_ch) aln.AttachView(1, self.chain_mapping_mdl.Select(query)) + cut_mdl_seq = ['-'] * aln.GetLength() + cut_ref_seq = ['-'] * aln.GetLength() for i, col in enumerate(aln): + + # check ref residue + r = col.GetResidue(0) + if r.IsValid(): + bs_r = ref_bs_chain.FindResidue(r.GetNumber()) + if bs_r.IsValid(): + cut_ref_seq[i] = col[1] + + # check mdl residue r = col.GetResidue(1) if r.IsValid(): bs_r = mdl_bs_chain.FindResidue(r.GetNumber()) if bs_r.IsValid(): cut_mdl_seq[i] = col[1] + + cut_ref_seq = ''.join(cut_ref_seq) + cut_mdl_seq = ''.join(cut_mdl_seq) cut_aln = seq.CreateAlignment() - cut_aln.AddSequence(aln.GetSequence(0)) - cut_aln.AddSequence(seq.CreateSequence(mdl_ch, ''.join(cut_mdl_seq))) + cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq)) + cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq)) cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln return cut_ref_mdl_alns - @staticmethod def _find_ligand_assignment(mat1, mat2=None, coverage=None, coverage_delta=None): """ Find the ligand assignment based on mat1. If mat2 is provided, it -- GitLab