From 21f231df8d385a69b0e60fd22056d1e918e908b1 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Tue, 21 May 2024 18:08:28 +0200 Subject: [PATCH] lddt-pli: cleanup --- modules/mol/alg/pymod/ligand_scoring.py | 267 ++++++++++++------------ 1 file changed, 135 insertions(+), 132 deletions(-) diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index cd014db30..f46cb2c63 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -383,6 +383,7 @@ class LigandScorer: # lazily precomputed variables to speedup lddt-pli computation self._lddt_pli_target_data = dict() self._lddt_pli_model_data = dict() + self._mappable_atoms = None # cache for rmsd values # rmsd is used as tie breaker in lddt-pli, we therefore need access to @@ -423,6 +424,7 @@ class LigandScorer: resnum_alignments=self.resnum_alignments) return self._chain_mapper + def get_target_binding_site(self, target_ligand): if target_ligand.handle.hash_code not in self._binding_sites: @@ -754,18 +756,23 @@ class LigandScorer: trg_ligand_res, scorer, chem_groups = \ self._lddt_pli_get_trg_data(target_ligand) + # Copy to make sure that we don't change anything on underlying + # references + # This is not strictly necessary in the current implementation but + # hey, maybe it avoids hard to debug errors when someone changes things + ref_indices = [a.copy() for a in scorer.ref_indices_ic] + ref_distances = [a.copy() for a in scorer.ref_distances_ic] + # distance hacking... remove any interchain distance except the ones # with the ligand - ref_indices = scorer.ref_indices_ic - ref_distances = scorer.ref_distances_ic ligand_start_idx = scorer.chain_start_indices[-1] for at_idx in range(ligand_start_idx): mask = ref_indices[at_idx] >= ligand_start_idx ref_indices[at_idx] = ref_indices[at_idx][mask] ref_distances[at_idx] = ref_distances[at_idx][mask] - mdl_residues, mdl_bs, mdl_chains, mdl_editor, mdl_ligand_chain,\ - mdl_ligand_res, chem_mapping = self._lddt_pli_get_mdl_data(model_ligand) + mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \ + chem_mapping = self._lddt_pli_get_mdl_data(model_ligand) if len(mdl_chains) == 0 or len(trg_chains) == 0: # It's a spaceship! @@ -787,37 +794,16 @@ class LigandScorer: chem_mapping, mdl_bs, trg_bs) - ######################## - # Setup model contacts # - ######################## - - # THATS A GLOBAL THING AND CAN BE LAZILY COMPUTED!!! - # store for each ref_ch,mdl_ch pair all mdl atoms that can be - # mapped. Don't store mappable atoms as hashes but rather as tuple - # (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that - # the full model and the mdl_bs are different entity handles without - # corresponding hashes. - mappable_atoms = dict() - for (ref_cname, mdl_cname), aln in self.ref_mdl_alns.items(): - mappable_atoms[(ref_cname, mdl_cname)] = set() - ref_ch = self.chain_mapper.target.Select(f"cname={ref_cname}") - mdl_ch = self.chain_mapping_mdl.Select(f"cname={mdl_cname}") - aln.AttachView(0, ref_ch) - aln.AttachView(1, mdl_ch) - for col in aln: - ref_r = col.GetResidue(0) - mdl_r = col.GetResidue(1) - if ref_r.IsValid() and mdl_r.IsValid(): - for mdl_a in mdl_r.atoms: - if ref_r.FindAtom(mdl_a.name).IsValid(): - at_key = (mdl_r.GetNumber(), mdl_a.name) - mappable_atoms[(ref_cname, mdl_cname)].add(at_key) + ######################################## + # Setup cache for added model contacts # + ######################################## # get each chain mapping that we ever observe in scoring - chain_mappings = list(chain_mapping._ChainMappings(chem_groups, chem_mapping)) + chain_mappings = list(chain_mapping._ChainMappings(chem_groups, + chem_mapping)) # for each mdl ligand atom, we collect all trg ligand atoms that are - # ever mapped onto it + # ever mapped onto it given *symmetries* ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms] for (trg_sym, mdl_sym) in symmetries: for trg_i, mdl_i in zip(trg_sym, mdl_sym): @@ -839,8 +825,15 @@ class LigandScorer: mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms] - symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)), dtype=np.int64) + symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)), + dtype=np.int64) + # two caches to cache things for each chain mapping => lists + # of len(chain_mappings) + # + # In principle we're caching for each trg/mdl ligand atom pair all + # information to update ref_indices/ref_distances and resolving the + # symmetries of the binding site. scoring_cache = list() penalty_cache = list() @@ -854,10 +847,7 @@ class LigandScorer: flat_mapping[b] = a # for each mdl bs atom (as atom hash), the trg bs atoms (as index in scorer) - # some caching could help here => same mdl_ch/ref_ch combination could occur - # in several mappings... bs_atom_mapping = dict() - yolo_mapping = dict() for mdl_cname, ref_cname in flat_mapping.items(): aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)] ref_ch = trg_bs.Select(f"cname={ref_cname}") @@ -874,8 +864,8 @@ class LigandScorer: ref_h = ref_a.handle.hash_code if ref_h in scorer.atom_indices: mdl_h = mdl_a.handle.hash_code - bs_atom_mapping[mdl_h] = scorer.atom_indices[ref_h] - yolo_mapping[mdl_h] = ref_a + bs_atom_mapping[mdl_h] = \ + scorer.atom_indices[ref_h] cache = dict() n_penalties = list() @@ -883,7 +873,8 @@ class LigandScorer: for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms): n_penalty = 0 trg_bs_indices = list() - close_a = mdl_bs.FindWithin(mdl_a.GetPos(), self.lddt_pli_radius) + close_a = mdl_bs.FindWithin(mdl_a.GetPos(), + self.lddt_pli_radius) for a in close_a: mdl_a_hash_code = a.hash_code if mdl_a_hash_code in bs_atom_mapping: @@ -892,7 +883,7 @@ class LigandScorer: at_key = (a.GetResidue().GetNumber(), a.name) cname = a.GetChain().name cname_key = (flat_mapping[cname], cname) - if at_key in mappable_atoms[cname_key]: + if at_key in self.mappable_atoms[cname_key]: # Its a contact in the model but not part of trg_bs. # It can still be mapped using the global # mdl_ch/ref_ch alignment @@ -903,7 +894,11 @@ class LigandScorer: n_penalties.append(n_penalty) trg_bs_indices = np.asarray(sorted(trg_bs_indices)) + for trg_a_idx in ligand_atom_mappings[mdl_a_idx]: + # mask selects entries in trg_bs_indices that are not yet + # part of classic lDDT ref_indices for atom at trg_a_idx + # => added mdl contacts mask = np.isin(trg_bs_indices, ref_indices[ligand_start_idx + trg_a_idx], assume_unique=True, invert=True) added_indices = np.asarray([], dtype=np.int64) @@ -930,17 +925,12 @@ class LigandScorer: penalty_cache.append(n_penalties) # cache for model contacts towards non mapped trg chains - # key: tuple in form (mdl_ch, trg_ch) + # key: tuple in form (trg_ch, mdl_ch) # value: yet another dict with # key: ligand_atom_hash # value: n contacts towards respective trg chain that can be mapped non_mapped_cache = dict() - # cache as helper to compute non_mapped_cache - # key: ligand_atom_hash - # value: list of mdl atom handles that are within self.lddt_pli_radius - close_atom_cache = dict() - ############################################################### # compute lDDT for all possible chain mappings and symmetries # ############################################################### @@ -1004,7 +994,8 @@ class LigandScorer: if chem_group_idx is None: raise RuntimeError("This should never happen... " "ask Gabriel...") - mdl_center = mdl.FindChain(mdl_ch).geometric_center + mdl_ch = self.chain_mapping_mdl.FindChain(mdl_ch) + mdl_center = mdl_ch.geometric_center closest_ch = None closest_dist = None for trg_ch in self.chem_groups[chem_group_idx]: @@ -1015,13 +1006,12 @@ class LigandScorer: closest_dist = d closest_ch = trg_ch if closest_ch is not None: - unmapped_chains.append((mdl_ch, closest_ch)) + unmapped_chains.append((closest_ch, mdl_ch)) for (trg_sym, mdl_sym) in symmetries: - # update positions - t0_sym = time.time() + # update positions for mdl_i, trg_i in zip(mdl_sym, trg_sym): pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :] @@ -1030,30 +1020,37 @@ class LigandScorer: funky_ref_distances = [np.copy(a) for a in ref_distances] # The only distances from the binding site towards the ligand - # we care about are the ones from the symmetric atoms. + # we care about are the ones from the symmetric atoms to + # correctly compute scorer._ResolveSymmetries. # We collect them while updating distances from added mdl # contacts for idx in symmetric_atoms: sym_idx_collector[idx] = list() sym_dist_collector[idx] = list() - # and add the ones from added mdl contacts + # add data from added mdl contacts cache added_penalty = 0 for mdl_i, trg_i in zip(mdl_sym, trg_sym): added_penalty += p_cache[mdl_i] cache = s_cache[mdl_i, trg_i] full_trg_i = ligand_start_idx + trg_i - funky_ref_indices[full_trg_i] = np.append(funky_ref_indices[full_trg_i], cache[0]) - funky_ref_distances[full_trg_i] = np.append(funky_ref_distances[full_trg_i], cache[1]) + funky_ref_indices[full_trg_i] = \ + np.append(funky_ref_indices[full_trg_i], cache[0]) + funky_ref_distances[full_trg_i] = \ + np.append(funky_ref_distances[full_trg_i], cache[1]) for idx, d in zip(cache[2], cache[3]): sym_idx_collector[idx].append(full_trg_i) sym_dist_collector[idx].append(d) for idx in symmetric_atoms: - funky_ref_indices[idx] = np.append(funky_ref_indices[idx], - np.asarray(sym_idx_collector[idx], dtype=np.int64)) - funky_ref_distances[idx] = np.append(funky_ref_distances[idx], - np.asarray(sym_dist_collector[idx], dtype=np.float64)) + funky_ref_indices[idx] = \ + np.append(funky_ref_indices[idx], + np.asarray(sym_idx_collector[idx], + dtype=np.int64)) + funky_ref_distances[idx] = \ + np.append(funky_ref_distances[idx], + np.asarray(sym_dist_collector[idx], + dtype=np.float64)) # we can pass funky_ref_indices/funky_ref_distances as # sym_ref_indices/sym_ref_distances in @@ -1064,26 +1061,26 @@ class LigandScorer: funky_ref_indices, funky_ref_distances) - n_exp = sum([len(funky_ref_indices[i]) for i in ligand_at_indices]) - n_exp += added_penalty + N = sum([len(funky_ref_indices[i]) for i in ligand_at_indices]) + N += added_penalty # collect number of expected contacts which can be mapped if len(unmapped_chains) > 0: - n_exp += \ - self._lddt_pli_unmapped_chain_penalty(unmapped_chains, - non_mapped_cache, - close_atom_cache, - mdl_bs, - mdl_ligand_res, - mdl_sym) + N += self._lddt_pli_unmapped_chain_penalty(unmapped_chains, + non_mapped_cache, + mdl_bs, + mdl_ligand_res, + mdl_sym) conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices, self.lddt_pli_thresholds, funky_ref_indices, funky_ref_distances), axis=0) print(conserved) - print(n_exp, added_penalty) - score = np.mean(conserved/n_exp) + print(N, added_penalty) + score = 0.0 + if N > 0: + score = np.mean(conserved/N) if score > best_score: best_score = score @@ -1116,10 +1113,15 @@ class LigandScorer: trg_ligand_res, scorer, chem_groups = \ self._lddt_pli_get_trg_data(target_ligand) - # distance hacking... remove any interchain distance except the ones + # Copy to make sure that we don't change anything on underlying + # references + # This is not strictly necessary in the current implementation but + # hey, maybe it avoids hard to debug errors when someone changes things + ref_indices = [a.copy() for a in scorer.ref_indices_ic] + ref_distances = [a.copy() for a in scorer.ref_distances_ic] + + # Distance hacking... remove any interchain distance except the ones # with the ligand - ref_indices = scorer.ref_indices_ic - ref_distances = scorer.ref_distances_ic ligand_start_idx = scorer.chain_start_indices[-1] for at_idx in range(ligand_start_idx): mask = ref_indices[at_idx] >= ligand_start_idx @@ -1131,13 +1133,8 @@ class LigandScorer: ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms)) n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices]) - # compute lddt symmetry related indices/distances - sym_ref_indices, sym_ref_distances = \ - lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms, - ref_indices, ref_distances) - - mdl_residues, mdl_bs, mdl_chains, mdl_editor, mdl_ligand_chain,\ - mdl_ligand_res, chem_mapping = self._lddt_pli_get_mdl_data(model_ligand) + mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \ + chem_mapping = self._lddt_pli_get_mdl_data(model_ligand) if len(mdl_chains) == 0 or len(trg_chains) == 0: # It's a spaceship! @@ -1167,13 +1164,11 @@ class LigandScorer: best_result = None # dummy alignment for ligand chains which is needed as input later on - ligand_aln = seq.CreateAlignment() - trg_s = seq.CreateSequence(trg_ligand_chain.name, - trg_ligand_res.GetOneLetterCode()) - mdl_s = seq.CreateSequence(mdl_ligand_chain.name, - mdl_ligand_res.GetOneLetterCode()) - ligand_aln.AddSequence(trg_s) - ligand_aln.AddSequence(mdl_s) + l_aln = seq.CreateAlignment() + l_aln.AddSequence(seq.CreateSequence(trg_ligand_chain.name, + trg_ligand_res.GetOneLetterCode())) + l_aln.AddSequence(seq.CreateSequence(mdl_ligand_chain.name, + mdl_ligand_res.GetOneLetterCode())) mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3)) for a_idx, a in enumerate(model_ligand.atoms): @@ -1195,10 +1190,10 @@ class LigandScorer: # add ligand to lddt_chain_mapping/lddt_alns lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name - lddt_alns[mdl_ligand_chain.name] = ligand_aln + lddt_alns[mdl_ligand_chain.name] = l_aln # already process model, positions will be manually hacked for each - # symmetry - small overhead of variables that are thrown away here + # symmetry - small overhead for variables that are thrown away here pos, _, _, _, _, _, lddt_symmetries = \ scorer._ProcessModel(mdl_bs, lddt_chain_mapping, residue_mapping = lddt_alns, @@ -1209,10 +1204,15 @@ class LigandScorer: t0_sym = time.time() for mdl_i, trg_i in zip(mdl_sym, trg_sym): pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :] + # we can pass ref_indices/ref_distances as + # sym_ref_indices/sym_ref_distances in + # scorer._ResolveSymmetries as we only have distances of the bs + # to the ligand and ligand atoms are "non-symmetric" scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds, lddt_symmetries, - sym_ref_indices, - sym_ref_distances) + ref_indices, + ref_distances) + # compute number of conserved distances for ligand atoms conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices, self.lddt_pli_thresholds, ref_indices, @@ -1237,7 +1237,6 @@ class LigandScorer: def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains, non_mapped_cache, - close_atom_cache, mdl_bs, mdl_ligand_res, mdl_sym): @@ -1245,50 +1244,23 @@ class LigandScorer: n_exp = 0 for ch_tuple in unmapped_chains: if ch_tuple not in non_mapped_cache: - - # identify each atom in given mdl chain from mdl_bs which can be - # mapped to a trg atom in given trg chain - mappable_atoms = set() - aln = self.ref_mdl_alns[(ch_tuple[1], ch_tuple[0])] - mdl_bs_chain = mdl_bs.FindChain(ch_tuple[0]) - trg_query = "cname=" + mol.QueryQuoteName(ch_tuple[1]) - trg_view = self.chain_mapper.target.Select(trg_query) - aln.AttachView(0, trg_view) - mdl_query = "cname=" + mol.QueryQuoteName(ch_tuple[0]) - mdl_view = self.chain_mapping_mdl.Select(mdl_query) - aln.AttachView(1, mdl_view) - for i, col in enumerate(aln): - r = col.GetResidue(1) - if r.IsValid(): - bs_r = mdl_bs_chain.FindResidue(r.GetNumber()) - if bs_r.IsValid(): - trg_r = col.GetResidue(0) - if trg_r.IsValid(): - for a in bs_r.atoms: - trg_a = trg_r.FindAtom(bs_a.GetName()) - if trg_a.IsValid(): - mappable_atoms.add(a.handle.hash_code) - # for each ligand atom, we count the number of mappable atoms # within lddt_pli_radius counts = dict() - for lig_a in mdl_ligand_res.atoms: - close_atoms = None - if lig_a.hash_code not in close_atom_cache: - tmp = mdl_bs.FindWithin(lig_a.GetPos(), - self.lddt_pli_radius) - h = mdl_ligand_res.hash_code - tmp = [x for x in tmp if x.GetResidue().hash_code != h] - close_atom_cache[lig_a.hash_code] = tmp - else: - close_atoms = close_atom_cache[lig_a.hash_code] - + # the select statement also excludes the ligand in mdl_bs + # as it resides in a separate chain + mdl_cname = ch_tuple[1] + mdl_bs_ch = mdl_bs.Select(f"cname={mdl_cname}") + for a in mdl_ligand_res.atoms: + close_atoms = \ + mdl_bs_ch.FindWithin(a.GetPos(), self.lddt_pli_radius) N = 0 for close_a in close_atoms: - if close_a.handle.hash_code in mappable_atoms: + at_key = (close_a.GetResidue().GetNumber(), + close_a.GetName()) + if at_key in self.mappable_atoms[ch_tuple]: N += 1 - - counts[lig_a.hash_code] = N + counts[a.hash_code] = N # fill cache non_mapped_cache[ch_tuple] = counts @@ -1296,9 +1268,9 @@ class LigandScorer: # add number of mdl contacts which can be mapped to target # as non-fulfilled contacts counts = non_mapped_cache[ch_tuple] - mdl_ligand_res_atoms = mdl_ligand_res.atoms + lig_hash_codes = [a.hash_code for a in mdl_ligand_res.atoms] for i in mdl_sym: - n_exp += counts[mdl_ligand_res_atoms[i].hash_code] + n_exp += counts[lig_hash_codes[i]] return n_exp @@ -1347,7 +1319,6 @@ class LigandScorer: self._lddt_pli_model_data[model_ligand] = (mdl_residues, mdl_bs, mdl_chains, - mdl_editor, mdl_ligand_chain, mdl_ligand_res, chem_mapping) @@ -1443,7 +1414,7 @@ class LigandScorer: if r.IsValid(): bs_r = ref_bs_chain.FindResidue(r.GetNumber()) if bs_r.IsValid(): - cut_ref_seq[i] = col[1] + cut_ref_seq[i] = col[0] # check mdl residue r = col.GetResidue(1) @@ -2141,6 +2112,38 @@ class LigandScorer: _ = self.unassigned_model_ligands # assigned there return self._unassigned_model_ligand_descriptions + @property + def mappable_atoms(self): + """ Stores mappable atoms given a chain mapping + + Store for each ref_ch,mdl_ch pair all mdl atoms that can be + mapped. Don't store mappable atoms as hashes but rather as tuple + (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that one might + operate on Copied EntityHandle objects without corresponding hashes. + Given a tuple defining c_pair: (ref_cname, mdl_cname), one + can check if a certain atom is mappable by evaluating: + if (mdl_r.GetNumber(), mdl_a.GetName()) in self.mappable_atoms(c_pair) + """ + if self._mappable_atoms is None: + self._mappable_atoms = dict() + for (ref_cname, mdl_cname), aln in self.ref_mdl_alns.items(): + self._mappable_atoms[(ref_cname, mdl_cname)] = set() + ref_ch = self.chain_mapper.target.Select(f"cname={ref_cname}") + mdl_ch = self.chain_mapping_mdl.Select(f"cname={mdl_cname}") + aln.AttachView(0, ref_ch) + aln.AttachView(1, mdl_ch) + for col in aln: + ref_r = col.GetResidue(0) + mdl_r = col.GetResidue(1) + if ref_r.IsValid() and mdl_r.IsValid(): + for mdl_a in mdl_r.atoms: + if ref_r.FindAtom(mdl_a.name).IsValid(): + c_key = (ref_cname, mdl_cname) + at_key = (mdl_r.GetNumber(), mdl_a.name) + self._mappable_atoms[c_key].add(at_key) + + return self._mappable_atoms + def _set_custom_mapping(self, mapping): """ sets self.__model_mapping with a full blown MappingResult object -- GitLab