rf-finalize3

tidymodels
statlearning
template
string
Published

May 17, 2023

Aufgabe

Berechnen Sie ein prädiktives Modell (Random Forest) mit dieser Modellgleichung:

body_mass_g ~ . (Datensatz: palmerpenguins::penguins).

Zeigen Sie, welche Werte für mtry im Default von Tidymodels gesetzt werden!

Hinweise: - Tunen Sie alle Tuningparameter mit jeweils 3 Werten. - Verwenden Sie Kreuzvalidierung - Verwenden Sie Standardwerte, wo nicht anders angegeben. - Fixieren Sie Zufallszahlen auf den Startwert 42.











Lösung

Standard-Start

Zuererst der Standardablauf:

# Setup:
library(tidymodels)
library(tidyverse)
library(tictoc)  # Zeitmessung
set.seed(42)


# Data:
d_path <- "https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv"
d <- read_csv(d_path)

# rm NA in the dependent variable:
d <- d %>% 
  drop_na(body_mass_g)


set.seed(42)
d_split <- initial_split(d)
d_train <- training(d_split)
d_test <- testing(d_split)


# model:
mod_rf <-
  rand_forest(mode = "regression",
           mtry = tune(),
           min_n = tune(),
           trees = tune())


# cv:
set.seed(42)
rsmpl <- vfold_cv(d_train)


# recipe:
rec_plain <- 
  recipe(body_mass_g ~  ., data = d_train) %>% 
  step_impute_bag(all_predictors())


# workflow:
wf1 <-
  workflow() %>% 
  add_model(mod_rf) %>% 
  add_recipe(rec_plain)

Tuninggrid

Welche Tuningparameter hat unser Workflow?

wf1_params_unclear <- 
  extract_parameter_set_dials(wf1)
wf1_params_unclear
name id source component component_id object
mtry mtry model_spec rand_forest main integer, 1, unknown(), TRUE, TRUE, # Randomly Selected Predictors, function (object, x, log_vals = FALSE, …) , {, check_param(object), rngs <- range_get(object, original = FALSE), if (!is_unknown(rngs\(upper)) {, return(object), }, x_dims <- dim(x), if (is.null(x_dims)) {, cli::cli_abort("Cannot determine number of columns. Is {.arg x} a 2D data object?"), }, if (log_vals) {, rngs[2] <- log10(x_dims[2]), }, else {, rngs[2] <- x_dims[2], }, if (object\)type == “integer” & is.null(object\(trans)) {, rngs <- as.integer(rngs), }, range_set(object, rngs), } | |trees |trees |model_spec |rand_forest |main |integer, 1 , 2000 , TRUE , TRUE , # Trees | |min_n |min_n |model_spec |rand_forest |main |integer, 2, 40, TRUE, TRUE, Minimal Node Size, function (object, x, log_vals = FALSE, frac = 1/3, ...) , {, check_param(object), rngs <- range_get(object, original = FALSE), if (!is_unknown(rngs\)upper)) {, return(object), }, x_dims <- dim(x), if (is.null(x_dims)) {, cli::cli_abort(“Cannot determine number of columns. Is {.arg x} a 2D data object?”), }, n_frac <- floor(x_dims[1] * frac), if (log_vals) {, rngs[2] <- log10(n_frac), }, else {, rngs[2] <- n_frac, }, if (object\(type == "integer" & is.null(object\)trans) & !log_vals) {, rngs <- as.integer(rngs), }, range_set(object, rngs), }

Verlangt waren 3 Tuningparameterwerte pro Parameter:

my_grid <- grid_latin_hypercube(wf1_params_unclear, levels = 3)
Error in `grid_latin_hypercube()`:
! `levels` is not an argument to `grid_latin_hypercube()`.
ℹ Did you mean `size`?
my_grid
Error: object 'my_grid' not found

Tidymodels weiß nicht, welche Werte für mtry benutzt werden sollen, da dieser Wert abhängig ist von der Anzahl der Spalten des Datensatzes, und damit unabhängig vom Modell.

Die Ausgabe nparam[?] oben sagt uns, dass Tidymodels den Wertebereich des Tuningparameter nicht klären könnte, da er Daten abhängig ist.

Informieren wir also Tidymodels zu diesem Wertebereich:

wf1_params <- 
  wf1 %>% 
  extract_parameter_set_dials() %>% 
  update(mtry = finalize(mtry(), d_train))

wf1_params
name id source component component_id object
mtry mtry model_spec rand_forest main integer , 1 , 9 , TRUE , TRUE , # Randomly Selected Predictors
trees trees model_spec rand_forest main integer, 1 , 2000 , TRUE , TRUE , # Trees
min_n min_n model_spec rand_forest main integer, 2, 40, TRUE, TRUE, Minimal Node Size, function (object, x, log_vals = FALSE, frac = 1/3, …) , {, check_param(object), rngs <- range_get(object, original = FALSE), if (!is_unknown(rngs\(upper)) {, return(object), }, x_dims <- dim(x), if (is.null(x_dims)) {, cli::cli_abort("Cannot determine number of columns. Is {.arg x} a 2D data object?"), }, n_frac <- floor(x_dims[1] * frac), if (log_vals) {, rngs[2] <- log10(n_frac), }, else {, rngs[2] <- n_frac, }, if (object\)type == “integer” & is.null(object$trans) & !log_vals) {, rngs <- as.integer(rngs), }, range_set(object, rngs), }

So, jetzt weiß Tidymodels, wie viele Werte für mtry benutzt werden können.

Wir können jetzt das Tuninggitter erstellen (das macht das Paket dials):

my_grid <- grid_latin_hypercube(wf1_params, size = 125)
my_grid %>% head()
mtry trees min_n
1 105 11
5 1036 21
3 325 16
4 1375 28
6 1405 21
7 304 15

Wie viele verschiedene Werte gibt es in dem Tuningitter?

Schauen wir es uns mal an.

my_grid %>% 
  ggplot(aes(x = trees, y = mtry)) +
  geom_point()

Wir können das Tuninggitter auch selber erstellen:

my_grid <-
  grid_latin_hypercube(mtry(range = c(1, ncol(d_train)-1)),
                       trees(),
                       min_n(),
                       size = 60)
dim(my_grid)
[1] 60  3

Tuning/Fitting

# tuning:
tic()
wf1_fit <-
  wf1 %>% 
  tune_grid(
    grid = my_grid,
    resamples = rsmpl)
toc()
135.811 sec elapsed

Dann schauen wir uns das Ergebnisobjekt vom Tuning an.

wf1_fit %>% 
  collect_metrics() %>% 
  filter(.metric == "rmse") %>% 
  arrange(mtry)
mtry trees min_n .metric .estimator mean n std_err .config
1 510 29 rmse standard 327.0506 10 14.102972 pre0_mod01_post0
1 826 3 rmse standard 310.0977 10 13.017354 pre0_mod02_post0
1 1742 14 rmse standard 315.6416 10 13.105640 pre0_mod03_post0
1 1835 33 rmse standard 332.4336 10 13.724682 pre0_mod04_post0
2 51 37 rmse standard 298.2026 10 11.349008 pre0_mod05_post0
2 81 29 rmse standard 292.0088 10 12.269233 pre0_mod06_post0
2 147 22 rmse standard 287.0004 10 10.698344 pre0_mod07_post0
2 359 6 rmse standard 282.6882 10 11.064381 pre0_mod08_post0
2 386 23 rmse standard 288.0263 10 11.926175 pre0_mod09_post0
2 672 15 rmse standard 283.0290 10 11.154272 pre0_mod10_post0
2 782 9 rmse standard 282.1710 10 10.670844 pre0_mod11_post0
2 1590 35 rmse standard 295.8688 10 12.409382 pre0_mod12_post0
2 1927 7 rmse standard 282.5915 10 11.442756 pre0_mod13_post0
3 641 8 rmse standard 281.3474 10 10.738309 pre0_mod14_post0
3 747 24 rmse standard 282.9493 10 11.059727 pre0_mod15_post0
3 1093 36 rmse standard 289.0472 10 11.810703 pre0_mod16_post0
3 1219 3 rmse standard 282.7862 10 11.065238 pre0_mod17_post0
3 1408 13 rmse standard 280.9360 10 10.869681 pre0_mod18_post0
3 1697 36 rmse standard 288.2570 10 11.634923 pre0_mod19_post0
3 1889 4 rmse standard 282.0771 10 10.655292 pre0_mod20_post0
3 1967 39 rmse standard 290.0862 10 11.890039 pre0_mod21_post0
4 261 31 rmse standard 285.9761 10 10.342291 pre0_mod22_post0
4 292 24 rmse standard 283.3860 10 10.744992 pre0_mod23_post0
4 330 21 rmse standard 282.0808 10 10.291713 pre0_mod24_post0
4 458 18 rmse standard 281.6246 10 9.928690 pre0_mod25_post0
4 709 14 rmse standard 280.7476 10 9.676739 pre0_mod26_post0
4 1062 35 rmse standard 286.0341 10 11.135346 pre0_mod27_post0
4 1360 8 rmse standard 282.1273 10 10.147874 pre0_mod28_post0
4 1543 25 rmse standard 281.9846 10 10.757734 pre0_mod29_post0
4 1617 12 rmse standard 280.4863 10 10.414125 pre0_mod30_post0
5 588 34 rmse standard 284.0172 10 10.505647 pre0_mod31_post0
5 900 38 rmse standard 285.5616 10 10.776293 pre0_mod32_post0
5 909 12 rmse standard 281.1531 10 10.183557 pre0_mod33_post0
5 1029 30 rmse standard 283.3368 10 10.506628 pre0_mod34_post0
5 1274 19 rmse standard 280.4960 10 10.069676 pre0_mod35_post0
5 1322 16 rmse standard 280.5548 10 10.308200 pre0_mod36_post0
5 1444 39 rmse standard 285.4752 10 10.843783 pre0_mod37_post0
5 1655 26 rmse standard 281.1615 10 10.492565 pre0_mod38_post0
5 1994 11 rmse standard 281.3936 10 10.143697 pre0_mod39_post0
6 481 2 rmse standard 284.1143 10 10.174460 pre0_mod40_post0
6 629 27 rmse standard 281.5415 10 10.115595 pre0_mod41_post0
6 858 32 rmse standard 282.5570 10 10.631752 pre0_mod42_post0
6 996 17 rmse standard 280.9367 10 9.847399 pre0_mod43_post0
6 1155 31 rmse standard 282.1929 10 10.186091 pre0_mod44_post0
6 1178 39 rmse standard 284.8779 10 10.436088 pre0_mod45_post0
6 1375 10 rmse standard 282.1038 10 10.032399 pre0_mod46_post0
6 1532 20 rmse standard 280.5911 10 10.113653 pre0_mod47_post0
7 7 20 rmse standard 295.2057 10 7.571570 pre0_mod48_post0
7 104 17 rmse standard 280.6154 10 10.137987 pre0_mod49_post0
7 422 27 rmse standard 280.7607 10 10.022320 pre0_mod50_post0
7 546 11 rmse standard 280.1788 10 10.373537 pre0_mod51_post0
7 953 23 rmse standard 280.3642 10 9.769144 pre0_mod52_post0
7 1111 29 rmse standard 280.3470 10 10.131759 pre0_mod53_post0
7 1254 26 rmse standard 280.9250 10 9.864827 pre0_mod54_post0
7 1713 6 rmse standard 283.5527 10 10.192602 pre0_mod55_post0
7 1787 33 rmse standard 281.5877 10 10.133599 pre0_mod56_post0
8 185 20 rmse standard 279.5526 10 9.209935 pre0_mod57_post0
8 213 9 rmse standard 281.9499 10 10.313192 pre0_mod58_post0
8 1479 5 rmse standard 283.1197 10 10.354868 pre0_mod59_post0
8 1821 16 rmse standard 279.8859 10 9.900455 pre0_mod60_post0

In der Hilfe ist zu lesen:

In some cases, the tuning parameter values depend on the dimensions of the data. For example, mtry in random forest models depends on the number of predictors. In this case, the default tuning parameter object requires an upper range. dials::finalize() can be used to derive the data-dependent parameters. Otherwise, a parameter set can be created (via dials::parameters()) and the dials update() function can be used to change the values. This updated parameter set can be passed to the function via the param_info argument.

Achtung: step_impute_knn scheint Probleme zu haben, wenn es Charakter-Variablen gibt.

Praktischerweise findet Tidymodels die Begrenzung von mtry selber heraus, wenn Sie kein Tuninggrid definieren. Das erkennen Sie daran, dass Tidymodels beim Tuning/Fitten die folgende Ausgabe zeigt:

i Creating pre-processing data to finalize unknown parameter: mtry.


Categories:

  • tidymodels
  • statlearning
  • template
  • string