Quick Start Guide: Optimal Survival Trees

This is a Python version of the corresponding OptimalTrees quick start guide.

In this example we will use Optimal Survival Trees (OST) on the Monoclonal Gammopathy dataset. First we load in the data and split into training and test datasets:

import pandas as pd
df = pd.read_csv("mgus2.csv")
df.sex = df.sex.astype('category')
      Unnamed: 0    id  age sex   hgb  ...  mspike  ptime  pstat  futime  death
0              1     1   88   F  13.1  ...     0.5     30      0      30      1
1              2     2   78   F  11.5  ...     2.0     25      0      25      1
2              3     3   94   M  10.5  ...     2.6     46      0      46      1
3              4     4   68   M  15.2  ...     1.2     92      0      92      1
4              5     5   90   F  10.7  ...     1.0      8      0       8      1
5              6     6   90   M  12.9  ...     0.5      4      0       4      1
6              7     7   89   F  10.5  ...     1.3    151      0     151      1
...          ...   ...  ...  ..   ...  ...     ...    ...    ...     ...    ...
1377        1378  1378   56   M  16.1  ...     0.5     59      0      59      0
1378        1379  1379   73   M  15.6  ...     0.5     48      0      48      0
1379        1380  1380   69   M  15.0  ...     0.0     22      0      22      1
1380        1381  1381   78   M  14.1  ...     1.9     35      0      35      0
1381        1382  1382   66   M  12.1  ...     0.0     31      0      31      1
1382        1383  1383   82   F  11.5  ...     2.3     38      1      61      0
1383        1384  1384   79   M   9.6  ...     1.7      6      0       6      1

[1384 rows x 11 columns]
from interpretableai import iai
X = df.iloc[:, 2:-4]
died = df.death == 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) = (
    iai.split_data('survival', X, died, times, seed=1))

Optimal Survival Trees

We will use a GridSearch to fit an OptimalTreeSurvivor:

grid = iai.GridSearch(
    iai.OptimalTreeSurvivor(
        random_seed=1,
        missingdatamode='separate_class',
        criterion='localfulllikelihood',
    ),
    max_depth=range(1, 6),
)
grid.fit(train_X, train_died, train_times,
         validation_criterion='integratedbrier')
grid.get_learner()
Optimal Trees Visualization

We can make predictions on new data using predict. For survival trees, this returns the appropriate SurvivalCurve for each point, which we can then use to query the survival probability at any time t (in this case for t = 10):

pred_curves = grid.predict(test_X)
t = 10
[c[t] for c in pred_curves]
[0.2262443438914027, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.2262443438914027, 0.2262443438914027, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.2262443438914027, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.2262443438914027, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.2262443438914027, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.2262443438914027, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.2262443438914027, 0.2262443438914027, 0.10077519379844957, 0.040166204986149756, 0.2262443438914027, 0.2262443438914027, 0.2262443438914027, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.2262443438914027, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.2262443438914027, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.2262443438914027, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.10077519379844957, 0.040166204986149756, 0.2262443438914027, 0.040166204986149756, 0.2262443438914027, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.040166204986149756, 0.2262443438914027, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.040166204986149756, 0.10077519379844957, 0.2262443438914027, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.2262443438914027, 0.040166204986149756, 0.040166204986149756, 0.2262443438914027, 0.040166204986149756, 0.040166204986149756, 0.10077519379844957, 0.10077519379844957, 0.2262443438914027, 0.10077519379844957, 0.10077519379844957, 0.0401662