tidymodels2

ds1
tidymodels
prediction
yacsda
statlearning
error
string
mtcars
Published

May 17, 2023

Aufgabe

Ein merkwürdiger Fehler bzw. eine merkwürdige Fehlermeldung in Tidymodels - das untersuchen wir hier genauer und versuchen das Phänomen zu erklären.

Aufgabe

Erläutern Sie die Ursachen des Fehlers! Schalten Sie den Fehler an und ab, um zu zeigen, dass Sie Ihn verstehen.

Startup

library(tidyverse)
library(tidymodels)

Data import

data("mtcars")

d_train <- mtcars %>% slice_head(n = 20)
d_test <- mtcars %>% slice(21:n())

Recipe

preds_chosen <- c("hp", "disp", "am")
rec1 <- 
  recipe( ~ ., data = d_train) %>% 
  update_role(all_predictors(), new_role = "id") %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")
rec1
d_train_baked <-
  rec1 %>% 
  prep() %>% 
  bake(new_data = NULL)

glimpse(d_train_baked)
Rows: 20
Columns: 11
$ mpg  <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2, 17.8,…
$ cyl  <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8, 4, 4, 4
$ disp <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 16…
$ hp   <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123, 180, 180, 180…
$ drat <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92, 3.92,…
$ wt   <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3.190, 3.150, 3.…
$ qsec <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18…
$ vs   <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1
$ am   <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
$ gear <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4
$ carb <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4, 1, 2, 1

Model 1

model_lm <- linear_reg()

Workflow 1

wf1 <-
  workflow() %>% 
  add_model(model_lm) %>% 
  add_recipe(rec1)

Fit

lm_fit1 <-
  wf1 %>% 
  fit(d_train)
preds <-
  lm_fit1 %>% 
  predict(d_test)

head(preds)
.pred
22.63594
17.24780
17.44343
12.09935
14.86481
28.15949

Aus Gründen der Reproduzierbarkeit bietet es sich an, eine SessionInfo anzugeben:

sessionInfo()
R version 4.4.1 (2024-06-14)
Platform: x86_64-apple-darwin20
Running under: macOS 15.6.1

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRblas.0.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0

locale:
[1] de_DE.UTF-8/de_DE.UTF-8/de_DE.UTF-8/C/de_DE.UTF-8/en_US.UTF-8

time zone: Europe/Berlin
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] yardstick_1.3.2    workflowsets_1.1.1 workflows_1.3.0    tune_2.0.0        
 [5] tailor_0.1.0       rsample_1.3.1      recipes_1.3.1      parsnip_1.3.3     
 [9] modeldata_1.5.1    infer_1.0.9        dials_1.4.2        scales_1.4.0      
[13] broom_1.0.10       tidymodels_1.4.1   lubridate_1.9.4    forcats_1.0.0     
[17] stringr_1.5.2      dplyr_1.1.4        purrr_1.1.0        readr_2.1.5       
[21] tidyr_1.3.1        tibble_3.3.0       ggplot2_4.0.0      tidyverse_2.0.0   
[25] colorout_1.3-2    

loaded via a namespace (and not attached):
 [1] tidyselect_1.2.1    timeDate_4041.110   farver_2.1.2       
 [4] S7_0.2.0            fastmap_1.2.0       digest_0.6.37      
 [7] rpart_4.1.23        timechange_0.3.0    lifecycle_1.0.4    
[10] survival_3.6-4      magrittr_2.0.4      compiler_4.4.1     
[13] rlang_1.1.6         tools_4.4.1         yaml_2.3.10        
[16] data.table_1.17.8   knitr_1.50          htmlwidgets_1.6.4  
[19] DiceDesign_1.10     RColorBrewer_1.1-3  withr_3.0.2        
[22] nnet_7.3-19         grid_4.4.1          sparsevctrs_0.3.3  
[25] future_1.58.0       globals_0.18.0      MASS_7.3-65        
[28] cli_3.6.5           rmarkdown_2.28      generics_0.1.4     
[31] rstudioapi_0.17.1   future.apply_1.20.0 tzdb_0.4.0         
[34] splines_4.4.1       parallel_4.4.1      vctrs_0.6.5        
[37] hardhat_1.4.2       Matrix_1.7-0        jsonlite_1.8.8     
[40] hms_1.1.3           listenv_0.9.1       gower_1.0.2        
[43] glue_1.8.0          parallelly_1.44.0   codetools_0.2-20   
[46] stringi_1.8.7       gtable_0.3.6        GPfit_1.0-8        
[49] pillar_1.11.1       furrr_0.3.1         htmltools_0.5.8.1  
[52] ipred_0.9-15        lava_1.8.0          R6_2.5.1           
[55] lhs_1.2.0           evaluate_1.0.3      lattice_0.22-6     
[58] backports_1.5.0     class_7.3-22        Rcpp_1.1.0         
[61] prodlim_2024.06.25  xfun_0.52           pkgconfig_2.0.3    











Lösung

Definiert man das Rezept so:

rec2 <- recipe(mpg ~ hp + disp + am, data = d_train)

Dann läuft predict() brav durch.

Auch dieser Code funktioniert:

rec3 <- 
  recipe(mpg ~ ., data = d_train) %>% 
  update_role(all_predictors(), new_role = "id") %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")

Das Problem von rec1 scheint darin zu legen, dass die Rollen der Variablen nicht richtig gelöscht werden, was predict() verwirrt:

rec1 <- 
  recipe(mpg ~ ., data = d_train) %>% 
  update_role(all_predictors(), new_role = "id") %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")
rec1

Daher läuft das Rezept rec3 durch, wenn man zunächst alle Prädiktoren in ID-Variablen umwandelt: Damit sind alle Rollen wieder sauber.


Categories:

  • ds1
  • tidymodels
  • prediction
  • yacsda
  • statlearning
  • error
  • string