Quick Start Guide: Heuristic Classifiers

This is an R version of the corresponding Heuristics quick start guide.

In this example we will use classifiers from Heuristics on the banknote authentication dataset. First we load in the data and split into training and test datasets:

df <- read.table("data_banknote_authentication.txt", sep = ",",
                 col.names = c("variance", "skewness", "curtosis", "entropy",
                               "class"))
   variance skewness curtosis  entropy class
1   3.62160   8.6661 -2.80730 -0.44699     0
2   4.54590   8.1674 -2.45860 -1.46210     0
3   3.86600  -2.6383  1.92420  0.10645     0
4   3.45660   9.5228 -4.01120 -3.59440     0
5   0.32924  -4.4552  4.57180 -0.98880     0
6   4.36840   9.6718 -3.96060 -3.16250     0
7   3.59120   3.0129  0.72888  0.56421     0
8   2.09220  -6.8100  8.46360 -0.60216     0
9   3.20320   5.7588 -0.75345 -0.61251     0
10  1.53560   9.1772 -2.27180 -0.73535     0
11  1.22470   8.7779 -2.21350 -0.80647     0
12  3.98990  -2.7066  2.39460  0.86291     0
 [ reached 'max' / getOption("max.print") -- omitted 1360 rows ]
X <- df[, 1:4]
y <- df[, 5]
split <- iai::split_data("classification", X, y, seed = 1)
train_X <- split$train$X
train_y <- split$train$y
test_X <- split$test$X
test_y <- split$test$y

Random Forest Classifier

We will use a grid_search to fit a random_forest_classifier with some basic parameter validation:

grid <- iai::grid_search(
    iai::random_forest_classifier(
        random_seed = 1,
    ),
    max_depth = 5:10,
)
iai::fit(grid, train_X, train_y)

We can make predictions on new data using predict:

iai::predict(grid, test_X)
 [1] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
[39] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 [ reached getOption("max.print") -- omitted 352 entries ]

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the misclassification on the training set:

iai::score(grid, train_X, train_y, criterion = "misclassification")
[1] 1

Or the AUC on the test set:

iai::score(grid, test_X, test_y, criterion = "auc")
[1] 0.9999761

We can also look at the variable importance:

iai::variable_importance(iai::get_learner(grid))
   Feature Importance
1 variance 0.54946411
2 skewness 0.25180303
3 curtosis 0.14398502
4  entropy 0.05474785

XGBoost Classifier

We will use a grid_search to fit an xgboost_classifier with some basic parameter validation:

grid <- iai::grid_search(
    iai::xgboost_classifier(
        random_seed = 1,
    ),
    max_depth = 2:5,
    num_round = c(20, 50, 100),
)
iai::fit(grid, train_X, train_y)

We can make predictions on new data using predict:

iai::predict(grid, test_X)
 [1] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
[39] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 [ reached getOption("max.print") -- omitted 352 entries ]

We can evaluate the quality of the model using score with any of the supported loss functions. For example, the misclassification on the training set:

iai::score(grid, train_X, train_y, criterion = "misclassification")
[1] 1

Or the AUC on the test set:

iai::score(grid, test_X, test_y, criterion = "auc")
[1] 0.9999284

We can also look at the variable importance:

iai::variable_importance(iai::get_learner(grid))
   Feature Importance
1 variance 0.59786180
2 skewness 0.23441181
3 curtosis 0.15546901
4  entropy 0.01225738

We can calculate the SHAP values:

iai::predict_shap(grid, test_X)
$expected_value
[1] 0.7067032 0.2932982

$labels
[1] 0 1

$features
   variance skewness curtosis  entropy
1   4.54590  8.16740 -2.45860 -1.46210
2   0.32924 -4.45520  4.57180 -0.98880
3   4.36840  9.67180 -3.96060 -3.16250
4   2.09220 -6.81000  8.46360 -0.60216
5   1.22470  8.77790 -2.21350 -0.80647
6   3.98990 -2.70660  2.39460  0.86291
7   4.67650 -3.38950  3.48960  1.47710
8   2.67190  3.06460  0.37158  0.58619
9   6.56330  9.81870 -4.41130 -3.22580
10 -0.24811 -0.17797  4.90680  0.15429
11  1.48840  3.62740  3.30800  0.48921
12  4.29690  7.61700 -2.38740 -0.96164
13 -1.61620  0.80908  8.16280  0.60817
14  2.68810  6.01950 -0.46641 -0.69268
15  3.48050  9.70080 -3.75410 -3.43790
 [ reached 'max' / getOption("max.print") -- omitted 397 rows ]

$shap_values
$shap_values[[1]]
              [,1]         [,2]        [,3]         [,4]
  [1,]  3.54497743  3.229129553 -2.00503373  0.394500077
  [2,]  0.72894478 -1.465527058  2.35883331  0.735524356
  [3,]  3.39880562  3.291855574 -2.13043642  0.300303429
  [4,]  2.95135808 -2.379413366  2.52936983  0.648338735
  [5,]  1.10210097  3.350145102 -1.98984838  0.507956326
  [6,]  4.35069799 -1.748496413  0.93079853 -0.612113357
  [7,]  4.09464073 -1.624990225  1.86801839 -0.612113357
  [8,]  4.34744787 -0.151716962 -0.12539595 -0.313408345
  [9,]  3.29869652  3.238628149 -2.67316556  0.300303429
 [10,] -0.15058978 -1.102718830  2.54167652 -0.638054609
 [11,]  1.74040437  0.355728924  2.70026445 -0.495974690
 [12,]  3.63787985  3.118540049 -2.00503373  0.394500077
 [13,] -2.00483656  0.452732146  3.64206815 -0.435520321
 [14,]  4.06534529  2.214010239 -0.83388287  0.394500077
 [15,]  3.39880562  3.291855574 -2.13043642  0.300303429
 [ reached getOption("max.print") -- omitted 397 rows ]

$shap_values[[2]]
              [,1]         [,2]        [,3]         [,4]
  [1,] -3.54497743 -3.229130507  2.00503469 -0.394499958
  [2,] -0.72894454  1.465527177 -2.35883355 -0.735524237
  [3,] -3.39880610 -3.291855812  2.13043690 -0.300303698
  [4,] -2.95135856  2.379413605 -2.52936959 -0.648338556
  [5,] -1.10210025 -3.350145817  1.98984885 -0.507956207
  [6,] -4.35069895  1.748496056 -0.93079925  0.612113357
  [7,] -4.09464169  1.624990225 -1.86801851  0.612113357
  [8,] -4.34744835  0.151716903  0.12539557  0.313408434
  [9,] -3.29869652 -3.238628387  2.67316651 -0.300303698
 [10,]  0.15059002  1.102719069 -2.54167676  0.638054788
 [11,] -1.74040437 -0.355729282 -2.70026493  0.495974541
 [12,] -3.63787985 -3.118540049  2.00503469 -0.394499958
 [13,]  2.00483584 -0.452732295 -3.64206791  0.435520381
 [14,] -4.06534576 -2.214010715  0.83388287 -0.394499958
 [15,] -3.39880610 -3.291855812  2.13043690 -0.300303698
 [ reached getOption("max.print") -- omitted 397 rows ]

We can then use the SHAP library to visualize these results in whichever way we prefer.