From 9a4bf7c81882fee7847f57efaad806ea0bc8b5b0 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 14 Feb 2020 08:26:26 +0100
Subject: [PATCH] allow to adapt parameter defaults when adding transition
 pseudo counts

---
 modules/seq/alg/pymod/wrap_seq_alg.cc    |  5 ++++-
 modules/seq/alg/src/hmm_pseudo_counts.cc | 23 ++++++++++++-----------
 modules/seq/alg/src/hmm_pseudo_counts.hh |  4 +++-
 3 files changed, 19 insertions(+), 13 deletions(-)

diff --git a/modules/seq/alg/pymod/wrap_seq_alg.cc b/modules/seq/alg/pymod/wrap_seq_alg.cc
index b9935d13d..40a858f60 100644
--- a/modules/seq/alg/pymod/wrap_seq_alg.cc
+++ b/modules/seq/alg/pymod/wrap_seq_alg.cc
@@ -267,7 +267,10 @@ void export_hmm_algorithms() {
                                                          arg("a")=0.9,
                                                          arg("b")=4.0,
                                                          arg("c")=1.0));
-  def("AddTransitionPseudoCounts", &AddTransitionPseudoCounts, (arg("profile")));
+  def("AddTransitionPseudoCounts", &AddTransitionPseudoCounts, (arg("profile"),
+                                                                arg("gapb")=1.0,
+                                                                arg("gapd")=0.15,
+                                                                arg("gape")=1.0));
   def("AddNullPseudoCounts", &AddNullPseudoCounts, (arg("profile")));
   def("HMMScore", &HMMScore, (arg("profile_0"), arg("profile_1"), arg("alignment"),
                               arg("s_0_idx"), arg("s_1_idx"), 
diff --git a/modules/seq/alg/src/hmm_pseudo_counts.cc b/modules/seq/alg/src/hmm_pseudo_counts.cc
index 6c031e868..37d2d490b 100644
--- a/modules/seq/alg/src/hmm_pseudo_counts.cc
+++ b/modules/seq/alg/src/hmm_pseudo_counts.cc
@@ -416,15 +416,16 @@ ContextProfileDBPtr ContextProfileDB::FromCRF(const String& filename) {
 }
 
 
-void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile) {
+void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile,
+                               Real gapb, Real gapd, Real gape) {
 
   // a priori probabilities estimated with default values of HHblits
-  Real pM2D = 0.15 * 0.0286;
+  Real pM2D = gapd * 0.0286;
   Real pM2I = pM2D; 
   Real pM2M = 1 - pM2D - pM2I;
-  Real pI2I = 0.75;
+  Real pI2I = 1.0 * gape / (gape - 1 + 1.0 / 0.75);;
   Real pI2M = 1 - pI2I;
-  Real pD2D = 0.75;
+  Real pD2D = pI2I;
   Real pD2M = 1 - pD2D;
 
   for (size_t col_idx = 0; col_idx < profile.size(); ++col_idx) {
@@ -433,9 +434,9 @@ void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile) {
 
     // Transitions from M state
     Real neff = data->GetNeff();
-    Real p0 = (neff - 1) * data->GetProb(ost::seq::HMM_M2M) + pM2M;
-    Real p1 = (neff - 1) * data->GetProb(ost::seq::HMM_M2D) + pM2D;
-    Real p2 = (neff - 1) * data->GetProb(ost::seq::HMM_M2I) + pM2I;
+    Real p0 = (neff - 1) * data->GetProb(ost::seq::HMM_M2M) + gapb*pM2M;
+    Real p1 = (neff - 1) * data->GetProb(ost::seq::HMM_M2D) + gapb*pM2D;
+    Real p2 = (neff - 1) * data->GetProb(ost::seq::HMM_M2I) + gapb*pM2I;
     Real sum = p0 + p1 + p2;
     data->SetProb(ost::seq::HMM_M2M, p0/sum);
     data->SetProb(ost::seq::HMM_M2D, p1/sum);
@@ -443,16 +444,16 @@ void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile) {
 
     // Transitions from I state
     Real neff_i = data->GetNeff_I();
-    p0 = neff_i * data->GetProb(ost::seq::HMM_I2M) + pI2M;
-    p1 = neff_i * data->GetProb(ost::seq::HMM_I2I) + pI2I;
+    p0 = neff_i * data->GetProb(ost::seq::HMM_I2M) + gapb*pI2M;
+    p1 = neff_i * data->GetProb(ost::seq::HMM_I2I) + gapb*pI2I;
     sum = p0 + p1;
     data->SetProb(ost::seq::HMM_I2M, p0/sum);
     data->SetProb(ost::seq::HMM_I2I, p1/sum);
 
     // Transitions from D state
     Real neff_d = data->GetNeff_D();
-    p0 = neff_d * data->GetProb(ost::seq::HMM_D2M) + pD2M;
-    p1 = neff_d * data->GetProb(ost::seq::HMM_D2D) + pD2D;
+    p0 = neff_d * data->GetProb(ost::seq::HMM_D2M) + gapb*pD2M;
+    p1 = neff_d * data->GetProb(ost::seq::HMM_D2D) + gapb*pD2D;
     sum = p0 + p1;
     data->SetProb(ost::seq::HMM_D2M, p0/sum);
     data->SetProb(ost::seq::HMM_D2D, p1/sum);
diff --git a/modules/seq/alg/src/hmm_pseudo_counts.hh b/modules/seq/alg/src/hmm_pseudo_counts.hh
index 573f76385..b967d9e6a 100644
--- a/modules/seq/alg/src/hmm_pseudo_counts.hh
+++ b/modules/seq/alg/src/hmm_pseudo_counts.hh
@@ -158,7 +158,9 @@ private:
 std::vector<ContextProfile> profiles_;
 };
 
-void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile);
+void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile,
+                               Real gapb = 1.0, Real gabd = 0.15, 
+                               Real gape = 1.0);
 
 void AddAAPseudoCounts(ost::seq::ProfileHandle& profile,
                        Real a = 1.0, Real b = 1.5, Real c = 1.0);
-- 
GitLab