Skip to content
Snippets Groups Projects
Unverified Commit 01791c49 authored by Xavier Robin's avatar Xavier Robin
Browse files

test: simplify data loading

parent df07afbb
No related branches found
No related tags found
No related merge requests found
import unittest, os, sys import unittest, os, sys
from functools import lru_cache
import numpy as np import numpy as np
import ost
from ost import io, mol, geom from ost import io, mol, geom
# check if we can import: fails if numpy or scipy not available # check if we can import: fails if numpy or scipy not available
try: try:
...@@ -13,22 +13,42 @@ except ImportError: ...@@ -13,22 +13,42 @@ except ImportError:
"networkx is missing. Ignoring test_ligand_scoring.py tests.") "networkx is missing. Ignoring test_ligand_scoring.py tests.")
sys.exit(0) sys.exit(0)
#ost.PushVerbosityLevel(ost.LogLevel.Debug)
def _GetTestfilePath(filename):
"""Get the path to the test file given filename"""
return os.path.join('testfiles', filename)
@lru_cache(maxsize=None)
def _LoadMMCIF(filename):
path = _GetTestfilePath(filename)
ent, seqres = io.LoadMMCIF(path, seqres=True)
return ent
@lru_cache(maxsize=None)
def _LoadPDB(filename):
path = _GetTestfilePath(filename)
ent = io.LoadPDB(path)
return ent
@lru_cache(maxsize=None)
def _LoadEntity(filename):
path = _GetTestfilePath(filename)
ent = io.LoadEntity(path)
return ent
class TestLigandScoring(unittest.TestCase): class TestLigandScoring(unittest.TestCase):
def test_extract_ligands_mmCIF(self): def test_extract_ligands_mmCIF(self):
"""Test that we can extract ligands from mmCIF files. """Test that we can extract ligands from mmCIF files.
""" """
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) trg = _LoadMMCIF("1r8q.cif.gz")
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True) mdl = _LoadMMCIF("P84080_model_02.cif.gz")
sc = LigandScorer(mdl, trg, None, None) sc = LigandScorer(mdl, trg, None, None)
# import ipdb; ipdb.set_trace()
# import ost.mol.alg.scoring
# scr = ost.mol.alg.scoring.Scorer(sc.model, sc.target)
# scr.lddt
# scr.local_lddt
assert len(sc.target_ligands) == 7 assert len(sc.target_ligands) == 7
assert len(sc.model_ligands) == 1 assert len(sc.model_ligands) == 1
...@@ -39,8 +59,8 @@ class TestLigandScoring(unittest.TestCase): ...@@ -39,8 +59,8 @@ class TestLigandScoring(unittest.TestCase):
"""Test that we can instantiate the scorer with ligands contained in """Test that we can instantiate the scorer with ligands contained in
the target and model entity and given in a list. the target and model entity and given in a list.
""" """
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) trg = _LoadMMCIF("1r8q.cif.gz")
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True) mdl = _LoadMMCIF("P84080_model_02.cif.gz")
# Pass entity views # Pass entity views
trg_lig = [trg.Select("rname=MG"), trg.Select("rname=G3D")] trg_lig = [trg.Select("rname=MG"), trg.Select("rname=G3D")]
...@@ -87,10 +107,10 @@ class TestLigandScoring(unittest.TestCase): ...@@ -87,10 +107,10 @@ class TestLigandScoring(unittest.TestCase):
io.SaveEntity(lig_ent, "%s_ligand_%d.sdf" % (prefix, lig_num)) io.SaveEntity(lig_ent, "%s_ligand_%d.sdf" % (prefix, lig_num))
lig_num += 1 lig_num += 1
""" """
mdl = io.LoadPDB(os.path.join('testfiles', "P84080_model_02_nolig.pdb")) mdl = _LoadPDB("P84080_model_02_nolig.pdb")
mdl_ligs = [io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))] mdl_ligs = [_LoadEntity("P84080_model_02_ligand_0.sdf")]
trg = io.LoadPDB(os.path.join('testfiles', "1r8q_protein.pdb.gz")) trg = _LoadPDB("1r8q_protein.pdb.gz")
trg_ligs = [io.LoadEntity(os.path.join('testfiles', "1r8q_ligand_%d.sdf" % i)) for i in range(7)] trg_ligs = [_LoadEntity("1r8q_ligand_%d.sdf" % i) for i in range(7)]
# Pass entities # Pass entities
sc = LigandScorer(mdl, trg, mdl_ligs, trg_ligs) sc = LigandScorer(mdl, trg, mdl_ligs, trg_ligs)
...@@ -114,18 +134,18 @@ class TestLigandScoring(unittest.TestCase): ...@@ -114,18 +134,18 @@ class TestLigandScoring(unittest.TestCase):
"""Test that we reject input if multiple ligands with the same chain """Test that we reject input if multiple ligands with the same chain
name/residue number are given. name/residue number are given.
""" """
mdl = io.LoadPDB(os.path.join('testfiles', "P84080_model_02_nolig.pdb")) mdl = _LoadPDB("P84080_model_02_nolig.pdb")
mdl_ligs = [io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))] mdl_ligs = [_LoadEntity("P84080_model_02_ligand_0.sdf")]
trg = io.LoadPDB(os.path.join('testfiles', "1r8q_protein.pdb.gz")) trg = _LoadPDB("1r8q_protein.pdb.gz")
trg_ligs = [io.LoadEntity(os.path.join('testfiles', "1r8q_ligand_%d.sdf" % i)) for i in range(7)] trg_ligs = [_LoadEntity("1r8q_ligand_%d.sdf" % i) for i in range(7)]
# Reject identical model ligands # Reject identical model ligands
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
sc = LigandScorer(mdl, trg, [mdl_ligs[0], mdl_ligs[0]], trg_ligs) sc = LigandScorer(mdl, trg, [mdl_ligs[0], mdl_ligs[0]], trg_ligs)
# Reject identical target ligands # Reject identical target ligands
lig0 = trg_ligs[0] lig0 = trg_ligs[0].Copy()
lig1 = trg_ligs[1] lig1 = trg_ligs[1].Copy()
ed1 = lig1.EditXCS() ed1 = lig1.EditXCS()
ed1.RenameChain(lig1.chains[0], lig0.chains[0].name) ed1.RenameChain(lig1.chains[0], lig0.chains[0].name)
ed1.SetResidueNumber(lig1.residues[0], lig0.residues[0].number) ed1.SetResidueNumber(lig1.residues[0], lig0.residues[0].number)
...@@ -135,7 +155,7 @@ class TestLigandScoring(unittest.TestCase): ...@@ -135,7 +155,7 @@ class TestLigandScoring(unittest.TestCase):
def test__ResidueToGraph(self): def test__ResidueToGraph(self):
"""Test that _ResidueToGraph works as expected """Test that _ResidueToGraph works as expected
""" """
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf")) mdl_lig = _LoadEntity("P84080_model_02_ligand_0.sdf")
graph = ligand_scoring._ResidueToGraph(mdl_lig.residues[0]) graph = ligand_scoring._ResidueToGraph(mdl_lig.residues[0])
assert len(graph.edges) == 34 assert len(graph.edges) == 34
...@@ -152,8 +172,8 @@ class TestLigandScoring(unittest.TestCase): ...@@ -152,8 +172,8 @@ class TestLigandScoring(unittest.TestCase):
def test__ComputeSymmetries(self): def test__ComputeSymmetries(self):
"""Test that _ComputeSymmetries works. """Test that _ComputeSymmetries works.
""" """
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) trg = _LoadMMCIF("1r8q.cif.gz")
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True) mdl = _LoadMMCIF("P84080_model_02.cif.gz")
trg_mg1 = trg.FindResidue("E", 1) trg_mg1 = trg.FindResidue("E", 1)
trg_g3d1 = trg.FindResidue("F", 1) trg_g3d1 = trg.FindResidue("F", 1)
...@@ -168,7 +188,7 @@ class TestLigandScoring(unittest.TestCase): ...@@ -168,7 +188,7 @@ class TestLigandScoring(unittest.TestCase):
assert len(sym) == 72 assert len(sym) == 72
# Test that we can match ions read from SDF # Test that we can match ions read from SDF
sdf_lig = io.LoadEntity(os.path.join('testfiles', "1r8q_ligand_0.sdf")) sdf_lig = _LoadEntity("1r8q_ligand_0.sdf")
sym = ligand_scoring._ComputeSymmetries(trg_mg1, sdf_lig.residues[0], by_atom_index=True) sym = ligand_scoring._ComputeSymmetries(trg_mg1, sdf_lig.residues[0], by_atom_index=True)
assert len(sym) == 1 assert len(sym) == 1
...@@ -195,8 +215,8 @@ class TestLigandScoring(unittest.TestCase): ...@@ -195,8 +215,8 @@ class TestLigandScoring(unittest.TestCase):
def test_SCRMSD(self): def test_SCRMSD(self):
"""Test that SCRMSD works. """Test that SCRMSD works.
""" """
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) trg = _LoadMMCIF("1r8q.cif.gz")
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True) mdl = _LoadMMCIF("P84080_model_02.cif.gz")
trg_mg1 = trg.FindResidue("E", 1) trg_mg1 = trg.FindResidue("E", 1)
trg_g3d1 = trg.FindResidue("F", 1) trg_g3d1 = trg.FindResidue("F", 1)
...@@ -240,8 +260,8 @@ class TestLigandScoring(unittest.TestCase): ...@@ -240,8 +260,8 @@ class TestLigandScoring(unittest.TestCase):
def test__compute_scores(self): def test__compute_scores(self):
"""Test that _compute_scores works. """Test that _compute_scores works.
""" """
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) trg = _LoadMMCIF("1r8q.cif.gz")
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True) mdl = _LoadMMCIF("P84080_model_02.cif.gz")
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf")) mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
sc = LigandScorer(mdl, trg, [mdl_lig], None) sc = LigandScorer(mdl, trg, [mdl_lig], None)
...@@ -273,8 +293,8 @@ class TestLigandScoring(unittest.TestCase): ...@@ -273,8 +293,8 @@ class TestLigandScoring(unittest.TestCase):
"""Test check_resname argument works """Test check_resname argument works
""" """
# 4C0A has mismatching sequence and fails with check_resnames=True # 4C0A has mismatching sequence and fails with check_resnames=True
mdl_1r8q, _ = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) mdl_1r8q = _LoadMMCIF("1r8q.cif.gz")
trg_4c0a, _ = io.LoadMMCIF(os.path.join('testfiles', "4c0a.cif.gz"), seqres=True) trg_4c0a = _LoadMMCIF("4c0a.cif.gz")
mdl = mdl_1r8q.Select("cname=D or cname=F") mdl = mdl_1r8q.Select("cname=D or cname=F")
trg = trg_4c0a.Select("cname=C or cname=I") trg = trg_4c0a.Select("cname=C or cname=I")
...@@ -290,8 +310,8 @@ class TestLigandScoring(unittest.TestCase): ...@@ -290,8 +310,8 @@ class TestLigandScoring(unittest.TestCase):
"""Test that the scores are computed correctly """Test that the scores are computed correctly
""" """
# 4C0A has more ligands # 4C0A has more ligands
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) trg = _LoadMMCIF("1r8q.cif.gz")
trg_4c0a, _ = io.LoadMMCIF(os.path.join('testfiles', "4c0a.cif.gz"), seqres=True) trg_4c0a = _LoadMMCIF("4c0a.cif.gz")
sc = LigandScorer(trg, trg_4c0a, None, None, check_resnames=False) sc = LigandScorer(trg, trg_4c0a, None, None, check_resnames=False)
expected_keys = {"J", "F"} expected_keys = {"J", "F"}
...@@ -344,19 +364,18 @@ class TestLigandScoring(unittest.TestCase): ...@@ -344,19 +364,18 @@ class TestLigandScoring(unittest.TestCase):
For RMSD, A: A results in a better chain mapping. However, C: A is a For RMSD, A: A results in a better chain mapping. However, C: A is a
better global chain mapping from an lDDT perspective (and lDDT-PLI). better global chain mapping from an lDDT perspective (and lDDT-PLI).
""" """
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) trg = _LoadMMCIF("1r8q.cif.gz")
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True) mdl = _LoadMMCIF("P84080_model_02.cif.gz")
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
# Local by default # Local by default
sc = LigandScorer(mdl, trg, [mdl_lig], None) sc = LigandScorer(mdl, trg, None, None)
assert sc.rmsd_details["00001_L_2"][1]["chain_mapping"] == {'A': 'A'} assert sc.rmsd_details["L_2"][1]["chain_mapping"] == {'A': 'A'}
assert sc.lddt_pli_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'} assert sc.lddt_pli_details["L_2"][1]["chain_mapping"] == {'C': 'A'}
# Global # Global
sc = LigandScorer(mdl, trg, [mdl_lig], None, global_chain_mapping=True) sc = LigandScorer(mdl, trg, None, None, global_chain_mapping=True)
assert sc.rmsd_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'} assert sc.rmsd_details["L_2"][1]["chain_mapping"] == {'C': 'A'}
assert sc.lddt_pli_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'} assert sc.lddt_pli_details["L_2"][1]["chain_mapping"] == {'C': 'A'}
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment