Quick Start Guide: Heuristic Survival Learners
This is a Python version of the corresponding Heuristics quick start guide.
In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:
import pandas as pd
df = pd.read_csv("mgus2.csv")
df.sex = df.sex.astype('category')
Unnamed: 0 id age sex hgb ... mspike ptime pstat futime death
0 1 1 88 F 13.1 ... 0.5 30 0 30 1
1 2 2 78 F 11.5 ... 2.0 25 0 25 1
2 3 3 94 M 10.5 ... 2.6 46 0 46 1
3 4 4 68 M 15.2 ... 1.2 92 0 92 1
4 5 5 90 F 10.7 ... 1.0 8 0 8 1
5 6 6 90 M 12.9 ... 0.5 4 0 4 1
6 7 7 89 F 10.5 ... 1.3 151 0 151 1
... ... ... ... .. ... ... ... ... ... ... ...
1377 1378 1378 56 M 16.1 ... 0.5 59 0 59 0
1378 1379 1379 73 M 15.6 ... 0.5 48 0 48 0
1379 1380 1380 69 M 15.0 ... 0.0 22 0 22 1
1380 1381 1381 78 M 14.1 ... 1.9 35 0 35 0
1381 1382 1382 66 M 12.1 ... 0.0 31 0 31 1
1382 1383 1383 82 F 11.5 ... 2.3 38 1 61 0
1383 1384 1384 79 M 9.6 ... 1.7 6 0 6 1
[1384 rows x 11 columns]from interpretableai import iai
X = df.iloc[:, 2:-4]
died = df.death == 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) = (
iai.split_data('survival', X, died, times, seed=12345))
Random Forest Survival Learner
We will use a GridSearch to fit a RandomForestSurvivalLearner with some basic parameter validation:
grid = iai.GridSearch(
iai.RandomForestSurvivalLearner(
missingdatamode='separate_class',
random_seed=1,
),
max_depth=range(5, 11),
)
grid.fit(train_X, train_died, train_times)
All Grid Results:
Row │ max_depth train_score valid_score rank_valid_score
│ Int64 Float64 Float64 Int64
─────┼───────────────────────────────────────────────────────
1 │ 5 0.291592 0.216862 1
2 │ 6 0.310289 0.212261 2
3 │ 7 0.322973 0.206279 3
4 │ 8 0.329134 0.202891 4
5 │ 9 0.331598 0.20266 5
6 │ 10 0.332294 0.202477 6
Best Params:
max_depth => 5
Best Model - Fitted RandomForestSurvivalLearnerWe can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):
pred_curves = grid.predict(test_X)
t = 10
import numpy as np
np.array([c[t] for c in pred_curves])
array([0.18739649, 0.21215498, 0.0925912 , ..., 0.14985601, 0.08245875,
0.17559601])Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:
grid.predict(test_X, t=10)
array([0.18739649, 0.21215498, 0.0925912 , ..., 0.14985601, 0.08245875,
0.17559601])We can also estimate the survival time for each point using predict_expected_survival_time:
grid.predict_expected_survival_time(test_X)
array([ 62.9600698 , 60.4519118 , 128.12247219, ..., 56.40262151,
86.15765599, 116.58587113])We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
grid.score(train_X, train_died, train_times, criterion='harrell_c_statistic')
0.728648073106186Or on the test set:
grid.score(test_X, test_died, test_times, criterion='harrell_c_statistic')
0.7130325263974928We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
grid.score(test_X, test_died, test_times, criterion='auc', evaluation_time=10)
0.7750750750750736We can also look at the variable importance:
grid.get_learner().variable_importance()
Feature Importance
0 age 0.626810
1 hgb 0.187153
2 creat 0.109443
3 mspike 0.048901
4 sex 0.027694XGBoost Survival Learner
We will use a GridSearch to fit an XGBoostSurvivalLearner with some basic parameter validation:
grid = iai.GridSearch(
iai.XGBoostSurvivalLearner(
random_seed=1,
),
max_depth=range(2, 6),
num_round=[20, 50, 100],
)
grid.fit(train_X, train_died, train_times)
All Grid Results:
Row │ num_round max_depth train_score valid_score rank_valid_score
│ Int64 Int64 Float64 Float64 Int64
─────┼──────────────────────────────────────────────────────────────────
1 │ 20 2 0.215788 0.15041 1
2 │ 20 3 0.194945 -0.0670846 3
3 │ 20 4 0.183779 -0.121177 4
4 │ 20 5 0.0916703 -0.401915 7
5 │ 50 2 0.210932 0.0262287 2
6 │ 50 3 0.153306 -0.217315 6
7 │ 50 4 0.0281118 -0.513705 9
8 │ 50 5 -0.180287 -0.874789 10
9 │ 100 2 0.182667 -0.132212 5
10 │ 100 3 0.0250015 -0.512681 8
11 │ 100 4 -0.345342 -1.03887 11
12 │ 100 5 -0.907832 -1.75591 12
Best Params:
num_round => 20
max_depth => 2
Best Model - Fitted XGBoostSurvivalLearnerWe can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):
pred_curves = grid.predict(test_X)
t = 10
import numpy as np
np.array([c[t] for c in pred_curves])
array([0.14218, 0.19365, 0.10101, ..., 0.21825, 0.12885, 0.22829])Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:
grid.predict(test_X, t=10)
array([0.14218, 0.19365, 0.10101, ..., 0.21825, 0.12885, 0.22829])We can also estimate the survival time for each point using predict_expected_survival_time:
grid.predict_expected_survival_time(test_X)
array([103.14 , 73.356, 143.604, ..., 63.766, 114.09 , 60.43 ])We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
grid.score(train_X, train_died, train_times, criterion='harrell_c_statistic')
0.71899034Or on the test set:
grid.score(test_X, test_died, test_times, criterion='harrell_c_statistic')
0.70912965We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
grid.score(test_X, test_died, test_times, criterion='auc', evaluation_time=10)
0.68960961We can also look at the variable importance:
grid.get_learner().variable_importance()
Feature Importance
0 age 0.427724
1 creat 0.200167
2 hgb 0.154691
3 sex=M 0.115374
4 mspike 0.102045We can calculate the SHAP values:
s = grid.predict_shap(test_X)
s['shap_values']
array([[ 0.61058, 0.30582, -0.04894, -0.01826, -0.14742],
[ 1.07476, 0.16585, -0.04894, -0.00356, -0.14742],
[ 0.23345, 0.18522, 0.05296, -0.01366, -0.12105],
...,
[ 1.07879, -0.07867, 0.03954, -0.00568, 0.14121],
[ 0.6234 , -0.16209, 0.03954, -0.01578, 0.11068],
[-0.26166, 0.29492, 0.19171, 0.81775, 0.18365]])We can then use the SHAP library to visualize these results in whichever way we prefer.
GLMNetCV Survival Learner
We will fit a GLMNetCVSurvivalLearner directly. Since GLMNetCVLearner does not support missing data, we first subset to complete cases:
train_cc = ~train_X.isnull().any(axis=1)
train_X_cc = train_X[train_cc]
train_died_cc = train_died[train_cc]
train_times_cc = train_times[train_cc]
test_cc = ~test_X.isnull().any(axis=1)
test_X_cc = test_X[test_cc]
test_died_cc = test_died[test_cc]
test_times_cc = test_times[test_cc]
We can now proceed with fitting the learner:
lnr = iai.GLMNetCVSurvivalLearner(random_seed=1)
lnr.fit(train_X_cc, train_died_cc, train_times_cc)
Fitted GLMNetCVSurvivalLearner:
Constant: -3.319
Weights:
age: 0.0578244
creat: 0.18159
hgb: -0.0889831
sex=M: 0.300142We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):
pred_curves = lnr.predict(test_X_cc)
t = 10
np.array([c[t] for c in pred_curves])
array([0.2006079 , 0.37016508, 0.12258563, ..., 0.29560852, 0.19178951,
0.12366086])Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:
lnr.predict(test_X_cc, t=10)
array([0.2006079 , 0.37016508, 0.12258563, ..., 0.29560852, 0.19178951,
0.12366086])We can also estimate the survival time for each point:
lnr.predict_expected_survival_time(test_X_cc)
array([ 69.71314448, 32.21543141, 118.46926932, ..., 43.47980691,
73.42616111, 117.45965404])We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
lnr.score(train_X_cc, train_died_cc, train_times_cc,
criterion='harrell_c_statistic')
0.6837433757104808Or on the test set:
lnr.score(test_X_cc, test_died_cc, test_times_cc,
criterion='harrell_c_statistic')
0.7027583847744624We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
lnr.score(test_X_cc, test_died_cc, test_times_cc, criterion='auc', evaluation_time=10)
0.6823274764451254We can also look at the variable importance:
lnr.variable_importance()
Feature Importance
0 age 0.592080
1 hgb 0.153895
2 sex=M 0.127414
3 creat 0.126610
4 mspike 0.000000