Quick Start Guide: Heuristic Survival Learners

This is a Python version of the corresponding Heuristics quick start guide.

In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

import pandas as pd
df = pd.read_csv("mgus2.csv")
df.sex = df.sex.astype('category')
      Unnamed: 0    id  age sex   hgb  ...  mspike  ptime  pstat  futime  death
0              1     1   88   F  13.1  ...     0.5     30      0      30      1
1              2     2   78   F  11.5  ...     2.0     25      0      25      1
2              3     3   94   M  10.5  ...     2.6     46      0      46      1
3              4     4   68   M  15.2  ...     1.2     92      0      92      1
4              5     5   90   F  10.7  ...     1.0      8      0       8      1
5              6     6   90   M  12.9  ...     0.5      4      0       4      1
6              7     7   89   F  10.5  ...     1.3    151      0     151      1
...          ...   ...  ...  ..   ...  ...     ...    ...    ...     ...    ...
1377        1378  1378   56   M  16.1  ...     0.5     59      0      59      0
1378        1379  1379   73   M  15.6  ...     0.5     48      0      48      0
1379        1380  1380   69   M  15.0  ...     0.0     22      0      22      1
1380        1381  1381   78   M  14.1  ...     1.9     35      0      35      0
1381        1382  1382   66   M  12.1  ...     0.0     31      0      31      1
1382        1383  1383   82   F  11.5  ...     2.3     38      1      61      0
1383        1384  1384   79   M   9.6  ...     1.7      6      0       6      1

[1384 rows x 11 columns]
from interpretableai import iai
X = df.iloc[:, 2:-4]
died = df.death == 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) = (
    iai.split_data('survival', X, died, times, seed=12345))

Random Forest Survival Learner

We will use a GridSearch to fit a RandomForestSurvivalLearner with some basic parameter validation:

grid = iai.GridSearch(
    iai.RandomForestSurvivalLearner(
        missingdatamode='separate_class',
        random_seed=1,
    ),
    max_depth=range(5, 11),
)
grid.fit(train_X, train_died, train_times)
All Grid Results:

 Row │ max_depth  train_score  valid_score  rank_valid_score
     │ Int64      Float64      Float64      Int64
─────┼───────────────────────────────────────────────────────
   1 │         5     0.291592     0.216862                 1
   2 │         6     0.310289     0.212261                 2
   3 │         7     0.322973     0.206279                 3
   4 │         8     0.329134     0.202891                 4
   5 │         9     0.331598     0.20266                  5
   6 │        10     0.332294     0.202477                 6

Best Params:
  max_depth => 5

Best Model - Fitted RandomForestSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves = grid.predict(test_X)
t = 10
import numpy as np
np.array([c[t] for c in pred_curves])
array([0.18739649, 0.21215498, 0.0925912 , ..., 0.14985601, 0.08245875,
       0.17559601])

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

grid.predict(test_X, t=10)
array([0.18739649, 0.21215498, 0.0925912 , ..., 0.14985601, 0.08245875,
       0.17559601])

We can also estimate the survival time for each point using predict_expected_survival_time:

grid.predict_expected_survival_time(test_X)
array([ 62.9600698 ,  60.4519118 , 128.12247219, ...,  56.40262151,
        86.15765599, 116.58587113])

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

grid.score(train_X, train_died, train_times, criterion='harrell_c_statistic')
0.728648073106186

Or on the test set:

grid.score(test_X, test_died, test_times, criterion='harrell_c_statistic')
0.7130325263974928

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

grid.score(test_X, test_died, test_times, criterion='auc', evaluation_time=10)
0.7750750750750736

We can also look at the variable importance:

grid.get_learner().variable_importance()
  Feature  Importance
0     age    0.626810
1     hgb    0.187153
2   creat    0.109443
3  mspike    0.048901
4     sex    0.027694

XGBoost Survival Learner

We will use a GridSearch to fit an XGBoostSurvivalLearner with some basic parameter validation:

grid = iai.GridSearch(
    iai.XGBoostSurvivalLearner(
        random_seed=1,
    ),
    max_depth=range(2, 6),
    num_round=[20, 50, 100],
)
grid.fit(train_X, train_died, train_times)
All Grid Results:

 Row │ num_round  max_depth  train_score  valid_score  rank_valid_score
     │ Int64      Int64      Float64      Float64      Int64
─────┼──────────────────────────────────────────────────────────────────
   1 │        20          2    0.215788     0.15041                   1
   2 │        20          3    0.194945    -0.0670844                 3
   3 │        20          4    0.183779    -0.121177                  4
   4 │        20          5    0.0916702   -0.401915                  7
   5 │        50          2    0.210932     0.0262287                 2
   6 │        50          3    0.153306    -0.217315                  6
   7 │        50          4    0.0281122   -0.513705                  9
   8 │        50          5   -0.180287    -0.87479                  10
   9 │       100          2    0.182667    -0.132212                  5
  10 │       100          3    0.0250015   -0.512682                  8
  11 │       100          4   -0.345342    -1.03887                  11
  12 │       100          5   -0.907852    -1.75593                  12

Best Params:
  num_round => 20
  max_depth => 2

Best Model - Fitted XGBoostSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves = grid.predict(test_X)
t = 10
import numpy as np
np.array([c[t] for c in pred_curves])
array([0.14218, 0.19365, 0.10101, ..., 0.21825, 0.12885, 0.22829])

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

grid.predict(test_X, t=10)
array([0.14218, 0.19365, 0.10101, ..., 0.21825, 0.12885, 0.22829])

We can also estimate the survival time for each point using predict_expected_survival_time:

grid.predict_expected_survival_time(test_X)
array([103.14 ,  73.356, 143.604, ...,  63.766, 114.09 ,  60.43 ])

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

grid.score(train_X, train_died, train_times, criterion='harrell_c_statistic')
0.71899034

Or on the test set:

grid.score(test_X, test_died, test_times, criterion='harrell_c_statistic')
0.70912965

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

grid.score(test_X, test_died, test_times, criterion='auc', evaluation_time=10)
0.68960961

We can also look at the variable importance:

grid.get_learner().variable_importance()
  Feature  Importance
0     age    0.427724
1   creat    0.200167
2     hgb    0.154691
3   sex=M    0.115374
4  mspike    0.102045

We can calculate the SHAP values:

s = grid.predict_shap(test_X)
s['shap_values']
array([[ 0.61058,  0.30582, -0.04894, -0.01826, -0.14742],
       [ 1.07476,  0.16585, -0.04894, -0.00356, -0.14742],
       [ 0.23345,  0.18522,  0.05296, -0.01366, -0.12105],
       ...,
       [ 1.07879, -0.07867,  0.03954, -0.00568,  0.14121],
       [ 0.6234 , -0.16209,  0.03954, -0.01578,  0.11068],
       [-0.26166,  0.29492,  0.19171,  0.81775,  0.18365]])

We can then use the SHAP library to visualize these results in whichever way we prefer.

GLMNetCV Survival Learner

We will fit a GLMNetCVSurvivalLearner directly. Since GLMNetCVLearner does not support missing data, we first subset to complete cases:

train_cc = ~train_X.isnull().any(axis=1)
train_X_cc = train_X[train_cc]
train_died_cc = train_died[train_cc]
train_times_cc = train_times[train_cc]

test_cc = ~test_X.isnull().any(axis=1)
test_X_cc = test_X[test_cc]
test_died_cc = test_died[test_cc]
test_times_cc = test_times[test_cc]

We can now proceed with fitting the learner:

lnr = iai.GLMNetCVSurvivalLearner(random_seed=1)
lnr.fit(train_X_cc, train_died_cc, train_times_cc)
Fitted GLMNetCVSurvivalLearner:
  Constant: -3.319
  Weights:
    age:    0.0578244
    creat:  0.18159
    hgb:   -0.0889831
    sex=M:  0.300142

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the mortality probability for any given time t (in this case for t = 10):

pred_curves = lnr.predict(test_X_cc)
t = 10
np.array([c[t] for c in pred_curves])
array([0.2006079 , 0.37016508, 0.12258563, ..., 0.29560852, 0.19178951,
       0.12366086])

Alternatively, we can get this mortality probability for any given time t by providing t as a keyword argument directly:

lnr.predict(test_X_cc, t=10)
array([0.2006079 , 0.37016508, 0.12258563, ..., 0.29560852, 0.19178951,
       0.12366086])

We can also estimate the survival time for each point:

lnr.predict_expected_survival_time(test_X_cc)
array([ 69.71314448,  32.21543141, 118.46926932, ...,  43.47980691,
        73.42616111, 117.45965404])

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

lnr.score(train_X_cc, train_died_cc, train_times_cc,
          criterion='harrell_c_statistic')
0.6837433757104808

Or on the test set:

lnr.score(test_X_cc, test_died_cc, test_times_cc,
          criterion='harrell_c_statistic')
0.7027583847744624

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:

lnr.score(test_X_cc, test_died_cc, test_times_cc, criterion='auc', evaluation_time=10)
0.6823274764451254

We can also look at the variable importance:

lnr.variable_importance()
  Feature  Importance
0     age    0.592080
1     hgb    0.153895
2   sex=M    0.127414
3   creat    0.126610
4  mspike    0.000000