I’ve seen lots on how decision trees work, but often the descriptions are textual. I’m a visual person, so let’s do a visual exploration.

The code for this little project is available in my GitHub here: https://github.com/jamesdeluk/data-projects/tree/main/visualising-trees

I’m not going to talk about how the splits are done mathematically (i.e. how they minimise MAE/MSE for regression models, or how they calculate Gini impurity or entropy for classification models); this is more from the use-case perspective, understanding the impact of different hyperparameters on getting the result you want.

Setting up

I’m going to use scikit-learn’s DecisionTreeRegressor. From the hyperparameter persepctive, classification decision trees work similarly to regression ones, so I won’t discuss them separately.

I’m going to use the California housing dataset, available through scikit-learn. It looks something like this:

MedIncHouseAgeAveRoomsAveBedrmsPopulationAveOccupLatitudeLongitudeMedHouseVal
8.3252416.984126981.023809523222.5555555637.88-122.234.526
8.3014216.238137080.9718804924012.1098418337.86-122.223.585
7.2574528.288135591.073446334962.8022598937.85-122.243.521
5.6431525.81735161.073059365582.5479452137.85-122.253.413
3.8462526.281853281.081081085652.1814671837.85-122.253.422

Each row is a “block group”, a geographical area. The columns are, in order: median income, median house age, average number of rooms, average number of bedrooms, population, average number of occupants, latitude, longitude, and the median house value (the target). The target values range from 0.15 to 5.00, with a mean of 2.1.

I set aside the last item to use as my own personal tester:

MedIncHouseAgeAveRoomsAveBedrmsPopulationAveOccupLatitudeLongitudeMedHouseVal
2.3886165.254716981.1622641513872.6169811339.37-121.240.894

I’ll use train_test_split to create training and testing data, which I’ll use to compare the trees.

Tree depth

Shallow

I’ll start with a small tree, with max_depth of 3. I’ll use timeit to record how long it takes to fit and predict. To get a more accurate timing, I took the mean of 10 fit-and-predicts.

It took 0.024s to fit, 0.0002 to predict, and resulted in a mean absolute error (MAE) of 0.6, a mean absolute percentage error (MAPE) of 0.38 (i.e. 38%), a mean squared error (MSE) of 0.65, a root mean squared error (RMSE) of 0.80, and an R² of 0.50. Note that for R², unlike the previous error stats, the higher the better. For my chosen block, it predicted 1.183, vs 0.894 actual. Overall, not great.

This is the tree itself, using plot_tree:

shallow tree

You can see it only uses the MedInc, AveRooms, and AveOccup features - in other words, removing HouseAge, AveBedrms, Population, Latitude, and Longitude from the dataset would give the same predictions.

Deep

Let’s go to max_depth of None, i.e. unlimited.

It took 0.09s to fit (~4x longer), 0.0007 to predict (~4x longer), and resulted in an MAE of 0.47, an MAPE of 0.26, an MSE of 0.53, an RMSE of 0.73, and an R² of 0.60. For my chosen block, it predicted 0.819, vs 0.894 actual. Much better.

The tree:

deep tree

Wow. It has 34 levels (.get_depth()), 29,749 nodes (.tree_.node_count), and 14,875 individual branches (.get_n_leaves()) - in other words, up to 14,875 different final values for MedHouseVal.

Using some custom code, I can plot one of the branches:

single branch

This branch alone uses six of the eight features, so it’s likely that, across all ~15,000 branches, all features are represented.

However, a tree this complex can lead to overfitting, as it can split into very small groups and capture noise.

Pruned

The ccp_alpha parameter (ccp = cost-complexity pruning) can prune a tree after it’s built. Adding in a value of 0.005 to the unlimited depth tree results in an MAE of 0.53, an MAPE of 0.33, an MSE of 0.52, an RMSE of 0.72, and an R² of 0.60 - so it performed between the deep and shallow trees. For my chosen block, it predicted 1.279, so in this case, worse than the shallow one. It took 0.64s to fit (>6x longer than the deep tree) and 0.0002 to predict (the same as the shallow tree) - so, it’s slow to fit, but fast to predict.

This tree looks like:

pruned tree

Cross validating

What if we mix up the data? Within a loop, I used train_test_split without a random state (to get new data each time), and fitted and predicted each tree based on the new data. Every loop I recorded the MAE/MAPE/MSE/RMSE/R², and then found the mean and standard deviation for each. I did 1000 loops. This helps (as the name suggests) validate our results - a single high or low error result could simply be a fluke, so taking the mean gives a better idea of the typical error on new data, and the standard deviation helps understand how stable/reliable a model is.

It’s worth noting that sklearn has some built-in tools for this form of validation, namely cross_validation, using ShuffleSplit or RepeatedKFold, and they’re typically much faster; I just did it manually to make it clearer what was going on, and to emphasise the time difference.

max_depth=3 (time: 22.1s)

MetricMeanStd
MAE0.5970.007
MAPE0.3780.008
MSE0.6330.015
RMSE0.7950.009
0.5240.011

max_depth=None (time: 100.0s)

MetricMeanStd
MAE0.4630.010
MAPE0.2530.008
MSE0.5240.023
RMSE0.7240.016
0.6060.018

max_depth=None, ccp_alpha=0.005 (time: 650.2s)

MetricMeanStd
MAE0.5310.012
MAPE0.3250.012
MSE0.5210.021
RMSE0.7220.015
0.6090.016

Compared with the deep tree, across all error stats, the shallow tree has higher errors (also known as biases), but lower standard deviations (also known as variances). In more casual terminology, there’s a trade-off between precision (all predictions being close together) and accuracy (all predictions being near the true value). The pruned deep tree generally performed between the two, but took far longer to fit.

We can visualise all the stats these with box plots:

box plots of errors for trees

We can see the deep trees (green boxes) typically have lower errors (smaller y-axis value) but larger variations (larger gap between the lines) than the shallow tree (blue boxes). Normalising the means (so they’re all 0), we can see the variation more clearly; for example, for the MAEs:

normalised mae box plots

Histograms can also be interesting. Again for the MAEs:

histographs of errors for trees

The green (deep) has lower errors, but the blue (shallow) has a narrower band. Interestingly, the pruned tree results are less normal than the other two - although this is not typical behaviour.

Other hyperparameters

What are the other hyperparameters we can tweak? The full list can be found in the docs: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

Minimum samples to split

This is the minimum number of samples of the total that an individual node can contain to allow splitting. It can be a number or a percentage (implemented as a float between 0 and 1). It helps avoid overfitting by ensure each branch contains a decent number of results, rather than splitting into smaller and smaller branches based on only a few samples.

For example, max_depth=10, which I’ll use as a reference, looks like:

deep tree

MetricMeanStd
MAE0.4260.010
MAPE0.2400.008
MSE0.4130.018
RMSE0.6430.014
0.6900.014

That’s 1563 nodes and 782 leaves.

Whereas max_depth=10, min_samples_split=0.2 looks like:

min_samples_split tree

MetricMeanStd
MAE0.6050.013
MAPE0.3670.007
MSE0.6520.027
RMSE0.8070.016
0.5100.019

Because it can’t split any node with fewer than 20% (0.2) of the total samples (as you can see in the leaves samples %), it’s limited to a depth of 4, with only 15 nodes and 8 leaves.

For the tree with depth 10, many of the leaves contained a single sample. Having so many leaves with so few sample can be a sign of overfitting. For the constrained tree, the smallest leaf contains over 1000 samples.

In this case, the constrained tree is worse than the unconstrained tree on all counts; however, setting min_samples_split to 10 (i.e. 10 samples, not 10%) improved the results:

MetricMeanStd
MAE0.4250.009
MAPE0.2400.008
MSE0.4070.017
RMSE0.6380.013
0.6950.013

This one was back to depth 10, with 1133 nodes and 567 leaves (so about 1/3 less than the unconstrained tree). Many of these leaves also contain a single sample.

Minimum samples per leaf

Another way of constraining a tree is by setting a minimum number of samples a leaf can have. Again, this can be a number or a percentage.

With max_depth=10, min_samples_leaf=0.1:

min_samples_leaf tree

Similar to the first min_samples_split one, it has a depth of 4, 15 nodes, and 8 leaves. However, notice the nodes and leaves are different; for example, in the right-most leaf in the min_samples_split tree, there were 5.8% of the samples, whereas in this one, the “same” leaf has 10% (that’s the 0.1).

The stats are similar to that one also:

MetricMeanStd
MAE0.6090.010
MAPE0.3670.007
MSE0.6590.023
RMSE0.8110.014
0.5050.016

Allowing “larger” leaves can improve results. min_samples_leaf=10 has depth 10, 961 nodes and 481 leaves - so similar to a min_samples_split=10. It gives our best results so far, suggesting limiting the number of 1-sample leaves has indeed reduced overfitting.

MetricMeanStd
MAE0.4170.010
MAPE0.2350.008
MSE0.3800.017
RMSE0.6160.014
0.7140.013

Maximum leaf nodes

Another way to stop having too many leaves with too few samples is to limit the number of leaves directly with max_leaf_nodes (technically it could still result in a single-sample leaf, but it’s less likely). The trees above above varied from 8 to almost 800 leaves. With max_depth=10, max_leaf_nodes=100:

max_leaf_nodes tree

This has a depth of 10 again, with 199 nodes and 100 leaves. In this case, there was only one leaf with a single sample, and only nine of them had fewer than ten samples. The results were decent too:

MetricMeanStd
MAE0.4500.010
MAPE0.2640.010
MSE0.4140.018
RMSE0.6440.014
0.6890.013

Bayes searching

Finally, what’s the “perfect” tree for this data? Sure, it’s possible to use trial-and-error with the above hyperparamters, but it’s much easier to use something like BayesSearchCV (assuming you have the time to let it run). In 20 minutes it performed 200 iterations (i.e. hyperparameter combinations) with five cross-validations (similar to five train_test_splits) each.

The hyperparameters it found: {'ccp_alpha': 0.0, 'criterion': 'squared_error', 'max_depth': 100, 'max_features': 0.9193546958301854, 'min_samples_leaf': 15, 'min_samples_split': 24}.

The tree was depth 20, with 798 leaves and 1595 nodes, so significantly less than the fully deep tree. This clearly demonstrates how increasing min_samples_ can help; while the numbers of leaves and nodes are similar to the depth 10 tree, having “larger” leaves with a deeper tree has improved the results.

For my single block it predicted 0.81632, so pretty close to the true value.

After putting it through the 1000 loops (which took just over 60 seconds - showing that the longest factor when fitting a tree is the pruning), the final scores:

MetricMeanStd
MAE0.3930.007
MAPE0.2160.006
MSE0.3510.013
RMSE0.5920.011
0.7360.010

Adding these to the box plots:

box plots of all errors inc bayessearch

Lower errors, lower variances, and higher R². Excellent.

Next, random forests!