From b888a131474db287ade0aa8fbb2361d3df7ffeb9 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Thu, 19 Jan 2017 10:27:09 +0100
Subject: [PATCH] make subrotamer optimization available in
 SidechainReconstructor

---
 sidechain/doc/reconstruct.rst                 |  8 ++++-
 .../pymod/export_sidechain_reconstructor.cc   |  3 +-
 sidechain/src/sidechain_reconstructor.cc      | 24 +++++++++++++
 sidechain/src/sidechain_reconstructor.hh      |  6 +++-
 sidechain/src/subrotamer_optimizer.hh         |  2 +-
 sidechain/tests/test_sidechain.py             | 34 ++++++++++++-------
 6 files changed, 60 insertions(+), 17 deletions(-)

diff --git a/sidechain/doc/reconstruct.rst b/sidechain/doc/reconstruct.rst
index d5361b7b..69ba7155 100644
--- a/sidechain/doc/reconstruct.rst
+++ b/sidechain/doc/reconstruct.rst
@@ -20,7 +20,8 @@ SidechainReconstructor Class
 --------------------------------------------------------------------------------
 
 .. class:: SidechainReconstructor(keep_sidechains=True, build_disulfids=True, \
-                                  cutoff=20, graph_max_complexity=100000000, \
+                                  optimize_subrotamers=False, cutoff=20, \
+                                  graph_max_complexity=100000000, \
                                   graph_intial_epsilon=0.02, \
                                   disulfid_score_thresh=45)
 
@@ -41,6 +42,11 @@ SidechainReconstructor Class
                           the result.
   :type build_disulfids: :class:`bool`
 
+  :param optimize_subrotamers: Flag, whether the :func:`SubrotamerOptimizer`
+                               with default parametrization should be called 
+                               if we're dealing with FRM rotamers.
+  :type optimize_subrotamers:  :class:`bool`
+
   :param cutoff: Cutoff used to search relevant residues surrounding the loop.
   :type cutoff:  :class:`float`
 
diff --git a/sidechain/pymod/export_sidechain_reconstructor.cc b/sidechain/pymod/export_sidechain_reconstructor.cc
index 4f97d55a..b550dd26 100644
--- a/sidechain/pymod/export_sidechain_reconstructor.cc
+++ b/sidechain/pymod/export_sidechain_reconstructor.cc
@@ -91,8 +91,9 @@ void export_SidechainReconstructor() {
 
   class_<SidechainReconstructor, SidechainReconstructorPtr>
     ("SidechainReconstructor", no_init)
-    .def(init<bool, bool, Real, uint64_t, Real, Real>(
+    .def(init<bool, bool, bool, Real, uint64_t, Real, Real>(
          (arg("keep_sidechains")=true, arg("build_disulfids")=true,
+          arg("optimize_subrotamers")=false,
           arg("cutoff")=20, arg("graph_max_complexity")=100000000,
           arg("graph_intial_epsilon")=0.02,
           arg("disulfid_score_thresh")=45)))
diff --git a/sidechain/src/sidechain_reconstructor.cc b/sidechain/src/sidechain_reconstructor.cc
index b82b7ce6..e1f9dd10 100644
--- a/sidechain/src/sidechain_reconstructor.cc
+++ b/sidechain/src/sidechain_reconstructor.cc
@@ -1,6 +1,7 @@
 #include <promod3/sidechain/sidechain_reconstructor.hh>
 #include <promod3/sidechain/disulfid.hh>
 #include <promod3/sidechain/rotamer_graph.hh>
+#include <promod3/sidechain/subrotamer_optimizer.hh>
 #include <promod3/core/message.hh>
 #include <promod3/core/runtime_profiling.hh>
 
@@ -321,6 +322,24 @@ void SidechainReconstructor::CollectRotamerGroups_(
   }
 }
 
+void ApplySubrotamerOptimization(std::vector<FRMRotamerGroupPtr>& rotamer_groups,
+                                 const std::vector<int>& solution) {
+
+  uint num_rotamer_groups = rotamer_groups.size();
+
+  std::vector<FRMRotamerPtr> rotamers(rotamer_groups.size());
+  for(uint i = 0; i < num_rotamer_groups; ++i){
+    rotamers[i] = (*rotamer_groups[i])[solution[i]];
+  }
+  SubrotamerOptimizer(rotamers);
+}
+
+void ApplySubrotamerOptimization(std::vector<RRMRotamerGroupPtr>& rotamer_groups,
+                                 const std::vector<int>& solution) {
+  //there's nothing to do... 
+}
+
+
 template<typename RotamerGroup>
 void SidechainReconstructor::SolveSystem_(
                        SidechainReconstructionDataPtr res) const {
@@ -377,6 +396,11 @@ void SidechainReconstructor::SolveSystem_(
   std::pair<std::vector<int>,Real> solution = 
     graph->TreeSolve(graph_max_complexity_, graph_intial_epsilon_);
 
+  // do subrotamer optimization if required
+  if(optimize_subrotamers_){
+    ApplySubrotamerOptimization(rotamer_groups, solution.first);
+  }
+
   // apply solution to subset of data
   for (uint i = 0; i < res->rotamer_res_indices.size(); ++i) {
     const uint res_idx = res->rotamer_res_indices[i];
diff --git a/sidechain/src/sidechain_reconstructor.hh b/sidechain/src/sidechain_reconstructor.hh
index 2e052b38..bd405ec3 100644
--- a/sidechain/src/sidechain_reconstructor.hh
+++ b/sidechain/src/sidechain_reconstructor.hh
@@ -38,12 +38,15 @@ class SidechainReconstructor {
 public:
 
   SidechainReconstructor(bool keep_sidechains = true,
-                         bool build_disulfids = true, Real cutoff = 20,
+                         bool build_disulfids = true, 
+                         bool optimize_subrotamers = false,
+                         Real cutoff = 20,
                          uint64_t graph_max_complexity = 100000000,
                          Real graph_intial_epsilon = 0.02,
                          Real disulfid_score_thresh = 45)
                          : keep_sidechains_(keep_sidechains)
                          , build_disulfids_(build_disulfids)
+                         , optimize_subrotamers_(optimize_subrotamers)
                          , cutoff_(cutoff)
                          , graph_max_complexity_(graph_max_complexity)
                          , graph_intial_epsilon_(graph_intial_epsilon)
@@ -104,6 +107,7 @@ private:
   // reconstruction parameters
   bool keep_sidechains_;
   bool build_disulfids_;
+  bool optimize_subrotamers_;
   Real cutoff_;
   uint64_t graph_max_complexity_;
   Real graph_intial_epsilon_;
diff --git a/sidechain/src/subrotamer_optimizer.hh b/sidechain/src/subrotamer_optimizer.hh
index 594d3ffd..df415f00 100644
--- a/sidechain/src/subrotamer_optimizer.hh
+++ b/sidechain/src/subrotamer_optimizer.hh
@@ -9,7 +9,7 @@
 namespace promod3{ namespace sidechain{
 
 void SubrotamerOptimizer(std::vector<FRMRotamerPtr>& rotamers,
-                         Real active_internal_energy = -0.5,
+                         Real active_internal_energy = -2.0,
                          Real inactive_internal_energy = 0.0,
                          uint max_complexity = 100000000,
                          Real initial_epsilon = 0.02);
diff --git a/sidechain/tests/test_sidechain.py b/sidechain/tests/test_sidechain.py
index 5ddc769a..fad4df49 100644
--- a/sidechain/tests/test_sidechain.py
+++ b/sidechain/tests/test_sidechain.py
@@ -23,17 +23,20 @@ class SidechainTests(unittest.TestCase):
             self.assertLessEqual(geom.Length(a.pos - a_ref.pos), max_dist)
 
     def CheckEnvVsPy(self, ent, env, keep_sidechains, build_disulfids,
+                     optimize_subrotamers,
                      rotamer_model, rotamer_library):
         # reconstruct sidechains for full OST entity
         ent_py = ent.Copy()
         sidechain.Reconstruct(ent_py, keep_sidechains=keep_sidechains,
                               build_disulfids=build_disulfids,
+                              optimize_subrotamers=optimize_subrotamers,
                               rotamer_model=rotamer_model,
                               rotamer_library=rotamer_library)
 
         # same with SidechainReconstructor
         sc_rec = sidechain.SidechainReconstructor(keep_sidechains=keep_sidechains,
-                                                  build_disulfids=build_disulfids)
+                                                  build_disulfids=build_disulfids,
+                                                  optimize_subrotamers=optimize_subrotamers)
         sc_rec.AttachEnvironment(env, use_frm=(rotamer_model=="frm"),
                                  rotamer_library=rotamer_library)
         res = sc_rec.Reconstruct(1, ent.residue_count)
@@ -63,28 +66,33 @@ class SidechainTests(unittest.TestCase):
         env = loop.AllAtomEnv(seqres_str)
         env.SetInitialEnvironment(ent)
         self.CheckEnvVsPy(ent, env, keep_sidechains=False,
-                          build_disulfids=False, rotamer_model="rrm",
-                          rotamer_library=self.rotamer_library)
+                          build_disulfids=False, optimize_subrotamers=False,
+                          rotamer_model="rrm", rotamer_library=self.rotamer_library)
         # reuse env with keep_sidechains=True
         self.CheckEnvVsPy(ent, env, keep_sidechains=True,
-                          build_disulfids=False, rotamer_model="rrm",
-                          rotamer_library=self.rotamer_library)
+                          build_disulfids=False, optimize_subrotamers=False,
+                          rotamer_model="rrm", rotamer_library=self.rotamer_library)
         # vary one by one (need to reset env to get new stuff)
         env = loop.AllAtomEnv(seqres_str)
         env.SetInitialEnvironment(ent)
         self.CheckEnvVsPy(ent, env, keep_sidechains=True,
-                          build_disulfids=False, rotamer_model="frm",
-                          rotamer_library=self.rotamer_library)
+                          build_disulfids=False, optimize_subrotamers=False,
+                          rotamer_model="frm", rotamer_library=self.rotamer_library)
         env = loop.AllAtomEnv(seqres_str)
         env.SetInitialEnvironment(ent)
         self.CheckEnvVsPy(ent, env, keep_sidechains=True,
-                          build_disulfids=False, rotamer_model="rrm",
-                          rotamer_library=self.rotamer_library)
+                          build_disulfids=False, optimize_subrotamers=False,
+                          rotamer_model="rrm", rotamer_library=self.rotamer_library)
         env = loop.AllAtomEnv(seqres_str)
         env.SetInitialEnvironment(ent)
         self.CheckEnvVsPy(ent, env, keep_sidechains=True,
-                          build_disulfids=False, rotamer_model="rrm",
-                          rotamer_library=self.bbdep_rotamer_library)
+                          build_disulfids=False, optimize_subrotamers=False,
+                          rotamer_model="rrm", rotamer_library=self.bbdep_rotamer_library)
+        env = loop.AllAtomEnv(seqres_str)
+        env.SetInitialEnvironment(ent)
+        self.CheckEnvVsPy(ent, env, keep_sidechains=False,
+                          build_disulfids=True, optimize_subrotamers=True,
+                          rotamer_model="frm", rotamer_library=self.bbdep_rotamer_library)
         
         # crn needed to check for disulfid bridges
         ent = io.LoadPDB(os.path.join('data', '1crn_sc_test.pdb'))
@@ -92,8 +100,8 @@ class SidechainTests(unittest.TestCase):
         env = loop.AllAtomEnv(seqres_str)
         env.SetInitialEnvironment(ent)
         self.CheckEnvVsPy(ent, env, keep_sidechains=True,
-                          build_disulfids=True, rotamer_model="rrm",
-                          rotamer_library=self.rotamer_library)
+                          build_disulfids=True, optimize_subrotamers=False,
+                          rotamer_model="rrm", rotamer_library=self.rotamer_library)
 
 if __name__ == "__main__":
     from ost import testutils
-- 
GitLab