Quick Start Guide: Heuristic Survival Learners

In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

using CSV, DataFrames
df = CSV.read("mgus2.csv", DataFrame, missingstring="NA", pool=true)
1384×11 DataFrame
  Row │ Column1  id     age    sex     hgb       creat     mspike    ptime  ps ⋯
      │ Int64    Int64  Int64  String  Float64?  Float64?  Float64?  Int64  In ⋯
──────┼─────────────────────────────────────────────────────────────────────────
    1 │       1      1     88  F           13.1       1.3       0.5     30     ⋯
    2 │       2      2     78  F           11.5       1.2       2.0     25
    3 │       3      3     94  M           10.5       1.5       2.6     46
    4 │       4      4     68  M           15.2       1.2       1.2     92
    5 │       5      5     90  F           10.7       0.8       1.0      8     ⋯
    6 │       6      6     90  M           12.9       1.0       0.5      4
    7 │       7      7     89  F           10.5       0.9       1.3    151
    8 │       8      8     87  F           12.3       1.2       1.6      2
  ⋮   │    ⋮       ⋮      ⋮      ⋮        ⋮         ⋮         ⋮        ⋮       ⋱
 1378 │    1378   1378     56  M           16.1       0.8       0.5     59     ⋯
 1379 │    1379   1379     73  M           15.6       1.1       0.5     48
 1380 │    1380   1380     69  M           15.0       0.8       0.0     22
 1381 │    1381   1381     78  M           14.1       1.3       1.9     35
 1382 │    1382   1382     66  M           12.1       2.0       0.0     31     ⋯
 1383 │    1383   1383     82  F           11.5       1.1       2.3     38
 1384 │    1384   1384     79  M            9.6       1.1       1.7      6
                                                 3 columns and 1369 rows omitted
X = df[:, 3:(end - 4)]
died = df.death .== 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) =
    IAI.split_data(:survival, X, died, times, seed=12345)

Random Forest Survival Learner

We will use a GridSearch to fit a RandomForestSurvivalLearner with some basic parameter validation:

grid = IAI.GridSearch(
    IAI.RandomForestSurvivalLearner(
        missingdatamode=:separate_class,
        random_seed=1,
    ),
    max_depth=5:10,
)
IAI.fit!(grid, train_X, train_died, train_times)
All Grid Results:

 Row │ max_depth  train_score  valid_score  rank_valid_score
     │ Int64      Float64      Float64      Int64
─────┼───────────────────────────────────────────────────────
   1 │         5     0.291592     0.216862                 1
   2 │         6     0.310289     0.212261                 2
   3 │         7     0.322973     0.206279                 3
   4 │         8     0.329134     0.202891                 4
   5 │         9     0.331598     0.20266                  5
   6 │        10     0.332294     0.202477                 6

Best Params:
  max_depth => 5

Best Model - Fitted RandomForestSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
 0.187396493728
 0.212154980723
 0.092591201757
 0.112061137082
 0.215569821689
 0.065290983536
 0.212558861798
 0.183942502458
 0.135830805771
 0.059289362626
 ⋮
 0.044168066296
 0.210253885428
 0.087656625803
 0.253140933592
 0.016125002538
 0.102207823668
 0.14985600766
 0.082458746894
 0.175596007719

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
 0.187396493728
 0.212154980723
 0.092591201757
 0.112061137082
 0.215569821689
 0.065290983536
 0.212558861798
 0.183942502458
 0.135830805771
 0.059289362626
 ⋮
 0.044168066296
 0.210253885428
 0.087656625803
 0.253140933592
 0.016125002538
 0.102207823668
 0.14985600766
 0.082458746894
 0.175596007719

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(grid, test_X)
415-element Vector{Float64}:
  62.960069798779
  60.451911795869
 128.122472190415
  80.415385085702
  58.226319438976
 135.239279269937
 109.304740210212
  68.303468814585
 133.166596426567
 115.056024736284
   ⋮
 182.832081869042
  59.377709561932
 127.219213341943
  86.951384261595
 313.631215413167
  77.03670255963
  56.40262151127
  86.157655986924
 116.585871130823

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
          criterion=:harrell_c_statistic)
0.728648073106186

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.7130325263974928

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 100-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
          evaluation_time=10)
0.7750750750750736

We can also look at the variable importance:

IAI.variable_importance(IAI.get_learner(grid))
5×2 DataFrame
 Row │ Feature  Importance
     │ Symbol   Float64
─────┼─────────────────────
   1 │ age       0.62681
   2 │ hgb       0.187153
   3 │ creat     0.109443
   4 │ mspike    0.0489006
   5 │ sex       0.0276938

XGBoost Survival Learner

We will use a GridSearch to fit a XGBoostSurvivalLearner with some basic parameter validation:

grid = IAI.GridSearch(
    IAI.XGBoostSurvivalLearner(
        random_seed=1,
    ),
    max_depth=2:5,
    num_round=[20, 50, 100],
)
IAI.fit!(grid, train_X, train_died, train_times)
All Grid Results:

 Row │ num_round  max_depth  train_score  valid_score  rank_valid_score
     │ Int64      Int64      Float64      Float64      Int64
─────┼──────────────────────────────────────────────────────────────────
   1 │        20          2    0.221512    0.164622                   1
   2 │        20          3    0.175214   -0.134411                   4
   3 │        20          4    0.143079   -0.120157                   3
   4 │        20          5    0.0489262  -0.474636                   7
   5 │        50          2    0.196386   -0.00737584                 2
   6 │        50          3    0.138169   -0.307373                   6
   7 │        50          4   -0.096126   -0.571589                   8
   8 │        50          5   -0.455181   -1.27913                   11
   9 │       100          2    0.11895    -0.226264                   5
  10 │       100          3   -0.122828   -0.851117                   9
  11 │       100          4   -0.558718   -1.20169                   10
  12 │       100          5   -1.32397    -2.35255                   12

Best Params:
  num_round => 20
  max_depth => 2

Best Model - Fitted XGBoostSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = IAI.predict(grid, test_X)
t = 10
[c[t] for c in pred_curves]
415-element Vector{Float64}:
 0.146602006944
 0.181972807304
 0.089343189025
 0.233761179821
 0.210810046477
 0.068637260518
 0.111577430497
 0.141948041866
 0.092365925364
 0.071650634104
 ⋮
 0.041039651399
 0.146602006944
 0.081671053483
 0.104403941284
 0.011575818595
 0.144190738587
 0.171362167869
 0.136016020377
 0.19172853442

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

IAI.predict(grid, test_X, t=10)
415-element Vector{Float64}:
 0.146602006944
 0.181972807304
 0.089343189025
 0.233761179821
 0.210810046477
 0.068637260518
 0.111577430497
 0.141948041866
 0.092365925364
 0.071650634104
 ⋮
 0.041039651399
 0.146602006944
 0.081671053483
 0.104403941284
 0.011575818595
 0.144190738587
 0.171362167869
 0.136016020377
 0.19172853442

We can also estimate the survival time for each point:

IAI.predict_expected_survival_time(grid, test_X)
415-element Vector{Float64}:
  99.870825314231
  78.7679563508
 159.460151822213
  58.731723642278
  66.436568115034
 194.523130853628
 131.166829046173
 103.316586760024
 155.11382620326
 188.780862231699
   ⋮
 260.262232232197
  99.870825314231
 171.314554719091
 139.421343481108
 367.082586212797
 101.633026339217
  84.279003600824
 107.98765306547
  74.202136506095

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

IAI.score(grid, train_X, train_died, train_times,
          criterion=:harrell_c_statistic)
0.7193754134965267

Or on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:harrell_c_statistic)
0.7093859379347714

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 100-month survival predictions on the test set:

IAI.score(grid, test_X, test_died, test_times, criterion=:auc,
          evaluation_time=10)
0.689069069069069

We can also look at the variable importance:

IAI.variable_importance(IAI.get_learner(grid))
5×2 DataFrame
 Row │ Feature  Importance
     │ Symbol   Float64
─────┼─────────────────────
   1 │ age       0.655875
   2 │ hgb       0.148613
   3 │ creat     0.100927
   4 │ mspike    0.0557582
   5 │ sex_M     0.0388273

We can calculate the SHAP values:

s = IAI.predict_shap(grid, test_X)
s[:shap_values]
415×5 Matrix{Float64}:
  0.640098   0.297386   -0.0352981   -0.022563    -0.145025
  0.979197   0.176634   -0.0352981   -0.00424254  -0.145025
  0.212314   0.136468   -0.00737888  -0.014659    -0.119171
  1.37775    0.0550153  -0.0434586   -0.0128254   -0.123347
  0.68526    0.275957    0.0332857   -0.0252807    0.166434
  0.201457  -0.130303   -0.0384152   -0.0024089   -0.0974934
 -0.165168   0.449972    0.0632479   -0.0148529    0.108749
  0.763955   0.138876   -0.0434586   -0.014659    -0.145025
  0.119493   0.291889   -0.0350694   -0.014659    -0.119171
  0.255614  -0.127896   -0.0399963   -0.0128254   -0.0974934
  ⋮
 -0.427469  -0.173282   -0.057895    -0.0130192    0.0757421
  0.662603   0.23594    -0.00426181  -0.014659    -0.145025
  0.173659  -0.14895    -0.00977322  -0.00512661   0.103846
 -0.487831   0.518108    0.456193     0.0483653   -0.163283
 -1.74569   -0.185214    0.0623039   -0.0837539    0.0757421
  0.713682  -0.152875    0.0275326   -0.00512661   0.133427
  0.878975  -0.129856    0.0275326   -0.00512661   0.133427
  0.64915   -0.152875    0.0690641   -0.0155431    0.103846
 -0.353826   0.312178    0.220955     0.68207      0.167903

We can then use the SHAP library in Python to visualize these results in whichever way we prefer. For example, the following code creates a summary plot:

using PyCall
shap = pyimport("shap")
shap.summary_plot(s[:shap_values], Matrix(s[:features]), names(s[:features]))