API Reference
Documentation for the IAITrees
public interface.
Index
IAI.Questionnaire
IAI.ClassificationTreeLearner
IAI.ClassificationTreeMultiLearner
IAI.MultiTreePlot
IAI.MultiTreePlot
IAI.PolicyTreeLearner
IAI.PrescriptionTreeLearner
IAI.RegressionTreeLearner
IAI.RegressionTreeMultiLearner
IAI.SimilarityComparison
IAI.StabilityAnalysis
IAI.SurvivalTreeLearner
IAI.TreeLearner
IAI.TreeMultiLearner
IAI.TreePlot
IAI.show_in_browser
IAI.show_questionnaire
IAI.variable_importance
IAI.write_html
IAI.write_questionnaire
IAI.apply
IAI.apply_nodes
IAI.compare_group_outcomes
IAI.decision_path
IAI.get_classification_label
IAI.get_classification_label
IAI.get_classification_label
IAI.get_classification_proba
IAI.get_classification_proba
IAI.get_classification_proba
IAI.get_cluster_assignments
IAI.get_cluster_details
IAI.get_cluster_distances
IAI.get_depth
IAI.get_lower_child
IAI.get_num_nodes
IAI.get_num_samples
IAI.get_parent
IAI.get_policy_treatment_outcome
IAI.get_policy_treatment_outcome_standard_error
IAI.get_policy_treatment_rank
IAI.get_prescription_treatment_rank
IAI.get_regression_constant
IAI.get_regression_constant
IAI.get_regression_constant
IAI.get_regression_constant
IAI.get_regression_constant
IAI.get_regression_constant
IAI.get_regression_constant
IAI.get_regression_constant
IAI.get_regression_weights
IAI.get_regression_weights
IAI.get_regression_weights
IAI.get_regression_weights
IAI.get_regression_weights
IAI.get_regression_weights
IAI.get_regression_weights
IAI.get_regression_weights
IAI.get_split_categories
IAI.get_split_feature
IAI.get_split_threshold
IAI.get_split_weights
IAI.get_stability_results
IAI.get_survival_curve
IAI.get_survival_expected_time
IAI.get_survival_hazard
IAI.get_train_errors
IAI.get_tree
IAI.get_upper_child
IAI.is_categoric_split
IAI.is_hyperplane_split
IAI.is_leaf
IAI.is_mixed_ordinal_split
IAI.is_mixed_parallel_split
IAI.is_ordinal_split
IAI.is_parallel_split
IAI.missing_goes_lower
IAI.print_path
IAI.reset_display_label!
IAI.set_display_label!
IAI.set_threshold!
IAI.variable_importance_similarity
IAI.write_dot
IAI.write_pdf
IAI.write_png
IAI.write_svg
Types
IAI.TreeLearner
— TypeAbstract type encompassing all tree-based learners.
IAI.ClassificationTreeLearner
— TypeAbstract type encompassing all tree-based learners with classification leaves.
IAI.RegressionTreeLearner
— TypeAbstract type encompassing all tree-based learners with regression leaves.
IAI.SurvivalTreeLearner
— TypeAbstract type encompassing all tree-based learners with survival leaves.
IAI.PrescriptionTreeLearner
— TypeAbstract type encompassing all tree-based learners with prescription leaves.
IAI.PolicyTreeLearner
— TypeAbstract type encompassing all tree-based learners with policy leaves.
IAI.TreeMultiLearner
— TypeAbstract type encompassing all multi-task tree-based learners.
IAI.ClassificationTreeMultiLearner
— TypeAbstract type encompassing all multi-task tree-based learners with classification leaves.
IAI.RegressionTreeMultiLearner
— TypeAbstract type encompassing all multi-task tree-based learners with regression leaves.
Tree Structure
These functions can be used to query the structure of a TreeLearner
. The examples make use of the following tree:
IAI.get_num_nodes
— Functionget_num_nodes(lnr::TreeLearner)
Return the number of nodes in the trained lnr
.
Example
IAI.get_num_nodes(lnr)
7
IAI.is_leaf
— Functionis_leaf(lnr::TreeLearner, node_index::Int)
Return true
if node node_index
in the trained lnr
is a leaf.
Example
IAI.is_leaf(lnr, 1)
false
IAI.get_depth
— Functionget_depth(lnr::TreeLearner, node_index::Int)
Return the depth of node node_index
in the trained lnr
.
Example
IAI.get_depth(lnr, 6)
2
IAI.get_num_samples
— Functionget_num_samples(lnr::TreeLearner, node_index::Int)
Return the number of training points contained in node node_index
in the trained lnr
.
Example
IAI.get_num_samples(lnr, 6)
78
IAI.get_parent
— Functionget_parent(lnr::TreeLearner, node_index::Int)
Return the index of the parent of node node_index
in the trained lnr
.
Example
IAI.get_parent(lnr, 2)
1
IAI.get_lower_child
— Functionget_lower_child(lnr::TreeLearner, node_index::Int)
Return the index of the lower child of node node_index
in the trained lnr
.
Example
IAI.get_lower_child(lnr, 1)
2
IAI.get_upper_child
— Functionget_upper_child(lnr::TreeLearner, node_index::Int)
Return the index of the upper child of node node_index
in the trained lnr
.
Example
IAI.get_upper_child(lnr, 1)
5
IAI.is_parallel_split
— Functionis_parallel_split(lnr::TreeLearner, node_index::Int)
Return true
if node node_index
in the trained lnr
is a parallel split.
Example
IAI.is_parallel_split(lnr, 1)
true
IAI.is_hyperplane_split
— Functionis_hyperplane_split(lnr::TreeLearner, node_index::Int)
Return true
if node node_index
in the trained lnr
is a hyperplane split.
Example
IAI.is_hyperplane_split(lnr, 2)
true
IAI.is_categoric_split
— Functionis_categoric_split(lnr::TreeLearner, node_index::Int)
Return true
if node node_index
in the trained lnr
is a categoric split.
Example
IAI.is_categoric_split(lnr, 5)
true
IAI.is_ordinal_split
— Functionis_ordinal_split(lnr::TreeLearner, node_index::Int)
Return true
if node node_index
in the trained lnr
is an ordinal split.
Example
IAI.is_ordinal_split(lnr, 1)
false
IAI.is_mixed_parallel_split
— Functionis_mixed_parallel_split(lnr::TreeLearner, node_index::Int)
Return true
if node node_index
in the trained lnr
is a mixed categoric/parallel split.
Example
IAI.is_mixed_parallel_split(lnr, 2)
false
IAI.is_mixed_ordinal_split
— Functionis_mixed_ordinal_split(lnr::TreeLearner, node_index::Int)
Return true
if node node_index
in the trained lnr
is a mixed categoric/ordinal split.
Example
IAI.is_mixed_ordinal_split(lnr, 5)
false
IAI.missing_goes_lower
— Functionmissing_goes_lower(lnr::TreeLearner, node_index::Int)
Return true
if missing
values take the lower branch at node node_index
in the trained lnr
.
Applies to non-leaf nodes.
Example
IAI.missing_goes_lower(lnr, 1)
false
IAI.get_split_feature
— Functionget_split_feature(lnr::TreeLearner, node_index::Int)
Return the feature used in the split at node node_index
in the trained lnr
.
Applies to categoric, ordinal, parallel, categoric/ordinal, and categoric/parallel splits.
Example
IAI.get_split_feature(lnr, 1)
:score1
IAI.get_split_threshold
— Functionget_split_threshold(lnr::TreeLearner, node_index::Int)
Return the threshold used in the split at node node_index
in the trained lnr
.
Applies to hyperplane, parallel, and categoric/parallel splits.
Example
IAI.get_split_threshold(lnr, 1)
60.04421
IAI.get_split_categories
— Functionget_split_categories(lnr::TreeLearner, node_index::Int)
Return a Dict
containing the categoric/ordinal information used in the split at node node_index
in the trained lnr
, where the keys are the levels used in the split and the values are true
if that level follows the lower branch and false
if that level follows the upper branch.
Applies to categoric, ordinal, categoric/ordinal, and categoric/parallel splits.
Example
IAI.get_split_categories(lnr, 5)
Dict{Any, Bool} with 5 entries:
"B" => 1
"A" => 1
"C" => 0
"D" => 0
"E" => 0
IAI.get_split_weights
— Functionget_split_weights(lnr::TreeLearner, node_index::Int)
Return two Dict
s containing the weights for numeric and categoric features, respectively, used in the hyperplane split at node node_index
in the trained lnr
.
The numeric Dict
has key-value pairs of feature names and their corresponding weights in the hyperplane split.
The categoric Dict
has key-value pairs of feature names and a corresponding Dict
that maps the categoric levels for that feature to their weights in the hyperplane.
Any features not included in either Dict
has zero weight in the hyperplane, and similarly, any categoric levels that are not included have zero weight.
Applies to hyperplane splits.
Example
numeric_weights, categoric_weights = IAI.get_split_weights(lnr, 2)
numeric_weights
Dict{Symbol, Float64} with 2 entries:
:score3 => 1.20415
:score2 => 0.0189015
categoric_weights
Dict{Symbol, Dict{Any, Float64}} with 1 entry:
:region => Dict("E"=>1.47922)
Classification
These functions can be used to query the structure of a ClassificationTreeLearner
. The examples make use of the following tree:
IAI.get_classification_label
— Methodget_classification_label(lnr::ClassificationTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the predicted label at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_classification_label(lnr, 2)
"setosa"
IAI.get_classification_proba
— Methodget_classification_proba(lnr::ClassificationTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the predicted probabilities of class membership at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_classification_proba(lnr, 4)
Dict{String, Float64} with 3 entries:
"virginica" => 0.0925926
"setosa" => 0.0
"versicolor" => 0.907407
IAI.get_regression_constant
— Methodget_regression_constant(lnr::ClassificationTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the constant term in the logistic regression prediction at node node_index
in the trained lnr
, or NaN
if the node does not contain a logistic regression model.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
IAI.get_regression_weights
— Methodget_regression_weights(lnr::ClassificationTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the weights for each feature in the logistic regression prediction at node node_index
in the trained lnr
. The weights are returned as two Dict
s in the same format as described for get_split_weights
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Regression
These functions can be used to query the structure of a RegressionTreeLearner
. The examples make use of the following tree:
IAI.get_regression_constant
— Methodget_regression_constant(lnr::RegressionTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the constant term in the regression prediction at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_regression_constant(lnr, 2)
30.88
IAI.get_regression_constant(lnr, 3)
30.8876
IAI.get_regression_weights
— Methodget_regression_weights(lnr::RegressionTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the weights for each feature in the regression prediction at node node_index
in the trained lnr
. The weights are returned as two Dict
s in the same format as described for get_split_weights
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
numeric_weights, categoric_weights = IAI.get_regression_weights(lnr, 3)
numeric_weights
Dict{Symbol, Float64} with 4 entries:
:Cyl => -0.794566
:WT => -1.64974
:Gear => 0.0585196
:HP => -0.0126672
categoric_weights
Dict{Symbol, Dict{Any, Float64}}()
Survival
These functions can be used to query the structure of a SurvivalTreeLearner
. The examples make use of the following tree:
IAI.get_survival_curve
— Functionget_survival_curve(lnr::SurvivalTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the SurvivalCurve
fitted at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_survival_curve(lnr, 2)
SurvivalCurve with 22 breakpoints
IAI.get_survival_expected_time
— Functionget_survival_expected_time(lnr::SurvivalTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the predicted expected survival time at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_survival_expected_time(lnr, 2)
23443.187287749995
IAI.get_survival_hazard
— Functionget_survival_hazard(lnr::SurvivalTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the predicted hazard ratio at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_survival_hazard(lnr, 2)
0.8880508
IAI.get_regression_constant
— Methodget_regression_constant(lnr::SurvivalTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the constant term in the Cox regression prediction at node node_index
in the trained lnr
, or NaN
if the node does not contain a Cox regression model.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
IAI.get_regression_weights
— Methodget_regression_weights(lnr::SurvivalTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return the weights for each feature in the Cox regression prediction at node node_index
in the trained lnr
. The weights are returned as two Dict
s in the same format as described for get_split_weights
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Prescription
These functions can be used to query the structure of a PrescriptionTreeLearner
. The examples make use of the following tree:
IAI.get_prescription_treatment_rank
— Functionget_prescription_treatment_rank(lnr::PrescriptionTreeLearner,
node_index::Int; check_leaf::Bool=true)
Return a Vector
containing the treatments ordered from most effective to least effective at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_prescription_treatment_rank(lnr, 2)
2-element Vector{String}:
"A"
"B"
IAI.get_regression_constant
— Methodget_regression_constant(lnr::PrescriptionTreeLearner, node_index::Int,
treatment::Any; check_leaf::Bool=true)
Return the constant in the regression prediction for treatment
at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_regression_constant(lnr, 2, "A")
28.68282
IAI.get_regression_weights
— Methodget_regression_weights(lnr::PrescriptionTreeLearner, node_index::Int,
treatment::Any; check_leaf::Bool=true)
Return the weights for each feature in the regression prediction for treatment
at node node_index
in the trained lnr
. The weights are returned as two Dict
s in the same format as described for get_split_weights
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
numeric_weights, categoric_weights = IAI.get_regression_weights(lnr, 2, "A")
numeric_weights
Dict{Symbol, Float64} with 1 entry:
:SystolicBP => -1.37769
categoric_weights
Dict{Symbol, Dict{Any, Float64}}()
Policy
These functions can be used to query the structure of a PolicyTreeLearner
. The examples make use of the following tree:
IAI.get_policy_treatment_rank
— Functionget_policy_treatment_rank(lnr::PolicyTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return a Vector
containing the treatments ordered from most effective to least effective at node node_index
in the trained lnr
.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
IAI.get_policy_treatment_rank(lnr, 3)
3-element Vector{String}:
"A"
"C"
"B"
IAI.get_policy_treatment_outcome
— Functionget_policy_treatment_outcome(lnr::PolicyTreeLearner, node_index::Int;
check_leaf::Bool=true)
Return a DataFrameRow
containing the quality of the treatments at node node_index
in the trained lnr
. These quality estimates are the values used by the model to determine the treatment ranks in get_policy_treatment_rank
and are based on aggregate statistics.
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
outcome = IAI.get_policy_treatment_outcome(lnr, 3)
DataFrameRow
Row │ A B C
│ Float64 Float64 Float64
─────┼────────────────────────────
1 │ 0.827778 1.70248 1.09849
outcome.A
0.8277784
IAI.get_policy_treatment_outcome_standard_error
— Functionget_policy_treatment_outcome_standard_error(lnr::PolicyTreeLearner,
node_index::Int;
check_leaf::Bool=true)
Return a DataFrameRow
containing the standard error for the estimated quality of the treatments at node node_index
in the trained lnr
. These errors can be used to construct confidence intervals around results from get_policy_treatment_outcome
Applies to leaf nodes by default, set check_leaf=false
to enable retrieving the same information from a split node as though it was a leaf node.
Example
errors = IAI.get_policy_treatment_outcome_standard_error(lnr, 3)
DataFrameRow
Row │ A B C
│ Float64 Float64 Float64
─────┼──────────────────────────────
1 │ 0.0777876 0.083841 0.10806
errors.A
0.07778763
Multi-task
Classification
These functions can be used to query the structure of a ClassificationTreeMultiLearner
.
IAI.get_classification_label
— Methodget_classification_label(lnr::ClassificationTreeMultiLearner,
node_index::Int; check_leaf::Bool=true)
Variant of get_classification_label
for multi-task problems that returns the information for all tasks as a dictionary.
IAI.get_classification_label
— Methodget_classification_label(lnr::ClassificationTreeMultiLearner,
node_index::Int, task_label::Symbol;
check_leaf::Bool=true)
Variant of get_classification_label
for multi-task problems that returns information for the task given by task_label
.
IAI.get_classification_proba
— Methodget_classification_proba(lnr::ClassificationTreeMultiLearner,
node_index::Int; check_leaf::Bool=true)
Variant of get_classification_proba
for multi-task problems that returns the information for all tasks as a dictionary.
IAI.get_classification_proba
— Methodget_classification_proba(lnr::ClassificationTreeMultiLearner,
node_index::Int, task_label::Symbol;
check_leaf::Bool=true)
Variant of get_classification_proba
for multi-task problems that returns information for the task given by task_label
.
IAI.get_regression_constant
— Methodget_regression_constant(lnr::ClassificationTreeMultiLearner,
node_index::Int; check_leaf::Bool=true)
Variant of get_regression_constant
for multi-task problems that returns the information for all tasks as a dictionary.
IAI.get_regression_constant
— Methodget_regression_constant(lnr::ClassificationTreeMultiLearner,
node_index::Int, task_label::Symbol;
check_leaf::Bool=true)
Variant of get_regression_constant
for multi-task problems that returns information for the task given by task_label
.
IAI.get_regression_weights
— Methodget_regression_weights(lnr::ClassificationTreeMultiLearner,
node_index::Int; check_leaf::Bool=true)
Variant of get_regression_weights
for multi-task problems that returns the information for all tasks as a dictionary.
IAI.get_regression_weights
— Methodget_regression_weights(lnr::ClassificationTreeMultiLearner,
node_index::Int, task_label::Symbol;
check_leaf::Bool=true)
Variant of get_regression_weights
for multi-task problems that returns information for the task given by task_label
.
Regression
These functions can be used to query the structure of a RegressionTreeMultiLearner
.
IAI.get_regression_constant
— Methodget_regression_constant(lnr::RegressionTreeMultiLearner,
node_index::Int; check_leaf::Bool=true)
Variant of get_regression_constant
for multi-task problems that returns the information for all tasks as a dictionary.
IAI.get_regression_constant
— Methodget_regression_constant(lnr::RegressionTreeMultiLearner,
node_index::Int, task_label::Symbol;
check_leaf::Bool=true)
Variant of get_regression_constant
for multi-task problems that returns information for the task given by task_label
.
IAI.get_regression_weights
— Methodget_regression_weights(lnr::RegressionTreeMultiLearner,
node_index::Int; check_leaf::Bool=true)
Variant of get_regression_weights
for multi-task problems that returns the information for all tasks as a dictionary.
IAI.get_regression_weights
— Methodget_regression_weights(lnr::RegressionTreeMultiLearner,
node_index::Int, task_label::Symbol;
check_leaf::Bool=true)
Variant of get_regression_weights
for multi-task problems that returns information for the task given by task_label
.
Learners
IAI.apply
— Functionapply(lnr::TreeLearner, X::FeatureInput)
Return a Vector{Int}
that contains the leaf index in lnr
into which each point in the features X
falls.
IAI.apply_nodes
— Functionapply_nodes(lnr::TreeLearner, X::FeatureInput)
Return a Vector
with one entry for each node in lnr
. The t
th element is a Vector{Int}
containing the indices of the points from the features X
that fall into node t
or its children.
IAI.decision_path
— Functiondecision_path(lnr::TreeLearner, X::FeatureInput)
Return a SparseMatrixCSC{Bool,Int64}
where entry (i, j)
is true
if the i
th point in the features X
passes through the j
th node in lnr
.
IAI.print_path
— Functionprint_path(lnr::TreeLearner, X::FeatureInput, i::Int)
Print the decision path for the i
th sample in the features X
. The output displays the value of the relevant features for the specified sample and the rules for the path that it takes through the tree.
Example
Print the path through the tree for the first sample in the features X
:
IAI.print_path(lnr, X, 1)
Rules used to predict sample 1:
1) Split: score1 (=28.9) < 60.04
2) Split: 0.0189 * score2 + 1.204 * score3 + 1.479 * region=E (=3.508) ≥ 2.346
4) Predict: true (97.50%), [2,78], 80 points, error 0.025
print_path(lnr::TreeLearner, X::FeatureInput, inds::AbstractVector{Int})
Print the decision path for the samples in the features X
indicated by inds
.
Example
Print the path through the tree for the first two samples in the features X
:
IAI.print_path(lnr, X, 1:2)
Rules used to predict sample 1:
1) Split: score1 (=28.9) < 60.04
2) Split: 0.0189 * score2 + 1.204 * score3 + 1.479 * region=E (=3.508) ≥ 2.346
4) Predict: true (97.50%), [2,78], 80 points, error 0.025
Rules used to predict sample 2:
1) Split: score1 (=29.72) < 60.04
2) Split: 0.0189 * score2 + 1.204 * score3 + 1.479 * region=E (=2.341) < 2.346
3) Predict: false (99.56%), [228,1], 229 points, error 0.004367
print_path(lnr::TreeLearner, X::FeatureInput)
Print the decision path for each sample in the features X
.
Example
Print the path through the tree for all samples in the features X
:
IAI.print_path(lnr, X)
(output omitted for brevity)
print_path(io::IO, lnr::TreeLearner, X::FeatureInput)
print_path(io::IO, lnr::TreeLearner, X::FeatureInput, i::Int)
print_path(io::IO, lnr::TreeLearner, X::FeatureInput,
inds::AbstractVector{Int})
Variants of print_path
that write to a specified io
rather than to stdout
.
Examples
Write the output of print_path
to print_path.txt
:
open("print_path.txt", "w") do f
IAI.print_path(f, lnr, X, 1)
end
Capture the output of print_path
as a String
:
sprint(IAI.print_path, lnr, X, 1)
"Rules used to predict sample 1:\n 1) Split: score1 (=28.9) < 60.04\n 2) Split: 0.0189 * score2 + 1.204 * score3 + 1.479 * region=E (=3.508) ≥ 2.346\n 4) Predict: true (97.50%), [2,78], 80 points, error 0.025\n"
IAI.variable_importance
— Methodvariable_importance(lnr::TreeLearner; keyword_arguments...)
For tree learners, the importance of each variable is measured as the total decrease in the loss function as a direct result of each split in the trees of lnr
that use this variable.
Keyword Arguments
proportion_to_use::Real
: a number between 0 and 1 indicating the proportion of trees to use when calculating importance. The default value is 0.1, indicating that the best 10% of the trees saved inlnr
should be used.
variable_importance(::TreeLearner, X::FeatureInput; keyword_arguments...)
For tree learners, calculates the variable importance of lnr
for the samples in X
, where the importance at each node is weighted by the number of samples that pass through that node.
Keyword Arguments
proportion_to_use::Real
: as above.sample_weight::SampleWeightInput=nothing
: the weighting to give to each data point.
variable_importance(::TreeLearner, X::FeatureInput, y::TargetInput...;
keyword_arguments...)
For tree learners, calculates the variable importance of lnr
with respect to data X
and y
.
Keyword Arguments
proportion_to_use::Real
: as above.sample_weight::SampleWeightInput=nothing
: the weighting to give to each data point.criterion=:default
: the scoring criterion to use when evaluating the importance (refer to the documentation on scoring criteria for more information). Uses the criterion inlnr
if left as:default
.- extra keyword arguments are passed through to configure the specified scoring criterion (e.g.
tweedie_variance_power
for:tweedie
)
Task-specific Functions
Classification
IAI.set_threshold!
— Functionset_threshold!(lnr::ClassificationTreeLearner, label::Any, threshold::Real,
simplify::Bool=false)
For a binary classification problem, update the the predicted labels in the leaves of lnr
. After running, a leaf will predict label
only if the predicted probability for this label is at least threshold
; otherwise, the other label will be predicted.
If simplify
is true
, the tree will be simplified so that there is no split that has two leaves with the same label prediction as children. This means that if both sides of a split are leaf nodes with the same label prediction, the split will be deleted from the tree and replaced with a single leaf node. This simplification is applied recursively throughout the tree.
Refer to the documentation on setting the threshold for more information.
set_threshold!(lnr::ClassificationTreeMultiLearner, task_label::Symbol,
label::Any, threshold::Real, simplify::Bool=false)
Variant of set_threshold!
for multi-task problems that operates on the task given by task_label
.
Visualization
Interactive Visualizations
IAI.write_html
— Methodwrite_html(f, lnr::TreeLearner; keyword_arguments...)
write_html(f, grid::GridSearch; keyword_arguments...)
Write interactive browser visualization of lnr
or grid
to f
in HTML format.
Keyword Arguments
show_node_id=true
: whether to show the ID label for each nodedata
: specify data to be shown in the visualization, should be passed as aTuple
orVector
in the same order as passed tofit!
, i.e.:data=(X, y)
for classification and regression problemsdata=(X, deaths, times)
for survival problemsdata=(X, treatments, outcomes)
for prescription problemsdata=(X, rewards)
for policy problems
You can also pass
data=X
to show the features without target information.Refer to the Tree Visualization documentation for more information.
Example
Save tree to mytree.html
:
IAI.write_html("mytree.html", lnr)
IAI.show_in_browser
— Methodshow_in_browser(lnr::TreeLearner; keyword_arguments...)
show_in_browser(grid::GridSearch; keyword_arguments...)
Show interactive visualization of lnr
or grid
in default browser.
Supports the same keyword arguments as write_html
.
IAI.write_questionnaire
— Methodwrite_questionnaire(f, lnr::TreeLearner; keyword_arguments...)
write_questionnaire(f, grid::GridSearch; keyword_arguments...)
Write interactive questionnaire based on lnr
or grid
to f
in HTML format.
Keyword Arguments
include_not_sure_buttons
: aBool
specifying whether to include "Not sure" buttons for each question. The default behavior is to include these buttons if any missing data was present in the training data.
This function also supports the same keyword arguments as write_html
.
Example
Save questionnaire to questions.html
:
IAI.write_questionnaire("myquestionnaire.html", lnr)
IAI.show_questionnaire
— Methodshow_questionnaire(lnr::TreeLearner; keyword_arguments...)
show_questionnaire(grid::GridSearch; keyword_arguments...)
Show interactive questionnaire based on lnr
or grid
in default browser.
Supports the same keyword arguments as write_questionnaire
.
IAI.TreePlot
— MethodTreePlot(lnr::TreeLearner; keyword_arguments...)
Specifies an interactive tree visualization of lnr
.
Keyword Arguments
feature_renames
,level_renames
andlabel_renames
allow renaming different aspects of the dataextra_content
allows including additional output at each node in the visualization
Refer to the documentation on advanced visualization for more information on using these keyword arguments.
IAI.Questionnaire
— MethodQuestionnaire(lnr::TreeLearner; keyword_arguments...)
Specifies an interactive questionnaire based on lnr
.
Supports the same keyword arguments as TreePlot
.
IAI.MultiTreePlot
— MethodMultiTreePlot(questions::Pair; keyword_arguments...)
Specifies an interactive tree visualization of multiple tree learners as specified by questions
. Refer to the documentation on multi-learner visualizations for more details. The keyword arguments are the same as for TreePlot
.
IAI.MultiTreePlot
— MethodMultiTreePlot(grid::GridSearch; keyword_arguments...)
Constructs an interactive tree visualization containing the final fitted learner as well as the learner found for each parameter combination. The keyword arguments are the same as for TreePlot
.
Static Images
IAI.write_png
— Functionwrite_png(filename::AbstractString, lnr::TreeLearner; keyword_arguments...)
Write lnr
to filename
as a PNG image.
Before using this function, make sure that either Graphviz_jll
is loaded, or GraphViz is installed and on the system PATH
.
Keyword Arguments
feature_renames
,level_renames
andlabel_renames
: renaming different aspects of the dataextra_content
: including additional output at each node in the visualizationfont
: the font for the text in the image. Defaults to "Arial"show_missing_direction
: whether to include the missing data direction in the split criterion. Defaults to true if the data has missing observations, and false otherwiseshow_node_id=true
: whether to show the ID label for each node
Refer to the documentation on advanced visualization for more information on using these keyword arguments.
Example
Save tree to mytree.png
:
IAI.write_png("mytree.png", lnr)
IAI.write_pdf
— Functionwrite_pdf(filename::AbstractString, lnr::TreeLearner; keyword_arguments...)
Write lnr
to filename
as a PDF image.
Supports the same keyword arguments and has the same requirements as write_png
.
IAI.write_svg
— Functionwrite_svg(filename::AbstractString, lnr::TreeLearner; keyword_arguments...)
Write lnr
to filename
as an SVG image.
Supports the same keyword arguments and has the same requirements as write_png
.
IAI.write_dot
— Functionwrite_dot(f, lnr::TreeLearner; keyword_arguments...)
Write the trained tree of lnr
into .dot format to the stream f
.
Supports the same keyword arguments as write_png
.
Example
Save tree to mytree.dot
:
IAI.write_dot("mytree.dot", lnr)
You can then convert mytree.dot
to PNG image at the command line (requires GraphViz be installed):
$ dot -Tpng -o mytree.png mytree.dot
Tree Stability
IAI.get_tree
— Functionget_tree(lnr::TreeLearner, index::Integer)
Return a copy of lnr
that uses the tree at index
rather than the tree with the best training objective.
Stability Analysis
IAI.StabilityAnalysis
— TypeStabilityAnalysis(lnr::TreeLearner)
StabilityAnalysis(lnr::TreeLearner, X, y...; criterion=:default)
Conduct a stability analysis of the trees in lnr
. The similarity scores are calculated using the data X
and y
with criterion
if supplied, otherwise the data and criterion from training are reused.
For classification problems, we strongly suggest that the stability analysis is conducted with :gini
or :entropy
as the criterion, as similarity measures derived using :misclassification
are not as precise.
The resulting analysis can be visualized in the browser using show_in_browser
, or with write_html
to save the visualization in HTML format. You can also use plot
from Plots.jl to view a summary of the analysis.
IAI.get_stability_results
— Functionget_stability_results(s::StabilityAnalysis)
Return a DataFrame
containing the trained trees in order of increasing training objective value, along with their variable importance scores for each feature.
IAI.get_cluster_distances
— Functionget_cluster_distances(s::StabilityAnalysis, num_trees::Int)
Return a Matrix
containing the distances between the centroids of each pair of clusters, under the clustering of the best num_trees
trees.
IAI.get_cluster_assignments
— Functionget_cluster_assignments(s::StabilityAnalysis, num_trees::Int)
Return a Vector
containing the indices of the trees assigned to each cluster, under the clustering of the best num_trees
trees.
IAI.get_cluster_details
— Functionget_cluster_details(s::StabilityAnalysis, num_trees::Int)
Return a DataFrame
containing the centroid information for each cluster, under the clustering of the best num_trees
trees.
Similarity Comparison
IAI.SimilarityComparison
— TypeSimilarityComparison(orig_lnr::TreeLearner, new_lnr::TreeLearner,
deviations::Vector{Float64})
Conduct a similarity comparison between the final tree in orig_lnr
and all trees in new_lnr
to consider the tradeoff between training performance and similarity to the original tree. deviations
is a vector containing the distances between the tree in orig_lnr
and all trees in new_lnr
, as calculated by variable_importance_similarity
, for example.
The resulting analysis can be visualized in the browser using show_in_browser
, or with write_html
to save the visualization in HTML format. You can also use plot
from Plots.jl to view a summary of the analysis.
IAI.variable_importance_similarity
— Functionvariable_importance_similarity(orig_lnr::TreeLearner, new_lnr::TreeLearner)
variable_importance_similarity(orig_lnr::TreeLearner, new_lnr::TreeLearner,
X, y...; criterion=:default)
Calculate similarity scores between the final tree in orig_lnr
and all trees in new_lnr
using variable importance scores. Scores are calculated using the data X
and y
with criterion
if supplied, otherwise the data and criterion from training are reused.
For classification problems, we strongly suggest that the scores are calculated with :gini
or :entropy
as the criterion because similarity measures derived using :misclassification
are not as precise.
IAI.get_train_errors
— Functionget_train_errors(s::SimilarityComparison)
Extract the training objective value for each candidate tree in s
, where a lower value indicates a better solution.
Miscellaneous
IAI.set_display_label!
— Functionset_display_label!(lnr::ClassificationTreeLearner, display_label::Any)
Changes which predicted probability is displayed when visualizing lnr
to show the probability of display_label
.
set_display_label!(lnr::ClassificationTreeMultiLearner, task_label::Symbol,
display_label::Any)
Variant of set_display_label!
for multi-task problems that operates on the task given by task_label
.
set_display_label!(grid::GridSearch{<:ClassificationTreeLearner},
display_label::Any)
Changes which predicted probability is displayed when visualizing grid
to show the probability of display_label
.
set_display_label!(grid::GridSearch{<:ClassificationTreeMultiLearner},
task_label::Symbol, display_label::Any)
Variant of set_display_label!
for multi-task problems that operates on the task given by task_label
.
IAI.reset_display_label!
— Functionreset_display_label!(lnr::ClassificationTreeLearner)
Resets the predicted probability displayed for lnr
to be that of the predicted label.
reset_display_label!(lnr::ClassificationTreeMultiLearner,
task_label::Symbol)
Variant of reset_display_label!
for multi-task problems that operates on the task given by task_label
.
reset_display_label!(grid::GridSearch{<:ClassificationTreeLearner})
Resets the predicted probability displayed for grid
to be that of the predicted label.
reset_display_label!(grid::GridSearch{<:ClassificationTreeMultiLearner},
task_label::Symbol)
Variant of reset_display_label!
for multi-task problems that operates on the task given by task_label
.
IAI.compare_group_outcomes
— Functioncompare_group_outcomes(lnr::TreeLearner, X::FeatureInput, y::AbstractVector,
group::AbstractVector; keyword_arguments...)
In each node of lnr
, conduct between-group statistical comparisons of the outcomes for data X
and y
that fall into the node, where the groups are given by group
.
Returns a Vector
where each entry corresponds to a node in the tree, and is a NamedTuple
with two fields:
summary
: ADataFrame
summarizing the outcome by group in this nodep_value
: ADict
containing the p-values of the statistical tests conducted. There are four types of tests conducted:"overall"
: A single p-value indicating whether there is an overall difference in outcomes between groups in this node (for regression, Welch's test; for classification, chi-squared test)"vs-mean"
: ADict
with one p-value for each group indicating whether there is a difference between this group and the overall population (for regression, one-sample t-test; for classification, binomial test)"vs-rest"
: ADict
with one p-value for each group indicating whether there is a difference between this group and all other groups (for regression, two-sample t-test; for classification, Fisher-exact test)"pairwise"
: ADict
ofDict
s, with one p-value for each pair of groups indicating whether there is a difference between these two groups (for regression, two-sample t-test; for classification, Fisher-exact test)
Keyword Arguments
positive_label
: For classification only, specify which label iny
to treat as the positive labelapprox::Bool=false
: Whether to use approximate comparisons, which can often be significantly faster. For classification, replaces Fisher-exact tests with chi-squared tests.