Quick Start Guide: Multi-task Optimal Classification Trees
In this example we will use Optimal Classification Trees (OCT) on the acute inflammations dataset to solve a multi-task classification problem.
This guide assumes you are familiar with OCTs and focuses on aspects that are unique to the multi-task setting. For a general introduction to OCTs, please refer to the OCT quickstart guide.
First we load in the data and split into training and test datasets:
# File is UTF-16 encoded so we use the StringEncodings package to read it
using CSV, DataFrames, StringEncodings
df = CSV.read(open("diagnosis.data", enc"utf-16"), DataFrame,
header=[:temp, :nausea, :lumbar_pain, :urine_pushing, :micturition_pains,
:burning, :inflammation, :nephritis],
delim='\t',
decimal=',',
)
120×8 DataFrame
Row │ temp nausea lumbar_pain urine_pushing micturition_pains burnin ⋯
│ Float64 String3 String3 String3 String3 String ⋯
─────┼──────────────────────────────────────────────────────────────────────────
1 │ 35.5 no yes no no no ⋯
2 │ 35.9 no no yes yes yes
3 │ 35.9 no yes no no no
4 │ 36.0 no no yes yes yes
5 │ 36.0 no yes no no no ⋯
6 │ 36.0 no yes no no no
7 │ 36.2 no no yes yes yes
8 │ 36.2 no yes no no no
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
114 │ 41.2 no yes yes no yes ⋯
115 │ 41.3 yes yes yes yes no
116 │ 41.4 no yes yes no yes
117 │ 41.5 no no no no no
118 │ 41.5 yes yes no yes no ⋯
119 │ 41.5 no yes yes no yes
120 │ 41.5 no yes yes no yes
3 columns and 105 rows omitted
The goal is to predict two diseases of the urinary system: acute inflammations of urinary bladder and acute nephritises. We therefore separate these two targets from the rest of the features, and split for training and testing:
targets = [:inflammation, :nephritis]
X = select(df, Not(targets))
y = select(df, targets)
(train_X, train_y), (test_X, test_y) = IAI.split_data(:multi_classification,
X, y, seed=1)
Multi-task Optimal Classification Trees
We will use a GridSearch
to fit an OptimalTreeMultiClassifier
:
grid = IAI.GridSearch(
IAI.OptimalTreeMultiClassifier(
random_seed=1,
),
max_depth=1:5,
)
IAI.fit!(grid, train_X, train_y)
IAI.get_learner(grid)
We can make predictions on new data using predict
:
IAI.predict(grid, test_X)
OrderedCollections.OrderedDict{Symbol, Vector{String3}} with 2 entries:
:inflammation => ["no", "no", "yes", "no", "yes", "no", "yes", "yes", "no", "…
:nephritis => ["no", "no", "no", "no", "no", "no", "no", "no", "no", "no" …
This returns a dictionary containing the predictions for each of the tasks, and can also be converted to a dataframe easily:
DataFrame(IAI.predict(grid, test_X))
36×2 DataFrame
Row │ inflammation nephritis
│ String3 String3
─────┼─────────────────────────
1 │ no no
2 │ no no
3 │ yes no
4 │ no no
5 │ yes no
6 │ no no
7 │ yes no
8 │ yes no
⋮ │ ⋮ ⋮
30 │ yes yes
31 │ no yes
32 │ no yes
33 │ yes yes
34 │ no yes
35 │ yes yes
36 │ no yes
21 rows omitted
We can also generate the predictions for a specific task by passing the task label:
IAI.predict(grid, test_X, :inflammation)
36-element Vector{String3}:
"no"
"no"
"yes"
"no"
"yes"
"no"
"yes"
"yes"
"no"
"yes"
⋮
"no"
"yes"
"yes"
"no"
"no"
"yes"
"no"
"yes"
"no"
We can evaluate the quality of the tree using score
with any of the supported loss functions. For multi-task problems, the returned score is the average of the scores of the individual tasks:
IAI.score(grid, test_X, test_y, criterion=:misclassification)
1.0
We can also calculate the score of a single task by specifying this task:
IAI.score(grid, test_X, test_y, :nephritis, criterion=:auc)
1.0
The other standard API functions (e.g. predict_proba
, ROCCurve
) can be called as normal. As above, by default they will generate output for all tasks, and a task can be specified to return information for a single task.
Extensions
The standard OCT extensions (e.g. hyperplane splits, logistic regression) are also available in the multi-task setting and controlled in the usual way.
For instance, we can use Optimal Classification Trees with hyperplane splits:
grid = IAI.GridSearch(
IAI.OptimalTreeMultiClassifier(
random_seed=1,
max_depth=2,
hyperplane_config=(sparsity=:all,)
),
)
IAI.fit!(grid, train_X, train_y)
IAI.get_learner(grid)