# Quick Start Guide: Heuristic Survival Learners

In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

using CSV, DataFrames
df = CSV.read("mgus2.csv", DataFrame, missingstring="NA", pool=true)
1384×11 DataFrame
Row │ Column1  id     age    sex     hgb       creat     mspike    ptime  ps ⋯
│ Int64    Int64  Int64  String  Float64?  Float64?  Float64?  Int64  In ⋯
──────┼─────────────────────────────────────────────────────────────────────────
1 │       1      1     88  F           13.1       1.3       0.5     30     ⋯
2 │       2      2     78  F           11.5       1.2       2.0     25
3 │       3      3     94  M           10.5       1.5       2.6     46
4 │       4      4     68  M           15.2       1.2       1.2     92
5 │       5      5     90  F           10.7       0.8       1.0      8     ⋯
6 │       6      6     90  M           12.9       1.0       0.5      4
7 │       7      7     89  F           10.5       0.9       1.3    151
8 │       8      8     87  F           12.3       1.2       1.6      2
⋮   │    ⋮       ⋮      ⋮      ⋮        ⋮         ⋮         ⋮        ⋮       ⋱
1378 │    1378   1378     56  M           16.1       0.8       0.5     59     ⋯
1379 │    1379   1379     73  M           15.6       1.1       0.5     48
1380 │    1380   1380     69  M           15.0       0.8       0.0     22
1381 │    1381   1381     78  M           14.1       1.3       1.9     35
1382 │    1382   1382     66  M           12.1       2.0       0.0     31     ⋯
1383 │    1383   1383     82  F           11.5       1.1       2.3     38
1384 │    1384   1384     79  M            9.6       1.1       1.7      6
3 columns and 1369 rows omitted
X = df[:, 3:(end - 4)]
died = df.death .== 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) =
IAI.split_data(:survival, X, died, times, seed=12345)

## Random Forest Survival Learner

We will use a GridSearch to fit a RandomForestSurvivalLearner with some basic parameter validation:

grid = IAI.GridSearch(
IAI.RandomForestSurvivalLearner(
missingdatamode=:separate_class,
random_seed=1,
),
max_depth=5:10,
)
IAI.fit!(grid, train_X, train_died, train_times)
All Grid Results:

Row │ max_depth  train_score  valid_score  rank_valid_score
│ Int64      Float64      Float64      Int64
─────┼───────────────────────────────────────────────────────
1 │         5     0.291592     0.216862                 1
2 │         6     0.310289     0.212261                 2
3 │         7     0.322973     0.206279                 3
4 │         8     0.329134     0.202891                 4
5 │         9     0.331598     0.20266                  5
6 │        10     0.332294     0.202477                 6

Best Params:
max_depth => 5

Best Model - Fitted RandomForestSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
0.187396493728
0.212154980723
0.092591201757
0.112061137082
0.215569821689
0.065290983536
0.212558861798
0.183942502458
0.135830805771
0.059289362626
⋮
0.044168066296
0.210253885428
0.087656625803
0.253140933592
0.016125002538
0.102207823668
0.14985600766
0.082458746894
0.175596007719

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
0.187396493728
0.212154980723
0.092591201757
0.112061137082
0.215569821689
0.065290983536
0.212558861798
0.183942502458
0.135830805771
0.059289362626
⋮
0.044168066296
0.210253885428
0.087656625803
0.253140933592
0.016125002538
0.102207823668
0.14985600766
0.082458746894
0.175596007719

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(grid, test_X)
415-element Vector{Float64}:
62.960069798779
60.451911795869
128.122472190415
80.415385085702
58.226319438976
135.239279269937
109.304740210212
68.303468814585
133.166596426567
115.056024736284
⋮
182.832081869042
59.377709561932
127.219213341943
86.951384261595
313.631215413167
77.03670255963
56.40262151127
86.157655986924
116.585871130823

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
criterion=:harrell_c_statistic)
0.728648073106186

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.7130325263974928

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
evaluation_time=10)
0.7750750750750736

We can also look at the variable importance:

IAI.variable_importance(IAI.get_learner(grid))
5×2 DataFrame
Row │ Feature  Importance
│ Symbol   Float64
─────┼─────────────────────
1 │ age       0.62681
2 │ hgb       0.187153
3 │ creat     0.109443
4 │ mspike    0.0489006
5 │ sex       0.0276938

## XGBoost Survival Learner

We will use a GridSearch to fit a XGBoostSurvivalLearner with some basic parameter validation:

grid = IAI.GridSearch(
IAI.XGBoostSurvivalLearner(
random_seed=1,
),
max_depth=2:5,
num_round=[20, 50, 100],
)
IAI.fit!(grid, train_X, train_died, train_times)
All Grid Results:

Row │ num_round  max_depth  train_score  valid_score  rank_valid_score
│ Int64      Int64      Float64      Float64      Int64
─────┼──────────────────────────────────────────────────────────────────
1 │        20          2    0.221512    0.164622                   1
2 │        20          3    0.175214   -0.134411                   4
3 │        20          4    0.143079   -0.120157                   3
4 │        20          5    0.0489262  -0.474636                   7
5 │        50          2    0.196386   -0.00737584                 2
6 │        50          3    0.138169   -0.307373                   6
7 │        50          4   -0.096126   -0.571589                   8
8 │        50          5   -0.455181   -1.27913                   11
9 │       100          2    0.11895    -0.226264                   5
10 │       100          3   -0.122828   -0.851117                   9
11 │       100          4   -0.558718   -1.20169                   10
12 │       100          5   -1.32397    -2.35255                   12

Best Params:
num_round => 20
max_depth => 2

Best Model - Fitted XGBoostSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
0.146602006944
0.181972807304
0.089343189025
0.233761179821
0.210810046477
0.068637260518
0.111577430497
0.141948041866
0.092365925364
0.071650634104
⋮
0.041039651399
0.146602006944
0.081671053483
0.104403941284
0.011575818595
0.144190738587
0.171362167869
0.136016020377
0.19172853442

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
0.146602006944
0.181972807304
0.089343189025
0.233761179821
0.210810046477
0.068637260518
0.111577430497
0.141948041866
0.092365925364
0.071650634104
⋮
0.041039651399
0.146602006944
0.081671053483
0.104403941284
0.011575818595
0.144190738587
0.171362167869
0.136016020377
0.19172853442

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(grid, test_X)
415-element Vector{Float64}:
99.870825314231
78.7679563508
159.460151822213
58.731723642278
66.436568115034
194.523130853628
131.166829046173
103.316586760024
155.11382620326
188.780862231699
⋮
260.262232232197
99.870825314231
171.314554719091
139.421343481108
367.082586212797
101.633026339217
84.279003600824
107.98765306547
74.202136506095

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
criterion=:harrell_c_statistic)
0.7193754134965267

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.7093859379347714

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
evaluation_time=10)
0.689069069069069

We can also look at the variable importance:

IAI.variable_importance(IAI.get_learner(grid))
5×2 DataFrame
Row │ Feature  Importance
│ Symbol   Float64
─────┼─────────────────────
1 │ age       0.655875
2 │ hgb       0.148613
3 │ creat     0.100927
4 │ mspike    0.0557582
5 │ sex=M     0.0388273

We can calculate the SHAP values:

s = IAI.predict_shap(grid, test_X)
s[:shap_values]
415×5 Matrix{Float64}:
0.640098   0.297386   -0.0352981   -0.022563    -0.145025
0.979197   0.176634   -0.0352981   -0.00424254  -0.145025
0.212314   0.136468   -0.00737888  -0.014659    -0.119171
1.37775    0.0550153  -0.0434586   -0.0128254   -0.123347
0.68526    0.275957    0.0332857   -0.0252807    0.166434
0.201457  -0.130303   -0.0384152   -0.0024089   -0.0974934
-0.165168   0.449972    0.0632479   -0.0148529    0.108749
0.763955   0.138876   -0.0434586   -0.014659    -0.145025
0.119493   0.291889   -0.0350694   -0.014659    -0.119171
0.255614  -0.127896   -0.0399963   -0.0128254   -0.0974934
⋮
-0.427469  -0.173282   -0.057895    -0.0130192    0.0757421
0.662603   0.23594    -0.00426181  -0.014659    -0.145025
0.173659  -0.14895    -0.00977322  -0.00512661   0.103846
-0.487831   0.518108    0.456193     0.0483653   -0.163283
-1.74569   -0.185214    0.0623039   -0.0837539    0.0757421
0.713682  -0.152875    0.0275326   -0.00512661   0.133427
0.878975  -0.129856    0.0275326   -0.00512661   0.133427
0.64915   -0.152875    0.0690641   -0.0155431    0.103846
-0.353826   0.312178    0.220955     0.68207      0.167903

We can then use the SHAP library in Python to visualize these results in whichever way we prefer. For example, the following code creates a summary plot:

using PyCall
shap = pyimport("shap")
shap.summary_plot(s[:shap_values], Matrix(s[:features]), names(s[:features]))

## GLMNetCV Survival Learner

We will fit a GLMNetCVSurvivalLearner directly. Since GLMNetCVLearner does not support missing data, we first subset to complete cases:

lnr = IAI.GLMNetCVSurvivalLearner(random_seed=1)
cc = completecases(train_X)
train_X_cc = train_X[cc, :]
train_died_cc = train_died[cc]
train_times_cc = train_times[cc]
IAI.fit!(lnr, train_X_cc, train_died_cc, train_times_cc)
Fitted GLMNetCVSurvivalLearner:
Constant: -3.319
Weights:
age:    0.0578244
creat:  0.18159
hgb:   -0.0889831
sex=M:  0.300142

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(lnr, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
0.200607896718
0.370165075078
0.122585629449
0.258150249738
0.263183570494
0.097649876452
0.130211918883
0.165323978367
0.142078644844
0.115488054643
⋮
0.077102775305
0.225010337685
0.132377147815
0.13317384091
0.032757577054
0.310343253654
0.295608518778
0.191789512261
0.123660862616

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(lnr, test_X, t=10)
415-element Vector{Float64}:
0.200607896718
0.370165075078
0.122585629449
0.258150249738
0.263183570494
0.097649876452
0.130211918883
0.165323978367
0.142078644844
0.115488054643
⋮
0.077102775305
0.225010337685
0.132377147815
0.13317384091
0.032757577054
0.310343253654
0.295608518778
0.191789512261
0.123660862616

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(lnr, test_X)
415-element Vector{Float64}:
69.7131444765
32.2154314147
118.4692693208
51.547362234
50.3308951517
146.1762990845
111.587096372
86.7639725725
102.0561682313
125.482504477
⋮
176.9479695605
60.8993919601
109.7463295443
109.0809614168
283.634785317
40.8347782051
43.4798069142
73.426161112
117.4596540356

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(lnr, train_X_cc, train_died_cc, train_times_cc,
criterion=:harrell_c_statistic)
0.6837433757104808

Or on the test set:

cc = completecases(test_X)
test_X_cc = test_X[cc, :]
test_died_cc = test_died[cc]
test_times_cc = test_times[cc]
IAI.score(lnr, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.6790416355461828

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(lnr, test_X_cc, test_died_cc, test_times_cc, criterion=:auc,
evaluation_time=10)
0.6823274764451254

We can also look at the variable importance:

IAI.variable_importance(lnr)
5×2 DataFrame
Row │ Feature  Importance
│ Symbol   Float64
─────┼─────────────────────
1 │ age        0.59208
2 │ hgb        0.153895
3 │ sex=M      0.127414
4 │ creat      0.12661
5 │ mspike     0.0