From bd3682cbe0d1e5a538f4c7b86218759b2690ac3c Mon Sep 17 00:00:00 2001
From: Gerardo Tauriello <gerardo.tauriello@unibas.ch>
Date: Wed, 28 Oct 2020 18:45:13 +0100
Subject: [PATCH] Added subsampling for VarianceMap/Dist2Mean/MeanlDDT.

---
 modules/seq/alg/doc/seqalg.rst        | 18 ++++++++++
 modules/seq/alg/pymod/wrap_seq_alg.cc | 49 ++++++++++++++++++++++++++-
 2 files changed, 66 insertions(+), 1 deletion(-)

diff --git a/modules/seq/alg/doc/seqalg.rst b/modules/seq/alg/doc/seqalg.rst
index 96328e1af..9a5470fa4 100644
--- a/modules/seq/alg/doc/seqalg.rst
+++ b/modules/seq/alg/doc/seqalg.rst
@@ -586,6 +586,15 @@ differences between the structures.
     :returns: A list of :meth:`GetSize` lists with :meth:`GetSize` variances.
     :rtype:   :class:`list` of :class:`list` of :class:`float`
 
+  .. method:: GetSubData(num_res_to_avg)
+
+    Gets subset of data in this map by averaging neighboring values for
+    *num_res_to_avg* residues.
+
+    :returns: A list of ceil(:meth:`GetSize`/*num_res_to_avg*) lists with
+              ceil(:meth:`GetSize`/*num_res_to_avg*) variances.
+    :rtype:   :class:`list` of :class:`list` of :class:`float`
+
 .. class:: Dist2Mean
 
   Container returned by :func:`CreateDist2Mean`.
@@ -645,6 +654,15 @@ differences between the structures.
               :meth:`GetNumStructures` distances.
     :rtype:   :class:`list` of :class:`list` of :class:`float`
 
+  .. method:: GetSubData(num_res_to_avg)
+
+    Gets subset of data in this map by averaging neighboring values for
+    *num_res_to_avg* residues.
+
+    :returns: A list of ceil(:meth:`GetNumResidues`/*num_res_to_avg*) lists with
+              :meth:`GetNumStructures` distances.
+    :rtype:   :class:`list` of :class:`list` of :class:`float`
+
 
 .. class:: MeanlDDT
 
diff --git a/modules/seq/alg/pymod/wrap_seq_alg.cc b/modules/seq/alg/pymod/wrap_seq_alg.cc
index 79873278e..f9562d5b2 100644
--- a/modules/seq/alg/pymod/wrap_seq_alg.cc
+++ b/modules/seq/alg/pymod/wrap_seq_alg.cc
@@ -38,6 +38,8 @@
 #include <ost/seq/alg/hmm_pseudo_counts.hh>
 #include <ost/seq/alg/hmm_score.hh>
 
+#include <algorithm>
+
 using namespace boost::python;
 using namespace ost::seq;
 using namespace ost::seq::alg;
@@ -79,10 +81,51 @@ list DistToMeanGetData(const Dist2MeanPtr d2m) {
   return GetList(*d2m, d2m->GetNumResidues(), d2m->GetNumStructures());
 }
 
-list  MeanlDDTGetData(const MeanlDDTPtr ld) {
+list MeanlDDTGetData(const MeanlDDTPtr ld) {
   return GetList(*ld, ld->GetNumResidues(), ld->GetNumStructures());
 }
 
+template <typename T>
+list GetSubList(const T& data, uint num_rows, uint num_cols, uint rows_to_avg,
+                uint cols_to_avg) {
+  if (rows_to_avg < 1 || cols_to_avg < 1) {
+    throw ost::Error("Invalid number of data to average!");
+  }
+  list ret;
+  Real n_to_avg = rows_to_avg * cols_to_avg;
+  for (uint row = 0; row < num_rows; row += rows_to_avg) {
+    list my_row;
+    for (uint col = 0; col < num_cols; col += cols_to_avg) {
+      Real avg_data = 0;
+      const uint max_row = std::min(num_rows, row + rows_to_avg);
+      const uint max_col = std::min(num_cols, col + cols_to_avg);
+      for (uint sub_row = row; sub_row < max_row; ++sub_row) {
+        for (uint sub_col = col; sub_col < max_col; ++sub_col) {
+          avg_data += data(sub_row, sub_col);
+        }
+      }
+      my_row.append(avg_data / n_to_avg);
+    }
+    ret.append(my_row);
+  }
+  return ret;
+}
+
+list VarMapGetSubData(const VarianceMapPtr v_map, uint num_res_to_avg) {
+  return GetSubList(*v_map, v_map->GetSize(), v_map->GetSize(),
+                    num_res_to_avg, num_res_to_avg);
+}
+
+list DistToMeanGetSubData(const Dist2MeanPtr d2m, uint num_res_to_avg) {
+  return GetSubList(*d2m, d2m->GetNumResidues(), d2m->GetNumStructures(),
+                    num_res_to_avg, 1);
+}
+
+list MeanlDDTGetSubData(const MeanlDDTPtr ld, uint num_res_to_avg) {
+  return GetSubList(*ld, ld->GetNumResidues(), ld->GetNumStructures(),
+                    num_res_to_avg, 1);
+}
+
 void AAPseudoCountsSimple(ProfileHandle& profile, Real a, Real b, Real c) {
   AddAAPseudoCounts(profile, a, b, c);
 }
@@ -225,6 +268,7 @@ void export_distance_analysis()
     .def("ExportJson", &VarianceMap::ExportJson, (arg("file_name")))
     .def("GetJsonString", &VarianceMap::GetJsonString)
     .def("GetData", &VarMapGetData)
+    .def("GetSubData", &VarMapGetSubData, (arg("num_res_to_avg")))
   ;
 
   class_<Dist2Mean, Dist2MeanPtr,
@@ -237,7 +281,9 @@ void export_distance_analysis()
     .def("ExportJson", &Dist2Mean::ExportJson, (arg("file_name")))
     .def("GetJsonString", &Dist2Mean::GetJsonString)
     .def("GetData", &DistToMeanGetData)
+    .def("GetSubData", &DistToMeanGetSubData, (arg("num_res_to_avg")))
   ;
+
   class_<MeanlDDT, MeanlDDTPtr,
          boost::noncopyable>("MeanlDDT", no_init)
     .def("Get", &MeanlDDT::Get, (arg("i_res"), arg("i_str")))
@@ -248,6 +294,7 @@ void export_distance_analysis()
     .def("ExportJson", &MeanlDDT::ExportJson, (arg("file_name")))
     .def("GetJsonString", &MeanlDDT::GetJsonString)
     .def("GetData", &MeanlDDTGetData)
+    .def("GetSubData", &MeanlDDTGetSubData, (arg("num_res_to_avg")))
   ;
 }
 
-- 
GitLab