Skip to content
Snippets Groups Projects
Commit f0604436 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

ligand scoring: add lddt pli to dissected scoring classes

parent 18da8627
No related branches found
No related tags found
No related merge requests found
...@@ -34,6 +34,7 @@ set(OST_MOL_ALG_PYMOD_MODULES ...@@ -34,6 +34,7 @@ set(OST_MOL_ALG_PYMOD_MODULES
contact_score.py contact_score.py
ligand_scoring_base.py ligand_scoring_base.py
ligand_scoring_scrmsd.py ligand_scoring_scrmsd.py
ligand_scoring_lddtpli.py
) )
if (NOT ENABLE_STATIC) if (NOT ENABLE_STATIC)
......
...@@ -57,10 +57,10 @@ class LigandScorer: ...@@ -57,10 +57,10 @@ class LigandScorer:
# lazily computed attributes # lazily computed attributes
self._chain_mapper = None self._chain_mapper = None
# keep track of error states # keep track of states
# simple integers instead of enums - documentation of property describes # simple integers instead of enums - documentation of property describes
# encoding # encoding
self._error_states = None self._states = None
# score matrices # score matrices
self._score_matrix = None self._score_matrix = None
...@@ -68,16 +68,15 @@ class LigandScorer: ...@@ -68,16 +68,15 @@ class LigandScorer:
self._aux_data = None self._aux_data = None
@property @property
def error_states(self): def states(self):
""" Encodes error states of ligand pairs """ Encodes states of ligand pairs
Not only critical things, but also things like: a pair of ligands Expect a valid score if respective location in this matrix is 0.
simply doesn't match. Target ligands are in rows, model ligands in Target ligands are in rows, model ligands in columns. States are encoded
columns. States are encoded as integers <= 9. Larger numbers encode as integers <= 9. Larger numbers encode errors for child classes.
errors for child classes.
* -1: Unknown Error - cannot be matched * -1: Unknown Error - cannot be matched
* 0: Ligand pair has valid symmetries - can be matched. * 0: Ligand pair can be matched and valid score is computed.
* 1: Ligand pair has no valid symmetry - cannot be matched. * 1: Ligand pair has no valid symmetry - cannot be matched.
* 2: Ligand pair has too many symmetries - cannot be matched. * 2: Ligand pair has too many symmetries - cannot be matched.
You might be able to get a match by increasing *max_symmetries*. You might be able to get a match by increasing *max_symmetries*.
...@@ -90,9 +89,9 @@ class LigandScorer: ...@@ -90,9 +89,9 @@ class LigandScorer:
:rtype: :class:`~numpy.ndarray` :rtype: :class:`~numpy.ndarray`
""" """
if self._error_states is None: if self._states is None:
self._compute_scores() self._compute_scores()
return self._error_states return self._states
@property @property
def score_matrix(self): def score_matrix(self):
...@@ -102,7 +101,7 @@ class LigandScorer: ...@@ -102,7 +101,7 @@ class LigandScorer:
NaN values indicate that no value could be computed (i.e. different NaN values indicate that no value could be computed (i.e. different
ligands). In other words: values are only valid if respective location ligands). In other words: values are only valid if respective location
:attr:`~error_states` is 0. :attr:`~states` is 0.
:rtype: :class:`~numpy.ndarray` :rtype: :class:`~numpy.ndarray`
""" """
...@@ -118,7 +117,7 @@ class LigandScorer: ...@@ -118,7 +117,7 @@ class LigandScorer:
NaN values indicate that no value could be computed (i.e. different NaN values indicate that no value could be computed (i.e. different
ligands). In other words: values are only valid if respective location ligands). In other words: values are only valid if respective location
:attr:`~error_states` is 0. If `substructure_match=False`, only full :attr:`~states` is 0. If `substructure_match=False`, only full
match isomorphisms are considered, and therefore only values of 1.0 match isomorphisms are considered, and therefore only values of 1.0
can be observed. can be observed.
...@@ -138,7 +137,7 @@ class LigandScorer: ...@@ -138,7 +137,7 @@ class LigandScorer:
class to provide additional information for a scored ligand pair. class to provide additional information for a scored ligand pair.
empty dictionaries indicate that no value could be computed empty dictionaries indicate that no value could be computed
(i.e. different ligands). In other words: values are only valid if (i.e. different ligands). In other words: values are only valid if
respective location :attr:`~error_states` is 0. respective location :attr:`~states` is 0.
:rtype: :class:`~numpy.ndarray` :rtype: :class:`~numpy.ndarray`
""" """
...@@ -301,7 +300,7 @@ class LigandScorer: ...@@ -301,7 +300,7 @@ class LigandScorer:
shape = (len(self.target_ligands), len(self.model_ligands)) shape = (len(self.target_ligands), len(self.model_ligands))
self._score_matrix = np.full(shape, np.nan, dtype=np.float32) self._score_matrix = np.full(shape, np.nan, dtype=np.float32)
self._coverage_matrix = np.full(shape, np.nan, dtype=np.float32) self._coverage_matrix = np.full(shape, np.nan, dtype=np.float32)
self._error_states = np.full(shape, -1, dtype=np.int32) self._states = np.full(shape, -1, dtype=np.int32)
self._aux_data = np.empty(shape, dtype=dict) self._aux_data = np.empty(shape, dtype=dict)
for target_id, target_ligand in enumerate(self.target_ligands): for target_id, target_ligand in enumerate(self.target_ligands):
...@@ -324,42 +323,45 @@ class LigandScorer: ...@@ -324,42 +323,45 @@ class LigandScorer:
# Ligands are different - skip # Ligands are different - skip
LogVerbose("No symmetry between %s and %s" % ( LogVerbose("No symmetry between %s and %s" % (
str(model_ligand), str(target_ligand))) str(model_ligand), str(target_ligand)))
self._error_states[target_id, model_id] = 1 self._states[target_id, model_id] = 1
continue continue
except TooManySymmetriesError: except TooManySymmetriesError:
# Ligands are too symmetrical - skip # Ligands are too symmetrical - skip
LogVerbose("Too many symmetries between %s and %s" % ( LogVerbose("Too many symmetries between %s and %s" % (
str(model_ligand), str(target_ligand))) str(model_ligand), str(target_ligand)))
self._error_states[target_id, model_id] = 2 self._states[target_id, model_id] = 2
continue continue
except NoIsomorphicSymmetryError: except NoIsomorphicSymmetryError:
# Ligands are different - skip # Ligands are different - skip
LogVerbose("No isomorphic symmetry between %s and %s" % ( LogVerbose("No isomorphic symmetry between %s and %s" % (
str(model_ligand), str(target_ligand))) str(model_ligand), str(target_ligand)))
self._error_states[target_id, model_id] = 3 self._states[target_id, model_id] = 3
continue continue
except DisconnectedGraphError: except DisconnectedGraphError:
LogVerbose("Disconnected graph observed for %s and %s" % ( LogVerbose("Disconnected graph observed for %s and %s" % (
str(model_ligand), str(target_ligand))) str(model_ligand), str(target_ligand)))
self._error_states[target_id, model_id] = 4 self._states[target_id, model_id] = 4
continue continue
##################################################### #####################################################
# Compute score by calling the child class _compute # # Compute score by calling the child class _compute #
##################################################### #####################################################
score, error_state, aux = self._compute(symmetries, target_ligand, score, state, aux = self._compute(symmetries, target_ligand,
model_ligand) model_ligand)
############ ############
# Finalize # # Finalize #
############ ############
if error_state != 0: if state != 0:
# non-zero error states up to 4 are reserved for base class # non-zero error states up to 4 are reserved for base class
if error_state <= 9: if state <= 9:
raise RuntimeError("Child returned reserved err. state") raise RuntimeError("Child returned reserved err. state")
self._error_states[target_id, model_id] = error_state self._states[target_id, model_id] = state
if error_state == 0: if state == 0:
if score is None or np.isnan(score):
raise RuntimeError("LigandScorer returned invalid "
"score despite 0 error state")
# it's a valid score! # it's a valid score!
self._score_matrix[target_id, model_id] = score self._score_matrix[target_id, model_id] = score
cvg = len(symmetries[0][0]) / len(model_ligand.atoms) cvg = len(symmetries[0][0]) / len(model_ligand.atoms)
...@@ -386,12 +388,13 @@ class LigandScorer: ...@@ -386,12 +388,13 @@ class LigandScorer:
:class:`ost.mol.ResidueView` :class:`ost.mol.ResidueView`
:returns: A :class:`tuple` with three elements: 1) a score :returns: A :class:`tuple` with three elements: 1) a score
(:class:`float`) 2) error state (:class:`int`). (:class:`float`) 2) state (:class:`int`).
3) auxiliary data for this ligand pair (:class:`dict`). 3) auxiliary data for this ligand pair (:class:`dict`).
If error state is 0, the score and auxiliary data will be If state is 0, the score and auxiliary data will be
added to :attr:`~score_matrix` and :attr:`~aux_data` as well added to :attr:`~score_matrix` and :attr:`~aux_data` as well
as the respective value in :attr:`~coverage_matrix`. as the respective value in :attr:`~coverage_matrix`.
Child specific non-zero states MUST be >= 10. Returned score must be valid in this case (not None/NaN).
Child specific non-zero states must be >= 10.
""" """
raise NotImplementedError("_compute must be implemented by child class") raise NotImplementedError("_compute must be implemented by child class")
......
This diff is collapsed.
...@@ -9,11 +9,11 @@ from ost.mol.alg import ligand_scoring_base ...@@ -9,11 +9,11 @@ from ost.mol.alg import ligand_scoring_base
class SCRMSDScorer(ligand_scoring_base.LigandScorer): class SCRMSDScorer(ligand_scoring_base.LigandScorer):
def __init__(self, model, target, model_ligands=None, target_ligands=None, def __init__(self, model, target, model_ligands=None, target_ligands=None,
resnum_alignments=False, rename_ligand_chain=False, resnum_alignments=False, rename_ligand_chain=False,
substructure_match=False, coverage_delta=0.2, substructure_match=False, coverage_delta=0.2,
max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0, max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0,
model_bs_radius=25, binding_sites_topn=100000, model_bs_radius=25, binding_sites_topn=100000,
full_bs_search=False): full_bs_search=False):
super().__init__(model, target, model_ligands = model_ligands, super().__init__(model, target, model_ligands = model_ligands,
......
...@@ -10,6 +10,7 @@ try: ...@@ -10,6 +10,7 @@ try:
from ost.mol.alg.ligand_scoring_base import * from ost.mol.alg.ligand_scoring_base import *
from ost.mol.alg import ligand_scoring_base from ost.mol.alg import ligand_scoring_base
from ost.mol.alg import ligand_scoring_scrmsd from ost.mol.alg import ligand_scoring_scrmsd
from ost.mol.alg import ligand_scoring_lddtpli
except ImportError: except ImportError:
print("Failed to import ligand_scoring.py. Happens when numpy, scipy or " print("Failed to import ligand_scoring.py. Happens when numpy, scipy or "
"networkx is missing. Ignoring test_ligand_scoring.py tests.") "networkx is missing. Ignoring test_ligand_scoring.py tests.")
...@@ -281,6 +282,7 @@ class TestLigandScoringFancy(unittest.TestCase): ...@@ -281,6 +282,7 @@ class TestLigandScoringFancy(unittest.TestCase):
with self.assertRaises(NoSymmetryError): with self.assertRaises(NoSymmetryError):
ligand_scoring_scrmsd.SCRMSD(trg_g3d1_sub, mdl_g3d) # no full match ligand_scoring_scrmsd.SCRMSD(trg_g3d1_sub, mdl_g3d) # no full match
def test_compute_rmsd_scores(self): def test_compute_rmsd_scores(self):
"""Test that _compute_scores works. """Test that _compute_scores works.
""" """
...@@ -300,6 +302,128 @@ class TestLigandScoringFancy(unittest.TestCase): ...@@ -300,6 +302,128 @@ class TestLigandScoringFancy(unittest.TestCase):
[0.29399303], [0.29399303],
[np.nan]]), decimal=5) [np.nan]]), decimal=5)
def test_compute_lddtpli_scores(self):
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl_lig], None,
add_mdl_contacts = False,
lddt_pli_binding_site_radius = 4.0)
self.assertEqual(sc.score_matrix.shape, (7, 1))
self.assertTrue(np.isnan(sc.score_matrix[0, 0]))
self.assertAlmostEqual(sc.score_matrix[1, 0], 0.99843, 5)
self.assertTrue(np.isnan(sc.score_matrix[2, 0]))
self.assertTrue(np.isnan(sc.score_matrix[3, 0]))
self.assertTrue(np.isnan(sc.score_matrix[4, 0]))
self.assertAlmostEqual(sc.score_matrix[5, 0], 1.0)
self.assertTrue(np.isnan(sc.score_matrix[6, 0]))
def test_check_resnames(self):
"""Test that the check_resname argument works.
When set to True, it should raise an error if any residue in the
representation of the binding site in the model has a different
name than in the reference. Here we manually modify a residue
name to achieve that effect. This is only relevant for the LDDTPLIScorer
"""
trg_4c0a = _LoadMMCIF("4c0a.cif.gz")
trg = trg_4c0a.Select("cname=C or cname=I")
# Here we modify the name of a residue in 4C0A (THR => TPO in C.15)
# This residue is in the binding site and should trigger the error
mdl = ost.mol.CreateEntityFromView(trg, include_exlusive_atoms=False)
ed = mdl.EditICS()
ed.RenameResidue(mdl.FindResidue("C", 15), "TPO")
ed.UpdateXCS()
with self.assertRaises(RuntimeError):
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl.FindResidue("I", 1)], [trg.FindResidue("I", 1)], check_resnames=True)
sc._compute_scores()
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl.FindResidue("I", 1)], [trg.FindResidue("I", 1)], check_resnames=False)
sc._compute_scores()
def test_added_mdl_contacts(self):
# binding site for ligand in chain G consists of chains A and B
prot = _LoadMMCIF("1r8q.cif.gz").Copy()
# model has the full binding site
mdl = mol.CreateEntityFromView(prot.Select("cname=A,B,G"), True)
# chain C has same sequence as chain A but is not in contact
# with ligand in chain G
# target has thus incomplete binding site only from chain B
trg = mol.CreateEntityFromView(prot.Select("cname=B,C,G"), True)
# if added model contacts are not considered, the incomplete binding
# site only from chain B is perfectly reproduced by model which also has
# chain B
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=False)
self.assertAlmostEqual(sc.score_matrix[0,0], 1.0, 5)
# if added model contacts are considered, contributions from chain B are
# perfectly reproduced but all contacts of ligand towards chain A are
# added as penalty
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True)
lig = prot.Select("cname=G")
A_count = 0
B_count = 0
for a in lig.atoms:
close_atoms = mdl.FindWithin(a.GetPos(), sc.lddt_pli_radius)
for ca in close_atoms:
cname = ca.GetChain().GetName()
if cname == "G":
pass # its a ligand atom...
elif cname == "A":
A_count += 1
elif cname == "B":
B_count += 1
self.assertAlmostEqual(sc.score_matrix[0,0],
B_count/(A_count + B_count), 5)
# Same as before but additionally we remove residue TRP.66
# from chain C in the target to test mapping magic...
# Chain C is NOT in contact with the ligand but we only
# add contacts from chain A as penalty that are mappable
# to the closest chain with same sequence. That would be
# chain C
query = "cname=B,G or (cname=C and rnum!=66)"
trg = mol.CreateEntityFromView(prot.Select(query), True)
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True)
TRP66_count = 0
for a in lig.atoms:
close_atoms = mdl.FindWithin(a.GetPos(), sc.lddt_pli_radius)
for ca in close_atoms:
cname = ca.GetChain().GetName()
if cname == "A" and ca.GetResidue().GetNumber().GetNum() == 66:
TRP66_count += 1
self.assertEqual(TRP66_count, 134)
# remove TRP66_count from original penalty
self.assertAlmostEqual(sc.score_matrix[0,0],
B_count/(A_count + B_count - TRP66_count), 5)
# Move a random atom in the model from chain B towards the ligand center
# chain B is also present in the target and interacts with the ligand,
# but that atom would be far away and thus adds to the penalty. Since
# the ligand is small enough, the number of added contacts should be
# exactly the number of ligand atoms.
mdl_ed = mdl.EditXCS()
at = mdl.FindResidue("B", mol.ResNum(8)).FindAtom("NZ")
mdl_ed.SetAtomPos(at, lig.geometric_center)
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True)
# compared to the last assertAlmostEqual, we add the number of ligand
# atoms as additional penalties
self.assertAlmostEqual(sc.score_matrix[0,0],
B_count/(A_count + B_count - TRP66_count + \
lig.GetAtomCount()), 5)
if __name__ == "__main__": if __name__ == "__main__":
from ost import testutils from ost import testutils
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment