[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_ridge_split.hxx
1//
2// C++ Interface: rf_ridge_split
3//
4// Description:
5//
6//
7// Author: Nico Splitthoff <splitthoff@zg00103>, (C) 2009
8//
9// Copyright: See COPYING file that comes with this distribution
10//
11//
12#ifndef VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
13#define VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
14//#include "rf_sampling.hxx"
15#include "../sampling.hxx"
16#include "rf_split.hxx"
17#include "rf_nodeproxy.hxx"
18#include "../regression.hxx"
19
20#define outm(v) std::cout << (#v) << ": " << (v) << std::endl;
21#define outm2(v) std::cout << (#v) << ": " << (v) << ", ";
22
23namespace vigra
24{
25
26/*template<>
27class Node<i_RegrNode>
28: public NodeBase
29{
30public:
31 typedef NodeBase BT;
32
33
34 Node( BT::T_Container_type & topology,
35 BT::P_Container_type & param,
36 int nNumCols)
37 : BT(5+nNumCols,2+nNumCols,topology, param)
38 {
39 BT::typeID() = i_RegrNode;
40 }
41
42 Node( BT::T_Container_type & topology,
43 BT::P_Container_type & param,
44 INT n )
45 : BT(5,2,topology, param, n)
46 {}
47
48 Node( BT & node_)
49 : BT(5, 2, node_)
50 {}
51
52 double& threshold()
53 {
54 return BT::parameters_begin()[1];
55 }
56
57 BT::INT& column()
58 {
59 return BT::column_data()[0];
60 }
61
62 template<class U, class C>
63 BT::INT& next(MultiArrayView<2,U,C> const & feature)
64 {
65 return (feature(0, column()) < threshold())? child(0):child(1);
66 }
67};*/
68
69
70template<class ColumnDecisionFunctor, class Tag = ClassificationTag>
71class RidgeSplit: public SplitBase<Tag>
72{
73 public:
74
75
76 typedef SplitBase<Tag> SB;
77
78 ArrayVector<Int32> splitColumns;
79 ColumnDecisionFunctor bgfunc;
80
81 double region_gini_;
82 ArrayVector<double> min_gini_;
83 ArrayVector<std::ptrdiff_t> min_indices_;
84 ArrayVector<double> min_thresholds_;
85
86 int bestSplitIndex;
87
88 //dns
89 bool m_bDoScalingInTraining;
90 bool m_bDoBestLambdaBasedOnGini;
91
92 RidgeSplit()
93 :m_bDoScalingInTraining(true),
94 m_bDoBestLambdaBasedOnGini(true)
95 {
96 }
97
98 double minGini() const
99 {
100 return min_gini_[bestSplitIndex];
101 }
102
103 int bestSplitColumn() const
104 {
105 return splitColumns[bestSplitIndex];
106 }
107
108 bool& doScalingInTraining()
109 { return m_bDoScalingInTraining; }
110
111 bool& doBestLambdaBasedOnGini()
112 { return m_bDoBestLambdaBasedOnGini; }
113
114 template<class T>
115 void set_external_parameters(ProblemSpec<T> const & in)
116 {
118 bgfunc.set_external_parameters(in);
119 int featureCount_ = in.column_count_;
120 splitColumns.resize(featureCount_);
121 for(int k=0; k<featureCount_; ++k)
122 splitColumns[k] = k;
123 min_gini_.resize(featureCount_);
124 min_indices_.resize(featureCount_);
125 min_thresholds_.resize(featureCount_);
126 }
127
128
129 template<class T, class C, class T2, class C2, class Region, class Random>
130 int findBestSplit(MultiArrayView<2, T, C> features,
131 MultiArrayView<2, T2, C2> multiClassLabels,
132 Region & region,
133 ArrayVector<Region>& childRegions,
134 Random & randint)
135 {
136
137 //std::cerr << "Split called" << std::endl;
138 typedef typename MultiArrayView <2, T, C>::difference_type fShape;
139 typedef typename MultiArrayView <2, T2, C2>::difference_type lShape;
140 typedef typename MultiArrayView <2, double>::difference_type dShape;
141
142 // calculate things that haven't been calculated yet.
143// std::cout << "start" << std::endl;
144 if(std::accumulate(region.classCounts().begin(),
145 region.classCounts().end(), 0) != region.size())
146 {
147 RandomForestClassCounter< MultiArrayView<2,T2, C2>,
148 ArrayVector<double> >
149 counter(multiClassLabels, region.classCounts());
150 std::for_each( region.begin(), region.end(), counter);
151 region.classCountsIsValid = true;
152 }
153
154
155 // Is the region pure already?
156 region_gini_ = GiniCriterion::impurity(region.classCounts(),
157 region.size());
158 if(region_gini_ == 0 || region.size() < SB::ext_param_.actual_mtry_ || region.oob_size() < 2)
159 return SB::makeTerminalNode(features, multiClassLabels, region, randint);
160
161 // select columns to be tried.
162 for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
163 std::swap(splitColumns[ii],
164 splitColumns[ii+ randint(features.shape(1) - ii)]);
165
166 //do implicit binary case
167 MultiArray<2, T2> labels(lShape(multiClassLabels.shape(0),1));
168 //number of classes should be >1, otherwise makeTerminalNode would have been called
169 int nNumClasses=0;
170 for(int n=0; n<static_cast<int>(region.classCounts().size()); n++)
171 nNumClasses+=((region.classCounts()[n]>0) ? 1:0);
172
173 //convert to binary case
174 if(nNumClasses>2)
175 {
176 int nMaxClass=0;
177 int nMaxClassCounts=0;
178 for(int n=0; n<static_cast<int>(region.classCounts().size()); n++)
179 {
180 //this should occur in any case:
181 //we had more than two non-zero classes in order to get here
182 if(region.classCounts()[n]>nMaxClassCounts)
183 {
184 nMaxClassCounts=region.classCounts()[n];
185 nMaxClass=n;
186 }
187 }
188
189 //create binary labels
190 for(int n=0; n<multiClassLabels.shape(0); n++)
191 labels(n,0)=((multiClassLabels(n,0)==nMaxClass) ? 1:0);
192 }
193 else
194 labels=multiClassLabels;
195
196 //_do implicit binary case
197
198 //uncomment this for some debugging
199/* int nNumCases=features.shape(0);
200
201 typedef typename MultiArrayView <2, int>::difference_type nShape;
202 MultiArray<2, int> elementCounterArray(nShape(nNumCases,1),(int)0);
203 int nUniqueElements=0;
204 for(int n=0; n<region.size(); n++)
205 elementCounterArray[region[n]]++;
206
207 for(int n=0; n<nNumCases; n++)
208 nUniqueElements+=((elementCounterArray[n]>0) ? 1:0);
209
210 outm(nUniqueElements);
211 nUniqueElements=0;
212 MultiArray<2, int> elementCounterArray_oob(nShape(nNumCases,1),(int)0);
213 for(int n=0; n<region.oob_size(); n++)
214 elementCounterArray_oob[region.oob_begin()[n]]++;
215 for(int n=0; n<nNumCases; n++)
216 nUniqueElements+=((elementCounterArray_oob[n]>0) ? 1:0);
217 outm(nUniqueElements);
218
219 int notUniqueElements=0;
220 for(int n=0; n<nNumCases; n++)
221 notUniqueElements+=(((elementCounterArray_oob[n]>0) && (elementCounterArray[n]>0)) ? 1:0);
222 outm(notUniqueElements);*/
223
224 //outm(SB::ext_param_.actual_mtry_);
225
226
227//select submatrix of features for regression calculation
228 MultiArrayView<2, T, C> cVector;
229 MultiArray<2, T> xtrain(fShape(region.size(),SB::ext_param_.actual_mtry_));
230 //we only want -1 and 1 for this
231 MultiArray<2, double> regrLabels(dShape(region.size(),1));
232
233 //copy data into a vigra data structure and centre and scale while doing so
234 MultiArray<2, double> meanMatrix(dShape(SB::ext_param_.actual_mtry_,1));
235 MultiArray<2, double> stdMatrix(dShape(SB::ext_param_.actual_mtry_,1));
236 for(int m=0; m<SB::ext_param_.actual_mtry_; m++)
237 {
238 cVector=columnVector(features, splitColumns[m]);
239
240 //centre and scale the data
241 double dCurrFeatureColumnMean=0.0;
242 double dCurrFeatureColumnStd=1.0; //default value
243
244 //calc mean on bootstrap data
245 for(int n=0; n<region.size(); n++)
246 dCurrFeatureColumnMean+=cVector[region[n]];
247 dCurrFeatureColumnMean/=region.size();
248 //calc scaling
249 if(m_bDoScalingInTraining)
250 {
251 for(int n=0; n<region.size(); n++)
252 {
253 dCurrFeatureColumnStd+=
254 (cVector[region[n]]-dCurrFeatureColumnMean)*(cVector[region[n]]-dCurrFeatureColumnMean);
255 }
256 //unbiased std estimator:
257 dCurrFeatureColumnStd=sqrt(dCurrFeatureColumnStd/(region.size()-1));
258 }
259 //dCurrFeatureColumnStd is still 1.0 if we didn't want scaling
260 stdMatrix(m,0)=dCurrFeatureColumnStd;
261
262 meanMatrix(m,0)=dCurrFeatureColumnMean;
263
264 //get feature matrix, i.e. A (note that weighting is done automatically
265 //since rows can occur multiple times -> bagging)
266 for(int n=0; n<region.size(); n++)
267 xtrain(n,m)=(cVector[region[n]]-dCurrFeatureColumnMean)/dCurrFeatureColumnStd;
268 }
269
270// std::cout << "middle" << std::endl;
271 //get label vector (i.e. b)
272 for(int n=0; n<region.size(); n++)
273 {
274 //we checked for/built binary case further up.
275 //class labels should thus be either 0 or 1
276 //-> convert to -1 and 1 for regression
277 regrLabels(n,0)=((labels[region[n]]==0) ? -1:1);
278 }
279
280 MultiArray<2, double> dLambdas(dShape(11,1));
281 int nCounter=0;
282 for(int nLambda=-5; nLambda<=5; nLambda++)
283 dLambdas[nCounter++]=pow(10.0,nLambda);
284 //destination vector for regression coefficients; use same type as for xtrain
285 MultiArray<2, double> regrCoef(dShape(SB::ext_param_.actual_mtry_,11));
286 ridgeRegressionSeries(xtrain,regrLabels,regrCoef,dLambdas);
287
288 double dMaxRidgeSum=NumericTraits<double>::min();
289 double dCurrRidgeSum;
290 int nMaxRidgeSumAtLambdaInd=0;
291
292 for(int nLambdaInd=0; nLambdaInd<11; nLambdaInd++)
293 {
294 //just sum up the correct answers
295 //(correct means >=intercept for class 1, <intercept for class 0)
296 //(intercept=0 or intercept=threshold based on gini)
297 dCurrRidgeSum=0.0;
298
299 //assemble projection vector
300 MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
301
302 for(int n=0; n<region.oob_size(); n++)
303 {
304 dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
305 for (int m=0; m<SB::ext_param_.actual_mtry_; m++)
306 {
307 dDistanceFromHyperplane(region.oob_begin()[n],0)+=
308 features(region.oob_begin()[n],splitColumns[m])*regrCoef(m,nLambdaInd);
309 }
310 }
311
312 double dCurrIntercept=0.0;
313 if(m_bDoBestLambdaBasedOnGini)
314 {
315 //calculate gini index
316 bgfunc(dDistanceFromHyperplane,
317 labels,
318 region.oob_begin(), region.oob_end(),
319 region.classCounts());
320 dCurrIntercept=bgfunc.min_threshold_;
321 }
322 else
323 {
324 for (int m=0; m<SB::ext_param_.actual_mtry_; m++)
325 dCurrIntercept+=meanMatrix(m,0)*regrCoef(m,nLambdaInd);
326 }
327
328 for(int n=0; n<region.oob_size(); n++)
329 {
330 //check what lambda performs best on oob data
331 int nClassPrediction=((dDistanceFromHyperplane(region.oob_begin()[n],0) >=dCurrIntercept) ? 1:0);
332 dCurrRidgeSum+=((nClassPrediction == labels(region.oob_begin()[n],0)) ? 1:0);
333 }
334 if(dCurrRidgeSum>dMaxRidgeSum)
335 {
336 dMaxRidgeSum=dCurrRidgeSum;
337 nMaxRidgeSumAtLambdaInd=nLambdaInd;
338 }
339 }
340
341// std::cout << "middle2" << std::endl;
342 //create a Node for output
343 Node<i_HyperplaneNode> node(SB::ext_param_.actual_mtry_, SB::t_data, SB::p_data);
344
345 //normalise coeffs
346 //data was scaled (by 1.0 or by std) -> take into account
347 MultiArray<2, double> dCoeffVector(dShape(SB::ext_param_.actual_mtry_,1));
348 for(int n=0; n<SB::ext_param_.actual_mtry_; n++)
349 dCoeffVector(n,0)=regrCoef(n,nMaxRidgeSumAtLambdaInd)*stdMatrix(n,0);
350
351 //calc norm
352 double dVnorm=columnVector(regrCoef,nMaxRidgeSumAtLambdaInd).norm();
353
354 for(int n=0; n<SB::ext_param_.actual_mtry_; n++)
355 node.weights()[n]=dCoeffVector(n,0)/dVnorm;
356 //_normalise coeffs
357
358 //save the columns
359 node.column_data()[0]=SB::ext_param_.actual_mtry_;
360 for(int n=0; n<SB::ext_param_.actual_mtry_; n++)
361 node.column_data()[n+1]=splitColumns[n];
362
363 //assemble projection vector
364 //careful here: "region" is a pointer to indices...
365 //all the indices in "region" need to have valid data
366 //convert from "region" space to original "feature" space
367 MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
368
369 for(int n=0; n<region.size(); n++)
370 {
371 dDistanceFromHyperplane(region[n],0)=0.0;
372 for (int m=0; m<SB::ext_param_.actual_mtry_; m++)
373 {
374 dDistanceFromHyperplane(region[n],0)+=
375 features(region[n],m)*node.weights()[m];
376 }
377 }
378 for(int n=0; n<region.oob_size(); n++)
379 {
380 dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
381 for (int m=0; m<SB::ext_param_.actual_mtry_; m++)
382 {
383 dDistanceFromHyperplane(region.oob_begin()[n],0)+=
384 features(region.oob_begin()[n],m)*node.weights()[m];
385 }
386 }
387
388 //calculate gini index
389 bgfunc(dDistanceFromHyperplane,
390 labels,
391 region.begin(), region.end(),
392 region.classCounts());
393
394 // did not find any suitable split
395 if(closeAtTolerance(bgfunc.min_gini_, NumericTraits<double>::max()))
396 return SB::makeTerminalNode(features, multiClassLabels, region, randint);
397
398 //take gini threshold here due to scaling, normalisation, etc. of the coefficients
399 node.intercept() = bgfunc.min_threshold_;
400 SB::node_ = node;
401
402 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
403 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
404 childRegions[0].classCountsIsValid = true;
405 childRegions[1].classCountsIsValid = true;
406
407 // Save the ranges of the child stack entries.
408 childRegions[0].setRange( region.begin() , region.begin() + bgfunc.min_index_ );
409 childRegions[0].rule = region.rule;
410 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
411 childRegions[1].setRange( region.begin() + bgfunc.min_index_ , region.end() );
412 childRegions[1].rule = region.rule;
413 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
414
415 //adjust oob ranges
416// std::cout << "adjust oob" << std::endl;
417 //sort the oobs
418 std::sort(region.oob_begin(), region.oob_end(),
419 SortSamplesByDimensions< MultiArray<2, double> > (dDistanceFromHyperplane, 0));
420
421 //find split index
422 int nOOBindx;
423 for(nOOBindx=0; nOOBindx<region.oob_size(); nOOBindx++)
424 {
425 if(dDistanceFromHyperplane(region.oob_begin()[nOOBindx],0)>=node.intercept())
426 break;
427 }
428
429 childRegions[0].set_oob_range( region.oob_begin() , region.oob_begin() + nOOBindx );
430 childRegions[1].set_oob_range( region.oob_begin() + nOOBindx , region.oob_end() );
431
432// std::cout << "end" << std::endl;
433// outm2(region.oob_begin());outm2(nOOBindx);outm(region.oob_begin() + nOOBindx);
434 //_adjust oob ranges
435
436 return i_HyperplaneNode;
437 }
438};
439
440/** Standard ridge regression split
441 */
442typedef RidgeSplit<BestGiniOfColumn<GiniCriterion> > GiniRidgeSplit;
443
444
445} //namespace vigra
446#endif // VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
static double impurity(Array const &hist, double total)
Definition rf_split.hxx:443
MultiArrayShape< actual_dimension >::type difference_type
Definition multi_array.hxx:739
int makeTerminalNode(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region &region, Random)
Definition rf_split.hxx:168
void set_external_parameters(ProblemSpec< T > const &in)
Definition rf_split.hxx:112
RidgeSplit< BestGiniOfColumn< GiniCriterion > > GiniRidgeSplit
Definition rf_ridge_split.hxx:442
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition fixedpoint.hxx:616
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality.
Definition mathutil.hxx:1638

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.2