Skip to content
Snippets Groups Projects
Commit 6bdd21ee authored by Studer Gabriel's avatar Studer Gabriel
Browse files

add pseudo counts according to Angermueller et al. 2012

parent 12692e83
No related branches found
No related tags found
No related merge requests found
......@@ -78,6 +78,15 @@ list VarMapGetData(const VarianceMapPtr v_map) {
list DistToMeanGetData(const Dist2MeanPtr d2m) {
return GetList(*d2m, d2m->GetNumResidues(), d2m->GetNumStructures());
}
void AAPseudoCountsSimple(ProfileHandle& profile) {
AddAAPseudoCounts(profile);
}
void AAPseudoCountsAngermueller(ProfileHandle& profile, const ContextProfileDB& db) {
AddAAPseudoCounts(profile, db);
}
} // anon ns
////////////////////////////////////////////////////////////////////
......@@ -228,7 +237,28 @@ void export_distance_analysis()
////////////////////////////////////////////////////////////////////
// algorithms involving hmms
void export_hmm_algorithms() {
def("AddAAPseudoCounts", &AddAAPseudoCounts, (arg("profile")));
class_<ContextProfile>("ContextProfile", init<int>())
.def("SetWeight",&ContextProfile::SetWeight, (arg("pos"), arg("olc"), arg("weight")))
.def("SetPseudoCount",&ContextProfile::SetPseudoCount, (arg("olc"), arg("count")))
.def("SetBias",&ContextProfile::SetBias, (arg("bias")))
.def("GetWeight", &ContextProfile::GetWeight, (arg("pos"), arg("olc")))
.def("GetPseudoCount", &ContextProfile::GetPseudoCount,(arg("olc")))
.def("GetBias", &ContextProfile::GetBias)
.def("GetLength", &ContextProfile::GetLength)
;
class_<ContextProfileDB, ContextProfileDBPtr>("ContextProfileDB", init<>())
.def("__len__",&ContextProfileDB::size)
.def("__getitem__",&ContextProfileDB::at,return_value_policy<reference_existing_object>(), (arg("idx")))
.def("Save", &ContextProfileDB::Save, (arg("filename")))
.def("Load", &ContextProfileDB::Load, (arg("filename"))).staticmethod("Load")
.def("FromCRF", &ContextProfileDB::FromCRF, (arg("filename"))).staticmethod("FromCRF")
.def("AddProfile", &ContextProfileDB::AddProfile, (arg("profile")))
;
def("AddAAPseudoCounts", &AAPseudoCountsSimple, (arg("profile")));
def("AddAAPseudoCounts", &AAPseudoCountsAngermueller, (arg("profile"), arg("context_profile_db")));
def("AddTransitionPseudoCounts", &AddTransitionPseudoCounts, (arg("profile")));
def("HMMScore", &HMMScore, (arg("profile_0"), arg("profile_1"), arg("alignment"),
arg("s_0_idx"), arg("s_1_idx")));
......
#include <ost/seq/alg/hmm_pseudo_counts.hh>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/filesystem/convenience.hpp>
#include <boost/filesystem/fstream.hpp>
#include <boost/iostreams/filter/gzip.hpp>
#include <boost/iostreams/filtering_stream.hpp>
namespace {
#include <limits>
#include <cmath>
#include <ost/string_ref.hh>
namespace {
// to mimic the HHblits behaviour, R is based on the Gonnet substitution matrix
// every entry R[a][b] corresponds to P(a|b)
const Real R[20][20] = {
......@@ -34,6 +44,378 @@ namespace {
namespace ost{ namespace seq{ namespace alg{
void ContextProfileDB::Save(const String& filename) const {
if(profiles_.empty()) {
throw Error("Cannot save empty ContextProfileDB");
}
std::ofstream out_stream(filename.c_str(), std::ios::binary);
if (!out_stream){
std::stringstream ss;
ss << "the file '" << filename << "' could not be opened.";
throw Error(ss.str());
}
std::vector<Real> data;
for(auto it = profiles_.begin(); it != profiles_.end(); ++it) {
const std::vector<Real>& p_data = it->GetData();
data.insert(data.end(), p_data.begin(), p_data.end());
}
uint32_t magic_number = 424242;
out_stream.write(reinterpret_cast<char*>(&magic_number), sizeof(uint32_t));
uint8_t version = 1;
out_stream.write(reinterpret_cast<char*>(&version), sizeof(uint8_t));
uint32_t length = profiles_[0].GetLength();
out_stream.write(reinterpret_cast<char*>(&length),sizeof(uint32_t));
uint32_t n_profiles = profiles_.size();
out_stream.write(reinterpret_cast<char*>(&n_profiles),sizeof(uint32_t));
uint32_t data_size = data.size();
out_stream.write(reinterpret_cast<char*>(&data_size),sizeof(uint32_t));
out_stream.write(reinterpret_cast<char*>(&data[0]), data_size*sizeof(Real));
}
ContextProfileDBPtr ContextProfileDB::Load(const String& filename) {
std::ifstream in_stream(filename.c_str(), std::ios::binary);
if (!in_stream){
std::stringstream ss;
ss << "the file '" << filename << "' could not be opened.";
throw Error(ss.str());
}
uint32_t magic_number;
in_stream.read(reinterpret_cast<char*>(&magic_number), sizeof(uint32_t));
if(magic_number != 424242) {
std::stringstream ss;
ss << "Could not read magic number in " << filename<<". Either the file ";
ss << "is corrupt or does not contain a ContextProfileDB.";
throw Error(ss.str());
}
uint8_t version;
in_stream.read(reinterpret_cast<char*>(&version), sizeof(uint8_t));
if(version != 1) {
std::stringstream ss;
ss << "ContextProfileDB in " << filename << " is of version " << version;
ss << " but only version 1 can be read.";
throw Error(ss.str());
}
uint32_t length;
in_stream.read(reinterpret_cast<char*>(&length), sizeof(uint32_t));
uint32_t n_profiles;
in_stream.read(reinterpret_cast<char*>(&n_profiles),sizeof(uint32_t));
uint32_t data_size;
in_stream.read(reinterpret_cast<char*>(&data_size), sizeof(uint32_t));
std::vector<Real> data(data_size, 0.0);
in_stream.read(reinterpret_cast<char*>(&data[0]), data_size*sizeof(Real));
ContextProfileDBPtr db(new ContextProfileDB);
int data_loc = 0;
int profile_data_size = ContextProfile::DataSize(length);
for(uint i = 0; i < n_profiles; ++i, data_loc+=profile_data_size) {
db->profiles_.push_back(ContextProfile(length, &data[data_loc]));
}
return db;
}
ContextProfileDBPtr ContextProfileDB::FromCRF(const String& filename) {
// open it up
boost::iostreams::filtering_stream<boost::iostreams::input> in;
boost::filesystem::ifstream stream(filename);
if(!stream) {
throw Error("Could not open " + filename);
}
// add unzip if necessary
if(boost::iequals(".gz", boost::filesystem::extension(filename))) {
in.push(boost::iostreams::gzip_decompressor());
}
in.push(stream);
// tmp. storage
std::string line;
ost::StringRef sline;
std::vector<ost::StringRef> chunks;
int size = -1;
int length = -1;
while(std::getline(in, line)) {
sline = ost::StringRef(line.c_str(), line.length());
if(sline.length()>=8 &&
sline.substr(0, 8) == ost::StringRef("CrfState", 8)) {
throw Error("Require to read SIZE and LENG before first ContextProfile");
}
if(sline.length()>4 &&
sline.substr(0, 4) == ost::StringRef("SIZE", 4)) {
chunks = sline.split();
if(chunks.size() != 2) {
throw Error("Badly formatted line: " + line);
}
std::pair<bool, int> s = chunks[1].to_int();
if (!s.first) {
throw Error("Badly formatted line: " + line);
}
size = s.second;
}
if(sline.length()>4 &&
sline.substr(0, 4) == ost::StringRef("LENG", 4)) {
chunks = sline.split();
if(chunks.size() != 2) {
throw Error("Badly formatted line: " + line);
}
std::pair<bool, int> l = chunks[1].to_int();
if (!l.first) {
throw Error("Badly formatted line: " + line);
}
length = l.second;
}
if(size != -1 && length != -1) {
break;
}
}
ContextProfileDBPtr db(new ContextProfileDB);
bool in_crf_state=false;
ContextProfile current_context_profile(length);
// stuff that we read for every context profile
Real cp_bias = std::numeric_limits<Real>::quiet_NaN();
int cp_leng = -1;
int cp_alph = -1;
std::vector<char> cp_olcs;
std::vector<std::vector<Real> > cp_weights;
std::vector<Real> cp_pc;
while(std::getline(in, line)) {
sline = ost::StringRef(line.c_str(), line.length());
if(!in_crf_state) {
if(sline.length()>=8 &&
sline.substr(0, 8) == ost::StringRef("CrfState", 8)) {
in_crf_state = true;
}
continue;
}
if(sline.length()>4 &&
sline.substr(0, 4) == ost::StringRef("BIAS", 4)) {
chunks = sline.split();
if(chunks.size() != 2) {
throw Error("Badly formatted line: " + line);
}
std::pair<bool, Real> b = chunks[1].to_float();
if (!b.first) {
throw Error("Badly formatted line: " + line);
}
cp_bias = b.second;
continue;
}
if(sline.length()>4 &&
sline.substr(0, 4) == ost::StringRef("LENG", 4)) {
chunks = sline.split();
if(chunks.size() != 2) {
throw Error("Badly formatted line: " + line);
}
std::pair<bool, int> l = chunks[1].to_int();
if (!l.first) {
throw Error("Badly formatted line: " + line);
}
cp_leng = l.second;
continue;
}
if(sline.length()>4 &&
sline.substr(0, 4) == ost::StringRef("ALPH", 4)) {
chunks = sline.split();
if(chunks.size() != 2) {
throw Error("Badly formatted line: " + line);
}
std::pair<bool, int> a = chunks[1].to_int();
if (!a.first) {
throw Error("Badly formatted line: " + line);
}
cp_alph = a.second;
continue;
}
if(sline.length()>7 &&
sline.substr(0, 7) == ost::StringRef("WEIGHTS", 7)) {
if(cp_alph == -1 || cp_leng == -1) {
throw Error("Require LENG and ALPH before reading WEIGHTS in entry");
}
chunks = sline.split();
if(chunks.size() != static_cast<uint>(cp_alph+1)) {
throw Error("Badly formatted line: " + line);
}
for(int i = 0; i < cp_alph; ++i) {
cp_olcs.push_back(chunks[i+1][0]);
}
cp_weights = std::vector<std::vector<Real> >(cp_leng, std::vector<Real>());
// the next cp_leng lines should be the weights
for(int i = 0; i < cp_leng; ++i) {
if(!std::getline(in, line)) {
throw Error("Failed to load all WEIGHTS in entry");
}
sline = ost::StringRef(line.c_str(), line.length());
if(sline.length()>=2 &&
sline.substr(0, 2) == ost::StringRef("//", 2)) {
throw Error("Arrived at end of entry before reading all weights");
}
chunks = sline.split();
if(chunks.size() != static_cast<uint>(cp_alph + 1)) {
throw Error("Failed to load all WEIGHTS in entry");
}
// read the position of the weights. if it can't be parsed to int,
// something is fishy (e.g. already another key word)
std::pair<bool, int> p = chunks[0].to_int();
if(!p.first) {
throw Error("Badly formatted line: " + line);
}
if(p.second < 1 || p.second > cp_leng) {
throw Error("Badly formatted line: " + line);
}
for(int j = 0; j < cp_alph; ++j) {
std::pair<bool, Real> w = chunks[j+1].to_float();
if(!w.first) {
throw Error("Badly formatted line: " + line);
}
cp_weights[p.second-1].push_back(0.001 * w.second);
}
}
continue;
}
if(sline.length()>2 &&
sline.substr(0, 2) == ost::StringRef("PC", 2)) {
chunks = sline.split();
if(chunks.size() != static_cast<uint>(cp_alph+1)) {
throw Error("Badly formatted line: " + line);
}
for(int i = 0; i < cp_alph; ++i) {
std::pair<bool, Real> w = chunks[i+1].to_float();
if(!w.first) {
throw Error("Badly formatted line: " + line);
}
cp_pc.push_back(0.001*w.second);
}
}
if(sline.length()>=2 &&
sline.substr(0, 2) == ost::StringRef("//", 2)) {
//check if data is OK
if(std::isnan(cp_bias)) {
throw Error("Oberved entry without BIAS value");
}
if(cp_leng == -1) {
throw Error("Observed entry without LENG value");
}
if(cp_leng != length) {
throw Error("Require all entries to be of same length as specified in file header");
}
if(cp_alph == -1) {
throw Error("Observed Entry without ALPH value");
}
if(cp_alph != 20) {
throw Error("Expect ALPH to be 20 for all entries");
}
if(cp_olcs.empty()) {
throw Error("Observed Entry without WEIGHTS value");
}
if(cp_olcs.size() != 20) {
throw Error("Expect exactly 20 items after WEIGHTS key word");
}
for(int i = 0; i < length; ++i) {
if(cp_weights[i].size() != 20) {
throw Error("Observed Entry with != 20 weight values for each pos");
}
}
if(cp_pc.size() != 20) {
throw Error("Observed Entry with != 20 PC values");
}
// fill data in context profile
current_context_profile = ContextProfile(length);
current_context_profile.SetBias(cp_bias);
for(int i = 0; i < length; ++i) {
for(int j = 0; j < 20; ++j) {
current_context_profile.SetWeight(i, cp_olcs[j], cp_weights[i][j]);
}
}
// magic rescaling of context profiles as implemented in hhblits
Real max = -std::numeric_limits<Real>::max();
for(int i = 0; i < 20; ++i) {
max = std::max(max, cp_pc[i]);
}
Real sum = 0.0;
for(int i = 0; i < 20; ++i) {
sum += std::exp(cp_pc[i]-max);
}
Real tmp = max + std::log(sum);
for(int i = 0; i < 20; ++i) {
current_context_profile.SetPseudoCount(cp_olcs[i], std::exp(cp_pc[i]-tmp));
}
db->AddProfile(current_context_profile);
// invalidate all variables for a next profile
cp_bias = std::numeric_limits<Real>::quiet_NaN();
cp_leng = -1;
cp_alph = -1;
cp_olcs = std::vector<char>();
cp_weights = std::vector<std::vector<Real> >();
cp_pc = std::vector<Real>();
in_crf_state = false;
}
}
if(db->size() != static_cast<size_t>(size)) {
throw Error("Number of read entries does not correspond to what was promised in the header");
}
return db;
}
void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile) {
// a priori probabilities estimated with default values of HHblits
......@@ -137,4 +519,134 @@ void AddAAPseudoCounts(ost::seq::ProfileHandle& profile) {
profile.SetNullModel(new_null_model);
}
void AddAAPseudoCounts(ost::seq::ProfileHandle& profile,
const ContextProfileDB& db) {
////////////////////
// do frequencies //
////////////////////
std::vector<Real> cp_scores(db.size(), 0.0);
int cp_length = db.profile_length();
if(cp_length % 2 != 1) {
throw Error("Length of profiles in db must be an odd number");
}
// extension from center to both directions
int cp_ext = (cp_length - 1) / 2;
std::vector<std::vector<Real> >
count_profile(profile.size(), std::vector<Real>(20, 0.0));
std::vector<std::vector<Real> >
context_profile(profile.size(), std::vector<Real>(20, 0.0));
// fill counts profile
for(size_t col_idx = 0; col_idx < profile.size(); ++col_idx) {
HMMDataPtr hmm_data = profile[col_idx].GetHMMData();
Real neff = hmm_data->GetNeff();
Real* col_freq = profile[col_idx].freqs_begin();
std::vector<Real>& counts = count_profile[col_idx];
for(int i = 0; i < 20; ++i) {
counts[i] = col_freq[i] * neff;
}
}
// process columns
for(size_t col_idx = 0; col_idx < count_profile.size(); ++col_idx) {
int min = std::max(0, static_cast<int>(col_idx) - cp_ext);
int max = std::min(col_idx + cp_ext, count_profile.size() - 1);
int cp_min = static_cast<int>(col_idx) < cp_ext ?
std::abs(static_cast<int>(col_idx) - cp_ext) : 0;
Real max_score = -std::numeric_limits<Real>::max();
// estimate score for each context profile in db
for(size_t cp_idx = 0; cp_idx < db.size(); ++cp_idx) {
const ContextProfile& cp = db[cp_idx];
Real score = cp.GetBias();
for(int i = min, j = cp_min; i<=max; ++i, ++j) {
const Real* weights = cp.GetWeights(j);
const std::vector<Real>& counts = count_profile[i];
for(int k = 0; k < 20; ++k) {
score += weights[k] * counts[k];
}
}
max_score = std::max(max_score, score);
cp_scores[cp_idx] = score;
}
// same in hhblits code: log-sum-exp trick to avoid overflows
Real summed_exp_score = 0.0;
for(size_t i = 0; i < cp_scores.size(); ++i) {
summed_exp_score += std::exp(cp_scores[i]-max_score);
}
Real tmp = max_score - std::log(summed_exp_score);
std::vector<Real>& col_freq = context_profile[col_idx];
for(size_t cp_idx = 0; cp_idx < db.size(); ++cp_idx) {
Real w = std::exp(cp_scores[cp_idx]-tmp);
const Real* cp_pc = db[cp_idx].GetPseudoCounts();
for(int i = 0; i < 20; ++i) {
col_freq[i] += w*cp_pc[i];
}
}
// normalize
Real sum = 0.0;
for(int i = 0; i < 20; ++i) {
sum += col_freq[i];
}
Real norm_factor = 1.0 / sum;
for(int i = 0; i < 20; ++i) {
col_freq[i] *= norm_factor;
}
}
// mix together count and context profile to get final frequencies
for(size_t col_idx = 0; col_idx < profile.size(); ++col_idx) {
// tau estimated as in hhblits in diversity dependent mode:
// tau = a/(1+((Neff[i]-1)/b)^c) with default values a=0.9, b=4.0, c=1.0
// this is the equation they write in HHblits when you display the help
// (Neff[i]-1) got rid of the -1. Well, that's how HHblits implements it
Real neff = profile[col_idx].GetHMMData()->GetNeff();
Real tau = std::min(1.0, 0.9 / (1.0 + (neff) / 4.0));
Real* col_freq = profile[col_idx].freqs_begin();
const std::vector<Real>& counts = count_profile[col_idx];
const std::vector<Real>& context = context_profile[col_idx];
for(int i = 0; i < 20; ++i) {
col_freq[i] = tau*context[i] + (1.-tau)*counts[i]/neff;
}
}
/////////////////////////
// do null_frequencies //
/////////////////////////
Real mixing_factor = 100.0 / profile.GetNeff();
const Real* current_null_freq = profile.GetNullModel().freqs_begin();
Real null_freq[20];
for(int i = 0; i < 20; ++i) {
null_freq[i] = current_null_freq[i] * mixing_factor;
}
for(size_t i = 0; i < profile.size(); ++i) {
const Real* freq = profile[i].freqs_begin();
for(int j = 0; j < 20; ++j) {
null_freq[j] += freq[j];
}
}
// normalize
Real summed_p = 0.0;
for(int i = 0; i < 20; ++i) {
summed_p += null_freq[i];
}
Real factor = 1.0/summed_p;
for(int i = 0; i < 20; ++i) {
null_freq[i] *= factor;
}
// create new nullmodel and set it
ost::seq::ProfileColumn new_null_model;
Real* new_null_freq = new_null_model.freqs_begin();
for(int i = 0; i < 20; ++i) {
new_null_freq[i] = null_freq[i];
}
profile.SetNullModel(new_null_model);
}
}}} // ns
......@@ -24,10 +24,147 @@
namespace ost{ namespace seq{ namespace alg{
class ContextProfileDB;
typedef boost::shared_ptr<ContextProfileDB> ContextProfileDBPtr;
class ContextProfile{
public:
ContextProfile(int length): length_(length),
data_(ContextProfile::DataSize(length), 0.0) { }
ContextProfile(int length, Real* data): length_(length),
data_(ContextProfile::DataSize(length), 0.0) {
memcpy(&data_[0], data, data_.size() * sizeof(Real));
}
void SetWeight(int pos, char olc, Real weight) {
if(pos >= length_) {
throw Error("Tried to access invalid pos in ContextProfile");
}
int olc_idx = ProfileColumn::GetIndex(olc);
if(olc_idx != -1) {
data_[pos*20 + olc_idx] = weight;
} else {
throw Error("Invalid one letter code in ContextProfile");
}
}
void SetPseudoCount(char olc, Real count) {
int olc_idx = ProfileColumn::GetIndex(olc);
if(olc_idx != -1) {
data_[length_*20 + olc_idx] = count;
} else {
throw Error("Invalid one letter code in ContextProfile");
}
}
void SetBias(Real bias) { data_.back() = bias; }
const Real* GetWeights(int pos) const{
if(pos >= length_) {
throw Error("Tried to access invalid pos in ContextProfile");
}
return &data_[pos*20];
}
Real GetWeight(int pos, char olc) {
if(pos >= length_) {
throw Error("Tried to access invalid pos in ContextProfile");
}
int olc_idx = ProfileColumn::GetIndex(olc);
if(olc_idx != -1) {
return data_[pos*20 + olc_idx];
} else {
throw Error("Invalid one letter code in ContextProfile");
}
}
const Real* GetPseudoCounts() const { return &data_[length_*20]; }
Real GetPseudoCount(char olc) {
int olc_idx = ProfileColumn::GetIndex(olc);
if(olc_idx != -1) {
return data_[length_*20 + olc_idx];
} else {
throw Error("Invalid one letter code in ContextProfile");
}
}
Real GetBias() const { return data_.back(); }
const std::vector<Real>& GetData() const { return data_; }
int GetLength() const { return length_; }
static int DataSize(int length) { return (length+1)*20+1; }
private:
int length_;
// data organisation:
// context weights in chunks of 20 (length_ chunks)
// followed by 20 elements representing the context pseudo counts
// last element is the bias
std::vector<Real> data_;
};
class ContextProfileDB {
public:
ContextProfileDB() { }
void Save(const String& filename) const;
static ContextProfileDBPtr Load(const String& filename);
static ContextProfileDBPtr FromCRF(const String& filename);
void AddProfile(const ContextProfile& profile){
// enforce same length for all profiles
if(!profiles_.empty()) {
if(profile.GetLength() != profiles_[0].GetLength()) {
throw Error("Require all profiles to be of same length");
}
}
profiles_.push_back(profile);
}
const ContextProfile& operator [](int idx) const {
return profiles_[idx];
}
const ContextProfile& at(int idx) const {
return profiles_.at(idx);
}
size_t size() const {
return profiles_.size();
}
size_t profile_length() const {
if(profiles_.empty()) {
throw Error("DB must contain profiles to get profile length");
}
return profiles_[0].GetLength();
}
private:
std::vector<ContextProfile> profiles_;
};
void AddTransitionPseudoCounts(ost::seq::ProfileHandle& profile);
void AddAAPseudoCounts(ost::seq::ProfileHandle& profile);
void AddAAPseudoCounts(ost::seq::ProfileHandle& profile,
const ContextProfileDB& db);
}}} // ns
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment