Survival

Quick Start Guide: Optimal Survival Trees

This is a Python version of the corresponding OptimalTrees quick start guide.

In this example we will use Optimal Survival Trees (OST) on the Monoclonal Gammopathy dataset. First we load in the data and split into training and test datasets:

import pandas as pd
df = pd.read_csv("mgus2.csv")
df.sex = df.sex.astype('category')
      Unnamed: 0    id  age sex   hgb  ...  mspike  ptime  pstat  futime  death
0              1     1   88   F  13.1  ...     0.5     30      0      30      1
1              2     2   78   F  11.5  ...     2.0     25      0      25      1
2              3     3   94   M  10.5  ...     2.6     46      0      46      1
3              4     4   68   M  15.2  ...     1.2     92      0      92      1
4              5     5   90   F  10.7  ...     1.0      8      0       8      1
5              6     6   90   M  12.9  ...     0.5      4      0       4      1
6              7     7   89   F  10.5  ...     1.3    151      0     151      1
...          ...   ...  ...  ..   ...  ...     ...    ...    ...     ...    ...
1377        1378  1378   56   M  16.1  ...     0.5     59      0      59      0
1378        1379  1379   73   M  15.6  ...     0.5     48      0      48      0
1379        1380  1380   69   M  15.0  ...     0.0     22      0      22      1
1380        1381  1381   78   M  14.1  ...     1.9     35      0      35      0
1381        1382  1382   66   M  12.1  ...     0.0     31      0      31      1
1382        1383  1383   82   F  11.5  ...     2.3     38      1      61      0
1383        1384  1384   79   M   9.6  ...     1.7      6      0       6      1

[1384 rows x 11 columns]
from interpretableai import iai
X = df.iloc[:, 2:-4]
died = df.death == 1
times = df.futime
(train_X, train_died, train_times), (test_X, test_died, test_times) = (
    iai.split_data('survival', X, died, times, seed=1))

Optimal Survival Trees

We will use a GridSearch to fit an OptimalTreeSurvivor:

grid = iai.GridSearch(
    iai.OptimalTreeSurvivor(
        random_seed=1,
        missingdatamode='separate_class',
        criterion='localfulllikelihood',
    ),
    max_depth=range(1, 6),
)
grid.fit(train_X, train_died, train_times,
         validation_criterion='integratedbrier')
grid.get_learner()
Optimal Trees Visualization