Quick Start Guide: Optimal Policy Trees with Survival Outcomes

This is an R version of the corresponding OptimalTrees quick start guide.

In this guide we will give a demonstration of how to use Optimal Policy Trees with survival outcomes. For this example, we will use the AIDS Clinical Trials Group Study 175 dataset, which was a randomized clinical trial examining the effects of four treatments on the survival of patients with HIV.

Note: this case is not intended to serve as a practical application of policy trees, but rather to serve as an illustration of the training and evaluation process.

First we load in the data:

df <- read.table("ACTG175.txt", header = T)
  pidnum age    wtkg hemo homo drugs karnof oprior z30 zprior preanti race
1  10056  48 89.8128    0    0     0    100      0   0      1       0    0
2  10059  61 49.4424    0    0     0     90      0   1      1     895    0
  gender str2 strat symptom treat offtrt cd40 cd420 cd496 r cd80 cd820 cens
1      0    0     1       0     1      0  422   477   660 1  566   324    0
2      0    1     3       0     1      0  162   218    NA 0  392   564    1
  days arms
1  948    2
2 1002    3
 [ reached 'max' / getOption("max.print") -- omitted 2137 rows ]

Policy trees are trained using a features matrix/dataframe X as usual and a rewards matrix that has one column for each potential treatment that contains the outcome for each sample under that treatment.

There are two ways to get this rewards matrix:

  • in rare cases, the problem may have full information about the outcome associated with each treatment for each sample
  • more commonly, we have observational data, and use this partial data to train models to estimate the outcome associated with each treatment

Refer to the documentation on data preparation for more information on the data format.

In this case, the dataset is observational, and so we will use RewardEstimation to estimate our rewards matrix.

Reward Estimation

First, we separate the dataset into the various pieces:

  • the features (X)
  • the treatments observed in the data (treatments)
  • whether the patient was known to have died (died)
  • the time of last contact with the patient (times) - this is the survival time for patients that died, and a lower bound on the survival time otherwise
X <- df[, c("age", "wtkg", "karnof", "cd40", "cd420", "cd80", "cd820", "gender",
            "homo", "race", "symptom", "drugs", "hemo", "str2")]

treatment_map <- c(
    "zidovudine",
    "zidovudine and didanosine",
    "zidovudine and zalcitabine",
    "didanosine"
)
treatments <- treatment_map[df$arms + 1]

died <- as.logical(df$cens)

times <- df$days

Next, we split into training and testing:

split <- iai::split_data("prescription_maximize", X, treatments, died, times,
                         seed = 2345, train_proportion = 0.5)
train_X <- split$train$X
train_treatments <- split$train$treatments
train_died <- split$train$deaths
train_times <- split$train$times
test_X <- split$test$X
test_treatments <- split$test$treatments
test_died <- split$test$deaths
test_times <- split$test$times

Note that we have used a training/test split of 50%/50%, so that we save more data for testing to ensure high-quality reward estimation on the test set.

The treatment is a categoric variable with four choices, so we follow the process for estimating rewards with categorical treatments.

Our outcome is the survival time of the patient, so we use a categorical_survival_reward_estimator to estimate the expected survival under each treatment option with a doubly-robust reward estimation method, using random forests to estimate both propensity scores and outcomes:

reward_lnr <- iai::categorical_survival_reward_estimator(
    propensity_estimator = iai::random_forest_classifier(),
    outcome_estimator = iai::random_forest_survival_learner(),
    reward_estimator = "doubly_robust",
    random_seed = 1,
)

train_rewards <- iai::fit_predict(
    reward_lnr, train_X, train_treatments, train_died, train_times,
    propensity_score_criterion = "misclassification",
    outcome_score_criterion = "harrell_c_statistic")
train_rewards$rewards
   didanosine zidovudine zidovudine and didanosine zidovudine and zalcitabine
1    914.7447   880.6740                 1041.0396                   998.1612
2   1705.3515   803.2756                 1047.7031                   877.2287
3   1079.1379  1003.6351                 1082.8073                  1116.1556
4    986.4404   380.1726                  937.6825                  1004.8328
5  -2805.6857   526.3117                  837.3645                   778.4508
6    752.1108  -748.2132                  836.1725                   784.4516
7   1204.9082  1143.1498                 1091.9074                   207.6754
8    881.7123   446.2232                  887.0838                 -1583.7952
9    811.2369   850.6208                -2278.5258                   825.1151
10   982.1498   960.2611                 1042.8775                  1724.6217
11  1185.6566  1093.9143                 1179.8668                  1159.8659
12   766.9350   603.5744                -1790.4708                   864.6945
13  -133.7400   860.9913                 1025.1836                  1012.4676
14  1156.9335  1041.1249                 -521.1263                  1151.5515
15  1067.2024   983.1436                 1053.9369                  1691.5854
 [ reached 'max' / getOption("max.print") -- omitted 1054 rows ]
train_rewards$score$propensity
[1] 0.2694025
train_rewards$score$outcome
$`zidovudine and zalcitabine`
[1] 0.6879994

$zidovudine
[1] 0.7026574

$didanosine
[1] 0.6867158

$`zidovudine and didanosine`
[1] 0.7127891

We can see that the internal outcome estimation models have c-statistics around 0.7, which gives us confidence that the survival time estimates are of decent quality, and good to base our training on. The accuracy of the propensity model is around the same as random guessing at 25%, which is to be expected as the underlying data comes from a randomized trial. Given this, we may not expect the doubly-robust estimation method to perform significantly differently to the direct method, as there is not likely to be much treatment assignment bias to correct for.

Optimal Policy Trees

Now that we have a complete rewards matrix, we can train a tree to learn an optimal prescription policy that maximizes survival time. Note that we exclude two features from our prescription policy (cd420 and cd820) as these are observed after the treatment assignment is decided. We will use a grid_search to fit an optimal_tree_policy_maximizer:

grid <- iai::grid_search(
    iai::optimal_tree_policy_maximizer(
        random_seed = 1,
        minbucket = 10,
    ),
    max_depth = 1:5,
)
iai::fit(grid, subset(train_X, select = -c(cd420, cd820)),
         train_rewards$rewards)
iai::get_learner(grid)
Fitted OptimalTreePolicyMaximizer:
  1) Split: age < 35.5
    2) Prescribe: zidovudine and zalcitabine, 595 points, error 4377.8
    3) Prescribe: zidovudine and didanosine, 474 points, error 4396.1
Optimal Trees Visualization

The resulting tree recommends different treatments based simply on the age of the patient.

We can make treatment prescriptions using predict:

iai::predict(grid, train_X)
 [1] "zidovudine and didanosine"  "zidovudine and didanosine"
 [3] "zidovudine and didanosine"  "zidovudine and zalcitabine"
 [5] "zidovudine and didanosine"  "zidovudine and zalcitabine"
 [7] "zidovudine and zalcitabine" "zidovudine and didanosine"
 [9] "zidovudine and didanosine"  "zidovudine and zalcitabine"
[11] "zidovudine and didanosine"  "zidovudine and zalcitabine"
[13] "zidovudine and didanosine"  "zidovudine and didanosine"
[15] "zidovudine and zalcitabine" "zidovudine and zalcitabine"
[17] "zidovudine and zalcitabine" "zidovudine and didanosine"
[19] "zidovudine and zalcitabine" "zidovudine and zalcitabine"
[21] "zidovudine and zalcitabine" "zidovudine and didanosine"
[23] "zidovudine and zalcitabine" "zidovudine and zalcitabine"
[25] "zidovudine and zalcitabine" "zidovudine and didanosine"
[27] "zidovudine and zalcitabine" "zidovudine and didanosine"
[29] "zidovudine and didanosine"  "zidovudine and zalcitabine"
[31] "zidovudine and didanosine"  "zidovudine and zalcitabine"
[33] "zidovudine and zalcitabine" "zidovudine and zalcitabine"
[35] "zidovudine and zalcitabine" "zidovudine and zalcitabine"
[37] "zidovudine and zalcitabine" "zidovudine and zalcitabine"
[39] "zidovudine and didanosine"  "zidovudine and didanosine"
[41] "zidovudine and didanosine"  "zidovudine and zalcitabine"
[43] "zidovudine and didanosine"  "zidovudine and didanosine"
[45] "zidovudine and didanosine"  "zidovudine and didanosine"
[47] "zidovudine and zalcitabine" "zidovudine and didanosine"
[49] "zidovudine and didanosine"  "zidovudine and zalcitabine"
[51] "zidovudine and zalcitabine" "zidovudine and didanosine"
[53] "zidovudine and zalcitabine" "zidovudine and didanosine"
[55] "zidovudine and zalcitabine" "zidovudine and zalcitabine"
[57] "zidovudine and zalcitabine" "zidovudine and didanosine"
[59] "zidovudine and zalcitabine" "zidovudine and zalcitabine"
 [ reached getOption("max.print") -- omitted 1009 entries ]

If we want more information about the relative performance of treatments for these points, we can predict the full treatment ranking with predict_treatment_rank:

iai::predict_treatment_rank(grid, train_X)
        [,1]                         [,2]         [,3]
   [1,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
   [2,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
   [3,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
   [4,] "zidovudine and zalcitabine" "didanosine" "zidovudine and didanosine"
   [5,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
   [6,] "zidovudine and zalcitabine" "didanosine" "zidovudine and didanosine"
   [7,] "zidovudine and zalcitabine" "didanosine" "zidovudine and didanosine"
   [8,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
   [9,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
  [10,] "zidovudine and zalcitabine" "didanosine" "zidovudine and didanosine"
  [11,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
  [12,] "zidovudine and zalcitabine" "didanosine" "zidovudine and didanosine"
  [13,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
  [14,] "zidovudine and didanosine"  "didanosine" "zidovudine and zalcitabine"
  [15,] "zidovudine and zalcitabine" "didanosine" "zidovudine and didanosine"
        [,4]
   [1,] "zidovudine"
   [2,] "zidovudine"
   [3,] "zidovudine"
   [4,] "zidovudine"
   [5,] "zidovudine"
   [6,] "zidovudine"
   [7,] "zidovudine"
   [8,] "zidovudine"
   [9,] "zidovudine"
  [10,] "zidovudine"
  [11,] "zidovudine"
  [12,] "zidovudine"
  [13,] "zidovudine"
  [14,] "zidovudine"
  [15,] "zidovudine"
 [ reached getOption("max.print") -- omitted 1054 rows ]

To quantify the difference in performance behind the treatment rankings, we can use predict_treatment_outcome to extract the estimated quality of each treatment for each point:

iai::predict_treatment_outcome(grid, train_X)
   didanosine zidovudine zidovudine and didanosine zidovudine and zalcitabine
1    1041.331   932.0212                  1087.372                   1007.868
2    1041.331   932.0212                  1087.372                   1007.868
3    1041.331   932.0212                  1087.372                   1007.868
4    1071.780   996.8385                  1052.419                   1105.696
5    1041.331   932.0212                  1087.372                   1007.868
6    1071.780   996.8385                  1052.419                   1105.696
7    1071.780   996.8385                  1052.419                   1105.696
8    1041.331   932.0212                  1087.372                   1007.868
9    1041.331   932.0212                  1087.372                   1007.868
10   1071.780   996.8385                  1052.419                   1105.696
11   1041.331   932.0212                  1087.372                   1007.868
12   1071.780   996.8385                  1052.419                   1105.696
13   1041.331   932.0212                  1087.372                   1007.868
14   1041.331   932.0212                  1087.372                   1007.868
15   1071.780   996.8385                  1052.419                   1105.696
 [ reached 'max' / getOption("max.print") -- omitted 1054 rows ]

Evaluating Optimal Policy Trees

It is critical for a fair evaluation that we do not evaluate the quality of the policy using rewards from our existing reward estimator trained on the training set. This is to avoid any information from the training set leaking through to the out-of-sample evaluation.

Instead, what we need to do is to estimate a new set of rewards using only the test set, and evaluate the policy against these rewards:

test_rewards <- iai::fit_predict(
    reward_lnr, test_X, test_treatments, test_died, test_times,
    propensity_score_criterion = "misclassification",
    outcome_score_criterion = "harrell_c_statistic")
test_rewards$rewards
   didanosine zidovudine zidovudine and didanosine zidovudine and zalcitabine
1   1194.0187  1135.6776                 1188.1092                  1412.7220
2   1334.0375  1107.7136                 1094.2021                  1135.7435
3   1113.5301  1565.1480                 1177.0666                  1158.4334
4   1045.8976  1415.2554                 1178.1970                  1187.0529
5   -678.6386   951.8971                  872.0108                   947.5150
6   1178.7468  1075.7631                 1166.3991                   940.8757
7    771.3231   625.2559                 3263.3268                   751.2474
8    750.5452   162.1744                  750.0566                   782.4900
9   1152.3730  1803.8306                 1126.5204                  1206.2716
10  1104.0607  1080.2692                 1182.4904                  1767.5082
11  1743.5536  1114.7534                 1164.9641                  1154.5840
12   608.2099  1160.5157                 1179.8565                  1186.7581
13  -790.8765   709.1615                  838.7242                   979.8873
14  1070.0974   947.2137                 1148.4886                  1433.5791
15 -2130.8484  1090.4045                 1179.1558                  1015.1111
 [ reached 'max' / getOption("max.print") -- omitted 1055 rows ]
test_rewards$score$propensity
[1] 0.262685
test_rewards$score$outcome
$`zidovudine and zalcitabine`
[1] 0.7223475

$zidovudine
[1] 0.718315

$didanosine
[1] 0.7463456

$`zidovudine and didanosine`
[1] 0.738453

We see the scores are similar to those on the training set, giving us confidence that the estimated rewards are a fair reflection of reality, and will serve as a good basis for evaluation.

We can now evaluate the quality using these new estimated rewards. First, we will calculate the average survival time under the treatments prescribed by the tree for the test set. To do this, we use predict_outcomes which uses the model to make prescriptions and looks up the predicted outcomes under these prescriptions:

policy_outcomes <- iai::predict_outcomes(grid, test_X, test_rewards$rewards)
 [1]  1188.10917  1094.20212  1177.06658  1178.19702   872.01081   940.87565
 [7]   751.24744   750.05663  1206.27161  1182.49039  1164.96405  1179.85646
[13]   838.72417  1148.48859  1015.11106  1107.43456  1719.67019  1341.00882
[19]   -33.96301   252.85546  1348.36051  1140.90320 -1392.41932   910.74883
[25]  1116.11398  1392.66871  1168.83016   832.27184  1108.43842  1063.42210
[31]  1327.02281  1449.81945  1158.35389  1307.46239  1122.13401  1338.69251
[37]  1170.61667  1105.72337  1177.45346  1174.34497  1164.05590  1144.58269
[43]   903.44999  1113.02317  1137.75844  1138.55335  1112.06960  1005.16941
[49]  1472.20035  1199.33044  1075.19811  1163.71621  1178.63241  1179.94737
[55]  1204.13715  1154.38423  1119.77749  1251.45116  1202.65741  1186.48096
 [ reached getOption("max.print") -- omitted 1010 entries ]

We can then get the average estimated survival times under our treatment policy:

mean(policy_outcomes)
[1] 1060.484

We can compare this number to the average estimated survival time under the treatment assignments that were actually observed:

mean(test_rewards$rewards[cbind(1:length(test_treatments), test_treatments)])
[1] 931.9281

We see that our policy leads to a sizeable improvement in survival times compared to the randomized assignments.