Optimal Survival Trees

Quick Start Guide: Optimal Survival Trees

In this example we will use Optimal Survival Trees (OST) on the Monoclonal Gammopathy dataset. First we load in the data and split into training and test datasets:

using CSV
df = CSV.read("mgus2.csv", missingstring="NA", categorical=true, copycols=true)
1384×11 DataFrames.DataFrame. Omitted printing of 4 columns
│ Row  │       │ id    │ age   │ sex          │ hgb      │ creat    │ mspike   │
│      │ Int64 │ Int64 │ Int64 │ Categorical… │ Float64⍰ │ Float64⍰ │ Float64⍰ │
├──────┼───────┼───────┼───────┼──────────────┼──────────┼──────────┼──────────┤
│ 1    │ 1     │ 1     │ 88    │ F            │ 13.1     │ 1.3      │ 0.5      │
│ 2    │ 2     │ 2     │ 78    │ F            │ 11.5     │ 1.2      │ 2.0      │
│ 3    │ 3     │ 3     │ 94    │ M            │ 10.5     │ 1.5      │ 2.6      │
│ 4    │ 4     │ 4     │ 68    │ M            │ 15.2     │ 1.2      │ 1.2      │
│ 5    │ 5     │ 5     │ 90    │ F            │ 10.7     │ 0.8      │ 1.0      │
│ 6    │ 6     │ 6     │ 90    │ M            │ 12.9     │ 1.0      │ 0.5      │
│ 7    │ 7     │ 7     │ 89    │ F            │ 10.5     │ 0.9      │ 1.3      │
⋮
│ 1377 │ 1377  │ 1377  │ 81    │ F            │ 11.3     │ 2.9      │ 1.9      │
│ 1378 │ 1378  │ 1378  │ 56    │ M            │ 16.1     │ 0.8      │ 0.5      │
│ 1379 │ 1379  │ 1379  │ 73    │ M            │ 15.6     │ 1.1      │ 0.5      │
│ 1380 │ 1380  │ 1380  │ 69    │ M            │ 15.0     │ 0.8      │ 0.0      │
│ 1381 │ 1381  │ 1381  │ 78    │ M            │ 14.1     │ 1.3      │ 1.9      │
│ 1382 │ 1382  │ 1382  │ 66    │ M            │ 12.1     │ 2.0      │ 0.0      │
│ 1383 │ 1383  │ 1383  │ 82    │ F            │ 11.5     │ 1.1      │ 2.3      │
│ 1384 │ 1384  │ 1384  │ 79    │ M            │ 9.6      │ 1.1      │ 1.7      │
X = df[:, 3:(end - 4)]
died = df.death .== 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) =
    IAI.split_data(:survival, X, died, times, seed=1)

Optimal Survival Trees

We will use a GridSearch to fit an OptimalTreeSurvivor:

grid = IAI.GridSearch(
    IAI.OptimalTreeSurvivor(
        random_seed=1,
        missingdatamode=:separate_class,
        criterion=:localfulllikelihood,
    ),
    max_depth=1:5,
)
IAI.fit!(grid, train_X, train_died, train_times,
         validation_criterion=:integratedbrier)
IAI.get_learner(grid)
Optimal Trees Visualization