From 2a6bb66ae8602b5f99df820a78a84a9c7bd14fec Mon Sep 17 00:00:00 2001
From: Rafal Gumienny <guma44@gmail.com>
Date: Fri, 2 Mar 2018 09:19:26 +0100
Subject: [PATCH] feat: SCHWED-3127 More convenience functions for lDDT
 calculation and printing

---
 modules/mol/alg/src/lddt.cc                 | 155 +----------------
 modules/mol/alg/src/local_dist_diff_test.cc | 183 +++++++++++++++++++-
 modules/mol/alg/src/local_dist_diff_test.hh |  25 ++-
 3 files changed, 206 insertions(+), 157 deletions(-)

diff --git a/modules/mol/alg/src/lddt.cc b/modules/mol/alg/src/lddt.cc
index 42072cdc5..2cff1c815 100644
--- a/modules/mol/alg/src/lddt.cc
+++ b/modules/mol/alg/src/lddt.cc
@@ -102,30 +102,6 @@ void usage()
 
 }
 
-// computes coverage
-std::pair<int,int> compute_coverage (const EntityView& v,const GlobalRDMap& glob_dist_list)
-{
-  int second=0;
-  int first=0;
-  if (v.GetResidueList().size()==0) {
-    if (glob_dist_list.size()==0) {
-      return std::make_pair(0,-1);
-    } else {    
-      return std::make_pair(0,glob_dist_list.size());
-    }  
-  }
-  ChainView vchain=v.GetChainList()[0];
-  for (GlobalRDMap::const_iterator i=glob_dist_list.begin();i!=glob_dist_list.end();++i)
-  {
-    ResNum rnum = (*i).first;
-    second++;
-    if (vchain.FindResidue(rnum)) {
-      first++;
-    }
-  }
-  return std::make_pair(first,second);
-}
-
 CompoundLibPtr load_compound_lib(const String& custom_path)
 {
   if (custom_path!="") {
@@ -175,16 +151,6 @@ CompoundLibPtr load_compound_lib(const String& custom_path)
   }
   return CompoundLibPtr();
 }
-bool is_resnum_in_globalrdmap(const ResNum& resnum, const GlobalRDMap& glob_dist_list)
-{
-  for (GlobalRDMap::const_iterator i=glob_dist_list.begin(), e=glob_dist_list.end(); i!=e; ++i) {
-    ResNum rn = i->first;
-    if (rn==resnum) {
-      return true;
-    }
-  }
-  return false;
-}
 
 int main (int argc, char **argv)
 {
@@ -317,18 +283,8 @@ int main (int argc, char **argv)
 
   // prints out parameters used in the lddt calculation
   std::cout << "Verbosity level: " << verbosity_level << std::endl;
+  settings.PrintParameters();
   if (settings.structural_checks) {
-    std::cout << "Stereo-chemical and steric clash checks: On " << std::endl;
-  } else {
-    std::cout << "Stereo-chemical and steric clash checks: Off " << std::endl;
-  }
-  std::cout << "Inclusion Radius: " << settings.radius << std::endl;
-
-  std::cout << "Sequence separation: " << settings.sequence_separation << std::endl;
-  if (settings.structural_checks) {
-    std::cout << "Parameter filename: " << settings.parameter_file_path << std::endl;
-    std::cout << "Tolerance in stddevs for bonds: " << settings.bond_tolerance << std::endl;
-    std::cout << "Tolerance in stddevs for angles: " << settings.angle_tolerance << std::endl;
     LOG_INFO("Log entries format:");
     LOG_INFO("BOND INFO FORMAT:  Chain  Residue  ResNum  Bond  Min  Max  Observed  Z-score  Status");
     LOG_INFO("ANGLE INFO FORMAT:  Chain  Residue  ResNum  Angle  Min  Max  Observed  Z-score  Status");
@@ -351,31 +307,11 @@ int main (int argc, char **argv)
       }
       continue;
     }
-    EntityView v=model.GetChainList()[0].Select("peptide=true");
-    EntityView outv=model.GetChainList()[0].Select("peptide=true");
-    for (std::vector<EntityView>::const_iterator ref_list_it = ref_list.begin();
-         ref_list_it != ref_list.end(); ++ref_list_it) {
-      bool cons_check = ResidueNamesMatch(v,*ref_list_it,settings.consistency_checks);
-      if (cons_check==false) {
-        if (settings.consistency_checks==true) {
-          LOG_ERROR("Residue names in model: " << files[i] << " and in reference structure(s) are inconsistent.");
-          exit(-1);            
-        } else {
-          LOG_WARNING("Residue names in model: " << files[i] << " and in reference structure(s) are inconsistent.");
-        }   
-      } 
-    }
+    EntityView model_view = model.GetChainList()[0].Select("peptide=true");
 
     boost::filesystem::path pathstring(files[i]);
-
     String filestring=BFPathToString(pathstring);
-    std::cout << "File: " << files[i] << std::endl; 
-    std::pair<int,int> cov = compute_coverage(v,glob_dist_list);
-    if (cov.second == -1) {
-      std::cout << "Coverage: 0 (0 out of 0 residues)" << std::endl;
-    } else {
-      std::cout << "Coverage: " << (float(cov.first)/float(cov.second)) << " (" << cov.first << " out of " << cov.second << " residues)" << std::endl;
-    }
+    std::cout << "File: " << files[i] << std::endl;
 
     if (settings.structural_checks) {
       StereoChemicalParamsReader stereochemical_params(settings.parameter_file_path);
@@ -387,7 +323,7 @@ int main (int argc, char **argv)
       }
       
       try {
-        CheckStructure(v,
+        CheckStructure(model_view,
                        stereochemical_params.bond_table,
                        stereochemical_params.angle_table,
                        stereochemical_params.nonbonded_table,
@@ -399,88 +335,11 @@ int main (int argc, char **argv)
       }
     }
 
-    if (cov.first==0) {
-      std::cout << "Global LDDT score: 0.0" << std::endl;
-      return 0;
-    }
-
     // computes the lddt score   
-    String label="localldt";
-    std::pair<int,int> total_ov=alg::LocalDistDiffTest(v, glob_dist_list, settings.cutoffs, settings.sequence_separation, label);
-    Real lddt = static_cast<Real>(total_ov.first)/(static_cast<Real>(total_ov.second) ? static_cast<Real>(total_ov.second) : 1);
-    std::cout << "Global LDDT score: " << std::setprecision(4) << lddt << std::endl;
-    std::cout << "(" << std::fixed << total_ov.first << " conserved distances out of " << total_ov.second
-              << " checked, over " << settings.cutoffs.size() << " thresholds)" << std::endl;
+    Real lddt = LocalDistDiffTest(model_view, ref_list, glob_dist_list, settings);
 
-    // prints the residue-by-residue statistics  
-    if (settings.structural_checks) {
-      std::cout << "Local LDDT Scores:" << std::endl;
-      std::cout << "(A 'Yes' in the 'Quality Problems' column stands for problems" << std::endl;
-      std::cout << "in the side-chain of a residue, while a 'Yes+' for problems" << std::endl;
-      std::cout << "in the backbone)" << std::endl;
-    } else {
-      std::cout << "Local LDDT Scores:" << std::endl;
-    }
-    if (settings.structural_checks) {
-      std::cout << "Chain\tResName\tResNum\tAsses.\tQ.Prob.\tScore\t(Conserved/Total, over " << settings.cutoffs.size() << " thresholds)" << std::endl;
-    } else {
-      std::cout << "Chain\tResName\tResNum\tAsses.\tScore\t(Conserved/Total, over " << settings.cutoffs.size() << " thresholds)" << std::endl;
-    }
-    for (ChainViewList::const_iterator ci = outv.GetChainList().begin(),
-         ce = outv.GetChainList().end(); ci != ce; ++ci) {
-      for (ResidueViewList::const_iterator rit = ci->GetResidueList().begin(),
-           re = ci->GetResidueList().end(); rit != re; ++rit) {
-     
-        ResidueView ritv=*rit;
-        ResNum rnum = ritv.GetNumber();
-        bool assessed = false;
-        String assessed_string="No";
-        String quality_problems_string="No";
-        Real lddt_local = -1;
-        String lddt_local_string="-";
-        int conserved_dist = -1;
-        int total_dist = -1;
-        String dist_string = "-";
-        if (is_resnum_in_globalrdmap(rnum,glob_dist_list)) {
-          assessed = true;
-          assessed_string="Yes";
-        }
-        if (ritv.HasProp("stereo_chemical_violation_sidechain") || 
-            ritv.HasProp("steric_clash_sidechain")) {
-          quality_problems_string="Yes";
-        }
-        if (ritv.HasProp("stereo_chemical_violation_backbone") || 
-            ritv.HasProp("steric_clash_backbone")) {
-          quality_problems_string="Yes+";
-        }
-
-        if (assessed==true) {
-          if (ritv.HasProp(label)) {
-            lddt_local=ritv.GetFloatProp(label);
-            std::stringstream stkeylddt;
-            stkeylddt <<  std::fixed << std::setprecision(4) << lddt_local;
-            lddt_local_string=stkeylddt.str();
-            conserved_dist=ritv.GetIntProp(label+"_conserved");
-            total_dist=ritv.GetIntProp(label+"_total");
-            std::stringstream stkeydist;
-            stkeydist << "("<< conserved_dist << "/" << total_dist << ")";
-            dist_string=stkeydist.str();
-          } else {
-            lddt_local = 0;
-            lddt_local_string="0.0000";
-            conserved_dist = 0;
-            total_dist = 0;
-            dist_string="(0/0)";
-          }
-        }
-        if (settings.structural_checks) {
-          std::cout << ritv.GetChain() << "\t" << ritv.GetName() << "\t" << ritv.GetNumber() << '\t' << assessed_string  << '\t' << quality_problems_string << '\t' << lddt_local_string << "\t" << dist_string << std::endl;
-        } else {
-          std::cout << ritv.GetChain() << "\t" << ritv.GetName() << "\t" << ritv.GetNumber() << '\t' << assessed_string  << '\t' << lddt_local_string << "\t" << dist_string << std::endl;
-        }
-      }
-    }
-    std::cout << std::endl;
+    // prints the residue-by-residue statistics
+    PrintlDDTPerResidueStats(model, glob_dist_list, settings);
   }
   return 0;
 }
diff --git a/modules/mol/alg/src/local_dist_diff_test.cc b/modules/mol/alg/src/local_dist_diff_test.cc
index 4f525571f..79eb285a5 100644
--- a/modules/mol/alg/src/local_dist_diff_test.cc
+++ b/modules/mol/alg/src/local_dist_diff_test.cc
@@ -1,9 +1,11 @@
 #include <iomanip>
+#include <sstream>
 #include <ost/log.hh>
 #include <ost/mol/mol.hh>
 #include <ost/platform.hh>
 #include "local_dist_diff_test.hh"
 #include <boost/concept_check.hpp>
+#include <ost/mol/alg/consistency_checks.hh>
 
 namespace ost { namespace mol { namespace alg {
 
@@ -309,6 +311,42 @@ void merge_distance_lists(GlobalRDMap& ref_dist_map, const GlobalRDMap& new_dist
 
 }
 
+// Computes coverage
+std::pair<int,int> ComputeCoverage(const EntityView& v,const GlobalRDMap& glob_dist_list)
+{
+  int second=0;
+  int first=0;
+  if (v.GetResidueList().size()==0) {
+    if (glob_dist_list.size()==0) {
+      return std::make_pair(0,-1);
+    } else {    
+      return std::make_pair(0,glob_dist_list.size());
+    }  
+  }
+  ChainView vchain=v.GetChainList()[0];
+  for (GlobalRDMap::const_iterator i=glob_dist_list.begin();i!=glob_dist_list.end();++i)
+  {
+    ResNum rnum = (*i).first;
+    second++;
+    if (vchain.FindResidue(rnum)) {
+      first++;
+    }
+  }
+  return std::make_pair(first,second);
+}
+
+bool IsResnumInGlobalRDMap(const ResNum& resnum, const GlobalRDMap& glob_dist_list)
+{
+  for (GlobalRDMap::const_iterator i=glob_dist_list.begin(), e=glob_dist_list.end(); i!=e; ++i) {
+    ResNum rn = i->first;
+    if (rn==resnum) {
+      return true;
+    }
+  }
+  return false;
+}
+
+
 // helper function
 bool IsStandardResidue(String rn)
 {
@@ -345,7 +383,9 @@ lDDTSettings::lDDTSettings(): bond_tolerance(12.0),
                               sequence_separation(0),
                               sel(""),
                               structural_checks(false),
-                              consistency_checks(true) {
+                              consistency_checks(true),
+                              print_stats(true),
+                              label("localldt") {
     cutoffs.push_back(0.5);
     cutoffs.push_back(1.0);
     cutoffs.push_back(2.0);
@@ -368,7 +408,9 @@ lDDTSettings::lDDTSettings(Real init_bond_tolerance,
                            String init_parameter_file_path,
                            bool init_structural_checks,
                            bool init_consistency_checks,
-                           std::vector<Real> init_cutoffs): 
+                           std::vector<Real>& init_cutoffs,
+                           bool init_print_stats,
+                           String init_label): 
                     bond_tolerance(init_bond_tolerance),
                     angle_tolerance(init_angle_tolerance),
                     radius(init_radius), 
@@ -376,7 +418,33 @@ lDDTSettings::lDDTSettings(Real init_bond_tolerance,
                     sel(init_sel),
                     parameter_file_path(init_parameter_file_path),
                     structural_checks(init_structural_checks),
-                    consistency_checks(init_consistency_checks) {}
+                    consistency_checks(init_consistency_checks),
+                    print_stats(init_print_stats),
+                    label(init_label) {}
+
+std::string lDDTSettings::ToString() {
+  std::ostringstream rep;
+  if (structural_checks) {
+    rep << "Stereo-chemical and steric clash checks: On \n";
+  } else {
+    rep << "Stereo-chemical and steric clash checks: Off \n";
+  }
+  rep << "Inclusion Radius: " << radius << "\n";
+
+  rep << "Sequence separation: " << sequence_separation << "\n";
+  if (structural_checks) {
+    rep << "Parameter filename: " + parameter_file_path + "\n";
+    rep << "Tolerance in stddevs for bonds: " << bond_tolerance << "\n";
+    rep << "Tolerance in stddevs for angles: " << angle_tolerance << "\n";
+  }
+  rep << "Residue properties label: " << label << "\n";
+
+  return rep.str();
+}
+
+void lDDTSettings::PrintParameters() {
+  std::cout << ToString();
+}
 
 
 GlobalRDMap CreateDistanceList(const EntityView& ref,Real max_dist)
@@ -514,6 +582,42 @@ Real LocalDistDiffTest(const EntityView& mdl, const EntityView& target, Real cut
    return static_cast<Real>(total_ov.first)/(static_cast<Real>(total_ov.second) ? static_cast<Real>(total_ov.second) : 1);
 }
 
+Real LocalDistDiffTest(const EntityView& v,
+                       std::vector<EntityView>& ref_list,
+                       const GlobalRDMap& glob_dist_list,
+                       lDDTSettings& settings) {
+
+  for (std::vector<EntityView>::const_iterator ref_list_it = ref_list.begin();
+       ref_list_it != ref_list.end(); ++ref_list_it) {
+    bool cons_check = ResidueNamesMatch(v,*ref_list_it,settings.consistency_checks);
+    if (cons_check==false) {
+      if (settings.consistency_checks==true) {
+        throw std::runtime_error("Residue names in model and in reference structure(s) are inconsistent.");            
+      } else {
+        LOG_WARNING("Residue names in model and in reference structure(s) are inconsistent.");
+      }   
+    } 
+  }
+
+  std::pair<int,int> cov = ComputeCoverage(v,glob_dist_list);
+  if (cov.second == -1) {
+    std::cout << "Coverage: 0 (0 out of 0 residues)" << std::endl;
+  } else {
+    std::cout << "Coverage: " << (float(cov.first)/float(cov.second)) << " (" << cov.first << " out of " << cov.second << " residues)" << std::endl;
+  }
+
+  if (cov.first==0) {
+    std::cout << "Global LDDT score: 0.0" << std::endl;
+    return 0.0;
+  }
+
+  std::pair<int,int> total_ov=alg::LocalDistDiffTest(v, glob_dist_list, settings.cutoffs, settings.sequence_separation, settings.label);
+  Real lddt = static_cast<Real>(total_ov.first)/(static_cast<Real>(total_ov.second) ? static_cast<Real>(total_ov.second) : 1);
+  std::cout << "Global LDDT score: " << std::setprecision(4) << lddt << std::endl;
+  std::cout << "(" << std::fixed << total_ov.first << " conserved distances out of " << total_ov.second
+            << " checked, over " << settings.cutoffs.size() << " thresholds)" << std::endl;
+  return lddt;
+}
 
 Real LocalDistDiffTest(const ost::seq::AlignmentHandle& aln,
                    Real cutoff, Real max_dist, int ref_index, int mdl_index)
@@ -650,6 +754,79 @@ void CheckStructure(EntityView& ent,
   std::cout << "Distances shorter than tolerance are on average shorter by: " << std::fixed << std::setprecision(5) << clash_info.GetAverageOffset() << std::endl;
 }
 
+void PrintlDDTPerResidueStats(EntityHandle& model, GlobalRDMap& glob_dist_list, lDDTSettings& settings){
+  EntityView outv = model.GetChainList()[0].Select("peptide=true");
+  if (settings.structural_checks) {
+    std::cout << "Local LDDT Scores:" << std::endl;
+    std::cout << "(A 'Yes' in the 'Quality Problems' column stands for problems" << std::endl;
+    std::cout << "in the side-chain of a residue, while a 'Yes+' for problems" << std::endl;
+    std::cout << "in the backbone)" << std::endl;
+  } else {
+    std::cout << "Local LDDT Scores:" << std::endl;
+  }
+  if (settings.structural_checks) {
+    std::cout << "Chain\tResName\tResNum\tAsses.\tQ.Prob.\tScore\t(Conserved/Total, over " << settings.cutoffs.size() << " thresholds)" << std::endl;
+  } else {
+    std::cout << "Chain\tResName\tResNum\tAsses.\tScore\t(Conserved/Total, over " << settings.cutoffs.size() << " thresholds)" << std::endl;
+  }
+  for (ChainViewList::const_iterator ci = outv.GetChainList().begin(),
+       ce = outv.GetChainList().end(); ci != ce; ++ci) {
+    for (ResidueViewList::const_iterator rit = ci->GetResidueList().begin(),
+         re = ci->GetResidueList().end(); rit != re; ++rit) {
+   
+      ResidueView ritv=*rit;
+      ResNum rnum = ritv.GetNumber();
+      bool assessed = false;
+      String assessed_string="No";
+      String quality_problems_string="No";
+      Real lddt_local = -1;
+      String lddt_local_string="-";
+      int conserved_dist = -1;
+      int total_dist = -1;
+      String dist_string = "-";
+      if (IsResnumInGlobalRDMap(rnum,glob_dist_list)) {
+        assessed = true;
+        assessed_string="Yes";
+      }
+      if (ritv.HasProp("stereo_chemical_violation_sidechain") || 
+          ritv.HasProp("steric_clash_sidechain")) {
+        quality_problems_string="Yes";
+      }
+      if (ritv.HasProp("stereo_chemical_violation_backbone") || 
+          ritv.HasProp("steric_clash_backbone")) {
+        quality_problems_string="Yes+";
+      }
+
+      if (assessed==true) {
+        if (ritv.HasProp(settings.label)) {
+          lddt_local=ritv.GetFloatProp(settings.label);
+          std::stringstream stkeylddt;
+          stkeylddt <<  std::fixed << std::setprecision(4) << lddt_local;
+          lddt_local_string=stkeylddt.str();
+          conserved_dist=ritv.GetIntProp(settings.label+"_conserved");
+          total_dist=ritv.GetIntProp(settings.label+"_total");
+          std::stringstream stkeydist;
+          stkeydist << "("<< conserved_dist << "/" << total_dist << ")";
+          dist_string=stkeydist.str();
+        } else {
+          std::cout << settings.label << std::endl;
+          lddt_local = 0;
+          lddt_local_string="0.0000";
+          conserved_dist = 0;
+          total_dist = 0;
+          dist_string="(0/0)";
+        }
+      }
+      if (settings.structural_checks) {
+        std::cout << ritv.GetChain() << "\t" << ritv.GetName() << "\t" << ritv.GetNumber() << '\t' << assessed_string  << '\t' << quality_problems_string << '\t' << lddt_local_string << "\t" << dist_string << std::endl;
+      } else {
+        std::cout << ritv.GetChain() << "\t" << ritv.GetName() << "\t" << ritv.GetNumber() << '\t' << assessed_string  << '\t' << lddt_local_string << "\t" << dist_string << std::endl;
+      }
+    }
+  }
+  std::cout << std::endl;
+}
+
 // debugging code
 /*
 Real OldStyleLDDTHA(EntityView& v, const GlobalRDMap& global_dist_list)
diff --git a/modules/mol/alg/src/local_dist_diff_test.hh b/modules/mol/alg/src/local_dist_diff_test.hh
index 8d29e79ac..775944724 100644
--- a/modules/mol/alg/src/local_dist_diff_test.hh
+++ b/modules/mol/alg/src/local_dist_diff_test.hh
@@ -26,7 +26,6 @@
 
 namespace ost { namespace mol { namespace alg {
 
-
 struct lDDTSettings {
   Real bond_tolerance;
   Real angle_tolerance;
@@ -37,6 +36,8 @@ struct lDDTSettings {
   bool structural_checks;
   bool consistency_checks;
   std::vector<Real> cutoffs;
+  bool print_stats;
+  String label;
 
   lDDTSettings();
   lDDTSettings(Real init_bond_tolerance,
@@ -47,8 +48,16 @@ struct lDDTSettings {
                String init_parameter_file_path,
                bool init_structural_checks,
                bool init_consistency_checks,
-               std::vector<Real> init_cutoffs);
+               std::vector<Real>& init_cutoffs,
+               bool init_print_stats,
+               String init_label);
+  void PrintParameters();
+  std::string ToString();
 };
+
+std::pair<int,int> DLLEXPORT_OST_MOL_ALG ComputeCoverage(const EntityView& v,const GlobalRDMap& glob_dist_list);
+
+bool DLLEXPORT_OST_MOL_ALG IsResnumInGlobalRDMap(const ResNum& resnum, const GlobalRDMap& glob_dist_list);
   
 /// \brief Calculates number of distances conserved in a model, given a list of distances to check and a model
 ///
@@ -99,10 +108,10 @@ Real DLLEXPORT_OST_MOL_ALG LocalDistDiffTest(const EntityView& mdl,
                                          Real max_dist,
                                          const String& local_ldt_property_string="");
 /// TODO document me
-Real DLLEXPORT_OST_MOL_ALG LocalDistDiffTest(const EntityView& mdl,
-                                         const std::vector<EntityView>& targets,
-                                         const lDDTSettings& settings,
-                                         const String& local_ldt_property_string="");
+Real DLLEXPORT_OST_MOL_ALG LocalDistDiffTest(const EntityView& v,
+                       std::vector<EntityView>& ref_list,
+                       const GlobalRDMap& glob_dist_list,
+                       lDDTSettings& settings);
 
 /// \brief Calculates the Local Distance Difference Test score for a given model starting from an alignment between a reference structure and the model. 
 ///
@@ -182,6 +191,10 @@ void DLLEXPORT_OST_MOL_ALG CheckStructure(EntityView& ent,
                                           Real bond_tolerance,
                                           Real angle_tolerance);
 
+void PrintlDDTPerResidueStats(EntityHandle& model,
+                              GlobalRDMap& glob_dist_list,
+                              lDDTSettings& settings);
+
 }}}
 
 #endif
-- 
GitLab