# Replication materials for the paper: # ---------------------------------------------------------------------------- # dynamite: An R Package for Dynamic Multivariate Panel Models # ---------------------------------------------------------------------------- # By Santtu Tikka and Jouni Helske # # The results of the paper were obtained under the following configuration: # # R version 4.4.2 (2024-10-31 ucrt) # Platform: x86_64-w64-mingw32/x64 # Running under: Windows 11 x64 (build 26100) # # Matrix products: default # # # locale: # [1] LC_COLLATE=English_Finland.utf8 LC_CTYPE=English_Finland.utf8 # [3] LC_MONETARY=English_Finland.utf8 LC_NUMERIC=C # [5] LC_TIME=English_Finland.utf8 # # time zone: Europe/Helsinki # tzcode source: internal # # attached base packages: # [1] stats graphics grDevices utils datasets methods base # # other attached packages: # [1] pryr_0.1.6 dplyr_1.1.4 ggplot2_3.5.1 dynamite_1.5.2 # # loaded via a namespace (and not attached): # [1] tensorA_0.36.2.1 utf8_1.2.4 generics_0.1.3 # [4] stringi_1.8.4 magrittr_2.0.3 evaluate_1.0.1 # [7] grid_4.4.2 lobstr_1.1.2 jsonlite_1.8.9 # [10] pkgbuild_1.4.5 backports_1.5.0 gridExtra_2.3 # [13] fansi_1.0.6 QuickJSR_1.4.0 scales_1.3.0 # [16] tweenr_2.0.3 codetools_0.2-20 abind_1.4-8 # [19] cli_3.6.3 rlang_1.1.4 polyclip_1.10-7 # [22] splines_4.4.2 munsell_0.5.1 withr_3.0.2 # [25] StanHeaders_2.32.10 tools_4.4.2 rstan_2.32.6 # [28] inline_0.3.20 parallel_4.4.2 checkmate_2.3.2 # [31] colorspace_2.1-1 curl_6.0.0 vctrs_0.6.5 # [34] posterior_1.6.0 R6_2.5.1 matrixStats_1.4.1 # [37] stats4_4.4.2 lifecycle_1.0.4 stringr_1.5.1 # [40] V8_6.0.0 MASS_7.3-61 pkgconfig_2.0.3 # [43] RcppParallel_5.1.9 pillar_1.9.0 gtable_0.3.6 # [46] loo_2.8.0 data.table_1.16.2 glue_1.8.0 # [49] Rcpp_1.0.13-1 ggforce_0.4.2 xfun_0.49 # [52] tibble_3.2.1 tidyselect_1.2.1 rstudioapi_0.17.1 # [55] knitr_1.49 farver_2.1.2 patchwork_1.3.0 # [58] labeling_0.4.3 compiler_4.4.2 prettyunits_1.2.0 # [61] distributional_0.5.0 # -------------------------------------------------------------------------- # Install dynamite and other packages used # install.packages(c("dynamite", "dplyr", "ggplot2", "pder", "pryr")) suppressPackageStartupMessages({ library("dynamite") library("ggplot2") library("dplyr") library("pryr") }) theme_set(theme_bw()) set.seed(0) # Remove unnecessary messages from dplyr::summarize() options(dplyr.summarise.inform = FALSE) # Please note that because dynamite uses Stan, exact reproducibility # is not always possible. For further information, see: # https://mc-stan.org/docs/reference-manual/reproducibility.html # Section 3.1 ------------------------------------------------------------- # Seat belt data data("SeatBelt", package = "pder") # Seat belt data wrangling seatbelt <- SeatBelt |> mutate( miles = (vmturban + vmtrural) / 10000, log_miles = log(miles), fatalities = farsocc, income10000 = percapin / 10000, law = factor( case_when( dp == 1 ~ "primary", dsp == 1 ~ "primary", ds == 1 & dsp == 0 ~ "secondary", TRUE ~ "no_law" ), levels = c("no_law", "secondary", "primary") ) ) # Seat belt model formula seatbelt_formula <- obs(usage ~ -1 + law + random(~1) + varying(~1), family = "beta") + obs(fatalities ~ usage + densurb + densrur + bac08 + mlda21 + lim65 + lim70p + income10000 + unemp + fueltax + random(~1) + offset(log_miles), family = "negbin") + splines(df = 10) # Seat belt model fit fit <- dynamite( dformula = seatbelt_formula, data = seatbelt, time = "year", group = "state", chains = 4, cores = 4, seed = 0, refresh = 0 ) # Seat belt model summary summary(fit, types = "beta", response = "usage") |> select(parameter, mean, sd, q5, q95) # Seat belt model fitted values seatbelt_new <- seatbelt seatbelt_new$law[] <- "no_law" pnl <- fitted(fit, newdata = seatbelt_new) seatbelt_new$law[] <- "secondary" psl <- fitted(fit, newdata = seatbelt_new) seatbelt_new$law[] <- "primary" ppl <- fitted(fit, newdata = seatbelt_new) bind_rows(no_law = pnl, secondary = psl, primary = ppl, .id = "law") |> mutate( law = factor(law, levels = c("no_law", "secondary", "primary")) ) |> group_by(law, .draw) |> summarize(mm = mean(usage_fitted)) |> group_by(law) |> summarize( mean = mean(mm), q5 = quantile(mm, 0.05), q95 = quantile(mm, 0.95) ) # Seat belt model fitted fatalities by usage seatbelt_new <- seatbelt seatbelt_new$usage[] <- 0.68 p68 <- fitted(fit, newdata = seatbelt_new) seatbelt_new$usage[] <- 0.90 p90 <- fitted(fit, newdata = seatbelt_new) bind_rows(low = p68, high = p90, .id = "usage") |> group_by(year, .draw) |> summarize( s = sum( fatalities_fitted[usage == "low"] - fatalities_fitted[usage == "high"] ) ) |> group_by(.draw) |> summarize(m = mean(s)) |> summarize( mean = mean(m), q5 = quantile(m, 0.05), q95 = quantile(m, 0.95) ) # Section 3.2 ------------------------------------------------------------- # Multichannel model example head(multichannel_example) # Multichannel model formula multi_formula <- obs(g ~ lag(g) + lag(logp), family = "gaussian") + obs(p ~ lag(g) + lag(logp) + lag(b), family = "poisson") + obs(b ~ lag(b) * lag(logp) + lag(b) * lag(g), family = "bernoulli") + aux(numeric(logp) ~ log(p + 1) | init(0)) # Multichannel model fit multichannel_fit <- dynamite( dformula = multi_formula, data = multichannel_example, time = "time", group = "id", chains = 4, cores = 4, seed = 0, refresh = 0 ) # Multichannel model fit betas plot(multichannel_fit, types = "beta") + labs(title = "") # Multichannel model newdata for predict multichannel_newdata <- multichannel_example |> mutate(across(g:b, ~ ifelse(time > 5, NA, .x))) # Multichannel model predictions new0 <- multichannel_newdata |> mutate(b = ifelse(time == 5, 0, b)) pred0 <- predict(multichannel_fit, newdata = new0, type = "mean") new1 <- multichannel_newdata |> mutate(b = ifelse(time == 5, 1, b)) pred1 <- predict(multichannel_fit, newdata = new1, type = "mean") # Prediction output head(pred0, n = 10) |> round(3) # Multichannel model summarized predictions sumr <- list(b0 = pred0, b1 = pred1) |> bind_rows(.id = "case") |> group_by(case, .draw, time) |> summarize(mean_t = mean(g_mean)) |> group_by(case, time) |> summarize( mean = mean(mean_t), q5 = quantile(mean_t, 0.05, na.rm = TRUE), q95 = quantile(mean_t, 0.95, na.rm = TRUE) ) # Multichannel model summarized predictions using funs pred0b <- predict( multichannel_fit, newdata = new0, type = "mean", funs = list(g = list(mean_t = mean)) )$simulated pred1b <- predict( multichannel_fit, newdata = new1, type = "mean", funs = list(g = list(mean_t = mean)) )$simulated sumrb <- list(b0 = pred0b, b1 = pred1b) |> bind_rows(.id = "case") |> group_by(case, time) |> summarize( mean = mean(mean_t_g), q5 = quantile(mean_t_g, 0.05, na.rm = TRUE), q95 = quantile(mean_t_g, 0.95, na.rm = TRUE) ) # Plot of the predictions ggplot(sumr, aes(time, mean)) + geom_ribbon(aes(ymin = q5, ymax = q95), alpha = 0.5) + geom_line(na.rm = TRUE) + scale_x_continuous(n.breaks = 10) + facet_wrap(~case) # Difference of the interventions sumr_diff <- list(b0 = pred0, b1 = pred1) |> bind_rows(.id = "case") |> group_by(.draw, time) |> summarize( mean_t = mean(g_mean[case == "b1"] - g_mean[case == "b0"]) ) |> group_by(time) |> summarize( mean = mean(mean_t), q5 = quantile(mean_t, 0.05, na.rm = TRUE), q95 = quantile(mean_t, 0.95, na.rm = TRUE) ) # Plot of the difference of the expected causal effects ggplot(sumr_diff, aes(time, mean)) + geom_ribbon(aes(ymin = q5, ymax = q95), alpha = 0.5) + geom_line(na.rm = TRUE) + scale_x_continuous(n.breaks = 10) # Section 4.1 ------------------------------------------------------------- # Formula construction and print method dform <- obs(y ~ lag(x), family = "gaussian") + obs(x ~ z, family = "poisson") print(dform) # Section 4.3 ------------------------------------------------------------- # Multiple intercepts results in a warning: # this warning shows that when both intercepts are defined # in the model formula, the model will default to a time-varying intercept obs(y ~ 1 + varying(~1), family = "gaussian") # Section 4.8 ------------------------------------------------------------- # Missing auxiliary channel type declaration results in a warning: # If the type declaration is missing, a `numeric` type is assumed by default aux(log1x ~ log(1 + x) | init(0)) # Section 4.9 ------------------------------------------------------------- # DAGs of the multichannel model plot(multi_formula) plot(multi_formula, show_auxiliary = FALSE) # Section 5.1 ------------------------------------------------------------- # Priors of the gaussian_example model get_priors(gaussian_example_fit) # Section 5.2 ------------------------------------------------------------- # Example model outputs print(gaussian_example_fit) mcmc_diagnostics(gaussian_example_fit) as.data.frame( gaussian_example_fit, responses = "y", types = "beta", summary = TRUE ) cat(get_code(gaussian_example_fit, blocks = "parameters")) # Section 5.3 ------------------------------------------------------------- # Default plot for time-varying parameters plot( gaussian_example_fit, types = c("alpha", "delta"), scales = "free" ) + labs(title = "") # Marginal posterior density and traceplot with plot_type = "trace" # for time-invariant parameter beta_y_z plot(gaussian_example_fit, plot_type = "trace", types = "beta") # Section 6 --------------------------------------------------------------- # Predictions for the gaussian_example model pred <- predict(gaussian_example_fit, n_draws = 50) pred |> dplyr::filter(id < 5) |> ggplot(aes(time, y_new, group = .draw)) + geom_line(alpha = 0.5) + geom_line(aes(y = y), colour = "tomato") + facet_wrap(~id) # Section 6.1 ------------------------------------------------------------- # Summarized predictions using the funs argument pred_funs <- predict( gaussian_example_fit, funs = list(y = list(mean = mean, sd = sd)) ) head(pred_funs$simulated) # Difference in memory consumption pred_full <- predict(gaussian_example_fit) object_size(pred_full) object_size(pred_funs) # Aggregated predictions using type = "mean" and the funs argument pred_funs_mean <- predict( gaussian_example_fit, type = "mean", funs = list(y = list(mean = mean, sd = sd)) ) head(pred_funs_mean$simulated) # Session information ----------------------------------------------------- sessionInfo()