Quick Start Guide: Optimal Survival Trees

This is an R version of the corresponding OptimalTrees quick start guide.

In this example we will use Optimal Survival Trees (OST) on the Monoclonal Gammopathy dataset. First we load in the data and split into training and test datasets:

df <- read.table("mgus2.csv", sep = ",", header = T)
  X id age sex  hgb creat mspike ptime pstat futime death
1 1  1  88   F 13.1   1.3    0.5    30     0     30     1
2 2  2  78   F 11.5   1.2    2.0    25     0     25     1
3 3  3  94   M 10.5   1.5    2.6    46     0     46     1
4 4  4  68   M 15.2   1.2    1.2    92     0     92     1
5 5  5  90   F 10.7   0.8    1.0     8     0      8     1
 [ reached 'max' / getOption("max.print") -- omitted 1379 rows ]
X <- df[, 3:7]
died <- df$death == 1
times <- df$futime
split <- iai::split_data("survival", X, died, times, seed = 1)
train_X <- split$train$X
train_died <- split$train$deaths
train_times <- split$train$times
test_X <- split$test$X
test_died <- split$test$deaths
test_times <- split$test$times

Optimal Survival Trees

We will use a grid_search to fit an optimal_tree_survivor:

grid <- iai::grid_search(
    iai::optimal_tree_survivor(
        random_seed = 1,
        missingdatamode = "separate_class",
        criterion = "localfulllikelihood",
    ),
    max_depth = 1:5,
)
iai::fit(grid, train_X, train_died, train_times,
         validation_criterion = "integratedbrier")
iai::get_learner(grid)
Optimal Trees Visualization

We can make predictions on new data using predict. For survival trees, this returns the appropriate survival curve for each point, which we can then use to query the survival probability at any time t (in this case for t = 10):

pred_curves <- iai::predict(grid, test_X)
t <- 10
pred_curves[1][t]
[1] 0.2262443

You can also query the probability for all of the points together by applying the [ function to each curve:

sapply(pred_curves, `[`, t)
 [1] 0.2262443 0.1007752 0.2262443 0.1007752 0.1007752 0.2262443 0.1007752
 [8] 0.1007752 0.2262443 0.2262443 0.1007752 0.1007752 0.2262443 0.2262443
[15] 0.2262443 0.2262443 0.1007752 0.1007752 0.1007752 0.1007752 0.1007752
[22] 0.1007752 0.2262443 0.1007752 0.1007752 0.0401662 0.1007752 0.2262443
[29] 0.1007752 0.1007752 0.1007752 0.0401662 0.1007752 0.0401662 0.1007752
[36] 0.1007752 0.2262443 0.2262443 0.0401662 0.0401662 0.2262443 0.1007752
[43] 0.1007752 0.0401662 0.1007752 0.1007752 0.0401662 0.1007752 0.1007752
[50] 0.1007752 0.0401662 0.0401662 0.0401662 0.1007752 0.0401662 0.1007752
[57] 0.1007752 0.0401662 0.1007752 0.0401662
 [ reached getOption("max.print") -- omitted 355 entries ]

We can evaluate the quality of the tree using score with any of the supported loss functions. For example, the integrated brier score on the training set:

iai::score(grid, train_X, train_died, train_times,
           criterion = "integratedbrier")
[1] 0.1794539

Or the log-likelihood on the test set:

iai::score(grid, test_X, test_died, test_times,
           criterion = "localfulllikelihood")
[1] 0.1147257