tidymodels-tree2

statlearning
trees
tidymodels
speed
string
Published

November 8, 2023

Aufgabe

Berechnen Sie folgendes einfache Modell:

  1. Entscheidungsbaum

Modellformel: am ~ . (Datensatz mtcars)

Hier geht es darum, die Geschwindigkeit (und den Ressourcenverbrauch) beim Fitten zu verringern. Benutzen Sie dazu folgende Methoden

  • Verwenden mehrerer Prozesskerne

Hinweise:

  • Tunen Sie alle Parameter (die der Engine anbietet).
  • Verwenden Sie Defaults, wo nicht anders angegeben.
  • Führen Sie eine \(v=2\)-fache Kreuzvalidierung durch (weil die Stichprobe so klein ist).
  • Beachten Sie die üblichen Hinweise.











Lösung

Setup

library(tidymodels)
data(mtcars)
library(tictoc)  # Zeitmessung
library(doParallel)  # Nutzen mehrerer Kerne

Für Klassifikation verlangt Tidymodels eine nominale AV, keine numerische:

mtcars <-
  mtcars %>% 
  mutate(am = factor(am))

Daten teilen

set.seed(42)
d_split <- initial_split(mtcars)
d_train <- training(d_split)
d_test <- testing(d_split)

Modell(e)

mod_tree <-
  decision_tree(mode = "classification",
                cost_complexity = tune(),
                tree_depth = tune(),
                min_n = tune())

Rezept(e)

rec_plain <- 
  recipe(am ~ ., data = d_train)

Resampling

set.seed(42)
rsmpl <- vfold_cv(d_train, v = 2)

Workflows

wf_tree <-
  workflow() %>%  
  add_recipe(rec_plain) %>% 
  add_model(mod_tree)

Tuning/Fitting

Tuninggrid:

tune_grid <- grid_regular(extract_parameter_set_dials(mod_tree), levels = 5)
tune_grid

Ohne Parallelisierung

tic()
fit_tree <-
  tune_grid(object = wf_tree,
            grid = tune_grid,
            metrics = metric_set(roc_auc),
            resamples = rsmpl)
toc()

ca. 45 sec. auf meinem Rechner (4-Kerne-MacBook Pro 2020).

Mit Parallelisierung

Wie viele CPUs hat mein Computer?

parallel::detectCores(logical = FALSE)

Parallele Verarbeitung starten:

cl <- makePSOCKcluster(4)  # Create 4 clusters
registerDoParallel(cl)
tic()
fit_tree2 <-
  tune_grid(object = wf_tree,
            grid = tune_grid,
            metrics = metric_set(roc_auc),
            resamples = rsmpl)
toc()

ca. 17 Sekunden - deutlich schneller!


Categories:

  • statlearning
  • trees
  • tidymodels
  • speed
  • string