tidymodels2

ds1
tidymodels
prediction
yacsda
statlearning
error
string
mtcars
Published

May 17, 2023

Aufgabe

Ein merkwürdiger Fehler bzw. eine merkwürdige Fehlermeldung in Tidymodels - das untersuchen wir hier genauer und versuchen das Phänomen zu erklären.

Aufgabe

Erläutern Sie die Ursachen des Fehlers! Schalten Sie den Fehler an und ab, um zu zeigen, dass Sie Ihn verstehen.

Startup

library(tidyverse)
library(tidymodels)

Data import

data("mtcars")

d_train <- mtcars %>% slice_head(n = 20)
d_test <- mtcars %>% slice(21:n())

Recipe

preds_chosen <- c("hp", "disp", "am")
rec1 <- 
  recipe( ~ ., data = d_train) %>% 
  update_role(all_predictors(), new_role = "id") %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")
rec1
d_train_baked <-
  rec1 %>% 
  prep() %>% 
  bake(new_data = NULL)

glimpse(d_train_baked)
Rows: 20
Columns: 11
$ mpg  <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2, 17.8,…
$ cyl  <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8, 4, 4, 4
$ disp <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 16…
$ hp   <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123, 180, 180, 180…
$ drat <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92, 3.92,…
$ wt   <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3.190, 3.150, 3.…
$ qsec <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18…
$ vs   <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1
$ am   <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
$ gear <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4
$ carb <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4, 1, 2, 1

Model 1

model_lm <- linear_reg()

Workflow 1

wf1 <-
  workflow() %>% 
  add_model(model_lm) %>% 
  add_recipe(rec1)

Fit

lm_fit1 <-
  wf1 %>% 
  fit(d_train)
preds <-
  lm_fit1 %>% 
  predict(d_test)

head(preds)
# A tibble: 6 × 1
  .pred
  <dbl>
1  22.6
2  17.2
3  17.4
4  12.1
5  14.9
6  28.2

Aus Gründen der Reproduzierbarkeit bietet es sich an, eine SessionInfo anzugeben:

sessionInfo()
R version 4.2.1 (2022-06-23)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Big Sur ... 10.16

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] yardstick_1.3.1    workflowsets_1.1.0 workflows_1.1.4    tune_1.2.1        
 [5] rsample_1.2.1      recipes_1.1.0      parsnip_1.2.1      modeldata_1.3.0   
 [9] infer_1.0.7        dials_1.3.0        scales_1.3.0       broom_1.0.6       
[13] tidymodels_1.2.0   lubridate_1.9.3    forcats_1.0.0      stringr_1.5.1     
[17] dplyr_1.1.4        purrr_1.0.2        readr_2.1.5        tidyr_1.3.1       
[21] tibble_3.2.1       ggplot2_3.5.1      tidyverse_2.0.0   

loaded via a namespace (and not attached):
 [1] foreach_1.5.2       jsonlite_1.8.8      splines_4.2.1      
 [4] prodlim_2023.03.31  GPfit_1.0-8         yaml_2.3.8         
 [7] globals_0.16.2      ipred_0.9-14        pillar_1.9.0       
[10] backports_1.4.1     lattice_0.21-8      glue_1.6.2         
[13] digest_0.6.33       hardhat_1.4.0       colorspace_2.1-0   
[16] htmltools_0.5.7     Matrix_1.5-4.1      timeDate_4022.108  
[19] pkgconfig_2.0.3     lhs_1.1.6           DiceDesign_1.9     
[22] listenv_0.9.0       gower_1.0.1         lava_1.7.2.1       
[25] tzdb_0.4.0          timechange_0.2.0    generics_0.1.3     
[28] withr_3.0.0         furrr_0.3.1         nnet_7.3-19        
[31] cli_3.6.2           survival_3.5-5      magrittr_2.0.3     
[34] evaluate_0.23       fansi_1.0.6         future_1.33.0      
[37] parallelly_1.36.0   MASS_7.3-60         class_7.3-22       
[40] tools_4.2.1         data.table_1.15.4   hms_1.1.3          
[43] lifecycle_1.0.4     munsell_0.5.0       compiler_4.2.1     
[46] rlang_1.1.4         grid_4.2.1          iterators_1.0.14   
[49] rstudioapi_0.16.0   htmlwidgets_1.6.4   rmarkdown_2.28     
[52] gtable_0.3.4        codetools_0.2-19    R6_2.5.1           
[55] knitr_1.48          fastmap_1.1.1       future.apply_1.11.0
[58] utf8_1.2.4          stringi_1.8.3       parallel_4.2.1     
[61] Rcpp_1.0.13         vctrs_0.6.5         rpart_4.1.21       
[64] tidyselect_1.2.0    xfun_0.47          











Lösung

Definiert man das Rezept so:

rec2 <- recipe(mpg ~ hp + disp + am, data = d_train)

Dann läuft predict() brav durch.

Auch dieser Code funktioniert:

rec3 <- 
  recipe(mpg ~ ., data = d_train) %>% 
  update_role(all_predictors(), new_role = "id") %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")

Das Problem von rec1 scheint darin zu legen, dass die Rollen der Variablen nicht richtig gelöscht werden, was predict() verwirrt:

rec1 <- 
  recipe(mpg ~ ., data = d_train) %>% 
  update_role(all_predictors(), new_role = "id") %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")
rec1

Daher läuft das Rezept rec3 durch, wenn man zunächst alle Prädiktoren in ID-Variablen umwandelt: Damit sind alle Rollen wieder sauber.


Categories:

  • ds1
  • tidymodels
  • prediction
  • yacsda
  • statlearning
  • error
  • string