wfsets1

R
statlearning
tidymodels
wfsets
template
Published

November 30, 2023

Aufgabe

Berechnen Sie die Vorhersagegüte (RMSE) für folgende Lernalgorithmen mittesl tidymodels:

  • lineares Modell

Modellgleichung: body_mass_g ~ bill_length_mm, data = d_train.

Nutzen Sie minimale Vorverarbeitung im Rahmen zweier Rezepte.

Nutzen Sie ein Workflowset.











Lösung

Setup

library(tidymodels)
library(tictoc)  # Zeitmessung
data(penguins, package = "palmerpenguins")

Daten

d <-
  penguins %>% 
  drop_na()
d_split <- initial_split(d)
d_train <- training(d_split)
d_test <- testing(d_split)

Modelle

Lineares Modell:

mod_lin <- linear_reg()

mod_knn <- nearest_neighbor(mode = "regression",
                                  neighbors = tune())

Rezepte

rec_basic <- recipe(body_mass_g ~ bill_length_mm, data = d_train) %>% 
         step_normalize(all_predictors())

rec_basic
rec_plain <- recipe(body_mass_g ~ bill_length_mm, data = d_train)

Resampling

rsmpls <- vfold_cv(d_train, v = 5)

Workflow Set

wf_set <-
  workflow_set(
    preproc = list(rec_simple = rec_basic,
                   rec_plain = rec_plain),
    models = list(mod_lm = mod_lin)
  )

wf_set
# A workflow set/tibble: 2 × 4
  wflow_id          info             option    result    
  <chr>             <list>           <list>    <list>    
1 rec_simple_mod_lm <tibble [1 × 4]> <opts[0]> <list [0]>
2 rec_plain_mod_lm  <tibble [1 × 4]> <opts[0]> <list [0]>

Fitten

tic()
wf_fit <-
  wf_set %>% 
  workflow_map(resamples = rsmpls)
toc()
1.261 sec elapsed
wf_fit
# A workflow set/tibble: 2 × 4
  wflow_id          info             option    result   
  <chr>             <list>           <list>    <list>   
1 rec_simple_mod_lm <tibble [1 × 4]> <opts[1]> <rsmp[+]>
2 rec_plain_mod_lm  <tibble [1 × 4]> <opts[1]> <rsmp[+]>

Check:

wf_fit %>% pluck("result")
[[1]]
# Resampling results
# 5-fold cross-validation 
# A tibble: 5 × 4
  splits           id    .metrics         .notes          
  <list>           <chr> <list>           <list>          
1 <split [199/50]> Fold1 <tibble [2 × 4]> <tibble [0 × 3]>
2 <split [199/50]> Fold2 <tibble [2 × 4]> <tibble [0 × 3]>
3 <split [199/50]> Fold3 <tibble [2 × 4]> <tibble [0 × 3]>
4 <split [199/50]> Fold4 <tibble [2 × 4]> <tibble [0 × 3]>
5 <split [200/49]> Fold5 <tibble [2 × 4]> <tibble [0 × 3]>

[[2]]
# Resampling results
# 5-fold cross-validation 
# A tibble: 5 × 4
  splits           id    .metrics         .notes          
  <list>           <chr> <list>           <list>          
1 <split [199/50]> Fold1 <tibble [2 × 4]> <tibble [0 × 3]>
2 <split [199/50]> Fold2 <tibble [2 × 4]> <tibble [0 × 3]>
3 <split [199/50]> Fold3 <tibble [2 × 4]> <tibble [0 × 3]>
4 <split [199/50]> Fold4 <tibble [2 × 4]> <tibble [0 × 3]>
5 <split [200/49]> Fold5 <tibble [2 × 4]> <tibble [0 × 3]>

Bester Kandidat

autoplot(wf_fit)

autoplot(wf_fit, select_best = TRUE)

collect_metrics(wf_fit)
# A tibble: 4 × 9
  wflow_id        .config preproc model .metric .estimator    mean     n std_err
  <chr>           <chr>   <chr>   <chr> <chr>   <chr>        <dbl> <int>   <dbl>
1 rec_simple_mod… Prepro… recipe  line… rmse    standard   655.        5 23.0   
2 rec_simple_mod… Prepro… recipe  line… rsq     standard     0.357     5  0.0336
3 rec_plain_mod_… Prepro… recipe  line… rmse    standard   655.        5 23.0   
4 rec_plain_mod_… Prepro… recipe  line… rsq     standard     0.357     5  0.0336
rank_results(wf_fit, rank_metric = "rmse") %>% 
  filter(.metric == "rmse")
# A tibble: 2 × 9
  wflow_id          .config .metric  mean std_err     n preprocessor model  rank
  <chr>             <chr>   <chr>   <dbl>   <dbl> <int> <chr>        <chr> <int>
1 rec_simple_mod_lm Prepro… rmse     655.    23.0     5 recipe       line…     1
2 rec_plain_mod_lm  Prepro… rmse     655.    23.0     5 recipe       line…     2

Last Fit

best_wf <-
  wf_fit %>% 
  extract_workflow("rec_simple_mod_lm")

Finalisieren müssen wir diesen Workflow nicht, da er keine Tuningparameter hatte.

fit_final <-
  best_wf %>% 
  last_fit(d_split)

Modellgüte im Test-Set

collect_metrics(fit_final)
# A tibble: 2 × 4
  .metric .estimator .estimate .config             
  <chr>   <chr>          <dbl> <chr>               
1 rmse    standard     653.    Preprocessor1_Model1
2 rsq     standard       0.341 Preprocessor1_Model1