0 votes
1 view
in Machine Learning by (17.4k points)

I've been reading up on Decision Trees and Cross Validation, and I understand both concepts. However, I'm having trouble understanding Cross-Validation as it pertains to Decision Trees. Essentially Cross Validation allows you to alternate between training and testing when your dataset is relatively small to maximize your error estimation. A very simple algorithm goes something like this:

  1. Decide on the number of folds you want (k)

  2. Subdivide your dataset into k folds

  3. Use k-1 folds for training set to build a tree.

  4. Use the testing set to estimate statistics about the error in your tree.

  5. Save your results for later

  6. Repeat steps 3-6 for k times leaving out a different fold for your test set.

  7. Average the errors across your iterations to predict the overall error

The problem I can't figure out is, in the end, you'll have k Decision trees that could all be slightly different because they might not split the same way, etc. Which tree do you pick? One idea I had picked the one with minimal errors (although that doesn't make it optimal just that it performed best on the fold it was given - maybe using stratification will help but everything I've read says it only helps a little bit).

As I understand cross-validation the point is to compute in node statistics that can later be used for pruning. So really each node in the tree will have statistics calculated for it based on the test set given to it. What's important are these in node stats, but if you're averaging your error. How do you merge these stats within each node across k trees when each tree could vary in what they choose to split on, etc.

What's the point of calculating the overall error across each iteration? That's not something that could be used during pruning.

Any help with this little wrinkle would be much appreciated.

1 Answer

0 votes
by (33.2k points)

After decision tree implementation, the decision trees can be quite different, because usually, they all split in different ways, according to different decision parameters.

Cross-validation is not generally used to select a particular instance of the classifier, but mainly used to provide the metrics, which can return average error ratio, deviation relative to this average, etc., which can be useful to get the level of precision, that is expected from the model.

Training the model on 100% training data will more likely produce a better tree. 

Cross-validation also helps to provide stats from different nodes of trees, which helps us to prune the less useful nodes from the tree.

Welcome to Intellipaat Community. Get your technical queries answered by top developers !