Quick Start Guide: Heuristic Survival Learners
This is an R version of the corresponding Heuristics quick start guide.
In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:
df <- read.table("mgus2.csv", sep = ",", header = T, stringsAsFactors = T)
X id age sex hgb creat mspike ptime pstat futime death
1 1 1 88 F 13.1 1.3 0.5 30 0 30 1
2 2 2 78 F 11.5 1.2 2.0 25 0 25 1
3 3 3 94 M 10.5 1.5 2.6 46 0 46 1
4 4 4 68 M 15.2 1.2 1.2 92 0 92 1
5 5 5 90 F 10.7 0.8 1.0 8 0 8 1
[ reached 'max' / getOption("max.print") -- omitted 1379 rows ]
X <- df[, 3:7]
died <- df$death == 1
times <- df$futime
split <- iai::split_data("survival", X, died, times, seed = 12345)
train_X <- split$train$X
train_died <- split$train$deaths
train_times <- split$train$times
test_X <- split$test$X
test_died <- split$test$deaths
test_times <- split$test$times
Random Forest Survival Learner
We will use a grid_search
to fit a random_forest_survival_learner
:
grid <- iai::grid_search(
iai::random_forest_survival_learner(
missingdatamode = "separate_class",
random_seed = 1,
),
max_depth = 5:10,
)
iai::fit(grid, train_X, train_died, train_times)
We can make predictions on new data using predict
. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t
(in this case for t = 10
):
pred_curves <- iai::predict(grid, test_X)
t <- 10
pred_curves[[1]][t]
[1] 0.1873965
You can also query the probability for all of the points together by applying the [
function to each curve:
sapply(pred_curves, `[`, t)
[1] 0.18739649 0.21215498 0.09259120 0.11206114 0.21556982 0.06529098
[7] 0.21255886 0.18394250 0.13583081 0.05928936 0.06964021 0.05761070
[13] 0.17592700 0.09695835 0.05083491 0.21279057 0.21020701 0.13996345
[19] 0.04926693 0.07267342 0.21694837 0.18411062 0.07030906 0.19655744
[25] 0.08155705 0.19667612 0.16373136 0.09083397 0.07260804 0.03996479
[31] 0.03565278 0.06140443 0.04793617 0.19591522 0.07673138 0.19419976
[37] 0.06210086 0.01889528 0.05808891 0.07176119 0.02248062 0.08093347
[43] 0.06311591 0.21353333 0.04997100 0.08071766 0.06164327 0.05396804
[49] 0.05687817 0.08088472 0.01339789 0.09761082 0.04075485 0.05751192
[55] 0.06222023 0.21530971 0.14798813 0.23725987 0.05334866 0.15878501
[ reached getOption("max.print") -- omitted 355 entries ]
Alternatively, you can get this mortality probability for any given time t
by providing t
as a keyword argument directly:
iai::predict(grid, test_X, t = 10)
[1] 0.18739649 0.21215498 0.09259120 0.11206114 0.21556982 0.06529098
[7] 0.21255886 0.18394250 0.13583081 0.05928936 0.06964021 0.05761070
[13] 0.17592700 0.09695835 0.05083491 0.21279057 0.21020701 0.13996345
[19] 0.04926693 0.07267342 0.21694837 0.18411062 0.07030906 0.19655744
[25] 0.08155705 0.19667612 0.16373136 0.09083397 0.07260804 0.03996479
[31] 0.03565278 0.06140443 0.04793617 0.19591522 0.07673138 0.19419976
[37] 0.06210086 0.01889528 0.05808891 0.07176119 0.02248062 0.08093347
[43] 0.06311591 0.21353333 0.04997100 0.08071766 0.06164327 0.05396804
[49] 0.05687817 0.08088472 0.01339789 0.09761082 0.04075485 0.05751192
[55] 0.06222023 0.21530971 0.14798813 0.23725987 0.05334866 0.15878501
[ reached getOption("max.print") -- omitted 355 entries ]
We can also estimate the survival time for each point using predict_expected_survival_time
:
iai::predict_expected_survival_time(grid, test_X)
[1] 62.96007 60.45191 128.12247 80.41539 58.22632 135.23928 109.30474
[8] 68.30347 133.16660 115.05602 126.94926 95.13604 62.95190 74.78489
[15] 135.56135 61.75748 60.32962 80.63109 96.98121 144.10354 58.41931
[22] 53.16931 121.19475 62.46134 126.53887 103.24951 54.78605 114.83856
[29] 177.45912 200.43259 208.13383 104.56134 205.81946 153.49698 95.57145
[36] 64.08807 142.50420 331.02697 94.14627 88.54103 345.24707 90.90750
[43] 147.78680 60.17856 181.83168 147.00933 156.06318 170.31862 116.44109
[50] 95.37139 354.85519 123.30760 279.52693 156.52047 145.26378 61.02193
[57] 107.51267 159.36992 176.68766 128.12151
[ reached getOption("max.print") -- omitted 355 entries ]
We can evaluate the quality of the model using score
with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
iai::score(grid, train_X, train_died, train_times,
criterion = "harrell_c_statistic")
[1] 0.7286481
Or on the test set:
iai::score(grid, test_X, test_died, test_times,
criterion = "harrell_c_statistic")
[1] 0.7130325
We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
iai::score(grid, test_X, test_died, test_times, criterion = "auc",
evaluation_time = 10)
[1] 0.7750751
We can also look at the variable importance:
iai::variable_importance(iai::get_learner(grid))
Feature Importance
1 age 0.62680953
2 hgb 0.18715335
3 creat 0.10944276
4 mspike 0.04890058
5 sex 0.02769378
XGBoost Survival Learner
We will use a grid_search
to fit an xgboost_survival_learner
:
grid <- iai::grid_search(
iai::xgboost_survival_learner(
random_seed = 1,
),
max_depth = 2:5,
num_round = c(20, 50, 100),
)
iai::fit(grid, train_X, train_died, train_times)
We can make predictions on new data using predict
. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t
(in this case for t = 10
):
pred_curves <- iai::predict(grid, test_X)
t <- 10
pred_curves[[1]][t]
[1] 0.1421811
You can also query the probability for all of the points together by applying the [
function to each curve:
sapply(pred_curves, `[`, t)
[1] 0.14218 0.19365 0.10101 0.21149 0.21305 0.06571 0.11277 0.13989 0.09729
[10] 0.06921 0.06571 0.11189 0.27597 0.14269 0.06921 0.19365 0.19365 0.12828
[19] 0.10603 0.07439 0.19939 0.23700 0.06943 0.14218 0.07995 0.15077 0.27848
[28] 0.07995 0.04525 0.03647 0.03647 0.08860 0.03309 0.06607 0.10751 0.28532
[37] 0.06571 0.02038 0.10603 0.16009 0.00833 0.11189 0.05234 0.14986 0.05390
[46] 0.06739 0.05234 0.06252 0.06921 0.11344 0.00790 0.08444 0.01892 0.05306
[55] 0.06571 0.14986 0.14440 0.07789 0.05390 0.09612
[ reached getOption("max.print") -- omitted 355 entries ]
Alternatively, you can get this mortality probability for any given time t
by providing t
as a keyword argument directly:
iai::predict(grid, test_X, t = 10)
[1] 0.14218 0.19365 0.10101 0.21149 0.21305 0.06571 0.11277 0.13989 0.09729
[10] 0.06921 0.06571 0.11189 0.27597 0.14269 0.06921 0.19365 0.19365 0.12828
[19] 0.10603 0.07439 0.19939 0.23700 0.06943 0.14218 0.07995 0.15077 0.27848
[28] 0.07995 0.04525 0.03647 0.03647 0.08860 0.03309 0.06607 0.10751 0.28532
[37] 0.06571 0.02038 0.10603 0.16009 0.00833 0.11189 0.05234 0.14986 0.05390
[46] 0.06739 0.05234 0.06252 0.06921 0.11344 0.00790 0.08444 0.01892 0.05306
[55] 0.06571 0.14986 0.14440 0.07789 0.05390 0.09612
[ reached getOption("max.print") -- omitted 355 entries ]
We can also estimate the survival time for each point using predict_expected_survival_time
:
iai::predict_expected_survival_time(grid, test_X)
[1] 103.140 73.356 143.604 66.184 65.612 200.331 129.869 104.900 148.390
[10] 193.413 200.331 130.828 47.879 102.751 193.413 73.356 73.356 114.597
[19] 137.486 183.769 70.912 57.762 192.987 103.140 174.146 96.940 47.335
[28] 174.146 248.469 273.908 273.908 160.553 284.628 199.604 135.759 45.907
[37] 200.331 330.071 137.486 90.852 382.014 130.828 230.173 97.563 226.385
[46] 196.969 230.173 206.961 193.413 129.143 384.077 166.900 335.877 228.409
[55] 200.331 97.563 101.480 177.629 226.385 149.958
[ reached getOption("max.print") -- omitted 355 entries ]
We can evaluate the quality of the model using score
with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
iai::score(grid, train_X, train_died, train_times,
criterion = "harrell_c_statistic")
[1] 0.7189903
Or on the test set:
iai::score(grid, test_X, test_died, test_times,
criterion = "harrell_c_statistic")
[1] 0.7091297
We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
iai::score(grid, test_X, test_died, test_times, criterion = "auc",
evaluation_time = 10)
[1] 0.6896096
We can also look at the variable importance:
iai::variable_importance(iai::get_learner(grid))
Feature Importance
1 age 0.427724
2 creat 0.200167
3 hgb 0.154691
4 sex=M 0.115374
5 mspike 0.102045
We can calculate the SHAP values:
iai::predict_shap(grid, test_X)
$expected_value
[1] 0.04607966
$features
age hgb creat mspike sex=M
1 79 9.4 1.10 2.3 0
2 94 11.0 1.10 0.7 0
3 74 12.5 1.25 0.9 0
4 92 14.0 0.80 1.6 0
5 81 11.4 1.50 1.9 1
6 74 14.7 1.00 0.5 0
7 67 10.7 1.30 1.2 1
8 81 12.7 0.90 1.3 0
9 74 9.8 0.80 1.4 0
10 76 13.8 0.90 1.5 0
11 72 14.1 1.10 1.7 0
12 84 14.3 1.30 0.4 0
[ reached 'max' / getOption("max.print") -- omitted 403 rows ]
$shap_values
[,1] [,2] [,3] [,4] [,5]
[1,] 0.61058 0.30582 -0.04894 -0.01826 -0.14742
[2,] 1.07476 0.16585 -0.04894 -0.00356 -0.14742
[3,] 0.23345 0.18522 0.05296 -0.01366 -0.12105
[4,] 1.40230 -0.05959 -0.05695 -0.01366 -0.13249
[5,] 0.69856 0.26270 0.04495 -0.02039 0.16211
[6,] 0.18722 -0.14855 -0.04093 -0.00356 -0.10612
[7,] -0.18245 0.42472 0.12438 -0.01668 0.10355
[8,] 0.77792 0.12435 -0.05695 -0.01366 -0.14742
[9,] 0.13403 0.33627 -0.03813 -0.01366 -0.12105
[10,] 0.24808 -0.14301 -0.04354 -0.01366 -0.10612
[11,] 0.20854 -0.15517 -0.04093 -0.01826 -0.10612
[12,] 0.67713 -0.13543 0.03954 -0.00356 -0.13249
[ reached getOption("max.print") -- omitted 403 rows ]
We can then use the SHAP library to visualize these results in whichever way we prefer.
GLMNetCV Survival Learner
We will fit a glmnetcv_survival_learner
directly. Since GLMNet does not support missing data, we first subset to complete cases:
train_cc = complete.cases(train_X)
train_X_cc = train_X[train_cc, ]
train_died_cc = train_died[train_cc]
train_times_cc = train_times[train_cc]
test_cc = complete.cases(test_X)
test_X_cc = test_X[test_cc, ]
test_died_cc = test_died[test_cc]
test_times_cc = test_times[test_cc]
We can now proceed with fitting the learner:
lnr <- iai::glmnetcv_survival_learner(random_seed = 1)
iai::fit(lnr, train_X_cc, train_died_cc, train_times_cc)
Julia Object of type GLMNetCVSurvivalLearner.
Fitted GLMNetCVSurvivalLearner:
Constant: -3.319
Weights:
age: 0.0578244
creat: 0.18159
hgb: -0.0889831
sex=M: 0.300142
We can make predictions on new data using predict
. For survival learners, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t
(in this case for t = 10
):
pred_curves <- iai::predict(lnr, test_X_cc)
t <- 10
pred_curves[[1]][t]
[1] 0.2006079
You can also query the probability for all of the points together by applying the [
function to each curve:
sapply(pred_curves, `[`, t)
[1] 0.20060790 0.37016508 0.12258563 0.25815025 0.26318357 0.09764988
[7] 0.13021192 0.16532398 0.14207864 0.11548805 0.09364096 0.18164974
[13] 0.19646153 0.22090925 0.11633882 0.32794891 0.35621026 0.13689713
[19] 0.18496651 0.11288777 0.29856028 0.32199141 0.11759813 0.23101774
[25] 0.14958820 0.22753554 0.36432456 0.14973809 0.04427567 0.06253306
[31] 0.05351038 0.13143858 0.04433383 0.14841218 0.17601814 0.10221974
[37] 0.02128404 0.13806728 0.13504396 0.03009027 0.15786415 0.08168353
[43] 0.27166587 0.07241175 0.11989440 0.08235410 0.08848366 0.12053551
[49] 0.18218162 0.01961504 0.12975188 0.03141674 0.08366623 0.09236169
[55] 0.23485295 0.15931575 0.06634960 0.06877971 0.11722330 0.10916430
[ reached getOption("max.print") -- omitted 341 entries ]
Alternatively, you can get this mortality probability for any given time t
by providing t
as a keyword argument directly:
iai::predict(lnr, test_X_cc, t = 10)
[1] 0.20060790 0.37016508 0.12258563 0.25815025 0.26318357 0.09764988
[7] 0.13021192 0.16532398 0.14207864 0.11548805 0.09364096 0.18164974
[13] 0.19646153 0.22090925 0.11633882 0.32794891 0.35621026 0.13689713
[19] 0.18496651 0.11288777 0.29856028 0.32199141 0.11759813 0.23101774
[25] 0.14958820 0.22753554 0.36432456 0.14973809 0.04427567 0.06253306
[31] 0.05351038 0.13143858 0.04433383 0.14841218 0.17601814 0.10221974
[37] 0.02128404 0.13806728 0.13504396 0.03009027 0.15786415 0.08168353
[43] 0.27166587 0.07241175 0.11989440 0.08235410 0.08848366 0.12053551
[49] 0.18218162 0.01961504 0.12975188 0.03141674 0.08366623 0.09236169
[55] 0.23485295 0.15931575 0.06634960 0.06877971 0.11722330 0.10916430
[ reached getOption("max.print") -- omitted 341 entries ]
We can also estimate the survival time for each point using predict_expected_survival_time
:
iai::predict_expected_survival_time(lnr, test_X_cc)
[1] 69.71314 32.21543 118.46927 51.54736 50.33090 146.17630 111.58710
[8] 86.76397 102.05617 125.48250 151.52896 78.11343 71.41961 62.24844
[15] 124.60867 37.98087 33.97530 106.05376 76.52748 128.21215 42.92932
[22] 38.91239 123.33223 59.00796 96.67249 60.09240 32.93626 96.56967
[29] 248.91474 204.73586 225.09247 110.53838 248.75425 97.48532 80.93234
[36] 140.41078 324.85608 105.12962 107.54374 292.57871 91.24689 169.32196
[43] 48.38208 185.27838 121.05584 168.24554 158.85087 120.43181 77.85550
[50] 331.48611 111.98441 288.08520 166.16775 153.29812 57.85016 90.34568
[57] 196.89239 192.11650 123.71006 132.28116
[ reached getOption("max.print") -- omitted 341 entries ]
We can evaluate the quality of the model using score
with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
iai::score(lnr, train_X_cc, train_died_cc, train_times_cc,
criterion = "harrell_c_statistic")
[1] 0.6837434
Or on the test set:
iai::score(lnr, test_X_cc, test_died_cc, test_times_cc,
criterion = "harrell_c_statistic")
[1] 0.7027584
We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
iai::score(lnr, test_X_cc, test_died_cc, test_times_cc, criterion = "auc",
evaluation_time = 10)
[1] 0.6823275
We can also look at the variable importance:
iai::variable_importance(lnr)
Feature Importance
1 age 0.5920800
2 hgb 0.1538955
3 sex=M 0.1274143
4 creat 0.1266102
5 mspike 0.0000000