35#ifndef RF_VISITORS_HXX
36#define RF_VISITORS_HXX
39# include "vigra/hdf5impex.hxx"
41#include <vigra/windows.h>
46#include <vigra/metaprogramming.hxx>
47#include <vigra/multi_pointoperators.hxx>
142 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
148 Feature_t & features,
151 ignore_argument(tree,split,parent,leftChild,rightChild,features,labels);
163 template<
class RF,
class PR,
class SM,
class ST>
166 ignore_argument(rf,pr,sm,st,index);
175 template<
class RF,
class PR>
178 ignore_argument(rf,pr);
187 template<
class RF,
class PR>
190 ignore_argument(rf,pr);
205 template<
class TR,
class IntT,
class TopT,
class Feat>
208 ignore_argument(tr,index,node_t,features);
215 template<
class TR,
class IntT,
class TopT,
class Feat>
254template <
class Visitor,
class Next = StopVisiting>
264 next_(next), visitor_(visitor)
269 next_(stop_), visitor_(visitor)
272 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
273 void visit_after_split( Tree & tree,
278 Feature_t & features,
281 if(visitor_.is_active())
282 visitor_.visit_after_split(tree, split,
283 parent, leftChild, rightChild,
285 next_.visit_after_split(tree, split, parent, leftChild, rightChild,
289 template<
class RF,
class PR,
class SM,
class ST>
290 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st,
int index)
292 if(visitor_.is_active())
293 visitor_.visit_after_tree(rf, pr, sm, st, index);
294 next_.visit_after_tree(rf, pr, sm, st, index);
297 template<
class RF,
class PR>
298 void visit_at_beginning(RF & rf, PR & pr)
300 if(visitor_.is_active())
301 visitor_.visit_at_beginning(rf, pr);
302 next_.visit_at_beginning(rf, pr);
304 template<
class RF,
class PR>
305 void visit_at_end(RF & rf, PR & pr)
307 if(visitor_.is_active())
308 visitor_.visit_at_end(rf, pr);
309 next_.visit_at_end(rf, pr);
312 template<
class TR,
class IntT,
class TopT,
class Feat>
313 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
315 if(visitor_.is_active())
316 visitor_.visit_external_node(tr, index, node_t,features);
317 next_.visit_external_node(tr, index, node_t,features);
319 template<
class TR,
class IntT,
class TopT,
class Feat>
320 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
322 if(visitor_.is_active())
323 visitor_.visit_internal_node(tr, index, node_t,features);
324 next_.visit_internal_node(tr, index, node_t,features);
329 if(visitor_.is_active() && visitor_.has_value())
330 return visitor_.return_val();
331 return next_.return_val();
355template<
class A,
class B>
356detail::VisitorNode<A, detail::VisitorNode<B> >
369template<
class A,
class B,
class C>
370detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
385template<
class A,
class B,
class C,
class D>
386detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
387 detail::VisitorNode<D> > > >
404template<
class A,
class B,
class C,
class D,
class E>
405detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
406 detail::VisitorNode<D, detail::VisitorNode<E> > > > >
426template<
class A,
class B,
class C,
class D,
class E,
428detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
429 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
451template<
class A,
class B,
class C,
class D,
class E,
453detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
454 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
455 detail::VisitorNode<G> > > > > > >
457 D & d, E & e, F & f, G & g)
479template<
class A,
class B,
class C,
class D,
class E,
480 class F,
class G,
class H>
481detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
482 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
483 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
510template<
class A,
class B,
class C,
class D,
class E,
511 class F,
class G,
class H,
class I>
512detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
513 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
514 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
542template<
class A,
class B,
class C,
class D,
class E,
543 class F,
class G,
class H,
class I,
class J>
544detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
545 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
546 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
547 detail::VisitorNode<J> > > > > > > > > >
588 bool adjust_thresholds;
598 adjust_thresholds(
false), tree_id(0), last_node_id(0), current_label(0)
600 struct MarginalDistribution
603 Int32 leftTotalCounts;
605 Int32 rightTotalCounts;
612 struct TreeOnlineInformation
614 std::vector<MarginalDistribution> mag_distributions;
615 std::vector<IndexList> index_lists;
617 std::map<int,int> interior_to_index;
619 std::map<int,int> exterior_to_index;
623 std::vector<TreeOnlineInformation> trees_online_information;
627 template<
class RF,
class PR>
631 trees_online_information.resize(rf.options_.tree_count_);
638 trees_online_information[tree_id].mag_distributions.clear();
639 trees_online_information[tree_id].index_lists.clear();
640 trees_online_information[tree_id].interior_to_index.clear();
641 trees_online_information[tree_id].exterior_to_index.clear();
646 template<
class RF,
class PR,
class SM,
class ST>
652 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
653 void visit_after_split( Tree & tree,
658 Feature_t & features,
662 int addr=tree.topology_.size();
663 if(split.createNode().typeID() == i_ThresholdNode)
665 if(adjust_thresholds)
668 linear_index=trees_online_information[tree_id].mag_distributions.size();
669 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
670 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
672 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
673 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
675 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
676 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
678 double gap_left,gap_right;
680 gap_left=features(leftChild[0],split.bestSplitColumn());
681 for(i=1;i<leftChild.size();++i)
682 if(features(leftChild[i],split.bestSplitColumn())>gap_left)
683 gap_left=features(leftChild[i],split.bestSplitColumn());
684 gap_right=features(rightChild[0],split.bestSplitColumn());
685 for(i=1;i<rightChild.size();++i)
686 if(features(rightChild[i],split.bestSplitColumn())<gap_right)
687 gap_right=features(rightChild[i],split.bestSplitColumn());
688 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
689 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
695 linear_index=trees_online_information[tree_id].index_lists.size();
696 trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
698 trees_online_information[tree_id].index_lists.push_back(IndexList());
700 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
701 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
704 void add_to_index_list(
int tree,
int node,
int index)
708 TreeOnlineInformation &ti=trees_online_information[tree];
709 ti.index_lists[ti.exterior_to_index[node]].push_back(index);
711 void move_exterior_node(
int src_tree,
int src_index,
int dst_tree,
int dst_index)
715 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
716 trees_online_information[src_tree].exterior_to_index.erase(src_index);
723 template<
class TR,
class IntT,
class TopT,
class Feat>
727 if(adjust_thresholds)
729 vigra_assert(node_t==i_ThresholdNode,
"We can only visit threshold nodes");
731 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
732 TreeOnlineInformation &ti=trees_online_information[tree_id];
733 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
734 if(value>m.gap_left && value<m.gap_right)
737 if(m.leftCounts[current_label]/
double(m.leftTotalCounts)>m.rightCounts[current_label]/
double(m.rightTotalCounts))
747 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
750 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
752 ++m.rightTotalCounts;
753 ++m.rightCounts[current_label];
758 ++m.rightCounts[current_label];
806 template<
class RF,
class PR,
class SM,
class ST>
810 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
812 oobCount.resize(rf.ext_param_.row_count_, 0);
813 oobErrorCount.resize(rf.ext_param_.row_count_, 0);
816 for(
int l = 0; l < rf.ext_param_.row_count_; ++l)
823 .predictLabel(rowVector(pr.features(), l))
824 != pr.response()(l,0))
835 template<
class RF,
class PR>
839 for(
int l=0; l < static_cast<int>(rf.ext_param_.row_count_); ++l)
843 oobError += double(oobErrorCount[l]) / oobCount[l];
881 void save(std::string filen, std::string pathn)
883 if(*(pathn.end()-1) !=
'/')
885 const char* filename = filen.c_str();
888 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
894 template<
class RF,
class PR>
895 void visit_at_beginning(RF & rf, PR &)
897 class_count = rf.class_count();
898 tmp_prob.
reshape(Shp(1, class_count), 0);
899 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
900 is_weighted = rf.options().predict_weighted_;
901 indices.resize(rf.ext_param().row_count_);
902 if(
int(oobCount.size()) != rf.ext_param_.row_count_)
904 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
906 for(
int ii = 0; ii < rf.ext_param().row_count_; ++ii)
912 template<
class RF,
class PR,
class SM,
class ST>
913 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &,
int index)
919 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
921 ArrayVector<int> oob_indices;
922 ArrayVector<int> cts(class_count, 0);
923 std::random_device rd;
924 std::mt19937 g(rd());
925 std::shuffle(indices.
begin(), indices.
end(), g);
926 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
928 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
930 oob_indices.push_back(indices[ii]);
931 ++cts[pr.response()(indices[ii], 0)];
934 for(
unsigned int ll = 0; ll < oob_indices.size(); ++ll)
937 ++oobCount[oob_indices[ll]];
940 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
941 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
942 rf.tree(index).parameters_,
945 for(
int ii = 0; ii < class_count; ++ii)
947 tmp_prob[ii] = node.prob_begin()[ii];
951 for(
int ii = 0; ii < class_count; ++ii)
952 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
954 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
959 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
962 if(!sm.is_used()[ll])
968 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
969 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
970 rf.tree(index).parameters_,
973 for(
int ii = 0; ii < class_count; ++ii)
975 tmp_prob[ii] = node.prob_begin()[ii];
979 for(
int ii = 0; ii < class_count; ++ii)
980 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
982 rowVector(prob_oob, ll) += tmp_prob;
991 template<
class RF,
class PR>
995 int totalOobCount =0;
996 int breimanstyle = 0;
997 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1001 if(
argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1072 void save(std::string filen, std::string pathn)
1074 if(*(pathn.end()-1) !=
'/')
1076 const char* filename = filen.c_str();
1082 writeHDF5(filename, (pathn +
"per_tree_error").c_str(), temp);
1084 writeHDF5(filename, (pathn +
"per_tree_error_std").c_str(), temp);
1086 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
1088 writeHDF5(filename, (pathn +
"ulli_error").c_str(), temp);
1094 template<
class RF,
class PR>
1095 void visit_at_beginning(RF & rf, PR &)
1097 class_count = rf.class_count();
1098 if(class_count == 2)
1102 tmp_prob.
reshape(Shp(1, class_count), 0);
1103 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1104 is_weighted = rf.options().predict_weighted_;
1108 if(
int(oobCount.size()) != rf.ext_param_.row_count_)
1110 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1111 oobErrorCount.
reshape(Shp(rf.ext_param_.row_count_,1), 0);
1115 template<
class RF,
class PR,
class SM,
class ST>
1116 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &,
int index)
1121 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1124 if(!sm.is_used()[ll])
1132 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1133 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1134 rf.tree(index).parameters_,
1137 for(
int ii = 0; ii < class_count; ++ii)
1139 tmp_prob[ii] = node.prob_begin()[ii];
1143 for(
int ii = 0; ii < class_count; ++ii)
1144 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1146 rowVector(prob_oob, ll) += tmp_prob;
1147 int label =
argMax(tmp_prob);
1149 if(label != pr.response()(ll, 0))
1154 ++oobErrorCount[ll];
1158 int breimanstyle = 0;
1159 int totalOobCount = 0;
1160 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1164 if(
argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1169 oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1177 MultiArrayView<3, double> current_roc
1179 for(
int gg = 0; gg < current_roc.shape(2); ++gg)
1181 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1185 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1187 current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1190 current_roc.bindOuter(gg)/= totalOobCount;
1194 oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1200 template<
class RF,
class PR>
1205 int totalOobCount =0;
1206 int breimanstyle = 0;
1207 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1211 if(
argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1258 int repetition_count_;
1262 void save(std::string filename, std::string prefix)
1264 prefix =
"variable_importance_" + prefix;
1277 : repetition_count_(rep_cnt)
1284 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1295 Int32 const class_count = tree.ext_param_.class_count_;
1296 Int32 const column_count = tree.ext_param_.column_count_;
1305 if(split.createNode().typeID() == i_ThresholdNode)
1307 Node<i_ThresholdNode> node(split.createNode());
1309 += split.region_gini_ - split.minGini();
1319 template<
class RF,
class PR,
class SM,
class ST>
1323 Int32 column_count = rf.ext_param_.column_count_;
1324 Int32 class_count = rf.ext_param_.class_count_;
1334 typedef typename PR::FeatureWithMemory_t FeatureArray;
1335 typedef typename FeatureArray::value_type FeatureValue;
1337 FeatureArray features = pr.features();
1341 ArrayVector<Int32>::iterator
1343 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1344 if(!sm.is_used()[ii])
1345 oob_indices.push_back(ii);
1351#ifdef CLASSIFIER_TEST
1362 oob_right(Shp_t(1, class_count + 1));
1364 perm_oob_right (Shp_t(1, class_count + 1));
1368 for(iter = oob_indices.
begin();
1369 iter != oob_indices.
end();
1373 .predictLabel(rowVector(features, *iter))
1374 == pr.response()(*iter, 0))
1377 ++oob_right[pr.response()(*iter,0)];
1379 ++oob_right[class_count];
1383 for(
int ii = 0; ii < column_count; ++ii)
1385 perm_oob_right.
init(0.0);
1387 backup_column.clear();
1388 for(iter = oob_indices.
begin();
1389 iter != oob_indices.
end();
1392 backup_column.push_back(features(*iter,ii));
1396 for(
int rr = 0; rr < repetition_count_; ++rr)
1399 int n = oob_indices.
size();
1400 for(
int jj = n-1; jj >= 1; --jj)
1401 std::swap(features(oob_indices[jj], ii),
1402 features(oob_indices[randint(jj+1)], ii));
1405 for(iter = oob_indices.
begin();
1406 iter != oob_indices.
end();
1410 .predictLabel(rowVector(features, *iter))
1411 == pr.response()(*iter, 0))
1414 ++perm_oob_right[pr.response()(*iter, 0)];
1416 ++perm_oob_right[class_count];
1423 perm_oob_right /= repetition_count_;
1424 perm_oob_right -=oob_right;
1425 perm_oob_right *= -1;
1426 perm_oob_right /= oob_indices.
size();
1428 .subarray(Shp_t(ii,0),
1429 Shp_t(ii+1,class_count+1)) += perm_oob_right;
1431 for(
int jj = 0; jj < int(oob_indices.
size()); ++jj)
1432 features(oob_indices[jj], ii) = backup_column[jj];
1441 template<
class RF,
class PR,
class SM,
class ST>
1449 template<
class RF,
class PR>
1462 template<
class RF,
class PR,
class SM,
class ST>
1463 void visit_after_tree(RF& rf, PR &, SM &, ST &,
int index){
1464 if(index != rf.options().tree_count_-1) {
1465 std::cout <<
"\r[" << std::setw(10) << (index+1)/
static_cast<double>(rf.options().tree_count_)*100 <<
"%]"
1466 <<
" (" << index+1 <<
" of " << rf.options().tree_count_ <<
") done" << std::flush;
1469 std::cout <<
"\r[" << std::setw(10) << 100.0 <<
"%]" << std::endl;
1473 template<
class RF,
class PR>
1474 void visit_at_end(RF
const & rf, PR
const &) {
1475 std::string a =
TOCS;
1476 std::cout <<
"all " << rf.options().tree_count_ <<
" trees have been learned in " << a << std::endl;
1479 template<
class RF,
class PR>
1480 void visit_at_beginning(RF
const & rf, PR
const &) {
1482 std::cout <<
"growing random forest, which will have " << rf.options().tree_count_ <<
" trees" << std::endl;
1530 void save(std::string, std::string)
1548 template<
class RF,
class PR>
1549 void visit_at_beginning(RF
const & rf, PR & pr)
1552 int n = rf.ext_param_.column_count_;
1555 corr_l.
reshape(Shp(n +1, 10));
1558 noise_l.
reshape(Shp(pr.features().shape(0), 10));
1560 for(
int ii = 0; ii <
noise.size(); ++ii)
1562 noise[ii] = random.uniform53();
1563 noise_l[ii] = random.uniform53() > 0.5;
1565 bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1566 tmp_labels.
reshape(pr.response().shape());
1571 template<
class RF,
class PR>
1572 void visit_at_end(RF
const &, PR
const &)
1581 for(
int jj = 0; jj < rC-1; ++jj)
1584 rowVector(
similarity, jj) -= mean_noise(jj, 0);
1586 for(
int jj = 0; jj < rC; ++jj)
1590 rowVector(
similarity, rC - 1) -= mean_noise(rC-1, 0);
1592 FindMinMax<double> minmax;
1595 for(
int jj = 0; jj < rC; ++jj)
1598 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1599 +=
similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1600 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1602 for(
int jj = 0; jj < rC; ++jj)
1605 FindMinMax<double> minmax2;
1607 for(
int jj = 0; jj < rC; ++jj)
1613 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1614 void visit_after_split( Tree &,
1619 Feature_t & features,
1622 if(split.createNode().typeID() == i_ThresholdNode)
1626 for(
int ii = 0; ii < parent.size(); ++ii)
1628 tmp_labels[parent[ii]]
1629 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1630 ++tmp_cc[tmp_labels[parent[ii]]];
1632 double region_gini = bgfunc.loss_of_region(tmp_labels,
1637 int n = split.bestSplitColumn();
1641 for(
int k = 0; k < features.shape(1); ++k)
1643 bgfunc(columnVector(features, k),
1645 parent.begin(), parent.end(),
1647 wgini = (region_gini - bgfunc.min_gini_);
1651 for(
int k = 0; k < 10; ++k)
1653 bgfunc(columnVector(
noise, k),
1655 parent.begin(), parent.end(),
1657 wgini = (region_gini - bgfunc.min_gini_);
1662 for(
int k = 0; k < 10; ++k)
1664 bgfunc(columnVector(noise_l, k),
1666 parent.begin(), parent.end(),
1668 wgini = (region_gini - bgfunc.min_gini_);
1672 bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1673 wgini = (region_gini - bgfunc.min_gini_);
1677 region_gini = split.region_gini_;
1679 Node<i_ThresholdNode> node(split.createNode());
1682 +=split.region_gini_ - split.minGini();
1684 for(
int k = 0; k < 10; ++k)
1686 split.bgfunc(columnVector(
noise, k),
1688 parent.begin(), parent.end(),
1689 parent.classCounts());
1695 for(
int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1697 wgini = region_gini - split.min_gini_[k];
1700 split.splitColumns[k])
1704 for(
int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1706 split.bgfunc(columnVector(features, split.splitColumns[k]),
1708 parent.begin(), parent.end(),
1709 parent.classCounts());
1710 wgini = region_gini - split.bgfunc.min_gini_;
1712 split.splitColumns[k]) += wgini;
1719 SortSamplesByDimensions<Feature_t>
1720 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1721 std::partition(parent.begin(), parent.end(), sorter);
const_iterator begin() const
Definition array_vector.hxx:223
const_pointer data() const
Definition array_vector.hxx:209
size_type size() const
Definition array_vector.hxx:358
void init(U const &initial)
Definition array_vector.hxx:146
const_iterator end() const
Definition array_vector.hxx:237
Definition array_vector.hxx:514
Definition rf_split.hxx:832
Base class for, and view to, vigra::MultiArray.
Definition multi_fwd.hxx:127
Main MultiArray class containing the memory management.
Definition multi_fwd.hxx:131
void reshape(const difference_type &shape)
Definition multi_array.hxx:2863
MultiArray & init(const U &init)
Definition multi_array.hxx:2853
Definition random.hxx:346
Class for fixed size vectors.
Definition tinyvector.hxx:1008
Definition rf_visitors.hxx:1014
double oob_per_tree2
Definition rf_visitors.hxx:1043
MultiArray< 2, double > breiman_per_tree
Definition rf_visitors.hxx:1048
double oob_mean
Definition rf_visitors.hxx:1026
double oob_breiman
Definition rf_visitors.hxx:1036
MultiArray< 2, double > oob_per_tree
Definition rf_visitors.hxx:1023
void visit_at_end(RF &rf, PR &pr)
Definition rf_visitors.hxx:1201
MultiArray< 4, double > oobroc_per_tree
Definition rf_visitors.hxx:1065
double oob_std
Definition rf_visitors.hxx:1029
Definition rf_visitors.hxx:1494
MultiArray< 2, double > distance
Definition rf_visitors.hxx:1522
MultiArray< 2, double > corr_noise
Definition rf_visitors.hxx:1507
MultiArray< 2, double > gini_missc
Definition rf_visitors.hxx:1499
MultiArray< 2, double > similarity
Definition rf_visitors.hxx:1519
ArrayVector< int > numChoices
Definition rf_visitors.hxx:1527
MultiArray< 2, double > noise
Definition rf_visitors.hxx:1503
Definition rf_visitors.hxx:865
double oob_breiman
Definition rf_visitors.hxx:875
void visit_at_end(RF &rf, PR &pr)
Definition rf_visitors.hxx:992
Definition rf_visitors.hxx:784
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_visitors.hxx:807
double oobError
Definition rf_visitors.hxx:788
void visit_at_end(RF &rf, PR &)
Definition rf_visitors.hxx:836
Definition rf_visitors.hxx:585
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition rf_visitors.hxx:724
void reset_tree(int tree_id)
Definition rf_visitors.hxx:636
void visit_after_tree(RF &, PR &, SM &, ST &, int)
Definition rf_visitors.hxx:647
void visit_at_beginning(RF &rf, const PR &)
Definition rf_visitors.hxx:628
Definition rf_visitors.hxx:1458
Definition rf_visitors.hxx:236
Definition rf_visitors.hxx:1229
void visit_after_split(Tree &tree, Split &split, Region &, Region &, Region &, Feature_t &, Label_t &)
Definition rf_visitors.hxx:1285
void visit_at_end(RF &rf, PR &)
Definition rf_visitors.hxx:1450
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_visitors.hxx:1442
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_visitors.hxx:1320
MultiArray< 2, double > variable_importance_
Definition rf_visitors.hxx:1257
Definition rf_visitors.hxx:103
void visit_at_beginning(RF const &rf, PR const &pr)
Definition rf_visitors.hxx:188
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition rf_visitors.hxx:206
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition rf_visitors.hxx:143
void visit_internal_node(TR &, IntT, TopT, Feat &)
Definition rf_visitors.hxx:216
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_visitors.hxx:164
void visit_at_end(RF const &rf, PR const &pr)
Definition rf_visitors.hxx:176
double return_val()
Definition rf_visitors.hxx:226
Definition rf_visitors.hxx:256
detail::VisitorNode< A > create_visitor(A &a)
Definition rf_visitors.hxx:345
void writeHDF5(...)
Store array data in an HDF5 file.
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition fftw3.hxx:1002
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
#define TIC
Definition timing.hxx:321
#define TOCS
Definition timing.hxx:324