[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_earlystopping.hxx
1#ifndef RF_EARLY_STOPPING_P_HXX
2#define RF_EARLY_STOPPING_P_HXX
3#include <cmath>
4#include <stdexcept>
5#include "rf_common.hxx"
6
7namespace vigra
8{
9
10#if 0
11namespace es_detail
12{
13 template<class T>
14 T power(T const & in, int n)
15 {
16 T result = NumericTraits<T>::one();
17 for(int ii = 0; ii < n ;++ii)
18 result *= in;
19 return result;
20 }
21}
22#endif
23
24/**Base class from which all EarlyStopping Functors derive.
25 */
27{
28protected:
29 ProblemSpec<> ext_param_;
30 int tree_count_ = 0;
31 bool is_weighted_ = false;
32
33public:
34 template<class T>
35 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
36 {
37 ext_param_ = prob;
38 is_weighted_ = is_weighted;
39 tree_count_ = tree_count;
40 }
41
42#ifdef DOXYGEN
43 /** called after the prediction of a tree was added to the total prediction
44 * \param weightIter Iterator to the weights delivered by current tree.
45 * \param k after kth tree
46 * \param prob Total probability array
47 * \param totalCt sum of probability array.
48 */
49 template<class WeightIter, class T, class C>
50 bool after_prediction(WeightIter weightIter, int k, MultiArrayView<2, T, C> const & prob , double totalCt)
51#else
52 template<class WeightIter, class T, class C>
53 bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
54 {return false;}
55#endif //DOXYGEN
56};
57
58
59/**Stop predicting after a set number of trees
60 */
61class StopAfterTree : public StopBase
62{
63public:
64 double max_tree_p;
65 int max_tree_;
66 typedef StopBase SB;
67
69
70 /** Constructor
71 * \param max_tree number of trees to be used for prediction
72 */
73 StopAfterTree(double max_tree)
74 :
75 max_tree_p(max_tree)
76 {}
77
78 template<class T>
79 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
80 {
81 max_tree_ = ceil(max_tree_p * tree_count);
82 SB::set_external_parameters(prob, tree_count, is_weighted);
83 }
84
85 template<class WeightIter, class T, class C>
86 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
87 {
88 if(k == SB::tree_count_ -1)
89 {
90 depths.push_back(double(k+1)/double(SB::tree_count_));
91 return false;
92 }
93 if(k < max_tree_)
94 return false;
95 depths.push_back(double(k+1)/double(SB::tree_count_));
96 return true;
97 }
98};
99
100/** Stop predicting after a certain amount of votes exceed certain proportion.
101 * case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_
102 * case weighted voting: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ;
103 * (maximal number of votes possible in both cases)
104 */
106{
107public:
108 double proportion_;
109 typedef StopBase SB;
110 ArrayVector<double> depths;
111
112 /** Constructor
113 * \param proportion specify proportion to be used.
114 */
115 StopAfterVoteCount(double proportion)
116 :
117 proportion_(proportion)
118 {}
119
120 template<class WeightIter, class T, class C>
121 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */)
122 {
123 if(k == SB::tree_count_ -1)
124 {
125 depths.push_back(double(k+1)/double(SB::tree_count_));
126 return false;
127 }
128
129
130 if(SB::is_weighted_)
131 {
132 if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
133 {
134 depths.push_back(double(k+1)/double(SB::tree_count_));
135 return true;
136 }
137 }
138 else
139 {
140 if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
141 {
142 depths.push_back(double(k+1)/double(SB::tree_count_));
143 return true;
144 }
145 }
146 return false;
147 }
148
149};
150
151
152/** Stop predicting if the 2norm of the probabilities does not change*/
154
155{
156public:
157 double thresh_;
158 int num_;
161 ArrayVector<double> depths;
162 typedef StopBase SB;
163
164 /** Constructor
165 * \param thresh: If the two norm of the probabilities changes less then thresh then stop
166 * \param num : look at atleast num trees before stopping
167 */
168 StopIfConverging(double thresh, int num = 10)
169 :
170 thresh_(thresh),
171 num_(num)
172 {}
173
174 template<class T>
175 void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
176 {
177 last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
178 cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
179 SB::set_external_parameters(prob, tree_count, is_weighted);
180 }
181 template<class WeightIter, class T, class C>
182 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double)
183 {
184 if(k == SB::tree_count_ -1)
185 {
186 depths.push_back(double(k+1)/double(SB::tree_count_));
187 return false;
188 }
189 if(k <= num_)
190 {
191 last_ = prob;
192 last_/= last_.norm(1);
193 return false;
194 }
195 else
196 {
197 cur_ = prob;
198 cur_ /= cur_.norm(1);
199 last_ -= cur_;
200 double nrm = last_.norm();
201 if(nrm < thresh_)
202 {
203 depths.push_back(double(k+1)/double(SB::tree_count_));
204 return true;
205 }
206 else
207 {
208 last_ = cur_;
209 }
210 }
211 return false;
212 }
213};
214
215
216/** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion
217 * case unweighted voting: stop if margin exceeds proportion * SB::tree_count_
218 * case weighted voting: stop if margin exceeds proportion * msample_ * SB::tree_count_ ;
219 * (maximal number of votes possible in both cases)
220 */
221class StopIfMargin : public StopBase
222{
223public:
224 double proportion_;
225 typedef StopBase SB;
226 ArrayVector<double> depths;
227
228 /** Constructor
229 * \param proportion specify proportion to be used.
230 */
231 StopIfMargin(double proportion)
232 :
233 proportion_(proportion)
234 {}
235
236 template<class WeightIter, class T, class C>
237 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double /* totalCt */)
238 {
239 if(k == SB::tree_count_ -1)
240 {
241 depths.push_back(double(k+1)/double(SB::tree_count_));
242 return false;
243 }
244 int index = argMax(prob);
245 double a = prob[argMax(prob)];
246 prob[argMax(prob)] = 0;
247 double b = prob[argMax(prob)];
248 prob[index] = a;
249 double margin = a - b;
250 if(SB::is_weighted_)
251 {
252 if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
253 {
254 depths.push_back(double(k+1)/double(SB::tree_count_));
255 return true;
256 }
257 }
258 else
259 {
260 if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
261 {
262 depths.push_back(double(k+1)/double(SB::tree_count_));
263 return true;
264 }
265 }
266 return false;
267 }
268};
269
270
271/**Probabilistic Stopping criterion (binomial test)
272 *
273 * Can only be used in a two class setting
274 *
275 * Stop if the Parameters estimated for the underlying binomial distribution
276 * can be estimated with certainty over 1-alpha.
277 * (Thesis, Rahul Nair Page 80 onwards: called the "binomial" criterion
278 */
280{
281public:
282 double alpha_;
283 MultiArrayView<2, double> n_choose_k;
284 /** Constructor
285 * \param alpha specify alpha (=proportion) value for binomial test.
286 * \param nck_ Matrix with precomputed values for n choose k
287 * nck_(n, k) is n choose k.
288 */
290 :
291 alpha_(alpha),
292 n_choose_k(nck_)
293 {}
294 typedef StopBase SB;
295
296 /**ArrayVector that will contain the fraction of trees that was visited before terminating
297 */
299
300 double binomial(int N, int k, double p)
301 {
302// return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
303 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
304 }
305
306 template<class WeightIter, class T, class C>
307 bool after_prediction(WeightIter, int k,
308 MultiArrayView<2, T, C> const &prob, double)
309 {
310 if(k == SB::tree_count_ -1)
311 {
312 depths.push_back(double(k+1)/double(SB::tree_count_));
313 return false;
314 }
315 if(k < 10)
316 {
317 return false;
318 }
319 int index = argMax(prob);
320 int n_a = prob[index];
321 int n_b = prob[(index+1)%2];
322 int n_tilde = (SB::tree_count_ - n_a + n_b);
323 double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
324 vigra_precondition(p_a <= 1, "probability should be smaller than 1");
325 double cum_val = 0;
326 int c = 0;
327 // std::cerr << "prob: " << p_a << std::endl;
328 if(n_a <= 0)n_a = 0;
329 if(n_b <= 0)n_b = 0;
330 for(int ii = 0; ii <= n_b + n_a;++ii)
331 {
332// std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl;
333 cum_val += binomial(n_b + n_a, ii, p_a);
334 if(cum_val >= 1 -alpha_)
335 {
336 c = ii;
337 break;
338 }
339 }
340// std::cerr << c << " " << n_a << " " << n_b << " " << p_a << alpha_ << std::endl;
341 if(c < n_a)
342 {
343 depths.push_back(double(k+1)/double(SB::tree_count_));
344 return true;
345 }
346
347 return false;
348 }
349};
350
351/**Probabilistic Stopping criteria. (toChange)
352 *
353 * Can only be used in a two class setting
354 *
355 * Stop if the probability that the decision will change after seeing all trees falls under
356 * a specified value alpha.
357 * (Thesis, Rahul Nair Page 80 onwards: called the "toChange" criterion
358 */
359class StopIfProb : public StopBase
360{
361public:
362 double alpha_;
363 MultiArrayView<2, double> n_choose_k;
364
365
366 /** Constructor
367 * \param alpha specify alpha (=proportion) value
368 * \param nck_ Matrix with precomputed values for n choose k
369 * nck_(n, k) is n choose k.
370 */
372 :
373 alpha_(alpha),
374 n_choose_k(nck_)
375 {}
376 typedef StopBase SB;
377 /**ArrayVector that will contain the fraction of trees that was visited before terminating
378 */
380
381 double binomial(int N, int k, double p)
382 {
383// return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
384 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
385 }
386
387 template<class WeightIter, class T, class C>
388 bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double)
389 {
390 if(k == SB::tree_count_ -1)
391 {
392 depths.push_back(double(k+1)/double(SB::tree_count_));
393 return false;
394 }
395 if(k <= 10)
396 {
397 return false;
398 }
399 int index = argMax(prob);
400 int n_a = prob[index];
401 int n_b = prob[(index+1)%2];
402 int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a;
403 int n_tilde = SB::tree_count_ - (n_a +n_b);
404 if(n_tilde <= 0) n_tilde = 0;
405 if(n_needed <= 0) n_needed = 0;
406 double p = 0;
407 for(int ii = n_needed; ii < n_tilde; ++ii)
408 p += binomial(n_tilde, ii, 0.5);
409
410 if(p >= 1-alpha_)
411 {
412 depths.push_back(double(k+1)/double(SB::tree_count_));
413 return true;
414 }
415
416 return false;
417 }
418};
419
420
421class DepthAndSizeStopping: public StopBase
422{
423public:
424 int max_depth_;
425 int min_size_;
426
427 int max_depth_reached = 0; //for debug maximum reached depth
428
429 DepthAndSizeStopping()
430 : max_depth_(NumericTraits<int>::max()), min_size_(0)
431 {}
432
433 /** Constructor DepthAndSize Criterion
434 * Stop growing the tree if a certain depth or size is reached or make a
435 * leaf if the node is smaller than a certain size. Note this is checked
436 * before the split so it is still possible that smaller leafs are created
437 */
438
439 DepthAndSizeStopping(int depth, int size) :
440 max_depth_(depth <= 0 ? NumericTraits<int>::max() : depth),
441 min_size_(size)
442 {}
443
444 template<class T>
445 void set_external_parameters(ProblemSpec<T> const &,
446 int /*tree_count*/ = 0, bool /* is_weighted_ */= false)
447 {}
448
449 template<class Region>
450 bool operator()(Region& region)
451 {
452 if (region.depth() > max_depth_)
453 throw std::runtime_error("violation in the stopping criterion");
454
455 return (region.depth() >= max_depth_) || (region.size() < min_size_) ;
456 }
457
458 template<class WeightIter, class T, class C>
459 bool after_prediction(WeightIter, int /* k */,
460 MultiArrayView<2, T, C> const &/* prob */, double /* totalCt */)
461 {
462 return true;
463 }
464};
465
466} //namespace vigra;
467#endif //RF_EARLY_STOPPING_P_HXX
Definition array_vector.hxx:514
Base class for, and view to, vigra::MultiArray.
Definition multi_fwd.hxx:127
Main MultiArray class containing the memory management.
Definition multi_fwd.hxx:131
void reshape(const difference_type &shape)
Definition multi_array.hxx:2863
problem specification class for the random forest.
Definition rf_common.hxx:539
Definition rf_earlystopping.hxx:62
StopAfterTree(double max_tree)
Definition rf_earlystopping.hxx:73
Definition rf_earlystopping.hxx:106
StopAfterVoteCount(double proportion)
Definition rf_earlystopping.hxx:115
Definition rf_earlystopping.hxx:27
Definition rf_earlystopping.hxx:280
StopIfBinTest(double alpha, MultiArrayView< 2, double > nck_)
Definition rf_earlystopping.hxx:289
ArrayVector< double > depths
Definition rf_earlystopping.hxx:298
Definition rf_earlystopping.hxx:155
StopIfConverging(double thresh, int num=10)
Definition rf_earlystopping.hxx:168
Definition rf_earlystopping.hxx:222
StopIfMargin(double proportion)
Definition rf_earlystopping.hxx:231
Definition rf_earlystopping.hxx:360
StopIfProb(double alpha, MultiArrayView< 2, double > nck_)
Definition rf_earlystopping.hxx:371
ArrayVector< double > depths
Definition rf_earlystopping.hxx:379
Class for fixed size vectors.
Definition tinyvector.hxx:1008
V power(const V &x)
Exponentiation to a positive integer power by squaring.
Definition mathutil.hxx:427
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition fixedpoint.hxx:675

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.2