From 0cac598dbd6becf42c28bc974fdd0742078bf482 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Sun, 12 Mar 2023 22:31:10 +0100
Subject: [PATCH] Update Motif Finder

The Motif finder algorithm was designed to have a database of small motifs
and search them in one target structure. Needs some changes for the use case
of lots of larger structures in which we want to find 1 Motif.
---
 modelling/doc/algorithms.rst           |  6 ++-
 modelling/pymod/export_motif_finder.cc |  9 ++--
 modelling/src/motif_finder.cc          | 60 +++++++++++++++++---------
 modelling/src/motif_finder.hh          |  3 +-
 4 files changed, 53 insertions(+), 25 deletions(-)

diff --git a/modelling/doc/algorithms.rst b/modelling/doc/algorithms.rst
index c9db60f1..5545a9ac 100644
--- a/modelling/doc/algorithms.rst
+++ b/modelling/doc/algorithms.rst
@@ -308,7 +308,7 @@ iteration.
 
 .. method:: FindMotifs(query, target_positions, hash_tresh=0.4, \
                        distance_thresh=1.0, refine_thresh=0.7, \
-                       flags=list())
+                       flags=list(), swap_thresh=False)
 
   Performs the detection and refinement stages of the geometric hashing 
   algorithm. 
@@ -322,6 +322,10 @@ iteration.
                         constructor. If you didn't provide anything there,
                         this can be ignored. Only the actual coordinates
                         matter in this case.
+  :param swap_thresh: *hash_thresh* and *refine_thresh* refer to fraction of
+                      covered positions in *query*. When setting this to
+                      True, they refer to the fraction of covered positions
+                      in *target_positions*.
 
   :returns:             All found matches
 
diff --git a/modelling/pymod/export_motif_finder.cc b/modelling/pymod/export_motif_finder.cc
index e966af44..c711e12a 100644
--- a/modelling/pymod/export_motif_finder.cc
+++ b/modelling/pymod/export_motif_finder.cc
@@ -69,7 +69,8 @@ boost::python::list WrapFindMotifs(const MotifQuery& query,
                                    Real hash_thresh,
                                    Real distance_thresh,
                                    Real refine_thresh,
-                                   const boost::python::list& flags) {
+                                   const boost::python::list& flags,
+                                   bool swap_thresh) {
 
   std::vector<int> v_flags;
   promod3::core::ConvertListToVector(flags, v_flags);
@@ -78,7 +79,8 @@ boost::python::list WrapFindMotifs(const MotifQuery& query,
                                                 hash_thresh,
                                                 distance_thresh,
                                                 refine_thresh,
-                                                v_flags);
+                                                v_flags,
+                                                swap_thresh);
   list return_list;
   for(std::vector<MotifMatch>::iterator it = v_result.begin();
       it != v_result.end(); ++it) {
@@ -153,5 +155,6 @@ void export_motif_finder() {
                                       arg("hash_thresh")=0.4,
                                       arg("distance_thresh")=1.0,
                                       arg("refine_thresh")=0.7,
-                                      arg("flags")=boost::python::list()));
+                                      arg("flags")=boost::python::list(),
+                                      arg("swap_thresh")=false));
 }
diff --git a/modelling/src/motif_finder.cc b/modelling/src/motif_finder.cc
index 61c62f07..2ec8d756 100644
--- a/modelling/src/motif_finder.cc
+++ b/modelling/src/motif_finder.cc
@@ -172,20 +172,29 @@ struct InitialHits{
 struct Accumulator{
 
   Accumulator(const promod3::modelling::MotifQuery& query, 
-              Real coverage_thresh) {
+              Real coverage_thresh, int fix_thresh) {
 
     int n = 0;
     for(uint i = 0; i < query.GetN(); ++i) {
       range_start.push_back(n);
-      uint32_t query_size = query.GetQuerySize(i);
-      uint32_t n_triangles = query.GetNTriangles(i);
-      uint16_t thresh = std::ceil((query_size-3) * coverage_thresh);
-      for(uint j = 0; j < n_triangles; ++j) {
-        thresholds.push_back(thresh);
-      }
-      n += n_triangles;
+      n += query.GetNTriangles(i);
     }
     accumulator.assign(n, 0);
+
+    if(fix_thresh < 0) {
+      // default: thresholds depend on sizes of query
+      for(uint i = 0; i < query.GetN(); ++i) {
+        uint32_t query_size = query.GetQuerySize(i);
+        uint32_t n_triangles = query.GetNTriangles(i);
+        uint16_t thresh = std::ceil((query_size-3) * coverage_thresh);
+        for(uint j = 0; j < n_triangles; ++j) {
+          thresholds.push_back(thresh);
+        }
+      }
+    } else {
+      // assign fix thresh for each triangle
+      thresholds.assign(n, fix_thresh);
+    }
   }
 
 
@@ -416,9 +425,9 @@ void GetInitialAlignment(const promod3::core::EMat3X& query_pos,
 bool RefineInitialHit(const promod3::modelling::MotifQuery& query,  
                       const promod3::core::EMat3X& target_pos, 
                       const std::vector<int>& target_flags,
-                      Real dist_thresh, Real refine_thresh, int query_idx, 
-                      int query_triangle_idx, int target_p1, int target_p2, 
-                      int target_p3, geom::Mat4& mat, 
+                      Real dist_thresh, Real refine_thresh, bool swap_thresh,
+                      int query_idx, int query_triangle_idx, int target_p1,
+                      int target_p2, int target_p3, geom::Mat4& mat, 
                       std::vector<std::pair<int, int> >& alignment) {
 
   // the query is only available as geom::Vec3List. This makes sense from a 
@@ -527,8 +536,14 @@ bool RefineInitialHit(const promod3::modelling::MotifQuery& query,
   }
 
   // check whether enough positions are superposed
-  if(static_cast<Real>(alignment.size()) / query_n < refine_thresh) {
-    return false;
+  if(swap_thresh) {
+    if(static_cast<Real>(alignment.size()) / target_pos.cols() < refine_thresh) {
+      return false;
+    }
+  } else {
+    if(static_cast<Real>(alignment.size()) / query_n < refine_thresh) {
+      return false;
+    }
   }
 
   // chain together the final transformation matrix
@@ -546,7 +561,7 @@ void RefineInitialHits(const InitialHits& initial_hits,
                        const promod3::modelling::MotifQuery& query,
                        const promod3::core::EMat3X& target_pos, 
                        const std::vector<int>& flags, Real dist_thresh,
-                       Real refine_thresh,
+                       Real refine_thresh, bool swap_thresh,
                        std::vector<promod3::modelling::MotifMatch>& results) {
 
 
@@ -562,9 +577,9 @@ void RefineInitialHits(const InitialHits& initial_hits,
     for(auto it = initial_hits.initial_hits[query_idx].begin();
         it != initial_hits.initial_hits[query_idx].end(); ++it) {
       if(RefineInitialHit(query, target_pos, flags, dist_thresh, 
-                          refine_thresh, query_idx, it->triangle_idx, 
-                          it->target_p1, it->target_p2, it->target_p3, 
-                          mat, aln)) {
+                          refine_thresh, swap_thresh, query_idx,
+                          it->triangle_idx,  it->target_p1, it->target_p2,
+                          it->target_p3, mat, aln)) {
         // only add the result if its unique
         bool already_there = false;
         for(uint res_idx = 0; res_idx < query_results.size(); ++res_idx) {
@@ -1302,7 +1317,8 @@ std::vector<MotifMatch> FindMotifs(const MotifQuery& query,
                                    Real hash_thresh, 
                                    Real distance_thresh,
                                    Real refine_thresh,
-                                   const std::vector<int>& flags) {
+                                   const std::vector<int>& flags,
+                                   bool swap_thresh) {
 
   promod3::core::ScopedTimerPtr prof = promod3::core::StaticRuntimeProfiler::StartScoped(
                                 "FindMotifs::FindMotifs", 2);
@@ -1345,7 +1361,11 @@ std::vector<MotifMatch> FindMotifs(const MotifQuery& query,
 
   // fetch hash map and setup accumulator
   const MotifHasherMap& map = query.data_->map;
-  Accumulator accumulator(query, hash_thresh);
+  int acc_thresh = -1;
+  if(swap_thresh) {
+    acc_thresh = std::ceil((n_target-3) * hash_thresh);
+  }
+  Accumulator accumulator(query, hash_thresh, acc_thresh);
     
   for(int p1 = 0; p1 < n_target; ++p1) {
     for(int p2 = p1+1; p2 < n_target; ++p2) {
@@ -1421,7 +1441,7 @@ std::vector<MotifMatch> FindMotifs(const MotifQuery& query,
   }
 
   RefineInitialHits(initial_hits, query, eigen_positions, flags, 
-                    distance_thresh, refine_thresh, results);
+                    distance_thresh, refine_thresh, swap_thresh, results);
 
   return results;
 }
diff --git a/modelling/src/motif_finder.hh b/modelling/src/motif_finder.hh
index 72991120..10806f14 100644
--- a/modelling/src/motif_finder.hh
+++ b/modelling/src/motif_finder.hh
@@ -105,7 +105,8 @@ std::vector<MotifMatch> FindMotifs(const MotifQuery& query,
                                    Real hash_thresh = 0.4,
                                    Real distance_thresh = 1.0,
                                    Real refine_thresh = 0.7,
-                                   const std::vector<int>& flags = std::vector<int>());
+                                   const std::vector<int>& flags = std::vector<int>(),
+                                   bool swap_thresh=false);
 
 
 }} // ns
-- 
GitLab