diff --git a/modules/seq/alg/pymod/wrap_seq_alg.cc b/modules/seq/alg/pymod/wrap_seq_alg.cc index 784b3af9b7950ad6c397d1eaa07a29052e574562..a913e3bbc74c00942c1f9ae9b8117d3d126d42d7 100644 --- a/modules/seq/alg/pymod/wrap_seq_alg.cc +++ b/modules/seq/alg/pymod/wrap_seq_alg.cc @@ -79,12 +79,13 @@ list DistToMeanGetData(const Dist2MeanPtr d2m) { return GetList(*d2m, d2m->GetNumResidues(), d2m->GetNumStructures()); } -void AAPseudoCountsSimple(ProfileHandle& profile) { - AddAAPseudoCounts(profile); +void AAPseudoCountsSimple(ProfileHandle& profile, Real a, Real b, Real c) { + AddAAPseudoCounts(profile, a, b, c); } -void AAPseudoCountsAngermueller(ProfileHandle& profile, const ContextProfileDB& db) { - AddAAPseudoCounts(profile, db); +void AAPseudoCountsAngermueller(ProfileHandle& profile, const ContextProfileDB& db, + Real a, Real b, Real c) { + AddAAPseudoCounts(profile, db, a, b, c); } } // anon ns @@ -257,8 +258,15 @@ void export_hmm_algorithms() { .def("AddProfile", &ContextProfileDB::AddProfile, (arg("profile"))) ; - def("AddAAPseudoCounts", &AAPseudoCountsSimple, (arg("profile"))); - def("AddAAPseudoCounts", &AAPseudoCountsAngermueller, (arg("profile"), arg("context_profile_db"))); + def("AddAAPseudoCounts", &AAPseudoCountsSimple, (arg("profile"), + arg("a")=1.0, + arg("b")=1.5, + arg("c")=1.0)); + def("AddAAPseudoCounts", &AAPseudoCountsAngermueller, (arg("profile"), + arg("context_profile_db"), + arg("a")=0.9, + arg("b")=4.0, + arg("c")=1.0)); def("AddTransitionPseudoCounts", &AddTransitionPseudoCounts, (arg("profile"))); def("AddNullPseudoCounts", &AddNullPseudoCounts, (arg("profile"))); def("HMMScore", &HMMScore, (arg("profile_0"), arg("profile_1"), arg("alignment"), diff --git a/modules/seq/alg/src/hmm_pseudo_counts.cc b/modules/seq/alg/src/hmm_pseudo_counts.cc index 7372e2a4c53b540ba98780a240306e83682a24b5..6c031e868fbc52d92610f9e6fb4528902c8d56a8 100644 --- a/modules/seq/alg/src/hmm_pseudo_counts.cc +++ b/modules/seq/alg/src/hmm_pseudo_counts.cc @@ -460,7 +460,8 @@ void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile) { } -void AddAAPseudoCounts(ost::seq::ProfileHandle& profile) { +void AddAAPseudoCounts(ost::seq::ProfileHandle& profile, + Real a, Real b, Real c) { Real full_admixture [20]; for(size_t col_idx = 0; col_idx < profile.size(); ++col_idx) { @@ -476,7 +477,7 @@ void AddAAPseudoCounts(ost::seq::ProfileHandle& profile) { // this is the equation they write in HHblits when you display the help // (Neff[i]-1) got rid of the -1. Well, that's how HHblits implements it Real neff = profile[col_idx].GetHMMData()->GetNeff(); - Real tau = std::min(1.0, 1.0 / (1.0 + (neff) / 1.5)); + Real tau = std::min(1.0, a / (1.0 + std::pow((neff) / b, c))); for (int i = 0; i < 20; ++i) { col_freq[i] = (1. - tau) * col_freq[i] + tau * full_admixture[i]; } @@ -485,7 +486,8 @@ void AddAAPseudoCounts(ost::seq::ProfileHandle& profile) { void AddAAPseudoCounts(ost::seq::ProfileHandle& profile, - const ContextProfileDB& db) { + const ContextProfileDB& db, + Real a, Real b, Real c) { std::vector<Real> cp_scores(db.size(), 0.0); int cp_length = db.profile_length(); @@ -566,7 +568,7 @@ void AddAAPseudoCounts(ost::seq::ProfileHandle& profile, // this is the equation they write in HHblits when you display the help // (Neff[i]-1) got rid of the -1. Well, that's how HHblits implements it Real neff = profile[col_idx].GetHMMData()->GetNeff(); - Real tau = std::min(1.0, 0.9 / (1.0 + (neff) / 4.0)); + Real tau = std::min(1.0, a / (1.0 + std::pow((neff) / b, c))); Real* col_freq = profile[col_idx].freqs_begin(); const std::vector<Real>& counts = count_profile[col_idx]; const std::vector<Real>& context = context_profile[col_idx]; diff --git a/modules/seq/alg/src/hmm_pseudo_counts.hh b/modules/seq/alg/src/hmm_pseudo_counts.hh index bbc48e89a5299b93df5abda1d82a5e0b20720552..573f7638520f2790a362031253070b0595a73584 100644 --- a/modules/seq/alg/src/hmm_pseudo_counts.hh +++ b/modules/seq/alg/src/hmm_pseudo_counts.hh @@ -160,10 +160,12 @@ std::vector<ContextProfile> profiles_; void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile); -void AddAAPseudoCounts(ost::seq::ProfileHandle& profile); +void AddAAPseudoCounts(ost::seq::ProfileHandle& profile, + Real a = 1.0, Real b = 1.5, Real c = 1.0); void AddAAPseudoCounts(ost::seq::ProfileHandle& profile, - const ContextProfileDB& db); + const ContextProfileDB& db, + Real a = 0.9, Real b = 4.0, Real c = 1.0); void AddNullPseudoCounts(ost::seq::ProfileHandle& profile);