Quick Start Guide: Heuristic Survival Learners

This is an R version of the corresponding Heuristics quick start guide.

In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

df <- read.table("mgus2.csv", sep = ",", header = T, stringsAsFactors = T)
  X id age sex  hgb creat mspike ptime pstat futime death
1 1  1  88   F 13.1   1.3    0.5    30     0     30     1
2 2  2  78   F 11.5   1.2    2.0    25     0     25     1
3 3  3  94   M 10.5   1.5    2.6    46     0     46     1
4 4  4  68   M 15.2   1.2    1.2    92     0     92     1
5 5  5  90   F 10.7   0.8    1.0     8     0      8     1
 [ reached 'max' / getOption("max.print") -- omitted 1379 rows ]
X <- df[, 3:7]
died <- df$death == 1
times <- df$futime
split <- iai::split_data("survival", X, died, times, seed = 12345)
train_X <- split$train$X
train_died <- split$train$deaths
train_times <- split$train$times
test_X <- split$test$X
test_died <- split$test$deaths
test_times <- split$test$times

Random Forest Survival Learner

We will use a grid_search to fit a random_forest_survival_learner:

grid <- iai::grid_search(
    iai::random_forest_survival_learner(
        missingdatamode = "separate_class",
        random_seed = 1,
    ),
    max_depth = 5:10,
)
iai::fit(grid, train_X, train_died, train_times)

We can make predictions on new data using predict. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves <- iai::predict(grid, test_X)
t <- 10
pred_curves[[1]][t]
[1] 0.1873965

You can also query the probability for all of the points together by applying the [ function to each curve:

sapply(pred_curves, `[`, t)
 [1] 0.18739649 0.21215498 0.09259120 0.11206114 0.21556982 0.06529098
 [7] 0.21255886 0.18394250 0.13583081 0.05928936 0.06964021 0.05761070
[13] 0.17592700 0.09695835 0.05083491 0.21279057 0.21020701 0.13996345
[19] 0.04926693 0.07267342 0.21694837 0.18411062 0.07030906 0.19655744
[25] 0.08155705 0.19667612 0.16373136 0.09083397 0.07260804 0.03996479
[31] 0.03565278 0.06140443 0.04793617 0.19591522 0.07673138 0.19419976
[37] 0.06210086 0.01889528 0.05808891 0.07176119 0.02248062 0.08093347
[43] 0.06311591 0.21353333 0.04997100 0.08071766 0.06164327 0.05396804
[49] 0.05687817 0.08088472 0.01339789 0.09761082 0.04075485 0.05751192
[55] 0.06222023 0.21530971 0.14798813 0.23725987 0.05334866 0.15878501
 [ reached getOption("max.print") -- omitted 355 entries ]

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

iai::predict(grid, test_X, t = 10)
 [1] 0.18739649 0.21215498 0.09259120 0.11206114 0.21556982 0.06529098
 [7] 0.21255886 0.18394250 0.13583081 0.05928936 0.06964021 0.05761070
[13] 0.17592700 0.09695835 0.05083491 0.21279057 0.21020701 0.13996345
[19] 0.04926693 0.07267342 0.21694837 0.18411062 0.07030906 0.19655744
[25] 0.08155705 0.19667612 0.16373136 0.09083397 0.07260804 0.03996479
[31] 0.03565278 0.06140443 0.04793617 0.19591522 0.07673138 0.19419976
[37] 0.06210086 0.01889528 0.05808891 0.07176119 0.02248062 0.08093347
[43] 0.06311591 0.21353333 0.04997100 0.08071766 0.06164327 0.05396804
[49] 0.05687817 0.08088472 0.01339789 0.09761082 0.04075485 0.05751192
[55] 0.06222023 0.21530971 0.14798813 0.23725987 0.05334866 0.15878501
 [ reached getOption("max.print") -- omitted 355 entries ]

We can also estimate the survival time for each point using predict_expected_survival_time:

iai::predict_expected_survival_time(grid, test_X)
 [1]  62.96007  60.45191 128.12247  80.41539  58.22632 135.23928 109.30474
 [8]  68.30347 133.16660 115.05602 126.94926  95.13604  62.95190  74.78489
[15] 135.56135  61.75748  60.32962  80.63109  96.98121 144.10354  58.41931
[22]  53.16931 121.19475  62.46134 126.53887 103.24951  54.78605 114.83856
[29] 177.45912 200.43259 208.13383 104.56134 205.81946 153.49698  95.57145
[36]  64.08807 142.50420 331.02697  94.14627  88.54103 345.24707  90.90750
[43] 147.78680  60.17856 181.83168 147.00933 156.06318 170.31862 116.44109
[50]  95.37139 354.85519 123.30760 279.52693 156.52047 145.26378  61.02193
[57] 107.51267 159.36992 176.68766 128.12151
 [ reached getOption("max.print") -- omitted 355 entries ]

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

iai::score(grid, train_X, train_died, train_times,
           criterion = "harrell_c_statistic")
[1] 0.7286481

Or on the test set:

iai::score(grid, test_X, test_died, test_times,
           criterion = "harrell_c_statistic")
[1] 0.7130325

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

iai::score(grid, test_X, test_died, test_times, criterion = "auc",
           evaluation_time = 10)
[1] 0.7750751

We can also look at the variable importance:

iai::variable_importance(iai::get_learner(grid))
  Feature Importance
1     age 0.62680953
2     hgb 0.18715335
3   creat 0.10944276
4  mspike 0.04890058
5     sex 0.02769378

XGBoost Survival Learner

We will use a grid_search to fit an xgboost_survival_learner:

grid <- iai::grid_search(
    iai::xgboost_survival_learner(
        random_seed = 1,
    ),
    max_depth = 2:5,
    num_round = c(20, 50, 100),
)
iai::fit(grid, train_X, train_died, train_times)

We can make predictions on new data using predict. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves <- iai::predict(grid, test_X)
t <- 10
pred_curves[[1]][t]
[1] 0.1421811

You can also query the probability for all of the points together by applying the [ function to each curve:

sapply(pred_curves, `[`, t)
 [1] 0.14218 0.19365 0.10101 0.21149 0.21305 0.06571 0.11277 0.13989 0.09729
[10] 0.06921 0.06571 0.11189 0.27597 0.14269 0.06921 0.19365 0.19365 0.12828
[19] 0.10603 0.07439 0.19939 0.23700 0.06943 0.14218 0.07995 0.15077 0.27848
[28] 0.07995 0.04525 0.03647 0.03647 0.08860 0.03309 0.06607 0.10751 0.28532
[37] 0.06571 0.02038 0.10603 0.16009 0.00833 0.11189 0.05234 0.14986 0.05390
[46] 0.06739 0.05234 0.06252 0.06921 0.11344 0.00790 0.08444 0.01892 0.05306
[55] 0.06571 0.14986 0.14440 0.07789 0.05390 0.09612
 [ reached getOption("max.print") -- omitted 355 entries ]

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

iai::predict(grid, test_X, t = 10)
 [1] 0.14218 0.19365 0.10101 0.21149 0.21305 0.06571 0.11277 0.13989 0.09729
[10] 0.06921 0.06571 0.11189 0.27597 0.14269 0.06921 0.19365 0.19365 0.12828
[19] 0.10603 0.07439 0.19939 0.23700 0.06943 0.14218 0.07995 0.15077 0.27848
[28] 0.07995 0.04525 0.03647 0.03647 0.08860 0.03309 0.06607 0.10751 0.28532
[37] 0.06571 0.02038 0.10603 0.16009 0.00833 0.11189 0.05234 0.14986 0.05390
[46] 0.06739 0.05234 0.06252 0.06921 0.11344 0.00790 0.08444 0.01892 0.05306
[55] 0.06571 0.14986 0.14440 0.07789 0.05390 0.09612
 [ reached getOption("max.print") -- omitted 355 entries ]

We can also estimate the survival time for each point using predict_expected_survival_time:

iai::predict_expected_survival_time(grid, test_X)
 [1] 103.140  73.356 143.604  66.184  65.612 200.331 129.869 104.900 148.390
[10] 193.413 200.331 130.828  47.879 102.751 193.413  73.356  73.356 114.597
[19] 137.486 183.769  70.912  57.762 192.987 103.140 174.146  96.940  47.335
[28] 174.146 248.469 273.908 273.908 160.553 284.628 199.604 135.759  45.907
[37] 200.331 330.071 137.486  90.852 382.014 130.828 230.173  97.563 226.385
[46] 196.969 230.173 206.961 193.413 129.143 384.077 166.900 335.877 228.409
[55] 200.331  97.563 101.480 177.629 226.385 149.958
 [ reached getOption("max.print") -- omitted 355 entries ]

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

iai::score(grid, train_X, train_died, train_times,
           criterion = "harrell_c_statistic")
[1] 0.7189903

Or on the test set:

iai::score(grid, test_X, test_died, test_times,
           criterion = "harrell_c_statistic")
[1] 0.7091297

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

iai::score(grid, test_X, test_died, test_times, criterion = "auc",
           evaluation_time = 10)
[1] 0.6896096

We can also look at the variable importance:

iai::variable_importance(iai::get_learner(grid))
  Feature Importance
1     age   0.427724
2   creat   0.200167
3     hgb   0.154691
4   sex=M   0.115374
5  mspike   0.102045

We can calculate the SHAP values:

iai::predict_shap(grid, test_X)
$expected_value
[1] 0.04607966

$features
   age  hgb creat mspike sex=M
1   79  9.4  1.10    2.3     0
2   94 11.0  1.10    0.7     0
3   74 12.5  1.25    0.9     0
4   92 14.0  0.80    1.6     0
5   81 11.4  1.50    1.9     1
6   74 14.7  1.00    0.5     0
7   67 10.7  1.30    1.2     1
8   81 12.7  0.90    1.3     0
9   74  9.8  0.80    1.4     0
10  76 13.8  0.90    1.5     0
11  72 14.1  1.10    1.7     0
12  84 14.3  1.30    0.4     0
 [ reached 'max' / getOption("max.print") -- omitted 403 rows ]

$shap_values
           [,1]     [,2]     [,3]     [,4]     [,5]
  [1,]  0.61058  0.30582 -0.04894 -0.01826 -0.14742
  [2,]  1.07476  0.16585 -0.04894 -0.00356 -0.14742
  [3,]  0.23345  0.18522  0.05296 -0.01366 -0.12105
  [4,]  1.40230 -0.05959 -0.05695 -0.01366 -0.13249
  [5,]  0.69856  0.26270  0.04495 -0.02039  0.16211
  [6,]  0.18722 -0.14855 -0.04093 -0.00356 -0.10612
  [7,] -0.18245  0.42472  0.12438 -0.01668  0.10355
  [8,]  0.77792  0.12435 -0.05695 -0.01366 -0.14742
  [9,]  0.13403  0.33627 -0.03813 -0.01366 -0.12105
 [10,]  0.24808 -0.14301 -0.04354 -0.01366 -0.10612
 [11,]  0.20854 -0.15517 -0.04093 -0.01826 -0.10612
 [12,]  0.67713 -0.13543  0.03954 -0.00356 -0.13249
 [ reached getOption("max.print") -- omitted 403 rows ]

We can then use the SHAP library to visualize these results in whichever way we prefer.

GLMNetCV Survival Learner

We will fit a glmnetcv_survival_learner directly. Since GLMNet does not support missing data, we first subset to complete cases:

train_cc = complete.cases(train_X)
train_X_cc = train_X[train_cc, ]
train_died_cc = train_died[train_cc]
train_times_cc = train_times[train_cc]

test_cc = complete.cases(test_X)
test_X_cc = test_X[test_cc, ]
test_died_cc = test_died[test_cc]
test_times_cc = test_times[test_cc]

We can now proceed with fitting the learner:

lnr <- iai::glmnetcv_survival_learner(random_seed = 1)
iai::fit(lnr, train_X_cc, train_died_cc, train_times_cc)
Julia Object of type GLMNetCVSurvivalLearner.
Fitted GLMNetCVSurvivalLearner:
  Constant: -3.319
  Weights:
    age:    0.0578244
    creat:  0.18159
    hgb:   -0.0889831
    sex=M:  0.300142

We can make predictions on new data using predict. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves <- iai::predict(lnr, test_X_cc)
t <- 10
pred_curves[[1]][t]
[1] 0.2006079

You can also query the probability for all of the points together by applying the [ function to each curve:

sapply(pred_curves, `[`, t)
 [1] 0.20060790 0.37016508 0.12258563 0.25815025 0.26318357 0.09764988
 [7] 0.13021192 0.16532398 0.14207864 0.11548805 0.09364096 0.18164974
[13] 0.19646153 0.22090925 0.11633882 0.32794891 0.35621026 0.13689713
[19] 0.18496651 0.11288777 0.29856028 0.32199141 0.11759813 0.23101774
[25] 0.14958820 0.22753554 0.36432456 0.14973809 0.04427567 0.06253306
[31] 0.05351038 0.13143858 0.04433383 0.14841218 0.17601814 0.10221974
[37] 0.02128404 0.13806728 0.13504396 0.03009027 0.15786415 0.08168353
[43] 0.27166587 0.07241175 0.11989440 0.08235410 0.08848366 0.12053551
[49] 0.18218162 0.01961504 0.12975188 0.03141674 0.08366623 0.09236169
[55] 0.23485295 0.15931575 0.06634960 0.06877971 0.11722330 0.10916430
 [ reached getOption("max.print") -- omitted 341 entries ]

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

iai::predict(lnr, test_X_cc, t = 10)
 [1] 0.20060790 0.37016508 0.12258563 0.25815025 0.26318357 0.09764988
 [7] 0.13021192 0.16532398 0.14207864 0.11548805 0.09364096 0.18164974
[13] 0.19646153 0.22090925 0.11633882 0.32794891 0.35621026 0.13689713
[19] 0.18496651 0.11288777 0.29856028 0.32199141 0.11759813 0.23101774
[25] 0.14958820 0.22753554 0.36432456 0.14973809 0.04427567 0.06253306
[31] 0.05351038 0.13143858 0.04433383 0.14841218 0.17601814 0.10221974
[37] 0.02128404 0.13806728 0.13504396 0.03009027 0.15786415 0.08168353
[43] 0.27166587 0.07241175 0.11989440 0.08235410 0.08848366 0.12053551
[49] 0.18218162 0.01961504 0.12975188 0.03141674 0.08366623 0.09236169
[55] 0.23485295 0.15931575 0.06634960 0.06877971 0.11722330 0.10916430
 [ reached getOption("max.print") -- omitted 341 entries ]

We can also estimate the survival time for each point using predict_expected_survival_time:

iai::predict_expected_survival_time(lnr, test_X_cc)
 [1]  69.71314  32.21543 118.46927  51.54736  50.33090 146.17630 111.58710
 [8]  86.76397 102.05617 125.48250 151.52896  78.11343  71.41961  62.24844
[15] 124.60867  37.98087  33.97530 106.05376  76.52748 128.21215  42.92932
[22]  38.91239 123.33223  59.00796  96.67249  60.09240  32.93626  96.56967
[29] 248.91474 204.73586 225.09247 110.53838 248.75425  97.48532  80.93234
[36] 140.41078 324.85608 105.12962 107.54374 292.57871  91.24689 169.32196
[43]  48.38208 185.27838 121.05584 168.24554 158.85087 120.43181  77.85550
[50] 331.48611 111.98441 288.08520 166.16775 153.29812  57.85016  90.34568
[57] 196.89239 192.11650 123.71006 132.28116
 [ reached getOption("max.print") -- omitted 341 entries ]

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

iai::score(lnr, train_X_cc, train_died_cc, train_times_cc,
           criterion = "harrell_c_statistic")
[1] 0.6837434

Or on the test set:

iai::score(lnr, test_X_cc, test_died_cc, test_times_cc,
           criterion = "harrell_c_statistic")
[1] 0.7027584

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

iai::score(lnr, test_X_cc, test_died_cc, test_times_cc, criterion = "auc",
           evaluation_time = 10)
[1] 0.6823275

We can also look at the variable importance:

iai::variable_importance(lnr)
  Feature Importance
1     age  0.5920800
2     hgb  0.1538955
3   sex=M  0.1274143
4   creat  0.1266102
5  mspike  0.0000000