Quick Start Guide: Optimal Policy Trees with Survival Outcomes

In this guide we will give a demonstration of how to use Optimal Policy Trees with survival outcomes. For this example, we will use the AIDS Clinical Trials Group Study 175 dataset, which was a randomized clinical trial examining the effects of four treatments on the survival of patients with HIV.

Note: this case is not intended to serve as a practical application of policy trees, but rather to serve as an illustration of the training and evaluation process.

First we load in the data:

using CSV, DataFrames
df = CSV.read("ACTG175.txt", DataFrame)
2139×27 DataFrame
  Row │ pidnum  age    wtkg      hemo   homo   drugs  karnof  oprior  z30    z ⋯
      │ Int64   Int64  Float64   Int64  Int64  Int64  Int64   Int64   Int64  I ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │  10056     48   89.8128      0      0      0     100       0      0    ⋯
    2 │  10059     61   49.4424      0      0      0      90       0      1
    3 │  10089     45   88.452       0      1      1      90       0      1
    4 │  10093     47   85.2768      0      1      0     100       0      1
    5 │  10124     43   66.6792      0      1      0     100       0      1    ⋯
    6 │  10140     46   88.9056      0      1      1     100       0      1
    7 │  10165     31   73.0296      0      1      0     100       0      1
    8 │  10190     41   66.2256      0      1      1     100       0      1
  ⋮   │   ⋮       ⋮       ⋮        ⋮      ⋮      ⋮      ⋮       ⋮       ⋮      ⋱
 2133 │ 990018     27   80.2872      1      0      0      70       0      1    ⋯
 2134 │ 990019     39   64.8648      1      0      0      90       0      1
 2135 │ 990021     21   53.298       1      0      0     100       0      1
 2136 │ 990026     17  102.967       1      0      0     100       0      1
 2137 │ 990030     53   69.8544      1      1      0      90       0      1    ⋯
 2138 │ 990071     14   60.0         1      0      0     100       0      0
 2139 │ 990077     45   77.3         1      0      0     100       0      0
                                                18 columns and 2124 rows omitted

Policy trees are trained using a features matrix/dataframe X as usual and a rewards matrix that has one column for each potential treatment that contains the outcome for each sample under that treatment.

There are two ways to get this rewards matrix:

  • in rare cases, the problem may have full information about the outcome associated with each treatment for each sample
  • more commonly, we have observational data, and use this partial data to train models to estimate the outcome associated with each treatment

Refer to the documentation on data preparation for more information on the data format.

In this case, the dataset is observational, and so we will use RewardEstimation to estimate our rewards matrix.

Reward Estimation

First, we separate the dataset into the various pieces:

  • the features (X)
  • the treatments observed in the data (treatments)
  • whether the patient was known to have died (died)
  • the time of last contact with the patient (times) - this is the survival time for patients that died, and a lower bound on the survival time otherwise
X = select(df, [:age, :wtkg, :karnof, :cd40, :cd420, :cd80, :cd820, :gender,
                :homo, :race, :symptom, :drugs, :hemo, :str2])

treatment_map = Dict(
  0 => "zidovudine",
  1 => "zidovudine and didanosine",
  2 => "zidovudine and zalcitabine",
  3 => "didanosine"
)
treatments = map(t -> treatment_map[t], df.arms)

died = Bool.(df.cens)

times = df.days

Next, we split into training and testing:

(train_X, train_treatments, train_died, train_times), (test_X, test_treatments, test_died, test_times) =
    IAI.split_data(:prescription_maximize, X, treatments, died, times, seed=2345, train_proportion=0.5)

Note that we have used a training/test split of 50%/50%, so that we save more data for testing to ensure high-quality reward estimation on the test set.

The treatment is a categoric variable with four choices, so we follow the process for estimating rewards with categorical treatments.

Our outcome is the survival time of the patient, so we use a CategoricalSurvivalRewardEstimator to estimate the expected survival under each treatment option with a doubly-robust reward estimation method, using random forests to estimate both propensity scores and outcomes:

reward_lnr = IAI.CategoricalSurvivalRewardEstimator(
    propensity_estimator=IAI.RandomForestClassifier(),
    outcome_estimator=IAI.RandomForestSurvivalLearner(),
    reward_estimator=:doubly_robust,
    random_seed=1,
)

train_predictions, train_reward_score = IAI.fit_predict!(
    reward_lnr, train_X, train_treatments, train_died, train_times,
    propensity_score_criterion=:misclassification,
    outcome_score_criterion=:harrell_c_statistic)
train_rewards = train_predictions[:reward]
1069×4 DataFrame
  Row │ didanosine  zidovudine  zidovudine and didanosine  zidovudine and zalc ⋯
      │ Float64     Float64     Float64                    Float64             ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │    914.745     880.674                   1041.04                       ⋯
    2 │   1705.35      803.276                   1047.7
    3 │   1079.14     1003.64                    1082.81                     1
    4 │    986.44      380.173                    937.683                    1
    5 │  -2805.69      526.312                    837.364                      ⋯
    6 │    752.111    -748.213                    836.173
    7 │   1204.91     1143.15                    1091.91
    8 │    881.712     446.223                    887.084                   -1
  ⋮   │     ⋮           ⋮                   ⋮                          ⋮       ⋱
 1063 │   1196.44     1147.62                    1187.29                     1 ⋯
 1064 │   1194.02     1102.16                    1370.71                     1
 1065 │   1337.76      805.064                    956.125
 1066 │   1067.24     1139.94                    1193.82
 1067 │    971.123    1131.22                    1036.43                     1 ⋯
 1068 │    947.833    -479.767                    887.853                    1
 1069 │   1446.01     1044.59                    1197.51                     1
                                                  1 column and 1054 rows omitted
train_reward_score
Dict{Symbol, Any} with 3 entries:
  :propensity => 0.269403
  :censoring  => Dict("zidovudine and zalcitabine"=>0.569372, "zidovudine"=>0.5…
  :outcome    => Dict("zidovudine and zalcitabine"=>0.687999, "zidovudine"=>0.7…

We can see that the internal outcome estimation models have c-statistics around 0.7, which gives us confidence that the survival time estimates are of decent quality, and good to base our training on. The accuracy of the propensity model is around the same as random guessing at 25%, which is to be expected as the underlying data comes from a randomized trial. Given this, we may not expect the doubly-robust estimation method to perform significantly differently to the direct method, as there is not likely to be much treatment assignment bias to correct for.

Optimal Policy Trees

Now that we have a complete rewards matrix, we can train a tree to learn an optimal prescription policy that maximizes survival time. Note that we exclude two features from our prescription policy (:cd420 and :cd820) as these are observed after the treatment assignment is decided. We will use a GridSearch to fit an OptimalTreePolicyMaximizer:

grid = IAI.GridSearch(
    IAI.OptimalTreePolicyMaximizer(
        random_seed=1,
        minbucket=10,
    ),
    max_depth=1:5,
)
IAI.fit!(grid, select(train_X, Not([:cd420, :cd820])), train_rewards)
Fitted OptimalTreePolicyMaximizer:
  1) Split: age < 35.5
    2) Prescribe: zidovudine and zalcitabine, 595 points, error 4377.8
    3) Prescribe: zidovudine and didanosine, 474 points, error 4396.1
Optimal Trees Visualization

The resulting tree recommends different treatments based simply on the age of the patient.

We can make treatment prescriptions using predict:

IAI.predict(grid, train_X)
1069-element Vector{String}:
 "zidovudine and didanosine"
 "zidovudine and didanosine"
 "zidovudine and didanosine"
 "zidovudine and zalcitabine"
 "zidovudine and didanosine"
 "zidovudine and zalcitabine"
 "zidovudine and zalcitabine"
 "zidovudine and didanosine"
 "zidovudine and didanosine"
 "zidovudine and zalcitabine"
 ⋮
 "zidovudine and zalcitabine"
 "zidovudine and zalcitabine"
 "zidovudine and zalcitabine"
 "zidovudine and zalcitabine"
 "zidovudine and zalcitabine"
 "zidovudine and didanosine"
 "zidovudine and zalcitabine"
 "zidovudine and zalcitabine"
 "zidovudine and didanosine"

If we want more information about the relative performance of treatments for these points, we can predict the full treatment ranking with predict_treatment_rank:

IAI.predict_treatment_rank(grid, train_X)
1069×4 Matrix{String}:
 "zidovudine and didanosine"   "didanosine"  …  "zidovudine"
 "zidovudine and didanosine"   "didanosine"     "zidovudine"
 "zidovudine and didanosine"   "didanosine"     "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 "zidovudine and didanosine"   "didanosine"     "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"  …  "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 "zidovudine and didanosine"   "didanosine"     "zidovudine"
 "zidovudine and didanosine"   "didanosine"     "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 ⋮                                           ⋱
 "zidovudine and zalcitabine"  "didanosine"  …  "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 "zidovudine and didanosine"   "didanosine"  …  "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 "zidovudine and zalcitabine"  "didanosine"     "zidovudine"
 "zidovudine and didanosine"   "didanosine"     "zidovudine"

To quantify the difference in performance behind the treatment rankings, we can use predict_treatment_outcome to extract the estimated quality of each treatment for each point:

IAI.predict_treatment_outcome(grid, train_X)
1069×4 DataFrame
  Row │ didanosine  zidovudine  zidovudine and didanosine  zidovudine and zalc ⋯
      │ Float64     Float64     Float64                    Float64             ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │    1041.33     932.021                    1087.37                      ⋯
    2 │    1041.33     932.021                    1087.37
    3 │    1041.33     932.021                    1087.37
    4 │    1071.78     996.838                    1052.42
    5 │    1041.33     932.021                    1087.37                      ⋯
    6 │    1071.78     996.838                    1052.42
    7 │    1071.78     996.838                    1052.42
    8 │    1041.33     932.021                    1087.37
  ⋮   │     ⋮           ⋮                   ⋮                          ⋮       ⋱
 1063 │    1071.78     996.838                    1052.42                      ⋯
 1064 │    1071.78     996.838                    1052.42
 1065 │    1071.78     996.838                    1052.42
 1066 │    1041.33     932.021                    1087.37
 1067 │    1071.78     996.838                    1052.42                      ⋯
 1068 │    1071.78     996.838                    1052.42
 1069 │    1041.33     932.021                    1087.37
                                                  1 column and 1054 rows omitted

We can also extract the standard errors of the outcome estimates with predict_treatment_outcome_standard_error:

IAI.predict_treatment_outcome_standard_error(grid, train_X)
1069×4 DataFrame
  Row │ didanosine  zidovudine  zidovudine and didanosine  zidovudine and zalc ⋯
      │ Float64     Float64     Float64                    Float64             ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │    30.5984     34.6914                    30.6807                      ⋯
    2 │    30.5984     34.6914                    30.6807
    3 │    30.5984     34.6914                    30.6807
    4 │    22.8591     27.0675                    27.801
    5 │    30.5984     34.6914                    30.6807                      ⋯
    6 │    22.8591     27.0675                    27.801
    7 │    22.8591     27.0675                    27.801
    8 │    30.5984     34.6914                    30.6807
  ⋮   │     ⋮           ⋮                   ⋮                          ⋮       ⋱
 1063 │    22.8591     27.0675                    27.801                       ⋯
 1064 │    22.8591     27.0675                    27.801
 1065 │    22.8591     27.0675                    27.801
 1066 │    30.5984     34.6914                    30.6807
 1067 │    22.8591     27.0675                    27.801                       ⋯
 1068 │    22.8591     27.0675                    27.801
 1069 │    30.5984     34.6914                    30.6807
                                                  1 column and 1054 rows omitted

These standard errors can be combined with the outcome estimates to construct confidence intervals in the usual way.

Evaluating Optimal Policy Trees

It is critical for a fair evaluation that we do not evaluate the quality of the policy using rewards from our existing reward estimator trained on the training set. This is to avoid any information from the training set leaking through to the out-of-sample evaluation.

Instead, what we need to do is to estimate a new set of rewards using only the test set, and evaluate the policy against these rewards:

test_predictions, test_reward_score = IAI.fit_predict!(
    reward_lnr, test_X, test_treatments, test_died, test_times,
    propensity_score_criterion=:misclassification,
    outcome_score_criterion=:harrell_c_statistic)
test_rewards = test_predictions[:reward]
1070×4 DataFrame
  Row │ didanosine  zidovudine  zidovudine and didanosine  zidovudine and zalc ⋯
      │ Float64     Float64     Float64                    Float64             ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │ 1194.02       1135.68                    1188.11                     1 ⋯
    2 │ 1334.04       1107.71                    1094.2                      1
    3 │ 1113.53       1565.15                    1177.07                     1
    4 │ 1045.9        1415.26                    1178.2                      1
    5 │ -678.639       951.897                    872.011                      ⋯
    6 │ 1178.75       1075.76                    1166.4
    7 │  771.323       625.256                   3263.33
    8 │  750.545       162.174                    750.057
  ⋮   │     ⋮           ⋮                   ⋮                          ⋮       ⋱
 1064 │ 1158.64       1094.91                    1277.68                     1 ⋯
 1065 │ 1186.69       1149.25                   -2412.4                      1
 1066 │  860.972       918.337                    987.892                    -
 1067 │  791.351       797.424                    763.251                    1
 1068 │ 1727.16       1003.28                    1019.87                       ⋯
 1069 │ 1708.73        422.735                    846.498
 1070 │ 1049.67       1168.49                    1116.67                     1
                                                  1 column and 1055 rows omitted
test_reward_score
Dict{Symbol, Any} with 3 entries:
  :propensity => 0.262685
  :censoring  => Dict("zidovudine and zalcitabine"=>0.60256, "zidovudine"=>0.56…
  :outcome    => Dict("zidovudine and zalcitabine"=>0.722348, "zidovudine"=>0.7…

We see the scores are similar to those on the training set, giving us confidence that the estimated rewards are a fair reflection of reality, and will serve as a good basis for evaluation.

We can now evaluate the quality using these new estimated rewards. First, we will calculate the average survival time under the treatments prescribed by the tree for the test set. To do this, we use predict_outcomes which uses the model to make prescriptions and looks up the predicted outcomes under these prescriptions:

policy_outcomes = IAI.predict_outcomes(grid, test_X, test_rewards)
1070-element Vector{Float64}:
  1188.1091666380607
  1094.2021180331067
  1177.0665783819982
  1178.1970183357666
   872.0108062376389
   940.8756504538612
   751.2474377711033
   750.0566285121313
  1206.271608150635
  1182.4903899207156
     ⋮
   939.1709118743534
  1088.9109132175718
  1277.6827120478747
 -2412.3977857700565
  -383.83853694591016
  1089.782113429851
   898.050691670649
   935.9659869903758
  1116.6731914729753

We can then get the average estimated survival times under our treatment policy:

using Statistics
mean(policy_outcomes)
1060.4842813

We can compare this number to the average estimated survival time under the treatment assignments that were actually observed:

mean([test_rewards[i, test_treatments[i]] for i in 1:length(test_treatments)])
931.92805498

We see that our policy leads to a sizeable improvement in survival times compared to the randomized assignments.

Survival Probability as Outcome

In the example above, we addressed the problem with the goal of assigning treatments to maximize the expected survival time of each patient. As an alternative way of looking at the problem, we can also try to assign treatments that maximize each patient's probability of survival at a given point in time.

In this case, we will try to assign treatments to maximize the probability of surviving at least 2.5 years. The only change to the previous workflow is that we need to use the evaluation_time parameter on the reward estimator to specify the time of interest:

reward_lnr = IAI.CategoricalSurvivalRewardEstimator(
    propensity_estimator=IAI.RandomForestClassifier(),
    outcome_estimator=IAI.RandomForestSurvivalLearner(),
    reward_estimator=:doubly_robust,
    random_seed=1,
    evaluation_time=(365 * 2.5),
)

We then proceed to estimate rewards on the training set and train an Optimal Policy Tree as before:

train_predictions, train_reward_score = IAI.fit_predict!(
    reward_lnr, train_X, train_treatments, train_died, train_times,
    propensity_score_criterion=:misclassification,
    outcome_score_criterion=:harrell_c_statistic)
train_rewards = train_predictions[:reward]

grid = IAI.GridSearch(
    IAI.OptimalTreePolicyMaximizer(
        random_seed=1,
        minbucket=10,
    ),
    max_depth=1:5,
)
IAI.fit!(grid, select(train_X, Not([:cd420, :cd820])), train_rewards)
Optimal Trees Visualization

We see that the resulting tree finds leaves that have very different responses to the treatments. In particular, Node 7 prescribes zidovudine and didanosine with an estimated 97% 2.5-year survival rate, whereas in the adjacent Node 6, the estimated survival rate under this treatment is only 51%.

We can also proceed to estimate the rewards on the test set in the same way as before:

test_predictions, test_reward_score = IAI.fit_predict!(
    reward_lnr, test_X, test_treatments, test_died, test_times,
    propensity_score_criterion=:misclassification,
    outcome_score_criterion=:harrell_c_statistic)
test_rewards = test_predictions[:reward]

We can then get the average estimated 2.5-year survival probability under our treatment policy:

policy_outcomes = IAI.predict_outcomes(grid, test_X, test_rewards)
mean(policy_outcomes)
0.7423757

We can compare this number to the average estimated 2.5-year survival probability under the treatment assignments that were actually observed:

mean([test_rewards[i, test_treatments[i]] for i in 1:length(test_treatments)])
0.56421586

We see that our policy leads to a sizeable improvement in survival probability compared to the randomized assignments.