1#ifndef RF_EARLY_STOPPING_P_HXX
2#define RF_EARLY_STOPPING_P_HXX
5#include "rf_common.hxx"
14 T
power(T
const & in,
int n)
16 T result = NumericTraits<T>::one();
17 for(
int ii = 0; ii < n ;++ii)
31 bool is_weighted_ =
false;
35 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
38 is_weighted_ = is_weighted;
39 tree_count_ = tree_count;
49 template<
class WeightIter,
class T,
class C>
52 template<
class WeightIter,
class T,
class C>
79 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
81 max_tree_ =
ceil(max_tree_p * tree_count);
82 SB::set_external_parameters(prob, tree_count, is_weighted);
85 template<
class WeightIter,
class T,
class C>
86 bool after_prediction(WeightIter,
int k, MultiArrayView<2, T, C>
const & ,
double )
88 if(k == SB::tree_count_ -1)
90 depths.push_back(
double(k+1)/
double(SB::tree_count_));
95 depths.push_back(
double(k+1)/
double(SB::tree_count_));
117 proportion_(proportion)
120 template<
class WeightIter,
class T,
class C>
123 if(k == SB::tree_count_ -1)
125 depths.push_back(
double(k+1)/
double(SB::tree_count_));
132 if(prob[
argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
134 depths.push_back(
double(k+1)/
double(SB::tree_count_));
140 if(prob[
argMax(prob)] > proportion_ * SB::tree_count_)
142 depths.push_back(
double(k+1)/
double(SB::tree_count_));
175 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
179 SB::set_external_parameters(prob, tree_count, is_weighted);
181 template<
class WeightIter,
class T,
class C>
182 bool after_prediction(WeightIter,
int k, MultiArrayView<2, T, C>
const & prob,
double)
184 if(k == SB::tree_count_ -1)
186 depths.push_back(
double(k+1)/
double(SB::tree_count_));
192 last_/= last_.norm(1);
198 cur_ /= cur_.norm(1);
200 double nrm = last_.norm();
203 depths.push_back(
double(k+1)/
double(SB::tree_count_));
233 proportion_(proportion)
236 template<
class WeightIter,
class T,
class C>
239 if(k == SB::tree_count_ -1)
241 depths.push_back(
double(k+1)/
double(SB::tree_count_));
245 double a = prob[
argMax(prob)];
247 double b = prob[
argMax(prob)];
249 double margin = a - b;
252 if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
254 depths.push_back(
double(k+1)/
double(SB::tree_count_));
260 if(prob[
argMax(prob)] > proportion_ * SB::tree_count_)
262 depths.push_back(
double(k+1)/
double(SB::tree_count_));
300 double binomial(
int N,
int k,
double p)
303 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
306 template<
class WeightIter,
class T,
class C>
307 bool after_prediction(WeightIter,
int k,
310 if(k == SB::tree_count_ -1)
312 depths.push_back(
double(k+1)/
double(SB::tree_count_));
320 int n_a = prob[index];
321 int n_b = prob[(index+1)%2];
322 int n_tilde = (SB::tree_count_ - n_a + n_b);
323 double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
324 vigra_precondition(p_a <= 1,
"probability should be smaller than 1");
330 for(
int ii = 0; ii <= n_b + n_a;++ii)
333 cum_val += binomial(n_b + n_a, ii, p_a);
334 if(cum_val >= 1 -alpha_)
343 depths.push_back(
double(k+1)/
double(SB::tree_count_));
381 double binomial(
int N,
int k,
double p)
384 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
387 template<
class WeightIter,
class T,
class C>
390 if(k == SB::tree_count_ -1)
392 depths.push_back(
double(k+1)/
double(SB::tree_count_));
400 int n_a = prob[index];
401 int n_b = prob[(index+1)%2];
402 int n_needed =
ceil(
double(SB::tree_count_)/2.0)-n_a;
403 int n_tilde = SB::tree_count_ - (n_a +n_b);
404 if(n_tilde <= 0) n_tilde = 0;
405 if(n_needed <= 0) n_needed = 0;
407 for(
int ii = n_needed; ii < n_tilde; ++ii)
408 p += binomial(n_tilde, ii, 0.5);
412 depths.push_back(
double(k+1)/
double(SB::tree_count_));
421class DepthAndSizeStopping:
public StopBase
427 int max_depth_reached = 0;
429 DepthAndSizeStopping()
430 : max_depth_(NumericTraits<int>::max()), min_size_(0)
439 DepthAndSizeStopping(
int depth,
int size) :
440 max_depth_(depth <= 0 ? NumericTraits<int>::max() : depth),
445 void set_external_parameters(ProblemSpec<T>
const &,
446 int = 0,
bool =
false)
449 template<
class Region>
450 bool operator()(Region& region)
452 if (region.depth() > max_depth_)
453 throw std::runtime_error(
"violation in the stopping criterion");
455 return (region.depth() >= max_depth_) || (region.size() < min_size_) ;
458 template<
class WeightIter,
class T,
class C>
459 bool after_prediction(WeightIter,
int ,
460 MultiArrayView<2, T, C>
const &,
double )
Definition array_vector.hxx:514
Base class for, and view to, vigra::MultiArray.
Definition multi_fwd.hxx:127
Main MultiArray class containing the memory management.
Definition multi_fwd.hxx:131
void reshape(const difference_type &shape)
Definition multi_array.hxx:2863
problem specification class for the random forest.
Definition rf_common.hxx:539
Definition rf_earlystopping.hxx:62
StopAfterTree(double max_tree)
Definition rf_earlystopping.hxx:73
Definition rf_earlystopping.hxx:106
StopAfterVoteCount(double proportion)
Definition rf_earlystopping.hxx:115
Definition rf_earlystopping.hxx:27
Definition rf_earlystopping.hxx:280
StopIfBinTest(double alpha, MultiArrayView< 2, double > nck_)
Definition rf_earlystopping.hxx:289
ArrayVector< double > depths
Definition rf_earlystopping.hxx:298
Definition rf_earlystopping.hxx:155
StopIfConverging(double thresh, int num=10)
Definition rf_earlystopping.hxx:168
Definition rf_earlystopping.hxx:222
StopIfMargin(double proportion)
Definition rf_earlystopping.hxx:231
Definition rf_earlystopping.hxx:360
StopIfProb(double alpha, MultiArrayView< 2, double > nck_)
Definition rf_earlystopping.hxx:371
ArrayVector< double > depths
Definition rf_earlystopping.hxx:379
Class for fixed size vectors.
Definition tinyvector.hxx:1008
V power(const V &x)
Exponentiation to a positive integer power by squaring.
Definition mathutil.hxx:427
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition fixedpoint.hxx:675