Quick Start Guide: Optimal Prescriptive Trees

This is a Python version of the corresponding OptimalTrees quick start guide.

In this example we will give a demonstration of how to use Optimal Prescriptive Trees (OPT). We will examine the impact of job training on annual earnings using the Lalonde sample from the National Supported Work Demonstration dataset.

First we load in the data:

import pandas as pd
colnames = ['treatment', 'age', 'education', 'black', 'hispanic', 'married',
            'nodegree', 'earnings_1975', 'earnings_1978']
df_control = pd.read_csv('nsw_control.txt', names=colnames,
                         delim_whitespace=True)
df_treated = pd.read_csv('nsw_treated.txt', names=colnames,
                         delim_whitespace=True)
df = pd.concat([df_control, df_treated])
     treatment   age  education  ...  nodegree  earnings_1975  earnings_1978
0          0.0  23.0       10.0  ...       1.0          0.000          0.000
1          0.0  26.0       12.0  ...       0.0          0.000      12383.680
2          0.0  22.0        9.0  ...       1.0          0.000          0.000
3          0.0  34.0        9.0  ...       1.0       4368.413      14051.160
4          0.0  18.0        9.0  ...       1.0          0.000      10740.080
5          0.0  45.0       11.0  ...       1.0          0.000      11796.470
6          0.0  18.0        9.0  ...       1.0          0.000       9227.052
..         ...   ...        ...  ...       ...            ...            ...
290        1.0  25.0       14.0  ...       0.0      11536.570      36646.950
291        1.0  26.0       10.0  ...       1.0          0.000          0.000
292        1.0  20.0        9.0  ...       1.0          0.000       8881.665
293        1.0  31.0        4.0  ...       1.0       4023.211       7382.549
294        1.0  24.0       10.0  ...       1.0       4078.152          0.000
295        1.0  33.0       11.0  ...       1.0      25142.240       4181.942
296        1.0  33.0       12.0  ...       0.0      10941.350      15952.600

[722 rows x 9 columns]

Data for prescriptive problems

Prescriptive trees are trained on observational data, and require three distinct types of data:

  • X: the features for each observation that can be used as the splits in the tree - this can be a matrix or a dataframe as for classification or regression problems
  • treatments: the treatment applied to each observation - this is a vector of the treatment labels similar to the target in a classification problem
  • outcomes: the outcome for each observation under the applied treatment - this is a vector of numeric values similar to the target in a regression problem

Refer to the documentation on data preparation for more information on the data format.

In this case, the treatment is whether or not the subject received job training, and the outcome is their 1978 earnings (which we are trying to maximize):

X = df.iloc[:, 1:-1]
treatments = df.treatment.map({1: "training", 0: "no training"})
outcomes = df.earnings_1978

We can now split into training and test datasets:

from interpretableai import iai
(train_X, train_treatments, train_outcomes), (test_X, test_treatments, test_outcomes) = (
    iai.split_data('prescription_maximize', X, treatments, outcomes, seed=2))

Note that we have used the default 70%/30% split, but in many prescriptive problems it is desirable to save more data for testing to ensure high-quality reward estimation on the test set.

Fitting Optimal Prescriptive Trees

We will use a GridSearch to fit an OptimalTreePrescriptionMaximizer (note that if we were trying to minimize the outcomes, we would use OptimalTreePrescriptionMinimizer):

grid = iai.GridSearch(
    iai.OptimalTreePrescriptionMaximizer(
        prescription_factor=1,
        treatment_minbucket=20,
        random_seed=234,
    ),
    max_depth=range(1, 6),
)
grid.fit(train_X, train_treatments, train_outcomes)
grid.get_learner()
Optimal Trees Visualization

Here, we have set prescription_factor=1 to focus the trees on maximizing the outcome, and treatment_minbucket=10 so that the tree can only prescribe a treatment in a leaf if there are at least 10 subjects in that leaf that received this treatment. This is to ensure that we have sufficient data on how the treatment affects subjects in this leaf before we can prescribe it.

In the resulting tree, the color in each leaf indicates which treatment is deemed to be stronger in this leaf, and the color intensity indicates the size of the difference. The tree contains some interesting insights about the effect of training, for example:

  • Node 17 is where the training had the weakest effect, which is for older subjects with high earnings in 1975. This seems to make sense, as these people are likely the least in need of training.
  • Node 10 shows that those with low 1975 earnings, at least 9 years of education, and at least 28 years old benefitted greatly from the training.
  • Nodes 6 through 8 show that for those with at least 9 year of education and 1975 earnings below $1103, the effectiveness of the training was highly linked to the age of the subject, with older subjects benefitting much more.

We can make predictions on new data using predict:

pred_treatments, pred_outcomes = grid.predict(test_X)

This returns the treatment prescribed for each subject as well as the outcome predicted for each subject under the prescribed treatment:

pred_treatments
['training', 'no training', 'no training', 'training', 'training', 'training', 'no training', 'training', 'no training', 'training', 'training', 'no training', 'training', 'training', 'training', 'training', 'no training', 'no training', 'training', 'no training', 'training', 'no training', 'training', 'no training', 'training', 'no training', 'training', 'training', 'training', 'training', 'training', 'no training', 'no training', 'no training', 'no training', 'training', 'no training', 'no training', 'training', 'training', 'training', 'no training', 'no training', 'no training', 'no training', 'no training', 'training', 'training', 'training', 'training', 'training', 'no training', 'training', 'no training', 'training', 'no training', 'no training', 'no training', 'training', 'no training', 'no training', 'training', 'training', 'training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'training', 'training', 'training', 'training', 'no training', 'no training', 'no training', 'no training', 'training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'training', 'training', 'no training', 'no training', 'training', 'training', 'training', 'training', 'training', 'training', 'training', 'training', 'no training', 'training', 'no training', 'training', 'no training', 'training', 'no training', 'training', 'training', 'no training', 'no training', 'no training', 'training', 'no training', 'training', 'training', 'no training', 'training', 'no training', 'no training', 'no training', 'no training', 'training', 'training', 'training', 'training', 'training', 'no training', 'no training', 'training', 'training', 'training', 'no training', 'training', 'no training', 'training', 'training', 'training', 'no training', 'training', 'no training', 'no training', 'no training', 'training', 'no training', 'training', 'training', 'training', 'training', 'training', 'training', 'no training', 'training', 'training', 'training', 'no training', 'training', 'no training', 'no training', 'no training', 'no training', 'training', 'training', 'training', 'training', 'training', 'no training', 'no training', 'no training', 'training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'training', 'training', 'training', 'no training', 'training', 'training', 'training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'no training', 'training', 'no training', 'no training', 'training', 'no training', 'no training', 'training', 'training', 'no training', 'no training', 'training']
pred_outcomes
[6837.618458181823, 6779.621136363639, 4697.901148148165, 11123.96342608695, 5669.543748484844, 6837.618458181823, 7402.351424999986, 11123.96342608695, 3288.4266500000012, 11123.96342608695, 6837.618458181823, 4697.901148148165, 5669.543748484844, 6837.618458181823, 5669.543748484844, 6837.618458181823, 6779.621136363639, 5314.847373076918, 6837.618458181823, 9646.777769696972, 11123.96342608695, 9646.777769696972, 6837.618458181823, 9646.777769696972, 6837.618458181823, 4697.901148148165, 6837.618458181823, 6837.618458181823, 6837.618458181823, 6837.618458181823, 6837.618458181823, 9646.777769696972, 3288.4266500000012, 7402.351424999986, 6779.621136363639, 6837.618458181823, 9646.777769696972, 4697.901148148165, 6837.618458181823, 5669.543748484844, 6837.618458181823, 4697.901148148165, 6779.621136363639, 4697.901148148165, 6779.621136363639, 3288.4266500000012, 11123.96342608695, 6837.618458181823, 6837.618458181823, 11123.96342608695, 6837.618458181823, 3288.4266500000012, 5669.543748484844, 6779.621136363639, 6837.618458181823, 4697.901148148165, 9646.777769696972, 9646.777769696972, 5669.543748484844, 9646.777769696972, 9646.777769696972, 6837.618458181823, 11123.96342608695, 6837.618458181823, 9646.777769696972, 3288.4266500000012, 9646.777769696972, 7402.351424999986, 3288.4266500000012, 3288.4266500000012, 5669.543748484844, 6837.618458181823, 11123.96342608695, 11123.96342608695, 6779.621136363639, 3288.4266500000012, 4697.901148148165, 6779.621136363639, 5669.543748484844, 4697.901148148165, 3288.4266500000012, 3288.4266500000012, 6779.621136363639, 7402.351424999986, 6779.621136363639, 9646.777769696972, 5669.543748484844, 6837.618458181823, 9646.777769696972, 4697.901148148165, 5669.543748484844, 5669.543748484844, 5669.543748484844, 5669.543748484844, 5669.543748484844, 5669.543748484844, 6837.618458181823, 6837.618458181823, 9646.777769696972, 5669.543748484844, 5314.847373076918, 6837.618458181823, 5314.847373076918, 5669.543748484844, 5314.847373076918, 5669.543748484844, 5669.543748484844, 5314.847373076918, 5314.847373076918, 5314.847373076918, 5669.543748484844, 4697.901148148165, 5669.543748484844, 6837.618458181823, 5314.847373076918, 5669.543748484844, 5314.847373076918, 7402.351424999986, 5314.847373076918, 5314.847373076918, 6837.618458181823, 5669.543748484844, 6837.618458181823, 6837.618458181823, 5669.543748484844, 5314.847373076918, 5314.847373076918, 11123.96342608695, 11123.96342608695, 6837.618458181823, 4697.901148148165, 6837.618458181823, 3288.4266500000012, 6837.618458181823, 5669.543748484844, 6837.618458181823, 4697.901148148165, 6837.618458181823, 4697.901148148165, 5314.847373076918, 4697.901148148165, 6837.618458181823, 4697.901148148165, 11123.96342608695, 6837.618458181823, 6837.618458181823, 6837.618458181823, 11123.96342608695, 11123.96342608695, 6779.621136363639, 11123.96342608695, 6837.618458181823, 6837.618458181823, 7402.351424999986, 6837.618458181823, 3288.4266500000012, 5314.847373076918, 7402.351424999986, 9646.777769696972, 5669.543748484844, 6837.618458181823, 6837.618458181823, 6837.618458181823, 6837.618458181823, 5314.847373076918, 7402.351424999986, 7402.351424999986, 5669.543748484844, 6779.621136363639, 6779.621136363639, 6779.621136363639, 9646.777769696972, 3288.4266500000012, 9646.777769696972, 9646.777769696972, 9646.777769696972, 9646.777769696972, 6837.618458181823, 5669.543748484844, 5669.543748484844, 6779.621136363639, 5669.543748484844, 5669.543748484844, 5669.543748484844, 9646.777769696972, 9646.777769696972, 9646.777769696972, 3288.4266500000012, 5314.847373076918, 4697.901148148165, 4697.901148148165, 7402.351424999986, 6779.621136363639, 5669.543748484844, 6779.621136363639, 3288.4266500000012, 3288.4266500000012, 3288.4266500000012, 9646.777769696972, 9646.777769696972, 9646.777769696972, 9646.777769696972, 7402.351424999986, 5314.847373076918, 9646.777769696972, 5669.543748484844, 5314.847373076918, 9646.777769696972, 6837.618458181823, 9646.777769696972, 6779.621136363639, 6837.618458181823, 5669.543748484844, 9646.777769696972, 9646.777769696972, 6837.618458181823]

You can also use predict_outcomes to get the predicted outcomes for all treatments:

grid.predict_outcomes(test_X)
     no training      training
0    3106.895687   6837.618458
1    6779.621136   2578.706261
2    4697.901148   2769.287639
3    4754.784521  11123.963426
4    3351.408983   5669.543748
5    3106.895687   6837.618458
6    7402.351425   3963.328157
..           ...           ...
209  9646.777770   6320.100182
210  6779.621136   2578.706261
211  3106.895687   6837.618458
212  3351.408983   5669.543748
213  9646.777770   6320.100182
214  9646.777770   6320.100182
215  3106.895687   6837.618458

[216 rows x 2 columns]

Evaluating Optimal Prescriptive Trees

In prescription problems, it is complicated to evaluate the quality of a prescription policy because our data only contains the outcome for the treatment that was received. Because we don't know the outcomes for the treatments that were not received (known as the counterfactuals), we cannot simply evaluate our prescriptions against the test set as we normally do.

A common approach to resolve this problem is reward estimation, where so-called rewards are estimated for each treatment for each observation. These rewards indicate the relative credit a model should be given for prescribing each treatment to each observation, and thus can be used to evaluate the quality of the prescription policy. For more details on how the reward estimation procedure is conducted, refer to the reward estimation documentation.

We will use a RewardEstimator to estimate the rewards (note that we are passing in the test data rather the training data to ensure we get a fair out-of-sample evaluation):

reward_lnr = iai.RewardEstimator(
    propensity_estimation_method='random_forest',
    outcome_estimation_method='random_forest',
    reward_estimation_method='doubly_robust',
    random_seed=1,
)
rewards = reward_lnr.fit_predict(test_X, test_treatments, test_outcomes)
      no training      training
0    13466.146182   5840.713692
1     2950.447056   5562.140020
2    10780.625029   4377.646614
3    17391.485616   5864.284600
4     -599.990036   5225.244544
5     6346.626635  15981.109548
6     -847.699124   3445.858808
..            ...           ...
209   9441.837867  13602.548048
210   2275.944765  10463.247238
211   5524.173342  -1373.690742
212   4849.778366   4586.343908
213   3339.218108   -366.204516
214   9404.629238  42383.012909
215   3758.271583  -4818.630926

[216 rows x 2 columns]

We can now use these reward values to evaluate the prescription in many ways. For example, we might like to see the mean reward achieved across all prescriptions on the test set:

def evaluate_reward_mean(treatments, rewards):
  total = 0.0
  for i in range(len(treatments)):
    total += rewards[treatments[i]][i]
  return total / len(treatments)

evaluate_reward_mean(pred_treatments, rewards)
5963.83953539401

For comparison's sake, we can compare this to the mean reward achieved under the actual treatment assignments that were observed in the data:

evaluate_reward_mean(test_treatments, rewards)
5494.877142584506

We can see that the prescriptive tree policy indeed achieves better results than the actual assignments.