Quick Start Guide: Optimal Policy Trees

This is a Python version of the corresponding OptimalTrees quick start guide.

In this example we will give a demonstration of how to use Optimal Policy Trees. We will examine the impact of job training on annual earnings using the Lalonde sample from the National Supported Work Demonstration dataset. This is the same dataset used in the quick start guide for Optimal Prescriptive Trees but showcases a different approach to prescription.

First we load in the data:

import pandas as pd
colnames = ['treatment', 'age', 'education', 'black', 'hispanic', 'married',
            'nodegree', 'earnings_1975', 'earnings_1978']
df_control = pd.read_csv('nsw_control.txt', names=colnames,
                         delim_whitespace=True)
df_treated = pd.read_csv('nsw_treated.txt', names=colnames,
                         delim_whitespace=True)
df = pd.concat([df_control, df_treated])
     treatment   age  education  ...  nodegree  earnings_1975  earnings_1978
0          0.0  23.0       10.0  ...       1.0          0.000          0.000
1          0.0  26.0       12.0  ...       0.0          0.000      12383.680
2          0.0  22.0        9.0  ...       1.0          0.000          0.000
3          0.0  34.0        9.0  ...       1.0       4368.413      14051.160
4          0.0  18.0        9.0  ...       1.0          0.000      10740.080
5          0.0  45.0       11.0  ...       1.0          0.000      11796.470
6          0.0  18.0        9.0  ...       1.0          0.000       9227.052
..         ...   ...        ...  ...       ...            ...            ...
290        1.0  25.0       14.0  ...       0.0      11536.570      36646.950
291        1.0  26.0       10.0  ...       1.0          0.000          0.000
292        1.0  20.0        9.0  ...       1.0          0.000       8881.665
293        1.0  31.0        4.0  ...       1.0       4023.211       7382.549
294        1.0  24.0       10.0  ...       1.0       4078.152          0.000
295        1.0  33.0       11.0  ...       1.0      25142.240       4181.942
296        1.0  33.0       12.0  ...       0.0      10941.350      15952.600

[722 rows x 9 columns]

Policy trees are trained using a features matrix/dataframe X as usual and a rewards matrix that has one column for each potential treatment that contains the outcome for each sample under that treatment.

There are two ways to get this rewards matrix:

  • in rare cases, the problem may have full information about the outcome associated with each treatment for each sample
  • more commonly, we have observational data, and use this partial data to train models to estimate the outcome associated with each treatment

Refer to the documentation on data preparation for more information on the data format.

In this case, the dataset is observational, and so we will use RewardEstimation to estimate our rewards matrix.

Reward Estimation

The treatment is whether or not the subject received job training, and the outcome is their 1978 earnings (which we are trying to maximize).

First, we extract this information and split into training and testing:

from interpretableai import iai

X = df.iloc[:, 1:-1]
treatments = df.treatment.map({1: "training", 0: "no training"})
outcomes = df.earnings_1978

(train_X, train_treatments, train_outcomes), (test_X, test_treatments, test_outcomes) = (
    iai.split_data('prescription_maximize', X, treatments, outcomes, seed=2))

Note that we have used the default 70%/30% split, but in many prescriptive problems it is desirable to save more data for testing to ensure high-quality reward estimation on the test set.

We will use a RewardEstimator to estimate the 1978 earnings for each participant in the study if they had received the opposite treatment to the one in the data:

reward_lnr = iai.RewardEstimator(
    propensity_estimation_method='random_forest',
    outcome_estimation_method='random_forest',
    reward_estimation_method='direct_method',
    random_seed=123,
)
train_rewards = reward_lnr.fit_predict(train_X, train_treatments,
                                       train_outcomes)
      no training     training
0     2285.855381  5478.406659
1     5110.322947  5288.380815
2    10112.899233  3752.631732
3     5144.866533  5475.134910
4     8585.753646  3806.614232
5     5144.866533  5475.134910
6     4325.421920  9195.874487
..            ...          ...
499   4015.435727  4941.580227
500   4510.511407  6869.504678
501   4542.340423  6353.019829
502   7281.530896  5726.990139
503   3136.035717  4247.089859
504   8103.522340  6597.851004
505   5805.665336  8545.415283

[506 rows x 2 columns]

Optimal Policy Trees

Now that we have a complete rewards matrix, we can train a tree to learn an optimal prescription policy that maximizes 1978 earnings. We will use a GridSearch to fit an OptimalTreePolicyMaximizer (note that if we were trying to minimize the outcomes, we would use OptimalTreePolicyMinimizer):

grid = iai.GridSearch(
    iai.OptimalTreePolicyMaximizer(
        random_seed=1,
        minbucket=10,
    ),
    max_depth=range(1, 6),
)
grid.fit(train_X, train_rewards, train_proportion=0.5)
grid.get_learner()
Optimal Trees Visualization

The resulting tree recommends training based on three variables: age, education and 1975 earnings. The intensity of the color in each leaf shows the strength of the treatment effect. We make the following observations:

  • Training is least effective in node 6, which are older people with low 1975 earnings. This is understandable, as unskilled people at this age may have trouble picking up new skills and benefiting from the training.
  • Training is most effective in node 5, which are younger, educated people with low 1975 earnings. In contrast to node 6, these people seem to benefit from the training due to their youth and education.
  • There is also a training benefit in node 11, which are those with high 1975 earnings and age over 26. It seems the training also benefits those that are slightly older and with a higher baseline level of earnings.

We can make treatment prescriptions using predict:

grid.predict(train_X)

We can also use our estimated rewards to look up the predicted outcomes under these prescriptions with [predict_outcomes](@ref predict_outcomes(::PolicyLearner):

grid.predict_outcomes(train_X, train_rewards)
[5478.40665904, 5288.3808148, 3752.63173195, 5475.13491035, 8585.75364634, 5475.13491035, 9195.87448716, 8388.73697497, 5074.72757498, 6452.28253923, 6345.31770071, 5312.5516973, 7536.86899916, 6355.60264537, 5181.53623057, 2213.00706081, 4018.88814014, 1723.28373609, 8741.79439234, 3937.83387398, 4634.3676156, 3103.34529829, 4125.28557533, 7227.73101823, 4616.61212981, 7009.73651832, 5381.36467691, 6602.97171097, 4071.94993864, 9826.66615093, 7842.39462157, 4542.63708176, 6277.02726759, 9007.21716122, 2842.2976019, 13293.52681571, 8828.32436709, 4250.63995174, 7437.87770438, 8255.70323649, 6285.85315842, 9145.15051664, 6685.97368789, 5006.2109049, 6071.78540287, 3774.4919224, 6033.67289817, 4125.28557533, 3381.50658464, 7630.16279414, 3273.47393648, 4713.08705096, 6612.74463523, 4048.02671024, 10391.19998219, 6440.70816493, 4991.41620494, 4273.34375883, 6972.51365516, 4554.24948525, 7584.2633865, 7297.99083758, 5976.35115713, 6117.08386759, 6071.78540287, 3130.41232902, 7759.20431371, 5088.71695828, 7371.96444137, 7252.44439386, 4826.49713042, 5117.33031301, 3381.50658464, 7961.7426983, 6248.38854736, 6231.04691694, 4212.21675351, 13024.54830547, 6641.74342078, 7152.11102174, 7254.54737013, 6962.23605769, 5269.75794348, 3514.46780096, 5288.3808148, 7483.83573147, 7534.18919434, 6801.53574898, 5539.96853622, 3381.50658464, 8020.48129932, 3103.34529829, 5763.54409976, 3117.43226648, 6182.65849756, 6355.60264537, 5610.52742432, 4353.76936334, 6932.53569428, 4158.87394511, 11079.72779588, 7382.68249052, 6532.75190868, 13857.84432622, 5760.10504012, 9335.56185022, 3652.01325775, 6894.84121747, 7670.253175, 10344.04034726, 6351.24964221, 6083.39510703, 6016.68758285, 6612.74463523, 11906.58114209, 6484.12993311, 4061.21497083, 10774.05406406, 5539.66090343, 4538.09878476, 5385.42729826, 7630.16279414, 8373.44315578, 13017.86414135, 7364.93110444, 6277.02726759, 5475.13491035, 4967.52525676, 14587.92246884, 9109.56965663, 4293.2882951, 5041.09212631, 8841.115031, 8800.48267358, 4628.3022281, 3652.01325775, 6912.36298283, 11944.44057519, 7903.41005545, 7867.24767044, 4557.08419949, 3651.36352403, 7145.02591785, 4888.27147206, 6931.34720962, 8882.45800723, 4353.76936334, 10494.87552472, 8849.46370087, 9494.95983633, 8023.27208377, 5740.71203323, 6224.94561184, 6650.64670551, 9202.73765648, 4212.21675351, 7085.94320387, 7531.09586861, 5086.71144968, 8296.23301571, 5030.4992122, 5478.40665904, 3699.5266339, 8222.32491786, 10822.07202667, 3381.50658464, 4018.53783184, 7502.6643095, 3309.28972293, 4402.4846499, 6178.21612283, 6766.61830339, 7049.8686808, 3103.34529829, 7000.25379044, 10376.23571628, 8060.56908489, 3911.5224358, 5253.47205226, 3853.4555519, 8340.78015948, 4144.09204314, 4535.56203272, 7757.33943753, 6502.19202122, 6873.68448755, 7112.61520049, 4212.21675351, 7327.70294634, 2142.0439037, 6045.18191146, 10133.72400142, 5752.28053436, 8360.88065403, 7701.93515527, 6049.20982993, 4465.6678309, 5140.68372071, 5384.27333575, 7099.75722, 7651.69217104, 4125.28557533, 7074.16409576, 8410.00984927, 9349.04611231, 2755.23410416, 6142.77411238, 5656.80517149, 9352.20982572, 6059.34330311, 7744.86215026, 2154.93934819, 6689.58354323, 5478.40665904, 2542.71881545, 13053.94848953, 8481.94914612, 8279.52656639, 9793.77246899, 6515.42444566, 13501.40414435, 6459.30032566, 7422.20343089, 17859.25307255, 7697.25296176, 6505.57156641, 6721.23212072, 5482.21108285, 8506.59637686, 5464.72286215, 4874.8712997, 6179.91211941, 5330.70099841, 5494.47096509, 5482.6699622, 4800.77893259, 7681.97168173, 7752.4318512, 7677.19890719, 6300.97212164, 5846.12864315, 7916.73861658, 5241.81662753, 6753.97092179, 5273.96186351, 5083.68963326, 5482.6699622, 6359.47551818, 6799.0286606, 4906.0818967, 5302.69628224, 8042.60201492, 8595.94957799, 6300.97212164, 6386.02747432, 4379.79637333, 6376.9911667, 7518.46149206, 4900.83011587, 4704.18540875, 4852.21266545, 5241.81662753, 7950.72535661, 4704.18540875, 4289.33619424, 5065.18298725, 5654.84143772, 7829.58264302, 6300.97212164, 3485.81627903, 4949.24289697, 7534.24762261, 4665.83564392, 6931.34720962, 6479.16607475, 6890.87415571, 6165.86249222, 6353.01982895, 6088.20784643, 8226.86717432, 4563.71170765, 7363.93317766, 3803.49141877, 5654.84143772, 5047.77052502, 5475.13491035, 5241.81662753, 6166.64106222, 5960.28204863, 5654.84143772, 7653.77841948, 4704.18540875, 10932.28963947, 4365.22707183, 3478.05149957, 7039.99686106, 4735.98022143, 5654.84143772, 6546.64652442, 5391.44508725, 4058.75183862, 5288.3808148, 6912.36298283, 7847.12070989, 9496.25475476, 6650.64670551, 10481.35784386, 10460.61725563, 5343.79686126, 11944.44057519, 5478.40665904, 9105.39238183, 9352.20982572, 4212.21675351, 8005.93062583, 8583.63787909, 3514.46780096, 6300.97212164, 7112.61520049, 4125.28557533, 4704.18540875, 3514.46780096, 4704.18540875, 6330.30559703, 3774.4919224, 4353.76936334, 12509.84255583, 3782.72237005, 6373.58914953, 5191.66654922, 3103.34529829, 5092.15808499, 8984.22120333, 7098.38023054, 4967.52525676, 5330.70099841, 3803.49141877, 5010.32571616, 4584.85484976, 8951.80754189, 4353.76936334, 5173.18881716, 4704.18540875, 4878.71109278, 5540.52932584, 10978.48424416, 7139.87601005, 4542.72184529, 3938.38896504, 6071.78540287, 5391.44508725, 4143.92801282, 6110.82089571, 5742.69604332, 8114.544551, 4181.63323514, 4804.20589743, 8721.85791382, 6936.09091306, 10694.32035656, 8373.44315578, 18646.01283266, 4616.61212981, 10704.54128495, 6972.75263448, 8297.04337982, 3335.52050872, 4371.42276035, 11944.44057519, 4071.94993864, 6577.13257499, 9779.24471785, 7522.53712603, 4542.63708176, 8951.80754189, 4704.18540875, 3451.76710815, 4125.28557533, 4557.08419949, 6542.29022389, 7469.2875799, 6681.85850655, 5475.13491035, 5547.80126839, 5123.99749583, 3381.50658464, 14605.98880602, 3103.34529829, 7867.24767044, 8242.13936839, 4125.28557533, 4824.34559773, 5150.96733079, 5241.81662753, 5654.84143772, 6602.97171097, 11944.44057519, 8128.0347667, 3803.49141877, 4636.36994022, 8246.33686197, 4441.51827847, 3742.06349841, 3074.84696437, 8136.07890723, 7739.5367631, 8632.99900503, 7680.68060079, 3800.35507729, 7561.76048881, 3938.93315007, 7656.43987085, 4745.62483123, 9502.95162841, 6876.40259486, 4470.01638961, 6466.27430911, 8888.09943222, 7117.35895727, 4614.15802235, 7364.93110444, 4557.55668004, 5806.59041893, 8673.67092724, 5968.88497061, 5227.74652964, 8080.9718426, 12165.48713777, 8595.02112325, 9456.62579783, 10483.7192527, 6737.88645696, 7052.5474904, 10721.39219555, 8616.53253121, 4016.70330012, 4951.10568582, 6219.51379001, 3087.5738519, 3774.03089199, 8333.35603175, 7000.80894806, 5246.08120273, 7506.85729643, 10476.3031872, 8594.26242586, 11016.2574703, 7241.61643215, 5385.18103056, 4164.47192612, 8319.50669627, 4466.68532347, 4613.59388166, 7751.43470479, 4716.27058672, 5084.00812604, 4229.3028548, 2871.24145116, 3991.56509533, 6542.35462566, 3439.05810399, 5312.5516973, 6660.01397604, 5163.96906273, 6868.15789963, 10242.79416855, 4116.47123108, 5444.30713339, 5281.15679457, 3029.73191403, 5986.69601867, 3381.8149835, 6447.66418141, 7151.24019898, 8717.97927902, 3924.82746184, 5809.24411871, 5503.77007301, 3615.51401768, 4085.2500046, 7858.6894974, 11680.95589973, 5253.15856438, 9576.27954723, 5321.87698508, 6623.18119809, 6836.06850963, 4294.64220512, 8398.33009371, 2936.63404859, 9522.79835292, 2958.46818396, 5879.46460394, 7624.12960915, 6638.82310271, 8957.58391362, 4235.66546799, 5868.71916244, 5051.02180961, 4941.58022747, 6869.50467782, 6353.01982895, 7281.53089567, 4247.0898589, 6597.85100371, 8545.41528283]

Evaluating Optimal Policy Trees

It is critical for a fair evaluation that we do not evaluate the quality of the policy using rewards from our existing reward estimator trained on the training set. This is to avoid any information from the training set leaking through to the out-of-sample evaluation.

Instead, what we need to do is to estimate a new set of rewards using only the test set, and evaluate the policy against these rewards:

test_reward_lnr = iai.RewardEstimator(
    propensity_estimation_method='random_forest',
    outcome_estimation_method='random_forest',
    reward_estimation_method='doubly_robust',
    random_seed=1,
)
test_rewards = test_reward_lnr.fit_predict(test_X, test_treatments,
                                           test_outcomes)
      no training      training
0    13466.146182   5840.713692
1     2950.447056   5562.140020
2    10780.625029   4377.646614
3    17391.485616   5864.284600
4     -599.990036   5225.244544
5     6346.626635  15981.109548
6     -847.699124   3445.858808
..            ...           ...
209   9441.837867  13602.548048
210   2275.944765  10463.247238
211   5524.173342  -1373.690742
212   4849.778366   4586.343908
213   3339.218108   -366.204516
214   9404.629238  42383.012909
215   3758.271583  -4818.630926

[216 rows x 2 columns]

We can then evaluate the quality using these new estimated rewards. First, we will calculate the average predicted 1978 earnings under the treatments prescribed by the tree for the test set:

def evaluate_reward_mean(treatments, rewards):
  total = 0.0
  for i in range(len(treatments)):
    total += rewards[treatments[i]][i]
  return total / len(treatments)

evaluate_reward_mean(grid.predict(test_X), test_rewards)
6047.52316269659

We can compare this number to the average earnings under the treatment assignments that were actually observed:

evaluate_reward_mean(test_treatments, test_rewards)
5494.877142584506

We see a significant improvement in our prescriptions over the baseline.