diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index b5c62d2763b1fa9954d10c47187225a18916dccb..405383321c645dd052b32065456dde1ba2c47118 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -91,10 +91,21 @@ class LigandScorer: If None (default), a chain mapper will be initialized lazily as required. :type chain_mapper: :class:`ost.mol.alg.chain_mapping.ChainMapper` - + :param substructure_match: Set this to True to allow partial target ligand. + :type substructure_match: :class:`bool` + :param radius: Inclusion radius for the binding site. Any residue with + atoms within this distance of the ligand will be included + in the binding site. + :param radius: :class:`float` + :param lddt_pli_radius: lDDT inclusion radius for lDDT-PLI. + :param lddt_pli_radius: :class:`float` + :param lddt_bs_radius: lDDT inclusion radius for lDDT-BS. + :param lddt_bs_radius: :class:`float` """ def __init__(self, model, target, model_ligands=None, target_ligands=None, - resnum_alignments=False, chain_mapper=None): + resnum_alignments=False, chain_mapper=None, + substructure_match=False, radius=4.0, lddt_pli_radius=6.0, + lddt_bs_radius=10.0): if isinstance(model, mol.EntityView): self.model = mol.CreateEntityFromView(model, False) @@ -124,12 +135,19 @@ class LigandScorer: self._chain_mapper = chain_mapper self.resnum_alignments = resnum_alignments + self.substructure_match = substructure_match + self.radius = radius + self.lddt_pli_radius = lddt_pli_radius + self.lddt_bs_radius = lddt_bs_radius # lazily computed scores self._lddt_pli = None self._rmsd = None self._lddt_bs = None + # lazily precomputed variables + self._binding_sites = {} + @property def chain_mapper(self): """ Chain mapper object for given :attr:`target` @@ -288,24 +306,163 @@ class LigandScorer: topn=topn) return self._binding_sites[ligand.hash_code] - def _compute_rmsd(self): + @staticmethod + def _build_binding_site_entity(ligand, residues, extra_residues=[]): + """ Build an entity with all the binding site residues in chain A + and the ligand in chain _. Residues are renumbered consecutively from + 1. The ligand is assigned residue number 1 and residue name LIG. + Residues in extra_residues not in `residues` in the model are added + at the end of chain A. + + :param ligand: the Residue Handle of the ligand + :type ligand: :class:`~ost.mol.ResidueHandle` + :param residues: a list of binding site residues + :type residues: :class:`list` of :class:`~ost.mol.ResidueHandle` + :param extra_residues: an optional list with addition binding site + residues. Residues in this list which are not + in `residues` will be added at the end of chain + A. This allows for instance adding unmapped + residues missing from the model into the + reference binding site. + :type extra_residues: :class:`list` of :class:`~ost.mol.ResidueHandle` + :rtype: :class:`~ost.mol.EntityHandle` + """ + bs_ent = mol.CreateEntity() + ed = bs_ent.EditXCS() + bs_chain = ed.InsertChain("A") + seen_res_qn = [] + for resnum, old_res in enumerate(residues, 1): + seen_res_qn.append(old_res.qualified_name) + new_res = ed.AppendResidue(bs_chain, old_res.handle, + deep=True) + ed.SetResidueNumber(new_res, mol.ResNum(resnum)) + + # Add extra residues at the end. + for extra_res in extra_residues: + if extra_res.qualified_name not in seen_res_qn: + resnum += 1 + seen_res_qn.append(extra_res.qualified_name) + new_res = ed.AppendResidue(bs_chain, + extra_res.handle, + deep=True) + ed.SetResidueNumber(new_res, mol.ResNum(resnum)) + # Add the ligand in chain _ + ligand_chain = ed.InsertChain("_") + ligand_res = ed.AppendResidue(ligand_chain, ligand, + deep=True) + ed.RenameResidue(ligand_res, "LIG") + ed.SetResidueNumber(ligand_res, mol.ResNum(1)) + ed.UpdateICS() + + return bs_ent + + + def _compute_scores(self): """""" # Create the matrix - self._rmsd_matrix = np.empty((len(self.target_ligands), - len(self.model_ligands)), dtype=float) - for trg_ligand in self.target_ligands: - LogDebug("Compute RMSD for target ligand %s" % trg_ligand) - - for binding_site in self._get_binding_sites(trg_ligand): - import ipdb; ipdb.set_trace() - - for mdl_ligand in self.model_ligands: - LogDebug("Compute RMSD for model ligand %s" % mdl_ligand) - - # Get symmetry graphs - model_graph = model_ligand.spyrmsd_mol.to_graph() - target_graph = target_ligand.struct_spyrmsd_mol.to_graph() - pass + self._rmsd_matrix = np.full((len(self.target_ligands), + len(self.model_ligands)), + float("inf"), dtype=float) + self._lddt_pli_matrix = np.empty((len(self.target_ligands), + len(self.model_ligands)), dtype=dict) + for target_i, target_ligand in enumerate(self.target_ligands): + LogDebug("Compute RMSD for target ligand %s" % target_ligand) + + for binding_site in self._get_binding_sites(target_ligand): + if len(binding_site.substructure.residues) == 0: + LogWarning("No residue in proximity of target ligand " + "%s" % str(target_ligand)) + continue # next binding site + elif len(binding_site.ref_residues) == 0: + LogWarning("Binding site of %s not mapped to the model " % + str(target_ligand)) + continue # next binding site + + ref_bs_ent = self._build_binding_site_entity( + target_ligand, binding_site.ref_residues, + binding_site.substructure.residues) + ref_bs_ent_ligand = ref_bs_ent.FindResidue("_", 1) # by definition + + custom_compounds = { + ref_bs_ent_ligand.name: + mol.alg.lddt.CustomCompound.FromResidue( + ref_bs_ent_ligand)} + lddt_scorer = mol.alg.lddt.lDDTScorer( + ref_bs_ent, + custom_compounds=custom_compounds, + inclusion_radius=self.lddt_pli_radius) + + for model_i, model_ligand in enumerate(self.model_ligands): + try: + symmetries = _ComputeSymmetries( + model_ligand, target_ligand, + substructure_match=self.substructure_match, + by_atom_index=True) + except NoSymmetryError: + # Ligands are different - skip + LogDebug("No symmetry between %s and %s" % ( + str(model_ligand), str(target_ligand))) + continue + + LogDebug("Compute RMSD for model ligand %s" % model_ligand) + + rmsd = SCRMSD(model_ligand, target_ligand, + transformation=binding_site.transform, + substructure_match=self.substructure_match) + self._rmsd_matrix[target_i, model_i] = rmsd + + mdl_bs_ent = self._build_binding_site_entity( + model_ligand, binding_site.mdl_residues, []) + mdl_bs_ent_ligand = mdl_bs_ent.FindResidue("_", 1) # by definition + + # Prepare to save the data for this target/model mapping + # TODO: figure out if this try/except is still needed + # try: + # bb_rmsd = binding_site.bb_rmsd + # except Exception as err: + # # TODO: switch to whole backbone superposition - and drop try/except + # LogWarning("Can't calculate backbone RMSD: %s" + # " - setting to Infinity" % str(err)) + # bb_rmsd = float("inf") + self._lddt_pli_matrix[target_i, model_i] = { + "lddt_pli": 0, + "lddt_local": None, + "lddt_pli_n_contacts": None, + "rmsd": rmsd, + # "symmetry_number": i, + "chain_mapping": binding_site.GetFlatChainMapping(), + "lddt_bs": binding_site.lDDT, + "bb_rmsd": binding_site.bb_rmsd, + "bs_num_res": len(binding_site.substructure.residues), + "bs_num_overlap_res": len(binding_site.ref_residues), + } + + # Now for each symmetry, loop and rename atoms according + # to ref. + mdl_editor = mdl_bs_ent.EditXCS() + for i, (trg_sym, mdl_sym) in enumerate(symmetries): + # Prepare Entities for RMSD + for mdl_anum, trg_anum in zip(mdl_sym, trg_sym): + # Rename model atoms according to symmetry + trg_atom = ref_bs_ent_ligand.atoms[trg_anum] + mdl_atom = mdl_bs_ent_ligand.atoms[mdl_anum] + mdl_editor.RenameAtom(mdl_atom, trg_atom.name) + mdl_editor.UpdateICS() + + global_lddt, local_lddt, lddt_tot, lddt_cons, n_res, \ + n_cont, n_cons = lddt_scorer.lDDT( + mdl_bs_ent, chain_mapping={"A": "A", "_": "_"}, + no_intrachain=True, + return_dist_test=True) + + # Save results? + best_lddt = self._lddt_pli_matrix[target_i, model_i]["lddt_pli"] + if global_lddt > best_lddt: + self._lddt_pli_matrix[target_i, model_i].update({ + "lddt_pli": global_lddt, + "lddt_local": local_lddt, + "lddt_pli_n_contacts": lddt_tot / 8, + }) def ResidueToGraph(residue, by_atom_index=False): diff --git a/modules/mol/alg/tests/test_ligand_scoring.py b/modules/mol/alg/tests/test_ligand_scoring.py index 8ee493364fb41ea512fe0ecc18d647c79c1b3819..8e3eb94a7f3258c63cf066f9e9c4c4a539327fa6 100644 --- a/modules/mol/alg/tests/test_ligand_scoring.py +++ b/modules/mol/alg/tests/test_ligand_scoring.py @@ -1,5 +1,8 @@ import unittest, os, sys +import numpy as np + +import ost from ost import io, mol, geom # check if we can import: fails if numpy or scipy not available try: @@ -10,6 +13,7 @@ except ImportError: "networkx is missing. Ignoring test_ligand_scoring.py tests.") sys.exit(0) +#ost.PushVerbosityLevel(ost.LogLevel.Debug) class TestLigandScoring(unittest.TestCase): @@ -43,7 +47,7 @@ class TestLigandScoring(unittest.TestCase): # IsLigand flag should still be set even on not selected ligands assert len([r for r in sc.target.residues if r.is_ligand]) == 7 assert len([r for r in sc.model.residues if r.is_ligand]) == 1 - + # Ensure the residues are not copied assert len(sc.target.Select("rname=MG").residues) == 2 assert len(sc.target.Select("rname=G3D").residues) == 2 @@ -218,6 +222,47 @@ class TestLigandScoring(unittest.TestCase): SCRMSD(mdl_g3d, trg_g3d1_sub) # no full match rmsd = SCRMSD(mdl_g3d, trg_g3d1_sub, substructure_match=True) + def test__compute_scores(self): + """Test that _compute_scores works. + """ + trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) + mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True) + sc = LigandScorer(mdl, trg, None, None) + + # Note: expect warning about Binding site of H.ZN1 not mapped to the model + sc._compute_scores() + + # Check RMSD + assert sc._rmsd_matrix.shape == (7, 1) + np.testing.assert_almost_equal(sc._rmsd_matrix, np.array( + [[np.inf], + [0.04244993], + [np.inf], + [np.inf], + [np.inf], + [0.29399303], + [np.inf]])) + + # Check lDDT-PLI + assert sc._lddt_pli_matrix.shape == (7, 1) + assert sc._lddt_pli_matrix[0, 0] is None + assert sc._lddt_pli_matrix[2, 0] is None + assert sc._lddt_pli_matrix[3, 0] is None + assert sc._lddt_pli_matrix[4, 0] is None + assert sc._lddt_pli_matrix[6, 0] is None + assert sc._lddt_pli_matrix[1, 0]["lddt_pli_n_contacts"] == 638 + assert sc._lddt_pli_matrix[5, 0]["lddt_pli_n_contacts"] == 636 + assert sc._lddt_pli_matrix[1, 0]["chain_mapping"] == {'A': 'A'} + assert sc._lddt_pli_matrix[5, 0]["chain_mapping"] == {'C': 'A'} + self.assertAlmostEqual(sc._lddt_pli_matrix[1, 0]["lddt_pli"], 0.99843, 5) + self.assertAlmostEqual(sc._lddt_pli_matrix[5, 0]["lddt_pli"], 1.0) + self.assertAlmostEqual(sc._lddt_pli_matrix[1, 0]["rmsd"], 0.04244993) + self.assertAlmostEqual(sc._lddt_pli_matrix[5, 0]["rmsd"], 0.29399303) + assert sc._lddt_pli_matrix[1, 0]["bs_num_res"] == 15 + assert sc._lddt_pli_matrix[5, 0]["bs_num_res"] == 15 + assert sc._lddt_pli_matrix[1, 0]["bs_num_overlap_res"] == 15 + assert sc._lddt_pli_matrix[5, 0]["bs_num_overlap_res"] == 15 + if __name__ == "__main__": from ost import testutils