Advanced

Classification Trees with Logistic Regression

Optimal Regression Trees can fit linear regression models in each leaf as part of the tree training process, which can drastically increase the power of the tree. Unfortunately, fitting regularized logistic regression models during the Optimal Classification Tree training process is computationally infeasible.

However, there are a number of ways to produce classification trees with logistic regression in the leaves, which for certain problems can lead to improved results:

  1. Training with linear discriminant analysis (LDA) models in each leaf
  2. Refitting a trained classification tree with logistic regression after training
  3. Refitting a linear-prediction regression tree with logistic regression
  4. Refitting a classification tree with logistic regression during training

We will demonstrate these approaches on the following synthetic dataset:

using DataFrames, Statistics, StableRNGs
rng = StableRNG(1)  # for consistent output across Julia versions
X = DataFrame(rand(rng, 200, 5), :auto)
betax = (((X.x1 .< 0.5) .* (X.x2 .+ X.x3)) .+
         ((X.x1 .> 0.5) .* (0.2 * X.x4 .+ X.x5)))
y = betax .> mean(betax)
(X_train, y_train), (X_test, y_test) = IAI.split_data(:classification, X, y, seed=1)

We start by training a normal Optimal Classification Tree:

grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed=1,
        minbucket=20,
    ),
    max_depth=1:3,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization
IAI.score(grid, X_test, y_test)
0.7833333333333333

We can see the results are not very good, which is likely due to the fact that the underlying data is linear in nature, but the tree can only predict constants in each leaf.

1. Training with LDA models in each leaf

The first option we can consider is training an Optimal Classification Tree with LDA models embedded in each leaf, which we achieve using the regression_features parameter:

grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed=1,
        minbucket=20,
        regression_features=All(),
    ),
    max_depth=1:3,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization
IAI.score(grid, X_test, y_test)
0.8833333333333333

We can see that this has resulted in linear models in each leaf, and has increased the quality of the model. Moreover, the correct split has been identified.

The main limitation of this approach can be seen by examining the models in each leaf - they are fully dense and use every feature in the data, making it difficult to interpret and extremely prone to overfitting. Additionally, the LDA fitting process can become expensive if too many features are used, so it can be prudent to limit the number of features specified in regression_features.

2. Refitting a trained classification tree

Another approach to introducing linearity is to refit the leaves of a trained tree with refit_leaves!, using the refit_learner to specify that the model in each leaf should be determined using L-1 regularized logistic regressions found with glmnet (via GLMNetCVClassifier):

lnr = IAI.get_learner(grid)
IAI.refit_leaves!(lnr, X_train, y_train,
    refit_learner=IAI.GLMNetCVClassifier(n_folds=2),
)
Optimal Trees Visualization
IAI.score(lnr, X_test, y_test)
0.9333333333333333

We see the performance has improved from before as a result of refining the models in each leaf, and by examining the models we can see that they are indeed sparser than before.

Note that we are able to refit any trained classification tree using this approach. In the example above, we refit a tree that was trained with LDA models, but we could just as easily have refit a tree trained with constant predictions in each leaf.

A limitation of this approach is that the tree structure is unchanged by the refitting process. This means that the new tree may not be pruned optimally with respect to the logistic regression models in each leaf. This is particularly troublesome when we are refitting a tree that ony used constant predictions, as neither the fitting or pruning of the tree has been conducted with any knowledge of the final linearity of leaf models.

3. Refitting a regression tree with linear predictions

Another way introduce linearity is to find a high-quality tree structure first, and then refit this with logistic regression models. We do this by first training a regression tree with linear predictions in the leaves, with our binary y as the target:

grid = IAI.GridSearch(
    IAI.OptimalTreeRegressor(
        random_seed=1,
        minbucket=20,
        regression_features=All(),
    ),
    max_depth=1:2,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization

We can then take this trained regression tree and refit the leaf models with logistic regression instead of linear regression using copy_splits_and_refit_leaves!, using glmnet to fit the logistic regression models:

old_lnr = IAI.get_learner(grid)
new_lnr = IAI.OptimalTreeClassifier(random_seed=1)
IAI.copy_splits_and_refit_leaves!(new_lnr, old_lnr, X, y,
    refit_learner=IAI.GLMNetCVClassifier(n_folds=2),
)
Optimal Trees Visualization
IAI.score(new_lnr, X_test, y_test)
1.0

This results in a classification tree with the same split structure as the regression tree, but with logistic regression models in each leaf where appropriate, and we can see the performance has improved.

4. Refitting a classification tree during training

We can also use the refit_learner parameter to incorporate the logistic regression refitting directly into the pruning process. This ensures the pruning of the tree will take into account the logistic regression models in the leaves, making it much easier to find the right-size tree.

In this case, we will fit again with logistic regression models with glmnet:

grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed=1,
        minbucket=20,
        refit_learner=IAI.GLMNetCVClassifier(),
    ),
    max_depth=1:2,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization
IAI.score(grid, X_test, y_test)
0.9333333333333333

We see that we now recover the correct tree structure, and have good performance. However, the ground truth regression models use fewer variables than our logistic regressions. To address this, we can also try using Optimal Feature Selection instead of glmnet to construct our logistic regression models. We will use a GridSearch that validates the sparsity of a OptimalFeatureSelectionClassifier as the refit_learner:

grid = IAI.GridSearch(
    IAI.OptimalTreeClassifier(
        random_seed=1,
        minbucket=20,
        refit_learner=IAI.GridSearch(
            IAI.OptimalFeatureSelectionClassifier(),
            sparsity=1:2,
        ),
    ),
    max_depth=1:2,
)
IAI.fit!(grid, X_train, y_train)
Optimal Trees Visualization