Quick Start Guide: Optimal Survival Trees
In this example we will use Optimal Survival Trees (OST) on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:
using CSV, DataFrames
df = CSV.read("mgus2.csv", DataFrame, missingstring="NA", pool=true)
1384×11 DataFrame
Row │ Column1 id age sex hgb creat mspike ptime ps ⋯
│ Int64 Int64 Int64 String Float64? Float64? Float64? Int64 In ⋯
──────┼─────────────────────────────────────────────────────────────────────────
1 │ 1 1 88 F 13.1 1.3 0.5 30 0 ⋯
2 │ 2 2 78 F 11.5 1.2 2.0 25 0
3 │ 3 3 94 M 10.5 1.5 2.6 46 0
4 │ 4 4 68 M 15.2 1.2 1.2 92 0
5 │ 5 5 90 F 10.7 0.8 1.0 8 0 ⋯
6 │ 6 6 90 M 12.9 1.0 0.5 4 0
7 │ 7 7 89 F 10.5 0.9 1.3 151 0
8 │ 8 8 87 F 12.3 1.2 1.6 2 0
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
1378 │ 1378 1378 56 M 16.1 0.8 0.5 59 0 ⋯
1379 │ 1379 1379 73 M 15.6 1.1 0.5 48 0
1380 │ 1380 1380 69 M 15.0 0.8 0.0 22 0
1381 │ 1381 1381 78 M 14.1 1.3 1.9 35 0
1382 │ 1382 1382 66 M 12.1 2.0 0.0 31 0 ⋯
1383 │ 1383 1383 82 F 11.5 1.1 2.3 38 1
1384 │ 1384 1384 79 M 9.6 1.1 1.7 6 0
3 columns and 1369 rows omitted
X = df[:, 3:(end - 4)]
died = df.death .== 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) =
IAI.split_data(:survival, X, died, times, seed=1)
Optimal Survival Trees
We will use a GridSearch
to fit an OptimalTreeSurvivalLearner
:
grid = IAI.GridSearch(
IAI.OptimalTreeSurvivalLearner(
random_seed=1,
missingdatamode=:separate_class,
),
max_depth=1:2,
)
IAI.fit!(grid, train_X, train_died, train_times,
validation_criterion=:harrell_c_statistic)
IAI.get_learner(grid)