From b59448d2109b874a9870f30a3889ef2ea75d19da Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Thu, 13 Feb 2020 08:59:42 +0100
Subject: [PATCH] allow to adapt parameter defaults in HMMScore

---
 modules/seq/alg/pymod/wrap_seq_alg.cc |  8 ++++-
 modules/seq/alg/src/hmm_score.cc      | 47 +++++++++++++++++++--------
 modules/seq/alg/src/hmm_score.hh      |  8 ++++-
 3 files changed, 48 insertions(+), 15 deletions(-)

diff --git a/modules/seq/alg/pymod/wrap_seq_alg.cc b/modules/seq/alg/pymod/wrap_seq_alg.cc
index 0f6c77e66..784b3af9b 100644
--- a/modules/seq/alg/pymod/wrap_seq_alg.cc
+++ b/modules/seq/alg/pymod/wrap_seq_alg.cc
@@ -262,7 +262,13 @@ void export_hmm_algorithms() {
   def("AddTransitionPseudoCounts", &AddTransitionPseudoCounts, (arg("profile")));
   def("AddNullPseudoCounts", &AddNullPseudoCounts, (arg("profile")));
   def("HMMScore", &HMMScore, (arg("profile_0"), arg("profile_1"), arg("alignment"),
-                              arg("s_0_idx"), arg("s_1_idx")));
+                              arg("s_0_idx"), arg("s_1_idx"), 
+                              arg("match_score_offset")=-0.03,
+                              arg("correl_score_weight")=0.1,
+                              arg("del_start_penalty_factor")=0.6,
+                              arg("del_extend_penalty_factor")=0.6,
+                              arg("ins_start_penalty_factor")=0.6,
+                              arg("ins_extend_penalty_factor")=0.6));
 }
 
 BOOST_PYTHON_MODULE(_ost_seq_alg)
diff --git a/modules/seq/alg/src/hmm_score.cc b/modules/seq/alg/src/hmm_score.cc
index 03dc372a9..f6eddc924 100644
--- a/modules/seq/alg/src/hmm_score.cc
+++ b/modules/seq/alg/src/hmm_score.cc
@@ -31,7 +31,11 @@ void SetupSequence(const ost::seq::ProfileHandle& profile,
 Real InsertionTransitionScore(const ost::seq::ProfileHandle& prof_0, 
                               const ost::seq::ProfileHandle& prof_1,
                               int before_0, int before_1,
-                              int after_0, int after_1) {
+                              int after_0, int after_1,
+                              Real del_start_penalty_factor,
+                              Real del_extend_penalty_factor,
+                              Real ins_start_penalty_factor,
+                              Real ins_extend_penalty_factor) {
 
   // option 1:
   // s0 switched to insertion state after the beginning of the gap
@@ -42,10 +46,12 @@ Real InsertionTransitionScore(const ost::seq::ProfileHandle& prof_0,
 
   // s0
   ins_score += 
-  std::log2(prof_0[before_0].GetTransProb(ost::seq::HMM_M2I)) * 0.6;
+  std::log2(prof_0[before_0].GetTransProb(ost::seq::HMM_M2I)) * 
+  ins_start_penalty_factor;
   int l = after_1 - before_1 - 1;
   ins_score += 
-  (l - 1) * std::log2(prof_0[before_0].GetTransProb(ost::seq::HMM_I2I)) * 0.6;    
+  (l - 1) * std::log2(prof_0[before_0].GetTransProb(ost::seq::HMM_I2I)) * 
+  ins_extend_penalty_factor;    
   ins_score += std::log2(prof_0[before_0].GetTransProb(ost::seq::HMM_I2M));
 
   // s1
@@ -65,9 +71,11 @@ Real InsertionTransitionScore(const ost::seq::ProfileHandle& prof_0,
 
   // s1
   del_score += 
-  std::log2(prof_1[before_1].GetTransProb(ost::seq::HMM_M2D)) * 0.6;
+  std::log2(prof_1[before_1].GetTransProb(ost::seq::HMM_M2D)) * 
+  del_start_penalty_factor;
   for(int i = before_1 + 1; i < after_1 - 1; ++i ) {
-  	del_score += std::log2(prof_1[i].GetTransProb(ost::seq::HMM_D2D)) * 0.6;
+  	del_score += std::log2(prof_1[i].GetTransProb(ost::seq::HMM_D2D)) * 
+    del_extend_penalty_factor;
   }
   del_score += std::log2(prof_1[after_1-1].GetTransProb(ost::seq::HMM_D2M));
 
@@ -79,9 +87,15 @@ Real InsertionTransitionScore(const ost::seq::ProfileHandle& prof_0,
 namespace ost{ namespace seq{ namespace alg{
 
 Real HMMScore(const ost::seq::ProfileHandle& prof_0, 
-	          const ost::seq::ProfileHandle& prof_1,
-	          const ost::seq::AlignmentHandle& aln,
-	          int seq_0_idx, int seq_1_idx) {
+	            const ost::seq::ProfileHandle& prof_1,
+	            const ost::seq::AlignmentHandle& aln,
+	            int seq_0_idx, int seq_1_idx,
+              Real match_score_offset,
+              Real correl_score_weight, 
+              Real del_start_penalty_factor,
+              Real del_extend_penalty_factor,
+              Real ins_start_penalty_factor,
+              Real ins_extend_penalty_factor) {
 
   String s_0;
   int s_0_o;
@@ -108,7 +122,6 @@ Real HMMScore(const ost::seq::ProfileHandle& prof_0,
   int state = 0; // 0: match
                  // 1: insertion s0
                  // 2: insertion s1
-  Real offset = -0.03; // described in 2005 hmm-hmm comparison paper by soding
 
   // sum up column scores and all MM->MM transition scores
   for(uint idx = 0; idx < s_0.size(); ++idx) {
@@ -170,14 +183,22 @@ Real HMMScore(const ost::seq::ProfileHandle& prof_0,
         transition_score += InsertionTransitionScore(prof_0, prof_1,  
                                                      ins_0_before_col_0+s_0_o, 
                                                      ins_0_before_col_1+s_1_o, 
-                                                     col_0+s_0_o, col_1+s_1_o);
+                                                     col_0+s_0_o, col_1+s_1_o,
+                                                     del_start_penalty_factor,
+                                                     del_extend_penalty_factor,
+                                                     ins_start_penalty_factor,
+                                                     ins_extend_penalty_factor);
       }
       if(state == 2 && ins_1_before_col_1 >= 0) {
         // an insertion in s0 ends here
         transition_score += InsertionTransitionScore(prof_1, prof_0, 
                                                      ins_1_before_col_1+s_1_o, 
                                                      ins_1_before_col_0+s_0_o, 
-                                                     col_1+s_1_o, col_0+s_0_o);
+                                                     col_1+s_1_o, col_0+s_0_o,
+                                                     del_start_penalty_factor,
+                                                     del_extend_penalty_factor,
+                                                     ins_start_penalty_factor,
+                                                     ins_extend_penalty_factor);
       }
       ++col_0;
       ++col_1;
@@ -204,8 +225,8 @@ Real HMMScore(const ost::seq::ProfileHandle& prof_0,
   Real score = std::accumulate(matching_scores.begin(), 
                                matching_scores.end(), 0.0);
 
-  score += n_matches * offset;
-  score += 0.1 * correl_score;
+  score += n_matches * match_score_offset;
+  score += correl_score_weight * correl_score;
   score += transition_score;
 
   return score;
diff --git a/modules/seq/alg/src/hmm_score.hh b/modules/seq/alg/src/hmm_score.hh
index c84f74d47..9b5682441 100644
--- a/modules/seq/alg/src/hmm_score.hh
+++ b/modules/seq/alg/src/hmm_score.hh
@@ -28,7 +28,13 @@ namespace ost{ namespace seq{ namespace alg{
 Real HMMScore(const ost::seq::ProfileHandle& profile_0, 
 	          const ost::seq::ProfileHandle& profile_1,
 	          const ost::seq::AlignmentHandle& aln,
-	          int seq_0_idx, int seq_1_idx);
+	          int seq_0_idx, int seq_1_idx, 
+	          Real match_score_offset = -0.03,
+	          Real correl_score_weight = 0.1,
+	          Real del_start_penalty_factor=0.6,
+	          Real del_extend_penalty_factor=0.6,
+	          Real ins_start_penalty_factor=0.6,
+	          Real ins_extend_penalty_factor=0.6);
 
 }}} // ns
 
-- 
GitLab