Quick Start Guide: Heuristic Survival Learners

This is an R version of the corresponding Heuristics quick start guide.

In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

df <- read.table("mgus2.csv", sep = ",", header = T, stringsAsFactors = T)
  X id age sex  hgb creat mspike ptime pstat futime death
1 1  1  88   F 13.1   1.3    0.5    30     0     30     1
2 2  2  78   F 11.5   1.2    2.0    25     0     25     1
3 3  3  94   M 10.5   1.5    2.6    46     0     46     1
4 4  4  68   M 15.2   1.2    1.2    92     0     92     1
5 5  5  90   F 10.7   0.8    1.0     8     0      8     1
 [ reached 'max' / getOption("max.print") -- omitted 1379 rows ]
X <- df[, 3:7]
died <- df$death == 1
times <- df$futime
split <- iai::split_data("survival", X, died, times, seed = 12345)
train_X <- split$train$X
train_died <- split$train$deaths
train_times <- split$train$times
test_X <- split$test$X
test_died <- split$test$deaths
test_times <- split$test$times

Random Forest Survival Learner

We will use a grid_search to fit a random_forest_survival_learner:

grid <- iai::grid_search(
    iai::random_forest_survival_learner(
        missingdatamode = "separate_class",
        random_seed = 1,
    ),
    max_depth = 5:10,
)
iai::fit(grid, train_X, train_died, train_times)

We can make predictions on new data using predict. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves <- iai::predict(grid, test_X)
t <- 10
pred_curves[1][t]
[1] 0.1873965

You can also query the probability for all of the points together by applying the [ function to each curve:

sapply(pred_curves, `[`, t)
 [1] 0.18739649 0.21215498 0.09259120 0.11206114 0.21556982 0.06529098
 [7] 0.21255886 0.18394250 0.13583081 0.05928936 0.06964021 0.05761070
[13] 0.17592700 0.09695835 0.05083491 0.21279057 0.21020701 0.13996345
[19] 0.04926693 0.07267342 0.21694837 0.18411062 0.07030906 0.19655744
[25] 0.08155705 0.19667612 0.16373136 0.09083397 0.07260804 0.03996479
[31] 0.03565278 0.06140443 0.04793617 0.19591522 0.07673138 0.19419976
[37] 0.06210086 0.01889528 0.05808891 0.07176119 0.02248062 0.08093347
[43] 0.06311591 0.21353333 0.04997100 0.08071766 0.06164327 0.05396804
[49] 0.05687817 0.08088472 0.01339789 0.09761082 0.04075485 0.05751192
[55] 0.06222023 0.21530971 0.14798813 0.23725987 0.05334866 0.15878501
 [ reached getOption("max.print") -- omitted 355 entries ]

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

iai::predict(grid, test_X, t = 10)
 [1] 0.18739649 0.21215498 0.09259120 0.11206114 0.21556982 0.06529098
 [7] 0.21255886 0.18394250 0.13583081 0.05928936 0.06964021 0.05761070
[13] 0.17592700 0.09695835 0.05083491 0.21279057 0.21020701 0.13996345
[19] 0.04926693 0.07267342 0.21694837 0.18411062 0.07030906 0.19655744
[25] 0.08155705 0.19667612 0.16373136 0.09083397 0.07260804 0.03996479
[31] 0.03565278 0.06140443 0.04793617 0.19591522 0.07673138 0.19419976
[37] 0.06210086 0.01889528 0.05808891 0.07176119 0.02248062 0.08093347
[43] 0.06311591 0.21353333 0.04997100 0.08071766 0.06164327 0.05396804
[49] 0.05687817 0.08088472 0.01339789 0.09761082 0.04075485 0.05751192
[55] 0.06222023 0.21530971 0.14798813 0.23725987 0.05334866 0.15878501
 [ reached getOption("max.print") -- omitted 355 entries ]

We can also estimate the survival time for each point using predict_expected_survival_time:

iai::predict_expected_survival_time(grid, test_X)
 [1]  62.96007  60.45191 128.12247  80.41539  58.22632 135.23928 109.30474
 [8]  68.30347 133.16660 115.05602 126.94926  95.13604  62.95190  74.78489
[15] 135.56135  61.75748  60.32962  80.63109  96.98121 144.10354  58.41931
[22]  53.16931 121.19475  62.46134 126.53887 103.24951  54.78605 114.83856
[29] 177.45912 200.43259 208.13383 104.56134 205.81946 153.49698  95.57145
[36]  64.08807 142.50420 331.02697  94.14627  88.54103 345.24707  90.90750
[43] 147.78680  60.17856 181.83168 147.00933 156.06318 170.31862 116.44109
[50]  95.37139 354.85519 123.30760 279.52693 156.52047 145.26378  61.02193
[57] 107.51267 159.36992 176.68766 128.12151
 [ reached getOption("max.print") -- omitted 355 entries ]

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

iai::score(grid, train_X, train_died, train_times,
           criterion = "harrell_c_statistic")
[1] 0.7286481

Or on the test set:

iai::score(grid, test_X, test_died, test_times,
           criterion = "harrell_c_statistic")
[1] 0.7130325

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

iai::score(grid, test_X, test_died, test_times, criterion = "auc",
           evaluation_time = 10)
[1] 0.7750751

We can also look at the variable importance:

iai::variable_importance(iai::get_learner(grid))
  Feature Importance
1     age 0.62680953
2     hgb 0.18715335
3   creat 0.10944276
4  mspike 0.04890058
5     sex 0.02769378

XGBoost Survival Learner

We will use a grid_search to fit an xgboost_survival_learner:

grid <- iai::grid_search(
    iai::xgboost_survival_learner(
        random_seed = 1,
    ),
    max_depth = 2:5,
    num_round = c(20, 50, 100),
)
iai::fit(grid, train_X, train_died, train_times)

We can make predictions on new data using predict. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves <- iai::predict(grid, test_X)
t <- 10
pred_curves[1][t]
[1] 0.146602

You can also query the probability for all of the points together by applying the [ function to each curve:

sapply(pred_curves, `[`, t)
 [1] 0.146602007 0.181972807 0.089343189 0.233761180 0.210810046 0.068637261
 [7] 0.111577430 0.141948042 0.092365925 0.071650634 0.068637261 0.109804344
[13] 0.260697440 0.144190739 0.071650634 0.181972807 0.181972807 0.118847887
[19] 0.109804344 0.097148408 0.204374072 0.191598696 0.068637261 0.146602007
[25] 0.081671053 0.160305987 0.299835538 0.081671053 0.044706129 0.034422041
[31] 0.034422041 0.091654244 0.027530256 0.051927370 0.108796970 0.252996733
[37] 0.068637261 0.029483093 0.109804344 0.208894603 0.009834575 0.109804344
[43] 0.056779203 0.146602007 0.048169590 0.074690356 0.056779203 0.055883834
[49] 0.071650634 0.119846399 0.009117891 0.090106636 0.017593932 0.056297044
[55] 0.068637261 0.146602007 0.144436267 0.062169577 0.048169590 0.095545787
 [ reached getOption("max.print") -- omitted 355 entries ]

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

iai::predict(grid, test_X, t = 10)
 [1] 0.146602007 0.181972807 0.089343189 0.233761180 0.210810046 0.068637261
 [7] 0.111577430 0.141948042 0.092365925 0.071650634 0.068637261 0.109804344
[13] 0.260697440 0.144190739 0.071650634 0.181972807 0.181972807 0.118847887
[19] 0.109804344 0.097148408 0.204374072 0.191598696 0.068637261 0.146602007
[25] 0.081671053 0.160305987 0.299835538 0.081671053 0.044706129 0.034422041
[31] 0.034422041 0.091654244 0.027530256 0.051927370 0.108796970 0.252996733
[37] 0.068637261 0.029483093 0.109804344 0.208894603 0.009834575 0.109804344
[43] 0.056779203 0.146602007 0.048169590 0.074690356 0.056779203 0.055883834
[49] 0.071650634 0.119846399 0.009117891 0.090106636 0.017593932 0.056297044
[55] 0.068637261 0.146602007 0.144436267 0.062169577 0.048169590 0.095545787
 [ reached getOption("max.print") -- omitted 355 entries ]

We can also estimate the survival time for each point using predict_expected_survival_time:

iai::predict_expected_survival_time(grid, test_X)
 [1]  99.87083  78.76796 159.46015  58.73172  66.43657 194.52313 131.16683
 [8] 103.31659 155.11383 188.78086 194.52313 133.13753  51.40148 101.63303
[15] 188.78086  78.76796  78.76796 123.53015 133.13753 148.58246  68.90000
[22]  74.26003 194.52313  99.87083 171.31455  90.71969  43.08736 171.31455
[29] 249.95242 280.33522 280.33522 156.12149 303.44139 231.18593 134.27708
[36]  53.33863 194.52313 296.65134 133.13753  67.15436 375.00975 133.13753
[43] 219.63029  99.87083 240.70658 183.22626 219.63029 221.70286 188.78086
[50] 122.53435 378.33451 158.34586 341.25306 220.74308 194.52313  99.87083
[57] 101.45135 207.69545 240.70658 150.72550
 [ reached getOption("max.print") -- omitted 355 entries ]

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

iai::score(grid, train_X, train_died, train_times,
           criterion = "harrell_c_statistic")
[1] 0.7193754

Or on the test set:

iai::score(grid, test_X, test_died, test_times,
           criterion = "harrell_c_statistic")
[1] 0.7093859

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

iai::score(grid, test_X, test_died, test_times, criterion = "auc",
           evaluation_time = 10)
[1] 0.6890691

We can also look at the variable importance:

iai::variable_importance(iai::get_learner(grid))
  Feature Importance
1     age 0.65587483
2     hgb 0.14861300
3   creat 0.10092666
4  mspike 0.05575816
5   sex=M 0.03882735

We can calculate the SHAP values:

iai::predict_shap(grid, test_X)
$expected_value
[1] -0.6530125

$features
   age  hgb creat mspike sex=M
1   79  9.4  1.10    2.3     0
2   94 11.0  1.10    0.7     0
3   74 12.5  1.25    0.9     0
4   92 14.0  0.80    1.6     0
5   81 11.4  1.50    1.9     1
6   74 14.7  1.00    0.5     0
7   67 10.7  1.30    1.2     1
8   81 12.7  0.90    1.3     0
9   74  9.8  0.80    1.4     0
10  76 13.8  0.90    1.5     0
11  72 14.1  1.10    1.7     0
12  84 14.3  1.30    0.4     0
 [ reached 'max' / getOption("max.print") -- omitted 403 rows ]

$shap_values
                [,1]         [,2]         [,3]          [,4]        [,5]
  [1,]  0.6400975585  0.297385663 -0.035298131 -0.0225629620 -0.14502457
  [2,]  0.9791966677  0.176633820 -0.035298131 -0.0042425417 -0.14502457
  [3,]  0.2123140544  0.136468321 -0.007378877 -0.0146590024 -0.11917135
  [4,]  1.3777524233  0.055015273 -0.043458562 -0.0128253633 -0.12334661
  [5,]  0.6852596402  0.275956571  0.033285711 -0.0252806693  0.16643402
  [6,]  0.2014569342 -0.130302861 -0.038415197 -0.0024089026 -0.09749340
  [7,] -0.1651683450  0.449971855  0.063247912 -0.0148528777  0.10874942
  [8,]  0.7639548779  0.138875619 -0.043458562 -0.0146590024 -0.14502457
  [9,]  0.1194929481  0.291889012 -0.035069373 -0.0146590024 -0.11917135
 [10,]  0.2556143105 -0.127895564 -0.039996311 -0.0128253633 -0.09749340
 [11,]  0.2220155001 -0.132540986 -0.038415197 -0.0207293220 -0.09749340
 [12,]  0.6910252571 -0.130302861 -0.010014940 -0.0024089026 -0.12334661
 [ reached getOption("max.print") -- omitted 403 rows ]

We can then use the SHAP library to visualize these results in whichever way we prefer.

GLMNetCV Survival Learner

We will fit a glmnetcv_survival_learner directly. Since GLMNet does not support missing data, we first subset to complete cases:

lnr <- iai::glmnetcv_survival_learner(random_seed = 1)
cc = complete.cases(train_X)
train_X_cc = train_X[cc, ]
train_died_cc = train_died[cc]
train_times_cc = train_times[cc]
iai::fit(lnr, train_X_cc, train_died_cc, train_times_cc)

We can make predictions on new data using predict. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves <- iai::predict(lnr, test_X)
t <- 10
pred_curves[1][t]
[1] 0.2006079

You can also query the probability for all of the points together by applying the [ function to each curve:

sapply(pred_curves, `[`, t)
 [1] 0.20060790 0.37016508 0.12258563 0.25815025 0.26318357 0.09764988
 [7] 0.13021192 0.16532398 0.14207864 0.11548805 0.09364096 0.18164974
[13] 0.19646153 0.22090925 0.11633882 0.32794891 0.35621026 0.13689713
[19] 0.18496651 0.11288777 0.29856028 0.32199141 0.11759813 0.23101774
[25] 0.14958820 0.22753554 0.36432456 0.14973809 0.04427567 0.06253306
[31] 0.05351038 0.13143858        NaN 0.04433383 0.14841218 0.17601814
[37] 0.10221974 0.02128404 0.13806728 0.13504396 0.03009027 0.15786415
[43] 0.08168353 0.27166587 0.07241175 0.11989440 0.08235410 0.08848366
[49] 0.12053551 0.18218162 0.01961504 0.12975188 0.03141674 0.08366623
[55] 0.09236169 0.23485295 0.15931575 0.06634960 0.06877971 0.11722330
 [ reached getOption("max.print") -- omitted 355 entries ]

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

iai::predict(lnr, test_X, t = 10)
 [1] 0.20060790 0.37016508 0.12258563 0.25815025 0.26318357 0.09764988
 [7] 0.13021192 0.16532398 0.14207864 0.11548805 0.09364096 0.18164974
[13] 0.19646153 0.22090925 0.11633882 0.32794891 0.35621026 0.13689713
[19] 0.18496651 0.11288777 0.29856028 0.32199141 0.11759813 0.23101774
[25] 0.14958820 0.22753554 0.36432456 0.14973809 0.04427567 0.06253306
[31] 0.05351038 0.13143858        NaN 0.04433383 0.14841218 0.17601814
[37] 0.10221974 0.02128404 0.13806728 0.13504396 0.03009027 0.15786415
[43] 0.08168353 0.27166587 0.07241175 0.11989440 0.08235410 0.08848366
[49] 0.12053551 0.18218162 0.01961504 0.12975188 0.03141674 0.08366623
[55] 0.09236169 0.23485295 0.15931575 0.06634960 0.06877971 0.11722330
 [ reached getOption("max.print") -- omitted 355 entries ]

We can also estimate the survival time for each point using predict_expected_survival_time:

iai::predict_expected_survival_time(lnr, test_X)
 [1]  69.71314  32.21543 118.46927  51.54736  50.33090 146.17630 111.58710
 [8]  86.76397 102.05617 125.48250 151.52896  78.11343  71.41961  62.24844
[15] 124.60867  37.98087  33.97530 106.05376  76.52748 128.21215  42.92932
[22]  38.91239 123.33223  59.00796  96.67249  60.09240  32.93626  96.56967
[29] 248.91474 204.73586 225.09247 110.53838       NaN 248.75425  97.48532
[36]  80.93234 140.41078 324.85608 105.12962 107.54374 292.57871  91.24689
[43] 169.32196  48.38208 185.27838 121.05584 168.24554 158.85087 120.43181
[50]  77.85550 331.48611 111.98441 288.08520 166.16775 153.29812  57.85016
[57]  90.34568 196.89239 192.11650 123.71006
 [ reached getOption("max.print") -- omitted 355 entries ]

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

cc = complete.cases(test_X)
test_X_cc = test_X[cc, ]
test_died_cc = test_died[cc]
test_times_cc = test_times[cc]
iai::score(lnr, train_X_cc, train_died_cc, train_times_cc,
           criterion = "harrell_c_statistic")
[1] 0.6837434

Or on the test set:

iai::score(lnr, test_X_cc, test_died_cc, test_times_cc,
           criterion = "harrell_c_statistic")
[1] 0.7027584

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

iai::score(lnr, test_X_cc, test_died_cc, test_times_cc, criterion = "auc",
           evaluation_time = 10)
[1] 0.6823275

We can also look at the variable importance:

iai::variable_importance(lnr)
  Feature Importance
1     age  0.5920800
2     hgb  0.1538955
3   sex=M  0.1274143
4   creat  0.1266102
5  mspike  0.0000000