diff --git a/modules/mol/alg/pymod/CMakeLists.txt b/modules/mol/alg/pymod/CMakeLists.txt index 2cae612018c804361e2e9af94da9d45c83e46379..1d0b60f899728d0835343acfabf636dbdd1c3220 100644 --- a/modules/mol/alg/pymod/CMakeLists.txt +++ b/modules/mol/alg/pymod/CMakeLists.txt @@ -34,6 +34,7 @@ set(OST_MOL_ALG_PYMOD_MODULES contact_score.py ligand_scoring_base.py ligand_scoring_scrmsd.py + ligand_scoring_lddtpli.py ) if (NOT ENABLE_STATIC) diff --git a/modules/mol/alg/pymod/ligand_scoring_base.py b/modules/mol/alg/pymod/ligand_scoring_base.py index 4d1f5af050d13f227bb98b8210b858e63336a8ed..396c748e1955d61e1758816794b82376ca8aa0a6 100644 --- a/modules/mol/alg/pymod/ligand_scoring_base.py +++ b/modules/mol/alg/pymod/ligand_scoring_base.py @@ -57,10 +57,10 @@ class LigandScorer: # lazily computed attributes self._chain_mapper = None - # keep track of error states + # keep track of states # simple integers instead of enums - documentation of property describes # encoding - self._error_states = None + self._states = None # score matrices self._score_matrix = None @@ -68,16 +68,15 @@ class LigandScorer: self._aux_data = None @property - def error_states(self): - """ Encodes error states of ligand pairs + def states(self): + """ Encodes states of ligand pairs - Not only critical things, but also things like: a pair of ligands - simply doesn't match. Target ligands are in rows, model ligands in - columns. States are encoded as integers <= 9. Larger numbers encode - errors for child classes. + Expect a valid score if respective location in this matrix is 0. + Target ligands are in rows, model ligands in columns. States are encoded + as integers <= 9. Larger numbers encode errors for child classes. * -1: Unknown Error - cannot be matched - * 0: Ligand pair has valid symmetries - can be matched. + * 0: Ligand pair can be matched and valid score is computed. * 1: Ligand pair has no valid symmetry - cannot be matched. * 2: Ligand pair has too many symmetries - cannot be matched. You might be able to get a match by increasing *max_symmetries*. @@ -90,9 +89,9 @@ class LigandScorer: :rtype: :class:`~numpy.ndarray` """ - if self._error_states is None: + if self._states is None: self._compute_scores() - return self._error_states + return self._states @property def score_matrix(self): @@ -102,7 +101,7 @@ class LigandScorer: NaN values indicate that no value could be computed (i.e. different ligands). In other words: values are only valid if respective location - :attr:`~error_states` is 0. + :attr:`~states` is 0. :rtype: :class:`~numpy.ndarray` """ @@ -118,7 +117,7 @@ class LigandScorer: NaN values indicate that no value could be computed (i.e. different ligands). In other words: values are only valid if respective location - :attr:`~error_states` is 0. If `substructure_match=False`, only full + :attr:`~states` is 0. If `substructure_match=False`, only full match isomorphisms are considered, and therefore only values of 1.0 can be observed. @@ -138,7 +137,7 @@ class LigandScorer: class to provide additional information for a scored ligand pair. empty dictionaries indicate that no value could be computed (i.e. different ligands). In other words: values are only valid if - respective location :attr:`~error_states` is 0. + respective location :attr:`~states` is 0. :rtype: :class:`~numpy.ndarray` """ @@ -301,7 +300,7 @@ class LigandScorer: shape = (len(self.target_ligands), len(self.model_ligands)) self._score_matrix = np.full(shape, np.nan, dtype=np.float32) self._coverage_matrix = np.full(shape, np.nan, dtype=np.float32) - self._error_states = np.full(shape, -1, dtype=np.int32) + self._states = np.full(shape, -1, dtype=np.int32) self._aux_data = np.empty(shape, dtype=dict) for target_id, target_ligand in enumerate(self.target_ligands): @@ -324,42 +323,45 @@ class LigandScorer: # Ligands are different - skip LogVerbose("No symmetry between %s and %s" % ( str(model_ligand), str(target_ligand))) - self._error_states[target_id, model_id] = 1 + self._states[target_id, model_id] = 1 continue except TooManySymmetriesError: # Ligands are too symmetrical - skip LogVerbose("Too many symmetries between %s and %s" % ( str(model_ligand), str(target_ligand))) - self._error_states[target_id, model_id] = 2 + self._states[target_id, model_id] = 2 continue except NoIsomorphicSymmetryError: # Ligands are different - skip LogVerbose("No isomorphic symmetry between %s and %s" % ( str(model_ligand), str(target_ligand))) - self._error_states[target_id, model_id] = 3 + self._states[target_id, model_id] = 3 continue except DisconnectedGraphError: LogVerbose("Disconnected graph observed for %s and %s" % ( str(model_ligand), str(target_ligand))) - self._error_states[target_id, model_id] = 4 + self._states[target_id, model_id] = 4 continue ##################################################### # Compute score by calling the child class _compute # ##################################################### - score, error_state, aux = self._compute(symmetries, target_ligand, - model_ligand) + score, state, aux = self._compute(symmetries, target_ligand, + model_ligand) ############ # Finalize # ############ - if error_state != 0: + if state != 0: # non-zero error states up to 4 are reserved for base class - if error_state <= 9: + if state <= 9: raise RuntimeError("Child returned reserved err. state") - self._error_states[target_id, model_id] = error_state - if error_state == 0: + self._states[target_id, model_id] = state + if state == 0: + if score is None or np.isnan(score): + raise RuntimeError("LigandScorer returned invalid " + "score despite 0 error state") # it's a valid score! self._score_matrix[target_id, model_id] = score cvg = len(symmetries[0][0]) / len(model_ligand.atoms) @@ -386,12 +388,13 @@ class LigandScorer: :class:`ost.mol.ResidueView` :returns: A :class:`tuple` with three elements: 1) a score - (:class:`float`) 2) error state (:class:`int`). + (:class:`float`) 2) state (:class:`int`). 3) auxiliary data for this ligand pair (:class:`dict`). - If error state is 0, the score and auxiliary data will be + If state is 0, the score and auxiliary data will be added to :attr:`~score_matrix` and :attr:`~aux_data` as well as the respective value in :attr:`~coverage_matrix`. - Child specific non-zero states MUST be >= 10. + Returned score must be valid in this case (not None/NaN). + Child specific non-zero states must be >= 10. """ raise NotImplementedError("_compute must be implemented by child class") diff --git a/modules/mol/alg/pymod/ligand_scoring_lddtpli.py b/modules/mol/alg/pymod/ligand_scoring_lddtpli.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4d88898b2cba09849131d62b6d41e4b124a310 --- /dev/null +++ b/modules/mol/alg/pymod/ligand_scoring_lddtpli.py @@ -0,0 +1,834 @@ +import numpy as np + +from ost import LogWarning +from ost import geom +from ost import mol +from ost import seq + +from ost.mol.alg import lddt +from ost.mol.alg import chain_mapping +from ost.mol.alg import ligand_scoring_base + +class LDDTPLIScorer(ligand_scoring_base.LigandScorer): + + def __init__(self, model, target, model_ligands=None, target_ligands=None, + resnum_alignments=False, rename_ligand_chain=False, + substructure_match=False, coverage_delta=0.2, + max_symmetries=1e5, check_resnames=True, lddt_pli_radius=6.0, + add_mdl_contacts=True, + lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0], + lddt_pli_binding_site_radius=None): + + super().__init__(model, target, model_ligands = model_ligands, + target_ligands = target_ligands, + resnum_alignments = resnum_alignments, + rename_ligand_chain = rename_ligand_chain, + substructure_match = substructure_match, + coverage_delta = coverage_delta, + max_symmetries = 1e5) + + self.check_resnames = check_resnames + self.lddt_pli_radius = lddt_pli_radius + self.add_mdl_contacts = add_mdl_contacts + self.lddt_pli_thresholds = lddt_pli_thresholds + self.lddt_pli_binding_site_radius = lddt_pli_binding_site_radius + + # lazily precomputed variables to speedup lddt-pli computation + self._lddt_pli_target_data = dict() + self._lddt_pli_model_data = dict() + self.__mappable_atoms = None + self.__chem_mapping = None + self.__chem_group_alns = None + self.__ref_mdl_alns = None + self.__chain_mapping_mdl = None + + + def _compute(self, symmetries, target_ligand, model_ligand): + + + if self.add_mdl_contacts: + result = self._compute_lddt_pli_add_mdl_contacts(symmetries, + target_ligand, + model_ligand) + else: + result = self._compute_lddt_pli_classic(symmetries, + target_ligand, + model_ligand) + + state = 0 + score = result["lddt_pli"] + + if score is None: + if result["lddt_pli_n_contacts"] == 0: + # it's a space ship! + state = 10 + else: + # unknwon error state + state = 11 + + return (score, state, result) + + def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand, + model_ligand): + + ############################### + # Get stuff from model/target # + ############################### + + trg_residues, trg_bs, trg_chains, trg_ligand_chain, \ + trg_ligand_res, scorer, chem_groups = \ + self._lddt_pli_get_trg_data(target_ligand) + + # Copy to make sure that we don't change anything on underlying + # references + # This is not strictly necessary in the current implementation but + # hey, maybe it avoids hard to debug errors when someone changes things + ref_indices = [a.copy() for a in scorer.ref_indices_ic] + ref_distances = [a.copy() for a in scorer.ref_distances_ic] + + # distance hacking... remove any interchain distance except the ones + # with the ligand + ligand_start_idx = scorer.chain_start_indices[-1] + for at_idx in range(ligand_start_idx): + mask = ref_indices[at_idx] >= ligand_start_idx + ref_indices[at_idx] = ref_indices[at_idx][mask] + ref_distances[at_idx] = ref_distances[at_idx][mask] + + mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \ + chem_mapping = self._lddt_pli_get_mdl_data(model_ligand) + + if len(mdl_chains) == 0 or len(trg_chains) == 0: + # It's a spaceship! + return {"lddt_pli": None, + "lddt_pli_n_contacts": 0, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "bs_ref_res": trg_residues, + "bs_mdl_res": mdl_residues} + + #################### + # Setup alignments # + #################### + + # ref_mdl_alns refers to full chain mapper trg and mdl structures + # => need to adapt mdl sequence that only contain residues in contact + # with ligand + cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups, + chem_mapping, + mdl_bs, trg_bs) + + ######################################## + # Setup cache for added model contacts # + ######################################## + + # get each chain mapping that we ever observe in scoring + chain_mappings = list(chain_mapping._ChainMappings(chem_groups, + chem_mapping)) + + # for each mdl ligand atom, we collect all trg ligand atoms that are + # ever mapped onto it given *symmetries* + ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms] + for (trg_sym, mdl_sym) in symmetries: + for trg_i, mdl_i in zip(trg_sym, mdl_sym): + ligand_atom_mappings[mdl_i].add(trg_i) + + mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3)) + for a_idx, a in enumerate(mdl_ligand_res.atoms): + p = a.GetPos() + mdl_ligand_pos[a_idx, 0] = p[0] + mdl_ligand_pos[a_idx, 1] = p[1] + mdl_ligand_pos[a_idx, 2] = p[2] + + trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3)) + for a_idx, a in enumerate(trg_ligand_res.atoms): + p = a.GetPos() + trg_ligand_pos[a_idx, 0] = p[0] + trg_ligand_pos[a_idx, 1] = p[1] + trg_ligand_pos[a_idx, 2] = p[2] + + mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms] + + symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)), + dtype=np.int64) + + # two caches to cache things for each chain mapping => lists + # of len(chain_mappings) + # + # In principle we're caching for each trg/mdl ligand atom pair all + # information to update ref_indices/ref_distances and resolving the + # symmetries of the binding site. + # in detail: each list entry in *scoring_cache* is a dict with + # key: (mdl_lig_at_idx, trg_lig_at_idx) + # value: tuple with 4 elements - 1: indices of atoms representing added + # contacts relative to overall inexing scheme in scorer 2: the + # respective distances 3: the same but only containing indices towards + # atoms of the binding site that are considered symmetric 4: the + # respective indices. + # each list entry in *penalty_cache* is a list of len N mdl lig atoms. + # For each mdl lig at it contains a penalty for this mdl lig at. That + # means the number of contacts in the mdl binding site that can + # directly be mapped to the target given the local chain mapping but + # are not present in the target binding site, i.e. interacting atoms are + # too far away. + scoring_cache = list() + penalty_cache = list() + + for mapping in chain_mappings: + + # flat mapping with mdl chain names as key + flat_mapping = dict() + for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping): + for a,b in zip(trg_chem_group, mdl_chem_group): + if a is not None and b is not None: + flat_mapping[b] = a + + # for each mdl bs atom (as atom hash), the trg bs atoms (as index in scorer) + bs_atom_mapping = dict() + for mdl_cname, ref_cname in flat_mapping.items(): + aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)] + ref_ch = trg_bs.Select(f"cname={mol.QueryQuoteName(ref_cname)}") + mdl_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}") + aln.AttachView(0, ref_ch) + aln.AttachView(1, mdl_ch) + for col in aln: + ref_r = col.GetResidue(0) + mdl_r = col.GetResidue(1) + if ref_r.IsValid() and mdl_r.IsValid(): + for mdl_a in mdl_r.atoms: + ref_a = ref_r.FindAtom(mdl_a.GetName()) + if ref_a.IsValid(): + ref_h = ref_a.handle.hash_code + if ref_h in scorer.atom_indices: + mdl_h = mdl_a.handle.hash_code + bs_atom_mapping[mdl_h] = \ + scorer.atom_indices[ref_h] + + cache = dict() + n_penalties = list() + + for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms): + n_penalty = 0 + trg_bs_indices = list() + close_a = mdl_bs.FindWithin(mdl_a.GetPos(), + self.lddt_pli_radius) + for a in close_a: + mdl_a_hash_code = a.hash_code + if mdl_a_hash_code in bs_atom_mapping: + trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code]) + elif mdl_a_hash_code not in mdl_lig_hashes: + if a.GetChain().GetName() in flat_mapping: + # Its in a mapped chain + at_key = (a.GetResidue().GetNumber(), a.name) + cname = a.GetChain().name + cname_key = (flat_mapping[cname], cname) + if at_key in self._mappable_atoms[cname_key]: + # Its a contact in the model but not part of + # trg_bs. It can still be mapped using the + # global mdl_ch/ref_ch alignment + # d in ref > self.lddt_pli_radius + max_thresh + # => guaranteed to be non-fulfilled contact + n_penalty += 1 + + n_penalties.append(n_penalty) + + trg_bs_indices = np.asarray(sorted(trg_bs_indices)) + + for trg_a_idx in ligand_atom_mappings[mdl_a_idx]: + # mask selects entries in trg_bs_indices that are not yet + # part of classic lDDT ref_indices for atom at trg_a_idx + # => added mdl contacts + mask = np.isin(trg_bs_indices, ref_indices[ligand_start_idx + trg_a_idx], + assume_unique=True, invert=True) + added_indices = np.asarray([], dtype=np.int64) + added_distances = np.asarray([], dtype=np.float64) + if np.sum(mask) > 0: + # compute ref distances on reference positions + added_indices = trg_bs_indices[mask] + tmp = scorer.positions.take(added_indices, axis=0) + np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :], out=tmp) + np.square(tmp, out=tmp) + tmp = tmp.sum(axis=1) + np.sqrt(tmp, out=tmp) # distances against all relevant atoms + added_distances = tmp + + # extract the distances towards bs atoms that are symmetric + sym_mask = np.isin(added_indices, symmetric_atoms, + assume_unique=True) + + cache[(mdl_a_idx, trg_a_idx)] = (added_indices, added_distances, + added_indices[sym_mask], + added_distances[sym_mask]) + + scoring_cache.append(cache) + penalty_cache.append(n_penalties) + + # cache for model contacts towards non mapped trg chains - this is + # relevant for self._lddt_pli_unmapped_chain_penalty + # key: tuple in form (trg_ch, mdl_ch) + # value: yet another dict with + # key: ligand_atom_hash + # value: n contacts towards respective trg chain that can be mapped + non_mapped_cache = dict() + + ############################################################### + # compute lDDT for all possible chain mappings and symmetries # + ############################################################### + + best_score = -1.0 + best_result = {"lddt_pli": None, + "lddt_pli_n_contacts": 0} + + # dummy alignment for ligand chains which is needed as input later on + ligand_aln = seq.CreateAlignment() + trg_s = seq.CreateSequence(trg_ligand_chain.name, + trg_ligand_res.GetOneLetterCode()) + mdl_s = seq.CreateSequence(mdl_ligand_chain.name, + mdl_ligand_res.GetOneLetterCode()) + ligand_aln.AddSequence(trg_s) + ligand_aln.AddSequence(mdl_s) + ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms)) + + sym_idx_collector = [None] * scorer.n_atoms + sym_dist_collector = [None] * scorer.n_atoms + + for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache, penalty_cache): + + lddt_chain_mapping = dict() + lddt_alns = dict() + for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping): + for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group): + # some mdl chains can be None + if mdl_ch is not None: + lddt_chain_mapping[mdl_ch] = ref_ch + lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)] + + # add ligand to lddt_chain_mapping/lddt_alns + lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name + lddt_alns[mdl_ligand_chain.name] = ligand_aln + + # already process model, positions will be manually hacked for each + # symmetry - small overhead for variables that are thrown away here + pos, _, _, _, _, _, lddt_symmetries = \ + scorer._ProcessModel(mdl_bs, lddt_chain_mapping, + residue_mapping = lddt_alns, + thresholds = self.lddt_pli_thresholds, + check_resnames = self.check_resnames) + + # estimate a penalty for unsatisfied model contacts from chains + # that are not in the local trg binding site, but can be mapped in + # the target. + # We're using the trg chain with the closest geometric center that + # can be mapped to the mdl chain according the chem mapping. + # An alternative would be to search for the target chain with + # the minimal number of additional contacts. + # There is not good solution for this problem... + unmapped_chains = list() + for mdl_ch in mdl_chains: + if mdl_ch not in lddt_chain_mapping: + # check which chain in trg is closest + chem_group_idx = None + for i, m in enumerate(self._chem_mapping): + if mdl_ch in m: + chem_group_idx = i + break + if chem_group_idx is None: + raise RuntimeError("This should never happen... " + "ask Gabriel...") + mdl_ch_view = self._chain_mapping_mdl.FindChain(mdl_ch) + mdl_center = mdl_ch_view.geometric_center + closest_ch = None + closest_dist = None + for trg_ch in self.chain_mapper.chem_groups[chem_group_idx]: + if trg_ch not in lddt_chain_mapping.values(): + c = self.chain_mapper.target.FindChain(trg_ch).geometric_center + d = geom.Distance(mdl_center, c) + if closest_dist is None or d < closest_dist: + closest_dist = d + closest_ch = trg_ch + if closest_ch is not None: + unmapped_chains.append((closest_ch, mdl_ch)) + + for (trg_sym, mdl_sym) in symmetries: + + # update positions + for mdl_i, trg_i in zip(mdl_sym, trg_sym): + pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :] + + # start new ref_indices/ref_distances from original values + funky_ref_indices = [np.copy(a) for a in ref_indices] + funky_ref_distances = [np.copy(a) for a in ref_distances] + + # The only distances from the binding site towards the ligand + # we care about are the ones from the symmetric atoms to + # correctly compute scorer._ResolveSymmetries. + # We collect them while updating distances from added mdl + # contacts + for idx in symmetric_atoms: + sym_idx_collector[idx] = list() + sym_dist_collector[idx] = list() + + # add data from added mdl contacts cache + added_penalty = 0 + for mdl_i, trg_i in zip(mdl_sym, trg_sym): + added_penalty += p_cache[mdl_i] + cache = s_cache[mdl_i, trg_i] + full_trg_i = ligand_start_idx + trg_i + funky_ref_indices[full_trg_i] = \ + np.append(funky_ref_indices[full_trg_i], cache[0]) + funky_ref_distances[full_trg_i] = \ + np.append(funky_ref_distances[full_trg_i], cache[1]) + for idx, d in zip(cache[2], cache[3]): + sym_idx_collector[idx].append(full_trg_i) + sym_dist_collector[idx].append(d) + + for idx in symmetric_atoms: + funky_ref_indices[idx] = \ + np.append(funky_ref_indices[idx], + np.asarray(sym_idx_collector[idx], + dtype=np.int64)) + funky_ref_distances[idx] = \ + np.append(funky_ref_distances[idx], + np.asarray(sym_dist_collector[idx], + dtype=np.float64)) + + # we can pass funky_ref_indices/funky_ref_distances as + # sym_ref_indices/sym_ref_distances in + # scorer._ResolveSymmetries as we only have distances of the bs + # to the ligand and ligand atoms are "non-symmetric" + scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds, + lddt_symmetries, + funky_ref_indices, + funky_ref_distances) + + N = sum([len(funky_ref_indices[i]) for i in ligand_at_indices]) + N += added_penalty + + # collect number of expected contacts which can be mapped + if len(unmapped_chains) > 0: + N += self._lddt_pli_unmapped_chain_penalty(unmapped_chains, + non_mapped_cache, + mdl_bs, + mdl_ligand_res, + mdl_sym) + + conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices, + self.lddt_pli_thresholds, + funky_ref_indices, + funky_ref_distances), axis=0) + score = None + if N > 0: + score = np.mean(conserved/N) + + if score is not None and score > best_score: + best_score = score + best_result = {"lddt_pli": score, + "lddt_pli_n_contacts": N} + + # fill misc info to result object + best_result["target_ligand"] = target_ligand + best_result["model_ligand"] = model_ligand + best_result["bs_ref_res"] = trg_residues + best_result["bs_mdl_res"] = mdl_residues + + return best_result + + + def _compute_lddt_pli_classic(self, symmetries, target_ligand, + model_ligand): + + ############################### + # Get stuff from model/target # + ############################### + + max_r = None + if self.lddt_pli_binding_site_radius: + max_r = self.lddt_pli_binding_site_radius + + trg_residues, trg_bs, trg_chains, trg_ligand_chain, \ + trg_ligand_res, scorer, chem_groups = \ + self._lddt_pli_get_trg_data(target_ligand, max_r = max_r) + + # Copy to make sure that we don't change anything on underlying + # references + # This is not strictly necessary in the current implementation but + # hey, maybe it avoids hard to debug errors when someone changes things + ref_indices = [a.copy() for a in scorer.ref_indices_ic] + ref_distances = [a.copy() for a in scorer.ref_distances_ic] + + # no matter what mapping/symmetries, the number of expected + # contacts stays the same + ligand_start_idx = scorer.chain_start_indices[-1] + ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms)) + n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices]) + + mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \ + chem_mapping = self._lddt_pli_get_mdl_data(model_ligand) + + if n_exp == 0: + # no contacts... nothing to compute... + return {"lddt_pli": None, + "lddt_pli_n_contacts": 0, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "bs_ref_res": trg_residues, + "bs_mdl_res": mdl_residues} + + # Distance hacking... remove any interchain distance except the ones + # with the ligand + for at_idx in range(ligand_start_idx): + mask = ref_indices[at_idx] >= ligand_start_idx + ref_indices[at_idx] = ref_indices[at_idx][mask] + ref_distances[at_idx] = ref_distances[at_idx][mask] + + #################### + # Setup alignments # + #################### + + # ref_mdl_alns refers to full chain mapper trg and mdl structures + # => need to adapt mdl sequence that only contain residues in contact + # with ligand + cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups, + chem_mapping, + mdl_bs, trg_bs) + + ############################################################### + # compute lDDT for all possible chain mappings and symmetries # + ############################################################### + + best_score = -1.0 + + # dummy alignment for ligand chains which is needed as input later on + l_aln = seq.CreateAlignment() + l_aln.AddSequence(seq.CreateSequence(trg_ligand_chain.name, + trg_ligand_res.GetOneLetterCode())) + l_aln.AddSequence(seq.CreateSequence(mdl_ligand_chain.name, + mdl_ligand_res.GetOneLetterCode())) + + mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3)) + for a_idx, a in enumerate(model_ligand.atoms): + p = a.GetPos() + mdl_ligand_pos[a_idx, 0] = p[0] + mdl_ligand_pos[a_idx, 1] = p[1] + mdl_ligand_pos[a_idx, 2] = p[2] + + for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping): + + lddt_chain_mapping = dict() + lddt_alns = dict() + for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping): + for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group): + # some mdl chains can be None + if mdl_ch is not None: + lddt_chain_mapping[mdl_ch] = ref_ch + lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)] + + # add ligand to lddt_chain_mapping/lddt_alns + lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name + lddt_alns[mdl_ligand_chain.name] = l_aln + + # already process model, positions will be manually hacked for each + # symmetry - small overhead for variables that are thrown away here + pos, _, _, _, _, _, lddt_symmetries = \ + scorer._ProcessModel(mdl_bs, lddt_chain_mapping, + residue_mapping = lddt_alns, + thresholds = self.lddt_pli_thresholds, + check_resnames = self.check_resnames) + + for (trg_sym, mdl_sym) in symmetries: + for mdl_i, trg_i in zip(mdl_sym, trg_sym): + pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :] + # we can pass ref_indices/ref_distances as + # sym_ref_indices/sym_ref_distances in + # scorer._ResolveSymmetries as we only have distances of the bs + # to the ligand and ligand atoms are "non-symmetric" + scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds, + lddt_symmetries, + ref_indices, + ref_distances) + # compute number of conserved distances for ligand atoms + conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices, + self.lddt_pli_thresholds, + ref_indices, + ref_distances), axis=0) + score = np.mean(conserved/n_exp) + + if score > best_score: + best_score = score + + # fill misc info to result object + best_result = {"lddt_pli": best_score, + "lddt_pli_n_contacts": n_exp, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "bs_ref_res": trg_residues, + "bs_mdl_res": mdl_residues} + + return best_result + + def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains, + non_mapped_cache, + mdl_bs, + mdl_ligand_res, + mdl_sym): + + n_exp = 0 + for ch_tuple in unmapped_chains: + if ch_tuple not in non_mapped_cache: + # for each ligand atom, we count the number of mappable atoms + # within lddt_pli_radius + counts = dict() + # the select statement also excludes the ligand in mdl_bs + # as it resides in a separate chain + mdl_cname = ch_tuple[1] + mdl_bs_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}") + for a in mdl_ligand_res.atoms: + close_atoms = \ + mdl_bs_ch.FindWithin(a.GetPos(), self.lddt_pli_radius) + N = 0 + for close_a in close_atoms: + at_key = (close_a.GetResidue().GetNumber(), + close_a.GetName()) + if at_key in self._mappable_atoms[ch_tuple]: + N += 1 + counts[a.hash_code] = N + + # fill cache + non_mapped_cache[ch_tuple] = counts + + # add number of mdl contacts which can be mapped to target + # as non-fulfilled contacts + counts = non_mapped_cache[ch_tuple] + lig_hash_codes = [a.hash_code for a in mdl_ligand_res.atoms] + for i in mdl_sym: + n_exp += counts[lig_hash_codes[i]] + + return n_exp + + + def _lddt_pli_get_mdl_data(self, model_ligand): + if model_ligand not in self._lddt_pli_model_data: + + mdl = self._chain_mapping_mdl + + mdl_residues = set() + for at in model_ligand.atoms: + close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radius) + for close_at in close_atoms: + mdl_residues.add(close_at.GetResidue()) + + max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds) + for r in mdl.residues: + r.SetIntProp("bs", 0) + for at in model_ligand.atoms: + close_atoms = mdl.FindWithin(at.GetPos(), max_r) + for close_at in close_atoms: + close_at.GetResidue().SetIntProp("bs", 1) + + mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True) + mdl_chains = set([ch.name for ch in mdl_bs.chains]) + + mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT) + mdl_ligand_chain = None + for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]: + try: + # I'm pretty sure, one of these chain names is not there... + mdl_ligand_chain = mdl_editor.InsertChain(cname) + break + except: + pass + if mdl_ligand_chain is None: + raise RuntimeError("Fuck this, I'm out...") + mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain, + model_ligand, + deep=True) + mdl_editor.RenameResidue(mdl_ligand_res, "LIG") + mdl_editor.SetResidueNumber(mdl_ligand_res, mol.ResNum(1)) + + chem_mapping = list() + for m in self._chem_mapping: + chem_mapping.append([x for x in m if x in mdl_chains]) + + self._lddt_pli_model_data[model_ligand] = (mdl_residues, + mdl_bs, + mdl_chains, + mdl_ligand_chain, + mdl_ligand_res, + chem_mapping) + + return self._lddt_pli_model_data[model_ligand] + + + def _lddt_pli_get_trg_data(self, target_ligand, max_r = None): + if target_ligand not in self._lddt_pli_target_data: + + trg = self.chain_mapper.target + + if max_r is None: + max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds) + + trg_residues = set() + for at in target_ligand.atoms: + close_atoms = trg.FindWithin(at.GetPos(), max_r) + for close_at in close_atoms: + trg_residues.add(close_at.GetResidue()) + + for r in trg.residues: + r.SetIntProp("bs", 0) + + for r in trg_residues: + r.SetIntProp("bs", 1) + + trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True) + trg_chains = set([ch.name for ch in trg_bs.chains]) + + trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT) + trg_ligand_chain = None + for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]: + try: + # I'm pretty sure, one of these chain names is not there yet + trg_ligand_chain = trg_editor.InsertChain(cname) + break + except: + pass + if trg_ligand_chain is None: + raise RuntimeError("Fuck this, I'm out...") + + trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain, + target_ligand, + deep=True) + trg_editor.RenameResidue(trg_ligand_res, "LIG") + trg_editor.SetResidueNumber(trg_ligand_res, mol.ResNum(1)) + + compound_name = trg_ligand_res.name + compound = lddt.CustomCompound.FromResidue(trg_ligand_res) + custom_compounds = {compound_name: compound} + + scorer = lddt.lDDTScorer(trg_bs, + custom_compounds = custom_compounds, + inclusion_radius = self.lddt_pli_radius) + + chem_groups = list() + for g in self.chain_mapper.chem_groups: + chem_groups.append([x for x in g if x in trg_chains]) + + self._lddt_pli_target_data[target_ligand] = (trg_residues, + trg_bs, + trg_chains, + trg_ligand_chain, + trg_ligand_res, + scorer, + chem_groups) + + return self._lddt_pli_target_data[target_ligand] + + + def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs, + ref_bs): + cut_ref_mdl_alns = dict() + for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping): + for ref_ch in ref_chem_group: + + ref_bs_chain = ref_bs.FindChain(ref_ch) + query = "cname=" + mol.QueryQuoteName(ref_ch) + ref_view = self.chain_mapper.target.Select(query) + + for mdl_ch in mdl_chem_group: + aln = self._ref_mdl_alns[(ref_ch, mdl_ch)] + + aln.AttachView(0, ref_view) + + mdl_bs_chain = mdl_bs.FindChain(mdl_ch) + query = "cname=" + mol.QueryQuoteName(mdl_ch) + aln.AttachView(1, self._chain_mapping_mdl.Select(query)) + + cut_mdl_seq = ['-'] * aln.GetLength() + cut_ref_seq = ['-'] * aln.GetLength() + for i, col in enumerate(aln): + + # check ref residue + r = col.GetResidue(0) + if r.IsValid(): + bs_r = ref_bs_chain.FindResidue(r.GetNumber()) + if bs_r.IsValid(): + cut_ref_seq[i] = col[0] + + # check mdl residue + r = col.GetResidue(1) + if r.IsValid(): + bs_r = mdl_bs_chain.FindResidue(r.GetNumber()) + if bs_r.IsValid(): + cut_mdl_seq[i] = col[1] + + cut_ref_seq = ''.join(cut_ref_seq) + cut_mdl_seq = ''.join(cut_mdl_seq) + cut_aln = seq.CreateAlignment() + cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq)) + cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq)) + cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln + return cut_ref_mdl_alns + + @property + def _mappable_atoms(self): + """ Stores mappable atoms given a chain mapping + + Store for each ref_ch,mdl_ch pair all mdl atoms that can be + mapped. Don't store mappable atoms as hashes but rather as tuple + (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that one might + operate on Copied EntityHandle objects without corresponding hashes. + Given a tuple defining c_pair: (ref_cname, mdl_cname), one + can check if a certain atom is mappable by evaluating: + if (mdl_r.GetNumber(), mdl_a.GetName()) in self._mappable_atoms(c_pair) + """ + if self.__mappable_atoms is None: + self.__mappable_atoms = dict() + for (ref_cname, mdl_cname), aln in self._ref_mdl_alns.items(): + self._mappable_atoms[(ref_cname, mdl_cname)] = set() + ref_ch = self.chain_mapper.target.Select(f"cname={mol.QueryQuoteName(ref_cname)}") + mdl_ch = self._chain_mapping_mdl.Select(f"cname={mol.QueryQuoteName(mdl_cname)}") + aln.AttachView(0, ref_ch) + aln.AttachView(1, mdl_ch) + for col in aln: + ref_r = col.GetResidue(0) + mdl_r = col.GetResidue(1) + if ref_r.IsValid() and mdl_r.IsValid(): + for mdl_a in mdl_r.atoms: + if ref_r.FindAtom(mdl_a.name).IsValid(): + c_key = (ref_cname, mdl_cname) + at_key = (mdl_r.GetNumber(), mdl_a.name) + self.__mappable_atoms[c_key].add(at_key) + + return self.__mappable_atoms + + @property + def _chem_mapping(self): + if self.__chem_mapping is None: + self.__chem_mapping, self.__chem_group_alns, \ + self.__chain_mapping_mdl = \ + self.chain_mapper.GetChemMapping(self.model) + return self.__chem_mapping + + @property + def _chem_group_alns(self): + if self.__chem_group_alns is None: + self.__chem_mapping, self.__chem_group_alns, \ + self.__chain_mapping_mdl = \ + self.chain_mapper.GetChemMapping(self.model) + return self.__chem_group_alns + + @property + def _ref_mdl_alns(self): + if self.__ref_mdl_alns is None: + self.__ref_mdl_alns = \ + chain_mapping._GetRefMdlAlns(self.chain_mapper.chem_groups, + self.chain_mapper.chem_group_alignments, + self._chem_mapping, + self._chem_group_alns) + return self.__ref_mdl_alns + + @property + def _chain_mapping_mdl(self): + if self.__chain_mapping_mdl is None: + self.__chem_mapping, self.__chem_group_alns, \ + self.__chain_mapping_mdl = \ + self.chain_mapper.GetChemMapping(self.model) + return self.__chain_mapping_mdl diff --git a/modules/mol/alg/pymod/ligand_scoring_scrmsd.py b/modules/mol/alg/pymod/ligand_scoring_scrmsd.py index 0370f10d7371488a9c431a3905824a9737484ee9..daa46bc8f7bda00ac6e5eac0f8135dd60da8dded 100644 --- a/modules/mol/alg/pymod/ligand_scoring_scrmsd.py +++ b/modules/mol/alg/pymod/ligand_scoring_scrmsd.py @@ -9,11 +9,11 @@ from ost.mol.alg import ligand_scoring_base class SCRMSDScorer(ligand_scoring_base.LigandScorer): def __init__(self, model, target, model_ligands=None, target_ligands=None, - resnum_alignments=False, rename_ligand_chain=False, - substructure_match=False, coverage_delta=0.2, - max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0, - model_bs_radius=25, binding_sites_topn=100000, - full_bs_search=False): + resnum_alignments=False, rename_ligand_chain=False, + substructure_match=False, coverage_delta=0.2, + max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0, + model_bs_radius=25, binding_sites_topn=100000, + full_bs_search=False): super().__init__(model, target, model_ligands = model_ligands, diff --git a/modules/mol/alg/tests/test_ligand_scoring_fancy.py b/modules/mol/alg/tests/test_ligand_scoring_fancy.py index 5dca71f51509d84355e36970390d5629634ef1de..a13a37be20018ad13e63ef1023ddbc6842b860a8 100644 --- a/modules/mol/alg/tests/test_ligand_scoring_fancy.py +++ b/modules/mol/alg/tests/test_ligand_scoring_fancy.py @@ -10,6 +10,7 @@ try: from ost.mol.alg.ligand_scoring_base import * from ost.mol.alg import ligand_scoring_base from ost.mol.alg import ligand_scoring_scrmsd + from ost.mol.alg import ligand_scoring_lddtpli except ImportError: print("Failed to import ligand_scoring.py. Happens when numpy, scipy or " "networkx is missing. Ignoring test_ligand_scoring.py tests.") @@ -281,6 +282,7 @@ class TestLigandScoringFancy(unittest.TestCase): with self.assertRaises(NoSymmetryError): ligand_scoring_scrmsd.SCRMSD(trg_g3d1_sub, mdl_g3d) # no full match + def test_compute_rmsd_scores(self): """Test that _compute_scores works. """ @@ -300,6 +302,128 @@ class TestLigandScoringFancy(unittest.TestCase): [0.29399303], [np.nan]]), decimal=5) + def test_compute_lddtpli_scores(self): + trg = _LoadMMCIF("1r8q.cif.gz") + mdl = _LoadMMCIF("P84080_model_02.cif.gz") + mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf")) + sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl_lig], None, + add_mdl_contacts = False, + lddt_pli_binding_site_radius = 4.0) + self.assertEqual(sc.score_matrix.shape, (7, 1)) + self.assertTrue(np.isnan(sc.score_matrix[0, 0])) + self.assertAlmostEqual(sc.score_matrix[1, 0], 0.99843, 5) + self.assertTrue(np.isnan(sc.score_matrix[2, 0])) + self.assertTrue(np.isnan(sc.score_matrix[3, 0])) + self.assertTrue(np.isnan(sc.score_matrix[4, 0])) + self.assertAlmostEqual(sc.score_matrix[5, 0], 1.0) + self.assertTrue(np.isnan(sc.score_matrix[6, 0])) + + def test_check_resnames(self): + """Test that the check_resname argument works. + + When set to True, it should raise an error if any residue in the + representation of the binding site in the model has a different + name than in the reference. Here we manually modify a residue + name to achieve that effect. This is only relevant for the LDDTPLIScorer + """ + trg_4c0a = _LoadMMCIF("4c0a.cif.gz") + trg = trg_4c0a.Select("cname=C or cname=I") + + # Here we modify the name of a residue in 4C0A (THR => TPO in C.15) + # This residue is in the binding site and should trigger the error + mdl = ost.mol.CreateEntityFromView(trg, include_exlusive_atoms=False) + ed = mdl.EditICS() + ed.RenameResidue(mdl.FindResidue("C", 15), "TPO") + ed.UpdateXCS() + + with self.assertRaises(RuntimeError): + sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl.FindResidue("I", 1)], [trg.FindResidue("I", 1)], check_resnames=True) + sc._compute_scores() + + sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl.FindResidue("I", 1)], [trg.FindResidue("I", 1)], check_resnames=False) + sc._compute_scores() + + def test_added_mdl_contacts(self): + + # binding site for ligand in chain G consists of chains A and B + prot = _LoadMMCIF("1r8q.cif.gz").Copy() + + # model has the full binding site + mdl = mol.CreateEntityFromView(prot.Select("cname=A,B,G"), True) + + # chain C has same sequence as chain A but is not in contact + # with ligand in chain G + # target has thus incomplete binding site only from chain B + trg = mol.CreateEntityFromView(prot.Select("cname=B,C,G"), True) + + # if added model contacts are not considered, the incomplete binding + # site only from chain B is perfectly reproduced by model which also has + # chain B + sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=False) + self.assertAlmostEqual(sc.score_matrix[0,0], 1.0, 5) + + # if added model contacts are considered, contributions from chain B are + # perfectly reproduced but all contacts of ligand towards chain A are + # added as penalty + sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True) + + lig = prot.Select("cname=G") + A_count = 0 + B_count = 0 + for a in lig.atoms: + close_atoms = mdl.FindWithin(a.GetPos(), sc.lddt_pli_radius) + for ca in close_atoms: + cname = ca.GetChain().GetName() + if cname == "G": + pass # its a ligand atom... + elif cname == "A": + A_count += 1 + elif cname == "B": + B_count += 1 + + self.assertAlmostEqual(sc.score_matrix[0,0], + B_count/(A_count + B_count), 5) + + # Same as before but additionally we remove residue TRP.66 + # from chain C in the target to test mapping magic... + # Chain C is NOT in contact with the ligand but we only + # add contacts from chain A as penalty that are mappable + # to the closest chain with same sequence. That would be + # chain C + query = "cname=B,G or (cname=C and rnum!=66)" + trg = mol.CreateEntityFromView(prot.Select(query), True) + sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True) + + TRP66_count = 0 + for a in lig.atoms: + close_atoms = mdl.FindWithin(a.GetPos(), sc.lddt_pli_radius) + for ca in close_atoms: + cname = ca.GetChain().GetName() + if cname == "A" and ca.GetResidue().GetNumber().GetNum() == 66: + TRP66_count += 1 + + self.assertEqual(TRP66_count, 134) + + # remove TRP66_count from original penalty + self.assertAlmostEqual(sc.score_matrix[0,0], + B_count/(A_count + B_count - TRP66_count), 5) + + # Move a random atom in the model from chain B towards the ligand center + # chain B is also present in the target and interacts with the ligand, + # but that atom would be far away and thus adds to the penalty. Since + # the ligand is small enough, the number of added contacts should be + # exactly the number of ligand atoms. + mdl_ed = mdl.EditXCS() + at = mdl.FindResidue("B", mol.ResNum(8)).FindAtom("NZ") + mdl_ed.SetAtomPos(at, lig.geometric_center) + sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True) + + # compared to the last assertAlmostEqual, we add the number of ligand + # atoms as additional penalties + self.assertAlmostEqual(sc.score_matrix[0,0], + B_count/(A_count + B_count - TRP66_count + \ + lig.GetAtomCount()), 5) + if __name__ == "__main__": from ost import testutils