Skip to content
Snippets Groups Projects
Commit 8f0d268d authored by Studer Gabriel's avatar Studer Gabriel
Browse files

Improve compression in OMF

Adds several features that can all be enabled separately:
- DEFAULT_PEPLIB: OMF stores a ResidueDefinition object for each observed
                  unique compound. No need to dump residue definitions for
                  standard amino acids.
- LOSSY: Reduce coordinate accuracy to 0.1
- AVG_BFACTORS: Optimization specific for models, where all atoms of a residue
                have the same bfactor => store one bfactor per residue instead
                of one bfactor per atom
- ROUND_BFACTORS: Round bfactors, i.e. 42.42 => 42.0
- SKIP_SS: Don't dump secondary structure, loaded objects assign COIL to
           each residue
- INFER_PEP_BONDS: Intra-residue bonds are stored in ResidueDefinition objects.
                   Inter-residue bonds need to be stored separately. No need
                   to do that for peptide bonds as this can be inferred on the
                   fly when loading
parent faa63233
No related branches found
No related tags found
No related merge requests found
...@@ -39,9 +39,19 @@ namespace{ ...@@ -39,9 +39,19 @@ namespace{
} }
void export_omf_io() { void export_omf_io() {
enum_<OMF::OMFOption>("OMFOption")
.value("DEFAULT_PEPLIB", OMF::DEFAULT_PEPLIB)
.value("LOSSY", OMF::LOSSY)
.value("AVG_BFACTORS", OMF::AVG_BFACTORS)
.value("ROUND_BFACTORS", OMF::ROUND_BFACTORS)
.value("SKIP_SS", OMF::SKIP_SS)
.value("INFER_PEP_BONDS", OMF::INFER_PEP_BONDS)
;
class_<OMF, OMFPtr>("OMF",no_init) class_<OMF, OMFPtr>("OMF",no_init)
.def("FromEntity", &OMF::FromEntity).staticmethod("FromEntity") .def("FromEntity", &OMF::FromEntity, (arg("ent"), arg("options")=0)).staticmethod("FromEntity")
.def("FromMMCIF", &OMF::FromMMCIF).staticmethod("FromMMCIF") .def("FromMMCIF", &OMF::FromMMCIF, (arg("ent"), arg("mmcif_info"), arg("options")=0)).staticmethod("FromMMCIF")
.def("FromFile", &OMF::FromFile).staticmethod("FromFile") .def("FromFile", &OMF::FromFile).staticmethod("FromFile")
.def("FromBytes", &wrap_from_bytes).staticmethod("FromBytes") .def("FromBytes", &wrap_from_bytes).staticmethod("FromBytes")
.def("ToFile", &OMF::ToFile) .def("ToFile", &OMF::ToFile)
......
This diff is collapsed.
...@@ -103,9 +103,15 @@ struct ChainData { ...@@ -103,9 +103,15 @@ struct ChainData {
const std::vector<int>& inter_residue_bond_orders, const std::vector<int>& inter_residue_bond_orders,
std::unordered_map<long, int>& atom_idx_mapper); std::unordered_map<long, int>& atom_idx_mapper);
void ToStream(std::ostream& stream) const; void ToStream(std::ostream& stream,
const std::vector<ResidueDefinition>& res_def,
bool lossy, bool avg_bfactors, bool round_bfactors,
bool skip_ss) const;
void FromStream(std::istream& stream); void FromStream(std::istream& stream,
const std::vector<ResidueDefinition>& res_def,
bool lossy, bool avg_bfactors, bool round_bfactors,
bool skip_ss);
// chain features // chain features
String ch_name; String ch_name;
...@@ -127,14 +133,39 @@ struct ChainData { ...@@ -127,14 +133,39 @@ struct ChainData {
std::vector<int> bond_orders; std::vector<int> bond_orders;
}; };
class DefaultPepLib{
public:
static DefaultPepLib& Instance() {
static DefaultPepLib instance;
return instance;
}
std::vector<ResidueDefinition> residue_definitions;
private:
DefaultPepLib();
DefaultPepLib(DefaultPepLib const& copy);
DefaultPepLib& operator=(DefaultPepLib const& copy);
};
class OMF { class OMF {
public: public:
static OMFPtr FromEntity(const ost::mol::EntityHandle& ent); enum OMFOption {DEFAULT_PEPLIB = 1, LOSSY = 2, AVG_BFACTORS = 4,
ROUND_BFACTORS = 8, SKIP_SS = 16, INFER_PEP_BONDS = 32};
bool OptionSet(OMFOption opt) const {
return (opt & options_) == opt;
}
static OMFPtr FromEntity(const ost::mol::EntityHandle& ent,
uint8_t options = 0);
static OMFPtr FromMMCIF(const ost::mol::EntityHandle& ent, static OMFPtr FromMMCIF(const ost::mol::EntityHandle& ent,
const MMCifInfo& info); const MMCifInfo& info,
uint8_t options = 0);
static OMFPtr FromFile(const String& fn); static OMFPtr FromFile(const String& fn);
...@@ -152,7 +183,7 @@ public: ...@@ -152,7 +183,7 @@ public:
private: private:
// only construct with static functions // only construct with static functions
OMF() { } OMF(): options_(0) { }
void ToStream(std::ostream& stream) const; void ToStream(std::ostream& stream) const;
...@@ -174,6 +205,9 @@ private: ...@@ -174,6 +205,9 @@ private:
std::vector<String> bond_chain_names_; std::vector<String> bond_chain_names_;
std::vector<int> bond_atoms_; std::vector<int> bond_atoms_;
std::vector<int> bond_orders_; std::vector<int> bond_orders_;
// bitfield with options
uint8_t options_;
}; };
}} //ns }} //ns
......
import unittest import unittest
import math
from ost import geom from ost import geom
from ost import io from ost import io
def compare_atoms(a1, a2): def compare_atoms(a1, a2, occupancy_thresh = 0.01, bfactor_thresh = 0.01,
if abs(a1.occupancy - a2.occupancy) > 0.01: dist_thresh = 0.001):
if abs(a1.occupancy - a2.occupancy) > occupancy_thresh:
return False return False
if abs(a1.b_factor - a2.b_factor) > 0.01: if abs(a1.b_factor - a2.b_factor) > bfactor_thresh:
return False return False
if geom.Distance(a1.GetPos(), a2.GetPos()) > 0.001: if geom.Distance(a1.GetPos(), a2.GetPos()) > dist_thresh:
return False return False
if a1.is_hetatom != a2.is_hetatom: if a1.is_hetatom != a2.is_hetatom:
return False return False
...@@ -15,13 +18,16 @@ def compare_atoms(a1, a2): ...@@ -15,13 +18,16 @@ def compare_atoms(a1, a2):
return False return False
return True return True
def compare_residues(r1, r2): def compare_residues(r1, r2, at_occupancy_thresh = 0.01,
at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
skip_ss = False):
if r1.GetName() != r2.GetName(): if r1.GetName() != r2.GetName():
return False return False
if r1.GetNumber() != r2.GetNumber(): if r1.GetNumber() != r2.GetNumber():
return False return False
if str(r1.GetSecStructure()) != str(r2.GetSecStructure()): if skip_ss is False:
return False if str(r1.GetSecStructure()) != str(r2.GetSecStructure()):
return False
if r1.one_letter_code != r2.one_letter_code: if r1.one_letter_code != r2.one_letter_code:
return False return False
if r1.chem_type != r2.chem_type: if r1.chem_type != r2.chem_type:
...@@ -36,15 +42,24 @@ def compare_residues(r1, r2): ...@@ -36,15 +42,24 @@ def compare_residues(r1, r2):
for aname in anames: for aname in anames:
a1 = r1.FindAtom(aname) a1 = r1.FindAtom(aname)
a2 = r2.FindAtom(aname) a2 = r2.FindAtom(aname)
if not compare_atoms(a1, a2): if not compare_atoms(a1, a2,
occupancy_thresh = at_occupancy_thresh,
bfactor_thresh = at_bfactor_thresh,
dist_thresh = at_dist_thresh):
return False return False
return True return True
def compare_chains(ch1, ch2): def compare_chains(ch1, ch2, at_occupancy_thresh = 0.01,
at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
skip_ss=False):
if len(ch1.residues) != len(ch2.residues): if len(ch1.residues) != len(ch2.residues):
return False return False
for r1, r2 in zip(ch1.residues, ch2.residues): for r1, r2 in zip(ch1.residues, ch2.residues):
if not compare_residues(r1, r2): if not compare_residues(r1, r2,
at_occupancy_thresh = at_occupancy_thresh,
at_bfactor_thresh = at_bfactor_thresh,
at_dist_thresh = at_dist_thresh,
skip_ss = skip_ss):
return False return False
return True return True
...@@ -59,7 +74,9 @@ def compare_bonds(ent1, ent2): ...@@ -59,7 +74,9 @@ def compare_bonds(ent1, ent2):
bonds2.append([min(bond_partners), max(bond_partners), b.bond_order]) bonds2.append([min(bond_partners), max(bond_partners), b.bond_order])
return sorted(bonds1) == sorted(bonds2) return sorted(bonds1) == sorted(bonds2)
def compare_ent(ent1, ent2): def compare_ent(ent1, ent2, at_occupancy_thresh = 0.01,
at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
skip_ss=False):
chain_names_one = [ch.GetName() for ch in ent1.chains] chain_names_one = [ch.GetName() for ch in ent1.chains]
chain_names_two = [ch.GetName() for ch in ent2.chains] chain_names_two = [ch.GetName() for ch in ent2.chains]
if not sorted(chain_names_one) == sorted(chain_names_two): if not sorted(chain_names_one) == sorted(chain_names_two):
...@@ -68,23 +85,124 @@ def compare_ent(ent1, ent2): ...@@ -68,23 +85,124 @@ def compare_ent(ent1, ent2):
for chain_name in chain_names: for chain_name in chain_names:
ch1 = ent1.FindChain(chain_name) ch1 = ent1.FindChain(chain_name)
ch2 = ent2.FindChain(chain_name) ch2 = ent2.FindChain(chain_name)
if not compare_chains(ch1, ch2): if not compare_chains(ch1, ch2,
at_occupancy_thresh = at_occupancy_thresh,
at_bfactor_thresh = at_bfactor_thresh,
at_dist_thresh = at_dist_thresh,
skip_ss=skip_ss):
return False return False
if not compare_bonds(ent1, ent2): if not compare_bonds(ent1, ent2):
return False return False
return True return True
class TestOMF(unittest.TestCase): class TestOMF(unittest.TestCase):
def test_AU(self):
def setUp(self):
ent, seqres, info = io.LoadMMCIF("testfiles/mmcif/3T6C.cif.gz", ent, seqres, info = io.LoadMMCIF("testfiles/mmcif/3T6C.cif.gz",
seqres=True, seqres=True,
info=True) info=True)
omf = io.OMF.FromMMCIF(ent, info) self.ent = ent
self.seqres = seqres
self.info = info
def test_AU(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes() omf_bytes = omf.ToBytes()
loaded_omf = io.OMF.FromBytes(omf_bytes) loaded_omf = io.OMF.FromBytes(omf_bytes)
loaded_ent = loaded_omf.GetAU() loaded_ent = loaded_omf.GetAU()
self.assertTrue(compare_ent(ent, loaded_ent)) self.assertTrue(compare_ent(self.ent, loaded_ent))
def test_default_peplib(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_def_pep = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.DEFAULT_PEPLIB)
omf_def_pep_bytes = omf_def_pep.ToBytes()
loaded_omf_def_pep = io.OMF.FromBytes(omf_def_pep_bytes)
loaded_ent = loaded_omf_def_pep.GetAU()
self.assertTrue(len(omf_def_pep_bytes) < len(omf_bytes))
self.assertTrue(compare_ent(self.ent, loaded_ent))
def test_lossy(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_lossy = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.LOSSY)
omf_lossy_bytes = omf_lossy.ToBytes()
loaded_omf_lossy = io.OMF.FromBytes(omf_lossy_bytes)
loaded_ent = loaded_omf_lossy.GetAU()
self.assertTrue(len(omf_lossy_bytes) < len(omf_bytes))
self.assertFalse(compare_ent(self.ent, loaded_ent))
max_dist = math.sqrt(3*0.05*0.05)
self.assertTrue(compare_ent(self.ent, loaded_ent,
at_dist_thresh=max_dist))
def test_avg_bfactors(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_avg_bfac = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.AVG_BFACTORS)
omf_avg_bfac_bytes = omf_avg_bfac.ToBytes()
loaded_omf_avg_bfac = io.OMF.FromBytes(omf_avg_bfac_bytes)
loaded_ent = loaded_omf_avg_bfac.GetAU()
self.assertTrue(len(omf_avg_bfac_bytes) < len(omf_bytes))
self.assertFalse(compare_ent(self.ent, loaded_ent))
# just give a huge slack for bfactors and check averaging manually
self.assertTrue(compare_ent(self.ent, loaded_ent,
at_bfactor_thresh=1000))
self.assertEqual(len(self.ent.residues), len(loaded_ent.residues))
for r_ref, r in zip(self.ent.residues, loaded_ent.residues):
exp_bfac = sum([a.b_factor for a in r_ref.atoms])
exp_bfac /= r_ref.atom_count
for a in r.atoms:
self.assertTrue(abs(a.b_factor - exp_bfac) < 0.008)
def test_round_bfactors(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_round_bfac = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.ROUND_BFACTORS)
omf_round_bfac_bytes = omf_round_bfac.ToBytes()
loaded_omf_round_bfac = io.OMF.FromBytes(omf_round_bfac_bytes)
loaded_ent = loaded_omf_round_bfac.GetAU()
self.assertTrue(len(omf_round_bfac_bytes) < len(omf_bytes))
self.assertFalse(compare_ent(self.ent, loaded_ent))
self.assertTrue(compare_ent(self.ent, loaded_ent,
at_bfactor_thresh=0.5))
def test_skip_ss(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_skip_ss = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.SKIP_SS)
omf_skip_ss_bytes = omf_skip_ss.ToBytes()
loaded_omf_skip_ss = io.OMF.FromBytes(omf_skip_ss_bytes)
loaded_ent = loaded_omf_skip_ss.GetAU()
self.assertTrue(len(omf_skip_ss_bytes) < len(omf_bytes))
self.assertFalse(compare_ent(self.ent, loaded_ent))
self.assertTrue(compare_ent(self.ent, loaded_ent, skip_ss=True))
def test_infer_pep_bonds(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_infer_pep_bonds = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.INFER_PEP_BONDS)
omf_infer_pep_bonds_bytes = omf_infer_pep_bonds.ToBytes()
loaded_omf_infer_pep_bonds = io.OMF.FromBytes(omf_infer_pep_bonds_bytes)
loaded_ent = loaded_omf_infer_pep_bonds.GetAU()
self.assertTrue(len(omf_infer_pep_bonds_bytes) < len(omf_bytes))
self.assertTrue(compare_ent(self.ent, loaded_ent))
if __name__== '__main__': if __name__== '__main__':
from ost import testutils from ost import testutils
testutils.RunTests() if testutils.SetDefaultCompoundLib():
testutils.RunTests()
else:
print('No compound library available. Ignoring test_stereochemistry.py tests.')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment