# Quick Start Guide: Heuristic Survival Learners

In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

using CSV, DataFrames
df = CSV.read("mgus2.csv", DataFrame, missingstring="NA", pool=true)
1384×11 DataFrame
Row │ Column1  id     age    sex     hgb       creat     mspike    ptime  ps ⋯
│ Int64    Int64  Int64  String  Float64?  Float64?  Float64?  Int64  In ⋯
──────┼─────────────────────────────────────────────────────────────────────────
1 │       1      1     88  F           13.1       1.3       0.5     30     ⋯
2 │       2      2     78  F           11.5       1.2       2.0     25
3 │       3      3     94  M           10.5       1.5       2.6     46
4 │       4      4     68  M           15.2       1.2       1.2     92
5 │       5      5     90  F           10.7       0.8       1.0      8     ⋯
6 │       6      6     90  M           12.9       1.0       0.5      4
7 │       7      7     89  F           10.5       0.9       1.3    151
8 │       8      8     87  F           12.3       1.2       1.6      2
⋮   │    ⋮       ⋮      ⋮      ⋮        ⋮         ⋮         ⋮        ⋮       ⋱
1378 │    1378   1378     56  M           16.1       0.8       0.5     59     ⋯
1379 │    1379   1379     73  M           15.6       1.1       0.5     48
1380 │    1380   1380     69  M           15.0       0.8       0.0     22
1381 │    1381   1381     78  M           14.1       1.3       1.9     35
1382 │    1382   1382     66  M           12.1       2.0       0.0     31     ⋯
1383 │    1383   1383     82  F           11.5       1.1       2.3     38
1384 │    1384   1384     79  M            9.6       1.1       1.7      6
3 columns and 1369 rows omitted
X = df[:, 3:(end - 4)]
died = df.death .== 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) =
IAI.split_data(:survival, X, died, times, seed=12345)

## Random Forest Survival Learner

We will use a GridSearch to fit a RandomForestSurvivalLearner with some basic parameter validation:

grid = IAI.GridSearch(
IAI.RandomForestSurvivalLearner(
missingdatamode=:separate_class,
random_seed=1,
),
max_depth=5:10,
)
IAI.fit!(grid, train_X, train_died, train_times)
All Grid Results:

Row │ max_depth  train_score  valid_score  rank_valid_score
│ Int64      Float64      Float64      Int64
─────┼───────────────────────────────────────────────────────
1 │         5     0.291592     0.216862                 1
2 │         6     0.310289     0.212261                 2
3 │         7     0.322973     0.206279                 3
4 │         8     0.329134     0.202891                 4
5 │         9     0.331598     0.20266                  5
6 │        10     0.332294     0.202477                 6

Best Params:
max_depth => 5

Best Model - Fitted RandomForestSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
0.187396493728
0.212154980723
0.092591201757
0.112061137082
0.215569821689
0.065290983536
0.212558861798
0.183942502458
0.135830805771
0.059289362626
⋮
0.044168066296
0.210253885428
0.087656625803
0.253140933592
0.016125002538
0.102207823668
0.14985600766
0.082458746894
0.175596007719

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
0.187396493728
0.212154980723
0.092591201757
0.112061137082
0.215569821689
0.065290983536
0.212558861798
0.183942502458
0.135830805771
0.059289362626
⋮
0.044168066296
0.210253885428
0.087656625803
0.253140933592
0.016125002538
0.102207823668
0.14985600766
0.082458746894
0.175596007719

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(grid, test_X)
415-element Vector{Float64}:
62.960069798779
60.451911795869
128.122472190415
80.415385085702
58.226319438976
135.239279269937
109.304740210212
68.303468814585
133.166596426567
115.056024736284
⋮
182.832081869042
59.377709561932
127.219213341943
86.951384261595
313.631215413167
77.03670255963
56.40262151127
86.157655986924
116.585871130823

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
criterion=:harrell_c_statistic)
0.728648073106186

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.7130325263974928

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
evaluation_time=10)
0.7750750750750736

We can also look at the variable importance:

IAI.variable_importance(IAI.get_learner(grid))
5×2 DataFrame
Row │ Feature  Importance
│ Symbol   Float64
─────┼─────────────────────
1 │ age       0.62681
2 │ hgb       0.187153
3 │ creat     0.109443
4 │ mspike    0.0489006
5 │ sex       0.0276938

## XGBoost Survival Learner

We will use a GridSearch to fit a XGBoostSurvivalLearner with some basic parameter validation:

grid = IAI.GridSearch(
IAI.XGBoostSurvivalLearner(
random_seed=1,
),
max_depth=2:5,
num_round=[20, 50, 100],
)
IAI.fit!(grid, train_X, train_died, train_times)
All Grid Results:

Row │ num_round  max_depth  train_score  valid_score  rank_valid_score
│ Int64      Int64      Float64      Float64      Int64
─────┼──────────────────────────────────────────────────────────────────
1 │        20          2    0.215788     0.15041                   1
2 │        20          3    0.194945    -0.0670844                 3
3 │        20          4    0.183779    -0.121177                  4
4 │        20          5    0.0916702   -0.401915                  7
5 │        50          2    0.210932     0.0262287                 2
6 │        50          3    0.153306    -0.217315                  6
7 │        50          4    0.0281122   -0.513705                  9
8 │        50          5   -0.180287    -0.87479                  10
9 │       100          2    0.182667    -0.132212                  5
10 │       100          3    0.0250015   -0.512682                  8
11 │       100          4   -0.345342    -1.03887                  11
12 │       100          5   -0.907852    -1.75593                  12

Best Params:
num_round => 20
max_depth => 2

Best Model - Fitted XGBoostSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
0.14218
0.19365
0.10101
0.21149
0.21305
0.06571
0.11277
0.13989
0.09729
0.06921
⋮
0.04348
0.14986
0.08444
0.14099
0.01246
0.1504
0.21825
0.12885
0.22829

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
0.14218
0.19365
0.10101
0.21149
0.21305
0.06571
0.11277
0.13989
0.09729
0.06921
⋮
0.04348
0.14986
0.08444
0.14099
0.01246
0.1504
0.21825
0.12885
0.22829

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(grid, test_X)
415-element Vector{Float64}:
103.14
73.356
143.604
66.184
65.612
200.331
129.869
104.9
148.39
193.413
⋮
253.328
97.563
166.9
104.048
363.12
97.19
63.766
114.09
60.43

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
criterion=:harrell_c_statistic)
0.71899034

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.70912965

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
evaluation_time=10)
0.68960961

We can also look at the variable importance:

IAI.variable_importance(IAI.get_learner(grid))
5×2 DataFrame
Row │ Feature  Importance
│ Symbol   Float64
─────┼─────────────────────
1 │ age       0.653471
2 │ hgb       0.154132
3 │ creat     0.0930732
4 │ mspike    0.0610053
5 │ sex=M     0.0383188

We can calculate the SHAP values:

s = IAI.predict_shap(grid, test_X)
s[:shap_values]
415×5 Matrix{Float64}:
0.61058   0.30582  -0.04894  -0.01826  -0.14742
1.07476   0.16585  -0.04894  -0.00356  -0.14742
0.23345   0.18522   0.05296  -0.01366  -0.12105
1.4023   -0.05959  -0.05695  -0.01366  -0.13249
0.69856   0.2627    0.04495  -0.02039   0.16211
0.18722  -0.14855  -0.04093  -0.00356  -0.10612
-0.18245   0.42472   0.12438  -0.01668   0.10355
0.77792   0.12435  -0.05695  -0.01366  -0.14742
0.13403   0.33627  -0.03813  -0.01366  -0.12105
0.24808  -0.14301  -0.04354  -0.01366  -0.10612
⋮
-0.38698  -0.15637  -0.05912  -0.01668   0.08266
0.63852   0.2364    0.04495  -0.01366  -0.14742
0.15582  -0.16502   0.05296  -0.00568   0.11068
-0.39859   0.50849   0.43877   0.32741  -0.18338
-1.78685  -0.19225   0.18891  -0.08385   0.07211
0.73717  -0.14956   0.03954  -0.00568   0.14121
1.07879  -0.07867   0.03954  -0.00568   0.14121
0.6234   -0.16209   0.03954  -0.01578   0.11068
-0.26166   0.29492   0.19171   0.81775   0.18365

We can then use the SHAP library in Python to visualize these results in whichever way we prefer. For example, the following code creates a summary plot:

using PyCall
shap = pyimport("shap")
shap.summary_plot(s[:shap_values], Matrix(s[:features]), names(s[:features])) ## GLMNetCV Survival Learner

We will fit a GLMNetCVSurvivalLearner directly. Since GLMNetCVLearner does not support missing data, we first subset to complete cases:

train_cc = completecases(train_X)
train_X_cc = train_X[train_cc, :]
train_died_cc = train_died[train_cc]
train_times_cc = train_times[train_cc]

test_cc = completecases(test_X)
test_X_cc = test_X[test_cc, :]
test_died_cc = test_died[test_cc]
test_times_cc = test_times[test_cc]

We can now proceed with fitting the learner:

lnr = IAI.GLMNetCVSurvivalLearner(random_seed=1)
IAI.fit!(lnr, train_X_cc, train_died_cc, train_times_cc)
Fitted GLMNetCVSurvivalLearner:
Constant: -3.319
Weights:
age:    0.0578244
creat:  0.18159
hgb:   -0.0889831
sex=M:  0.300142

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(lnr, test_X_cc)
t = 10
[c[t] for c in pred_curves]
401-element Vector{Float64}:
0.200607896718
0.370165075078
0.122585629449
0.258150249738
0.263183570494
0.097649876452
0.130211918883
0.165323978367
0.142078644844
0.115488054643
⋮
0.288921901855
0.077102775305
0.225010337685
0.132377147815
0.032757577054
0.310343253654
0.295608518778
0.191789512261
0.123660862616

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(lnr, test_X_cc, t=10)
401-element Vector{Float64}:
0.200607896718
0.370165075078
0.122585629449
0.258150249738
0.263183570494
0.097649876452
0.130211918883
0.165323978367
0.142078644844
0.115488054643
⋮
0.288921901855
0.077102775305
0.225010337685
0.132377147815
0.032757577054
0.310343253654
0.295608518778
0.191789512261
0.123660862616

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(lnr, test_X_cc)
401-element Vector{Float64}:
69.71314448
32.21543141
118.46926932
51.54736223
50.33089515
146.17629908
111.58709637
86.76397257
102.05616823
125.48250448
⋮
44.76791753
176.94796956
60.89939196
109.74632954
283.63478532
40.83477821
43.47980691
73.42616111
117.45965404

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(lnr, train_X_cc, train_died_cc, train_times_cc,
criterion=:harrell_c_statistic)
0.6837433757104808

Or on the test set:

IAI.score(lnr, test_X_cc, test_died_cc, test_times_cc,
criterion=:harrell_c_statistic)
0.7027583847744624

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(lnr, test_X_cc, test_died_cc, test_times_cc, criterion=:auc,
evaluation_time=10)
0.6823274764451254

We can also look at the variable importance:

IAI.variable_importance(lnr)
5×2 DataFrame
Row │ Feature  Importance
│ Symbol   Float64
─────┼─────────────────────
1 │ age        0.59208
2 │ hgb        0.153895
3 │ sex=M      0.127414
4 │ creat      0.12661
5 │ mspike     0.0