Decision Tree Structure
This page contains a guide to the structure of trees created by IAI algorithms, and defines a number of terms that we use when referring to the tree structure.
A tree is comprised of nodes that can either be a split or a leaf:
- A split node has two children that we refer to as the lower child and upper child respectively, and contains a split rule for deciding whether a given point goes to the lower or upper branch.
- A leaf node contains a prediction rule for making predictions for a given point, and the type of prediction made will depend on the problem task (classification, regression, etc.)
The tree makes predictions by sending each point down the tree, starting from the root node. At a split node, the point is sent to the lower or upper child based on the split rule in the node. When it reaches a leaf node, the prediction rule is applied to generate a prediction for this point.
The number of nodes in the tree can be found with get_num_nodes
.
The following information is available for all types of nodes:
is_leaf
checks whether a node is a leafget_depth
returns the depth of the node in the treeget_parent
returns the index of this node's parent in the treeget_num_samples
r eturns the number of training samples that reached this node
Split Nodes
At a split node, you can retrieve the lower and upper children using get_lower_child
and get_upper_child
. In visualizations, the lower child is displayed as either the left or top child, while the upper child is displayed as the right or bottom child. All nodes are also labeled with their index in the visualization for easy reference.
There are a number of different types of split rules that can be applied by a split node. You can use the following functions to determine the type of split:
is_parallel_split
is_hyperplane_split
is_categoric_split
is_ordinal_split
is_mixed_parallel_split
is_mixed_ordinal_split
Parallel splits
A parallel split is the split typically used by other decision tree methods like CART. It specifies a single numeric feature $j$ in the data, and a threshold value $b$. A point $\mathbf{x}$ follows the lower branch if its value in this feature is lower than the threshold:
\[x_j < b\]
This information can be queried using:
Hyperplane splits
A hyperplane split applies a split that involves many features of the data at once. It specifies a set of features $\mathcal{J}$, with corresponding weights $a_j$ for each feature $j$ in $\mathcal{J}$, as well as a threshold value $b$. Categoric features can be included by specifying individual weights for any or all of the categoric levels. A point $\mathbf{x}$ follows the lower branch if the weighted sum of these features is lower than the threshold:
\[\sum_{j \in \mathcal{J}} a_j x_j < b\]
This information can be queried using:
Categoric splits
A categoric split applies a split to a single categoric feature $j$ in the data, and a subset $\mathcal{L}$ of the categoric levels of this categoric feature. A point $\mathbf{x}$ follows the lower branch if its value in this feature belongs to the subset:
\[x_j \in \mathcal{L}\]
This information can be queried using:
Ordinal splits
An ordinal split is similar to a categoric split in that it applies a split to a single ordinal feature $j$ in the data, and a subset $\mathcal{L}$ of the levels of this feature. The key difference to a categoric split is that all levels in the subset $\mathcal{L}$ are less than all features not in $\mathcal{L}$, meaning the split respects the ordering of the ordinal feature:
\[x_j \in \mathcal{L}\]
This information can be queried using:
Mixed splits
A mixed split is a split rule that applies to mixed data. The split rule combines the split rules of either parallel and categoric, or ordinal and categoric as appropriate.
Mixed parallel splits
A mixed parallel split applies a split for a single mixed numeric/categoric feature $j$ in the data. It selects a threshold $b$ and a subset of the categoric levels $\mathcal{L}$. A point $\mathbf{x}$ follows the lower branch if its value in this feature is numeric and is lower than the threshold, or if the value is categoric and belongs to the subset:
\[\begin{cases} x_j < b, & \text{if } x_j \text{ is numeric}\\ x_j \in \mathcal{L}, & \text{otherwise} \end{cases}\]
This information can be queried using:
Mixed ordinal splits
A mixed ordinal split applies a split for a single mixed ordinal/categoric feature $j$ in the data. It specifies a subset of the categoric and ordinal levels $\mathcal{L}$, with the restriction that all ordinal levels in the subset must be less than all ordinal levels not in the subset. A point $\mathbf{x}$ follows the lower branch if the value in this feature belongs to the subset:
\[x_j \in \mathcal{L}\]
This information can be queried using:
Missing data
Regardless of the type of split rule, the split node also contains a separate rule in case of missing data. For a given point, if any of the information required to decide the split rule is missing, then the rule for missing data at this split will specify to send this point to either the lower or upper child.
This information can be queried using:
Leaf Nodes
Each leaf node in the tree contains one or more prediction rules that specify how it generates predictions for new points that fall into the leaf. The predictions rules depend on the type of leaf node.
Classification leaves
A classification leaf predicts a label for points that fall into this leaf. It also predicts the probability that a point falling into this leaf takes each of the possible labels.
This information can be queried using:
If the leaf has a logistic regression fit, you can also query the weights and the constant using:
Regression leaves
A regression leaf predicts a continuous outcome for each point falling into the leaf using a linear regression equation. It specifies a set of features $\mathcal{J}$, with corresponding weights $\beta_j$ for each feature $j$ in $\mathcal{J}$, as well as a constant value $\beta_0$. Categoric features can be included by specifying individual weights for any or all of the categoric levels. The prediction for a point $\mathbf{x}$ is given by:
\[\sum_{j \in \mathcal{J}} \beta_j x_j + \beta_0\]
Note that the regression weights are often all zero, in which case the leaf simply makes a single constant prediction for all points.
This information can be queried using:
Survival leaves
A survival leaf predicts the survival probability over time for points that fall into the leaf, using a Kaplan-Meier curve. From this curve, it can also predict an expected survival time.
Each survival leaf also predicts a hazard ratio, giving an estimate of the risk of this leaf relative to the baseline risk (a hazard above 1 indicates increased risk, and below 1 indicates decreased risk).
This information can be queried using:
If the leaf has a Cox regression fit, you can also query the weights and the constant using:
Prescription leaves
A prescription leaf ranks the possible treatments in order from most effective to least effective. When generating a new prescription, it prescribes the most effective treatment that is permissible for a given point. The leaf also contains a predictive regression model for each treatment in the leaf, of the same form as regression leaves. For points falling into the leaf, it predicts the outcome under a given treatment using the corresponding predictive model.
This information can be queried using:
Policy leaves
A policy leaf ranks the possible treatments in order from most effective to least effective. Similar to a prescription leaf, when generating a new prescription, it prescribes the most effective treatment that is permissible for a given point.
The information on ranking among treatments can be queried using:
The effectiveness of each treatment, as used to determine the ranking, can be queried using:
Examples of querying tree structure
Re-implementing IAI.apply
The following example demonstrates how we could use the IAITrees API to reimplement the apply
function:
using DataFrames
function my_apply(lnr::IAI.TreeLearner, X::DataFrame)
[my_apply(lnr, X, i) for i = 1:nrow(X)]
end
function my_apply(lnr::IAI.TreeLearner, X::DataFrame, i::Int)
t = 1
while true
if IAI.is_leaf(lnr, t)
return t
end
if IAI.is_hyperplane_split(lnr, t)
numeric_weights, categoric_weights = IAI.get_split_weights(lnr, t)
split_value = 0.0
value_missing = false
for (feature, weight) in numeric_weights
x = X[i, feature]
value_missing = value_missing | ismissing(x)
split_value += weight * x
end
for (feature, level_weights) in categoric_weights
x = X[i, feature]
value_missing = value_missing | ismissing(x_$feature)
for (level, weight) in level_weights
if x == level
split_value += weight
end
end
end
threshold = IAI.get_split_threshold(lnr, t)
goes_lower = split_value < threshold
else
feature = IAI.get_split_feature(lnr, t)
x = X[i, feature]
value_missing = ismissing(x)
if IAI.is_ordinal_split(lnr, t) ||
IAI.is_categoric_split(lnr, t) ||
IAI.is_mixed_ordinal_split(lnr, t)
categories = IAI.get_split_categories(lnr, t)
goes_lower = categories[x]
elseif IAI.is_parallel_split(lnr, t)
threshold = IAI.get_split_threshold(lnr, t)
goes_lower = x < threshold
elseif IAI.is_mixed_parallel_split(lnr, t)
threshold = IAI.get_split_threshold(lnr, t)
categories = IAI.get_split_categories(lnr, t)
goes_lower = isa(x, Real) ? x < threshold : categories[x]
end
end
if value_missing
goes_lower = IAI.missing_goes_lower(lnr, t)
end
t = goes_lower ? get_lower_child(lnr, t) : get_upper_child(lnr, t)
end
end