From 9b144237890bdacbac72adcade142b588f166082 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 21 Jul 2023 00:29:35 +0200
Subject: [PATCH] omf: update rotamer/oxygen compression

Only perform compression if reconstruction error is < 0.5A. Info whether
compression is performed is stored in a bitfield.
---
 modules/io/pymod/export_omf_io.cc |   1 -
 modules/io/src/mol/omf.cc         | 274 ++++++++++++++++++++++--------
 modules/io/src/mol/omf.hh         |   7 +-
 3 files changed, 206 insertions(+), 76 deletions(-)

diff --git a/modules/io/pymod/export_omf_io.cc b/modules/io/pymod/export_omf_io.cc
index d2df0e475..43355a4a9 100644
--- a/modules/io/pymod/export_omf_io.cc
+++ b/modules/io/pymod/export_omf_io.cc
@@ -61,7 +61,6 @@ void export_omf_io() {
     .value("ROUND_BFACTORS", OMF::ROUND_BFACTORS)
     .value("SKIP_SS", OMF::SKIP_SS)
     .value("INFER_PEP_BONDS", OMF::INFER_PEP_BONDS)
-    .value("INFER_AA_POS", OMF::INFER_AA_POS)
   ;
 
   class_<OMF, OMFPtr>("OMF",no_init)
diff --git a/modules/io/src/mol/omf.cc b/modules/io/src/mol/omf.cc
index afe2b9a71..647f4f96f 100644
--- a/modules/io/src/mol/omf.cc
+++ b/modules/io/src/mol/omf.cc
@@ -388,14 +388,14 @@ namespace{
             std::map<String, ost::io::ChainDataPtr>& map,
             const std::vector<ost::io::ResidueDefinition>& res_def,
             int version, bool lossy, bool avg_bfactors, bool round_bfactors,
-            bool skip_ss, bool infer_aa_pos) {
+            bool skip_ss) {
     uint32_t size;
     stream.read(reinterpret_cast<char*>(&size), sizeof(uint32_t));
     map.clear();
     for(uint i = 0; i < size; ++i) {
       ost::io::ChainDataPtr p(new ost::io::ChainData);
       p->FromStream(stream, res_def, version, lossy, avg_bfactors,
-                    round_bfactors, skip_ss, infer_aa_pos);
+                    round_bfactors, skip_ss);
       map[p->ch_name] = p;
     }
   }
@@ -404,14 +404,14 @@ namespace{
             const std::map<String, ost::io::ChainDataPtr>& map,
             const std::vector<ost::io::ResidueDefinition>& res_def,
             bool lossy, bool avg_bfactors, bool round_bfactors,
-            bool skip_ss, bool infer_aa_pos) {
+            bool skip_ss) {
     uint32_t size = map.size();
     stream.write(reinterpret_cast<char*>(&size), sizeof(uint32_t));
     for(auto it = map.begin(); it != map.end(); ++it) {
         // we don't dump the key (chain name), that's an attribute of the
         // chain itself anyway
       it->second->ToStream(stream, res_def, lossy, avg_bfactors,
-                           round_bfactors, skip_ss, infer_aa_pos); 
+                           round_bfactors, skip_ss); 
     }
   }
 
@@ -548,6 +548,33 @@ namespace{
     }
   }
 
+  void Dump(std::ostream& stream,
+            const std::vector<bool>& vec) {
+    uint32_t size = vec.size();
+    uint32_t n_bytes = std::ceil(static_cast<Real>(size)/8);
+    std::vector<uint8_t> bit_vector(n_bytes, 0);
+    for(size_t i = 0; i < size; ++i) {
+      if(vec[i]) bit_vector[i/8] += (1 << (i%8));
+    }
+    stream.write(reinterpret_cast<char*>(&size), sizeof(uint32_t));
+    stream.write(reinterpret_cast<char*>(&bit_vector[0]),
+                                         n_bytes * sizeof(uint8_t));
+  }
+
+  void Load(std::istream& stream,
+            std::vector<bool>& vec) {
+    uint32_t size;
+    stream.read(reinterpret_cast<char*>(&size), sizeof(uint32_t));
+    uint32_t n_bytes = std::ceil(static_cast<Real>(size)/8);
+    std::vector<uint8_t> bit_vector(n_bytes);
+    stream.read(reinterpret_cast<char*>(&bit_vector[0]),
+                                        n_bytes * sizeof(uint8_t));
+    vec.resize(size);
+    for(uint i = 0; i < size; ++i) {
+      vec[i] = static_cast<bool>(bit_vector[i/8] & (1 << (i%8)));
+    }
+  }
+
   void Dump(std::ostream& stream,
             const std::vector<ost::io::BioUnitDefinition>& vec) {
     uint32_t size = vec.size();
@@ -1185,8 +1212,7 @@ ChainData::ChainData(const ost::mol::ChainHandle& chain,
 void ChainData::ToStream(std::ostream& stream,
                          const std::vector<ResidueDefinition>& res_def,
                          bool lossy, bool avg_bfactors,
-                         bool round_bfactors, bool skip_ss,
-                         bool infer_aa_pos) const {
+                         bool round_bfactors, bool skip_ss) const {
   Dump(stream, ch_name);
   if(chain_type > std::numeric_limits<int8_t>::max()) {
     throw ost::Error("ChainType out of bounds");
@@ -1218,52 +1244,142 @@ void ChainData::ToStream(std::ostream& stream,
     DumpBFactors(stream, bfactors, round_bfactors);
   }
 
-  if(infer_aa_pos) {
+  if(lossy) {
     geom::Vec3List positions_to_dump;
-    std::vector<Real> chi_angles;
+    std::vector<Real> pep_chi_angles;
+    std::vector<bool> pep_oxygen_compression;
+    std::vector<bool> pep_rotamer_compression;
     positions_to_dump.reserve(positions.size());
     int res_start_idx = 0;
     int n_res = res_def_indices.size();
     for(int res_idx = 0; res_idx < n_res; ++res_idx) {
       const ResidueDefinition& def = res_def[res_def_indices[res_idx]];
-      std::set<int> skip_indices = def.GetRotamericAtoms();
-      int o_idx = -1;
+      std::set<int> skip_indices;
+      int res_n_atoms = def.anames.size();
       if(def.chem_type == 'A') {
+        pep_oxygen_compression.push_back(false);
+        pep_rotamer_compression.push_back(false);
+
         // can reconstruct O if there is CA, C, no OXT, its not the last
-        // residue (res_idx < res_def_indices.size()-1) and the next residue is
+        // residue (res_idx < res_def_indices.size()-1), the next residue is
         // an amino acid too and has N.
-        if(def.GetIdx("CA") != -1 && def.GetIdx("C") != -1 &&
-           def.GetIdx("OXT") == -1 && res_idx < n_res-1) {
+        int ca_idx = def.GetIdx("CA");
+        int c_idx = def.GetIdx("C");
+        int o_idx = def.GetIdx("O");
+        int oxt_idx = def.GetIdx("OXT");
+        if(ca_idx != -1 && c_idx != -1 && oxt_idx == -1 && res_idx < n_res-1 &&
+           o_idx != -1) {
           const ResidueDefinition next_def = res_def[res_def_indices[res_idx+1]];
-          if(next_def.chem_type == 'A' && next_def.GetIdx("N") != -1) {
-            o_idx = def.GetIdx("O");
+          int n_idx_next = next_def.GetIdx("N");
+          if(next_def.chem_type == 'A' && n_idx_next != -1) {
+            // compute error when anchor atoms have reduced accuracy
+            geom::Vec3 ca_pos = positions[res_start_idx + ca_idx];
+            geom::Vec3 c_pos = positions[res_start_idx + c_idx];
+            geom::Vec3 n_pos = positions[res_start_idx + res_n_atoms +
+                                         n_idx_next];
+            geom::Vec3 o_pos = positions[res_start_idx + o_idx];
+            ca_pos[0] = 0.1*std::round(ca_pos[0]*10);
+            ca_pos[1] = 0.1*std::round(ca_pos[1]*10);
+            ca_pos[2] = 0.1*std::round(ca_pos[2]*10);
+            c_pos[0] = 0.1*std::round(c_pos[0]*10);
+            c_pos[1] = 0.1*std::round(c_pos[1]*10);
+            c_pos[2] = 0.1*std::round(c_pos[2]*10);
+            n_pos[0] = 0.1*std::round(n_pos[0]*10);
+            n_pos[1] = 0.1*std::round(n_pos[1]*10);
+            n_pos[2] = 0.1*std::round(n_pos[2]*10);
+            geom::Vec3 reconstructed_o_pos;
+            ConstructOPos(ca_pos, c_pos, n_pos, reconstructed_o_pos);
+            if(geom::Length2(reconstructed_o_pos - o_pos) <= Real(0.25)) {
+              pep_oxygen_compression.back() = true;
+              skip_indices.insert(o_idx);
+            }
+          }
+        }
+        if(!def.GetRotamericAtoms().empty()) {
+          std::vector<geom::Vec3> res_pos(positions.begin() + res_start_idx,
+                                          positions.begin() + res_start_idx +
+                                          res_n_atoms);
+          std::vector<geom::Vec3> comp_res_pos;
+          for(auto it = res_pos.begin(); it != res_pos.end(); ++it) {
+            int x = std::round((*it)[0]*10);
+            int y = std::round((*it)[1]*10);
+            int z = std::round((*it)[2]*10);
+            comp_res_pos.push_back(geom::Vec3(0.1*x, 0.1*y, 0.1*z));
+          }
+
+          std::vector<Real> angles;
+          const std::vector<ChiDefinition>& chi_defs = def.GetChiDefinitions();
+          for(auto it = chi_defs.begin(); it != chi_defs.end(); ++it) {
+            angles.push_back(geom::DihedralAngle(res_pos[it->idx_one],
+                                                 res_pos[it->idx_two],
+                                                 res_pos[it->idx_three],
+                                                 res_pos[it->idx_four]));
+          }
+          std::vector<Real> comp_angles;
+          for(auto it = angles.begin(); it != angles.end(); ++it) {
+            int tmp = std::round((*it + M_PI)/(2*M_PI)*255);
+            comp_angles.push_back(static_cast<Real>(tmp)/255*2*M_PI-M_PI);
+          }
+
+          const std::vector<SidechainAtomRule>& at_rules =
+          def.GetSidechainAtomRules();
+          for(auto it = at_rules.begin(); it != at_rules.end(); ++it) {
+            Real dihedral = it->base_dihedral;
+            if(it->dihedral_idx != 4) {
+              dihedral += comp_angles[it->dihedral_idx];
+            }
+            ConstructAtomPos(comp_res_pos[it->anchor_idx[0]],
+                             comp_res_pos[it->anchor_idx[1]],
+                             comp_res_pos[it->anchor_idx[2]],
+                             it->bond_length, it->angle, dihedral,
+                             comp_res_pos[it->sidechain_atom_idx]);
+          }
+          Real max_d = 0.0;
+          for(size_t i = 0; i < res_pos.size(); ++i) {
+            max_d = std::max(max_d, geom::Length2(res_pos[i]-
+                                                  comp_res_pos[i]));
+          }
+          if(std::sqrt(max_d) <= Real(0.5)) {
+            pep_rotamer_compression.back() = true;
+            for(auto it = at_rules.begin(); it != at_rules.end(); ++it) {
+              skip_indices.insert(it->sidechain_atom_idx);
+            }
+            pep_chi_angles.insert(pep_chi_angles.end(), angles.begin(),
+                                  angles.end());
           }
         }
       }
-      if(o_idx != -1) {
-        skip_indices.insert(o_idx);
-      }
-      int n_atoms = res_def[res_def_indices[res_idx]].anames.size();
-      for(int a_idx = 0; a_idx < n_atoms; ++a_idx) {
+      for(int a_idx = 0; a_idx < res_n_atoms; ++a_idx) {
         // skips atoms in skip_indices
         if(skip_indices.find(a_idx) == skip_indices.end()) {
           positions_to_dump.push_back(positions[res_start_idx+a_idx]);
         }
       }
-      const std::vector<ChiDefinition>& chi_vec = def.GetChiDefinitions();
-      for(auto it = chi_vec.begin(); it != chi_vec.end(); ++it) {
-        Real a = geom::DihedralAngle(positions[res_start_idx + it->idx_one],
-                                     positions[res_start_idx + it->idx_two],
-                                     positions[res_start_idx + it->idx_three],
-                                     positions[res_start_idx + it->idx_four]);
-        chi_angles.push_back(a);
-      }
-      res_start_idx += n_atoms;
+      res_start_idx += res_n_atoms;
+    }
+    int8_t flags = 0;
+    if(!pep_chi_angles.empty()) {
+      flags += 1;
+    }
+    if(!pep_oxygen_compression.empty()) {
+      flags += 2;
+    }
+    if(!pep_rotamer_compression.empty()) {
+      flags += 4;
+    }
+    stream.write(reinterpret_cast<char*>(&flags), sizeof(uint8_t));
+    if(!pep_chi_angles.empty()) {
+      DumpDihedrals(stream, pep_chi_angles);
     }
-    DumpPositions(stream, positions_to_dump, lossy);
-    DumpDihedrals(stream, chi_angles);
+    if(!pep_oxygen_compression.empty()) {
+      Dump(stream, pep_oxygen_compression);
+    }
+    if(!pep_rotamer_compression.empty()) {
+      Dump(stream, pep_rotamer_compression);
+    }
+    DumpPositions(stream, positions_to_dump, true);
   } else {
-    DumpPositions(stream, positions, lossy);
+    DumpPositions(stream, positions, false);
   }
   DumpBonds(stream, bonds);
   DumpBondOrders(stream, bond_orders);
@@ -1275,8 +1391,7 @@ void ChainData::ToStream(std::ostream& stream,
 void ChainData::FromStream(std::istream& stream,
                            const std::vector<ResidueDefinition>& res_def,
                            int version, bool lossy, bool avg_bfactors,
-                           bool round_bfactors, bool skip_ss,
-                           bool infer_aa_pos) {
+                           bool round_bfactors, bool skip_ss) {
   
   Load(stream, ch_name);
   if(version >= 2) {
@@ -1299,10 +1414,30 @@ void ChainData::FromStream(std::istream& stream,
   } else {
     LoadBFactors(stream, bfactors, round_bfactors);
   }
-  LoadPositions(stream, positions, lossy);
-  if(infer_aa_pos) {
-    std::vector<Real> chi_angles;
-    LoadDihedrals(stream, chi_angles);
+  
+  if(lossy) {
+    std::vector<Real> pep_chi_angles;
+    std::vector<bool> pep_oxygen_compression;
+    std::vector<bool> pep_rotamer_compression;
+    int8_t flags = 0;
+    stream.read(reinterpret_cast<char*>(&flags), sizeof(uint8_t));
+    if(flags & 1) {
+      LoadDihedrals(stream, pep_chi_angles);
+    }
+    if(flags & 2) {
+      Load(stream, pep_oxygen_compression);
+    }
+    if(flags & 4) {
+      Load(stream, pep_rotamer_compression);
+    }
+
+    if(!pep_oxygen_compression.empty()) {
+      flags += 2;
+    }
+    if(!pep_rotamer_compression.empty()) {
+      flags += 4;
+    }
+    LoadPositions(stream, positions, true);
 
     int n_res = res_def_indices.size();
     int n_at = 0;
@@ -1310,52 +1445,49 @@ void ChainData::FromStream(std::istream& stream,
       n_at += res_def[*it].anames.size();
     }
     geom::Vec3List full_positions(n_at);
+    std::vector<bool> infer_pep_oxygen(n_res, false);
+    std::vector<bool> infer_pep_rotamer(n_res, false);
 
     int pos_idx = 0;
     int full_pos_idx = 0;
-    std::vector<bool> infer_rotamer(n_res, false);
-    std::vector<bool> infer_oxygen(n_res, false);
+    int pep_oxygen_compression_idx = 0;
+    int pep_rotamer_compression_idx = 0;
     for(int res_idx = 0; res_idx < n_res; ++res_idx) {
       const ResidueDefinition& def = res_def[res_def_indices[res_idx]];
       int n_res_at = def.anames.size();
-      std::set<int> inferred_indices = def.GetRotamericAtoms();
-      if(!inferred_indices.empty()) {
-        infer_rotamer[res_idx] = true;
-      }
       if(def.chem_type == 'A') {
-        // can reconstruct O if there is CA, C no OXT, its not the last residue
-        // (res_idx < res_def.size()-1) and the next residue is an amino acid
-        // too and has N.
-        int ca_idx = def.GetIdx("CA");
-        int c_idx = def.GetIdx("C");
-        int o_idx = def.GetIdx("O");
-        int oxt_idx = def.GetIdx("OXT");
-        if(o_idx != -1 && ca_idx != -1 && c_idx != -1 && oxt_idx == -1 &&
-           res_idx < n_res-1) {
-          const ResidueDefinition& next = res_def[res_def_indices[res_idx+1]];
-          if(next.chem_type == 'A' && next.GetIdx("N") != -1) {
-            inferred_indices.insert(o_idx);
-            infer_oxygen[res_idx] = true;
+        std::set<int> inferred_indices;
+        if(pep_oxygen_compression[pep_oxygen_compression_idx++]) {
+          inferred_indices.insert(def.GetIdx("O"));
+          infer_pep_oxygen[res_idx] = true;
+        }
+        if(pep_rotamer_compression[pep_rotamer_compression_idx++]) {
+          inferred_indices.insert(def.rotameric_atoms.begin(),
+                                  def.rotameric_atoms.end());
+          infer_pep_rotamer[res_idx] = true;
+        }
+        for(int i = 0; i < n_res_at; ++i) {
+          if(inferred_indices.find(i) == inferred_indices.end()) {
+            full_positions[full_pos_idx++] = positions[pos_idx++];
+          } else {
+            ++full_pos_idx; // skip
           }
         }
-      }
-      // transfer positions
-      for(int i = 0; i < n_res_at; ++i) {
-        if(inferred_indices.find(i) == inferred_indices.end()) {
+      } else {
+        // transfer all positions
+        for(int i = 0; i < n_res_at; ++i) {
           full_positions[full_pos_idx++] = positions[pos_idx++];
-        } else {
-          ++full_pos_idx; // skip
         }
       }
     }
 
     // infer
     int start_idx = 0;
-    int chi_idx = 0;
+    int pep_chi_angles_idx = 0;
     for(int res_idx = 0; res_idx < n_res; ++res_idx) {
       const ResidueDefinition& def = res_def[res_def_indices[res_idx]];
       int n_res_atoms = def.anames.size();
-      if(infer_oxygen[res_idx]) {
+      if(infer_pep_oxygen[res_idx]) {
         const ResidueDefinition& next_def = res_def[res_def_indices[res_idx+1]];
         int ca_idx = start_idx + def.GetIdx("CA");
         int c_idx = start_idx + def.GetIdx("C");
@@ -1364,12 +1496,12 @@ void ChainData::FromStream(std::istream& stream,
         ConstructOPos(full_positions[ca_idx], full_positions[c_idx],
                       full_positions[n_next_idx], full_positions[o_idx]);
       }
-      if(infer_rotamer[res_idx]) {
+      if(infer_pep_rotamer[res_idx]) {
         const std::vector<SidechainAtomRule>& at_rules =
         def.GetSidechainAtomRules();
         std::vector<Real> dihedral_angles;
         for(int i = 0; i < def.GetNChiAngles(); ++i) {
-          dihedral_angles.push_back(chi_angles[chi_idx++]);
+          dihedral_angles.push_back(pep_chi_angles[pep_chi_angles_idx++]);
         }
         for(auto it = at_rules.begin(); it != at_rules.end(); ++it) {
           Real dihedral = it->base_dihedral;
@@ -1386,6 +1518,8 @@ void ChainData::FromStream(std::istream& stream,
       start_idx += n_res_atoms;
     }
     std::swap(positions, full_positions);
+  } else {
+    LoadPositions(stream, positions, false);
   }
   LoadBonds(stream, bonds);
   LoadBondOrders(stream, bond_orders);
@@ -4116,8 +4250,7 @@ void OMF::ToStream(std::ostream& stream) const {
 
   Dump(stream, biounit_definitions_);
   Dump(stream, chain_data_, residue_definitions_, OptionSet(LOSSY),
-       OptionSet(AVG_BFACTORS), OptionSet(ROUND_BFACTORS), OptionSet(SKIP_SS),
-       OptionSet(INFER_AA_POS));
+       OptionSet(AVG_BFACTORS), OptionSet(ROUND_BFACTORS), OptionSet(SKIP_SS));
   Dump(stream, bond_chain_names_);
   Dump(stream, bond_atoms_);
   Dump(stream, bond_orders_);
@@ -4160,8 +4293,7 @@ void OMF::FromStream(std::istream& stream) {
 
   Load(stream, biounit_definitions_);
   Load(stream, chain_data_, residue_definitions_, version_, OptionSet(LOSSY),
-       OptionSet(AVG_BFACTORS), OptionSet(ROUND_BFACTORS), OptionSet(SKIP_SS),
-       OptionSet(INFER_AA_POS));
+       OptionSet(AVG_BFACTORS), OptionSet(ROUND_BFACTORS), OptionSet(SKIP_SS));
   Load(stream, bond_chain_names_);
   Load(stream, bond_atoms_);
   Load(stream, bond_orders_);
diff --git a/modules/io/src/mol/omf.hh b/modules/io/src/mol/omf.hh
index 7cd4e81dd..ca40ab16f 100644
--- a/modules/io/src/mol/omf.hh
+++ b/modules/io/src/mol/omf.hh
@@ -152,12 +152,12 @@ struct ChainData {
   void ToStream(std::ostream& stream,
                 const std::vector<ResidueDefinition>& res_def,
                 bool lossy, bool avg_bfactors, bool round_bfactors,
-                bool skip_ss, bool infer_aa_pos) const;
+                bool skip_ss) const;
 
   void FromStream(std::istream& stream,
                   const std::vector<ResidueDefinition>& res_def,
                   int version, bool lossy, bool avg_bfactors,
-                  bool round_bfactors, bool skip_ss, bool infer_aa_pos);
+                  bool round_bfactors, bool skip_ss);
 
   // chain features
   String ch_name;
@@ -200,8 +200,7 @@ class OMF {
 public:
 
   enum OMFOption {DEFAULT_PEPLIB = 1, LOSSY = 2, AVG_BFACTORS = 4,
-                  ROUND_BFACTORS = 8, SKIP_SS = 16, INFER_PEP_BONDS = 32,
-                  INFER_AA_POS = 64};
+                  ROUND_BFACTORS = 8, SKIP_SS = 16, INFER_PEP_BONDS = 32};
 
   bool OptionSet(OMFOption opt) const {
     return (opt & options_) == opt;
-- 
GitLab