Quick Start Guide: Optimal Survival Trees
This is an R version of the corresponding OptimalTrees quick start guide.
In this example we will use Optimal Survival Trees (OST) on the Monoclonal Gammopathy dataset to predict patient survival time (refer to the data preparation guide to learn more about the data format for survival problems). First we load in the data and split into training and test datasets:
df <- read.table("mgus2.csv", sep = ",", header = T, stringsAsFactors = T)
X id age sex hgb creat mspike ptime pstat futime death
1 1 1 88 F 13.1 1.3 0.5 30 0 30 1
2 2 2 78 F 11.5 1.2 2.0 25 0 25 1
3 3 3 94 M 10.5 1.5 2.6 46 0 46 1
4 4 4 68 M 15.2 1.2 1.2 92 0 92 1
5 5 5 90 F 10.7 0.8 1.0 8 0 8 1
[ reached 'max' / getOption("max.print") -- omitted 1379 rows ]
X <- df[, 3:7]
died <- df$death == 1
times <- df$futime
split <- iai::split_data("survival", X, died, times, seed = 12345)
train_X <- split$train$X
train_died <- split$train$deaths
train_times <- split$train$times
test_X <- split$test$X
test_died <- split$test$deaths
test_times <- split$test$times
Optimal Survival Trees
We will use a grid_search
to fit an optimal_tree_survival_learner
:
grid <- iai::grid_search(
iai::optimal_tree_survival_learner(
random_seed = 1,
missingdatamode = "separate_class",
minbucket = 15,
),
max_depth = 1:2,
)
iai::fit(grid, train_X, train_died, train_times,
validation_criterion = "harrell_c_statistic")
iai::get_learner(grid)
The survival tree shows the Kaplan-Meier survival curve for the population in each node as a solid red line.
In each split node:
- the dotted green line shows the survival curve of the population in the lower/left child
- the dotted blue line shows the survival curve of the population in the upper/right child
This means that the distance between the green and blue lines gives an indication on how well the split separates the two groups.
In each leaf node, the dotted black line shows the survival curve of the entire population, which allows us to easily see how the survival outlook for this subpopulation differs from the entire population.
In this example, age
and hgb
partition the population into three subgroups with distinct survival patterns:
- Node 3: when
age < 67.5
andhgb < 12.25
, the population has a survival similar to the overall average - Node 4: when
age < 67.5
andhgb > 12.25
, the survival is significantly better than average - Node 6: when
67.5 < age < 77.5
, the survival is similar to the overall average - Node 7: when
age > 77.5
, the survival is significantly worse than average
We can make predictions on new data using predict
. For survival trees, this returns the appropriate survival curve for each point, which we can then use to query the survival probability for any given time t
(in this case for t = 10
):
pred_curves <- iai::predict(grid, test_X)
t <- 10
pred_curves[[1]][t]
[1] 0.1655385
You can also query the probability for all of the points together by applying the [
function to each curve:
sapply(pred_curves, `[`, t)
[1] 0.16553847 0.16553847 0.09713376 0.16553847 0.16553847 0.09713376
[7] 0.24537037 0.16553847 0.09713376 0.09713376 0.09713376 0.16553847
[13] 0.16553847 0.16553847 0.09713376 0.16553847 0.16553847 0.09713376
[19] 0.16553847 0.09713376 0.16553847 0.16553847 0.09713376 0.16553847
[25] 0.09713376 0.24537037 0.16553847 0.09713376 0.03283167 0.03283167
[31] 0.03283167 0.09713376 0.03283167 0.24537037 0.09713376 0.16553847
[37] 0.09713376 0.03283167 0.16553847 0.16553847 0.03283167 0.16553847
[43] 0.09713376 0.16553847 0.03283167 0.09713376 0.09713376 0.03283167
[49] 0.09713376 0.09713376 0.03283167 0.09713376 0.03283167 0.09713376
[55] 0.09713376 0.16553847 0.09713376 0.24537037 0.03283167 0.09713376
[ reached getOption("max.print") -- omitted 355 entries ]
Alternatively, you can get this mortality probability for any given time t
by providing t
as a keyword argument directly:
iai::predict(grid, test_X, t = 10)
[1] 0.16553847 0.16553847 0.09713376 0.16553847 0.16553847 0.09713376
[7] 0.24537037 0.16553847 0.09713376 0.09713376 0.09713376 0.16553847
[13] 0.16553847 0.16553847 0.09713376 0.16553847 0.16553847 0.09713376
[19] 0.16553847 0.09713376 0.16553847 0.16553847 0.09713376 0.16553847
[25] 0.09713376 0.24537037 0.16553847 0.09713376 0.03283167 0.03283167
[31] 0.03283167 0.09713376 0.03283167 0.24537037 0.09713376 0.16553847
[37] 0.09713376 0.03283167 0.16553847 0.16553847 0.03283167 0.16553847
[43] 0.09713376 0.16553847 0.03283167 0.09713376 0.09713376 0.03283167
[49] 0.09713376 0.09713376 0.03283167 0.09713376 0.03283167 0.09713376
[55] 0.09713376 0.16553847 0.09713376 0.24537037 0.03283167 0.09713376
[ reached getOption("max.print") -- omitted 355 entries ]
We can evaluate the quality of the tree using score
with any of the supported loss functions. For example, the Harrell's c-statistic on the training set:
iai::score(grid, train_X, train_died, train_times,
criterion = "harrell_c_statistic")
[1] 0.6608295
Or on the test set:
iai::score(grid, test_X, test_died, test_times,
criterion = "harrell_c_statistic")
[1] 0.6668717
We can also evaluate the performance of the tree at a particular point in time using classification criteria. For instance, we can evaluate the AUC of the 10-month survival predictions on the test set:
iai::score(grid, test_X, test_died, test_times, criterion = "auc",
evaluation_time = 10)
[1] 0.6793093