From 44e1ff4c8a067188f115db25717dfb81ad0bc8e7 Mon Sep 17 00:00:00 2001
From: Marco Biasini <marco.biasini@unibas.ch>
Date: Tue, 12 Oct 2010 17:59:46 +0200
Subject: [PATCH] added on-demand loader for CHARMM trajectory data

also, use the CHARMM dialect to load structures files by default.
Fixes tons of warnings with structures produced by CHARMM.
---
 modules/conop/data/charmm.cif       |   2 +-
 modules/io/pymod/__init__.py        |  33 +++-
 modules/io/pymod/wrap_io.cc         |  11 +-
 modules/io/src/mol/dcd_io.cc        | 272 ++++++++++++++++++----------
 modules/io/src/mol/dcd_io.hh        |  16 +-
 modules/mol/base/src/coord_group.hh |   7 +-
 6 files changed, 220 insertions(+), 121 deletions(-)

diff --git a/modules/conop/data/charmm.cif b/modules/conop/data/charmm.cif
index ced6449d3..536843430 100644
--- a/modules/conop/data/charmm.cif
+++ b/modules/conop/data/charmm.cif
@@ -1040,7 +1040,7 @@ ILE O O O 0 N Y N O ILE 4
 ILE CB CB C 0 N N N CB ILE 5
 ILE CG2 CG2 C 0 N N N CG2 ILE 6
 ILE CG1 CG1 C 0 N N N CG1 ILE 7
-ILE CD CD C 0 N N N CD ILE 8
+ILE CD CD1 C 0 N N N CD ILE 8
 ILE HA HA H 0 N N N HA ILE 9
 ILE HB HB H 0 N N N HB ILE 10
 ILE HG21 HG21 H 0 N N N HG21 ILE 11
diff --git a/modules/io/pymod/__init__.py b/modules/io/pymod/__init__.py
index 2cf6900a1..5534c540d 100644
--- a/modules/io/pymod/__init__.py
+++ b/modules/io/pymod/__init__.py
@@ -187,11 +187,38 @@ def LoadImageList (files):
 
 LoadMapList=LoadImageList
 
+def LoadCHARMMTraj(crd, dcd_file=None, lazy_load=False, stride=1, 
+                   dialect='CHARMM'):
+  """
+  Load CHARMM trajectory file.
+  
+  :param crd: EntityHandle or filename of the CRD (PDB) file containing the
+      structure. The structure must have the same number of atoms as the 
+      trajectory
+  :param dcd_file: The filename of the DCD file. If not set, and crd is a string, 
+      the filename is set to the <crd>.dcd
+  :param layz_load: Whether the trajectory should be loaded on demand. Instead of 
+      loading the complete trajectory into memory, the trajectory frames are 
+      loaded from disk when requested.
+  :param stride: The spacing of the frames to load. When set to 2, for example, 
+      every second frame is loaded from the trajectory
+  """
+  if not isinstance(crd, mol.EntityHandle):
+    if dcd_file==None:
+      dcd_file='%s.dcd' % os.path.splitext(crd)[0]    
+    crd=LoadPDB(crd, dialect=dialect)
+
+  else:
+    if not dcd_file:
+      raise ValueError("No DCD filename given")
+  return LoadCHARMMTraj_(crd, dcd_file, stride, lazy_load)
+
 ## \example fft_li.py
 #
-# This scripts loads one or more images and shows their Fourier Transforms on the screen. A viewer
-# is opened for each loaded image. The Fourier Transform honors the origin of the reference system,
-# which is assumed to be at the center of the image.
+# This scripts loads one or more images and shows their Fourier Transforms on 
+# the screen. A viewer is opened for each loaded image. The Fourier Transform 
+# honors the origin of the reference system, which is assumed to be at the 
+# center of the image.
 #
 # Usage:
 #
diff --git a/modules/io/pymod/wrap_io.cc b/modules/io/pymod/wrap_io.cc
index dc5ce987c..6eefe1728 100644
--- a/modules/io/pymod/wrap_io.cc
+++ b/modules/io/pymod/wrap_io.cc
@@ -69,11 +69,6 @@ BOOST_PYTHON_FUNCTION_OVERLOADS(save_entity_view_ov,
 BOOST_PYTHON_FUNCTION_OVERLOADS(save_charmm_trj_ov,
                                 SaveCHARMMTraj, 3, 4)
 
-mol::CoordGroupHandle load_dcd1(const String& c, const String& t) {return LoadCHARMMTraj(c,t);}
-mol::CoordGroupHandle load_dcd2(const String& c, const String& t, unsigned int s) {return LoadCHARMMTraj(c,t,s);}
-mol::CoordGroupHandle load_dcd3(const mol::EntityHandle& e, const String& t) {return LoadCHARMMTraj(e,t);}
-mol::CoordGroupHandle load_dcd4(const mol::EntityHandle& e, const String& t, unsigned int s) {return LoadCHARMMTraj(e,t,s);}
-
 }
 
 void export_pdb_io();
@@ -115,10 +110,8 @@ BOOST_PYTHON_MODULE(_io)
   def("LoadSDF", &LoadSDF);
 
   def("LoadCRD", &LoadCRD);
-  def("LoadCHARMMTraj",load_dcd1);
-  def("LoadCHARMMTraj",load_dcd2);
-  def("LoadCHARMMTraj",load_dcd3);
-  def("LoadCHARMMTraj",load_dcd4);
+  def("LoadCHARMMTraj_", &LoadCHARMMTraj, (arg("ent"), arg("trj_filename"), 
+      arg("stride")=1, arg("lazy_load")=false));
   def("SaveCHARMMTraj",SaveCHARMMTraj,save_charmm_trj_ov());
 
   def("LoadMAE", &LoadMAE);
diff --git a/modules/io/src/mol/dcd_io.cc b/modules/io/src/mol/dcd_io.cc
index 177cd57af..86a59ed42 100644
--- a/modules/io/src/mol/dcd_io.cc
+++ b/modules/io/src/mol/dcd_io.cc
@@ -66,42 +66,30 @@ bool less_index(const mol::AtomHandle& a1, const mol::AtomHandle& a2)
   return a1.GetIndex()<a2.GetIndex();
 }
 
-mol::CoordGroupHandle load_dcd(const mol::AtomHandleList& alist2,
-			       const String& trj_fn,
-			       unsigned int stride)
+bool read_dcd_header(std::istream& istream, DCDHeader& header, bool& swap_flag, 
+                     bool& skip_flag, bool& gap_flag)
 {
-  Profile profile_load("LoadCHARMMTraj");
-
-  mol::AtomHandleList alist(alist2);
-  std::sort(alist.begin(),alist.end(),less_index);
-
-  bool gap_flag = true;
-
-  boost::filesystem::path trj_f(trj_fn);
-  boost::filesystem::ifstream ff(trj_f, std::ios::binary);
-  
-  DCDHeader header;
-  char dummy[4];
-  bool swap_flag=false;
-
-  LOG_INFO("importing trajectory data");
-
-  if(gap_flag) ff.read(dummy,sizeof(dummy));
-  ff.read(header.hdrr,sizeof(char)*4);
-  ff.read(reinterpret_cast<char*>(header.icntrl),sizeof(int)*20);
+  if (!istream) {
+    return false;
+  }
+  char dummy[4];  
+  gap_flag=true;
+  swap_flag=false;
+  skip_flag=false;
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  istream.read(header.hdrr,sizeof(char)*4);
+  istream.read(reinterpret_cast<char*>(header.icntrl),sizeof(int)*20);
   if(header.icntrl[1]<0 || header.icntrl[1]>1e8) {
     // nonsense atom count, try swapping
     swap_int(header.icntrl,20);
     if(header.icntrl[1]<0 || header.icntrl[1]>1e8) {
       throw(IOException("LoadCHARMMTraj: nonsense atom count in header"));
     } else {
-      LOG_INFO("LoadCHARMMTraj: byte-swapping");
+      LOG_VERBOSE("LoadCHARMMTraj: byte-swapping");
       swap_flag=true;
     }
   }
 
-  bool skip_flag=false;
-
   if(header.icntrl[19]!=0) { // CHARMM format
     skip_flag=(header.icntrl[10]!=0);
     if(skip_flag) {
@@ -113,24 +101,102 @@ mol::CoordGroupHandle load_dcd(const mol::AtomHandleList& alist2,
     // XPLOR format
     LOG_VERBOSE("LoadCHARMMTraj: using XPLOR format");
   }
-
-  if(gap_flag) ff.read(dummy,sizeof(dummy));
-  ff.read(reinterpret_cast<char*>(&header.ntitle),sizeof(int));
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  if (!istream) {
+    return false;
+  }
+  istream.read(reinterpret_cast<char*>(&header.ntitle),sizeof(int));
+  if (!istream) {
+    return false;
+  }
   if(swap_flag) swap_int(&header.ntitle,1);
-  if(gap_flag) ff.read(dummy,sizeof(dummy));
-  ff.read(header.title,sizeof(char)*header.ntitle);
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+
+  istream.read(header.title,sizeof(char)*header.ntitle);
   header.title[header.ntitle]='\0';
-  if(gap_flag) ff.read(dummy,sizeof(dummy));
-  ff.read(reinterpret_cast<char*>(&header.t_atom_count),sizeof(int));
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  istream.read(reinterpret_cast<char*>(&header.t_atom_count),sizeof(int));
   if(swap_flag) swap_int(&header.t_atom_count,1);
-  if(gap_flag) ff.read(dummy,sizeof(dummy));
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
   header.num=header.icntrl[0];
   header.istep=header.icntrl[1];
   header.freq=header.icntrl[2];
   header.nstep=header.icntrl[3];
   header.f_atom_count=header.icntrl[8];
   header.atom_count=header.t_atom_count-header.f_atom_count;
+  return true;
+}
+
+
+size_t calc_frame_size(bool skip_flag, bool gap_flag, size_t num_atoms)
+{
+  size_t frame_size=0;
+  if (skip_flag) {
+    frame_size+=14*sizeof(int);
+  }
+  if (gap_flag) {
+    frame_size+=6*sizeof(int);
+  }
+  frame_size+=3*sizeof(float)*num_atoms;
+  return frame_size;
+}
+
+bool read_frame(std::istream& istream, const DCDHeader& header, 
+                size_t frame_size, bool skip_flag, bool gap_flag, 
+                bool swap_flag, std::vector<float>& xlist,
+                std::vector<geom::Vec3>& frame)
+{
+  char dummy[4];
+  if(skip_flag) istream.seekg(14*4,std::ios_base::cur);
+  // read each frame
+  if(!istream) {
+    /* premature EOF */
+    LOG_ERROR("LoadCHARMMTraj: premature end of file, frames read");
+    return false;
+  }
+  // x coord
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  istream.read(reinterpret_cast<char*>(&xlist[0]),sizeof(float)*xlist.size());
+  if(swap_flag) swap_float(&xlist[0],xlist.size());
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  for(uint j=0;j<frame.size();++j) {
+    frame[j].x=xlist[j];
+  }
+
+  // y coord
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  istream.read(reinterpret_cast<char*>(&xlist[0]),sizeof(float)*xlist.size());
+  if(swap_flag) swap_float(&xlist[0],xlist.size());
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  for(uint j=0;j<frame.size();++j) {
+    frame[j].y=xlist[j];
+  }
+
+  // z coord
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  istream.read(reinterpret_cast<char*>(&xlist[0]),sizeof(float)*xlist.size());
+  if(swap_flag) swap_float(&xlist[0],xlist.size());
+  if(gap_flag) istream.read(dummy,sizeof(dummy));
+  for(uint j=0;j<frame.size();++j) {
+    frame[j].z=xlist[j];
+  }
+  return true;
+}
+
+
+mol::CoordGroupHandle load_dcd(const mol::AtomHandleList& alist2,
+                               const String& trj_fn,
+                               unsigned int stride)
+{
+  Profile profile_load("LoadCHARMMTraj");
 
+  mol::AtomHandleList alist(alist2);
+  std::sort(alist.begin(),alist.end(),less_index);
+  
+  std::ifstream istream(trj_fn.c_str(), std::ios::binary);
+  DCDHeader header; 
+  bool swap_flag=false, skip_flag=false, gap_flag=false;
+  read_dcd_header(istream, header, swap_flag, skip_flag, gap_flag);
   LOG_DEBUG("LoadCHARMMTraj: " << header.num << " trajectories with " 
                << header.atom_count << " atoms (" << header.f_atom_count 
                << " fixed) each");
@@ -145,63 +211,23 @@ mol::CoordGroupHandle load_dcd(const mol::AtomHandleList& alist2,
   mol::CoordGroupHandle cg=CreateCoordGroup(alist);
   std::vector<geom::Vec3> clist(header.t_atom_count);
   std::vector<float> xlist(header.t_atom_count);
-
-  size_t frame_size=0;
-  if (skip_flag) {
-    frame_size+=14*4;
-  }
-  if (gap_flag) {
-    frame_size+=6*sizeof(dummy);
-  }
-  frame_size+=3*sizeof(float)*xlist.size();
-
+  size_t frame_size=calc_frame_size(skip_flag, gap_flag, xlist.size());
   int i=0;
   for(;i<header.num;i+=stride) {
-    if(skip_flag) ff.seekg(14*4,std::ios_base::cur);
-    // read each frame
-    if(!ff) {
-      /* premature EOF */
-      LOG_ERROR("LoadCHARMMTraj: premature end of file, " << i 
-                 << " frames read");
+    if (!read_frame(istream, header, frame_size, skip_flag, gap_flag, 
+                    swap_flag, xlist, clist)) {
       break;
     }
-    // x coord
-    if(gap_flag) ff.read(dummy,sizeof(dummy));
-    ff.read(reinterpret_cast<char*>(&xlist[0]),sizeof(float)*xlist.size());
-    if(swap_flag) swap_float(&xlist[0],xlist.size());
-    if(gap_flag) ff.read(dummy,sizeof(dummy));
-    for(uint j=0;j<clist.size();++j) {
-      clist[j].x=xlist[j];
-    }
-
-    // y coord
-    if(gap_flag) ff.read(dummy,sizeof(dummy));
-    ff.read(reinterpret_cast<char*>(&xlist[0]),sizeof(float)*xlist.size());
-    if(swap_flag) swap_float(&xlist[0],xlist.size());
-    if(gap_flag) ff.read(dummy,sizeof(dummy));
-    for(uint j=0;j<clist.size();++j) {
-      clist[j].y=xlist[j];
-    }
-
-    // z coord
-    if(gap_flag) ff.read(dummy,sizeof(dummy));
-    ff.read(reinterpret_cast<char*>(&xlist[0]),sizeof(float)*xlist.size());
-    if(swap_flag) swap_float(&xlist[0],xlist.size());
-    if(gap_flag) ff.read(dummy,sizeof(dummy));
-    for(uint j=0;j<clist.size();++j) {
-      clist[j].z=xlist[j];
-    }
-
     cg.AddFrame(clist);
 
     // skip frames (defined by stride)
-    if(stride>1) ff.seekg(frame_size*(stride-1),std::ios_base::cur);
+    if(stride>1) istream.seekg(frame_size*(stride-1),std::ios_base::cur);
   }
 
-  ff.get();
-  if(!ff.eof()) {
+  istream.get();
+  if(!istream.eof()) {
     LOG_VERBOSE("LoadCHARMMTraj: unexpected trailing file data, bytes read: " 
-                 << ff.tellg());
+                 << istream.tellg());
   }
 
   LOG_VERBOSE("Loaded " << cg.GetFrameCount() << " frames with " << cg.GetAtomCount() << " atoms each");
@@ -209,25 +235,89 @@ mol::CoordGroupHandle load_dcd(const mol::AtomHandleList& alist2,
   return cg;
 }
 
-} // anon ns
+class  DCDCoordSource : public mol::CoordSource {
+public:
+  DCDCoordSource(const mol::AtomHandleList& atoms, const String& filename, 
+                 uint stride): 
+    mol::CoordSource(atoms), filename_(filename), 
+    stream_(filename.c_str(), std::ios::binary), loaded_(false), stride_(stride)
+  {
+    this->SetMutable(false);
+    frame_=mol::CoordFramePtr(new mol::CoordFrame(atoms.size()));
+  }
+    
+  
+  virtual uint GetFrameCount() 
+  { 
+    if (!frame_count_)
+      const_cast<DCDCoordSource*>(this)->FetchFrame(0);
+    return frame_count_; 
+  }
+  
+  virtual mol::CoordFramePtr GetFrame(uint frame_id) const {
+    const_cast<DCDCoordSource*>(this)->FetchFrame(frame_id);
+    return frame_;
+  }
 
-mol::CoordGroupHandle LoadCHARMMTraj(const String& crd_fn,
-                                     const String& trj_fn,
-                                     unsigned int stride)
+  virtual void AddFrame(const std::vector<geom::Vec3>& coords) {}
+  virtual void InsertFrame(int pos, const std::vector<geom::Vec3>& coords) {}
+private:
+  
+  void FetchFrame(uint frame);
+  String               filename_;
+  DCDHeader            header_;
+  bool                 skip_flag_;
+  bool                 swap_flag_;
+  bool                 gap_flag_;
+  std::ifstream        stream_;
+  bool                 loaded_;
+  uint                 frame_count_;
+  uint                 curr_frame_;
+  uint                 stride_;
+  size_t               frame_start_;
+  mol::CoordFramePtr   frame_;
+};
+
+
+void DCDCoordSource::FetchFrame(uint frame)
 {
-  mol::EntityHandle ent=LoadEntity(crd_fn);
-  return load_dcd(ent.GetAtomList(),trj_fn,stride);
+  if (!loaded_) {
+    read_dcd_header(stream_, header_, swap_flag_, skip_flag_, gap_flag_);
+    frame_start_=stream_.tellg();
+    loaded_=true;
+    frame_count_=header_.num;
+  }
+  size_t frame_size=calc_frame_size(skip_flag_, gap_flag_, 
+                                    header_.t_atom_count);  
+  size_t pos=frame_start_+frame_size*frame*stride_;
+  stream_.seekg(pos,std::ios_base::beg);
+  std::vector<float> xlist(header_.t_atom_count);
+  if (!read_frame(stream_, header_, frame_size, skip_flag_, gap_flag_, 
+                  swap_flag_, xlist, *frame_.get())) {
+  }  
 }
 
+typedef boost::shared_ptr<DCDCoordSource> DCDCoordSourcePtr;
+
+
+} // anon ns
+
 
-mol::CoordGroupHandle DLLEXPORT_OST_IO LoadCHARMMTraj(const mol::EntityHandle& e,
-						      const String& trj_fn,
-						      unsigned int stride)
+mol::CoordGroupHandle LoadCHARMMTraj(const mol::EntityHandle& ent,
+                                     const String& trj_fn,
+                                     unsigned int stride, bool lazy_load)
 {
-  return load_dcd(e.GetAtomList(),trj_fn,stride);
+  mol::AtomHandleList alist(ent.GetAtomList());
+  std::sort(alist.begin(),alist.end(),less_index);
+  if (lazy_load) {
+    LOG_INFO("Importing CHARMM trajectory with lazy_load=true");
+    DCDCoordSource* source=new DCDCoordSource(alist, trj_fn, stride);
+    return mol::CoordGroupHandle(DCDCoordSourcePtr(source));
+  }
+    LOG_INFO("Importing CHARMM trajectory with lazy_load=false");  
+  return load_dcd(alist, trj_fn, stride);
 }
 
-
 namespace {
 
 void write_dcd_hdr(std::ofstream& out,
diff --git a/modules/io/src/mol/dcd_io.hh b/modules/io/src/mol/dcd_io.hh
index b20a3fb0f..24d189e63 100644
--- a/modules/io/src/mol/dcd_io.hh
+++ b/modules/io/src/mol/dcd_io.hh
@@ -29,22 +29,14 @@
 
 namespace ost { namespace io {
 
-/*! \brief import a CHARMM trajectory in dcd format
-    requires the coordinate and the trajectory file; the format
-    of the coordinate file will be automatically deduced from the extension
-    the optional stride parameter will cause only every nth frame to be loaded
-*/
-mol::CoordGroupHandle DLLEXPORT_OST_IO LoadCHARMMTraj(const String& coord,
-						      const String& trj,
-						      unsigned int stride=1);
-
 /*! \brief import a CHARMM trajectory in dcd format with an existing entity
     requires the existing entity and the trajectory file - obviously the
     atom layout of the entity must match the trajectory file
 */
-mol::CoordGroupHandle DLLEXPORT_OST_IO LoadCHARMMTraj(const mol::EntityHandle& e,
-						      const String& trj,
-						      unsigned int stride=1);
+mol::CoordGroupHandle DLLEXPORT_OST_IO LoadCHARMMTraj(const mol::EntityHandle& ent,
+                                                       const String& trj_filename,
+                                                       unsigned int stride=1,
+                                                       bool lazy_load=false);
 
 
 /*! \brief export coord group as PDB file and DCD trajectory
diff --git a/modules/mol/base/src/coord_group.hh b/modules/mol/base/src/coord_group.hh
index 38fa81d2e..70844a26b 100644
--- a/modules/mol/base/src/coord_group.hh
+++ b/modules/mol/base/src/coord_group.hh
@@ -33,9 +33,6 @@ namespace ost { namespace mol {
 
 /// \brief coordinate group, for trajectories and such
 class DLLEXPORT_OST_MOL CoordGroupHandle {
-  friend DLLEXPORT_OST_MOL 
-  CoordGroupHandle CreateCoordGroup(const std::vector<AtomHandle>&);
-
 public:
   /// \brief create empty, invalid handle
   CoordGroupHandle();
@@ -84,10 +81,10 @@ public:
   
   AtomHandleList GetAtomList() const;
   CoordFramePtr GetFrame(uint frame) const;
-private:
-  void CheckValidity() const;
   
   CoordGroupHandle(CoordSourcePtr source);
+private:
+  void CheckValidity() const;
 
   CoordSourcePtr source_;
 };
-- 
GitLab