Decision Tree Structure

This page contains a guide to the structure of trees created by IAI algorithms, and defines a number of terms that we use when referring to the tree structure.

A tree is comprised of nodes that can either be a split or a leaf:

  • A split node has two children that we refer to as the lower child and upper child respectively, and contains a split rule for deciding whether a given point goes to the lower or upper branch.
  • A leaf node contains a prediction rule for making predictions for a given point, and the type of prediction made will depend on the problem task (classification, regression, etc.)

The tree makes predictions by sending each point down the tree, starting from the root node. At a split node, the point is sent to the lower or upper child based on the split rule in the node. When it reaches a leaf node, the prediction rule is applied to generate a prediction for this point.

The number of nodes in the tree can be found with get_num_nodes.

The following information is available for all types of nodes:

  • is_leaf checks whether a node is a leaf
  • get_depth returns the depth of the node in the tree
  • get_parent returns the index of this node's parent in the tree
  • get_num_samples r eturns the number of training samples that reached this node

Split Nodes

At a split node, you can retrieve the lower and upper children using get_lower_child and get_upper_child. In visualizations, the lower child is displayed as either the left or top child, while the upper child is displayed as the right or bottom child. All nodes are also labeled with their index in the visualization for easy reference.

There are a number of different types of split rules that can be applied by a split node. You can use the following functions to determine the type of split:

Parallel splits

A parallel split is the split typically used by other decision tree methods like CART. It specifies a single numeric feature $j$ in the data, and a threshold value $b$. A point $\mathbf{x}$ follows the lower branch if its value in this feature is lower than the threshold:

\[x_j < b\]

This information can be queried using:

Hyperplane splits

A hyperplane split applies a split that involves many features of the data at once. It specifies a set of features $\mathcal{J}$, with corresponding weights $a_j$ for each feature $j$ in $\mathcal{J}$, as well as a threshold value $b$. Categoric features can be included by specifying individual weights for any or all of the categoric levels. A point $\mathbf{x}$ follows the lower branch if the weighted sum of these features is lower than the threshold:

\[\sum_{j \in \mathcal{J}} a_j x_j < b\]

This information can be queried using:

Categoric splits

A categoric split applies a split to a single categoric feature $j$ in the data, and a subset $\mathcal{L}$ of the categoric levels of this categoric feature. A point $\mathbf{x}$ follows the lower branch if its value in this feature belongs to the subset:

\[x_j \in \mathcal{L}\]

This information can be queried using:

Ordinal splits

An ordinal split is similar to a categoric split in that it applies a split to a single ordinal feature $j$ in the data, and a subset $\mathcal{L}$ of the levels of this feature. The key difference to a categoric split is that all levels in the subset $\mathcal{L}$ are less than all features not in $\mathcal{L}$, meaning the split respects the ordering of the ordinal feature:

\[x_j \in \mathcal{L}\]

This information can be queried using:

Mixed splits

A mixed split is a split rule that applies to mixed data. The split rule combines the split rules of either parallel and categoric, or ordinal and categoric as appropriate.

Mixed parallel splits

A mixed parallel split applies a split for a single mixed numeric/categoric feature $j$ in the data. It selects a threshold $b$ and a subset of the categoric levels $\mathcal{L}$. A point $\mathbf{x}$ follows the lower branch if its value in this feature is numeric and is lower than the threshold, or if the value is categoric and belongs to the subset:

\[\begin{cases} x_j < b, & \text{if } x_j \text{ is numeric}\\ x_j \in \mathcal{L}, & \text{otherwise} \end{cases}\]

This information can be queried using:

Mixed ordinal splits

A mixed ordinal split applies a split for a single mixed ordinal/categoric feature $j$ in the data. It specifies a subset of the categoric and ordinal levels $\mathcal{L}$, with the restriction that all ordinal levels in the subset must be less than all ordinal levels not in the subset. A point $\mathbf{x}$ follows the lower branch if the value in this feature belongs to the subset:

\[x_j \in \mathcal{L}\]

This information can be queried using:

Missing data

Regardless of the type of split rule, the split node also contains a separate rule in case of missing data. For a given point, if any of the information required to decide the split rule is missing, then the rule for missing data at this split will specify to send this point to either the lower or upper child.

This information can be queried using:

Leaf Nodes

Each leaf node in the tree contains one or more prediction rules that specify how it generates predictions for new points that fall into the leaf. The predictions rules depend on the type of leaf node.

Classification leaves

A classification leaf predicts a label for points that fall into this leaf. It also predicts the probability that a point falling into this leaf takes each of the possible labels.

This information can be queried using:

If the leaf has a logistic regression fit, you can also query the weights and the constant using:

Regression leaves

A regression leaf predicts a continuous outcome for each point falling into the leaf using a linear regression equation. It specifies a set of features $\mathcal{J}$, with corresponding weights $\beta_j$ for each feature $j$ in $\mathcal{J}$, as well as a constant value $\beta_0$. Categoric features can be included by specifying individual weights for any or all of the categoric levels. The prediction for a point $\mathbf{x}$ is given by:

\[\sum_{j \in \mathcal{J}} \beta_j x_j + \beta_0\]

Note that the regression weights are often all zero, in which case the leaf simply makes a single constant prediction for all points.

This information can be queried using:

Survival leaves

A survival leaf predicts the survival probability over time for points that fall into the leaf, using a Kaplan-Meier curve. From this curve, it can also predict an expected survival time.

Each survival leaf also predicts a hazard ratio, giving an estimate of the risk of this leaf relative to the baseline risk (a hazard above 1 indicates increased risk, and below 1 indicates decreased risk).

This information can be queried using:

If the leaf has a Cox regression fit, you can also query the weights and the constant using:

Prescription leaves

A prescription leaf ranks the possible treatments in order from most effective to least effective. When generating a new prescription, it prescribes the most effective treatment that is permissible for a given point. The leaf also contains a predictive regression model for each treatment in the leaf, of the same form as regression leaves. For points falling into the leaf, it predicts the outcome under a given treatment using the corresponding predictive model.

This information can be queried using:

Policy leaves

A policy leaf ranks the possible treatments in order from most effective to least effective. Similar to a prescription leaf, when generating a new prescription, it prescribes the most effective treatment that is permissible for a given point.

The information on ranking among treatments can be queried using:

The effectiveness of each treatment, as used to determine the ranking, can be queried using:

Examples of querying tree structure

Re-implementing IAI.apply

The following example demonstrates how we could use the IAITrees API to reimplement the apply function:

using DataFrames

function my_apply(lnr::IAI.TreeLearner, X::DataFrame)
  [my_apply(lnr, X, i) for i = 1:nrow(X)]
end
function my_apply(lnr::IAI.TreeLearner, X::DataFrame, i::Int)
  t = 1
  while true
    if IAI.is_leaf(lnr, t)
      return t
    end

    if IAI.is_hyperplane_split(lnr, t)
      numeric_weights, categoric_weights = IAI.get_split_weights(lnr, t)

      split_value = 0.0
      value_missing = false

      for (feature, weight) in numeric_weights
        x = X[i, feature]
        value_missing = value_missing | ismissing(x)
        split_value += weight * x
      end
      for (feature, level_weights) in categoric_weights
        x = X[i, feature]
        value_missing = value_missing | ismissing(x_$feature)
        for (level, weight) in level_weights
          if x == level
            split_value += weight
          end
        end
      end

      threshold = IAI.get_split_threshold(lnr, t)
      goes_lower = split_value < threshold

    else
      feature = IAI.get_split_feature(lnr, t)
      x = X[i, feature]
      value_missing = ismissing(x)

      if IAI.is_ordinal_split(lnr, t) ||
         IAI.is_categoric_split(lnr, t) ||
         IAI.is_mixed_ordinal_split(lnr, t)
        categories = IAI.get_split_categories(lnr, t)
        goes_lower = categories[x]

      elseif IAI.is_parallel_split(lnr, t)
        threshold = IAI.get_split_threshold(lnr, t)
        goes_lower = x < threshold

      elseif IAI.is_mixed_parallel_split(lnr, t)
        threshold = IAI.get_split_threshold(lnr, t)
        categories = IAI.get_split_categories(lnr, t)
        goes_lower = isa(x, Real) ? x < threshold : categories[x]
      end
    end

    if value_missing
      goes_lower = IAI.missing_goes_lower(lnr, t)
    end

    t = goes_lower ? get_lower_child(lnr, t) : get_upper_child(lnr, t)
  end
end