Skip to content
Snippets Groups Projects
Commit 8f0d268d authored by Studer Gabriel's avatar Studer Gabriel
Browse files

Improve compression in OMF

Adds several features that can all be enabled separately:
- DEFAULT_PEPLIB: OMF stores a ResidueDefinition object for each observed
                  unique compound. No need to dump residue definitions for
                  standard amino acids.
- LOSSY: Reduce coordinate accuracy to 0.1
- AVG_BFACTORS: Optimization specific for models, where all atoms of a residue
                have the same bfactor => store one bfactor per residue instead
                of one bfactor per atom
- ROUND_BFACTORS: Round bfactors, i.e. 42.42 => 42.0
- SKIP_SS: Don't dump secondary structure, loaded objects assign COIL to
           each residue
- INFER_PEP_BONDS: Intra-residue bonds are stored in ResidueDefinition objects.
                   Inter-residue bonds need to be stored separately. No need
                   to do that for peptide bonds as this can be inferred on the
                   fly when loading
parent faa63233
No related branches found
No related tags found
No related merge requests found
......@@ -39,9 +39,19 @@ namespace{
}
void export_omf_io() {
enum_<OMF::OMFOption>("OMFOption")
.value("DEFAULT_PEPLIB", OMF::DEFAULT_PEPLIB)
.value("LOSSY", OMF::LOSSY)
.value("AVG_BFACTORS", OMF::AVG_BFACTORS)
.value("ROUND_BFACTORS", OMF::ROUND_BFACTORS)
.value("SKIP_SS", OMF::SKIP_SS)
.value("INFER_PEP_BONDS", OMF::INFER_PEP_BONDS)
;
class_<OMF, OMFPtr>("OMF",no_init)
.def("FromEntity", &OMF::FromEntity).staticmethod("FromEntity")
.def("FromMMCIF", &OMF::FromMMCIF).staticmethod("FromMMCIF")
.def("FromEntity", &OMF::FromEntity, (arg("ent"), arg("options")=0)).staticmethod("FromEntity")
.def("FromMMCIF", &OMF::FromMMCIF, (arg("ent"), arg("mmcif_info"), arg("options")=0)).staticmethod("FromMMCIF")
.def("FromFile", &OMF::FromFile).staticmethod("FromFile")
.def("FromBytes", &wrap_from_bytes).staticmethod("FromBytes")
.def("ToFile", &OMF::ToFile)
......
This diff is collapsed.
......@@ -103,9 +103,15 @@ struct ChainData {
const std::vector<int>& inter_residue_bond_orders,
std::unordered_map<long, int>& atom_idx_mapper);
void ToStream(std::ostream& stream) const;
void ToStream(std::ostream& stream,
const std::vector<ResidueDefinition>& res_def,
bool lossy, bool avg_bfactors, bool round_bfactors,
bool skip_ss) const;
void FromStream(std::istream& stream);
void FromStream(std::istream& stream,
const std::vector<ResidueDefinition>& res_def,
bool lossy, bool avg_bfactors, bool round_bfactors,
bool skip_ss);
// chain features
String ch_name;
......@@ -127,14 +133,39 @@ struct ChainData {
std::vector<int> bond_orders;
};
class DefaultPepLib{
public:
static DefaultPepLib& Instance() {
static DefaultPepLib instance;
return instance;
}
std::vector<ResidueDefinition> residue_definitions;
private:
DefaultPepLib();
DefaultPepLib(DefaultPepLib const& copy);
DefaultPepLib& operator=(DefaultPepLib const& copy);
};
class OMF {
public:
static OMFPtr FromEntity(const ost::mol::EntityHandle& ent);
enum OMFOption {DEFAULT_PEPLIB = 1, LOSSY = 2, AVG_BFACTORS = 4,
ROUND_BFACTORS = 8, SKIP_SS = 16, INFER_PEP_BONDS = 32};
bool OptionSet(OMFOption opt) const {
return (opt & options_) == opt;
}
static OMFPtr FromEntity(const ost::mol::EntityHandle& ent,
uint8_t options = 0);
static OMFPtr FromMMCIF(const ost::mol::EntityHandle& ent,
const MMCifInfo& info);
const MMCifInfo& info,
uint8_t options = 0);
static OMFPtr FromFile(const String& fn);
......@@ -152,7 +183,7 @@ public:
private:
// only construct with static functions
OMF() { }
OMF(): options_(0) { }
void ToStream(std::ostream& stream) const;
......@@ -174,6 +205,9 @@ private:
std::vector<String> bond_chain_names_;
std::vector<int> bond_atoms_;
std::vector<int> bond_orders_;
// bitfield with options
uint8_t options_;
};
}} //ns
......
import unittest
import math
from ost import geom
from ost import io
def compare_atoms(a1, a2):
if abs(a1.occupancy - a2.occupancy) > 0.01:
def compare_atoms(a1, a2, occupancy_thresh = 0.01, bfactor_thresh = 0.01,
dist_thresh = 0.001):
if abs(a1.occupancy - a2.occupancy) > occupancy_thresh:
return False
if abs(a1.b_factor - a2.b_factor) > 0.01:
if abs(a1.b_factor - a2.b_factor) > bfactor_thresh:
return False
if geom.Distance(a1.GetPos(), a2.GetPos()) > 0.001:
if geom.Distance(a1.GetPos(), a2.GetPos()) > dist_thresh:
return False
if a1.is_hetatom != a2.is_hetatom:
return False
......@@ -15,13 +18,16 @@ def compare_atoms(a1, a2):
return False
return True
def compare_residues(r1, r2):
def compare_residues(r1, r2, at_occupancy_thresh = 0.01,
at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
skip_ss = False):
if r1.GetName() != r2.GetName():
return False
if r1.GetNumber() != r2.GetNumber():
return False
if str(r1.GetSecStructure()) != str(r2.GetSecStructure()):
return False
if skip_ss is False:
if str(r1.GetSecStructure()) != str(r2.GetSecStructure()):
return False
if r1.one_letter_code != r2.one_letter_code:
return False
if r1.chem_type != r2.chem_type:
......@@ -36,15 +42,24 @@ def compare_residues(r1, r2):
for aname in anames:
a1 = r1.FindAtom(aname)
a2 = r2.FindAtom(aname)
if not compare_atoms(a1, a2):
if not compare_atoms(a1, a2,
occupancy_thresh = at_occupancy_thresh,
bfactor_thresh = at_bfactor_thresh,
dist_thresh = at_dist_thresh):
return False
return True
def compare_chains(ch1, ch2):
def compare_chains(ch1, ch2, at_occupancy_thresh = 0.01,
at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
skip_ss=False):
if len(ch1.residues) != len(ch2.residues):
return False
for r1, r2 in zip(ch1.residues, ch2.residues):
if not compare_residues(r1, r2):
if not compare_residues(r1, r2,
at_occupancy_thresh = at_occupancy_thresh,
at_bfactor_thresh = at_bfactor_thresh,
at_dist_thresh = at_dist_thresh,
skip_ss = skip_ss):
return False
return True
......@@ -59,7 +74,9 @@ def compare_bonds(ent1, ent2):
bonds2.append([min(bond_partners), max(bond_partners), b.bond_order])
return sorted(bonds1) == sorted(bonds2)
def compare_ent(ent1, ent2):
def compare_ent(ent1, ent2, at_occupancy_thresh = 0.01,
at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
skip_ss=False):
chain_names_one = [ch.GetName() for ch in ent1.chains]
chain_names_two = [ch.GetName() for ch in ent2.chains]
if not sorted(chain_names_one) == sorted(chain_names_two):
......@@ -68,23 +85,124 @@ def compare_ent(ent1, ent2):
for chain_name in chain_names:
ch1 = ent1.FindChain(chain_name)
ch2 = ent2.FindChain(chain_name)
if not compare_chains(ch1, ch2):
if not compare_chains(ch1, ch2,
at_occupancy_thresh = at_occupancy_thresh,
at_bfactor_thresh = at_bfactor_thresh,
at_dist_thresh = at_dist_thresh,
skip_ss=skip_ss):
return False
if not compare_bonds(ent1, ent2):
return False
return True
class TestOMF(unittest.TestCase):
def test_AU(self):
def setUp(self):
ent, seqres, info = io.LoadMMCIF("testfiles/mmcif/3T6C.cif.gz",
seqres=True,
info=True)
omf = io.OMF.FromMMCIF(ent, info)
self.ent = ent
self.seqres = seqres
self.info = info
def test_AU(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
loaded_omf = io.OMF.FromBytes(omf_bytes)
loaded_ent = loaded_omf.GetAU()
self.assertTrue(compare_ent(ent, loaded_ent))
self.assertTrue(compare_ent(self.ent, loaded_ent))
def test_default_peplib(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_def_pep = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.DEFAULT_PEPLIB)
omf_def_pep_bytes = omf_def_pep.ToBytes()
loaded_omf_def_pep = io.OMF.FromBytes(omf_def_pep_bytes)
loaded_ent = loaded_omf_def_pep.GetAU()
self.assertTrue(len(omf_def_pep_bytes) < len(omf_bytes))
self.assertTrue(compare_ent(self.ent, loaded_ent))
def test_lossy(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_lossy = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.LOSSY)
omf_lossy_bytes = omf_lossy.ToBytes()
loaded_omf_lossy = io.OMF.FromBytes(omf_lossy_bytes)
loaded_ent = loaded_omf_lossy.GetAU()
self.assertTrue(len(omf_lossy_bytes) < len(omf_bytes))
self.assertFalse(compare_ent(self.ent, loaded_ent))
max_dist = math.sqrt(3*0.05*0.05)
self.assertTrue(compare_ent(self.ent, loaded_ent,
at_dist_thresh=max_dist))
def test_avg_bfactors(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_avg_bfac = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.AVG_BFACTORS)
omf_avg_bfac_bytes = omf_avg_bfac.ToBytes()
loaded_omf_avg_bfac = io.OMF.FromBytes(omf_avg_bfac_bytes)
loaded_ent = loaded_omf_avg_bfac.GetAU()
self.assertTrue(len(omf_avg_bfac_bytes) < len(omf_bytes))
self.assertFalse(compare_ent(self.ent, loaded_ent))
# just give a huge slack for bfactors and check averaging manually
self.assertTrue(compare_ent(self.ent, loaded_ent,
at_bfactor_thresh=1000))
self.assertEqual(len(self.ent.residues), len(loaded_ent.residues))
for r_ref, r in zip(self.ent.residues, loaded_ent.residues):
exp_bfac = sum([a.b_factor for a in r_ref.atoms])
exp_bfac /= r_ref.atom_count
for a in r.atoms:
self.assertTrue(abs(a.b_factor - exp_bfac) < 0.008)
def test_round_bfactors(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_round_bfac = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.ROUND_BFACTORS)
omf_round_bfac_bytes = omf_round_bfac.ToBytes()
loaded_omf_round_bfac = io.OMF.FromBytes(omf_round_bfac_bytes)
loaded_ent = loaded_omf_round_bfac.GetAU()
self.assertTrue(len(omf_round_bfac_bytes) < len(omf_bytes))
self.assertFalse(compare_ent(self.ent, loaded_ent))
self.assertTrue(compare_ent(self.ent, loaded_ent,
at_bfactor_thresh=0.5))
def test_skip_ss(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_skip_ss = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.SKIP_SS)
omf_skip_ss_bytes = omf_skip_ss.ToBytes()
loaded_omf_skip_ss = io.OMF.FromBytes(omf_skip_ss_bytes)
loaded_ent = loaded_omf_skip_ss.GetAU()
self.assertTrue(len(omf_skip_ss_bytes) < len(omf_bytes))
self.assertFalse(compare_ent(self.ent, loaded_ent))
self.assertTrue(compare_ent(self.ent, loaded_ent, skip_ss=True))
def test_infer_pep_bonds(self):
omf = io.OMF.FromMMCIF(self.ent, self.info)
omf_bytes = omf.ToBytes()
omf_infer_pep_bonds = io.OMF.FromMMCIF(self.ent, self.info,
io.OMFOption.INFER_PEP_BONDS)
omf_infer_pep_bonds_bytes = omf_infer_pep_bonds.ToBytes()
loaded_omf_infer_pep_bonds = io.OMF.FromBytes(omf_infer_pep_bonds_bytes)
loaded_ent = loaded_omf_infer_pep_bonds.GetAU()
self.assertTrue(len(omf_infer_pep_bonds_bytes) < len(omf_bytes))
self.assertTrue(compare_ent(self.ent, loaded_ent))
if __name__== '__main__':
from ost import testutils
testutils.RunTests()
from ost import testutils
if testutils.SetDefaultCompoundLib():
testutils.RunTests()
else:
print('No compound library available. Ignoring test_stereochemistry.py tests.')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment