<- "https://raw.githubusercontent.com/sebastiansauer/Lehre/main/data/tmdb-box-office-prediction/train.csv"
d_train_path <- "https://raw.githubusercontent.com/sebastiansauer/Lehre/main/data/tmdb-box-office-prediction/test.csv" d_test_path
tmdb02
Aufgabe
Wir bearbeiten hier die Fallstudie TMDB Box Office Prediction - Can you predict a movie’s worldwide box office revenue?, ein Kaggle-Prognosewettbewerb.
Ziel ist es, genaue Vorhersagen zu machen, in diesem Fall für Filme.
Die Daten können Sie von der Kaggle-Projektseite beziehen oder so:
Aufgabe
Reichen Sie bei Kaggle eine Submission für die Fallstudie ein! Berichten Sie den Kaggle-Score
Hinweise:
- Sie müssen sich bei Kaggle ein Konto anlegen (kostenlos und anonym möglich); alternativ können Sie sich mit einem Google-Konto anmelden.
- Berechnen Sie einen Entscheidungsbaum und einen Random-Forest.
- Tunen Sie nach Bedarf; verwenden Sie aber Default-Werte.
- Verwenden Sie Tidymodels.
Lösung
Vorbereitung
library(tidyverse)
library(tidymodels)
library(tictoc)
library(doParallel) # mehrere CPUs nutzen
library(finetune) # Tune Anova
<- read_csv(d_train_path)
d_train <- read_csv(d_test_path)
d_test
glimpse(d_train)
Rows: 3,000
Columns: 23
$ id <dbl> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…
$ belongs_to_collection <chr> "[{'id': 313576, 'name': 'Hot Tub Time Machine C…
$ budget <dbl> 1.40e+07, 4.00e+07, 3.30e+06, 1.20e+06, 0.00e+00…
$ genres <chr> "[{'id': 35, 'name': 'Comedy'}]", "[{'id': 35, '…
$ homepage <chr> NA, NA, "http://sonyclassics.com/whiplash/", "ht…
$ imdb_id <chr> "tt2637294", "tt0368933", "tt2582802", "tt182148…
$ original_language <chr> "en", "en", "en", "hi", "ko", "en", "en", "en", …
$ original_title <chr> "Hot Tub Time Machine 2", "The Princess Diaries …
$ overview <chr> "When Lou, who has become the \"father of the In…
$ popularity <dbl> 6.575393, 8.248895, 64.299990, 3.174936, 1.14807…
$ poster_path <chr> "/tQtWuwvMf0hCc2QR2tkolwl7c3c.jpg", "/w9Z7A0GHEh…
$ production_companies <chr> "[{'name': 'Paramount Pictures', 'id': 4}, {'nam…
$ production_countries <chr> "[{'iso_3166_1': 'US', 'name': 'United States of…
$ release_date <chr> "2/20/15", "8/6/04", "10/10/14", "3/9/12", "2/5/…
$ runtime <dbl> 93, 113, 105, 122, 118, 83, 92, 84, 100, 91, 119…
$ spoken_languages <chr> "[{'iso_639_1': 'en', 'name': 'English'}]", "[{'…
$ status <chr> "Released", "Released", "Released", "Released", …
$ tagline <chr> "The Laws of Space and Time are About to be Viol…
$ title <chr> "Hot Tub Time Machine 2", "The Princess Diaries …
$ Keywords <chr> "[{'id': 4379, 'name': 'time travel'}, {'id': 96…
$ cast <chr> "[{'cast_id': 4, 'character': 'Lou', 'credit_id'…
$ crew <chr> "[{'credit_id': '59ac067c92514107af02c8c8', 'dep…
$ revenue <dbl> 12314651, 95149435, 13092000, 16000000, 3923970,…
glimpse(d_test)
Rows: 4,398
Columns: 22
$ id <dbl> 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, …
$ belongs_to_collection <chr> "[{'id': 34055, 'name': 'Pokémon Collection', 'p…
$ budget <dbl> 0.00e+00, 8.80e+04, 0.00e+00, 6.80e+06, 2.00e+06…
$ genres <chr> "[{'id': 12, 'name': 'Adventure'}, {'id': 16, 'n…
$ homepage <chr> "http://www.pokemon.com/us/movies/movie-pokemon-…
$ imdb_id <chr> "tt1226251", "tt0051380", "tt0118556", "tt125595…
$ original_language <chr> "ja", "en", "en", "fr", "en", "en", "de", "en", …
$ original_title <chr> "ディアルガVSパルキアVSダークライ", "Attack of t…
$ overview <chr> "Ash and friends (this time accompanied by newco…
$ popularity <dbl> 3.851534, 3.559789, 8.085194, 8.596012, 3.217680…
$ poster_path <chr> "/tnftmLMemPLduW6MRyZE0ZUD19z.jpg", "/9MgBNBqlH1…
$ production_companies <chr> NA, "[{'name': 'Woolner Brothers Pictures Inc.',…
$ production_countries <chr> "[{'iso_3166_1': 'JP', 'name': 'Japan'}, {'iso_3…
$ release_date <chr> "7/14/07", "5/19/58", "5/23/97", "9/4/10", "2/11…
$ runtime <dbl> 90, 65, 100, 130, 92, 121, 119, 77, 120, 92, 88,…
$ spoken_languages <chr> "[{'iso_639_1': 'en', 'name': 'English'}, {'iso_…
$ status <chr> "Released", "Released", "Released", "Released", …
$ tagline <chr> "Somewhere Between Time & Space... A Legend Is B…
$ title <chr> "Pokémon: The Rise of Darkrai", "Attack of the 5…
$ Keywords <chr> "[{'id': 11451, 'name': 'pok√©mon'}, {'id': 1155…
$ cast <chr> "[{'cast_id': 3, 'character': 'Tonio', 'credit_i…
$ crew <chr> "[{'credit_id': '52fe44e7c3a368484e03d683', 'dep…
Rezept
Rezept definieren
<-
rec1 recipe(revenue ~ ., data = d_train) %>%
update_role(all_predictors(), new_role = "id") %>%
update_role(popularity, runtime, revenue, budget) %>%
update_role(revenue, new_role = "outcome") %>%
step_mutate(budget = ifelse(budget < 10, 10, budget)) %>%
step_log(budget) %>%
step_impute_knn(all_predictors())
rec1
Check das Rezept
<-
rec1_prepped prep(rec1, verbose = TRUE)
oper 1 step mutate [training]
oper 2 step log [training]
oper 3 step impute knn [training]
The retained training set is ~ 28.71 Mb in memory.
rec1_prepped
<-
d_train_baked %>%
rec1_prepped bake(new_data = NULL)
head(d_train_baked)
# A tibble: 6 × 23
id belongs_to_collection budget genres homepage imdb_id original_language
<dbl> <fct> <dbl> <fct> <fct> <fct> <fct>
1 1 [{'id': 313576, 'name'… 16.5 [{'id… <NA> tt2637… en
2 2 [{'id': 107674, 'name'… 17.5 [{'id… <NA> tt0368… en
3 3 <NA> 15.0 [{'id… http://… tt2582… en
4 4 <NA> 14.0 [{'id… http://… tt1821… hi
5 5 <NA> 2.30 [{'id… <NA> tt1380… ko
6 6 <NA> 15.9 [{'id… <NA> tt0093… en
# ℹ 16 more variables: original_title <fct>, overview <fct>, popularity <dbl>,
# poster_path <fct>, production_companies <fct>, production_countries <fct>,
# release_date <fct>, runtime <dbl>, spoken_languages <fct>, status <fct>,
# tagline <fct>, title <fct>, Keywords <fct>, cast <fct>, crew <fct>,
# revenue <dbl>
Die AV-Spalte sollte leer sein:
bake(rec1_prepped, new_data = head(d_test), all_outcomes())
# A tibble: 6 × 0
%>%
d_train_baked map_df(~ sum(is.na(.)))
# A tibble: 1 × 23
id belongs_to_collection budget genres homepage imdb_id original_language
<int> <int> <int> <int> <int> <int> <int>
1 0 2396 0 7 2054 0 0
# ℹ 16 more variables: original_title <int>, overview <int>, popularity <int>,
# poster_path <int>, production_companies <int>, production_countries <int>,
# release_date <int>, runtime <int>, spoken_languages <int>, status <int>,
# tagline <int>, title <int>, Keywords <int>, cast <int>, crew <int>,
# revenue <int>
Keine fehlenden Werte mehr in den Prädiktoren.
Nach fehlenden Werten könnte man z.B. auch so suchen:
::describe_distribution(d_train_baked) datawizard
Variable | Mean | SD | IQR | Range | Skewness | Kurtosis | n | n_Missing
---------------------------------------------------------------------------------------------------------
id | 1500.50 | 866.17 | 1500.50 | [1.00, 3000.00] | 0.00 | -1.20 | 3000 | 0
budget | 12.51 | 6.44 | 14.88 | [2.30, 19.76] | -0.87 | -1.09 | 3000 | 0
popularity | 8.46 | 12.10 | 6.88 | [1.00e-06, 294.34] | 14.38 | 280.10 | 3000 | 0
runtime | 107.85 | 22.08 | 24.00 | [0.00, 338.00] | 1.02 | 8.20 | 3000 | 0
revenue | 6.67e+07 | 1.38e+08 | 6.66e+07 | [1.00, 1.52e+09] | 4.54 | 27.78 | 3000 | 0
So bekommt man gleich noch ein paar Infos über die Verteilung der Variablen. Praktische Sache.
Das Test-Sample backen wir auch mal:
<-
d_test_baked bake(rec1_prepped, new_data = d_test)
%>%
d_test_baked head()
# A tibble: 6 × 22
id belongs_to_collection budget genres homepage imdb_id original_language
<dbl> <fct> <dbl> <fct> <fct> <fct> <fct>
1 3001 [{'id': 34055, 'name':… 2.30 [{'id… <NA> <NA> ja
2 3002 <NA> 11.4 [{'id… <NA> <NA> en
3 3003 <NA> 2.30 [{'id… <NA> <NA> en
4 3004 <NA> 15.7 <NA> <NA> <NA> fr
5 3005 <NA> 14.5 [{'id… <NA> <NA> en
6 3006 <NA> 2.30 [{'id… <NA> <NA> en
# ℹ 15 more variables: original_title <fct>, overview <fct>, popularity <dbl>,
# poster_path <fct>, production_companies <fct>, production_countries <fct>,
# release_date <fct>, runtime <dbl>, spoken_languages <fct>, status <fct>,
# tagline <fct>, title <fct>, Keywords <fct>, cast <fct>, crew <fct>
Kreuzvalidierung
<- vfold_cv(d_train,
cv_scheme v = 5,
repeats = 1)
Modelle
Baum
<-
mod_tree decision_tree(cost_complexity = tune(),
tree_depth = tune(),
mode = "regression")
Random Forest
<-
mod_rf rand_forest(mtry = tune(),
min_n = tune(),
trees = 1000,
mode = "regression") %>%
set_engine("ranger", num.threads = 4)
Workflows
<-
wf_tree workflow() %>%
add_model(mod_tree) %>%
add_recipe(rec1)
<-
wf_rf workflow() %>%
add_model(mod_rf) %>%
add_recipe(rec1)
Fitten und tunen
Um Rechenzeit zu sparen, kann man den Parameter grid
bei tune_grid()
auf einen kleinen Wert setzen. Der Default ist 10. Um gute Vorhersagen zu erzielen, sollte man den Wert tendenziell noch über 10 erhöhen.
Tree
Parallele Verarbeitung starten:
<- makePSOCKcluster(4) # Create 4 clusters
cl registerDoParallel(cl)
tic()
<-
tree_fit %>%
wf_tree tune_race_anova(
resamples = cv_scheme,
#grid = 2
)toc()
37.736 sec elapsed
Hilfe zu tune_grid()
bekommt man hier.
tree_fit
# Tuning results
# 5-fold cross-validation
# A tibble: 5 × 5
splits id .order .metrics .notes
<list> <chr> <int> <list> <list>
1 <split [2400/600]> Fold1 3 <tibble [20 × 6]> <tibble [0 × 3]>
2 <split [2400/600]> Fold2 1 <tibble [20 × 6]> <tibble [0 × 3]>
3 <split [2400/600]> Fold3 2 <tibble [20 × 6]> <tibble [0 × 3]>
4 <split [2400/600]> Fold5 4 <tibble [16 × 6]> <tibble [0 × 3]>
5 <split [2400/600]> Fold4 5 <tibble [14 × 6]> <tibble [0 × 3]>
Steht was in den .notes
?
".notes"]][[2]] tree_fit[[
# A tibble: 0 × 3
# ℹ 3 variables: location <chr>, type <chr>, note <chr>
Nein.
collect_metrics(tree_fit)
# A tibble: 14 × 8
cost_complexity tree_depth .metric .estimator mean n std_err .config
<dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
1 1.56e- 5 14 rmse standard 8.95e+7 5 4.65e+6 Prepro…
2 1.56e- 5 14 rsq standard 5.82e-1 5 3.16e-2 Prepro…
3 9.32e- 5 10 rmse standard 8.91e+7 5 4.66e+6 Prepro…
4 9.32e- 5 10 rsq standard 5.85e-1 5 3.11e-2 Prepro…
5 2.36e-10 5 rmse standard 8.80e+7 5 4.57e+6 Prepro…
6 2.36e-10 5 rsq standard 5.92e-1 5 3.20e-2 Prepro…
7 2.29e- 8 11 rmse standard 8.93e+7 5 4.67e+6 Prepro…
8 2.29e- 8 11 rsq standard 5.83e-1 5 3.10e-2 Prepro…
9 9.60e- 4 9 rmse standard 8.84e+7 5 5.00e+6 Prepro…
10 9.60e- 4 9 rsq standard 5.90e-1 5 3.22e-2 Prepro…
11 1.94e- 9 12 rmse standard 8.95e+7 5 4.64e+6 Prepro…
12 1.94e- 9 12 rsq standard 5.82e-1 5 3.10e-2 Prepro…
13 5.72e- 7 7 rmse standard 8.83e+7 5 4.73e+6 Prepro…
14 5.72e- 7 7 rsq standard 5.91e-1 5 3.38e-2 Prepro…
show_best(tree_fit)
Warning: No value of `metric` was given; metric 'rmse' will be used.
# A tibble: 5 × 8
cost_complexity tree_depth .metric .estimator mean n std_err .config
<dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
1 2.36e-10 5 rmse standard 88038619. 5 4572618. Prepro…
2 5.72e- 7 7 rmse standard 88262344. 5 4734314. Prepro…
3 9.60e- 4 9 rmse standard 88397994. 5 5003102. Prepro…
4 9.32e- 5 10 rmse standard 89140111. 5 4663576. Prepro…
5 2.29e- 8 11 rmse standard 89330466. 5 4668641. Prepro…
Finalisieren
<-
best_tree_wf %>%
wf_tree finalize_workflow(select_best(tree_fit))
Warning: No value of `metric` was given; metric 'rmse' will be used.
best_tree_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()
── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps
• step_mutate()
• step_log()
• step_impute_knn()
── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (regression)
Main Arguments:
cost_complexity = 2.36005153743282e-10
tree_depth = 5
Computational engine: rpart
<-
tree_last_fit fit(best_tree_wf, data = d_train)
tree_last_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()
── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps
• step_mutate()
• step_log()
• step_impute_knn()
── Model ───────────────────────────────────────────────────────────────────────
n= 3000
node), split, n, deviance, yval
* denotes terminal node
1) root 3000 5.672651e+19 66725850
2) budget< 18.32631 2845 1.958584e+19 46935270
4) budget< 17.19976 2252 5.443953e+18 25901120
8) popularity< 9.734966 1745 1.665118e+18 17076460
16) popularity< 5.761331 1019 3.184962e+17 8793730
32) budget< 15.44456 782 1.408243e+17 6074563 *
33) budget>=15.44456 237 1.528117e+17 17765830 *
17) popularity>=5.761331 726 1.178595e+18 28701940
34) budget< 16.15249 484 6.504138e+17 21093220 *
35) budget>=16.15249 242 4.441208e+17 43919380 *
9) popularity>=9.734966 507 3.175231e+18 56273980
18) budget< 15.36217 186 3.092335e+17 24880850
36) popularity< 14.04031 151 1.743659e+17 20728170 *
37) popularity>=14.04031 35 1.210294e+17 42796710 *
19) budget>=15.36217 321 2.576473e+18 74464390
38) popularity< 19.64394 300 2.025184e+18 68010500 *
39) popularity>=19.64394 21 3.602808e+17 166662900 *
5) budget>=17.19976 593 9.361685e+18 126815400
10) popularity< 19.63372 570 6.590372e+18 117422100
20) budget< 17.86726 374 2.692151e+18 94469490
40) popularity< 8.444193 149 6.363495e+17 68256660 *
41) popularity>=8.444193 225 1.885623e+18 111828200 *
21) budget>=17.86726 196 3.325222e+18 161219400
42) popularity< 11.60513 126 1.693483e+18 136587100 *
43) popularity>=11.60513 70 1.417677e+18 205557600 *
11) popularity>=19.63372 23 1.474624e+18 359605200
22) runtime>=109.5 16 9.882757e+17 299077200 *
23) runtime< 109.5 7 2.937458e+17 497955000 *
3) budget>=18.32631 155 1.557371e+19 429978800
6) popularity< 17.26579 101 4.711450e+18 299997300
12) budget< 18.73897 67 1.671489e+18 230290900
24) popularity< 12.66146 40 5.426991e+17 174328700
48) budget< 18.44536 18 1.099070e+17 134734600 *
49) budget>=18.44536 22 3.814856e+17 206724000 *
25) popularity>=12.66146 27 8.179336e+17 313197700
50) budget< 18.52944 13 1.273606e+17 234797100 *
51) budget>=18.52944 14 5.364675e+17 385998300 *
13) budget>=18.73897 34 2.072879e+18 437360100
26) runtime< 132.5 26 1.123840e+18 391271100
52) popularity< 11.34182 9 9.729505e+16 248614500 *
53) popularity>=11.34182 17 7.464210e+17 466795200 *
27) runtime>=132.5 8 7.143147e+17 587149400 *
7) popularity>=17.26579 54 5.964228e+18 673092200
14) budget< 18.99438 33 2.082469e+18 534404700
28) popularity< 25.35778 19 5.425201e+17 416871200 *
...
and 4 more lines.
Vorhersage Test-Sample
predict(tree_last_fit, new_data = d_test)
# A tibble: 4,398 × 1
.pred
<dbl>
1 6074563.
2 6074563.
3 21093221.
4 21093221.
5 6074563.
6 21093221.
7 6074563.
8 68256659.
9 43919378.
10 205557624.
# ℹ 4,388 more rows
RF
Fitten und Tunen
Um Rechenzeit zu sparen, kann man das Objekt, wenn einmal berechnet, abspeichern unter result_obj_path
auf der Festplatte und beim nächsten Mal importieren, das geht schneller als neu berechnen.
Das könnte dann z.B. so aussehen:
if (file.exists(result_obj_path)) {
<- read_rds(result_obj_path)
rf_fit else {
} tic()
<-
rf_fit %>%
wf_rf tune_grid(
resamples = cv_scheme)
toc()
}
Achtung Ein Ergebnisobjekt von der Festplatte zu laden ist gefährlich. Wenn Sie Ihr Modell verändern, aber vergessen, das Objekt auf der Festplatte zu aktualisieren, werden Ihre Ergebnisse falsch sein (da auf dem veralteten Objekt beruhend), ohne dass Sie durch eine Fehlermeldung von R gewarnt würden!
So kann man das Ergebnisobjekt auf die Festplatte schreiben:
#write_rds(rf_fit, file = "objects/tmbd_rf_fit1.rds")
Aber wir berechnen lieber neu:
tic()
<-
rf_fit %>%
wf_rf tune_grid(
resamples = cv_scheme
#grid = 2
)toc()
34.282 sec elapsed
collect_metrics(rf_fit)
# A tibble: 20 × 8
mtry min_n .metric .estimator mean n std_err .config
<int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
1 3 26 rmse standard 81496992. 5 4420334. Preprocessor1…
2 3 26 rsq standard 0.647 5 0.0319 Preprocessor1…
3 1 8 rmse standard 81104914. 5 4249148. Preprocessor1…
4 1 8 rsq standard 0.651 5 0.0270 Preprocessor1…
5 3 13 rmse standard 82253761. 5 4204371. Preprocessor1…
6 3 13 rsq standard 0.639 5 0.0316 Preprocessor1…
7 2 16 rmse standard 81466291. 5 4103501. Preprocessor1…
8 2 16 rsq standard 0.646 5 0.0298 Preprocessor1…
9 2 36 rmse standard 81355080. 5 4051776. Preprocessor1…
10 2 36 rsq standard 0.649 5 0.0281 Preprocessor1…
11 3 5 rmse standard 84125788. 5 4113181. Preprocessor1…
12 3 5 rsq standard 0.623 5 0.0347 Preprocessor1…
13 1 32 rmse standard 82381636. 5 4069505. Preprocessor1…
14 1 32 rsq standard 0.645 5 0.0230 Preprocessor1…
15 1 33 rmse standard 82130106. 5 3978566. Preprocessor1…
16 1 33 rsq standard 0.647 5 0.0231 Preprocessor1…
17 2 20 rmse standard 81547269. 5 4189669. Preprocessor1…
18 2 20 rsq standard 0.647 5 0.0294 Preprocessor1…
19 2 23 rmse standard 81351141. 5 4073682. Preprocessor1…
20 2 23 rsq standard 0.648 5 0.0285 Preprocessor1…
select_best(rf_fit)
Warning: No value of `metric` was given; metric 'rmse' will be used.
# A tibble: 1 × 3
mtry min_n .config
<int> <int> <chr>
1 1 8 Preprocessor1_Model02
Finalisieren
<-
final_wf %>%
wf_rf finalize_workflow(select_best(rf_fit))
Warning: No value of `metric` was given; metric 'rmse' will be used.
<-
final_fit fit(final_wf, data = d_train)
<-
final_preds %>%
final_fit predict(new_data = d_test) %>%
bind_cols(d_test)
<-
submission %>%
final_preds select(id, revenue = .pred)
Abspeichern und einreichen:
write_csv(submission, file = "submission.csv")
Kaggle Score
Diese Submission erzielte einen Score von 2.7664 (RMSLE).
<- 2.7664 sol
Categories:
- ds1
- tidymodels
- statlearning
- tmdb
- trees
- num