From ac4ffe27ded452078fdf29cd09b8bc3de5c8f618 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 5 Jan 2024 14:57:02 +0100
Subject: [PATCH] mmcif writer: enable ost::mol::EntityHandle in MMCifWriter

---
 modules/io/pymod/export_mmcif_io.cc |  15 +-
 modules/io/src/mol/mmcif_writer.cc  | 244 +++++++++++++++++-----------
 modules/io/src/mol/mmcif_writer.hh  |  13 ++
 3 files changed, 172 insertions(+), 100 deletions(-)

diff --git a/modules/io/pymod/export_mmcif_io.cc b/modules/io/pymod/export_mmcif_io.cc
index de498d382..15c0beebe 100644
--- a/modules/io/pymod/export_mmcif_io.cc
+++ b/modules/io/pymod/export_mmcif_io.cc
@@ -70,6 +70,18 @@ void WrapStarWriterWrite(StarWriter& writer, const String& data_name,
   writer.Write(data_name, filename);
 }
 
+void WrapSetStructureHandle(MMCifWriter& writer,
+                            const ost::mol::EntityHandle& ent,
+                            bool mmcif_conform) {
+  writer.SetStructure(ent, mmcif_conform);
+}
+
+void WrapSetStructureView(MMCifWriter& writer,
+                          const ost::mol::EntityView& ent,
+                          bool mmcif_conform) {
+  writer.SetStructure(ent, mmcif_conform);
+}
+
 void export_mmcif_io()
 {
   class_<MMCifReader, boost::noncopyable>("MMCifReader", init<const String&, EntityHandle&, const IOProfile&>())
@@ -124,7 +136,8 @@ void export_mmcif_io()
   ;
 
   class_<MMCifWriter, bases<StarWriter> >("MMCifWriter", init<>())
-    .def("SetStructure", &MMCifWriter::SetStructure, (arg("ent"), arg("mmcif_conform")=true))
+    .def("SetStructure", &WrapSetStructureHandle, (arg("ent"), arg("mmcif_conform")=true))
+    .def("SetStructure", &WrapSetStructureView, (arg("ent"), arg("mmcif_conform")=true))
   ;
 
   enum_<MMCifInfoCitation::MMCifInfoCType>("MMCifInfoCType")
diff --git a/modules/io/src/mol/mmcif_writer.cc b/modules/io/src/mol/mmcif_writer.cc
index 35dd15c6b..2c5490ea4 100644
--- a/modules/io/src/mol/mmcif_writer.cc
+++ b/modules/io/src/mol/mmcif_writer.cc
@@ -68,7 +68,9 @@ namespace {
     std::vector<int> indices;
   };
 
-  String GuessEntityPolyType(const ost::mol::ResidueHandleList& res_list) {
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
+  String GuessEntityPolyType(const T& res_list) {
 
     // guesses _entity_poly.type based on residue chem classes
 
@@ -152,7 +154,9 @@ namespace {
     return "other";
   }
 
-  String GuessEntityType(const ost::mol::ResidueHandleList& res_list) {
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
+  String GuessEntityType(const T& res_list) {
 
     // guesses _entity.type based on residue chem classes
 
@@ -188,11 +192,6 @@ namespace {
     return "polymer";
   }
 
-  // internal object with all info to fill chem_comp_ category
-  struct CompInfo {
-    String type;
-  };
-
   inline String ChemClassToChemCompType(char chem_class) {
     String type = "";
     switch(chem_class) {
@@ -435,8 +434,10 @@ namespace {
     return "(" + mon_id + ")";
   }
 
-  void SetupChemComp(const ost::mol::ResidueHandleList& res_list,
-                     std::map<String, CompInfo>& comp_infos) {
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
+  void SetupChemComp(const T& res_list,
+                     std::map<String, ost::io::MMCifWriterComp>& comp_infos) {
     for(auto res: res_list) {
       String res_name = res.GetName();
       String type = ChemClassToChemCompType(res.GetChemClass());
@@ -450,7 +451,7 @@ namespace {
           if(type == "OTHER") {
             continue; 
           } else if (it->second.type == "OTHER") {
-            CompInfo info;
+            ost::io::MMCifWriterComp info;
             info.type = type;
             comp_infos[res_name] = info;
           } else {
@@ -463,14 +464,16 @@ namespace {
           }
         }
       } else {
-        CompInfo info;
+        ost::io::MMCifWriterComp info;
         info.type = type;
         comp_infos[res_name] = info;
       }
     }
   }
 
-  bool MatchEntity(const ost::mol::ResidueHandleList& res_list,
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
+  bool MatchEntity(const T& res_list,
                    const ost::io::MMCifWriterEntity& info) {
     // checks if the residue names in res_list are an exact match
     // with mon_ids in info
@@ -489,7 +492,9 @@ namespace {
     info.asym_alns.push_back(info.mon_ids);
   }
 
-  bool MatchEntityResnum(const ost::mol::ResidueHandleList& res_list,
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
+  bool MatchEntityResnum(const T& res_list,
                          const ost::io::MMCifWriterEntity& info,
                          Real beyond_frac = 0.05) {
     // Checks if res_list matches SEQRES given in info.mon_ids
@@ -540,8 +545,10 @@ namespace {
     }
   }
 
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
   void AddAsymResnum(const String& asym_chain_name,
-                     const ost::mol::ResidueHandleList& res_list,
+                     const T& res_list,
                      ost::io::MMCifWriterEntity& info) {
 
     if(!info.is_poly) {
@@ -584,10 +591,9 @@ namespace {
     for(size_t i = 0; i < resnums.size(); ++i) {
       if(info.mon_ids[resnums[i]-1] == "-") {
         info.mon_ids[resnums[i]-1] = mon_ids[i];
-        const ost::mol::ResidueHandle& res = res_list[i];
-        info.seq_olcs[resnums[i]-1] = MonIDToOLC(res.GetChemClass(),
+        info.seq_olcs[resnums[i]-1] = MonIDToOLC(res_list[i].GetChemClass(),
                                                  mon_ids[i]);
-        char olc = res.GetOneLetterCode();
+        char olc = res_list[i].GetOneLetterCode();
         if(olc < 'A' || olc > 'Z') {
           info.seq_can_olcs[resnums[i]-1] = "X";
         } else {
@@ -601,10 +607,12 @@ namespace {
     info.asym_alns.push_back(aln_mon_ids);
   }
 
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
   int SetupEntity(const String& asym_chain_name,
                   const String& type,
                   const String& poly_type,
-                  const ost::mol::ResidueHandleList& res_list,
+                  const T& res_list,
                   bool resnum_alignment,
                   std::vector<ost::io::MMCifWriterEntity>& entity_infos) {
 
@@ -678,10 +686,9 @@ namespace {
       seq_can.assign(max_resnum, "-");
       for(size_t i = 0; i < res_mon_ids.size(); ++i) {
         mon_ids[resnums[i]-1] = res_mon_ids[i];
-        const ost::mol::ResidueHandle& res = res_list[i];
-        seq[resnums[i]-1] = MonIDToOLC(res.GetChemClass(),
+        seq[resnums[i]-1] = MonIDToOLC(res_list[i].GetChemClass(),
                                        mon_ids[resnums[i]-1]);
-        char olc = res.GetOneLetterCode();
+        char olc = res_list[i].GetOneLetterCode();
         if(olc < 'A' || olc > 'Z') {
           seq_can[resnums[i]-1] = "X";
         } else {
@@ -725,8 +732,10 @@ namespace {
     return entity_idx;
   }
 
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
   int SetupEntity(const String& asym_chain_name,
-                  const ost::mol::ResidueHandleList& res_list,
+                  const T& res_list,
                   bool resnum_alignment,
                   std::vector<ost::io::MMCifWriterEntity>& entity_infos) {
     // use chem types in res_list to determine _entity.type and
@@ -741,9 +750,11 @@ namespace {
                        resnum_alignment, entity_infos);
   }
 
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
   int SetupEntity(const String& asym_chain_name,
                   ost::mol::ChainType chain_type,
-                  const ost::mol::ResidueHandleList& res_list,
+                  const T& res_list,
                   bool resnum_alignment, 
                   std::vector<ost::io::MMCifWriterEntity>& entity_infos) {
     // use chain_type info attached to chain to determine
@@ -864,11 +875,13 @@ namespace {
     }
   }
 
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
   void Feed_pdbx_poly_seq_scheme(ost::io::StarWriterLoopPtr pdbx_poly_seq_scheme_ptr,
                                  const String& label_asym_id,
                                  int label_entity_id,
                                  ost::io::MMCifWriterEntity& entity_info,
-                                 const ost::mol::ResidueHandleList& res_list) {
+                                 const T& res_list) {
 
     std::vector<ost::io::StarWriterValue> data;
     data.push_back(ost::io::StarWriterValue::FromString(label_asym_id));
@@ -916,7 +929,7 @@ namespace {
         if(ins_code == '\0') {
           data[6] = ost::io::StarWriterValue::FromString("");
         } else {
-          data[6] = ost::io::StarWriterValue::FromString(String(1, ' '));
+          data[6] = ost::io::StarWriterValue::FromString(String(1, ins_code));
         }      
       }
       pdbx_poly_seq_scheme_ptr->AddData(data);
@@ -924,11 +937,13 @@ namespace {
     }
   }
 
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
   void Feed_atom_site(ost::io::StarWriterLoopPtr atom_site_ptr,
                       const String& label_asym_id,
                       int label_entity_id,
                       const ost::io::MMCifWriterEntity& entity_info,
-                      const ost::mol::ResidueHandleList& res_list) {
+                      const T& res_list) {
 
     int asym_idx = entity_info.GetAsymIdx(label_asym_id);
     const std::vector<String>& aln = entity_info.asym_alns[asym_idx];
@@ -937,7 +952,7 @@ namespace {
     for(auto res: res_list) {
       String comp_id = res.GetName();
 
-      ost::mol::AtomHandleList at_list = res.GetAtomList();
+      auto at_list = res.GetAtomList();
       String auth_asym_id = res.GetChain().GetName();
       if(res.HasProp("pdb_auth_chain_name")) {
         auth_asym_id = res.GetStringProp("pdb_auth_chain_name");
@@ -1086,7 +1101,7 @@ namespace {
   }
 
   void Feed_chem_comp(ost::io::StarWriterLoopPtr chem_comp_ptr,
-                      const std::map<String, CompInfo>& comp_infos) {
+                      const std::map<String, ost::io::MMCifWriterComp>& comp_infos) {
     std::vector<ost::io::StarWriterValue> comp_data;
     comp_data.push_back(ost::io::StarWriterValue::FromString("ALA"));
     comp_data.push_back(ost::io::StarWriterValue::FromString("L-PEPTIDE LINKING"));
@@ -1097,54 +1112,27 @@ namespace {
     }
   }
 
-  void ProcessEnt(const ost::mol::EntityHandle& ent,
-                  std::map<String, CompInfo>& comp_infos,
-                  std::vector<ost::io::MMCifWriterEntity>& entity_info,
-                  ost::io::StarWriterLoopPtr atom_site,
-                  ost::io::StarWriterLoopPtr pdbx_poly_seq_scheme) {
-    ost::mol::ChainHandleList chain_list = ent.GetChainList();
-    for(auto ch: chain_list) {
-
-      ost::mol::ResidueHandleList res_list = ch.GetResidueList();
-
-      SetupChemComp(res_list, comp_infos);
-      String chain_name = ch.GetName();
-      int entity_id = SetupEntity(chain_name,
-                                  ch.GetType(),
-                                  res_list,
-                                  true,
-                                  entity_info);
-      Feed_atom_site(atom_site, chain_name, entity_id, entity_info[entity_id],
-                     res_list);
-      if(entity_info[entity_id].is_poly) {
-        Feed_pdbx_poly_seq_scheme(pdbx_poly_seq_scheme, chain_name,
-                                  entity_id, entity_info[entity_id], res_list);
-      }
-    }
-  }
-
-  void ProcessEntmmCIFify(const ost::mol::EntityHandle& ent,
-                          std::map<String, CompInfo>& comp_infos,
+  // template to allow ost::mol::ResidueHandleList and ost::mol::ResidueViewList
+  template<class T>
+  void ProcessEntmmCIFify(const std::vector<T>& res_lists,
+                          std::map<String, ost::io::MMCifWriterComp>& comp_infos,
                           std::vector<ost::io::MMCifWriterEntity>& entity_info,
                           ost::io::StarWriterLoopPtr atom_site,
                           ost::io::StarWriterLoopPtr pdbx_poly_seq_scheme) {
 
     ChainNameGenerator chain_name_gen;
 
-    ost::mol::ChainHandleList chain_list = ent.GetChainList();
-    for(auto ch: chain_list) {
-
-      ost::mol::ResidueHandleList res_list = ch.GetResidueList();
+    for(auto res_list: res_lists) {
 
       SetupChemComp(res_list, comp_infos);
 
-      std::vector<ost::mol::ResidueHandle> L_chain;
-      std::vector<ost::mol::ResidueHandle> D_chain;
-      std::vector<ost::mol::ResidueHandle> P_chain;
-      std::vector<ost::mol::ResidueHandle> R_chain;
-      std::vector<ost::mol::ResidueHandle> S_chain;
-      std::vector<ost::mol::ResidueHandle> Z_chain;
-      std::vector<ost::mol::ResidueHandle> W_chain;
+      T L_chain;
+      T D_chain;
+      T P_chain;
+      T R_chain;
+      T S_chain;
+      T Z_chain;
+      T W_chain;
 
       // first scan only concerning peptides...
       // Avoid mix of both in same chain: L-peptide linking, D-peptide linking
@@ -1193,7 +1181,7 @@ namespace {
         } else if(res.GetChemClass() == ost::mol::ChemClass::NON_POLYMER ||
                   res.GetChemClass() == ost::mol::ChemClass::UNKNOWN) {
           // already process non-poly and unknown
-          ost::mol::ResidueHandleList tmp;
+          T tmp;
           tmp.push_back(res);
           String chain_name = chain_name_gen.Get();
           int entity_id = SetupEntity(chain_name,
@@ -1212,11 +1200,7 @@ namespace {
       }
 
       // process poly chains
-      std::vector<ost::mol::ResidueHandle>* poly_chains[5] = {&L_chain,
-                                                              &D_chain,
-                                                              &P_chain,
-                                                              &R_chain,
-                                                              &S_chain};
+      T* poly_chains[5] = {&L_chain, &D_chain, &P_chain, &R_chain, &S_chain};
       for(int i = 0; i < 5; ++i) {
         if(!poly_chains[i]->empty()) {
           String chain_name = chain_name_gen.Get();
@@ -1259,8 +1243,70 @@ namespace {
     }
   }
 
+  void ProcessEntmmCIFify(const ost::mol::EntityHandle& ent,
+                          std::map<String, ost::io::MMCifWriterComp>& comp_infos,
+                          std::vector<ost::io::MMCifWriterEntity>& entity_info,
+                          ost::io::StarWriterLoopPtr atom_site,
+                          ost::io::StarWriterLoopPtr pdbx_poly_seq_scheme) {
+    std::vector<ost::mol::ResidueHandleList> res_lists;
+    ost::mol::ChainHandleList chain_list = ent.GetChainList();
+    for(auto ch: chain_list) {
+      res_lists.push_back(ch.GetResidueList());
+    }
+    ProcessEntmmCIFify(res_lists, comp_infos, entity_info,
+                       atom_site, pdbx_poly_seq_scheme);
+  }
+
+  void ProcessEntmmCIFify(const ost::mol::EntityView& ent,
+                          std::map<String, ost::io::MMCifWriterComp>& comp_infos,
+                          std::vector<ost::io::MMCifWriterEntity>& entity_info,
+                          ost::io::StarWriterLoopPtr atom_site,
+                          ost::io::StarWriterLoopPtr pdbx_poly_seq_scheme) {
+    std::vector<ost::mol::ResidueViewList> res_lists;
+    ost::mol::ChainViewList chain_list = ent.GetChainList();
+    for(auto ch: chain_list) {
+      res_lists.push_back(ch.GetResidueList());
+    }
+    ProcessEntmmCIFify(res_lists, comp_infos, entity_info,
+                       atom_site, pdbx_poly_seq_scheme);
+  }
+
+  // template to allow ost::mol::EntityHandle and ost::mol::EntityView
+  template<class T>
+  void ProcessEnt(const T& ent,
+                  bool mmcif_conform,
+                  std::map<String, ost::io::MMCifWriterComp>& comp_infos,
+                  std::vector<ost::io::MMCifWriterEntity>& entity_info,
+                  ost::io::StarWriterLoopPtr atom_site,
+                  ost::io::StarWriterLoopPtr pdbx_poly_seq_scheme) {
+
+    if(mmcif_conform) {
+      auto chain_list = ent.GetChainList();
+      for(auto ch: chain_list) {
+        auto res_list = ch.GetResidueList();
+        SetupChemComp(res_list, comp_infos);
+        String chain_name = ch.GetName();
+        int entity_id = SetupEntity(chain_name,
+                                    ch.GetType(),
+                                    res_list,
+                                    true,
+                                    entity_info);
+        Feed_atom_site(atom_site, chain_name, entity_id, entity_info[entity_id],
+                       res_list);
+        if(entity_info[entity_id].is_poly) {
+          Feed_pdbx_poly_seq_scheme(pdbx_poly_seq_scheme, chain_name,
+                                    entity_id, entity_info[entity_id], res_list);
+        }
+      }
+    } else {
+      // delegate to more complex ProcessEntmmCIFify
+      ProcessEntmmCIFify(ent, comp_infos, entity_info, atom_site,
+                         pdbx_poly_seq_scheme);
+    }
+  }
+
   void ProcessUnknowns(std::vector<ost::io::MMCifWriterEntity>& entity_infos,
-                       std::map<String, CompInfo>& comp_infos) {
+                       std::map<String, ost::io::MMCifWriterComp>& comp_infos) {
 
     for(size_t entity_idx = 0; entity_idx < entity_infos.size(); ++entity_idx) {
       if(entity_infos[entity_idx].is_poly) {
@@ -1278,7 +1324,7 @@ namespace {
               entity_infos[entity_idx].seq_olcs[mon_id_idx] = "(UNK)"; 
               entity_infos[entity_idx].seq_can_olcs[mon_id_idx] = "X";
               if(comp_infos.find("UNK") == comp_infos.end()) {
-                CompInfo info;
+                ost::io::MMCifWriterComp info;
                 info.type = "L-PEPTIDE LINKING";
                 comp_infos["UNK"] = info;
               }
@@ -1289,7 +1335,7 @@ namespace {
               entity_infos[entity_idx].seq_olcs[mon_id_idx] = "(DN)"; 
               entity_infos[entity_idx].seq_can_olcs[mon_id_idx] = "N";
               if(comp_infos.find("DN") == comp_infos.end()) {
-                CompInfo info;
+                ost::io::MMCifWriterComp info;
                 info.type = "DNA LINKING";
                 comp_infos["DN"] = info;
               }
@@ -1301,7 +1347,7 @@ namespace {
               entity_infos[entity_idx].seq_olcs[mon_id_idx] = "N"; 
               entity_infos[entity_idx].seq_can_olcs[mon_id_idx] = "N";
               if(comp_infos.find("N") == comp_infos.end()) {
-                CompInfo info;
+                ost::io::MMCifWriterComp info;
                 info.type = "RNA LINKING";
                 comp_infos["N"] = info;
               }
@@ -1343,6 +1389,22 @@ int MMCifWriterEntity::GetAsymIdx(const String& asym_id) const {
 void MMCifWriter::SetStructure(const ost::mol::EntityHandle& ent,
                                bool mmcif_conform) {
 
+  this->Setup();
+  ProcessEnt(ent, mmcif_conform, comp_info_, entity_info_, atom_site_,
+             pdbx_poly_seq_scheme_);
+  this->Finalize();
+}
+
+void MMCifWriter::SetStructure(const ost::mol::EntityView& ent,
+                               bool mmcif_conform) {
+
+  this->Setup();
+  ProcessEnt(ent, mmcif_conform, comp_info_, entity_info_, atom_site_,
+             pdbx_poly_seq_scheme_);
+  this->Finalize();
+}
+
+void MMCifWriter::Setup() {
   if(structure_set_) {
     throw ost::io::IOException("SetStructure can be called only once on a "
                                "given MMCifWriter instance");
@@ -1356,37 +1418,21 @@ void MMCifWriter::SetStructure(const ost::mol::EntityHandle& ent,
   entity_poly_ = Setup_entity_poly_ptr();
   entity_poly_seq_ = Setup_entity_poly_seq_ptr();
   chem_comp_ = Setup_chem_comp_ptr();
+}
 
-  std::map<String, CompInfo> comp_infos;
-
-  // The ProcessEnt functions fill comp_info and entity_info_, i.e. gather
-  // info on all unique compounds and entities observed in ent and relate the
-  // entities with asym chain names that are directly written to
-  // atom_site_/pdbx_poly_seq_scheme_.
-  if(mmcif_conform) {
-    // chains are assumed to be mmCIF conform - that means water in separate
-    // chains, ligands in separate chains etc. Chain types are inferred from
-    // chain type property set to the chains in ent.
-    ProcessEnt(ent, comp_infos, entity_info_,
-               atom_site_, pdbx_poly_seq_scheme_);
-  } else {
-    // rule based splitting of chains into mmCIF conform chains
-    ProcessEntmmCIFify(ent, comp_infos, entity_info_,
-                       atom_site_, pdbx_poly_seq_scheme_);
-  } 
-
-  // depending on the strategy above, there might be gaps in the entities
-  // mon_ids/ seq_olcs/ seq_can_olcs
+void MMCifWriter::Finalize() {
+  // depending on the strategy (mmcif_conform), there might be gaps in the
+  // entities mon_ids/ seq_olcs/ seq_can_olcs
   // The following function adds valid stuff depending on chain type
   // (e.g. UNK if we're having a peptide linking chain and then adds
   // that UNK directly to comp_info)
-  ProcessUnknowns(entity_info_, comp_infos);
+  ProcessUnknowns(entity_info_, comp_info_);
 
   Feed_entity(entity_, entity_info_);
   Feed_struct_asym(struct_asym_, entity_info_);
   Feed_entity_poly(entity_poly_, entity_info_);
   Feed_entity_poly_seq(entity_poly_seq_, entity_info_);
-  Feed_chem_comp(chem_comp_, comp_infos);
+  Feed_chem_comp(chem_comp_, comp_info_);
   Feed_atom_type(atom_type_, atom_site_); 
 
   // finalize
diff --git a/modules/io/src/mol/mmcif_writer.hh b/modules/io/src/mol/mmcif_writer.hh
index cc44706df..c176ac74f 100644
--- a/modules/io/src/mol/mmcif_writer.hh
+++ b/modules/io/src/mol/mmcif_writer.hh
@@ -66,6 +66,11 @@ struct MMCifWriterEntity {
 };
 
 
+struct MMCifWriterComp {
+  String type;
+};
+
+
 class DLLEXPORT_OST_IO MMCifWriter : public StarWriter {
 public:
 
@@ -75,8 +80,16 @@ public:
 
   void SetStructure(const ost::mol::EntityHandle& ent, bool mmcif_conform=true);
 
+  void SetStructure(const ost::mol::EntityView& ent, bool mmcif_conform=true);
+  
 private:
+
+  void Setup();
+
+  void Finalize();
+
   std::vector<MMCifWriterEntity> entity_info_;
+  std::map<String, MMCifWriterComp> comp_info_;
   StarWriterLoopPtr atom_type_;
   StarWriterLoopPtr atom_site_;
   StarWriterLoopPtr pdbx_poly_seq_scheme_;
-- 
GitLab