Skip to content
Snippets Groups Projects
Commit 7e3744b6 authored by Andreas Schenk's avatar Andreas Schenk
Browse files

refactored stat_accumulator to handle weighted values and arbitrary order moments

parent 889a9f16
No related branches found
No related tags found
No related merge requests found
......@@ -20,54 +20,47 @@
#ifndef OST_STAT_ACCUMULATOR_HH
#define OST_STAT_ACCUMULATOR_HH
#include <boost/math/special_functions/binomial.hpp>
#include <ost/base.hh>
#include <ost/message.hh>
#include <ost/img/alg/module_config.hh>
namespace ost { namespace img { namespace alg {
template<unsigned int MAX_MOMENT=4>
class StatAccumulator
{
public:
StatAccumulator():
mean_(0.0),
sum_(0.0),
sum2_(0.0),
m2_(0.0),
m3_(0.0),
m4_(0.0),
n_(1)
{}
void operator()(Real val)
m_(),
w_(0.0),
n_(0)
{
Real delta,delta_n,delta_n2,term;
if(MAX_MOMENT>0){
delta = val - mean_;
delta_n = delta / n_;
}
if(MAX_MOMENT>3){
delta_n2 = delta_n * delta_n;
for(unsigned int i=0;i<MAX_MOMENT;++i){
m_[i]=0.0;
}
if(MAX_MOMENT>1){
term = delta * delta_n * (n_-1);
}
if(MAX_MOMENT>3){
m4_ += term * delta_n2 * (n_*n_ - 3*n_ + 3) + 6 * delta_n2 * m2_ - 4 * delta_n * m3_;
}
if(MAX_MOMENT>2){
m3_ += term * delta_n * (n_ - 2) - 3 * delta_n * m2_;
}
if(MAX_MOMENT>1){
m2_ += term;
}
if(MAX_MOMENT>0){
mean_ += delta_n;
}
StatAccumulator(Real val, Real w=1.0):
sum_(val),
sum2_(val*val),
m_(),
w_(w),
n_(1)
{
m_[0]=val;
for(unsigned int i=1;i<MAX_MOMENT;++i){
m_[i]=0.0;
}
n_+=1;
sum_+=val;
sum2_+=val*val;
}
void operator()(Real val, Real w=1.0)
{
*this+=StatAccumulator(val,w);
}
StatAccumulator<MAX_MOMENT> operator+(const StatAccumulator<MAX_MOMENT>& acc2) const
......@@ -79,53 +72,78 @@ public:
StatAccumulator<MAX_MOMENT>& operator+=(const StatAccumulator<MAX_MOMENT>& acc)
{
if(acc.n_==1){
if(0.0>=w_){
*this=acc;
return *this;
}
if(n_==1){
mean_=acc.mean_;
sum_=acc.sum_;
sum2_=acc.sum2_;
m2_=acc.m2_;
m3_=acc.m3_;
m4_=acc.m4_;
n_=acc.n_;
if(0.0>=acc.w_){
return *this;
}
Real delta,delta_n,delta_n2,na,nanb;
Real nb=acc.n_-1;
if(MAX_MOMENT>0){
na=n_-1;
delta = acc.mean_ - mean_;
delta_n = delta / (na+nb);
}
if(MAX_MOMENT>1){
nanb=na*nb;
}
if(MAX_MOMENT>2){
delta_n2 = delta_n * delta_n;
}
if(MAX_MOMENT>3){
m4_+=acc.m4_+delta*delta_n*delta_n2*nanb*(na*na-nanb+nb*nb)+6.0*delta_n2*(na*na*acc.m2_+nb*nb*m2_)+4.0*delta_n*(na*acc.m3_-nb*m3_);
}
if(MAX_MOMENT>2){
m3_+=acc.m3_+delta*delta_n2*nanb*(na-nb)+3.0*delta_n*(na*acc.m2_-nb*m2_);
}
if(MAX_MOMENT>1){
m2_ += acc.m2_+delta*delta_n*nanb;
}
if(MAX_MOMENT>0){
mean_ += nb*delta_n;
}
n_+=acc.n_;
Real wn=w_+acc.w_;
sum_+=acc.sum_;
sum2_+=acc.sum2_;
n_+=nb;
if(MAX_MOMENT>0){
Real delta=acc.m_[0]-m_[0];
Real delta_w=delta/wn;
if(MAX_MOMENT>2){
Real w1w2_delta_w=w_*acc.w_*delta_w;
Real w1w2_delta_wp=w1w2_delta_w*w1w2_delta_w;
Real iw2=1.0/acc.w_;
Real iw2pm1=iw2;
Real miw1=-1.0/w_;
Real miw1pm1=miw1;
Real mn[MAX_MOMENT]; // only MAX_MOMENT-2 values needed but overdimensioned to allow compilation for MAX_MOMENT==0 (even though it gets kicked out in the end by the dead code elimination)
for(unsigned int p=3;p<=MAX_MOMENT;++p){
w1w2_delta_wp*=w1w2_delta_w;
iw2pm1*=iw2;
miw1pm1*=miw1;
Real delta_wk=1.0;
Real s=0.0;
Real mw2k=1.0;
Real w1k=1.0;
for(unsigned int k=1;k<=p-2;++k){
w1k*=w_;
mw2k*=-acc.w_;
delta_wk*=delta_w;
s+=boost::math::binomial_coefficient<Real>(p, k)*(mw2k*m_[p-k-1]+w1k*acc.m_[p-k-1])*delta_wk;
}
mn[p-3]=acc.m_[p-1]+s+w1w2_delta_wp*(iw2pm1-miw1pm1);
}
for(unsigned int p=3;p<=MAX_MOMENT;++p){
m_[p-1]+=mn[p-3];
}
}
if(MAX_MOMENT>1){
m_[1]+=acc.m_[1]+delta_w*delta*acc.w_*w_;
}
m_[0]+=delta_w*acc.w_;
w_=wn;
}
return *this;
}
Real GetNumSamples() const
{
return n_-1.0;
return n_;
}
Real GetWeight() const
{
return w_;
}
Real GetMoment(unsigned int i) const
{
if(1>i){
throw Error("Invalid moment.");
}
if(MAX_MOMENT<i){
throw Error("Moment was not calculated.");
}
return m_[i-1];
}
Real GetMean() const
......@@ -133,7 +151,7 @@ public:
if(MAX_MOMENT<1){
throw Error("Mean was not calculated.");
}
return mean_;
return m_[0];
}
Real GetSum() const
......@@ -146,10 +164,10 @@ public:
if(MAX_MOMENT<2){
throw Error("Variance was not calculated.");
}
if(n_==1.0){
if(n_==0.0){
return 0.0;
}
return m2_/(n_-1);
return m_[1]/(w_);
}
Real GetStandardDeviation() const
......@@ -159,10 +177,10 @@ public:
Real GetRootMeanSquare() const
{
if(n_==1.0){
if(n_==0.0){
return 0.0;
}
return sqrt(sum2_/(n_-1));
return sqrt(sum2_/(w_));
}
Real GetSkewness() const
......@@ -170,10 +188,10 @@ public:
if(MAX_MOMENT<3){
throw Error("Skewness was not calculated.");
}
if(m2_<1e-20){
if(m_[1]<1e-20){
return 0.0;
}else{
return m3_/sqrt(m2_*m2_*m2_);
return m_[2]/sqrt(m_[1]*m_[1]*m_[1]);
}
}
......@@ -182,21 +200,19 @@ public:
if(MAX_MOMENT<4){
throw Error("Kurtosis was not calculated.");
}
if(m2_<1e-20){
if(m_[1]<1e-20){
return 0.0;
}else{
return ((n_-1)*m4_) / (m2_*m2_);
return w_*m_[3] / (m_[1]*m_[1]);
}
}
private:
Real mean_;
Real sum_;
Real sum2_;
Real m2_;
Real m3_;
Real m4_;
Real n_;
Real m_[MAX_MOMENT];
Real w_;
unsigned int n_;
};
}}} //ns
......
......@@ -86,6 +86,51 @@ void test() {
BOOST_CHECK_CLOSE(acc_c.GetStandardDeviation(),Real(2.58198889747),Real(0.0001));
BOOST_CHECK_CLOSE(acc_c.GetSkewness()+Real(0.5),Real(0.5),Real(0.0001));
BOOST_CHECK_CLOSE(acc_c.GetKurtosis(),Real(1.77),Real(0.0001));
// check accumulator template for restriction to lower order moments
StatAccumulator<3> acc4;
for(int u=0;u<3;++u) {
for(int v=0;v<3;++v) {
acc4(val[u][v]);
}
}
BOOST_CHECK_CLOSE(acc4.GetMean(),Real(5.0),Real(0.0001));
BOOST_CHECK_CLOSE(acc4.GetStandardDeviation(),Real(2.58198889747),Real(0.0001));
BOOST_CHECK_CLOSE(acc4.GetSkewness()+Real(0.5),Real(0.5),Real(0.0001));
BOOST_CHECK_THROW(acc4.GetKurtosis(),ost::Error);
StatAccumulator<2> acc5;
for(int u=0;u<3;++u) {
for(int v=0;v<3;++v) {
acc5(val[u][v]);
}
}
BOOST_CHECK_CLOSE(acc5.GetMean(),Real(5.0),Real(0.0001));
BOOST_CHECK_CLOSE(acc5.GetStandardDeviation(),Real(2.58198889747),Real(0.0001));
BOOST_CHECK_THROW(acc5.GetSkewness(),ost::Error);
BOOST_CHECK_THROW(acc5.GetKurtosis(),ost::Error);
StatAccumulator<1> acc6;
for(int u=0;u<3;++u) {
for(int v=0;v<3;++v) {
acc6(val[u][v]);
}
}
BOOST_CHECK_CLOSE(acc6.GetMean(),Real(5.0),Real(0.0001));
BOOST_CHECK_THROW(acc6.GetStandardDeviation(),ost::Error);
BOOST_CHECK_THROW(acc6.GetSkewness(),ost::Error);
BOOST_CHECK_THROW(acc6.GetKurtosis(),ost::Error);
StatAccumulator<0> acc7;
for(int u=0;u<3;++u) {
for(int v=0;v<3;++v) {
acc7(val[u][v]);
}
}
BOOST_CHECK_THROW(acc7.GetMean(),ost::Error);
BOOST_CHECK_THROW(acc7.GetStandardDeviation(),ost::Error);
BOOST_CHECK_THROW(acc7.GetSkewness(),ost::Error);
BOOST_CHECK_THROW(acc7.GetKurtosis(),ost::Error);
}
} // namespace
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment