Skip to content
Snippets Groups Projects
Commit 18da8627 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

ligand scoring: dissect scoring into separate classes/files

Only initial work done for a base scoring class that takes care
of ligand prep, computing ligand symmetries and calling score computation
functionality from child classes. Child class for SCRMSD is already there
with basic testing.
parent 68cf375a
Branches
Tags
No related merge requests found
...@@ -32,6 +32,8 @@ set(OST_MOL_ALG_PYMOD_MODULES ...@@ -32,6 +32,8 @@ set(OST_MOL_ALG_PYMOD_MODULES
ligand_scoring.py ligand_scoring.py
dockq.py dockq.py
contact_score.py contact_score.py
ligand_scoring_base.py
ligand_scoring_scrmsd.py
) )
if (NOT ENABLE_STATIC) if (NOT ENABLE_STATIC)
......
This diff is collapsed.
import numpy as np
from ost import LogWarning
from ost import geom
from ost import mol
from ost.mol.alg import ligand_scoring_base
class SCRMSDScorer(ligand_scoring_base.LigandScorer):
def __init__(self, model, target, model_ligands=None, target_ligands=None,
resnum_alignments=False, rename_ligand_chain=False,
substructure_match=False, coverage_delta=0.2,
max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0,
model_bs_radius=25, binding_sites_topn=100000,
full_bs_search=False):
super().__init__(model, target, model_ligands = model_ligands,
target_ligands = target_ligands,
resnum_alignments = resnum_alignments,
rename_ligand_chain = rename_ligand_chain,
substructure_match = substructure_match,
coverage_delta = coverage_delta,
max_symmetries = 1e5)
self.bs_radius = bs_radius
self.lddt_lp_radius = lddt_lp_radius
self.model_bs_radius = model_bs_radius
self.binding_sites_topn = binding_sites_topn
self.full_bs_search = full_bs_search
# Residues that are in contact with a ligand => binding site
# defined as all residues with at least one atom within self.radius
# key: ligand.handle.hash_code, value: EntityView of whatever
# entity ligand belongs to
self._binding_sites = dict()
# cache for GetRepr chain mapping calls
self._repr = dict()
# lazily precomputed variables to speedup GetRepr chain mapping calls
# for localized GetRepr searches
self.__chem_mapping = None
self.__chem_group_alns = None
self.__ref_mdl_alns = None
self.__chain_mapping_mdl = None
self._get_repr_input = dict()
def _compute(self, symmetries, target_ligand, model_ligand):
# set default to invalid scores
best_rmsd_result = {"rmsd": None,
"lddt_lp": None,
"bs_ref_res": list(),
"bs_ref_res_mapped": list(),
"bs_mdl_res_mapped": list(),
"bb_rmsd": None,
"target_ligand": target_ligand,
"model_ligand": model_ligand,
"chain_mapping": dict(),
"transform": geom.Mat4(),
"inconsistent_residues": list()}
for r in self._get_repr(target_ligand, model_ligand):
rmsd = _SCRMSD_symmetries(symmetries, model_ligand,
target_ligand, transformation=r.transform)
if best_rmsd_result["rmsd"] is None or rmsd < best_rmsd_result["rmsd"]:
best_rmsd_result = {"rmsd": rmsd,
"lddt_lp": r.lDDT,
"bs_ref_res": r.substructure.residues,
"bs_ref_res_mapped": r.ref_residues,
"bs_mdl_res_mapped": r.mdl_residues,
"bb_rmsd": r.bb_rmsd,
"target_ligand": target_ligand,
"model_ligand": model_ligand,
"chain_mapping": r.GetFlatChainMapping(),
"transform": r.transform,
"inconsistent_residues": r.inconsistent_residues}
# set default to error
best_rmsd = np.nan
error_state = 10
if best_rmsd_result["rmsd"] is not None:
# but here we save the day
best_rmsd = best_rmsd_result["rmsd"]
error_state = 0
return (best_rmsd, error_state, best_rmsd_result)
def _get_repr(self, target_ligand, model_ligand):
key = None
if self.full_bs_search:
# all possible binding sites, independent from actual model ligand
key = (target_ligand.handle.hash_code, 0)
else:
key = (target_ligand.handle.hash_code, model_ligand.handle.hash_code)
if key not in self._repr:
ref_bs = self._get_target_binding_site(target_ligand)
if self.full_bs_search:
reprs = self.chain_mapper.GetRepr(
ref_bs, self.model, inclusion_radius=self.lddt_lp_radius,
topn=self.binding_sites_topn)
else:
reprs = self.chain_mapper.GetRepr(ref_bs, self.model,
inclusion_radius=self.lddt_lp_radius,
topn=self.binding_sites_topn,
chem_mapping_result = self._get_get_repr_input(model_ligand))
self._repr[key] = reprs
if len(reprs) == 0:
# whatever is in there already has precedence
if target_ligand not in self._unassigned_target_ligands_reason:
self._unassigned_target_ligands_reason[target_ligand] = (
"model_representation",
"No representation of the reference binding site was "
"found in the model")
return self._repr[key]
def _get_target_binding_site(self, target_ligand):
if target_ligand.handle.hash_code not in self._binding_sites:
# create view of reference binding site
ref_residues_hashes = set() # helper to keep track of added residues
ignored_residue_hashes = {target_ligand.hash_code}
for ligand_at in target_ligand.atoms:
close_atoms = self.target.FindWithin(ligand_at.GetPos(), self.bs_radius)
for close_at in close_atoms:
# Skip any residue not in the chain mapping target
ref_res = close_at.GetResidue()
h = ref_res.handle.GetHashCode()
if h not in ref_residues_hashes and \
h not in ignored_residue_hashes:
if self.chain_mapper.target.ViewForHandle(ref_res).IsValid():
h = ref_res.handle.GetHashCode()
ref_residues_hashes.add(h)
elif ref_res.is_ligand:
LogWarning("Ignoring ligand %s in binding site of %s" % (
ref_res.qualified_name, target_ligand.qualified_name))
ignored_residue_hashes.add(h)
elif ref_res.chem_type == mol.ChemType.WATERS:
pass # That's ok, no need to warn
else:
LogWarning("Ignoring residue %s in binding site of %s" % (
ref_res.qualified_name, target_ligand.qualified_name))
ignored_residue_hashes.add(h)
ref_bs = self.target.CreateEmptyView()
if ref_residues_hashes:
# reason for doing that separately is to guarantee same ordering of
# residues as in underlying entity. (Reorder by ResNum seems only
# available on ChainHandles)
for ch in self.target.chains:
for r in ch.residues:
if r.handle.GetHashCode() in ref_residues_hashes:
ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL)
if len(ref_bs.residues) == 0:
raise RuntimeError("Failed to add proximity residues to "
"the reference binding site entity")
else:
# Flag missing binding site
self._unassigned_target_ligands_reason[target_ligand] = ("binding_site",
"No residue in proximity of the target ligand")
self._binding_sites[target_ligand.handle.hash_code] = ref_bs
return self._binding_sites[target_ligand.handle.hash_code]
@property
def _chem_mapping(self):
if self.__chem_mapping is None:
self.__chem_mapping, self.__chem_group_alns, \
self.__chain_mapping_mdl = \
self.chain_mapper.GetChemMapping(self.model)
return self.__chem_mapping
@property
def _chem_group_alns(self):
if self.__chem_group_alns is None:
self.__chem_mapping, self.__chem_group_alns, \
self.__chain_mapping_mdl = \
self.chain_mapper.GetChemMapping(self.model)
return self.__chem_group_alns
@property
def _ref_mdl_alns(self):
if self.__ref_mdl_alns is None:
self.__ref_mdl_alns = \
chain_mapping._GetRefMdlAlns(self.chain_mapper.chem_groups,
self.chain_mapper.chem_group_alignments,
self._chem_mapping,
self._chem_group_alns)
return self.__ref_mdl_alns
@property
def _chain_mapping_mdl(self):
if self.__chain_mapping_mdl is None:
self.__chem_mapping, self.__chem_group_alns, \
self.__chain_mapping_mdl = \
self.chain_mapper.GetChemMapping(self.model)
return self.__chain_mapping_mdl
def _get_get_repr_input(self, mdl_ligand):
if mdl_ligand.handle.hash_code not in self._get_repr_input:
# figure out what chains in the model are in contact with the ligand
# that may give a non-zero contribution to lDDT in
# chain_mapper.GetRepr
radius = self.model_bs_radius
chains = set()
for at in mdl_ligand.atoms:
close_atoms = self._chain_mapping_mdl.FindWithin(at.GetPos(),
radius)
for close_at in close_atoms:
chains.add(close_at.GetChain().GetName())
if len(chains) > 0:
# the chain mapping model which only contains close chains
query = "cname="
query += ','.join([mol.QueryQuoteName(x) for x in chains])
mdl = self._chain_mapping_mdl.Select(query)
# chem mapping which is reduced to the respective chains
chem_mapping = list()
for m in self._chem_mapping:
chem_mapping.append([x for x in m if x in chains])
self._get_repr_input[mdl_ligand.handle.hash_code] = \
(mdl, chem_mapping)
else:
self._get_repr_input[mdl_ligand.handle.hash_code] = \
(self._chain_mapping_mdl.CreateEmptyView(),
[list() for _ in self._chem_mapping])
return (self._get_repr_input[mdl_ligand.hash_code][1],
self._chem_group_alns,
self._get_repr_input[mdl_ligand.hash_code][0])
def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(),
substructure_match=False, max_symmetries=1e6):
"""Calculate symmetry-corrected RMSD.
Binding site superposition must be computed separately and passed as
`transformation`.
:param model_ligand: The model ligand
:type model_ligand: :class:`ost.mol.ResidueHandle` or
:class:`ost.mol.ResidueView`
:param target_ligand: The target ligand
:type target_ligand: :class:`ost.mol.ResidueHandle` or
:class:`ost.mol.ResidueView`
:param transformation: Optional transformation to apply on each atom
position of model_ligand.
:type transformation: :class:`ost.geom.Mat4`
:param substructure_match: Set this to True to allow partial target
ligand.
:type substructure_match: :class:`bool`
:param max_symmetries: If more than that many isomorphisms exist, raise
a :class:`TooManySymmetriesError`. This can only be assessed by
generating at least that many isomorphisms and can take some time.
:type max_symmetries: :class:`int`
:rtype: :class:`float`
:raises: :class:`NoSymmetryError` when no symmetry can be found,
:class:`DisconnectedGraphError` when ligand graph is disconnected,
:class:`TooManySymmetriesError` when more than `max_symmetries`
isomorphisms are found.
"""
symmetries = ligand_scoring_base.ComputeSymmetries(model_ligand,
target_ligand,
substructure_match=substructure_match,
by_atom_index=True,
max_symmetries=max_symmetries)
return _SCRMSD_symmetries(symmetries, model_ligand, target_ligand,
transformation)
def _SCRMSD_symmetries(symmetries, model_ligand, target_ligand,
transformation):
"""Compute SCRMSD with pre-computed symmetries. Internal. """
# setup numpy positions for model ligand and apply transformation
mdl_ligand_pos = np.ones((model_ligand.GetAtomCount(), 4))
for a_idx, a in enumerate(model_ligand.atoms):
p = a.GetPos()
mdl_ligand_pos[a_idx, 0] = p[0]
mdl_ligand_pos[a_idx, 1] = p[1]
mdl_ligand_pos[a_idx, 2] = p[2]
np_transformation = np.zeros((4,4))
for i in range(4):
for j in range(4):
np_transformation[i,j] = transformation[i,j]
mdl_ligand_pos = mdl_ligand_pos.dot(np_transformation.T)[:,:3]
# setup numpy positions for target ligand
trg_ligand_pos = np.zeros((target_ligand.GetAtomCount(), 3))
for a_idx, a in enumerate(target_ligand.atoms):
p = a.GetPos()
trg_ligand_pos[a_idx, 0] = p[0]
trg_ligand_pos[a_idx, 1] = p[1]
trg_ligand_pos[a_idx, 2] = p[2]
# position matrices to iterate symmetries
# there is a guarantee that
# target_ligand.GetAtomCount() <= model_ligand.GetAtomCount()
# and that each target ligand atom is part of every symmetry
# => target_ligand.GetAtomCount() is size of both position matrices
rmsd_mdl_pos = np.zeros((target_ligand.GetAtomCount(), 3))
rmsd_trg_pos = np.zeros((target_ligand.GetAtomCount(), 3))
# iterate symmetries and find the one with lowest RMSD
best_rmsd = np.inf
for i, (trg_sym, mdl_sym) in enumerate(symmetries):
for idx, (mdl_anum, trg_anum) in enumerate(zip(mdl_sym, trg_sym)):
rmsd_mdl_pos[idx,:] = mdl_ligand_pos[mdl_anum, :]
rmsd_trg_pos[idx,:] = trg_ligand_pos[trg_anum, :]
rmsd = np.sqrt(((rmsd_mdl_pos - rmsd_trg_pos)**2).sum(-1).mean())
if rmsd < best_rmsd:
best_rmsd = rmsd
return best_rmsd
...@@ -20,7 +20,8 @@ if (COMPOUND_LIB) ...@@ -20,7 +20,8 @@ if (COMPOUND_LIB)
list(APPEND OST_MOL_ALG_UNIT_TESTS test_qsscoring.py list(APPEND OST_MOL_ALG_UNIT_TESTS test_qsscoring.py
test_nonstandard.py test_nonstandard.py
test_chain_mapping.py test_chain_mapping.py
test_ligand_scoring.py) test_ligand_scoring.py
test_ligand_scoring_fancy.py)
endif() endif()
ost_unittest(MODULE mol_alg SOURCES "${OST_MOL_ALG_UNIT_TESTS}" LINK ost_io) ost_unittest(MODULE mol_alg SOURCES "${OST_MOL_ALG_UNIT_TESTS}" LINK ost_io)
import unittest, os, sys
from functools import lru_cache
import numpy as np
import ost
from ost import io, mol, geom
# check if we can import: fails if numpy or scipy not available
try:
from ost.mol.alg.ligand_scoring_base import *
from ost.mol.alg import ligand_scoring_base
from ost.mol.alg import ligand_scoring_scrmsd
except ImportError:
print("Failed to import ligand_scoring.py. Happens when numpy, scipy or "
"networkx is missing. Ignoring test_ligand_scoring.py tests.")
sys.exit(0)
def _GetTestfilePath(filename):
"""Get the path to the test file given filename"""
return os.path.join('testfiles', filename)
@lru_cache(maxsize=None)
def _LoadMMCIF(filename):
path = _GetTestfilePath(filename)
ent = io.LoadMMCIF(path)
return ent
@lru_cache(maxsize=None)
def _LoadPDB(filename):
path = _GetTestfilePath(filename)
ent = io.LoadPDB(path)
return ent
@lru_cache(maxsize=None)
def _LoadEntity(filename):
path = _GetTestfilePath(filename)
ent = io.LoadEntity(path)
return ent
class TestLigandScoringFancy(unittest.TestCase):
def setUp(self):
# Silence expected warnings about ignoring of ligands in binding site
ost.PushVerbosityLevel(ost.LogLevel.Error)
def tearDown(self):
ost.PopVerbosityLevel()
def test_extract_ligands_mmCIF(self):
"""Test that we can extract ligands from mmCIF files.
"""
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
sc = LigandScorer(mdl, trg, None, None)
self.assertEqual(len(sc.target_ligands), 7)
self.assertEqual(len(sc.model_ligands), 1)
self.assertEqual(len([r for r in sc.target.residues if r.is_ligand]), 7)
self.assertEqual(len([r for r in sc.model.residues if r.is_ligand]), 1)
def test_extract_ligands_PDB(self):
"""Test that we can extract ligands from PDB files containing HET records.
"""
trg = _LoadPDB("1R8Q.pdb")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
sc = LigandScorer(mdl, trg, None, None)
self.assertEqual(len(sc.target_ligands), 7)
self.assertEqual(len(sc.model_ligands), 1)
self.assertEqual(len([r for r in sc.target.residues if r.is_ligand]), 7)
self.assertEqual(len([r for r in sc.model.residues if r.is_ligand]), 1)
def test_init_given_ligands(self):
"""Test that we can instantiate the scorer with ligands contained in
the target and model entity and given in a list.
"""
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
# Pass entity views
trg_lig = [trg.Select("rname=MG"), trg.Select("rname=G3D")]
mdl_lig = [mdl.Select("rname=G3D")]
sc = LigandScorer(mdl, trg, mdl_lig, trg_lig)
self.assertEqual(len(sc.target_ligands), 4)
self.assertEqual(len(sc.model_ligands), 1)
# IsLigand flag should still be set even on not selected ligands
self.assertEqual(len([r for r in sc.target.residues if r.is_ligand]), 7)
self.assertEqual(len([r for r in sc.model.residues if r.is_ligand]), 1)
# Ensure the residues are not copied
self.assertEqual(len(sc.target.Select("rname=MG").residues), 2)
self.assertEqual(len(sc.target.Select("rname=G3D").residues), 2)
self.assertEqual(len(sc.model.Select("rname=G3D").residues), 1)
# Pass residue handles
trg_lig = [trg.FindResidue("F", 1), trg.FindResidue("H", 1)]
mdl_lig = [mdl.FindResidue("L_2", 1)]
sc = LigandScorer(mdl, trg, mdl_lig, trg_lig)
self.assertEqual(len(sc.target_ligands), 2)
self.assertEqual(len(sc.model_ligands), 1)
# Ensure the residues are not copied
self.assertEqual(len(sc.target.Select("rname=ZN").residues), 1)
self.assertEqual(len(sc.target.Select("rname=G3D").residues), 2)
self.assertEqual(len(sc.model.Select("rname=G3D").residues), 1)
def test_init_sdf_ligands(self):
"""Test that we can instantiate the scorer with ligands from separate SDF files.
In order to setup the ligand SDF files, the following code was used:
for prefix in [os.path.join('testfiles', x) for x in ["1r8q", "P84080_model_02"]]:
trg = io.LoadMMCIF("%s.cif.gz" % prefix)
trg_prot = trg.Select("protein=True")
io.SavePDB(trg_prot, "%s_protein.pdb.gz" % prefix)
lig_num = 0
for chain in trg.chains:
if chain.chain_type == mol.ChainType.CHAINTYPE_NON_POLY:
lig_sel = trg.Select("cname=%s" % chain.name)
lig_ent = mol.CreateEntityFromView(lig_sel, False)
io.SaveEntity(lig_ent, "%s_ligand_%d.sdf" % (prefix, lig_num))
lig_num += 1
"""
mdl = _LoadPDB("P84080_model_02_nolig.pdb")
mdl_ligs = [_LoadEntity("P84080_model_02_ligand_0.sdf")]
trg = _LoadPDB("1r8q_protein.pdb.gz")
trg_ligs = [_LoadEntity("1r8q_ligand_%d.sdf" % i) for i in range(7)]
# Pass entities
sc = LigandScorer(mdl, trg, mdl_ligs, trg_ligs)
self.assertEqual(len(sc.target_ligands), 7)
self.assertEqual(len(sc.model_ligands), 1)
# Ensure we set the is_ligand flag
self.assertEqual(len([r for r in sc.target.residues if r.is_ligand]), 7)
self.assertEqual(len([r for r in sc.model.residues if r.is_ligand]), 1)
# Pass residues
mdl_ligs_res = [mdl_ligs[0].residues[0]]
trg_ligs_res = [res for ent in trg_ligs for res in ent.residues]
sc = LigandScorer(mdl, trg, mdl_ligs_res, trg_ligs_res)
self.assertEqual(len(sc.target_ligands), 7)
self.assertEqual(len(sc.model_ligands), 1)
def test_init_reject_duplicate_ligands(self):
"""Test that we reject input if multiple ligands with the same chain
name/residue number are given.
"""
mdl = _LoadPDB("P84080_model_02_nolig.pdb")
mdl_ligs = [_LoadEntity("P84080_model_02_ligand_0.sdf")]
trg = _LoadPDB("1r8q_protein.pdb.gz")
trg_ligs = [_LoadEntity("1r8q_ligand_%d.sdf" % i) for i in range(7)]
# Reject identical model ligands
with self.assertRaises(RuntimeError):
sc = LigandScorer(mdl, trg, [mdl_ligs[0], mdl_ligs[0]], trg_ligs)
# Reject identical target ligands
lig0 = trg_ligs[0].Copy()
lig1 = trg_ligs[1].Copy()
ed1 = lig1.EditXCS()
ed1.RenameChain(lig1.chains[0], lig0.chains[0].name)
ed1.SetResidueNumber(lig1.residues[0], lig0.residues[0].number)
with self.assertRaises(RuntimeError):
sc = LigandScorer(mdl, trg, mdl_ligs, [lig0, lig1])
def test__ResidueToGraph(self):
"""Test that _ResidueToGraph works as expected
"""
mdl_lig = _LoadEntity("P84080_model_02_ligand_0.sdf")
graph = ligand_scoring_base._ResidueToGraph(mdl_lig.residues[0])
self.assertEqual(len(graph.edges), 34)
self.assertEqual(len(graph.nodes), 32)
# Check an arbitrary node
self.assertEqual([a for a in graph.adj["14"].keys()], ["13", "29"])
graph = ligand_scoring_base._ResidueToGraph(mdl_lig.residues[0], by_atom_index=True)
self.assertEqual(len(graph.edges), 34)
self.assertEqual(len(graph.nodes), 32)
# Check an arbitrary node
self.assertEqual([a for a in graph.adj[13].keys()], [12, 28])
def test__ComputeSymmetries(self):
"""Test that _ComputeSymmetries works.
"""
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
trg_mg1 = trg.FindResidue("E", 1)
trg_g3d1 = trg.FindResidue("F", 1)
trg_afb1 = trg.FindResidue("G", 1)
trg_g3d2 = trg.FindResidue("J", 1)
mdl_g3d = mdl.FindResidue("L_2", 1)
sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d, trg_g3d1)
self.assertEqual(len(sym), 72)
sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d, trg_g3d1, by_atom_index=True)
self.assertEqual(len(sym), 72)
# Test that we can match ions read from SDF
sdf_lig = _LoadEntity("1r8q_ligand_0.sdf")
sym = ligand_scoring_base.ComputeSymmetries(trg_mg1, sdf_lig.residues[0], by_atom_index=True)
self.assertEqual(len(sym), 1)
# Test that it works with views and only consider atoms in the view
# Skip PA, PB and O[1-3]A and O[1-3]B in target and model
# We assume atom index are fixed and won't change
trg_g3d1_sub_ent = trg_g3d1.Select("aindex>6019")
trg_g3d1_sub = trg_g3d1_sub_ent.residues[0]
mdl_g3d_sub_ent = mdl_g3d.Select("aindex>1447")
mdl_g3d_sub = mdl_g3d_sub_ent.residues[0]
sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d_sub, trg_g3d1_sub)
self.assertEqual(len(sym), 6)
sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d_sub, trg_g3d1_sub, by_atom_index=True)
self.assertEqual(len(sym), 6)
# Substructure matches
sym = ligand_scoring_base.ComputeSymmetries(mdl_g3d, trg_g3d1_sub, substructure_match=True)
self.assertEqual(len(sym), 6)
# Missing atoms only allowed in target, not in model
with self.assertRaises(NoSymmetryError):
ligand_scoring_base.ComputeSymmetries(mdl_g3d_sub, trg_g3d1, substructure_match=True)
def test_SCRMSD(self):
"""Test that SCRMSD works.
"""
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
trg_mg1 = trg.FindResidue("E", 1)
trg_g3d1 = trg.FindResidue("F", 1)
trg_afb1 = trg.FindResidue("G", 1)
trg_g3d2 = trg.FindResidue("J", 1)
mdl_g3d = mdl.FindResidue("L_2", 1)
rmsd = ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d1)
self.assertAlmostEqual(rmsd, 2.21341e-06, 10)
rmsd = ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d2)
self.assertAlmostEqual(rmsd, 61.21325, 4)
# Ensure we raise a NoSymmetryError if the ligand is wrong
with self.assertRaises(NoSymmetryError):
ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_mg1)
with self.assertRaises(NoSymmetryError):
ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_afb1)
# Assert that transform works
trans = geom.Mat4(-0.999256, 0.00788487, -0.0377333, -15.4397,
0.0380652, 0.0473315, -0.998154, 29.9477,
-0.00608426, -0.998848, -0.0475963, 28.8251,
0, 0, 0, 1)
rmsd = ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d2, transformation=trans)
self.assertAlmostEqual(rmsd, 0.293972, 5)
# Assert that substructure matches work
trg_g3d1_sub = trg_g3d1.Select("aindex>6019").residues[0] # Skip PA, PB and O[1-3]A and O[1-3]B.
# mdl_g3d_sub = mdl_g3d.Select("aindex>1447").residues[0] # Skip PA, PB and O[1-3]A and O[1-3]B.
with self.assertRaises(NoIsomorphicSymmetryError):
ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d1_sub) # no full match
# But partial match is OK
rmsd = ligand_scoring_scrmsd.SCRMSD(mdl_g3d, trg_g3d1_sub, substructure_match=True)
self.assertAlmostEqual(rmsd, 2.2376232209353475e-06, 8)
# Ensure it doesn't work the other way around - ie incomplete model is invalid
with self.assertRaises(NoSymmetryError):
ligand_scoring_scrmsd.SCRMSD(trg_g3d1_sub, mdl_g3d) # no full match
def test_compute_rmsd_scores(self):
"""Test that _compute_scores works.
"""
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
sc = ligand_scoring_scrmsd.SCRMSDScorer(mdl, trg, [mdl_lig], None)
# Note: expect warning about Binding site of H.ZN1 not mapped to the model
self.assertEqual(sc.score_matrix.shape, (7, 1))
np.testing.assert_almost_equal(sc.score_matrix, np.array(
[[np.nan],
[0.04244993],
[np.nan],
[np.nan],
[np.nan],
[0.29399303],
[np.nan]]), decimal=5)
if __name__ == "__main__":
from ost import testutils
if testutils.DefaultCompoundLibIsSet():
testutils.RunTests()
else:
print('No compound lib available. Ignoring test_ligand_scoring.py tests.')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment