Quick Start Guide: Optimal Survival Trees

In this example we will use Optimal Survival Trees (OST) on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

using CSV, DataFrames
df = CSV.read("mgus2.csv", DataFrame, missingstring="NA", pool=true)
1384×11 DataFrame
  Row │ Column1  id     age    sex     hgb       creat     mspike    ptime  ps ⋯
      │ Int64    Int64  Int64  String  Float64?  Float64?  Float64?  Int64  In ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │       1      1     88  F           13.1       1.3       0.5     30  0  ⋯
    2 │       2      2     78  F           11.5       1.2       2.0     25  0
    3 │       3      3     94  M           10.5       1.5       2.6     46  0
    4 │       4      4     68  M           15.2       1.2       1.2     92  0
    5 │       5      5     90  F           10.7       0.8       1.0      8  0  ⋯
    6 │       6      6     90  M           12.9       1.0       0.5      4  0
    7 │       7      7     89  F           10.5       0.9       1.3    151  0
    8 │       8      8     87  F           12.3       1.2       1.6      2  0
  ⋮   │    ⋮       ⋮      ⋮      ⋮        ⋮         ⋮         ⋮        ⋮       ⋱
 1378 │    1378   1378     56  M           16.1       0.8       0.5     59  0  ⋯
 1379 │    1379   1379     73  M           15.6       1.1       0.5     48  0
 1380 │    1380   1380     69  M           15.0       0.8       0.0     22  0
 1381 │    1381   1381     78  M           14.1       1.3       1.9     35  0
 1382 │    1382   1382     66  M           12.1       2.0       0.0     31  0  ⋯
 1383 │    1383   1383     82  F           11.5       1.1       2.3     38  1
 1384 │    1384   1384     79  M            9.6       1.1       1.7      6  0
                                                 3 columns and 1369 rows omitted
X = df[:, 3:(end - 4)]
died = df.death .== 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) =
    IAI.split_data(:survival, X, died, times, seed=1)

Optimal Survival Trees

We will use a GridSearch to fit an OptimalTreeSurvivalLearner:

grid = IAI.GridSearch(
    IAI.OptimalTreeSurvivalLearner(
        random_seed=1,
        missingdatamode=:separate_class,
    ),
    max_depth=1:2,
)
IAI.fit!(grid, train_X, train_died, train_times,
         validation_criterion=:harrell_c_statistic)
IAI.get_learner(grid)
Optimal Trees Visualization