From 003af288909e556888825da037afb5052d7acf9d Mon Sep 17 00:00:00 2001
From: Gerardo Tauriello <gerardo.tauriello@unibas.ch>
Date: Fri, 8 Sep 2017 11:58:15 +0200
Subject: [PATCH] SCHWED-1740: extended MmSystemCreator to work with multiple
 loops.

---
 loop/pymod/export_mm_system_creator.cc |  84 ++++++++++++---
 loop/src/mm_system_creator.cc          | 137 +++++++++++++++++--------
 loop/src/mm_system_creator.hh          |  36 ++++++-
 loop/tests/test_mm_system_creator.cc   |   4 +-
 4 files changed, 200 insertions(+), 61 deletions(-)

diff --git a/loop/pymod/export_mm_system_creator.cc b/loop/pymod/export_mm_system_creator.cc
index 51daaaa9..a4f316c0 100644
--- a/loop/pymod/export_mm_system_creator.cc
+++ b/loop/pymod/export_mm_system_creator.cc
@@ -25,22 +25,11 @@ WrapGetDisulfidBridges(const MmSystemCreator& mm_sys,
   return return_list;
 }
 
-void WrapSetupSystem(MmSystemCreator& mm_sys,
-                     const AllAtomPositions& all_pos,
-                     const boost::python::list& res_idx_list,
-                     uint loop_length, const boost::python::list& is_n_ter_list,
-                     const boost::python::list& is_c_ter_list,
-                     const boost::python::list& bridge_list) {
-  // get vectors
-  std::vector<uint> res_indices;
-  std::vector<bool> is_n_ter;
-  std::vector<bool> is_c_ter;
-  core::ConvertListToVector(res_idx_list, res_indices);
-  core::ConvertListToVector(is_n_ter_list, is_n_ter);
-  core::ConvertListToVector(is_c_ter_list, is_c_ter);
-  // special for bridges
+void ConvertBridges(const boost::python::list& bridge_list,
+                    MmSystemCreator::DisulfidBridgeVector& bridges) {
+  // special conversion for bridges (list of tuples)
   const uint num_bridges = boost::python::len(bridge_list);
-  MmSystemCreator::DisulfidBridgeVector bridges(num_bridges);
+  bridges.resize(num_bridges);
   for (uint i = 0; i < num_bridges; ++i) {
     boost::python::extract<boost::python::tuple> extract_tuple(bridge_list[i]);
     if (!extract_tuple.check()) {
@@ -53,11 +42,54 @@ void WrapSetupSystem(MmSystemCreator& mm_sys,
     bridges[i].first = boost::python::extract<uint>(bridge_list[i][0]);
     bridges[i].second = boost::python::extract<uint>(bridge_list[i][1]);
   }
+}
+
+void WrapSetupSystem(MmSystemCreator& mm_sys,
+                     const AllAtomPositions& all_pos,
+                     const boost::python::list& res_idx_list,
+                     uint loop_length, const boost::python::list& is_n_ter_list,
+                     const boost::python::list& is_c_ter_list,
+                     const boost::python::list& bridge_list) {
+  // get vectors
+  std::vector<uint> res_indices;
+  std::vector<bool> is_n_ter;
+  std::vector<bool> is_c_ter;
+  MmSystemCreator::DisulfidBridgeVector bridges;
+  core::ConvertListToVector(res_idx_list, res_indices);
+  core::ConvertListToVector(is_n_ter_list, is_n_ter);
+  core::ConvertListToVector(is_c_ter_list, is_c_ter);
+  ConvertBridges(bridge_list, bridges);
   // finally setup system
   mm_sys.SetupSystem(all_pos, res_indices, loop_length, is_n_ter, is_c_ter,
                      bridges);
 }
 
+void WrapSetupSystemMulti(MmSystemCreator& mm_sys,
+                          const AllAtomPositions& all_pos,
+                          const boost::python::list& res_idx_list,
+                          const boost::python::list& loop_start_indices_list,
+                          const boost::python::list& loop_lengths_list,
+                          const boost::python::list& is_n_ter_list,
+                          const boost::python::list& is_c_ter_list,
+                          const boost::python::list& bridge_list) {
+  // get vectors
+  std::vector<uint> res_indices;
+  std::vector<uint> loop_start_indices;
+  std::vector<uint> loop_lengths;
+  std::vector<bool> is_n_ter;
+  std::vector<bool> is_c_ter;
+  MmSystemCreator::DisulfidBridgeVector bridges;
+  core::ConvertListToVector(res_idx_list, res_indices);
+  core::ConvertListToVector(loop_start_indices_list, loop_start_indices);
+  core::ConvertListToVector(loop_lengths_list, loop_lengths);
+  core::ConvertListToVector(is_n_ter_list, is_n_ter);
+  core::ConvertListToVector(is_c_ter_list, is_c_ter);
+  ConvertBridges(bridge_list, bridges);
+  // finally setup system
+  mm_sys.SetupSystem(all_pos, res_indices, loop_start_indices, loop_lengths,
+                     is_n_ter, is_c_ter, bridges);
+}
+
 void WrapUpdatePositions(MmSystemCreator& mm_sys,
                          const AllAtomPositions& all_pos,
                          const boost::python::list& res_idx_list) {
@@ -83,6 +115,20 @@ void WrapExtractLoopPositionsIndexed(MmSystemCreator& mm_sys,
   mm_sys.ExtractLoopPositions(out_pos, res_indices);
 }
 
+boost::python::list WrapGetLoopStartIndices(const MmSystemCreator& mm_sys) {
+  // convert into list
+  boost::python::list return_list;
+  core::AppendVectorToList(mm_sys.GetLoopStartIndices(), return_list);
+  return return_list;
+}
+
+boost::python::list WrapGetLoopLengths(const MmSystemCreator& mm_sys) {
+  // convert into list
+  boost::python::list return_list;
+  core::AppendVectorToList(mm_sys.GetLoopLengths(), return_list);
+  return return_list;
+}
+
 boost::python::list WrapGetFfAa(const MmSystemCreator& mm_sys) {
   // convert into list
   boost::python::list return_list;
@@ -111,12 +157,20 @@ void export_MmSystemCreator() {
     .def("SetupSystem", WrapSetupSystem,
          (arg("all_pos"), arg("res_indices"), arg("loop_length"),
           arg("is_n_ter"), arg("is_c_ter"), arg("disulfid_bridges")))
+    .def("SetupSystem", WrapSetupSystemMulti,
+         (arg("all_pos"), arg("res_indices"), arg("loop_start_indices"),
+          arg("loop_lengths"), arg("is_n_ter"), arg("is_c_ter"),
+          arg("disulfid_bridges")))
     .def("UpdatePositions", WrapUpdatePositions,
          (arg("all_pos"), arg("res_indices")))
     .def("ExtractLoopPositions", WrapExtractLoopPositions, (arg("loop_pos")))
     .def("ExtractLoopPositions", WrapExtractLoopPositionsIndexed,
          (arg("out_pos"), arg("res_indices")))
     .def("GetSimulation", &MmSystemCreator::GetSimulation)
+    .def("GetNumResidues", &MmSystemCreator::GetNumResidues)
+    .def("GetNumLoopResidues", &MmSystemCreator::GetNumLoopResidues)
+    .def("GetLoopStartIndices", WrapGetLoopStartIndices)
+    .def("GetLoopLengths", WrapGetLoopLengths)
     .def("GetForcefieldAminoAcids", WrapGetFfAa)
     .def("GetIndexing", WrapGetIndexing)
     .def("GetCpuPlatformSupport", &MmSystemCreator::GetCpuPlatformSupport)
diff --git a/loop/src/mm_system_creator.cc b/loop/src/mm_system_creator.cc
index 817cd2d4..a2643e5c 100644
--- a/loop/src/mm_system_creator.cc
+++ b/loop/src/mm_system_creator.cc
@@ -254,14 +254,33 @@ void MmSystemCreator::SetupSystem(const AllAtomPositions& all_pos,
                                   const std::vector<bool>& is_c_ter,
                                   const DisulfidBridgeVector& disulfid_bridges)
 {
+  // call other one
+  std::vector<uint> loop_start_indices(1, 0);
+  std::vector<uint> loop_lengths(1, loop_length);
+  SetupSystem(all_pos, res_indices, loop_start_indices, loop_lengths,
+              is_n_ter, is_c_ter, disulfid_bridges);
+}
 
+void MmSystemCreator::SetupSystem(const AllAtomPositions& all_pos,
+                                  const std::vector<uint>& res_indices,
+                                  const std::vector<uint>& loop_start_indices,
+                                  const std::vector<uint>& loop_lengths,
+                                  const std::vector<bool>& is_n_ter,
+                                  const std::vector<bool>& is_c_ter,
+                                  const DisulfidBridgeVector& disulfid_bridges)
+{
   promod3::core::ScopedTimerPtr prof = core::StaticRuntimeProfiler::StartScoped(
                                 "MmSystemCreator::SetupSystem", 2);
 
   // check data consistency
-  if (loop_length > res_indices.size()) {
-    throw promod3::Error("Loop length cannnot be longer than res_indices "
-                         "in MmSystemCreator::SetupSystem!");
+  const uint num_loops = loop_start_indices.size();
+  for (uint i_loop = 0; i_loop < num_loops; ++i_loop) {
+    // loop from start_idx to start_idx + loop_lengths[i_loop] - 1
+    const uint start_idx = loop_start_indices[i_loop];
+    if (start_idx + loop_lengths[i_loop] > res_indices.size()) {
+      throw promod3::Error("Loop indices out of bounds compared to res_indices "
+                           "in MmSystemCreator::SetupSystem!");
+    }
   }
   if (is_n_ter.size() != res_indices.size()) {
     throw promod3::Error("Sizes of res_indices and is_n_ter must match in "
@@ -290,8 +309,13 @@ void MmSystemCreator::SetupSystem(const AllAtomPositions& all_pos,
     }
   }
 
-  // setup internal data (loop_length_, is_X_ter_, ff_aa_, first_idx_)
-  loop_length_ = loop_length;
+  // setup internal data (loop_..., is_X_ter_, ff_aa_, first_idx_)
+  num_loop_residues_ = 0;
+  for (uint i_loop = 0; i_loop < num_loops; ++i_loop) {
+    num_loop_residues_ += loop_lengths[i_loop];
+  }
+  loop_start_indices_ = loop_start_indices;
+  loop_lengths_ = loop_lengths;
   is_n_ter_ = is_n_ter;
   is_c_ter_ = is_c_ter;
   SetupNextIRes_(res_indices);
@@ -340,29 +364,44 @@ void MmSystemCreator::ExtractLoopPositions(AllAtomPositions& loop_pos) {
                                 "MmSystemCreator::ExtractLoopPositions", 2);
 
   // check data consistency
-  if (loop_pos.GetNumResidues() < loop_length_) {
+  if (loop_pos.GetNumResidues() < num_loop_residues_) {
     throw promod3::Error("Output storage too small in "
                          "MmSystemCreator::ExtractLoopPositions!");
   }
-  for (uint i_res = 0; i_res < loop_length_; ++i_res) {
-    if (loop_pos.GetAA(i_res) != ff_lookup_->GetAA(ff_aa_[i_res])) {
-      throw promod3::Error("Inconsistent amino acid types observed in "
-                           "MmSystemCreator::ExtractLoopPositions!");
+  // check AA for all loops
+  const uint num_loops = loop_start_indices_.size();
+  uint loop_pos_idx = 0;
+  for (uint i_loop = 0; i_loop < num_loops; ++i_loop) {
+    // loop from start_idx to start_idx + loop_lengths_[i_loop] - 1
+    const uint start_idx = loop_start_indices_[i_loop];
+    for (uint loop_idx = 0; loop_idx < loop_lengths_[i_loop]; ++loop_idx) {
+      const uint i_res = start_idx + loop_idx;
+      if (loop_pos.GetAA(loop_pos_idx) != ff_lookup_->GetAA(ff_aa_[i_res])) {
+        throw promod3::Error("Inconsistent amino acid types observed in "
+                             "MmSystemCreator::ExtractLoopPositions!");
+      }
+      ++loop_pos_idx;
     }
   }
 
   // get from simulation
   positions_ = simulation_->GetPositions();
 
-  // fill heavy atoms
-  for (uint i_res = 0; i_res < loop_length_; ++i_res) {
-    const ForcefieldAminoAcid ff_aa = ff_aa_[i_res];
-    const uint first_idx = first_idx_[i_res];
-    const ost::conop::AminoAcid aa = ff_lookup_->GetAA(ff_aa);
-    // get heavy atoms
-    for (uint i = 0; i < aa_lookup_.GetNumAtoms(aa); ++i) {
-      const uint idx = ff_lookup_->GetHeavyIndex(ff_aa, i);
-      loop_pos.SetPos(i_res, i, positions_[first_idx + idx]);
+  // fill heavy atoms from each loop
+  loop_pos_idx = 0;
+  for (uint i_loop = 0; i_loop < num_loops; ++i_loop) {
+    const uint start_idx = loop_start_indices_[i_loop];
+    for (uint loop_idx = 0; loop_idx < loop_lengths_[i_loop]; ++loop_idx) {
+      const uint i_res = start_idx + loop_idx;
+      const ForcefieldAminoAcid ff_aa = ff_aa_[i_res];
+      const uint first_idx = first_idx_[i_res];
+      const ost::conop::AminoAcid aa = ff_lookup_->GetAA(ff_aa);
+      // get heavy atoms
+      for (uint i = 0; i < aa_lookup_.GetNumAtoms(aa); ++i) {
+        const uint idx = ff_lookup_->GetHeavyIndex(ff_aa, i);
+        loop_pos.SetPos(loop_pos_idx, i, positions_[first_idx + idx]);
+      }
+      ++loop_pos_idx;
     }
   }
 }
@@ -374,32 +413,44 @@ void MmSystemCreator::ExtractLoopPositions(AllAtomPositions& out_pos,
                                 "MmSystemCreator::ExtractLoopPositions", 2);
 
   // check data consistency
-  if (res_indices.size() < loop_length_) {
-    throw promod3::Error("Too few residue indices passed in "
-                         "MmSystemCreator::ExtractLoopPositions!");
-  }
-  for (uint i_res = 0; i_res < loop_length_; ++i_res) {
-    const uint res_idx = res_indices[i_res];
-    out_pos.CheckResidueIndex(res_idx);
-    if (out_pos.GetAA(res_idx) != ff_lookup_->GetAA(ff_aa_[i_res])) {
-      throw promod3::Error("Inconsistent amino acid types observed in "
-                           "MmSystemCreator::ExtractLoopPositions!");
+  const uint num_loops = loop_start_indices_.size();
+  for (uint i_loop = 0; i_loop < num_loops; ++i_loop) {
+    // loop from start_idx to start_idx + loop_lengths_[i_loop] - 1
+    const uint start_idx = loop_start_indices_[i_loop];
+    // check res_indices
+    if (start_idx + loop_lengths_[i_loop] > res_indices.size()) {
+        throw promod3::Error("Too few residue indices passed in "
+                             "MmSystemCreator::ExtractLoopPositions!");
+    }
+    // check AA
+    for (uint loop_idx = 0; loop_idx < loop_lengths_[i_loop]; ++loop_idx) {
+      const uint i_res = start_idx + loop_idx;
+      const uint res_idx = res_indices[i_res];
+      out_pos.CheckResidueIndex(res_idx);
+      if (out_pos.GetAA(res_idx) != ff_lookup_->GetAA(ff_aa_[i_res])) {
+        throw promod3::Error("Inconsistent amino acid types observed in "
+                             "MmSystemCreator::ExtractLoopPositions!");
+      }
     }
   }
 
   // get from simulation
   positions_ = simulation_->GetPositions();
 
-  // fill heavy atoms
-  for (uint i_res = 0; i_res < loop_length_; ++i_res) {
-    const uint res_idx = res_indices[i_res];
-    const ForcefieldAminoAcid ff_aa = ff_aa_[i_res];
-    const uint first_idx = first_idx_[i_res];
-    const ost::conop::AminoAcid aa = ff_lookup_->GetAA(ff_aa);
-    // get heavy atoms
-    for (uint i = 0; i < aa_lookup_.GetNumAtoms(aa); ++i) {
-      const uint idx = ff_lookup_->GetHeavyIndex(ff_aa, i);
-      out_pos.SetPos(res_idx, i, positions_[first_idx + idx]);
+  // fill heavy atoms from each loop
+  for (uint i_loop = 0; i_loop < num_loops; ++i_loop) {
+    const uint start_idx = loop_start_indices_[i_loop];
+    for (uint loop_idx = 0; loop_idx < loop_lengths_[i_loop]; ++loop_idx) {
+      const uint i_res = start_idx + loop_idx;
+      const uint res_idx = res_indices[i_res];
+      const ForcefieldAminoAcid ff_aa = ff_aa_[i_res];
+      const uint first_idx = first_idx_[i_res];
+      const ost::conop::AminoAcid aa = ff_lookup_->GetAA(ff_aa);
+      // get heavy atoms
+      for (uint i = 0; i < aa_lookup_.GetNumAtoms(aa); ++i) {
+        const uint idx = ff_lookup_->GetHeavyIndex(ff_aa, i);
+        out_pos.SetPos(res_idx, i, positions_[first_idx + idx]);
+      }
     }
   }
 }
@@ -495,7 +546,9 @@ MmSystemCreator::SetupTopology_(const DisulfidBridgeVector& disulfid_bridges) {
     const std::vector<Real>&
     masses = ff_lookup_->GetMasses(ff_aa, is_nter, is_cter);
     std::copy(masses.begin(), masses.end(), &atom_masses[first_idx]);
-    if (i_res == 0 && !is_nter) {
+    // loop check
+    const IndexLocation_ i_loc = GetIndexLocation_(i_res);
+    if (i_loc == ON_N_STEM && !is_nter) {
       // fix N-stem (only fix N, CA, CB)
       const uint idx_N = ff_lookup_->GetHeavyIndex(ff_aa, BB_N_INDEX);
       atom_masses[first_idx + idx_N] = 0;
@@ -505,7 +558,7 @@ MmSystemCreator::SetupTopology_(const DisulfidBridgeVector& disulfid_bridges) {
         const uint idx_CB = ff_lookup_->GetHeavyIndex(ff_aa, BB_CB_INDEX);
         atom_masses[first_idx + idx_CB] = 0;
       }
-    } else if (i_res == loop_length_-1 && !is_cter) {
+    } else if (i_loc == ON_C_STEM && !is_cter) {
       // fix C-stem (only fix CA, CB, C, O)
       const uint idx_CA = ff_lookup_->GetHeavyIndex(ff_aa, BB_CA_INDEX);
       atom_masses[first_idx + idx_CA] = 0;
@@ -517,7 +570,7 @@ MmSystemCreator::SetupTopology_(const DisulfidBridgeVector& disulfid_bridges) {
         const uint idx_CB = ff_lookup_->GetHeavyIndex(ff_aa, BB_CB_INDEX);
         atom_masses[first_idx + idx_CB] = 0;
       }
-    } else if (i_res >= loop_length_) {
+    } else if (i_loc == OUT_OF_LOOP) {
       // fix surrounding (either all or only heavy atoms)
       if (fix_surrounding_hydrogens_) {
         std::fill_n(&atom_masses[first_idx], masses.size(), Real(0));
diff --git a/loop/src/mm_system_creator.hh b/loop/src/mm_system_creator.hh
index 2ac4cff6..9ee72384 100644
--- a/loop/src/mm_system_creator.hh
+++ b/loop/src/mm_system_creator.hh
@@ -43,11 +43,20 @@ public:
                      const std::vector<uint>& res_indices) const;
 
   // setup system (overwrites old one completely!)
+  // single loop
   void SetupSystem(const AllAtomPositions& all_pos,
                    const std::vector<uint>& res_indices, uint loop_length,
                    const std::vector<bool>& is_n_ter,
                    const std::vector<bool>& is_c_ter,
                    const DisulfidBridgeVector& disulfid_bridges);
+  // multi loop
+  void SetupSystem(const AllAtomPositions& all_pos,
+                   const std::vector<uint>& res_indices,
+                   const std::vector<uint>& loop_start_indices,
+                   const std::vector<uint>& loop_lengths,
+                   const std::vector<bool>& is_n_ter,
+                   const std::vector<bool>& is_c_ter,
+                   const DisulfidBridgeVector& disulfid_bridges);
 
   // overwrite positions (input must be compatible with last SetupSystem call)
   void UpdatePositions(const AllAtomPositions& all_pos,
@@ -72,7 +81,12 @@ public:
   void SetCpuPlatformSupport(bool cpu_platform_support);
 
   // getters for data
-  uint GetLoopLength() const { return loop_length_; }
+  uint GetNumResidues() const { return ff_aa_.size(); }
+  uint GetNumLoopResidues() const { return num_loop_residues_; }
+  const std::vector<uint>& GetLoopStartIndices() const {
+    return loop_start_indices_;
+  }
+  const std::vector<uint>& GetLoopLengths() const { return loop_lengths_; }
   const std::vector<ForcefieldAminoAcid>& GetForcefieldAminoAcids() const {
     return ff_aa_;
   }
@@ -96,6 +110,22 @@ protected:
                        const std::vector<uint>& res_indices);
   void SetupSimulation_();
 
+  // dealing with loops
+  enum IndexLocation_ {
+    ON_N_STEM, ON_C_STEM, IN_LOOP, OUT_OF_LOOP
+  };
+  IndexLocation_ GetIndexLocation_(uint i_res) {
+    // check all loops
+    const uint num_loops = loop_start_indices_.size();
+    for (uint i_loop = 0; i_loop < num_loops; ++i_loop) {
+      const uint tst = i_res - loop_start_indices_[i_loop];
+      if (tst == 0) return ON_N_STEM;
+      if (tst == loop_lengths_[i_loop]-1) return ON_C_STEM;
+      if (tst < loop_lengths_[i_loop]-1) return IN_LOOP;
+    }
+    return OUT_OF_LOOP;
+  }
+
   // settings
   ForcefieldLookupPtr ff_lookup_;
   bool fix_surrounding_hydrogens_;
@@ -105,7 +135,9 @@ protected:
   bool cpu_platform_support_;
 
   // input data
-  uint loop_length_;
+  uint num_loop_residues_;                 // sum of loop lengths
+  std::vector<uint> loop_start_indices_;   // index in [0, num_residues-1]
+  std::vector<uint> loop_lengths_;         // len = len(loop_start_indices_)
   std::vector<bool> is_n_ter_;             // len = num_residues
   std::vector<bool> is_c_ter_;             // len = num_residues
   // internal res. idx. of next residue or -1 if no next res. in system
diff --git a/loop/tests/test_mm_system_creator.cc b/loop/tests/test_mm_system_creator.cc
index a4d8555d..0ba27a2c 100644
--- a/loop/tests/test_mm_system_creator.cc
+++ b/loop/tests/test_mm_system_creator.cc
@@ -290,7 +290,7 @@ void CheckLoops(const AllAtomPositions& all_pos, bool fix_surrounding_hydrogens,
   Real pot_ref = mm_sim.GetSimulation()->GetPotentialEnergy();
   SetupLoop(all_pos, 0, all_pos.GetNumResidues()-1, mm_sim);
   BOOST_CHECK_EQUAL(mm_sim.GetSimulation()->GetPotentialEnergy(), pot_ref);
-  BOOST_CHECK_EQUAL(mm_sim.GetLoopLength(), all_pos.GetNumResidues());
+  BOOST_CHECK_EQUAL(mm_sim.GetNumLoopResidues(), all_pos.GetNumResidues());
   // try some loops
   const uint loop_length = 5;
   SetupLoop(all_pos, 0, loop_length - 1, mm_sim);
@@ -305,7 +305,7 @@ void CheckLoops(const AllAtomPositions& all_pos, bool fix_surrounding_hydrogens,
   }
   for (uint i = 0; i <= all_pos.GetNumResidues() - loop_length; i += 10) {
     SetupLoop(all_pos, i, i + loop_length - 1, mm_sim);
-    BOOST_CHECK_EQUAL(mm_sim.GetLoopLength(), loop_length);
+    BOOST_CHECK_EQUAL(mm_sim.GetNumLoopResidues(), loop_length);
     if (fix_surrounding_hydrogens && kill_es) {
       BOOST_CHECK_CLOSE(mm_sim.GetSimulation()->GetPotentialEnergy(),
                         pot_ref_loop, 40);
-- 
GitLab