Skip to content
Snippets Groups Projects
Unverified Commit 9ddd09a8 authored by Xavier Robin's avatar Xavier Robin
Browse files

feat: SCHWED-5783 compute lDDT-PLI and RMSD

parent 26a7e701
No related branches found
No related tags found
No related merge requests found
......@@ -91,10 +91,21 @@ class LigandScorer:
If None (default), a chain mapper will be initialized
lazily as required.
:type chain_mapper: :class:`ost.mol.alg.chain_mapping.ChainMapper`
:param substructure_match: Set this to True to allow partial target ligand.
:type substructure_match: :class:`bool`
:param radius: Inclusion radius for the binding site. Any residue with
atoms within this distance of the ligand will be included
in the binding site.
:param radius: :class:`float`
:param lddt_pli_radius: lDDT inclusion radius for lDDT-PLI.
:param lddt_pli_radius: :class:`float`
:param lddt_bs_radius: lDDT inclusion radius for lDDT-BS.
:param lddt_bs_radius: :class:`float`
"""
def __init__(self, model, target, model_ligands=None, target_ligands=None,
resnum_alignments=False, chain_mapper=None):
resnum_alignments=False, chain_mapper=None,
substructure_match=False, radius=4.0, lddt_pli_radius=6.0,
lddt_bs_radius=10.0):
if isinstance(model, mol.EntityView):
self.model = mol.CreateEntityFromView(model, False)
......@@ -124,12 +135,19 @@ class LigandScorer:
self._chain_mapper = chain_mapper
self.resnum_alignments = resnum_alignments
self.substructure_match = substructure_match
self.radius = radius
self.lddt_pli_radius = lddt_pli_radius
self.lddt_bs_radius = lddt_bs_radius
# lazily computed scores
self._lddt_pli = None
self._rmsd = None
self._lddt_bs = None
# lazily precomputed variables
self._binding_sites = {}
@property
def chain_mapper(self):
""" Chain mapper object for given :attr:`target`
......@@ -288,24 +306,163 @@ class LigandScorer:
topn=topn)
return self._binding_sites[ligand.hash_code]
def _compute_rmsd(self):
@staticmethod
def _build_binding_site_entity(ligand, residues, extra_residues=[]):
""" Build an entity with all the binding site residues in chain A
and the ligand in chain _. Residues are renumbered consecutively from
1. The ligand is assigned residue number 1 and residue name LIG.
Residues in extra_residues not in `residues` in the model are added
at the end of chain A.
:param ligand: the Residue Handle of the ligand
:type ligand: :class:`~ost.mol.ResidueHandle`
:param residues: a list of binding site residues
:type residues: :class:`list` of :class:`~ost.mol.ResidueHandle`
:param extra_residues: an optional list with addition binding site
residues. Residues in this list which are not
in `residues` will be added at the end of chain
A. This allows for instance adding unmapped
residues missing from the model into the
reference binding site.
:type extra_residues: :class:`list` of :class:`~ost.mol.ResidueHandle`
:rtype: :class:`~ost.mol.EntityHandle`
"""
bs_ent = mol.CreateEntity()
ed = bs_ent.EditXCS()
bs_chain = ed.InsertChain("A")
seen_res_qn = []
for resnum, old_res in enumerate(residues, 1):
seen_res_qn.append(old_res.qualified_name)
new_res = ed.AppendResidue(bs_chain, old_res.handle,
deep=True)
ed.SetResidueNumber(new_res, mol.ResNum(resnum))
# Add extra residues at the end.
for extra_res in extra_residues:
if extra_res.qualified_name not in seen_res_qn:
resnum += 1
seen_res_qn.append(extra_res.qualified_name)
new_res = ed.AppendResidue(bs_chain,
extra_res.handle,
deep=True)
ed.SetResidueNumber(new_res, mol.ResNum(resnum))
# Add the ligand in chain _
ligand_chain = ed.InsertChain("_")
ligand_res = ed.AppendResidue(ligand_chain, ligand,
deep=True)
ed.RenameResidue(ligand_res, "LIG")
ed.SetResidueNumber(ligand_res, mol.ResNum(1))
ed.UpdateICS()
return bs_ent
def _compute_scores(self):
""""""
# Create the matrix
self._rmsd_matrix = np.empty((len(self.target_ligands),
len(self.model_ligands)), dtype=float)
for trg_ligand in self.target_ligands:
LogDebug("Compute RMSD for target ligand %s" % trg_ligand)
for binding_site in self._get_binding_sites(trg_ligand):
import ipdb; ipdb.set_trace()
for mdl_ligand in self.model_ligands:
LogDebug("Compute RMSD for model ligand %s" % mdl_ligand)
# Get symmetry graphs
model_graph = model_ligand.spyrmsd_mol.to_graph()
target_graph = target_ligand.struct_spyrmsd_mol.to_graph()
pass
self._rmsd_matrix = np.full((len(self.target_ligands),
len(self.model_ligands)),
float("inf"), dtype=float)
self._lddt_pli_matrix = np.empty((len(self.target_ligands),
len(self.model_ligands)), dtype=dict)
for target_i, target_ligand in enumerate(self.target_ligands):
LogDebug("Compute RMSD for target ligand %s" % target_ligand)
for binding_site in self._get_binding_sites(target_ligand):
if len(binding_site.substructure.residues) == 0:
LogWarning("No residue in proximity of target ligand "
"%s" % str(target_ligand))
continue # next binding site
elif len(binding_site.ref_residues) == 0:
LogWarning("Binding site of %s not mapped to the model " %
str(target_ligand))
continue # next binding site
ref_bs_ent = self._build_binding_site_entity(
target_ligand, binding_site.ref_residues,
binding_site.substructure.residues)
ref_bs_ent_ligand = ref_bs_ent.FindResidue("_", 1) # by definition
custom_compounds = {
ref_bs_ent_ligand.name:
mol.alg.lddt.CustomCompound.FromResidue(
ref_bs_ent_ligand)}
lddt_scorer = mol.alg.lddt.lDDTScorer(
ref_bs_ent,
custom_compounds=custom_compounds,
inclusion_radius=self.lddt_pli_radius)
for model_i, model_ligand in enumerate(self.model_ligands):
try:
symmetries = _ComputeSymmetries(
model_ligand, target_ligand,
substructure_match=self.substructure_match,
by_atom_index=True)
except NoSymmetryError:
# Ligands are different - skip
LogDebug("No symmetry between %s and %s" % (
str(model_ligand), str(target_ligand)))
continue
LogDebug("Compute RMSD for model ligand %s" % model_ligand)
rmsd = SCRMSD(model_ligand, target_ligand,
transformation=binding_site.transform,
substructure_match=self.substructure_match)
self._rmsd_matrix[target_i, model_i] = rmsd
mdl_bs_ent = self._build_binding_site_entity(
model_ligand, binding_site.mdl_residues, [])
mdl_bs_ent_ligand = mdl_bs_ent.FindResidue("_", 1) # by definition
# Prepare to save the data for this target/model mapping
# TODO: figure out if this try/except is still needed
# try:
# bb_rmsd = binding_site.bb_rmsd
# except Exception as err:
# # TODO: switch to whole backbone superposition - and drop try/except
# LogWarning("Can't calculate backbone RMSD: %s"
# " - setting to Infinity" % str(err))
# bb_rmsd = float("inf")
self._lddt_pli_matrix[target_i, model_i] = {
"lddt_pli": 0,
"lddt_local": None,
"lddt_pli_n_contacts": None,
"rmsd": rmsd,
# "symmetry_number": i,
"chain_mapping": binding_site.GetFlatChainMapping(),
"lddt_bs": binding_site.lDDT,
"bb_rmsd": binding_site.bb_rmsd,
"bs_num_res": len(binding_site.substructure.residues),
"bs_num_overlap_res": len(binding_site.ref_residues),
}
# Now for each symmetry, loop and rename atoms according
# to ref.
mdl_editor = mdl_bs_ent.EditXCS()
for i, (trg_sym, mdl_sym) in enumerate(symmetries):
# Prepare Entities for RMSD
for mdl_anum, trg_anum in zip(mdl_sym, trg_sym):
# Rename model atoms according to symmetry
trg_atom = ref_bs_ent_ligand.atoms[trg_anum]
mdl_atom = mdl_bs_ent_ligand.atoms[mdl_anum]
mdl_editor.RenameAtom(mdl_atom, trg_atom.name)
mdl_editor.UpdateICS()
global_lddt, local_lddt, lddt_tot, lddt_cons, n_res, \
n_cont, n_cons = lddt_scorer.lDDT(
mdl_bs_ent, chain_mapping={"A": "A", "_": "_"},
no_intrachain=True,
return_dist_test=True)
# Save results?
best_lddt = self._lddt_pli_matrix[target_i, model_i]["lddt_pli"]
if global_lddt > best_lddt:
self._lddt_pli_matrix[target_i, model_i].update({
"lddt_pli": global_lddt,
"lddt_local": local_lddt,
"lddt_pli_n_contacts": lddt_tot / 8,
})
def ResidueToGraph(residue, by_atom_index=False):
......
import unittest, os, sys
import numpy as np
import ost
from ost import io, mol, geom
# check if we can import: fails if numpy or scipy not available
try:
......@@ -10,6 +13,7 @@ except ImportError:
"networkx is missing. Ignoring test_ligand_scoring.py tests.")
sys.exit(0)
#ost.PushVerbosityLevel(ost.LogLevel.Debug)
class TestLigandScoring(unittest.TestCase):
......@@ -43,7 +47,7 @@ class TestLigandScoring(unittest.TestCase):
# IsLigand flag should still be set even on not selected ligands
assert len([r for r in sc.target.residues if r.is_ligand]) == 7
assert len([r for r in sc.model.residues if r.is_ligand]) == 1
# Ensure the residues are not copied
assert len(sc.target.Select("rname=MG").residues) == 2
assert len(sc.target.Select("rname=G3D").residues) == 2
......@@ -218,6 +222,47 @@ class TestLigandScoring(unittest.TestCase):
SCRMSD(mdl_g3d, trg_g3d1_sub) # no full match
rmsd = SCRMSD(mdl_g3d, trg_g3d1_sub, substructure_match=True)
def test__compute_scores(self):
"""Test that _compute_scores works.
"""
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True)
sc = LigandScorer(mdl, trg, None, None)
# Note: expect warning about Binding site of H.ZN1 not mapped to the model
sc._compute_scores()
# Check RMSD
assert sc._rmsd_matrix.shape == (7, 1)
np.testing.assert_almost_equal(sc._rmsd_matrix, np.array(
[[np.inf],
[0.04244993],
[np.inf],
[np.inf],
[np.inf],
[0.29399303],
[np.inf]]))
# Check lDDT-PLI
assert sc._lddt_pli_matrix.shape == (7, 1)
assert sc._lddt_pli_matrix[0, 0] is None
assert sc._lddt_pli_matrix[2, 0] is None
assert sc._lddt_pli_matrix[3, 0] is None
assert sc._lddt_pli_matrix[4, 0] is None
assert sc._lddt_pli_matrix[6, 0] is None
assert sc._lddt_pli_matrix[1, 0]["lddt_pli_n_contacts"] == 638
assert sc._lddt_pli_matrix[5, 0]["lddt_pli_n_contacts"] == 636
assert sc._lddt_pli_matrix[1, 0]["chain_mapping"] == {'A': 'A'}
assert sc._lddt_pli_matrix[5, 0]["chain_mapping"] == {'C': 'A'}
self.assertAlmostEqual(sc._lddt_pli_matrix[1, 0]["lddt_pli"], 0.99843, 5)
self.assertAlmostEqual(sc._lddt_pli_matrix[5, 0]["lddt_pli"], 1.0)
self.assertAlmostEqual(sc._lddt_pli_matrix[1, 0]["rmsd"], 0.04244993)
self.assertAlmostEqual(sc._lddt_pli_matrix[5, 0]["rmsd"], 0.29399303)
assert sc._lddt_pli_matrix[1, 0]["bs_num_res"] == 15
assert sc._lddt_pli_matrix[5, 0]["bs_num_res"] == 15
assert sc._lddt_pli_matrix[1, 0]["bs_num_overlap_res"] == 15
assert sc._lddt_pli_matrix[5, 0]["bs_num_overlap_res"] == 15
if __name__ == "__main__":
from ost import testutils
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment