Quick Start Guide: Heuristic Survival Learners

In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

using CSV, DataFrames
df = CSV.read("mgus2.csv", DataFrame, missingstring="NA", pool=true)
1384×11 DataFrame
  Row │ Column1  id     age    sex     hgb       creat     mspike    ptime  ps ⋯
      │ Int64    Int64  Int64  String  Float64?  Float64?  Float64?  Int64  In ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │       1      1     88  F           13.1       1.3       0.5     30     ⋯
    2 │       2      2     78  F           11.5       1.2       2.0     25
    3 │       3      3     94  M           10.5       1.5       2.6     46
    4 │       4      4     68  M           15.2       1.2       1.2     92
    5 │       5      5     90  F           10.7       0.8       1.0      8     ⋯
    6 │       6      6     90  M           12.9       1.0       0.5      4
    7 │       7      7     89  F           10.5       0.9       1.3    151
    8 │       8      8     87  F           12.3       1.2       1.6      2
  ⋮   │    ⋮       ⋮      ⋮      ⋮        ⋮         ⋮         ⋮        ⋮       ⋱
 1378 │    1378   1378     56  M           16.1       0.8       0.5     59     ⋯
 1379 │    1379   1379     73  M           15.6       1.1       0.5     48
 1380 │    1380   1380     69  M           15.0       0.8       0.0     22
 1381 │    1381   1381     78  M           14.1       1.3       1.9     35
 1382 │    1382   1382     66  M           12.1       2.0       0.0     31     ⋯
 1383 │    1383   1383     82  F           11.5       1.1       2.3     38
 1384 │    1384   1384     79  M            9.6       1.1       1.7      6
                                                 3 columns and 1369 rows omitted
X = df[:, 3:(end - 4)]
died = df.death .== 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) =
    IAI.split_data(:survival, X, died, times, seed=12345)

Random Forest Survival Learner

We will use a GridSearch to fit a RandomForestSurvivalLearner with some basic parameter validation:

grid = IAI.GridSearch(
    IAI.RandomForestSurvivalLearner(
        missingdatamode=:separate_class,
        random_seed=1,
    ),
    max_depth=5:10,
)
IAI.fit!(grid, train_X, train_died, train_times)
All Grid Results:

 Row │ max_depth  train_score  valid_score  rank_valid_score
     │ Int64      Float64      Float64      Int64
─────┼───────────────────────────────────────────────────────
   1 │         5     0.291592     0.216862                 1
   2 │         6     0.310289     0.212261                 2
   3 │         7     0.322973     0.206279                 3
   4 │         8     0.329134     0.202891                 4
   5 │         9     0.331598     0.20266                  5
   6 │        10     0.332294     0.202477                 6

Best Params:
  max_depth => 5

Best Model - Fitted RandomForestSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
 0.187396493728
 0.212154980723
 0.092591201757
 0.112061137082
 0.215569821689
 0.065290983536
 0.212558861798
 0.183942502458
 0.135830805771
 0.059289362626
 ⋮
 0.044168066296
 0.210253885428
 0.087656625803
 0.253140933592
 0.016125002538
 0.102207823668
 0.14985600766
 0.082458746894
 0.175596007719

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
 0.187396493728
 0.212154980723
 0.092591201757
 0.112061137082
 0.215569821689
 0.065290983536
 0.212558861798
 0.183942502458
 0.135830805771
 0.059289362626
 ⋮
 0.044168066296
 0.210253885428
 0.087656625803
 0.253140933592
 0.016125002538
 0.102207823668
 0.14985600766
 0.082458746894
 0.175596007719

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(grid, test_X)
415-element Vector{Float64}:
  62.960069798779
  60.451911795869
 128.122472190415
  80.415385085702
  58.226319438976
 135.239279269937
 109.304740210212
  68.303468814585
 133.166596426567
 115.056024736284
   ⋮
 182.832081869042
  59.377709561932
 127.219213341943
  86.951384261595
 313.631215413167
  77.03670255963
  56.40262151127
  86.157655986924
 116.585871130823

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
          criterion=:harrell_c_statistic)
0.728648073106186

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.7130325263974928

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
          evaluation_time=10)
0.7750750750750736

We can also look at the variable importance:

IAI.variable_importance(IAI.get_learner(grid))
5×2 DataFrame
 Row │ Feature  Importance
     │ Symbol   Float64
─────┼─────────────────────
   1 │ age       0.62681
   2 │ hgb       0.187153
   3 │ creat     0.109443
   4 │ mspike    0.0489006
   5 │ sex       0.0276938

XGBoost Survival Learner

We will use a GridSearch to fit a XGBoostSurvivalLearner with some basic parameter validation:

grid = IAI.GridSearch(
    IAI.XGBoostSurvivalLearner(
        random_seed=1,
    ),
    max_depth=2:5,
    num_round=[20, 50, 100],
)
IAI.fit!(grid, train_X, train_died, train_times)
All Grid Results:

 Row │ num_round  max_depth  train_score  valid_score  rank_valid_score
     │ Int64      Int64      Float64      Float64      Int64
─────┼──────────────────────────────────────────────────────────────────
   1 │        20          2    0.215788     0.15041                   1
   2 │        20          3    0.194945    -0.0670844                 3
   3 │        20          4    0.183779    -0.121177                  4
   4 │        20          5    0.0916702   -0.401915                  7
   5 │        50          2    0.210932     0.0262287                 2
   6 │        50          3    0.153306    -0.217315                  6
   7 │        50          4    0.0281122   -0.513705                  9
   8 │        50          5   -0.180287    -0.87479                  10
   9 │       100          2    0.182667    -0.132212                  5
  10 │       100          3    0.0250015   -0.512682                  8
  11 │       100          4   -0.345342    -1.03887                  11
  12 │       100          5   -0.907852    -1.75593                  12

Best Params:
  num_round => 20
  max_depth => 2

Best Model - Fitted XGBoostSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
 0.14218
 0.19365
 0.10101
 0.21149
 0.21305
 0.06571
 0.11277
 0.13989
 0.09729
 0.06921
 ⋮
 0.04348
 0.14986
 0.08444
 0.14099
 0.01246
 0.1504
 0.21825
 0.12885
 0.22829

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
 0.14218
 0.19365
 0.10101
 0.21149
 0.21305
 0.06571
 0.11277
 0.13989
 0.09729
 0.06921
 ⋮
 0.04348
 0.14986
 0.08444
 0.14099
 0.01246
 0.1504
 0.21825
 0.12885
 0.22829

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(grid, test_X)
415-element Vector{Float64}:
 103.14
  73.356
 143.604
  66.184
  65.612
 200.331
 129.869
 104.9
 148.39
 193.413
   ⋮
 253.328
  97.563
 166.9
 104.048
 363.12
  97.19
  63.766
 114.09
  60.43

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
          criterion=:harrell_c_statistic)
0.71899034

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.70912965

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
          evaluation_time=10)
0.68960961

We can also look at the variable importance:

IAI.variable_importance(IAI.get_learner(grid))
5×2 DataFrame
 Row │ Feature  Importance
     │ Symbol   Float64
─────┼─────────────────────
   1 │ age        0.427724
   2 │ creat      0.200167
   3 │ hgb        0.154691
   4 │ sex=M      0.115374
   5 │ mspike     0.102045

We can calculate the SHAP values:

s = IAI.predict_shap(grid, test_X)
s[:shap_values]
415×5 Matrix{Float64}:
  0.61058   0.30582  -0.04894  -0.01826  -0.14742
  1.07476   0.16585  -0.04894  -0.00356  -0.14742
  0.23345   0.18522   0.05296  -0.01366  -0.12105
  1.4023   -0.05959  -0.05695  -0.01366  -0.13249
  0.69856   0.2627    0.04495  -0.02039   0.16211
  0.18722  -0.14855  -0.04093  -0.00356  -0.10612
 -0.18245   0.42472   0.12438  -0.01668   0.10355
  0.77792   0.12435  -0.05695  -0.01366  -0.14742
  0.13403   0.33627  -0.03813  -0.01366  -0.12105
  0.24808  -0.14301  -0.04354  -0.01366  -0.10612
  ⋮
 -0.38698  -0.15637  -0.05912  -0.01668   0.08266
  0.63852   0.2364    0.04495  -0.01366  -0.14742
  0.15582  -0.16502   0.05296  -0.00568   0.11068
 -0.39859   0.50849   0.43877   0.32741  -0.18338
 -1.78685  -0.19225   0.18891  -0.08385   0.07211
  0.73717  -0.14956   0.03954  -0.00568   0.14121
  1.07879  -0.07867   0.03954  -0.00568   0.14121
  0.6234   -0.16209   0.03954  -0.01578   0.11068
 -0.26166   0.29492   0.19171   0.81775   0.18365

We can then use the SHAP library in Python to visualize these results in whichever way we prefer. For example, the following code creates a summary plot:

using PyCall
shap = pyimport("shap")
shap.summary_plot(s[:shap_values], Matrix(s[:features]), names(s[:features]))

GLMNetCV Survival Learner

We will fit a GLMNetCVSurvivalLearner directly. Since GLMNetCVLearner does not support missing data, we first subset to complete cases:

train_cc = completecases(train_X)
train_X_cc = train_X[train_cc, :]
train_died_cc = train_died[train_cc]
train_times_cc = train_times[train_cc]

test_cc = completecases(test_X)
test_X_cc = test_X[test_cc, :]
test_died_cc = test_died[test_cc]
test_times_cc = test_times[test_cc]

We can now proceed with fitting the learner:

lnr = IAI.GLMNetCVSurvivalLearner(random_seed=1)
IAI.fit!(lnr, train_X_cc, train_died_cc, train_times_cc)
Fitted GLMNetCVSurvivalLearner:
  Constant: -3.319
  Weights:
    age:    0.0578244
    creat:  0.18159
    hgb:   -0.0889831
    sex=M:  0.300142

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(lnr, test_X_cc)
t = 10
[c[t] for c in pred_curves]
401-element Vector{Float64}:
 0.200607896718
 0.370165075078
 0.122585629449
 0.258150249738
 0.263183570494
 0.097649876452
 0.130211918883
 0.165323978367
 0.142078644844
 0.115488054643
 ⋮
 0.288921901855
 0.077102775305
 0.225010337685
 0.132377147815
 0.032757577054
 0.310343253654
 0.295608518778
 0.191789512261
 0.123660862616

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(lnr, test_X_cc, t=10)
401-element Vector{Float64}:
 0.200607896718
 0.370165075078
 0.122585629449
 0.258150249738
 0.263183570494
 0.097649876452
 0.130211918883
 0.165323978367
 0.142078644844
 0.115488054643
 ⋮
 0.288921901855
 0.077102775305
 0.225010337685
 0.132377147815
 0.032757577054
 0.310343253654
 0.295608518778
 0.191789512261
 0.123660862616

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(lnr, test_X_cc)
401-element Vector{Float64}:
  69.71314448
  32.21543141
 118.46926932
  51.54736223
  50.33089515
 146.17629908
 111.58709637
  86.76397257
 102.05616823
 125.48250448
   ⋮
  44.76791753
 176.94796956
  60.89939196
 109.74632954
 283.63478532
  40.83477821
  43.47980691
  73.42616111
 117.45965404

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(lnr, train_X_cc, train_died_cc, train_times_cc,
          criterion=:harrell_c_statistic)
0.6837433757104808

Or on the test set:

IAI.score(lnr, test_X_cc, test_died_cc, test_times_cc,
          criterion=:harrell_c_statistic)
0.7027583847744624

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

IAI.score(lnr, test_X_cc, test_died_cc, test_times_cc, criterion=:auc,
          evaluation_time=10)
0.6823274764451254

We can also look at the variable importance:

IAI.variable_importance(lnr)
5×2 DataFrame
 Row │ Feature  Importance
     │ Symbol   Float64
─────┼─────────────────────
   1 │ age        0.59208
   2 │ hgb        0.153895
   3 │ sex=M      0.127414
   4 │ creat      0.12661
   5 │ mspike     0.0