diff --git a/modules/img/alg/src/stat_accumulator.hh b/modules/img/alg/src/stat_accumulator.hh index 53399f8be99c0cb4e1a445b38aabfd7252bb4c41..a744c5e7144f264f85a0d11476412c418e1589f1 100644 --- a/modules/img/alg/src/stat_accumulator.hh +++ b/modules/img/alg/src/stat_accumulator.hh @@ -20,54 +20,47 @@ #ifndef OST_STAT_ACCUMULATOR_HH #define OST_STAT_ACCUMULATOR_HH +#include <boost/math/special_functions/binomial.hpp> #include <ost/base.hh> #include <ost/message.hh> #include <ost/img/alg/module_config.hh> namespace ost { namespace img { namespace alg { + template<unsigned int MAX_MOMENT=4> class StatAccumulator { public: StatAccumulator(): - mean_(0.0), sum_(0.0), sum2_(0.0), - m2_(0.0), - m3_(0.0), - m4_(0.0), - n_(1) - {} - - void operator()(Real val) + m_(), + w_(0.0), + n_(0) { - Real delta,delta_n,delta_n2,term; - if(MAX_MOMENT>0){ - delta = val - mean_; - delta_n = delta / n_; - } - if(MAX_MOMENT>3){ - delta_n2 = delta_n * delta_n; + for(unsigned int i=0;i<MAX_MOMENT;++i){ + m_[i]=0.0; } - if(MAX_MOMENT>1){ - term = delta * delta_n * (n_-1); - } - if(MAX_MOMENT>3){ - m4_ += term * delta_n2 * (n_*n_ - 3*n_ + 3) + 6 * delta_n2 * m2_ - 4 * delta_n * m3_; - } - if(MAX_MOMENT>2){ - m3_ += term * delta_n * (n_ - 2) - 3 * delta_n * m2_; - } - if(MAX_MOMENT>1){ - m2_ += term; - } - if(MAX_MOMENT>0){ - mean_ += delta_n; + + } + + StatAccumulator(Real val, Real w=1.0): + sum_(val), + sum2_(val*val), + m_(), + w_(w), + n_(1) + { + m_[0]=val; + for(unsigned int i=1;i<MAX_MOMENT;++i){ + m_[i]=0.0; } - n_+=1; - sum_+=val; - sum2_+=val*val; + } + + void operator()(Real val, Real w=1.0) + { + *this+=StatAccumulator(val,w); } StatAccumulator<MAX_MOMENT> operator+(const StatAccumulator<MAX_MOMENT>& acc2) const @@ -79,53 +72,78 @@ public: StatAccumulator<MAX_MOMENT>& operator+=(const StatAccumulator<MAX_MOMENT>& acc) { - if(acc.n_==1){ + if(0.0>=w_){ + *this=acc; return *this; } - if(n_==1){ - mean_=acc.mean_; - sum_=acc.sum_; - sum2_=acc.sum2_; - m2_=acc.m2_; - m3_=acc.m3_; - m4_=acc.m4_; - n_=acc.n_; + if(0.0>=acc.w_){ return *this; } - Real delta,delta_n,delta_n2,na,nanb; - Real nb=acc.n_-1; - if(MAX_MOMENT>0){ - na=n_-1; - delta = acc.mean_ - mean_; - delta_n = delta / (na+nb); - } - if(MAX_MOMENT>1){ - nanb=na*nb; - } - if(MAX_MOMENT>2){ - delta_n2 = delta_n * delta_n; - } - if(MAX_MOMENT>3){ - m4_+=acc.m4_+delta*delta_n*delta_n2*nanb*(na*na-nanb+nb*nb)+6.0*delta_n2*(na*na*acc.m2_+nb*nb*m2_)+4.0*delta_n*(na*acc.m3_-nb*m3_); - } - if(MAX_MOMENT>2){ - m3_+=acc.m3_+delta*delta_n2*nanb*(na-nb)+3.0*delta_n*(na*acc.m2_-nb*m2_); - } - if(MAX_MOMENT>1){ - m2_ += acc.m2_+delta*delta_n*nanb; - } - if(MAX_MOMENT>0){ - mean_ += nb*delta_n; - } + n_+=acc.n_; + Real wn=w_+acc.w_; + sum_+=acc.sum_; sum2_+=acc.sum2_; - n_+=nb; + + if(MAX_MOMENT>0){ + Real delta=acc.m_[0]-m_[0]; + Real delta_w=delta/wn; + if(MAX_MOMENT>2){ + Real w1w2_delta_w=w_*acc.w_*delta_w; + Real w1w2_delta_wp=w1w2_delta_w*w1w2_delta_w; + Real iw2=1.0/acc.w_; + Real iw2pm1=iw2; + Real miw1=-1.0/w_; + Real miw1pm1=miw1; + Real mn[MAX_MOMENT]; // only MAX_MOMENT-2 values needed but overdimensioned to allow compilation for MAX_MOMENT==0 (even though it gets kicked out in the end by the dead code elimination) + for(unsigned int p=3;p<=MAX_MOMENT;++p){ + w1w2_delta_wp*=w1w2_delta_w; + iw2pm1*=iw2; + miw1pm1*=miw1; + Real delta_wk=1.0; + Real s=0.0; + Real mw2k=1.0; + Real w1k=1.0; + for(unsigned int k=1;k<=p-2;++k){ + w1k*=w_; + mw2k*=-acc.w_; + delta_wk*=delta_w; + s+=boost::math::binomial_coefficient<Real>(p, k)*(mw2k*m_[p-k-1]+w1k*acc.m_[p-k-1])*delta_wk; + } + mn[p-3]=acc.m_[p-1]+s+w1w2_delta_wp*(iw2pm1-miw1pm1); + } + for(unsigned int p=3;p<=MAX_MOMENT;++p){ + m_[p-1]+=mn[p-3]; + } + } + if(MAX_MOMENT>1){ + m_[1]+=acc.m_[1]+delta_w*delta*acc.w_*w_; + } + m_[0]+=delta_w*acc.w_; + w_=wn; + } return *this; } Real GetNumSamples() const { - return n_-1.0; + return n_; + } + + Real GetWeight() const + { + return w_; + } + + Real GetMoment(unsigned int i) const + { + if(1>i){ + throw Error("Invalid moment."); + } + if(MAX_MOMENT<i){ + throw Error("Moment was not calculated."); + } + return m_[i-1]; } Real GetMean() const @@ -133,7 +151,7 @@ public: if(MAX_MOMENT<1){ throw Error("Mean was not calculated."); } - return mean_; + return m_[0]; } Real GetSum() const @@ -146,10 +164,10 @@ public: if(MAX_MOMENT<2){ throw Error("Variance was not calculated."); } - if(n_==1.0){ + if(n_==0.0){ return 0.0; } - return m2_/(n_-1); + return m_[1]/(w_); } Real GetStandardDeviation() const @@ -159,10 +177,10 @@ public: Real GetRootMeanSquare() const { - if(n_==1.0){ + if(n_==0.0){ return 0.0; } - return sqrt(sum2_/(n_-1)); + return sqrt(sum2_/(w_)); } Real GetSkewness() const @@ -170,10 +188,10 @@ public: if(MAX_MOMENT<3){ throw Error("Skewness was not calculated."); } - if(m2_<1e-20){ + if(m_[1]<1e-20){ return 0.0; }else{ - return m3_/sqrt(m2_*m2_*m2_); + return m_[2]/sqrt(m_[1]*m_[1]*m_[1]); } } @@ -182,21 +200,19 @@ public: if(MAX_MOMENT<4){ throw Error("Kurtosis was not calculated."); } - if(m2_<1e-20){ + if(m_[1]<1e-20){ return 0.0; }else{ - return ((n_-1)*m4_) / (m2_*m2_); + return w_*m_[3] / (m_[1]*m_[1]); } } private: - Real mean_; Real sum_; Real sum2_; - Real m2_; - Real m3_; - Real m4_; - Real n_; + Real m_[MAX_MOMENT]; + Real w_; + unsigned int n_; }; }}} //ns diff --git a/modules/img/alg/tests/test_stat.cc b/modules/img/alg/tests/test_stat.cc index b736c6bd93d30aebbcc2eb05616e70c4d0d19f7e..7e202193f37dcc42a6929a777dc9e37a26c4fca0 100644 --- a/modules/img/alg/tests/test_stat.cc +++ b/modules/img/alg/tests/test_stat.cc @@ -86,6 +86,51 @@ void test() { BOOST_CHECK_CLOSE(acc_c.GetStandardDeviation(),Real(2.58198889747),Real(0.0001)); BOOST_CHECK_CLOSE(acc_c.GetSkewness()+Real(0.5),Real(0.5),Real(0.0001)); BOOST_CHECK_CLOSE(acc_c.GetKurtosis(),Real(1.77),Real(0.0001)); + + // check accumulator template for restriction to lower order moments + StatAccumulator<3> acc4; + for(int u=0;u<3;++u) { + for(int v=0;v<3;++v) { + acc4(val[u][v]); + } + } + BOOST_CHECK_CLOSE(acc4.GetMean(),Real(5.0),Real(0.0001)); + BOOST_CHECK_CLOSE(acc4.GetStandardDeviation(),Real(2.58198889747),Real(0.0001)); + BOOST_CHECK_CLOSE(acc4.GetSkewness()+Real(0.5),Real(0.5),Real(0.0001)); + BOOST_CHECK_THROW(acc4.GetKurtosis(),ost::Error); + + StatAccumulator<2> acc5; + for(int u=0;u<3;++u) { + for(int v=0;v<3;++v) { + acc5(val[u][v]); + } + } + BOOST_CHECK_CLOSE(acc5.GetMean(),Real(5.0),Real(0.0001)); + BOOST_CHECK_CLOSE(acc5.GetStandardDeviation(),Real(2.58198889747),Real(0.0001)); + BOOST_CHECK_THROW(acc5.GetSkewness(),ost::Error); + BOOST_CHECK_THROW(acc5.GetKurtosis(),ost::Error); + + StatAccumulator<1> acc6; + for(int u=0;u<3;++u) { + for(int v=0;v<3;++v) { + acc6(val[u][v]); + } + } + BOOST_CHECK_CLOSE(acc6.GetMean(),Real(5.0),Real(0.0001)); + BOOST_CHECK_THROW(acc6.GetStandardDeviation(),ost::Error); + BOOST_CHECK_THROW(acc6.GetSkewness(),ost::Error); + BOOST_CHECK_THROW(acc6.GetKurtosis(),ost::Error); + + StatAccumulator<0> acc7; + for(int u=0;u<3;++u) { + for(int v=0;v<3;++v) { + acc7(val[u][v]); + } + } + BOOST_CHECK_THROW(acc7.GetMean(),ost::Error); + BOOST_CHECK_THROW(acc7.GetStandardDeviation(),ost::Error); + BOOST_CHECK_THROW(acc7.GetSkewness(),ost::Error); + BOOST_CHECK_THROW(acc7.GetKurtosis(),ost::Error); } } // namespace