Skip to content
Snippets Groups Projects
Commit 8be12ac0 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

ligand scoring: add dict-style access of scoring result in refactored classes

parent 241466d9
Branches
Tags
No related merge requests found
......@@ -221,24 +221,26 @@ class LigandScorer:
# keep track of states
# simple integers instead of enums - documentation of property describes
# encoding
self._states = None
self._states_matrix = None
# score matrices
self._score_matrix = None
self._coverage_matrix = None
self._aux_data = None
self._aux_matrix = None
# assignment and derived data
self._assignment = None
self._score_dict = None
self._aux_dict = None
@property
def states(self):
def states_matrix(self):
""" Encodes states of ligand pairs
Expect a valid score if respective location in this matrix is 0.
Target ligands are in rows, model ligands in columns. States are encoded
as integers <= 9. Larger numbers encode errors for child classes.
* -1: Unknown Error - cannot be matched
* 0: Ligand pair can be matched and valid score is computed.
* 1: Ligand pair has no valid symmetry - cannot be matched.
* 2: Ligand pair has too many symmetries - cannot be matched.
......@@ -249,12 +251,13 @@ class LigandScorer:
0 if this flag is enabled.
* 4: Disconnected graph error - cannot be matched.
Either target ligand or model ligand has disconnected graph.
* 9: Unknown Error - cannot be matched
:rtype: :class:`~numpy.ndarray`
"""
if self._states is None:
if self._states_matrix is None:
self._compute_scores()
return self._states
return self._states_matrix
@property
def score_matrix(self):
......@@ -291,7 +294,7 @@ class LigandScorer:
return self._coverage_matrix
@property
def aux_data(self):
def aux_matrix(self):
""" Get the matrix of scorer specific auxiliary data.
Target ligands are in rows, model ligands in columns.
......@@ -315,7 +318,7 @@ class LigandScorer:
Implements a greedy algorithm to assign target and model ligands
with each other. Starts from each valid ligand pair as indicated
by a state of 0 in :attr:`states`. Each iteration first selects
by a state of 0 in :attr:`states_matrix`. Each iteration first selects
high coverage pairs. Given max_coverage defined as the highest
coverage observed in the available pairs, all pairs with coverage
in [max_coverage-*coverage_delta*, max_coverage] are selected.
......@@ -333,7 +336,7 @@ class LigandScorer:
tmp = list()
for trg_idx in range(self.score_matrix.shape[0]):
for mdl_idx in range(self.score_matrix.shape[1]):
if self.states[trg_idx, mdl_idx] == 0:
if self.states_matrix[trg_idx, mdl_idx] == 0:
tmp.append((self.score_matrix[trg_idx, mdl_idx],
self.coverage_matrix[trg_idx, mdl_idx],
trg_idx, mdl_idx))
......@@ -362,6 +365,50 @@ class LigandScorer:
return self._assignment
@property
def score(self):
"""Get a dictionary of score values, keyed by model ligand
Extract score with something like:
`scorer.score[lig.GetChain().GetName()][lig.GetNumber()]`.
The returned scores are based on :attr:`~assignment`.
:rtype: :class:`dict`
"""
if self._score_dict is None:
self._score_dict = dict()
for (trg_lig_idx, mdl_lig_idx) in self.assignment:
mdl_lig = self.model_ligands[mdl_lig_idx]
cname = mdl_lig.GetChain().GetName()
rnum = mdl_lig.GetNumber()
if cname not in self._score_dict:
self._score_dict[cname] = dict()
score = self.score_matrix[trg_lig_idx, mdl_lig_idx]
self._score_dict[cname][rnum] = score
return self._score_dict
@property
def aux(self):
"""Get a dictionary of score details, keyed by model ligand
Extract dict with something like:
`scorer.score[lig.GetChain().GetName()][lig.GetNumber()]`.
The returned info dicts are based on :attr:`~assignment`. The content is
documented in the respective child class.
:rtype: :class:`dict`
"""
if self._aux_dict is None:
self._aux_dict = dict()
for (trg_lig_idx, mdl_lig_idx) in self.assignment:
mdl_lig = self.model_ligands[mdl_lig_idx]
cname = mdl_lig.GetChain().GetName()
rnum = mdl_lig.GetNumber()
if cname not in self._aux_dict:
self._aux_dict[cname] = dict()
d = self.aux_matrix[trg_lig_idx, mdl_lig_idx]
self._aux_dict[cname][rnum] = d
return self._aux_dict
@property
def _chain_mapper(self):
......@@ -518,8 +565,8 @@ class LigandScorer:
shape = (len(self.target_ligands), len(self.model_ligands))
self._score_matrix = np.full(shape, np.nan, dtype=np.float32)
self._coverage_matrix = np.full(shape, np.nan, dtype=np.float32)
self._states = np.full(shape, -1, dtype=np.int32)
self._aux_data = np.empty(shape, dtype=dict)
self._states_matrix = np.full(shape, -1, dtype=np.int32)
self._aux_matrix = np.empty(shape, dtype=dict)
for target_id, target_ligand in enumerate(self.target_ligands):
LogVerbose("Analyzing target ligand %s" % target_ligand)
......@@ -541,24 +588,24 @@ class LigandScorer:
# Ligands are different - skip
LogVerbose("No symmetry between %s and %s" % (
str(model_ligand), str(target_ligand)))
self._states[target_id, model_id] = 1
self._states_matrix[target_id, model_id] = 1
continue
except TooManySymmetriesError:
# Ligands are too symmetrical - skip
LogVerbose("Too many symmetries between %s and %s" % (
str(model_ligand), str(target_ligand)))
self._states[target_id, model_id] = 2
self._states_matrix[target_id, model_id] = 2
continue
except NoIsomorphicSymmetryError:
# Ligands are different - skip
LogVerbose("No isomorphic symmetry between %s and %s" % (
str(model_ligand), str(target_ligand)))
self._states[target_id, model_id] = 3
self._states_matrix[target_id, model_id] = 3
continue
except DisconnectedGraphError:
LogVerbose("Disconnected graph observed for %s and %s" % (
str(model_ligand), str(target_ligand)))
self._states[target_id, model_id] = 4
self._states_matrix[target_id, model_id] = 4
continue
#####################################################
......@@ -571,11 +618,11 @@ class LigandScorer:
# Finalize #
############
if state != 0:
# non-zero error states up to 4 are reserved for base class
# non-zero error states up to 9 are reserved for base class
if state <= 9:
raise RuntimeError("Child returned reserved err. state")
self._states[target_id, model_id] = state
self._states_matrix[target_id, model_id] = state
if state == 0:
if score is None or np.isnan(score):
raise RuntimeError("LigandScorer returned invalid "
......@@ -584,7 +631,7 @@ class LigandScorer:
self._score_matrix[target_id, model_id] = score
cvg = len(symmetries[0][0]) / len(model_ligand.atoms)
self._coverage_matrix[target_id, model_id] = cvg
self._aux_data[target_id, model_id] = aux
self._aux_matrix[target_id, model_id] = aux
def _compute(self, symmetries, target_ligand, model_ligand):
""" Compute score for specified ligand pair - defined by child class
......@@ -609,7 +656,7 @@ class LigandScorer:
(:class:`float`) 2) state (:class:`int`).
3) auxiliary data for this ligand pair (:class:`dict`).
If state is 0, the score and auxiliary data will be
added to :attr:`~score_matrix` and :attr:`~aux_data` as well
added to :attr:`~score_matrix` and :attr:`~aux_matrix` as well
as the respective value in :attr:`~coverage_matrix`.
Returned score must be valid in this case (not None/NaN).
Child specific non-zero states must be >= 10.
......
......@@ -192,7 +192,7 @@ class TestLigandScoringFancy(unittest.TestCase):
# Check an arbitrary node
self.assertEqual([a for a in graph.adj[13].keys()], [12, 28])
def test__ComputeSymmetries(self):
def test_ComputeSymmetries(self):
"""Test that _ComputeSymmetries works.
"""
trg = _LoadMMCIF("1r8q.cif.gz")
......@@ -433,7 +433,88 @@ class TestLigandScoringFancy(unittest.TestCase):
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg)
self.assertEqual(sc.assignment, [(5, 0)])
def test_dict_results_rmsd(self):
"""Test that the scores are computed correctly
"""
# 4C0A has more ligands
trg = _LoadMMCIF("1r8q.cif.gz")
trg_4c0a = _LoadMMCIF("4c0a.cif.gz")
sc = ligand_scoring_scrmsd.SCRMSDScorer(trg, trg_4c0a, None, None)
expected_keys = {"J", "F"}
self.assertFalse(expected_keys.symmetric_difference(sc.score.keys()))
self.assertFalse(expected_keys.symmetric_difference(sc.aux.keys()))
# rmsd
self.assertAlmostEqual(sc.score["J"][mol.ResNum(1)], 0.8016608357429504, 5)
self.assertAlmostEqual(sc.score["F"][mol.ResNum(1)], 0.9286373257637024, 5)
# rmsd_details
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["chain_mapping"], {'F': 'D', 'C': 'C'})
self.assertEqual(len(sc.aux["J"][mol.ResNum(1)]["bs_ref_res"]), 15)
self.assertEqual(len(sc.aux["J"][mol.ResNum(1)]["bs_ref_res_mapped"]), 15)
self.assertEqual(len(sc.aux["J"][mol.ResNum(1)]["bs_mdl_res_mapped"]), 15)
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["target_ligand"].qualified_name, 'I.G3D1')
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["model_ligand"].qualified_name, 'J.G3D1')
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["chain_mapping"], {'B': 'B', 'G': 'A'})
self.assertEqual(len(sc.aux["F"][mol.ResNum(1)]["bs_ref_res"]), 15)
self.assertEqual(len(sc.aux["F"][mol.ResNum(1)]["bs_ref_res_mapped"]), 15)
self.assertEqual(len(sc.aux["F"][mol.ResNum(1)]["bs_mdl_res_mapped"]), 15)
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["target_ligand"].qualified_name, 'K.G3D1')
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["model_ligand"].qualified_name, 'F.G3D1')
def test_dict_results_lddtpli(self):
"""Test that the scores are computed correctly
"""
# 4C0A has more ligands
trg = _LoadMMCIF("1r8q.cif.gz")
trg_4c0a = _LoadMMCIF("4c0a.cif.gz")
sc = ligand_scoring_lddtpli.LDDTPLIScorer(trg, trg_4c0a, None, None,
check_resnames=False,
add_mdl_contacts=False,
lddt_pli_binding_site_radius = 4.0)
expected_keys = {"J", "F"}
self.assertFalse(expected_keys.symmetric_difference(sc.score.keys()))
self.assertFalse(expected_keys.symmetric_difference(sc.aux.keys()))
# lddt_pli
self.assertAlmostEqual(sc.score["J"][mol.ResNum(1)], 0.9127105666156202, 5)
self.assertAlmostEqual(sc.score["F"][mol.ResNum(1)], 0.915929203539823, 5)
# lddt_pli_details
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["lddt_pli_n_contacts"], 653)
self.assertEqual(len(sc.aux["J"][mol.ResNum(1)]["bs_ref_res"]), 15)
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["target_ligand"].qualified_name, 'I.G3D1')
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["model_ligand"].qualified_name, 'J.G3D1')
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["lddt_pli_n_contacts"], 678)
self.assertEqual(len(sc.aux["F"][mol.ResNum(1)]["bs_ref_res"]), 15)
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["target_ligand"].qualified_name, 'K.G3D1')
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["model_ligand"].qualified_name, 'F.G3D1')
# lddt_pli with added mdl contacts
sc = ligand_scoring_lddtpli.LDDTPLIScorer(trg, trg_4c0a, None, None,
check_resnames=False,
add_mdl_contacts=True)
self.assertAlmostEqual(sc.score["J"][mol.ResNum(1)], 0.8988340192043895, 5)
self.assertAlmostEqual(sc.score["F"][mol.ResNum(1)], 0.9039735099337749, 5)
# lddt_pli_details
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["lddt_pli_n_contacts"], 729)
self.assertEqual(len(sc.aux["J"][mol.ResNum(1)]["bs_ref_res"]), 63)
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["target_ligand"].qualified_name, 'I.G3D1')
self.assertEqual(sc.aux["J"][mol.ResNum(1)]["model_ligand"].qualified_name, 'J.G3D1')
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["lddt_pli_n_contacts"], 755)
self.assertEqual(len(sc.aux["F"][mol.ResNum(1)]["bs_ref_res"]), 62)
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["target_ligand"].qualified_name, 'K.G3D1')
self.assertEqual(sc.aux["F"][mol.ResNum(1)]["model_ligand"].qualified_name, 'F.G3D1')
def test_ignore_binding_site(self):
"""Test that we ignore non polymer stuff in the binding site.
NOTE: we should consider changing this behavior in the future and take
other ligands, peptides and short oligomers into account for superposition.
When that's the case this test should be adapter
"""
trg = _LoadMMCIF("1SSP.cif.gz")
sc = ligand_scoring_scrmsd.SCRMSDScorer(trg, trg, None, None)
expected_bs_ref_res = ['C.GLY62', 'C.GLN63', 'C.ASP64', 'C.PRO65', 'C.TYR66', 'C.CYS76', 'C.PHE77', 'C.ASN123', 'C.HIS187']
ost.PushVerbosityLevel(ost.LogLevel.Error)
self.assertEqual([str(r) for r in sc.aux["D"][1]["bs_ref_res"]], expected_bs_ref_res)
ost.PopVerbosityLevel()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment