Quick Start Guide: Optimal Survival Trees

In this example we will use Optimal Survival Trees (OST) on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

using CSV, DataFrames
df = CSV.read("mgus2.csv", DataFrame, missingstring="NA", pool=true)
1384×11 DataFrame
  Row │ Column1  id     age    sex     hgb       creat     mspike    ptime  ps ⋯
      │ Int64    Int64  Int64  String  Float64?  Float64?  Float64?  Int64  In ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │       1      1     88  F           13.1       1.3       0.5     30     ⋯
    2 │       2      2     78  F           11.5       1.2       2.0     25
    3 │       3      3     94  M           10.5       1.5       2.6     46
    4 │       4      4     68  M           15.2       1.2       1.2     92
    5 │       5      5     90  F           10.7       0.8       1.0      8     ⋯
    6 │       6      6     90  M           12.9       1.0       0.5      4
    7 │       7      7     89  F           10.5       0.9       1.3    151
    8 │       8      8     87  F           12.3       1.2       1.6      2
  ⋮   │    ⋮       ⋮      ⋮      ⋮        ⋮         ⋮         ⋮        ⋮       ⋱
 1378 │    1378   1378     56  M           16.1       0.8       0.5     59     ⋯
 1379 │    1379   1379     73  M           15.6       1.1       0.5     48
 1380 │    1380   1380     69  M           15.0       0.8       0.0     22
 1381 │    1381   1381     78  M           14.1       1.3       1.9     35
 1382 │    1382   1382     66  M           12.1       2.0       0.0     31     ⋯
 1383 │    1383   1383     82  F           11.5       1.1       2.3     38
 1384 │    1384   1384     79  M            9.6       1.1       1.7      6
                                                 3 columns and 1369 rows omitted
X = df[:, 3:(end - 4)]
died = df.death .== 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) =
    IAI.split_data(:survival, X, died, times, seed=12345)

Optimal Survival Trees

We will use a GridSearch to fit an OptimalTreeSurvivalLearner:

grid = IAI.GridSearch(
    IAI.OptimalTreeSurvivalLearner(
        random_seed=1,
        missingdatamode=:separate_class,
        minbucket=15,
    ),
    max_depth=1:2,
)
IAI.fit!(grid, train_X, train_died, train_times,
         validation_criterion=:harrell_c_statistic)
IAI.get_learner(grid)
Optimal Trees Visualization

The survival tree shows the Kaplan-Meier survival curve for the population in each node as a solid red line.

In each split node:

  • the dotted green line shows the survival curve of the population in the lower/left child
  • the dotted blue line shows the survival curve of the population in the upper/right child

This means that the distance between the green and blue lines gives an indication on how well the split separates the two groups.

In each leaf node, the dotted black line shows the survival curve of the entire population, which allows us to easily see how the survival outlook for this subpopulation differs from the entire population.

In this example, age and hgb partition the population into three subgroups with distinct survival patterns:

  • Node 3: when age < 67.5 and hgb < 12.25, the population has a survival similar to the overall average
  • Node 4: when age < 67.5 and hgb > 12.25, the survival is significantly better than average
  • Node 6: when 67.5 < age < 77.5, the survival is similar to the overall average
  • Node 7: when age > 77.5, the survival is significantly worse than average

We can make predictions on new data using predict. For survival trees, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
 0.16553847096299912
 0.16553847096299912
 0.09713375796178347
 0.16553847096299912
 0.16553847096299912
 0.09713375796178347
 0.24537037037037024
 0.16553847096299912
 0.09713375796178347
 0.09713375796178347
 ⋮
 0.03283166851646313
 0.16553847096299912
 0.09713375796178347
 0.24537037037037024
 0.03283166851646313
 0.16553847096299912
 0.16553847096299912
 0.16553847096299912
 0.24537037037037024

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
 0.16553847096299912
 0.16553847096299912
 0.09713375796178347
 0.16553847096299912
 0.16553847096299912
 0.09713375796178347
 0.24537037037037024
 0.16553847096299912
 0.09713375796178347
 0.09713375796178347
 ⋮
 0.03283166851646313
 0.16553847096299912
 0.09713375796178347
 0.24537037037037024
 0.03283166851646313
 0.16553847096299912
 0.16553847096299912
 0.16553847096299912
 0.24537037037037024

We can evaluate the quality of the tree using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
          criterion=:harrell_c_statistic)
0.6608294740324181

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.6668716957368598

We can also evaluate the performance of the tree at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
          evaluation_time=10)
0.6793093093093092