From c260b0d6e1bfba9f415f23dbcc4645015bf69acd Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Tue, 23 Apr 2024 15:15:54 +0200 Subject: [PATCH] ligand scoring: refactor - separate scrmsd and lddtpli computations in separate functions - optionally reduce the binding site search space in the model --- modules/mol/alg/pymod/ligand_scoring.py | 493 +++++++++++++----------- 1 file changed, 271 insertions(+), 222 deletions(-) diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index 70d8f08f0..c630598ae 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -254,6 +254,11 @@ class LigandScorer: of None, and reason for not being assigned in the \\*_details matrix. Defaults to False. :type unassigned: :class:`bool` + :param full_bs_search: If True, all potential binding sites in the model + are searched for each target binding site. If False, + the search space in the model is reduced to regions + around model ligands. + :type full_bs_search: :class:`bool` """ def __init__(self, model, target, model_ligands=None, target_ligands=None, resnum_alignments=False, check_resnames=True, @@ -263,7 +268,7 @@ class LigandScorer: radius=4.0, lddt_pli_radius=6.0, lddt_lp_radius=10.0, binding_sites_topn=100000, global_chain_mapping=False, rmsd_assignment=False, n_max_naive=12, max_symmetries=1e5, - custom_mapping=None, unassigned=False): + custom_mapping=None, unassigned=False, full_bs_search=True): if isinstance(model, mol.EntityView): self.model = mol.CreateEntityFromView(model, False) @@ -316,6 +321,7 @@ class LigandScorer: self.max_symmetries = max_symmetries self.unassigned = unassigned self.coverage_delta = coverage_delta + self.full_bs_search = full_bs_search # scoring matrices self._rmsd_matrix = None @@ -330,15 +336,34 @@ class LigandScorer: self._lddt_pli_details = None # lazily precomputed variables - self._binding_sites = {} self.__model_mapping = None + # Residues that are in contact with a ligand => binding site + # defined as all residues with at least one atom within self.radius + # key: ligand.handle.hash_code, value: EntityView of whatever + # entity ligand belongs to + self._binding_sites = dict() + # lazily precomputed variables to speedup GetRepr chain mapping calls + # for localized GetRepr searches self._chem_mapping = None self._chem_group_alns = None self._chain_mapping_mdl = None self._get_repr_input = dict() + # the actual representations as returned by ChainMapper.GetRepr + # key: (target_ligand.handle.hash_code, model_ligand.handle.hash_code) + # value: list of repr results + self._repr = dict() + + # cache for rmsd values + # rmsd is used as tie breaker in lddt-pli, we therefore need access to + # the rmsd for each target_ligand/model_ligand/repr combination + # key: (target_ligand.handle.hash_code, model_ligand.handle.hash_code, + # repr_id) + # value: rmsd + self._rmsd_cache = dict() + # Bookkeeping of unassigned ligands self._unassigned_target_ligands = None self._unassigned_model_ligands = None @@ -370,6 +395,56 @@ class LigandScorer: resnum_alignments=self.resnum_alignments) return self._chain_mapper + def get_target_binding_site(self, target_ligand): + + if target_ligand.handle.hash_code not in self._binding_sites: + + # create view of reference binding site + ref_residues_hashes = set() # helper to keep track of added residues + ignored_residue_hashes = {target_ligand.hash_code} + for ligand_at in target_ligand.atoms: + close_atoms = self.target.FindWithin(ligand_at.GetPos(), self.radius) + for close_at in close_atoms: + # Skip any residue not in the chain mapping target + ref_res = close_at.GetResidue() + h = ref_res.handle.GetHashCode() + if h not in ref_residues_hashes and \ + h not in ignored_residue_hashes: + if self.chain_mapper.target.ViewForHandle(ref_res).IsValid(): + h = ref_res.handle.GetHashCode() + ref_residues_hashes.add(h) + elif ref_res.is_ligand: + LogWarning("Ignoring ligand %s in binding site of %s" % ( + ref_res.qualified_name, target_ligand.qualified_name)) + ignored_residue_hashes.add(h) + elif ref_res.chem_type == mol.ChemType.WATERS: + pass # That's ok, no need to warn + else: + LogWarning("Ignoring residue %s in binding site of %s" % ( + ref_res.qualified_name, target_ligand.qualified_name)) + ignored_residue_hashes.add(h) + + ref_bs = self.target.CreateEmptyView() + if ref_residues_hashes: + # reason for doing that separately is to guarantee same ordering of + # residues as in underlying entity. (Reorder by ResNum seems only + # available on ChainHandles) + for ch in self.target.chains: + for r in ch.residues: + if r.handle.GetHashCode() in ref_residues_hashes: + ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL) + if len(ref_bs.residues) == 0: + raise RuntimeError("Failed to add proximity residues to " + "the reference binding site entity") + else: + # Flag missing binding site + self._unassigned_target_ligands_reason[target_ligand] = ("binding_site", + "No residue in proximity of the target ligand") + + self._binding_sites[target_ligand.handle.hash_code] = ref_bs + + return self._binding_sites[target_ligand.handle.hash_code] + @property def _model_mapping(self): """Get the global chain mapping for the model.""" @@ -506,91 +581,6 @@ class LigandScorer: new_editor.UpdateICS() return extracted_ligands - def _get_binding_sites(self, ligand): - """Find representations of the binding site of *ligand* in the model. - - Only consider protein and nucleic acid chains that pass the criteria - for the :class:`ost.mol.alg.chain_mapping`. This means ignoring other - ligands, waters, short polymers as well as any incorrectly connected - chain that may be in proximity. - - :param ligand: Defines the binding site to identify. - :type ligand: :class:`~ost.mol.ResidueHandle` - """ - if ligand.hash_code not in self._binding_sites: - - # create view of reference binding site - ref_residues_hashes = set() # helper to keep track of added residues - ignored_residue_hashes = {ligand.hash_code} - for ligand_at in ligand.atoms: - close_atoms = self.target.FindWithin(ligand_at.GetPos(), self.radius) - for close_at in close_atoms: - # Skip any residue not in the chain mapping target - ref_res = close_at.GetResidue() - h = ref_res.handle.GetHashCode() - if h not in ref_residues_hashes and \ - h not in ignored_residue_hashes: - if self.chain_mapper.target.ViewForHandle(ref_res).IsValid(): - h = ref_res.handle.GetHashCode() - ref_residues_hashes.add(h) - elif ref_res.is_ligand: - LogWarning("Ignoring ligand %s in binding site of %s" % ( - ref_res.qualified_name, ligand.qualified_name)) - ignored_residue_hashes.add(h) - elif ref_res.chem_type == mol.ChemType.WATERS: - pass # That's ok, no need to warn - else: - LogWarning("Ignoring residue %s in binding site of %s" % ( - ref_res.qualified_name, ligand.qualified_name)) - ignored_residue_hashes.add(h) - - if ref_residues_hashes: - # reason for doing that separately is to guarantee same ordering of - # residues as in underlying entity. (Reorder by ResNum seems only - # available on ChainHandles) - ref_bs = self.target.CreateEmptyView() - for ch in self.target.chains: - for r in ch.residues: - if r.handle.GetHashCode() in ref_residues_hashes: - ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL) - if len(ref_bs.residues) == 0: - raise RuntimeError("Failed to add proximity residues to " - "the reference binding site entity") - - reprs = list() - for model_ligand in self.model_ligands: - # Find the representations - if self.global_chain_mapping: - reprs.extend(self.chain_mapper.GetRepr(ref_bs, self.model, - inclusion_radius=self.lddt_lp_radius, - chem_mapping_result = self.get_get_repr_input(model_ligand), - global_mapping = self._model_mapping)) - else: - reprs.extend(self.chain_mapper.GetRepr(ref_bs, self.model, - inclusion_radius=self.lddt_lp_radius, - topn=self.binding_sites_topn, - chem_mapping_result = self.get_get_repr_input(model_ligand))) - - ################################################ - # TODO: sort by lDDT and ensure unique results # - ################################################ - - # Flag empty representation - self._binding_sites[ligand.hash_code] = reprs - if not self._binding_sites[ligand.hash_code]: - self._unassigned_target_ligands_reason[ligand] = ( - "model_representation", - "No representation of the reference binding site was " - "found in the model") - - else: # if ref_residues_hashes - # Flag missing binding site - self._unassigned_target_ligands_reason[ligand] = ("binding_site", - "No residue in proximity of the target ligand") - self._binding_sites[ligand.hash_code] = [] - - return self._binding_sites[ligand.hash_code] - @staticmethod def _build_binding_site_entity(ligand, residues, extra_residues=[]): """ Build an entity with all the binding site residues in chain A @@ -646,20 +636,26 @@ class LigandScorer: Compute the RMSD and lDDT-PLI scores for every possible target-model ligand pair and store the result in internal matrices. """ - # Create the result matrices - rmsd_full_matrix = np.empty( + ############################## + # Create the result matrices # + ############################## + self._rmsd_full_matrix = np.empty( (len(self.target_ligands), len(self.model_ligands)), dtype=dict) - lddt_pli_full_matrix = np.empty( + self._lddt_pli_full_matrix = np.empty( (len(self.target_ligands), len(self.model_ligands)), dtype=dict) self._assignment_isomorphisms = np.full( (len(self.target_ligands), len(self.model_ligands)), fill_value=np.nan) self._assignment_match_coverage = np.zeros( (len(self.target_ligands), len(self.model_ligands))) - for target_i, target_ligand in enumerate(self.target_ligands): + for target_id, target_ligand in enumerate(self.target_ligands): LogVerbose("Analyzing target ligand %s" % target_ligand) + for model_id, model_ligand in enumerate(self.model_ligands): + LogVerbose("Compare to model ligand %s" % model_ligand) - for model_i, model_ligand in enumerate(self.model_ligands): + ######################################################### + # Compute symmetries for given target/model ligand pair # + ######################################################### try: symmetries = _ComputeSymmetries( model_ligand, target_ligand, @@ -672,147 +668,156 @@ class LigandScorer: # Ligands are different - skip LogVerbose("No symmetry between %s and %s" % ( str(model_ligand), str(target_ligand))) - self._assignment_isomorphisms[target_i, model_i] = 0. + self._assignment_isomorphisms[target_id, model_id] = 0. continue except TooManySymmetriesError: # Ligands are too symmetrical - skip LogVerbose("Too many symmetries between %s and %s" % ( str(model_ligand), str(target_ligand))) - self._assignment_isomorphisms[target_i, model_i] = -1. + self._assignment_isomorphisms[target_id, model_id] = -1. continue except DisconnectedGraphError: # Disconnected graph is handled elsewhere continue - # for binding_site in _get_binding_site_matches(target_ligand, model_ligand): - for binding_site in self._get_binding_sites(target_ligand): - LogVerbose("Found binding site with chain mapping %s" % ( - binding_site.GetFlatChainMapping())) - - # Build the reference binding site and scorer - # Note: this could be refactored to avoid building the - # binding site entity and lDDT scorer repeatedly - ref_bs_ent = self._build_binding_site_entity( - target_ligand, binding_site.ref_residues, - binding_site.substructure.residues) - ref_bs_ent_ligand = ref_bs_ent.FindResidue("_", 1) # by definition - - custom_compounds = { - ref_bs_ent_ligand.name: - mol.alg.lddt.CustomCompound.FromResidue( - ref_bs_ent_ligand)} - lddt_scorer = mol.alg.lddt.lDDTScorer( - ref_bs_ent, - custom_compounds=custom_compounds, - inclusion_radius=self.lddt_pli_radius) - - substructure_match = len(symmetries[0][0]) != len( - model_ligand.atoms) - coverage = len(symmetries[0][0]) / len(model_ligand.atoms) - self._assignment_match_coverage[target_i, model_i] = coverage - self._assignment_isomorphisms[target_i, model_i] = 1. - - rmsd = _SCRMSD_symmetries(symmetries, model_ligand, - target_ligand, transformation=binding_site.transform) - LogDebug("RMSD: %.4f" % rmsd) - - # Save results? - if not rmsd_full_matrix[target_i, model_i] or \ - rmsd_full_matrix[target_i, model_i]["rmsd"] > rmsd: - rmsd_full_matrix[target_i, model_i] = { - "rmsd": rmsd, - "lddt_lp": binding_site.lDDT, - "bs_ref_res": binding_site.substructure.residues, - "bs_ref_res_mapped": binding_site.ref_residues, - "bs_mdl_res_mapped": binding_site.mdl_residues, - "bb_rmsd": binding_site.bb_rmsd, - "target_ligand": target_ligand, - "model_ligand": model_ligand, - "chain_mapping": binding_site.GetFlatChainMapping(), - "transform": binding_site.transform, - "substructure_match": substructure_match, - "coverage": coverage, - "inconsistent_residues": binding_site.inconsistent_residues, - } - if self.unassigned: - rmsd_full_matrix[target_i, model_i][ - "unassigned"] = False - LogDebug("Saved RMSD") - - mdl_bs_ent = self._build_binding_site_entity( - model_ligand, binding_site.mdl_residues, []) - mdl_bs_ent_ligand = mdl_bs_ent.FindResidue("_", 1) # by definition - - # Now for each symmetry, loop and rename atoms according - # to ref. - mdl_editor = mdl_bs_ent.EditXCS() - for i, (trg_sym, mdl_sym) in enumerate(symmetries): - # Prepare Entities for RMSD - for mdl_anum, trg_anum in zip(mdl_sym, trg_sym): - # Rename model atoms according to symmetry - trg_atom = ref_bs_ent_ligand.atoms[trg_anum] - mdl_atom = mdl_bs_ent_ligand.atoms[mdl_anum] - mdl_editor.RenameAtom(mdl_atom, trg_atom.name) - mdl_editor.UpdateICS() - - global_lddt, local_lddt, lddt_tot, lddt_cons, n_res, \ - n_cont, n_cons = lddt_scorer.lDDT( - mdl_bs_ent, chain_mapping={"A": "A", "_": "_"}, - no_intrachain=True, - return_dist_test=True, - check_resnames=self.check_resnames) - if lddt_tot == 0: - LogDebug("lDDT-PLI is undefined for %d" % i) - self._unassigned_target_ligands_reason[ - target_ligand] = ("no_contact", - "No lDDT-PLI contacts in the" - " reference structure") - else: - LogDebug("lDDT-PLI for symmetry %d: %.4f" % (i, global_lddt)) + substructure_match = len(symmetries[0][0]) != len(model_ligand.atoms) + coverage = len(symmetries[0][0]) / len(model_ligand.atoms) + self._assignment_match_coverage[target_id, model_id] = coverage + self._assignment_isomorphisms[target_id, model_id] = 1. + + ################################################################ + # Compute best rmsd/lddt-pli by naively enumerating symmetries # + ################################################################ + # rmsds MUST be computed first, as lDDT uses them as tiebreaker + # and expects the values to be in self._rmsd_cache + rmsd_result = self._compute_rmsd(symmetries, target_ligand, + model_ligand) + lddt_pli_result = self._compute_lddtpli(symmetries, target_ligand, + model_ligand) + + ########################################### + # Extend results by symmetry related info # + ########################################### + if rmsd_result is not None: + rmsd_result["substructure_match"] = substructure_match + rmsd_result["coverage"] = coverage + if self.unassigned: + rmsd_result["unassigned"] = False + + if lddt_pli_result is not None: + lddt_pli_result["substructure_match"] = substructure_match + lddt_pli_result["coverage"] = coverage + if self.unassigned: + lddt_pli_result["unassigned"] = False + + ############ + # Finalize # + ############ + self._lddt_pli_full_matrix[target_id, model_id] = lddt_pli_result + self._rmsd_full_matrix[target_id, model_id] = rmsd_result + + + def _compute_rmsd(self, symmetries, target_ligand, model_ligand): + best_rmsd_result = None + for r_i, r in enumerate(self.get_repr(target_ligand, model_ligand)): + rmsd = _SCRMSD_symmetries(symmetries, model_ligand, + target_ligand, transformation=r.transform) + + cache_key = (target_ligand.handle.hash_code, + model_ligand.handle.hash_code, r_i) + self._rmsd_cache[cache_key] = rmsd + + if best_rmsd_result is None or rmsd < best_rmsd_result["rmsd"]: + best_rmsd_result = {"rmsd": rmsd, + "lddt_lp": r.lDDT, + "bs_ref_res": r.substructure.residues, + "bs_ref_res_mapped": r.ref_residues, + "bs_mdl_res_mapped": r.mdl_residues, + "bb_rmsd": r.bb_rmsd, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "chain_mapping": r.GetFlatChainMapping(), + "transform": r.transform, + "inconsistent_residues": r.inconsistent_residues} + + return best_rmsd_result + + + def _compute_lddtpli(self, symmetries, target_ligand, model_ligand): + ref_bs = self.get_target_binding_site(target_ligand) + best_lddt_result = None + for r_i, r in enumerate(self.get_repr(target_ligand, model_ligand)): + ref_bs_ent = self._build_binding_site_entity( + target_ligand, r.ref_residues, + r.substructure.residues) + ref_bs_ent_ligand = ref_bs_ent.FindResidue("_", 1) # by definition + + custom_compounds = { + ref_bs_ent_ligand.name: + mol.alg.lddt.CustomCompound.FromResidue( + ref_bs_ent_ligand)} + lddt_scorer = mol.alg.lddt.lDDTScorer( + ref_bs_ent, + custom_compounds=custom_compounds, + inclusion_radius=self.lddt_pli_radius) + + lddt_tot = 4 * sum([len(x) for x in lddt_scorer.ref_indices_ic]) + if lddt_tot == 0: + # it's a space ship in the reference! + self._unassigned_target_ligands_reason[ + target_ligand] = ("no_contact", + "No lDDT-PLI contacts in the" + " reference structure") + #continue + mdl_bs_ent = self._build_binding_site_entity( + model_ligand, r.mdl_residues, []) + mdl_bs_ent_ligand = mdl_bs_ent.FindResidue("_", 1) # by definition + # Now for each symmetry, loop and rename atoms according + # to ref. + mdl_editor = mdl_bs_ent.EditXCS() + for i, (trg_sym, mdl_sym) in enumerate(symmetries): + for mdl_anum, trg_anum in zip(mdl_sym, trg_sym): + # Rename model atoms according to symmetry + trg_atom = ref_bs_ent_ligand.atoms[trg_anum] + mdl_atom = mdl_bs_ent_ligand.atoms[mdl_anum] + mdl_editor.RenameAtom(mdl_atom, trg_atom.name) + mdl_editor.UpdateICS() + + global_lddt, local_lddt = lddt_scorer.lDDT( + mdl_bs_ent, chain_mapping={"A": "A", "_": "_"}, + no_intrachain=True, + check_resnames=self.check_resnames) + + its_awesome = (best_lddt_result is None) or \ + (global_lddt > best_lddt_result["lddt_pli"]) + + # additionally consider rmsd as tiebreaker + if (not its_awesome) and (global_lddt == best_lddt_result["lddt_pli"]): + rmsd_cache_key = (target_ligand.handle.hash_code, + model_ligand.handle.hash_code, r_i) + rmsd = self._rmsd_cache[rmsd_cache_key] + if rmsd < best_lddt_result["rmsd"]: + its_awesome = True + + if its_awesome: + rmsd_cache_key = (target_ligand.handle.hash_code, + model_ligand.handle.hash_code, r_i) + best_lddt_result = {"lddt_pli": global_lddt, + "lddt_lp": r.lDDT, + "lddt_pli_n_contacts": lddt_tot, + "rmsd": self._rmsd_cache[rmsd_cache_key], + "bs_ref_res": r.substructure.residues, + "bs_ref_res_mapped": r.ref_residues, + "bs_mdl_res_mapped": r.mdl_residues, + "bb_rmsd": r.bb_rmsd, + "target_ligand": target_ligand, + "model_ligand": model_ligand, + "chain_mapping": r.GetFlatChainMapping(), + "transform": r.transform, + "inconsistent_residues": r.inconsistent_residues} + + return best_lddt_result - # Save results? - if not lddt_pli_full_matrix[target_i, model_i]: - # First iteration - save_lddt = True - else: - last_best_lddt = lddt_pli_full_matrix[ - target_i, model_i]["lddt_pli"] - last_best_rmsd = lddt_pli_full_matrix[ - target_i, model_i]["rmsd"] - if global_lddt > last_best_lddt: - # Better lDDT-PLI - save_lddt = True - elif global_lddt == last_best_lddt and \ - rmsd < last_best_rmsd: - # Same lDDT-PLI, better RMSD - save_lddt = True - else: - save_lddt = False - if save_lddt: - lddt_pli_full_matrix[target_i, model_i] = { - "lddt_pli": global_lddt, - "rmsd": rmsd, - "lddt_lp": binding_site.lDDT, - "lddt_pli_n_contacts": lddt_tot, - "bs_ref_res": binding_site.substructure.residues, - "bs_ref_res_mapped": binding_site.ref_residues, - "bs_mdl_res_mapped": binding_site.mdl_residues, - "bb_rmsd": binding_site.bb_rmsd, - "target_ligand": target_ligand, - "model_ligand": model_ligand, - "chain_mapping": binding_site.GetFlatChainMapping(), - "transform": binding_site.transform, - "substructure_match": substructure_match, - "coverage": coverage, - "inconsistent_residues": binding_site.inconsistent_residues, - } - if self.unassigned: - lddt_pli_full_matrix[target_i, model_i][ - "unassigned"] = False - LogDebug("Saved lDDT-PLI") - - self._rmsd_full_matrix = rmsd_full_matrix - self._lddt_pli_full_matrix = lddt_pli_full_matrix @staticmethod def _find_ligand_assignment(mat1, mat2=None, coverage=None, coverage_delta=None): @@ -1781,6 +1786,50 @@ class LigandScorer: self._get_repr_input[mdl_ligand.hash_code][0]) + def get_repr(self, target_ligand, model_ligand): + + key = None + if self.full_bs_search: + # all possible binding sites, independent from actual model ligand + key = (target_ligand.handle.hash_code, 0) + else: + key = (target_ligand.handle.hash_code, model_ligand.handle.hash_code) + + if key not in self._repr: + reprs = list() + ref_bs = self.get_target_binding_site(target_ligand) + if self.full_bs_search: + if self.global_chain_mapping: + reprs = self.chain_mapper.GetRepr( + ref_bs, self.model, inclusion_radius=self.lddt_lp_radius, + global_mapping = self._model_mapping) + else: + reprs = self.chain_mapper.GetRepr( + ref_bs, self.model, inclusion_radius=self.lddt_lp_radius, + topn=self.binding_sites_topn) + else: + if self.global_chain_mapping: + reprs = self.chain_mapper.GetRepr(ref_bs, self.model, + inclusion_radius=self.lddt_lp_radius, + chem_mapping_result = self.get_get_repr_input(model_ligand), + global_mapping = self._model_mapping) + else: + reprs = self.chain_mapper.GetRepr(ref_bs, self.model, + inclusion_radius=self.lddt_lp_radius, + topn=self.binding_sites_topn, + chem_mapping_result = self.get_get_repr_input(model_ligand)) + self._repr[key] = reprs + if len(reprs) == 0: + # whatever is in there already has precedence + if target_ligand not in self._unassigned_target_ligands_reason: + self._unassigned_target_ligands_reason[target_ligand] = ( + "model_representation", + "No representation of the reference binding site was " + "found in the model") + + return self._repr[key] + + def _ResidueToGraph(residue, by_atom_index=False): """Return a NetworkX graph representation of the residue. -- GitLab