Quick Start Guide: Heuristic Survival Learners
This is a Python version of the corresponding Heuristics quick start guide.
In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:
import pandas as pd
df = pd.read_csv("mgus2.csv")
df.sex = df.sex.astype('category')
Unnamed: 0 id age sex hgb ... mspike ptime pstat futime death
0 1 1 88 F 13.1 ... 0.5 30 0 30 1
1 2 2 78 F 11.5 ... 2.0 25 0 25 1
2 3 3 94 M 10.5 ... 2.6 46 0 46 1
3 4 4 68 M 15.2 ... 1.2 92 0 92 1
4 5 5 90 F 10.7 ... 1.0 8 0 8 1
5 6 6 90 M 12.9 ... 0.5 4 0 4 1
6 7 7 89 F 10.5 ... 1.3 151 0 151 1
... ... ... ... .. ... ... ... ... ... ... ...
1377 1378 1378 56 M 16.1 ... 0.5 59 0 59 0
1378 1379 1379 73 M 15.6 ... 0.5 48 0 48 0
1379 1380 1380 69 M 15.0 ... 0.0 22 0 22 1
1380 1381 1381 78 M 14.1 ... 1.9 35 0 35 0
1381 1382 1382 66 M 12.1 ... 0.0 31 0 31 1
1382 1383 1383 82 F 11.5 ... 2.3 38 1 61 0
1383 1384 1384 79 M 9.6 ... 1.7 6 0 6 1
[1384 rows x 11 columns]
from interpretableai import iai
X = df.iloc[:, 2:-4]
died = df.death == 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) = (
iai.split_data('survival', X, died, times, seed=12345))
Random Forest Survival Learner
We will use a GridSearch
to fit a RandomForestSurvivalLearner
with some basic parameter validation:
grid = iai.GridSearch(
iai.RandomForestSurvivalLearner(
missingdatamode='separate_class',
random_seed=1,
),
max_depth=range(5, 11),
)
grid.fit(train_X, train_died, train_times)
All Grid Results:
Row │ max_depth train_score valid_score rank_valid_score
│ Int64 Float64 Float64 Int64
─────┼───────────────────────────────────────────────────────
1 │ 5 0.291592 0.216862 1
2 │ 6 0.310289 0.212261 2
3 │ 7 0.322973 0.206279 3
4 │ 8 0.329134 0.202891 4
5 │ 9 0.331598 0.20266 5
6 │ 10 0.332294 0.202477 6
Best Params:
max_depth => 5
Best Model - Fitted RandomForestSurvivalLearner
We can make predictions on new data using predict
. For survival learners, this returns the appropriate SurvivalCurve
for each point, which we can then use to query the survival probability for any given time t
(in this case for t = 10
):
pred_curves = grid.predict(test_X)
t = 10
import numpy as np
np.array([c[t] for c in pred_curves])
array([0.18739649, 0.21215498, 0.0925912 , ..., 0.14985601, 0.08245875,
0.17559601])
Alternatively, you can get this mortality probability for any given time t
by providing t
as a keyword argument directly:
grid.predict(test_X, t=10)
array([0.18739649, 0.21215498, 0.0925912 , ..., 0.14985601, 0.08245875,
0.17559601])
We can also estimate the survival time for each point using predict_expected_survival_time
:
grid.predict_expected_survival_time(test_X)
array([ 62.9600698 , 60.4519118 , 128.12247219, ..., 56.40262151,
86.15765599, 116.58587113])
We can evaluate the quality of the model using score
with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
grid.score(train_X, train_died, train_times, criterion='harrell_c_statistic')
0.728648073106186
Or on the test set:
grid.score(test_X, test_died, test_times, criterion='harrell_c_statistic')
0.7130325263974928
We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
grid.score(test_X, test_died, test_times, criterion='auc', evaluation_time=10)
0.7750750750750736
We can also look at the variable importance:
grid.get_learner().variable_importance()
Feature Importance
0 age 0.626810
1 hgb 0.187153
2 creat 0.109443
3 mspike 0.048901
4 sex 0.027694
XGBoost Survival Learner
We will use a GridSearch
to fit an XGBoostSurvivalLearner
with some basic parameter validation:
grid = iai.GridSearch(
iai.XGBoostSurvivalLearner(
random_seed=1,
),
max_depth=range(2, 6),
num_round=[20, 50, 100],
)
grid.fit(train_X, train_died, train_times)
All Grid Results:
Row │ num_round max_depth train_score valid_score rank_valid_score
│ Int64 Int64 Float64 Float64 Int64
─────┼──────────────────────────────────────────────────────────────────
1 │ 20 2 0.215788 0.15041 1
2 │ 20 3 0.194945 -0.0670846 3
3 │ 20 4 0.183779 -0.121177 4
4 │ 20 5 0.0916703 -0.401915 7
5 │ 50 2 0.210932 0.0262287 2
6 │ 50 3 0.153306 -0.217315 6
7 │ 50 4 0.0281118 -0.513705 9
8 │ 50 5 -0.180287 -0.874789 10
9 │ 100 2 0.182667 -0.132212 5
10 │ 100 3 0.0250015 -0.512681 8
11 │ 100 4 -0.345342 -1.03887 11
12 │ 100 5 -0.907832 -1.75591 12
Best Params:
num_round => 20
max_depth => 2
Best Model - Fitted XGBoostSurvivalLearner
We can make predictions on new data using predict
. For survival learners, this returns the appropriate SurvivalCurve
for each point, which we can then use to query the survival probability for any given time t
(in this case for t = 10
):
pred_curves = grid.predict(test_X)
t = 10
import numpy as np
np.array([c[t] for c in pred_curves])
array([0.14218, 0.19365, 0.10101, ..., 0.21825, 0.12885, 0.22829])
Alternatively, you can get this mortality probability for any given time t
by providing t
as a keyword argument directly:
grid.predict(test_X, t=10)
array([0.14218, 0.19365, 0.10101, ..., 0.21825, 0.12885, 0.22829])
We can also estimate the survival time for each point using predict_expected_survival_time
:
grid.predict_expected_survival_time(test_X)
array([103.14 , 73.356, 143.604, ..., 63.766, 114.09 , 60.43 ])
We can evaluate the quality of the model using score
with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
grid.score(train_X, train_died, train_times, criterion='harrell_c_statistic')
0.71899034
Or on the test set:
grid.score(test_X, test_died, test_times, criterion='harrell_c_statistic')
0.70912965
We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
grid.score(test_X, test_died, test_times, criterion='auc', evaluation_time=10)
0.68960961
We can also look at the variable importance:
grid.get_learner().variable_importance()
Feature Importance
0 age 0.427724
1 creat 0.200167
2 hgb 0.154691
3 sex=M 0.115374
4 mspike 0.102045
We can calculate the SHAP values:
s = grid.predict_shap(test_X)
s['shap_values']
array([[ 0.61058, 0.30582, -0.04894, -0.01826, -0.14742],
[ 1.07476, 0.16585, -0.04894, -0.00356, -0.14742],
[ 0.23345, 0.18522, 0.05296, -0.01366, -0.12105],
...,
[ 1.07879, -0.07867, 0.03954, -0.00568, 0.14121],
[ 0.6234 , -0.16209, 0.03954, -0.01578, 0.11068],
[-0.26166, 0.29492, 0.19171, 0.81775, 0.18365]])
We can then use the SHAP library to visualize these results in whichever way we prefer.
GLMNetCV Survival Learner
We will fit a GLMNetCVSurvivalLearner
directly. Since GLMNetCVLearner
does not support missing data, we first subset to complete cases:
train_cc = ~train_X.isnull().any(axis=1)
train_X_cc = train_X[train_cc]
train_died_cc = train_died[train_cc]
train_times_cc = train_times[train_cc]
test_cc = ~test_X.isnull().any(axis=1)
test_X_cc = test_X[test_cc]
test_died_cc = test_died[test_cc]
test_times_cc = test_times[test_cc]
We can now proceed with fitting the learner:
lnr = iai.GLMNetCVSurvivalLearner(random_seed=1)
lnr.fit(train_X_cc, train_died_cc, train_times_cc)
Fitted GLMNetCVSurvivalLearner:
Constant: -3.319
Weights:
age: 0.0578244
creat: 0.18159
hgb: -0.0889831
sex=M: 0.300142
We can make predictions on new data using predict
. For survival learners, this returns the appropriate SurvivalCurve
for each point, which we can then use to query the mortality probability for any given time t
(in this case for t = 10
):
pred_curves = lnr.predict(test_X_cc)
t = 10
np.array([c[t] for c in pred_curves])
array([0.2006079 , 0.37016508, 0.12258563, ..., 0.29560852, 0.19178951,
0.12366086])
Alternatively, we can get this mortality probability for any given time t
by providing t
as a keyword argument directly:
lnr.predict(test_X_cc, t=10)
array([0.2006079 , 0.37016508, 0.12258563, ..., 0.29560852, 0.19178951,
0.12366086])
We can also estimate the survival time for each point:
lnr.predict_expected_survival_time(test_X_cc)
array([ 69.71314448, 32.21543141, 118.46926932, ..., 43.47980691,
73.42616111, 117.45965404])
We can evaluate the quality of the model using score
with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
lnr.score(train_X_cc, train_died_cc, train_times_cc,
criterion='harrell_c_statistic')
0.6837433757104808
Or on the test set:
lnr.score(test_X_cc, test_died_cc, test_times_cc,
criterion='harrell_c_statistic')
0.7027583847744624
We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
lnr.score(test_X_cc, test_died_cc, test_times_cc, criterion='auc', evaluation_time=10)
0.6823274764451254
We can also look at the variable importance:
lnr.variable_importance()
Feature Importance
0 age 0.592080
1 hgb 0.153895
2 sex=M 0.127414
3 creat 0.126610
4 mspike 0.000000