Advanced

Classification Trees with Logistic Regression

Optimal Regression Trees can fit linear regression models in each leaf as part of the tree training process, which can drastically increase the power of the tree. This is not available for Optimal Classification Trees, as fitting logistic regression models during the training process is computationally infeasible.

However, there are a number of ways to produce classification trees with logistic regression in the leaves, which for certain problems can lead to improved results. We will demonstrate these approaches on the following synthetic dataset:

using DataFrames, Statistics, StableRNGs
rng = StableRNG(1)  # for consistent output across Julia versions
X = DataFrame(rand(rng, 200, 5), :auto)
betax = (((X.x1 .< 0.5) .* (X.x2 .+ X.x3)) .+
         ((X.x1 .> 0.5) .* (0.2 * X.x4 .+ X.x5)))
y = betax .> mean(betax)
(X_train, y_train), (X_test, y_test) = IAI.split_data(:classification, X, y, seed=1)

1. Refitting a trained classification tree

We start by training a normal Optimal Classification Tree:

grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed=1,
        minbucket=20,
    ),
    max_depth=1:3,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization
IAI.score(grid, X_test, y_test)
0.7833333333333333

We can see the results are not very good, which is likely due to the fact that the underlying data being linear in nature, but the tree can only predict constants in each leaf. We can refit the leaves of the tree with refit_leaves!, using the refit_learner to specify that the model in each leaf should be determined using L-1 regularized logistic regressions found with glmnet (via GLMNetCVClassifier):

lnr = IAI.get_learner(grid)
IAI.refit_leaves!(lnr, X_train, y_train,
    refit_learner=IAI.GLMNetCVClassifier(n_folds=2),
)
Optimal Trees Visualization
IAI.score(lnr, X_test, y_test)
0.9

We see the performance has improved from before, as a result of adding in some linearity.

A limitation of this approach is that the tree structure is unchanged by the refitting process. This means that the new tree may not be pruned optimally with respect to the logistic regression models in each leaf. We might expect that adding logistic regression models in each leaf would allow us to use a smaller tree and still reach the same performance, but simply refitting the leaves with logistic regression models cannot solve this as neither the fitting or pruning of the tree has been conducted with any knowledge of the final linearity of leaf models.

2. Refitting a regression tree with linear predictions

One way to resolve the limitation of this first approach is to train the tree with some ability to fit linear predictions, as this might lead it to make splitting and pruning choices that are more suited to the linear nature of the final logistic regression models. We do this by first training a regression tree with linear predictions in the leaves, with our binary y as the target:

grid = IAI.GridSearch(
    IAI.OptimalTreeRegressor(
        random_seed=1,
        minbucket=20,
        regression_sparsity=:all,
    ),
    max_depth=1:2,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization

We can then take this trained regression tree and refit the leaf models with logistic regression instead of linear regression using copy_splits_and_refit_leaves!, using glmnet to fit the logistic regression models:

old_lnr = IAI.get_learner(grid)
new_lnr = IAI.OptimalTreeClassifier(random_seed=1)
IAI.copy_splits_and_refit_leaves!(new_lnr, old_lnr, X, y,
    refit_learner=IAI.GLMNetCVClassifier(n_folds=2),
)
Optimal Trees Visualization
IAI.score(new_lnr, X_test, y_test)
1.0

This results in a classification tree with the same split structure as the regression tree, but with logistic regression models in each leaf where appropriate, and we can see the performance has improved.

3. Refitting a classification tree during training

We can also overcome the limitation of the first approach by using the refit_learner parameter to incorporate the logistic regression refitting directly into the pruning process. This ensures the pruning of the tree will take into account the logistic regression models in the leaves, making it much easier to find the right-size tree.

In this case, we will fit again with logistic regression models with glmnet:

grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed=1,
        minbucket=20,
        refit_learner=IAI.GLMNetCVClassifier(),
    ),
    max_depth=1:2,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization
IAI.score(grid, X_test, y_test)
0.9333333333333333

We see that we now recover the correct tree structure, and have good performance. However, the ground truth regression models use fewer variables than our logistic regressions. To address this, we can also try using Optimal Feature Selection instead of glmnet to construct our logistic regression models. We will use a GridSearch that validates the sparsity of a OptimalFeatureSelectionClassifier as the refit_learner:

grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed=1,
        minbucket=20,
        refit_learner=IAI.GridSearch(
            IAI.OptimalFeatureSelectionClassifier(),
            sparsity=1:2,
        ),
    ),
    max_depth=1:2,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization
IAI.score(grid, X_test, y_test)
0.9

We achieve similar performance to before, but now the logistic regression models are much sparser.

In the example above, the refit_learner was a GridSearch, which means that the sparsity parameter will be validated separately in each leaf when fitting the logistic regression models. Another approach is to use the same sparsity in all leaves of the tree, and validate this shared sparsity parameter as part of the overall validation procedure. We can do this by including refit_learner_sparsity in the outermost GridSearch as a parameter to be validated:

grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed=1,
        minbucket=20,
        refit_learner=IAI.OptimalFeatureSelectionClassifier(),
    ),
    max_depth=1:2,
    refit_learner_sparsity=1:2,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization