Quick Start Guide: Heuristic Survival Learners

This is a Python version of the corresponding Heuristics quick start guide.

In this example we will use survival learners from Heuristics on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:

import pandas as pd
df = pd.read_csv("mgus2.csv")
df.sex = df.sex.astype('category')
      Unnamed: 0    id  age sex   hgb  ...  mspike  ptime  pstat  futime  death
0              1     1   88   F  13.1  ...     0.5     30      0      30      1
1              2     2   78   F  11.5  ...     2.0     25      0      25      1
2              3     3   94   M  10.5  ...     2.6     46      0      46      1
3              4     4   68   M  15.2  ...     1.2     92      0      92      1
4              5     5   90   F  10.7  ...     1.0      8      0       8      1
5              6     6   90   M  12.9  ...     0.5      4      0       4      1
6              7     7   89   F  10.5  ...     1.3    151      0     151      1
...          ...   ...  ...  ..   ...  ...     ...    ...    ...     ...    ...
1377        1378  1378   56   M  16.1  ...     0.5     59      0      59      0
1378        1379  1379   73   M  15.6  ...     0.5     48      0      48      0
1379        1380  1380   69   M  15.0  ...     0.0     22      0      22      1
1380        1381  1381   78   M  14.1  ...     1.9     35      0      35      0
1381        1382  1382   66   M  12.1  ...     0.0     31      0      31      1
1382        1383  1383   82   F  11.5  ...     2.3     38      1      61      0
1383        1384  1384   79   M   9.6  ...     1.7      6      0       6      1

[1384 rows x 11 columns]
from interpretableai import iai
X = df.iloc[:, 2:-4]
died = df.death == 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) = (
    iai.split_data('survival', X, died, times, seed=12345))

Random Forest Survival Learner

We will use a GridSearch to fit a RandomForestSurvivalLearner with some basic parameter validation:

grid = iai.GridSearch(
    iai.RandomForestSurvivalLearner(
        missingdatamode='separate_class',
        random_seed=1,
    ),
    max_depth=range(5, 11),
)
grid.fit(train_X, train_died, train_times)
All Grid Results:

 Row │ max_depth  train_score  valid_score  rank_valid_score
     │ Int64      Float64      Float64      Int64
─────┼───────────────────────────────────────────────────────
   1 │         5     0.291592     0.216862                 1
   2 │         6     0.310289     0.212261                 2
   3 │         7     0.322973     0.206279                 3
   4 │         8     0.329134     0.202891                 4
   5 │         9     0.331598     0.20266                  5
   6 │        10     0.332294     0.202477                 6

Best Params:
  max_depth => 5

Best Model - Fitted RandomForestSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves = grid.predict(test_X)
t = 10
import numpy as np
np.array([c[t] for c in pred_curves])
array([0.18739649, 0.21215498, 0.0925912 , ..., 0.14985601, 0.08245875,
       0.17559601])

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

grid.predict(test_X, t=10)
array([0.18739649, 0.21215498, 0.0925912 , ..., 0.14985601, 0.08245875,
       0.17559601])

We can also estimate the survival time for each point using predict_expected_survival_time:

grid.predict_expected_survival_time(test_X)
array([ 62.9600698 ,  60.4519118 , 128.12247219, ...,  56.40262151,
        86.15765599, 116.58587113])

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

grid.score(train_X, train_died, train_times, criterion='harrell_c_statistic')
0.728648073106186

Or on the test set:

grid.score(test_X, test_died, test_times, criterion='harrell_c_statistic')
0.7130325263974928

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 100-month survival predictions on the test set:

grid.score(test_X, test_died, test_times, criterion='auc', evaluation_time=10)
0.7750750750750736

We can also look at the variable importance:

grid.get_learner().variable_importance()
  Feature  Importance
0     age    0.626810
1     hgb    0.187153
2   creat    0.109443
3  mspike    0.048901
4     sex    0.027694

XGBoost Survival Learner

We will use a GridSearch to fit an XGBoostSurvivalLearner with some basic parameter validation:

grid = iai.GridSearch(
    iai.XGBoostSurvivalLearner(
        random_seed=1,
    ),
    max_depth=range(2, 6),
    num_round=[20, 50, 100],
)
grid.fit(train_X, train_died, train_times)
All Grid Results:

 Row │ num_round  max_depth  train_score  valid_score  rank_valid_score
     │ Int64      Int64      Float64      Float64      Int64
─────┼──────────────────────────────────────────────────────────────────
   1 │        20          2    0.221512    0.164622                   1
   2 │        20          3    0.175214   -0.134411                   4
   3 │        20          4    0.143079   -0.120157                   3
   4 │        20          5    0.0489262  -0.474636                   7
   5 │        50          2    0.196386   -0.00737584                 2
   6 │        50          3    0.138169   -0.307373                   6
   7 │        50          4   -0.096126   -0.571589                   8
   8 │        50          5   -0.455181   -1.27913                   11
   9 │       100          2    0.11895    -0.226264                   5
  10 │       100          3   -0.122828   -0.851117                   9
  11 │       100          4   -0.558718   -1.20169                   10
  12 │       100          5   -1.32397    -2.35255                   12

Best Params:
  num_round => 20
  max_depth => 2

Best Model - Fitted XGBoostSurvivalLearner

We can make predictions on new data using predict. For survival learners, this returns the appropriate SurvivalCurve for each point, which we can then use to query the survival probability for any given time t (in this case for t = 10):

pred_curves = grid.predict(test_X)
t = 10
import numpy as np
np.array([c[t] for c in pred_curves])
array([0.14660201, 0.18197281, 0.08934319, ..., 0.17136217, 0.13601602,
       0.19172853])

Alternatively, you can get this mortality probability for any given time t by providing t as a keyword argument directly:

grid.predict(test_X, t=10)
array([0.14660201, 0.18197281, 0.08934319, ..., 0.17136217, 0.13601602,
       0.19172853])

We can also estimate the survival time for each point using predict_expected_survival_time:

grid.predict_expected_survival_time(test_X)
array([ 99.87082531,  78.76795635, 159.46015182, ...,  84.2790036 ,
       107.98765307,  74.20213651])

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:

grid.score(train_X, train_died, train_times, criterion='harrell_c_statistic')
0.7193754134965267

Or on the test set:

grid.score(test_X, test_died, test_times, criterion='harrell_c_statistic')
0.7093859379347714

We can also evaluate the performance of the model at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 100-month survival predictions on the test set:

grid.score(test_X, test_died, test_times, criterion='auc', evaluation_time=10)
0.689069069069069

We can also look at the variable importance:

grid.get_learner().variable_importance()
  Feature  Importance
0     age    0.655875
1     hgb    0.148613
2   creat    0.100927
3  mspike    0.055758
4   sex_M    0.038827

We can calculate the SHAP values:

s = grid.predict_shap(test_X)
s['shap_values']
array([[ 0.64009756,  0.29738566, -0.03529813, -0.02256296, -0.14502457],
       [ 0.97919667,  0.17663382, -0.03529813, -0.00424254, -0.14502457],
       [ 0.21231405,  0.13646832, -0.00737888, -0.014659  , -0.11917135],
       ...,
       [ 0.87897521, -0.12985606,  0.02753258, -0.00512661,  0.13342674],
       [ 0.64914954, -0.15287514,  0.06906411, -0.01554307,  0.10384575],
       [-0.35382563,  0.31217802,  0.22095543,  0.68206966,  0.16790301]])

We can then use the SHAP library to visualize these results in whichever way we prefer.