Skip to content
Snippets Groups Projects
Unverified Commit 01791c49 authored by Xavier Robin's avatar Xavier Robin
Browse files

test: simplify data loading

parent df07afbb
Branches
Tags
No related merge requests found
import unittest, os, sys
from functools import lru_cache
import numpy as np
import ost
from ost import io, mol, geom
# check if we can import: fails if numpy or scipy not available
try:
......@@ -13,22 +13,42 @@ except ImportError:
"networkx is missing. Ignoring test_ligand_scoring.py tests.")
sys.exit(0)
#ost.PushVerbosityLevel(ost.LogLevel.Debug)
def _GetTestfilePath(filename):
"""Get the path to the test file given filename"""
return os.path.join('testfiles', filename)
@lru_cache(maxsize=None)
def _LoadMMCIF(filename):
path = _GetTestfilePath(filename)
ent, seqres = io.LoadMMCIF(path, seqres=True)
return ent
@lru_cache(maxsize=None)
def _LoadPDB(filename):
path = _GetTestfilePath(filename)
ent = io.LoadPDB(path)
return ent
@lru_cache(maxsize=None)
def _LoadEntity(filename):
path = _GetTestfilePath(filename)
ent = io.LoadEntity(path)
return ent
class TestLigandScoring(unittest.TestCase):
def test_extract_ligands_mmCIF(self):
"""Test that we can extract ligands from mmCIF files.
"""
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True)
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
sc = LigandScorer(mdl, trg, None, None)
# import ipdb; ipdb.set_trace()
# import ost.mol.alg.scoring
# scr = ost.mol.alg.scoring.Scorer(sc.model, sc.target)
# scr.lddt
# scr.local_lddt
assert len(sc.target_ligands) == 7
assert len(sc.model_ligands) == 1
......@@ -39,8 +59,8 @@ class TestLigandScoring(unittest.TestCase):
"""Test that we can instantiate the scorer with ligands contained in
the target and model entity and given in a list.
"""
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True)
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
# Pass entity views
trg_lig = [trg.Select("rname=MG"), trg.Select("rname=G3D")]
......@@ -87,10 +107,10 @@ class TestLigandScoring(unittest.TestCase):
io.SaveEntity(lig_ent, "%s_ligand_%d.sdf" % (prefix, lig_num))
lig_num += 1
"""
mdl = io.LoadPDB(os.path.join('testfiles', "P84080_model_02_nolig.pdb"))
mdl_ligs = [io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))]
trg = io.LoadPDB(os.path.join('testfiles', "1r8q_protein.pdb.gz"))
trg_ligs = [io.LoadEntity(os.path.join('testfiles', "1r8q_ligand_%d.sdf" % i)) for i in range(7)]
mdl = _LoadPDB("P84080_model_02_nolig.pdb")
mdl_ligs = [_LoadEntity("P84080_model_02_ligand_0.sdf")]
trg = _LoadPDB("1r8q_protein.pdb.gz")
trg_ligs = [_LoadEntity("1r8q_ligand_%d.sdf" % i) for i in range(7)]
# Pass entities
sc = LigandScorer(mdl, trg, mdl_ligs, trg_ligs)
......@@ -114,18 +134,18 @@ class TestLigandScoring(unittest.TestCase):
"""Test that we reject input if multiple ligands with the same chain
name/residue number are given.
"""
mdl = io.LoadPDB(os.path.join('testfiles', "P84080_model_02_nolig.pdb"))
mdl_ligs = [io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))]
trg = io.LoadPDB(os.path.join('testfiles', "1r8q_protein.pdb.gz"))
trg_ligs = [io.LoadEntity(os.path.join('testfiles', "1r8q_ligand_%d.sdf" % i)) for i in range(7)]
mdl = _LoadPDB("P84080_model_02_nolig.pdb")
mdl_ligs = [_LoadEntity("P84080_model_02_ligand_0.sdf")]
trg = _LoadPDB("1r8q_protein.pdb.gz")
trg_ligs = [_LoadEntity("1r8q_ligand_%d.sdf" % i) for i in range(7)]
# Reject identical model ligands
with self.assertRaises(RuntimeError):
sc = LigandScorer(mdl, trg, [mdl_ligs[0], mdl_ligs[0]], trg_ligs)
# Reject identical target ligands
lig0 = trg_ligs[0]
lig1 = trg_ligs[1]
lig0 = trg_ligs[0].Copy()
lig1 = trg_ligs[1].Copy()
ed1 = lig1.EditXCS()
ed1.RenameChain(lig1.chains[0], lig0.chains[0].name)
ed1.SetResidueNumber(lig1.residues[0], lig0.residues[0].number)
......@@ -135,7 +155,7 @@ class TestLigandScoring(unittest.TestCase):
def test__ResidueToGraph(self):
"""Test that _ResidueToGraph works as expected
"""
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
mdl_lig = _LoadEntity("P84080_model_02_ligand_0.sdf")
graph = ligand_scoring._ResidueToGraph(mdl_lig.residues[0])
assert len(graph.edges) == 34
......@@ -152,8 +172,8 @@ class TestLigandScoring(unittest.TestCase):
def test__ComputeSymmetries(self):
"""Test that _ComputeSymmetries works.
"""
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True)
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
trg_mg1 = trg.FindResidue("E", 1)
trg_g3d1 = trg.FindResidue("F", 1)
......@@ -168,7 +188,7 @@ class TestLigandScoring(unittest.TestCase):
assert len(sym) == 72
# Test that we can match ions read from SDF
sdf_lig = io.LoadEntity(os.path.join('testfiles', "1r8q_ligand_0.sdf"))
sdf_lig = _LoadEntity("1r8q_ligand_0.sdf")
sym = ligand_scoring._ComputeSymmetries(trg_mg1, sdf_lig.residues[0], by_atom_index=True)
assert len(sym) == 1
......@@ -195,8 +215,8 @@ class TestLigandScoring(unittest.TestCase):
def test_SCRMSD(self):
"""Test that SCRMSD works.
"""
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True)
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
trg_mg1 = trg.FindResidue("E", 1)
trg_g3d1 = trg.FindResidue("F", 1)
......@@ -240,8 +260,8 @@ class TestLigandScoring(unittest.TestCase):
def test__compute_scores(self):
"""Test that _compute_scores works.
"""
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True)
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
sc = LigandScorer(mdl, trg, [mdl_lig], None)
......@@ -273,8 +293,8 @@ class TestLigandScoring(unittest.TestCase):
"""Test check_resname argument works
"""
# 4C0A has mismatching sequence and fails with check_resnames=True
mdl_1r8q, _ = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
trg_4c0a, _ = io.LoadMMCIF(os.path.join('testfiles', "4c0a.cif.gz"), seqres=True)
mdl_1r8q = _LoadMMCIF("1r8q.cif.gz")
trg_4c0a = _LoadMMCIF("4c0a.cif.gz")
mdl = mdl_1r8q.Select("cname=D or cname=F")
trg = trg_4c0a.Select("cname=C or cname=I")
......@@ -290,8 +310,8 @@ class TestLigandScoring(unittest.TestCase):
"""Test that the scores are computed correctly
"""
# 4C0A has more ligands
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
trg_4c0a, _ = io.LoadMMCIF(os.path.join('testfiles', "4c0a.cif.gz"), seqres=True)
trg = _LoadMMCIF("1r8q.cif.gz")
trg_4c0a = _LoadMMCIF("4c0a.cif.gz")
sc = LigandScorer(trg, trg_4c0a, None, None, check_resnames=False)
expected_keys = {"J", "F"}
......@@ -344,19 +364,18 @@ class TestLigandScoring(unittest.TestCase):
For RMSD, A: A results in a better chain mapping. However, C: A is a
better global chain mapping from an lDDT perspective (and lDDT-PLI).
"""
trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True)
mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True)
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
# Local by default
sc = LigandScorer(mdl, trg, [mdl_lig], None)
assert sc.rmsd_details["00001_L_2"][1]["chain_mapping"] == {'A': 'A'}
assert sc.lddt_pli_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'}
sc = LigandScorer(mdl, trg, None, None)
assert sc.rmsd_details["L_2"][1]["chain_mapping"] == {'A': 'A'}
assert sc.lddt_pli_details["L_2"][1]["chain_mapping"] == {'C': 'A'}
# Global
sc = LigandScorer(mdl, trg, [mdl_lig], None, global_chain_mapping=True)
assert sc.rmsd_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'}
assert sc.lddt_pli_details["00001_L_2"][1]["chain_mapping"] == {'C': 'A'}
sc = LigandScorer(mdl, trg, None, None, global_chain_mapping=True)
assert sc.rmsd_details["L_2"][1]["chain_mapping"] == {'C': 'A'}
assert sc.lddt_pli_details["L_2"][1]["chain_mapping"] == {'C': 'A'}
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment