From d79a1f396d308ef14b9f44b2eb2b87f7539e3a02 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 15 Jan 2025 16:17:45 +0100
Subject: [PATCH] introduce GetHits method

Returns the entries which have a certain minimum of kmer hits.
This is different to the TopN method that either starts to give you
trash at some point as there are no close hits anymore or does
not give all relevant hits as there are too many.
---
 kmatch.cpp | 42 +++++++++++++++++++++++++++++++++++-------
 1 file changed, 35 insertions(+), 7 deletions(-)

diff --git a/kmatch.cpp b/kmatch.cpp
index f2a93df..9de0877 100644
--- a/kmatch.cpp
+++ b/kmatch.cpp
@@ -306,8 +306,7 @@ public:
     meta_stream.close();  
   }
 
-  std::vector<int32_t> TopN(const std::string& sequence, int32_t top_n,
-                            bool unique) const {
+  void Accumulate(const std::string& sequence, bool unique) const {
 
     ///////////
     // SETUP //
@@ -358,11 +357,13 @@ public:
           ++accumulator_[read_buffer[i]];
         }
       }
-    }
+    } 
+  }
+
+  std::vector<int32_t> TopN(const std::string& sequence, int32_t top_n,
+                            bool unique) const {
 
-    ///////////////
-    // GET TOP N //
-    ///////////////
+    this->Accumulate(sequence, unique);
 
     // pair of numbers per element in top_n (count and index)
     // which are sorted by counts (descending, i.e. top count in front)
@@ -398,6 +399,32 @@ public:
     return best_v;
   }
 
+  std::vector<int32_t> GetHits(const std::string& sequence, int32_t min_hits,
+                               bool unique) const {
+
+    this->Accumulate(sequence, unique);
+
+    std::vector<std::pair<int32_t, int32_t> > hits;
+
+    for(int32_t i = 0; i < N_; ++i) {
+      if(accumulator_[i] >= min_hits) {
+        hits.push_back(std::make_pair(accumulator_[i], i));
+      }
+    }
+
+    std::sort(hits.begin(), hits.end(), std::greater<std::pair<int32_t, int32_t> >());
+
+    int32_t n_hits = hits.size();
+    std::vector<int32_t> result_vec(n_hits * 2);
+
+    for(int32_t i = 0; i < n_hits; ++i) {
+      result_vec[2*i] = hits[i].first;
+      result_vec[2*i+1] = hits[i].second;
+    }
+
+    return result_vec;
+  }
+
 private:
   bool in_mem_indexer_;
   std::vector<int64_t> pos_;
@@ -417,7 +444,8 @@ PYBIND11_MODULE(kmatch, m) {
     pybind11::class_<KMatch>(m, "KMatch")
         .def(pybind11::init<const std::string&, bool>())
         .def_static("FromFasta", &KMatch::FromFasta)
-        .def("TopN", &KMatch::TopN);
+        .def("TopN", &KMatch::TopN)
+        .def("GetHits", &KMatch::GetHits);
 }
 
 } // ns
-- 
GitLab