# Setup:
library(tidymodels)
library(tidyverse)
library(tictoc) # Zeitmessung
set.seed(42)
# Data:
<- "https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv"
d_path <- read_csv(d_path)
d
# rm NA in the dependent variable:
<- d %>%
d drop_na(body_mass_g)
set.seed(42)
<- initial_split(d)
d_split <- training(d_split)
d_train <- testing(d_split)
d_test
# model:
<-
mod_rf rand_forest(mode = "regression",
mtry = tune(),
min_n = tune(),
trees = tune())
# cv:
set.seed(42)
<- vfold_cv(d_train)
rsmpl
# recipe:
<-
rec_plain recipe(body_mass_g ~ ., data = d_train) %>%
step_impute_bag(all_predictors())
# workflow:
<-
wf1 workflow() %>%
add_model(mod_rf) %>%
add_recipe(rec_plain)
rf-finalize3
Aufgabe
Berechnen Sie ein prädiktives Modell (Random Forest) mit dieser Modellgleichung:
body_mass_g ~ .
(Datensatz: palmerpenguins::penguins).
Zeigen Sie, welche Werte für mtry im Default von Tidymodels gesetzt werden!
Hinweise: - Tunen Sie alle Tuningparameter mit jeweils 3 Werten. - Verwenden Sie Kreuzvalidierung - Verwenden Sie Standardwerte, wo nicht anders angegeben. - Fixieren Sie Zufallszahlen auf den Startwert 42.
Lösung
Standard-Start
Zuererst der Standardablauf:
Tuninggrid
Welche Tuningparameter hat unser Workflow?
<-
wf1_params_unclear extract_parameter_set_dials(wf1)
wf1_params_unclear
name | id | source | component | component_id | object |
---|---|---|---|---|---|
mtry | mtry | model_spec | rand_forest | main | integer, 1, unknown(), TRUE, TRUE, # Randomly Selected Predictors, function (object, x, log_vals = FALSE, …) , {, check_param(object), rngs <- range_get(object, original = FALSE), if (!is_unknown(rngs\(upper)) {, return(object), }, x_dims <- dim(x), if (is.null(x_dims)) {, cli::cli_abort("Cannot determine number of columns. Is {.arg x} a 2D data object?"), }, if (log_vals) {, rngs[2] <- log10(x_dims[2]), }, else {, rngs[2] <- x_dims[2], }, if (object\)type == “integer” & is.null(object\(trans)) {, rngs <- as.integer(rngs), }, range_set(object, rngs), } | |trees |trees |model_spec |rand_forest |main |integer, 1 , 2000 , TRUE , TRUE , # Trees | |min_n |min_n |model_spec |rand_forest |main |integer, 2, 40, TRUE, TRUE, Minimal Node Size, function (object, x, log_vals = FALSE, frac = 1/3, ...) , {, check_param(object), rngs <- range_get(object, original = FALSE), if (!is_unknown(rngs\)upper)) {, return(object), }, x_dims <- dim(x), if (is.null(x_dims)) {, cli::cli_abort(“Cannot determine number of columns. Is {.arg x} a 2D data object?”), }, n_frac <- floor(x_dims[1] * frac), if (log_vals) {, rngs[2] <- log10(n_frac), }, else {, rngs[2] <- n_frac, }, if (object\(type == "integer" & is.null(object\)trans) & !log_vals) {, rngs <- as.integer(rngs), }, range_set(object, rngs), } |
Verlangt waren 3 Tuningparameterwerte pro Parameter:
<- grid_latin_hypercube(wf1_params_unclear, levels = 3) my_grid
Error in `grid_latin_hypercube()`:
! `levels` is not an argument to `grid_latin_hypercube()`.
ℹ Did you mean `size`?
my_grid
Error: object 'my_grid' not found
Tidymodels weiß nicht, welche Werte für mtry
benutzt werden sollen, da dieser Wert abhängig ist von der Anzahl der Spalten des Datensatzes, und damit unabhängig vom Modell.
Die Ausgabe nparam[?]
oben sagt uns, dass Tidymodels den Wertebereich des Tuningparameter nicht klären könnte, da er Daten abhängig ist.
Informieren wir also Tidymodels zu diesem Wertebereich:
<-
wf1_params %>%
wf1 extract_parameter_set_dials() %>%
update(mtry = finalize(mtry(), d_train))
wf1_params
name | id | source | component | component_id | object |
---|---|---|---|---|---|
mtry | mtry | model_spec | rand_forest | main | integer , 1 , 9 , TRUE , TRUE , # Randomly Selected Predictors |
trees | trees | model_spec | rand_forest | main | integer, 1 , 2000 , TRUE , TRUE , # Trees |
min_n | min_n | model_spec | rand_forest | main | integer, 2, 40, TRUE, TRUE, Minimal Node Size, function (object, x, log_vals = FALSE, frac = 1/3, …) , {, check_param(object), rngs <- range_get(object, original = FALSE), if (!is_unknown(rngs\(upper)) {, return(object), }, x_dims <- dim(x), if (is.null(x_dims)) {, cli::cli_abort("Cannot determine number of columns. Is {.arg x} a 2D data object?"), }, n_frac <- floor(x_dims[1] * frac), if (log_vals) {, rngs[2] <- log10(n_frac), }, else {, rngs[2] <- n_frac, }, if (object\)type == “integer” & is.null(object$trans) & !log_vals) {, rngs <- as.integer(rngs), }, range_set(object, rngs), } |
So, jetzt weiß Tidymodels, wie viele Werte für mtry
benutzt werden können.
Wir können jetzt das Tuninggitter erstellen (das macht das Paket dials
):
<- grid_latin_hypercube(wf1_params, size = 125)
my_grid %>% head() my_grid
mtry | trees | min_n |
---|---|---|
1 | 105 | 11 |
5 | 1036 | 21 |
3 | 325 | 16 |
4 | 1375 | 28 |
6 | 1405 | 21 |
7 | 304 | 15 |
Wie viele verschiedene Werte gibt es in dem Tuningitter?
Schauen wir es uns mal an.
%>%
my_grid ggplot(aes(x = trees, y = mtry)) +
geom_point()
Wir können das Tuninggitter auch selber erstellen:
<-
my_grid grid_latin_hypercube(mtry(range = c(1, ncol(d_train)-1)),
trees(),
min_n(),
size = 60)
dim(my_grid)
[1] 60 3
Tuning/Fitting
# tuning:
tic()
<-
wf1_fit %>%
wf1 tune_grid(
grid = my_grid,
resamples = rsmpl)
toc()
135.811 sec elapsed
Dann schauen wir uns das Ergebnisobjekt vom Tuning an.
%>%
wf1_fit collect_metrics() %>%
filter(.metric == "rmse") %>%
arrange(mtry)
mtry | trees | min_n | .metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|---|---|---|
1 | 510 | 29 | rmse | standard | 327.0506 | 10 | 14.102972 | pre0_mod01_post0 |
1 | 826 | 3 | rmse | standard | 310.0977 | 10 | 13.017354 | pre0_mod02_post0 |
1 | 1742 | 14 | rmse | standard | 315.6416 | 10 | 13.105640 | pre0_mod03_post0 |
1 | 1835 | 33 | rmse | standard | 332.4336 | 10 | 13.724682 | pre0_mod04_post0 |
2 | 51 | 37 | rmse | standard | 298.2026 | 10 | 11.349008 | pre0_mod05_post0 |
2 | 81 | 29 | rmse | standard | 292.0088 | 10 | 12.269233 | pre0_mod06_post0 |
2 | 147 | 22 | rmse | standard | 287.0004 | 10 | 10.698344 | pre0_mod07_post0 |
2 | 359 | 6 | rmse | standard | 282.6882 | 10 | 11.064381 | pre0_mod08_post0 |
2 | 386 | 23 | rmse | standard | 288.0263 | 10 | 11.926175 | pre0_mod09_post0 |
2 | 672 | 15 | rmse | standard | 283.0290 | 10 | 11.154272 | pre0_mod10_post0 |
2 | 782 | 9 | rmse | standard | 282.1710 | 10 | 10.670844 | pre0_mod11_post0 |
2 | 1590 | 35 | rmse | standard | 295.8688 | 10 | 12.409382 | pre0_mod12_post0 |
2 | 1927 | 7 | rmse | standard | 282.5915 | 10 | 11.442756 | pre0_mod13_post0 |
3 | 641 | 8 | rmse | standard | 281.3474 | 10 | 10.738309 | pre0_mod14_post0 |
3 | 747 | 24 | rmse | standard | 282.9493 | 10 | 11.059727 | pre0_mod15_post0 |
3 | 1093 | 36 | rmse | standard | 289.0472 | 10 | 11.810703 | pre0_mod16_post0 |
3 | 1219 | 3 | rmse | standard | 282.7862 | 10 | 11.065238 | pre0_mod17_post0 |
3 | 1408 | 13 | rmse | standard | 280.9360 | 10 | 10.869681 | pre0_mod18_post0 |
3 | 1697 | 36 | rmse | standard | 288.2570 | 10 | 11.634923 | pre0_mod19_post0 |
3 | 1889 | 4 | rmse | standard | 282.0771 | 10 | 10.655292 | pre0_mod20_post0 |
3 | 1967 | 39 | rmse | standard | 290.0862 | 10 | 11.890039 | pre0_mod21_post0 |
4 | 261 | 31 | rmse | standard | 285.9761 | 10 | 10.342291 | pre0_mod22_post0 |
4 | 292 | 24 | rmse | standard | 283.3860 | 10 | 10.744992 | pre0_mod23_post0 |
4 | 330 | 21 | rmse | standard | 282.0808 | 10 | 10.291713 | pre0_mod24_post0 |
4 | 458 | 18 | rmse | standard | 281.6246 | 10 | 9.928690 | pre0_mod25_post0 |
4 | 709 | 14 | rmse | standard | 280.7476 | 10 | 9.676739 | pre0_mod26_post0 |
4 | 1062 | 35 | rmse | standard | 286.0341 | 10 | 11.135346 | pre0_mod27_post0 |
4 | 1360 | 8 | rmse | standard | 282.1273 | 10 | 10.147874 | pre0_mod28_post0 |
4 | 1543 | 25 | rmse | standard | 281.9846 | 10 | 10.757734 | pre0_mod29_post0 |
4 | 1617 | 12 | rmse | standard | 280.4863 | 10 | 10.414125 | pre0_mod30_post0 |
5 | 588 | 34 | rmse | standard | 284.0172 | 10 | 10.505647 | pre0_mod31_post0 |
5 | 900 | 38 | rmse | standard | 285.5616 | 10 | 10.776293 | pre0_mod32_post0 |
5 | 909 | 12 | rmse | standard | 281.1531 | 10 | 10.183557 | pre0_mod33_post0 |
5 | 1029 | 30 | rmse | standard | 283.3368 | 10 | 10.506628 | pre0_mod34_post0 |
5 | 1274 | 19 | rmse | standard | 280.4960 | 10 | 10.069676 | pre0_mod35_post0 |
5 | 1322 | 16 | rmse | standard | 280.5548 | 10 | 10.308200 | pre0_mod36_post0 |
5 | 1444 | 39 | rmse | standard | 285.4752 | 10 | 10.843783 | pre0_mod37_post0 |
5 | 1655 | 26 | rmse | standard | 281.1615 | 10 | 10.492565 | pre0_mod38_post0 |
5 | 1994 | 11 | rmse | standard | 281.3936 | 10 | 10.143697 | pre0_mod39_post0 |
6 | 481 | 2 | rmse | standard | 284.1143 | 10 | 10.174460 | pre0_mod40_post0 |
6 | 629 | 27 | rmse | standard | 281.5415 | 10 | 10.115595 | pre0_mod41_post0 |
6 | 858 | 32 | rmse | standard | 282.5570 | 10 | 10.631752 | pre0_mod42_post0 |
6 | 996 | 17 | rmse | standard | 280.9367 | 10 | 9.847399 | pre0_mod43_post0 |
6 | 1155 | 31 | rmse | standard | 282.1929 | 10 | 10.186091 | pre0_mod44_post0 |
6 | 1178 | 39 | rmse | standard | 284.8779 | 10 | 10.436088 | pre0_mod45_post0 |
6 | 1375 | 10 | rmse | standard | 282.1038 | 10 | 10.032399 | pre0_mod46_post0 |
6 | 1532 | 20 | rmse | standard | 280.5911 | 10 | 10.113653 | pre0_mod47_post0 |
7 | 7 | 20 | rmse | standard | 295.2057 | 10 | 7.571570 | pre0_mod48_post0 |
7 | 104 | 17 | rmse | standard | 280.6154 | 10 | 10.137987 | pre0_mod49_post0 |
7 | 422 | 27 | rmse | standard | 280.7607 | 10 | 10.022320 | pre0_mod50_post0 |
7 | 546 | 11 | rmse | standard | 280.1788 | 10 | 10.373537 | pre0_mod51_post0 |
7 | 953 | 23 | rmse | standard | 280.3642 | 10 | 9.769144 | pre0_mod52_post0 |
7 | 1111 | 29 | rmse | standard | 280.3470 | 10 | 10.131759 | pre0_mod53_post0 |
7 | 1254 | 26 | rmse | standard | 280.9250 | 10 | 9.864827 | pre0_mod54_post0 |
7 | 1713 | 6 | rmse | standard | 283.5527 | 10 | 10.192602 | pre0_mod55_post0 |
7 | 1787 | 33 | rmse | standard | 281.5877 | 10 | 10.133599 | pre0_mod56_post0 |
8 | 185 | 20 | rmse | standard | 279.5526 | 10 | 9.209935 | pre0_mod57_post0 |
8 | 213 | 9 | rmse | standard | 281.9499 | 10 | 10.313192 | pre0_mod58_post0 |
8 | 1479 | 5 | rmse | standard | 283.1197 | 10 | 10.354868 | pre0_mod59_post0 |
8 | 1821 | 16 | rmse | standard | 279.8859 | 10 | 9.900455 | pre0_mod60_post0 |
In der Hilfe ist zu lesen:
In some cases, the tuning parameter values depend on the dimensions of the data. For example, mtry in random forest models depends on the number of predictors. In this case, the default tuning parameter object requires an upper range. dials::finalize() can be used to derive the data-dependent parameters. Otherwise, a parameter set can be created (via dials::parameters()) and the dials update() function can be used to change the values. This updated parameter set can be passed to the function via the param_info argument.
Achtung: step_impute_knn
scheint Probleme zu haben, wenn es Charakter-Variablen gibt.
Praktischerweise findet Tidymodels die Begrenzung von mtry
selber heraus, wenn Sie kein Tuninggrid definieren. Das erkennen Sie daran, dass Tidymodels beim Tuning/Fitten die folgende Ausgabe zeigt:
i Creating pre-processing data to finalize unknown parameter: mtry
.
Categories:
- tidymodels
- statlearning
- template
- string