From ce41f4222a69f5589740fe6a533e69654b8400fc Mon Sep 17 00:00:00 2001 From: Adam Howes Date: Tue, 3 Dec 2024 13:22:46 +0000 Subject: [PATCH] Issue #221: Add marginal model (#426) * Add cohort model template * Fix to previous commit (name as cohort_model) * Generate simulated cohort data * Add unweighted and weighted direct models * Thinking about custom family for pcd function * functions component of stanvars * Add transformed parameters for cohort model * Progress on implementing PCD model * Set q as vector * Get rid of params input * This would work, apart from it's the CDF. For the PMF need to import many primarycensored functions.. * Almost working with "import all functions" strategy * Wrap up on attempt * Rename to marginal model * Move towards single wrap function with others imported * Just running into some C++ errors.. * This doesn't change anything * Move to marginal model name and lint * Create aggregate data inside model conversion function for now * First draft on moving marginal_model into functions * Run document * Tests working up to valid Stan code * Regex version of marginal model * Use prep_marginal_obs * Improve assert for marginal model * Add pkgdown and document * Clean up scratch implementation * remove scratch file * update data format, formula, and family * update stan code * basic working version * add transformm data methods * start exploring the ebola example * add a helper to find meaningful relative_obs_times * get the full ebola vignette working with new variable requirements * improve return messages * update approx vignette * get ebole vignette passing by checking pp and related inputs * add marginal model integration tests * expand post processing tests * add marginal model * use the right transform data keyword * add ... pass through to make constructors correct * fix .summarise_n_by_formula test so error message is as expected * drop not required .row_id * check using ... properly * make the progress messages prettier for reducing data complexit: * check post process tests again * add a test for the specific transform data method * add some tests for the generic transform data method * put transform data tests in the correct folder * add a news update * change vignette language to talk about marginal model * update the FAQ to use the marginal variables * call it transformed_data not trans_data * change the error message to make it clear its a epidist limitation * update stan docs * Update NEWS.md * Update NEWS.md Co-authored-by: Adam Howes * Update NEWS.md Co-authored-by: Adam Howes * Update inst/stan/latent_model/functions.stan Co-authored-by: Adam Howes * Update setup.R --------- Co-authored-by: Sam --- .Rbuildignore | 1 + NAMESPACE | 16 ++ NEWS.md | 11 +- R/epidist.R | 10 +- R/latent_model.R | 9 +- R/marginal_model.R | 270 ++++++++++++++++++ R/transform_data.R | 37 +++ R/utils.R | 40 +++ _pkgdown.yml | 8 + inst/stan/latent_model/functions.stan | 8 +- inst/stan/marginal_model/functions.stan | 29 ++ man/as_epidist_latent_model.Rd | 4 +- ...dist_latent_model.epidist_linelist_data.Rd | 4 +- man/as_epidist_marginal_model.Rd | 25 ++ ...st_marginal_model.epidist_linelist_data.Rd | 31 ++ man/dot-extract_dpar_terms.Rd | 19 ++ man/dot-summarise_n_by_formula.Rd | 24 ++ ...ist_family_model.epidist_marginal_model.Rd | 28 ++ ...st_formula_model.epidist_marginal_model.Rd | 31 ++ man/epidist_transform_data.Rd | 30 ++ man/epidist_transform_data_model.Rd | 27 ++ man/epidist_transform_data_model.default.Rd | 27 ++ man/is_epidist_marginal_model.Rd | 23 ++ man/new_epidist_latent_model.Rd | 4 +- man/new_epidist_marginal_model.Rd | 26 ++ tests/testthat/setup.R | 23 +- tests/testthat/test-gen.R | 190 ++++++------ tests/testthat/test-int-marginal_model.R | 82 ++++++ tests/testthat/test-marginal_model.R | 107 +++++++ tests/testthat/test-postprocess.R | 105 ++++--- tests/testthat/test-transform_data.R | 39 +++ tests/testthat/test-utils.R | 55 ++++ vignettes/approx-inference.Rmd | 6 +- vignettes/ebola.Rmd | 60 ++-- vignettes/epidist.Rmd | 4 +- vignettes/faq.Rmd | 22 +- 36 files changed, 1248 insertions(+), 187 deletions(-) create mode 100644 R/marginal_model.R create mode 100644 R/transform_data.R create mode 100644 inst/stan/marginal_model/functions.stan create mode 100644 man/as_epidist_marginal_model.Rd create mode 100644 man/as_epidist_marginal_model.epidist_linelist_data.Rd create mode 100644 man/dot-extract_dpar_terms.Rd create mode 100644 man/dot-summarise_n_by_formula.Rd create mode 100644 man/epidist_family_model.epidist_marginal_model.Rd create mode 100644 man/epidist_formula_model.epidist_marginal_model.Rd create mode 100644 man/epidist_transform_data.Rd create mode 100644 man/epidist_transform_data_model.Rd create mode 100644 man/epidist_transform_data_model.default.Rd create mode 100644 man/is_epidist_marginal_model.Rd create mode 100644 man/new_epidist_marginal_model.Rd create mode 100644 tests/testthat/test-int-marginal_model.R create mode 100644 tests/testthat/test-marginal_model.R create mode 100644 tests/testthat/test-transform_data.R diff --git a/.Rbuildignore b/.Rbuildignore index be0e79bdb..151bca3d6 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -17,4 +17,5 @@ ^pkgdown$ ^vignettes/approx-inference\.Rmd$ ^vignettes/ebola\.Rmd$ +^vignettes/faq\.Rmd$ ^\.lintr$ diff --git a/NAMESPACE b/NAMESPACE index 93ba3ff12..8faf96881 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,26 +6,34 @@ S3method(add_mean_sd,lognormal_samples) S3method(as_epidist_latent_model,epidist_linelist_data) S3method(as_epidist_linelist_data,data.frame) S3method(as_epidist_linelist_data,default) +S3method(as_epidist_marginal_model,epidist_linelist_data) S3method(as_epidist_naive_model,epidist_linelist_data) S3method(assert_epidist,default) S3method(assert_epidist,epidist_latent_model) S3method(assert_epidist,epidist_linelist_data) +S3method(assert_epidist,epidist_marginal_model) S3method(assert_epidist,epidist_naive_model) S3method(epidist_family_model,default) S3method(epidist_family_model,epidist_latent_model) +S3method(epidist_family_model,epidist_marginal_model) S3method(epidist_family_param,default) S3method(epidist_family_prior,default) S3method(epidist_family_prior,lognormal) S3method(epidist_formula_model,default) S3method(epidist_formula_model,epidist_latent_model) +S3method(epidist_formula_model,epidist_marginal_model) S3method(epidist_model_prior,default) S3method(epidist_model_prior,epidist_latent_model) S3method(epidist_stancode,default) S3method(epidist_stancode,epidist_latent_model) +S3method(epidist_stancode,epidist_marginal_model) +S3method(epidist_transform_data_model,default) +S3method(epidist_transform_data_model,epidist_marginal_model) export(Gamma) export(add_mean_sd) export(as_epidist_latent_model) export(as_epidist_linelist_data) +export(as_epidist_marginal_model) export(as_epidist_naive_model) export(assert_epidist) export(bf) @@ -42,12 +50,16 @@ export(epidist_gen_posterior_predict) export(epidist_model_prior) export(epidist_prior) export(epidist_stancode) +export(epidist_transform_data) +export(epidist_transform_data_model) export(is_epidist_latent_model) export(is_epidist_linelist_data) +export(is_epidist_marginal_model) export(is_epidist_naive_model) export(lognormal) export(new_epidist_latent_model) export(new_epidist_linelist_data) +export(new_epidist_marginal_model) export(new_epidist_naive_model) export(predict_delay_parameters) export(predict_dpar) @@ -78,14 +90,18 @@ importFrom(cli,cli_abort) importFrom(cli,cli_alert_info) importFrom(cli,cli_inform) importFrom(cli,cli_warn) +importFrom(dplyr,across) importFrom(dplyr,bind_cols) importFrom(dplyr,bind_rows) importFrom(dplyr,filter) importFrom(dplyr,full_join) +importFrom(dplyr,group_by) importFrom(dplyr,mutate) importFrom(dplyr,select) +importFrom(dplyr,summarise) importFrom(lubridate,days) importFrom(lubridate,is.timepoint) +importFrom(purrr,map_chr) importFrom(purrr,map_dbl) importFrom(stats,Gamma) importFrom(stats,as.formula) diff --git a/NEWS.md b/NEWS.md index 63d0ca7c8..14395a9ad 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,19 +2,26 @@ Development version of `epidist`. +## Models + +- Added a marginalised likelihood model based on `primarycensored`. This can be specified using `as_epidist_marginal_model()`. This is currently limited to Weibull, log-normal, and gamma distributions with uniform primary censoring but this will be generalised in future releases. See #426. +- Added user settable primary event priors to the latent model. See #474. +- Added a marginalised likelihood to the latent model. See #474. + ## Package - Remove the default method for `epidist()`. See #473. - Added `enforce_presence` argument to `epidist_prior()` to allow for priors to be specified if they do not match existing parameters. See #474. - Added a `merge` argument to `epidist_prior()` to allow for not merging user and package priors. See #474. -- Added user settable primary event priors to the latent model. See #474. -- Added a marginalised likelihood to the latent model. See #474. - Generalised the stan reparametrisation feature to work across all distributions without manual specification by generating stan code with `brms` and then extracting the reparameterisation. See #474. +- Added a `transform_data` S3 method to allow for data to be transformed for specific models. This is specifically useful for the marginal model at the moment as it allows reducing the data to its unique strata. See #474. ## Documentation - Brings the README into line with `epinowcast` standards. See #467. +- Switched over to using the marginal model as default in documentation. See #426. +- Added helper functions for new variables to avoid code duplication in vignettes. See #426. # epidist 0.1.0 diff --git a/R/epidist.R b/R/epidist.R index 78db416a8..a4718c229 100644 --- a/R/epidist.R +++ b/R/epidist.R @@ -38,16 +38,20 @@ epidist <- function(data, formula = mu ~ 1, epidist_formula <- epidist_formula( data = data, family = epidist_family, formula = formula ) + transformed_data <- epidist_transform_data( + data, epidist_family, epidist_formula + ) epidist_prior <- epidist_prior( - data = data, family = epidist_family, formula = epidist_formula, prior, + data = transformed_data, family = epidist_family, + formula = epidist_formula, prior, merge = merge_priors ) epidist_stancode <- epidist_stancode( - data = data, family = epidist_family, formula = epidist_formula + data = transformed_data, family = epidist_family, formula = epidist_formula ) fit <- fn( formula = epidist_formula, family = epidist_family, prior = epidist_prior, - stanvars = epidist_stancode, data = data, ... + stanvars = epidist_stancode, data = transformed_data, ... ) class(fit) <- c(class(fit), "epidist_fit") return(fit) diff --git a/R/latent_model.R b/R/latent_model.R index 3d5c72dbb..de21666aa 100644 --- a/R/latent_model.R +++ b/R/latent_model.R @@ -1,9 +1,10 @@ #' Convert an object to an `epidist_latent_model` object #' #' @param data An object to be converted to the class `epidist_latent_model` +#' @param ... Additional arguments passed to methods. #' @family latent_model #' @export -as_epidist_latent_model <- function(data) { +as_epidist_latent_model <- function(data, ...) { UseMethod("as_epidist_latent_model") } @@ -11,11 +12,12 @@ as_epidist_latent_model <- function(data) { #' The latent model method for `epidist_linelist_data` objects #' #' @param data An `epidist_linelist_data` object +#' @param ... Not used in this method. #' @method as_epidist_latent_model epidist_linelist_data #' @family latent_model #' @autoglobal #' @export -as_epidist_latent_model.epidist_linelist_data <- function(data) { +as_epidist_latent_model.epidist_linelist_data <- function(data, ...) { assert_epidist(data) data <- data |> mutate( @@ -38,10 +40,11 @@ as_epidist_latent_model.epidist_linelist_data <- function(data) { #' Class constructor for `epidist_latent_model` objects #' #' @param data An object to be set with the class `epidist_latent_model` +#' @param ... Additional arguments passed to methods. #' @returns An object of class `epidist_latent_model` #' @family latent_model #' @export -new_epidist_latent_model <- function(data) { +new_epidist_latent_model <- function(data, ...) { class(data) <- c("epidist_latent_model", class(data)) return(data) } diff --git a/R/marginal_model.R b/R/marginal_model.R new file mode 100644 index 000000000..21837d506 --- /dev/null +++ b/R/marginal_model.R @@ -0,0 +1,270 @@ +#' Prepare marginal model to pass through to `brms` +#' +#' @param data A `data.frame` containing line list data +#' @param ... Additional arguments passed to methods. +#' @family marginal_model +#' @export +as_epidist_marginal_model <- function(data, ...) { + UseMethod("as_epidist_marginal_model") +} + +#' The marginal model method for `epidist_linelist_data` objects +#' +#' @param data An `epidist_linelist_data` object +#' @param obs_time_threshold Ratio used to determine threshold for setting +#' relative observation times to Inf. Observation times greater than +#' `obs_time_threshold` times the maximum delay will be set to Inf to improve +#' model efficiency by reducing the number of unique observation times. +#' Default is 2. +#' @param ... Not used in this method. +#' @method as_epidist_marginal_model epidist_linelist_data +#' @family marginal_model +#' @autoglobal +#' @export +as_epidist_marginal_model.epidist_linelist_data <- function( + data, obs_time_threshold = 2, ...) { + assert_epidist(data) + + data <- data |> + mutate( + pwindow = .data$ptime_upr - .data$ptime_lwr, + swindow = .data$stime_upr - .data$stime_lwr, + relative_obs_time = .data$obs_time - .data$ptime_lwr, + orig_relative_obs_time = .data$obs_time - .data$ptime_lwr, + delay_lwr = .data$stime_lwr - .data$ptime_lwr, + delay_upr = .data$stime_upr - .data$ptime_lwr, + n = 1 + ) + + # Calculate maximum delay + max_delay <- max(data$delay_upr, na.rm = TRUE) + threshold <- max_delay * obs_time_threshold + + # Count observations beyond threshold + n_beyond <- sum(data$relative_obs_time > threshold, na.rm = TRUE) + + if (n_beyond > 0) { + cli::cli_inform(c( + "!" = paste0( + "Setting {n_beyond} observation time{?s} beyond ", + "{threshold} (={obs_time_threshold}x max delay) to Inf. ", + "This improves model efficiency by reducing unique observation times ", + "while maintaining model accuracy as these times should have ", + "negligible impact." + ) + )) + data$relative_obs_time[data$relative_obs_time > threshold] <- Inf + } + + data <- new_epidist_marginal_model(data) + assert_epidist(data) + return(data) +} + +#' Class constructor for `epidist_marginal_model` objects +#' +#' @param data A data.frame to convert +#' @returns An object of class `epidist_marginal_model` +#' @family marginal_model +#' @export +new_epidist_marginal_model <- function(data) { + class(data) <- c("epidist_marginal_model", class(data)) + return(data) +} + +#' @method assert_epidist epidist_marginal_model +#' @family marginal_model +#' @export +assert_epidist.epidist_marginal_model <- function(data, ...) { + assert_data_frame(data) + assert_names(names(data), must.include = c( + "pwindow", "swindow", "delay_lwr", "delay_upr", "n", + "relative_obs_time" + )) + assert_numeric(data$pwindow, lower = 0) + assert_numeric(data$swindow, lower = 0) + assert_integerish(data$delay_lwr) + assert_integerish(data$delay_upr) + assert_numeric(data$relative_obs_time) + if (!all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10)) { + cli::cli_abort( + "delay_upr must equal delay_lwr + swindow" + ) + } + if (!all(data$relative_obs_time >= data$delay_upr)) { + cli::cli_abort( + "relative_obs_time must be greater than or equal to delay_upr" + ) + } + assert_numeric(data$n, lower = 1) +} + +#' Check if data has the `epidist_marginal_model` class +#' +#' @param data A `data.frame` containing line list data +#' @family marginal_model +#' @export +is_epidist_marginal_model <- function(data) { + inherits(data, "epidist_marginal_model") +} + +#' Create the model-specific component of an `epidist` custom family +#' +#' @inheritParams epidist_family_model +#' @param ... Additional arguments passed to method. +#' @method epidist_family_model epidist_marginal_model +#' @family marginal_model +#' @export +epidist_family_model.epidist_marginal_model <- function( + data, family, ...) { + custom_family <- brms::custom_family( + paste0("marginal_", family$family), + dpars = family$dpars, + links = c(family$link, family$other_links), + lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))), + ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))), + type = "int", + vars = c( + "vreal1[n]", "vreal2[n]", "vreal3[n]", "vreal4[n]", "primary_params" + ), + loop = TRUE, + log_lik = epidist_gen_log_lik(family), + posterior_predict = epidist_gen_posterior_predict(family), + posterior_epred = epidist_gen_posterior_epred(family) + ) + return(custom_family) +} + +#' Define the model-specific component of an `epidist` custom formula +#' +#' @inheritParams epidist_formula_model +#' @param ... Additional arguments passed to method. +#' @method epidist_formula_model epidist_marginal_model +#' @family marginal_model +#' @export +epidist_formula_model.epidist_marginal_model <- function( + data, formula, ...) { + # data is only used to dispatch on + formula <- stats::update( + formula, delay_lwr | weights(n) + + vreal(relative_obs_time, pwindow, swindow, delay_upr) ~ . + ) + return(formula) +} + +#' @method epidist_transform_data_model epidist_marginal_model +#' @family marginal_model +#' @importFrom purrr map_chr +#' @export +epidist_transform_data_model.epidist_marginal_model <- function( + data, family, formula, ...) { + required_cols <- c( + "delay_lwr", "delay_upr", "relative_obs_time", "pwindow", "swindow" + ) + n_rows_before <- nrow(data) + + trans_data <- data |> + .summarise_n_by_formula(by = required_cols, formula = formula) |> + new_epidist_marginal_model() + n_rows_after <- nrow(trans_data) + if (n_rows_before > n_rows_after) { + cli::cli_inform(c( + "i" = "Data summarised by unique combinations of:" # nolint + )) + + formula_vars <- setdiff(names(trans_data), c(required_cols, "n")) + if (length(formula_vars) > 0) { + cli::cli_inform(c( + "*" = "Formula variables: {.code {formula_vars}}" + )) + } + + cli::cli_inform(paste0( + "* Model variables: delay bounds, observation time, ", + "and primary censoring window" + )) + + cli::cli_inform(c( + "!" = paste("Reduced from", n_rows_before, "to", n_rows_after, "rows."), + "i" = "This should improve model efficiency with no loss of information." # nolint + )) + } + + return(trans_data) +} + +#' @method epidist_stancode epidist_marginal_model +#' @importFrom brms stanvar +#' @family marginal_model +#' @autoglobal +#' @export +epidist_stancode.epidist_marginal_model <- function( + data, + family = epidist_family(data), + formula = epidist_formula(data), ...) { + assert_epidist(data) + + stanvars_version <- .version_stanvar() + + stanvars_functions <- brms::stanvar( + block = "functions", + scode = .stan_chunk(file.path("marginal_model", "functions.stan")) + ) + + family_name <- gsub("marginal_", "", family$name, fixed = TRUE) + + stanvars_functions[[1]]$scode <- gsub( + "family", family_name, stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + if (family_name == "lognormal") { + dist_id <- 1 + } else if (family_name == "gamma") { + dist_id <- 2 + } else if (family_name == "weibell") { + dist_id <- 3 + } else { + cli_abort(c( + "!" = "epidist does not currently support this family for the marginal model" # nolint + )) + } + + # Replace the dist_id passed to primarycensored + stanvars_functions[[1]]$scode <- gsub( + "dist_id", dist_id, stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + stanvars_functions[[1]]$scode <- gsub( + "dpars_A", + toString(paste0("real ", family$dpars)), + stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + stanvars_functions[[1]]$scode <- gsub( + "dpars_B", family$param, stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + stanvars_functions[[1]]$scode <- gsub( + "primary_id", "1", stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + stanvars_parameters <- brms::stanvar( + block = "parameters", + scode = "array[0] real primary_params;" + ) + + pcd_stanvars_functions <- brms::stanvar( + block = "functions", + scode = primarycensored::pcd_load_stan_functions() + ) + + stanvars_all <- stanvars_version + stanvars_functions + + pcd_stanvars_functions + stanvars_parameters + + return(stanvars_all) +} diff --git a/R/transform_data.R b/R/transform_data.R new file mode 100644 index 000000000..4cbc205f1 --- /dev/null +++ b/R/transform_data.R @@ -0,0 +1,37 @@ +#' Transform data for an epidist model +#' +#' This function is used within [epidist()] to transform data before passing to +#' `brms`. It is unlikely that as a user you will need this function, but we +#' export it nonetheless to be transparent about what happens inside of a call +#' to [epidist()]. +#' +#' @inheritParams epidist +#' @param family A description of the response distribution and link function to +#' be used in the model created using [epidist_family()]. +#' @param formula A formula object created using [epidist_formula()]. +#' @family transform_data +#' @export +epidist_transform_data <- function(data, family, formula, ...) { + assert_epidist(data) + data <- epidist_transform_data_model(data, family, formula) + return(data) +} + +#' The model-specific parts of an `epidist_transform_data()` call +#' +#' @inheritParams epidist_transform_data +#' @rdname epidist_transform_data_model +#' @family transform_data +#' @export +epidist_transform_data_model <- function(data, family, formula, ...) { + UseMethod("epidist_transform_data_model") +} + +#' Default method for transforming data for a model +#' +#' @inheritParams epidist_transform_data_model +#' @family transform_data +#' @export +epidist_transform_data_model.default <- function(data, family, formula, ...) { + return(data) +} diff --git a/R/utils.R b/R/utils.R index 826beb3e9..51d6b7da7 100644 --- a/R/utils.R +++ b/R/utils.R @@ -204,6 +204,46 @@ return(formula) } +#' Extract distributional parameter terms from a brms formula +#' +#' This function extracts all unique terms from the right-hand side of all +#' distributional parameters in a brms formula. +#' +#' @param formula A `brms formula object +#' @return A character vector of unique terms +#' @keywords internal +.extract_dpar_terms <- function(formula) { + terms <- brms::brmsterms(formula) + # Extract all terms from the right hand side of all dpars + dpar_terms <- purrr::map(terms$dpars, \(x) all.vars(x$allvars)) + dpar_terms <- unique(unlist(dpar_terms)) + return(dpar_terms) +} + +#' Summarise data by grouping variables and count occurrences +#' +#' @param data A `data.frame` to summarise which must contain a `n` column +#' which is a count of occurrences. +#' @param by Character vector of column names to group by. +#' @param formula Optional `brms` formula object to extract additional grouping +#' terms from. +#' @return A `data.frame` summarised by the grouping variables with counts +#' @keywords internal +#' @importFrom dplyr group_by summarise across +.summarise_n_by_formula <- function(data, by = character(), formula = NULL) { + if (!is.null(formula)) { + formula_terms <- .extract_dpar_terms(formula) + by <- c(by, formula_terms) + } + # Remove duplicates + by <- unique(by) + + data |> + tibble::as_tibble() |> + summarise(n = sum(.data$n), .by = dplyr::all_of(by)) +} + + #' Rename the columns of a `data.frame` #' #' @param df ... diff --git a/_pkgdown.yml b/_pkgdown.yml index 51ef0149f..5a1cff9c7 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -45,6 +45,10 @@ reference: desc: Specific methods for the latent model contents: - has_concept("latent_model") +- title: Marginal model + desc: Specific methods for the marginal model + contents: + - has_concept("marginal_model") - title: Postprocess desc: Functions for postprocessing model output contents: @@ -64,6 +68,10 @@ reference: desc: Functions related to specifying custom `brms` formula contents: - has_concept("formula") +- title: Transform data + desc: Transform data using the formula and family information + contents: + - has_concept("transform_data") - title: Prior distributions desc: Functions for specifying prior distributions contents: diff --git a/inst/stan/latent_model/functions.stan b/inst/stan/latent_model/functions.stan index 3134f7f95..fba796829 100644 --- a/inst/stan/latent_model/functions.stan +++ b/inst/stan/latent_model/functions.stan @@ -3,11 +3,11 @@ * * This function is designed to be read into R where: * - 'family' is replaced with the target distribution (e.g., 'lognormal') - * - 'dpars_A' is replaced with multiple parameters in the format + * - 'dpars_A' is replaced with multiple distribution parameters in the format * "vector|real paramname1, vector|real paramname2, ..." depending on whether - * each parameter has a model. This includes distribution parameters. - * - 'dpars_B' is replaced with the same parameters as dpars_A but with window - * indices removed. + * each parameter has a model. + * - 'dpars_B' is replaced with the same parameters as dpars_A but + * reparameterised according to the brms parameterisation for Stan. * * @param y Vector of observed values (delays) * @param dpars_A Distribution parameters (replaced via regex) diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan new file mode 100644 index 000000000..ee55dae2c --- /dev/null +++ b/inst/stan/marginal_model/functions.stan @@ -0,0 +1,29 @@ +/** + * Compute the log probability mass function for a marginal model with censoring + * + * This function is designed to be read into R where: + * - 'family' is replaced with the target distribution (e.g., 'lognormal') + * - 'dpars_A' is replaced with multiple distribution parameters in the format + * "real paramname1, real paramname2, ...". + * - 'dpars_B' is replaced with the same parameters as dpars_A but + * reparameterised according to the brms parameterisation for Stan. + * + * @param y Real value of observed delay + * @param dpars_A Distribution parameters (replaced via regex) + * @param relative_obs_t Observation time relative to primary window start + * @param pwindow_width Primary window width (actual time scale) + * @param swindow_width Secondary window width (actual time scale) + * @param y_upper Upper bound of delay interval + * @param primary_params Array of parameters for primary distribution + * + * @return Log probability mass with censoring adjustment for marginal model + */ + real marginal_family_lpmf(data int y, dpars_A, data real relative_obs_t, + data real pwindow_width, data real swindow_width, + data real y_upper, array[] real primary_params) { + + return primarycensored_lpmf( + y | dist_id, {dpars_B}, pwindow_width, y_upper, relative_obs_t, + primary_id, primary_params + ); +} diff --git a/man/as_epidist_latent_model.Rd b/man/as_epidist_latent_model.Rd index 9f9a419aa..f63e925c3 100644 --- a/man/as_epidist_latent_model.Rd +++ b/man/as_epidist_latent_model.Rd @@ -4,10 +4,12 @@ \alias{as_epidist_latent_model} \title{Convert an object to an \code{epidist_latent_model} object} \usage{ -as_epidist_latent_model(data) +as_epidist_latent_model(data, ...) } \arguments{ \item{data}{An object to be converted to the class \code{epidist_latent_model}} + +\item{...}{Additional arguments passed to methods.} } \description{ Convert an object to an \code{epidist_latent_model} object diff --git a/man/as_epidist_latent_model.epidist_linelist_data.Rd b/man/as_epidist_latent_model.epidist_linelist_data.Rd index 3b91956df..5e8e86af8 100644 --- a/man/as_epidist_latent_model.epidist_linelist_data.Rd +++ b/man/as_epidist_latent_model.epidist_linelist_data.Rd @@ -4,10 +4,12 @@ \alias{as_epidist_latent_model.epidist_linelist_data} \title{The latent model method for \code{epidist_linelist_data} objects} \usage{ -\method{as_epidist_latent_model}{epidist_linelist_data}(data) +\method{as_epidist_latent_model}{epidist_linelist_data}(data, ...) } \arguments{ \item{data}{An \code{epidist_linelist_data} object} + +\item{...}{Not used in this method.} } \description{ The latent model method for \code{epidist_linelist_data} objects diff --git a/man/as_epidist_marginal_model.Rd b/man/as_epidist_marginal_model.Rd new file mode 100644 index 000000000..0b8432f1b --- /dev/null +++ b/man/as_epidist_marginal_model.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{as_epidist_marginal_model} +\alias{as_epidist_marginal_model} +\title{Prepare marginal model to pass through to \code{brms}} +\usage{ +as_epidist_marginal_model(data, ...) +} +\arguments{ +\item{data}{A \code{data.frame} containing line list data} + +\item{...}{Additional arguments passed to methods.} +} +\description{ +Prepare marginal model to pass through to \code{brms} +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/as_epidist_marginal_model.epidist_linelist_data.Rd b/man/as_epidist_marginal_model.epidist_linelist_data.Rd new file mode 100644 index 000000000..2277bbe09 --- /dev/null +++ b/man/as_epidist_marginal_model.epidist_linelist_data.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{as_epidist_marginal_model.epidist_linelist_data} +\alias{as_epidist_marginal_model.epidist_linelist_data} +\title{The marginal model method for \code{epidist_linelist_data} objects} +\usage{ +\method{as_epidist_marginal_model}{epidist_linelist_data}(data, obs_time_threshold = 2, ...) +} +\arguments{ +\item{data}{An \code{epidist_linelist_data} object} + +\item{obs_time_threshold}{Ratio used to determine threshold for setting +relative observation times to Inf. Observation times greater than +\code{obs_time_threshold} times the maximum delay will be set to Inf to improve +model efficiency by reducing the number of unique observation times. +Default is 2.} + +\item{...}{Not used in this method.} +} +\description{ +The marginal model method for \code{epidist_linelist_data} objects +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/dot-extract_dpar_terms.Rd b/man/dot-extract_dpar_terms.Rd new file mode 100644 index 000000000..3ae238bec --- /dev/null +++ b/man/dot-extract_dpar_terms.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{.extract_dpar_terms} +\alias{.extract_dpar_terms} +\title{Extract distributional parameter terms from a brms formula} +\usage{ +.extract_dpar_terms(formula) +} +\arguments{ +\item{formula}{A `brms formula object} +} +\value{ +A character vector of unique terms +} +\description{ +This function extracts all unique terms from the right-hand side of all +distributional parameters in a brms formula. +} +\keyword{internal} diff --git a/man/dot-summarise_n_by_formula.Rd b/man/dot-summarise_n_by_formula.Rd new file mode 100644 index 000000000..f68b462b3 --- /dev/null +++ b/man/dot-summarise_n_by_formula.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{.summarise_n_by_formula} +\alias{.summarise_n_by_formula} +\title{Summarise data by grouping variables and count occurrences} +\usage{ +.summarise_n_by_formula(data, by = character(), formula = NULL) +} +\arguments{ +\item{data}{A \code{data.frame} to summarise which must contain a \code{n} column +which is a count of occurrences.} + +\item{by}{Character vector of column names to group by.} + +\item{formula}{Optional \code{brms} formula object to extract additional grouping +terms from.} +} +\value{ +A \code{data.frame} summarised by the grouping variables with counts +} +\description{ +Summarise data by grouping variables and count occurrences +} +\keyword{internal} diff --git a/man/epidist_family_model.epidist_marginal_model.Rd b/man/epidist_family_model.epidist_marginal_model.Rd new file mode 100644 index 000000000..ea6746ee5 --- /dev/null +++ b/man/epidist_family_model.epidist_marginal_model.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{epidist_family_model.epidist_marginal_model} +\alias{epidist_family_model.epidist_marginal_model} +\title{Create the model-specific component of an \code{epidist} custom family} +\usage{ +\method{epidist_family_model}{epidist_marginal_model}(data, family, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{family}{Output of a call to \code{brms::brmsfamily()} with additional +information as provided by \code{.add_dpar_info()}} + +\item{...}{Additional arguments passed to method.} +} +\description{ +Create the model-specific component of an \code{epidist} custom family +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/epidist_formula_model.epidist_marginal_model.Rd b/man/epidist_formula_model.epidist_marginal_model.Rd new file mode 100644 index 000000000..94806ff88 --- /dev/null +++ b/man/epidist_formula_model.epidist_marginal_model.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{epidist_formula_model.epidist_marginal_model} +\alias{epidist_formula_model.epidist_marginal_model} +\title{Define the model-specific component of an \code{epidist} custom formula} +\usage{ +\method{epidist_formula_model}{epidist_marginal_model}(data, formula, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{formula}{An object of class \link[stats:formula]{stats::formula} or \link[brms:brmsformula]{brms::brmsformula} +(or one that can be coerced to those classes). A symbolic description of the +model to be fitted. A formula must be provided for the distributional +parameter \code{mu}, and may optionally be provided for other distributional +parameters.} + +\item{...}{Additional arguments passed to method.} +} +\description{ +Define the model-specific component of an \code{epidist} custom formula +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/epidist_transform_data.Rd b/man/epidist_transform_data.Rd new file mode 100644 index 000000000..00c80c823 --- /dev/null +++ b/man/epidist_transform_data.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/transform_data.R +\name{epidist_transform_data} +\alias{epidist_transform_data} +\title{Transform data for an epidist model} +\usage{ +epidist_transform_data(data, family, formula, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{family}{A description of the response distribution and link function to +be used in the model created using \code{\link[=epidist_family]{epidist_family()}}.} + +\item{formula}{A formula object created using \code{\link[=epidist_formula]{epidist_formula()}}.} + +\item{...}{Additional arguments passed to \code{fn} method.} +} +\description{ +This function is used within \code{\link[=epidist]{epidist()}} to transform data before passing to +\code{brms}. It is unlikely that as a user you will need this function, but we +export it nonetheless to be transparent about what happens inside of a call +to \code{\link[=epidist]{epidist()}}. +} +\seealso{ +Other transform_data: +\code{\link{epidist_transform_data_model}()}, +\code{\link{epidist_transform_data_model.default}()} +} +\concept{transform_data} diff --git a/man/epidist_transform_data_model.Rd b/man/epidist_transform_data_model.Rd new file mode 100644 index 000000000..f8daf119e --- /dev/null +++ b/man/epidist_transform_data_model.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/transform_data.R +\name{epidist_transform_data_model} +\alias{epidist_transform_data_model} +\title{The model-specific parts of an \code{epidist_transform_data()} call} +\usage{ +epidist_transform_data_model(data, family, formula, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{family}{A description of the response distribution and link function to +be used in the model created using \code{\link[=epidist_family]{epidist_family()}}.} + +\item{formula}{A formula object created using \code{\link[=epidist_formula]{epidist_formula()}}.} + +\item{...}{Additional arguments passed to \code{fn} method.} +} +\description{ +The model-specific parts of an \code{epidist_transform_data()} call +} +\seealso{ +Other transform_data: +\code{\link{epidist_transform_data}()}, +\code{\link{epidist_transform_data_model.default}()} +} +\concept{transform_data} diff --git a/man/epidist_transform_data_model.default.Rd b/man/epidist_transform_data_model.default.Rd new file mode 100644 index 000000000..ca56c63a1 --- /dev/null +++ b/man/epidist_transform_data_model.default.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/transform_data.R +\name{epidist_transform_data_model.default} +\alias{epidist_transform_data_model.default} +\title{Default method for transforming data for a model} +\usage{ +\method{epidist_transform_data_model}{default}(data, family, formula, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{family}{A description of the response distribution and link function to +be used in the model created using \code{\link[=epidist_family]{epidist_family()}}.} + +\item{formula}{A formula object created using \code{\link[=epidist_formula]{epidist_formula()}}.} + +\item{...}{Additional arguments passed to \code{fn} method.} +} +\description{ +Default method for transforming data for a model +} +\seealso{ +Other transform_data: +\code{\link{epidist_transform_data}()}, +\code{\link{epidist_transform_data_model}()} +} +\concept{transform_data} diff --git a/man/is_epidist_marginal_model.Rd b/man/is_epidist_marginal_model.Rd new file mode 100644 index 000000000..5585c2f78 --- /dev/null +++ b/man/is_epidist_marginal_model.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{is_epidist_marginal_model} +\alias{is_epidist_marginal_model} +\title{Check if data has the \code{epidist_marginal_model} class} +\usage{ +is_epidist_marginal_model(data) +} +\arguments{ +\item{data}{A \code{data.frame} containing line list data} +} +\description{ +Check if data has the \code{epidist_marginal_model} class +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/new_epidist_latent_model.Rd b/man/new_epidist_latent_model.Rd index 8658161c4..0816107d5 100644 --- a/man/new_epidist_latent_model.Rd +++ b/man/new_epidist_latent_model.Rd @@ -4,10 +4,12 @@ \alias{new_epidist_latent_model} \title{Class constructor for \code{epidist_latent_model} objects} \usage{ -new_epidist_latent_model(data) +new_epidist_latent_model(data, ...) } \arguments{ \item{data}{An object to be set with the class \code{epidist_latent_model}} + +\item{...}{Additional arguments passed to methods.} } \value{ An object of class \code{epidist_latent_model} diff --git a/man/new_epidist_marginal_model.Rd b/man/new_epidist_marginal_model.Rd new file mode 100644 index 000000000..f2abe7e32 --- /dev/null +++ b/man/new_epidist_marginal_model.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{new_epidist_marginal_model} +\alias{new_epidist_marginal_model} +\title{Class constructor for \code{epidist_marginal_model} objects} +\usage{ +new_epidist_marginal_model(data) +} +\arguments{ +\item{data}{A data.frame to convert} +} +\value{ +An object of class \code{epidist_marginal_model} +} +\description{ +Class constructor for \code{epidist_marginal_model} objects +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 4cbb14949..d6d056826 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -113,12 +113,16 @@ sim_obs_sex <- as_epidist_linelist_data( sim_obs_sex$obs_time, sex = sim_obs_sex$sex ) - prep_obs <- as_epidist_latent_model(sim_obs) prep_naive_obs <- as_epidist_naive_model(sim_obs) +prep_marginal_obs <- as_epidist_marginal_model(sim_obs) prep_obs_gamma <- as_epidist_latent_model(sim_obs_gamma) prep_obs_sex <- as_epidist_latent_model(sim_obs_sex) +prep_marginal_obs <- as_epidist_marginal_model(sim_obs) +prep_marginal_obs_gamma <- as_epidist_marginal_model(sim_obs_gamma) +prep_marginal_obs_sex <- as_epidist_marginal_model(sim_obs_sex) + if (not_on_cran()) { set.seed(1) fit <- epidist( @@ -129,6 +133,10 @@ if (not_on_cran()) { fit_rstan <- epidist( data = prep_obs, seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0 ) + fit_marginal <- suppressMessages(epidist( + data = prep_marginal_obs, seed = 1, chains = 2, cores = 2, silent = 2, + refresh = 0, backend = "cmdstanr" + )) fit_gamma <- epidist( data = prep_obs_gamma, family = Gamma(link = "log"), @@ -136,10 +144,23 @@ if (not_on_cran()) { backend = "cmdstanr" ) + fit_marginal_gamma <- suppressMessages(epidist( + data = prep_marginal_obs_gamma, family = Gamma(link = "log"), + seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0, + backend = "cmdstanr" + )) + fit_sex <- epidist( data = prep_obs_sex, formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex), seed = 1, silent = 2, refresh = 0, cores = 2, chains = 2, backend = "cmdstanr" ) + + fit_marginal_sex <- suppressMessages(epidist( + data = prep_marginal_obs_sex, + formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex), + seed = 1, silent = 2, refresh = 0, + cores = 2, chains = 2, backend = "cmdstanr" + )) } diff --git a/tests/testthat/test-gen.R b/tests/testthat/test-gen.R index 37b7d7bf3..93d22a8d4 100644 --- a/tests/testthat/test-gen.R +++ b/tests/testthat/test-gen.R @@ -1,119 +1,125 @@ test_that("epidist_gen_posterior_predict returns a function that outputs positive integers with length equal to draws", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - prep <- brms::prepare_predictions(fit) - i <- 1 - predict_fn <- epidist_gen_posterior_predict(lognormal()) - pred_i <- predict_fn(i = i, prep) - expect_identical(floor(pred_i), pred_i) - expect_length(pred_i, prep$ndraws) - expect_gte(min(pred_i), 0) - # Test gamma - prep_gamma <- brms::prepare_predictions(fit_gamma) - predict_fn_gamma <- epidist_gen_posterior_predict(Gamma()) - pred_i_gamma <- predict_fn_gamma(i = i, prep_gamma) - expect_identical(floor(pred_i_gamma), pred_i_gamma) - expect_length(pred_i_gamma, prep_gamma$ndraws) - expect_gte(min(pred_i_gamma), 0) + # Helper function to test predictions + test_predictions <- function(fit, family) { + prep <- brms::prepare_predictions(fit) + i <- 1 + predict_fn <- epidist_gen_posterior_predict(family) + pred_i <- predict_fn(i = i, prep) + expect_identical(floor(pred_i), pred_i) + expect_length(pred_i, prep$ndraws) + expect_gte(min(pred_i), 0) + } + + # Test lognormal - latent and marginal + test_predictions(fit, lognormal()) + test_predictions(fit_marginal, lognormal()) + + # Test gamma - latent and marginal + test_predictions(fit_gamma, Gamma()) + test_predictions(fit_marginal_gamma, Gamma()) }) test_that("epidist_gen_posterior_predict returns a function that errors for i out of bounds", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - prep <- brms::prepare_predictions(fit) - i_out_of_bounds <- length(prep$data$Y) + 1 - predict_fn <- epidist_gen_posterior_predict(lognormal()) - expect_warning( - expect_error( - predict_fn(i = i_out_of_bounds, prep) + + # Helper function to test out of bounds errors + test_out_of_bounds <- function(fit, family) { + prep <- brms::prepare_predictions(fit) + i_out_of_bounds <- length(prep$data$Y) + 1 + predict_fn <- epidist_gen_posterior_predict(family) + expect_warning( + expect_error( + predict_fn(i = i_out_of_bounds, prep) + ) ) - ) + } - # Test gamma - prep_gamma <- brms::prepare_predictions(fit_gamma) - i_out_of_bounds_gamma <- length(prep_gamma$data$Y) + 1 - predict_fn_gamma <- epidist_gen_posterior_predict(Gamma()) - expect_warning( - expect_error(predict_fn_gamma(i = i_out_of_bounds_gamma, prep_gamma)) - ) + # Test lognormal - latent and marginal + test_out_of_bounds(fit, lognormal()) + test_out_of_bounds(fit_marginal, lognormal()) + + # Test gamma - latent and marginal + test_out_of_bounds(fit_gamma, Gamma()) + test_out_of_bounds(fit_marginal_gamma, Gamma()) }) test_that("epidist_gen_posterior_predict returns a function that can generate predictions with no censoring", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - predict_fn <- epidist_gen_posterior_predict(lognormal()) - draws <- data.frame(relative_obs_time = 1000, pwindow = 0, swindow = 0) |> - tidybayes::add_predicted_draws(fit, ndraws = 100) - expect_identical(draws$.draw, 1:100) - pred <- draws$.prediction - expect_gte(min(pred), 0) - expect_true(all(abs(pred - round(pred)) > .Machine$double.eps^0.5)) - # Test gamma - predict_fn_gamma <- epidist_gen_posterior_predict(Gamma()) - draws_gamma <- data.frame( - relative_obs_time = 1000, pwindow = 0, swindow = 0 - ) |> - tidybayes::add_predicted_draws(fit_gamma, ndraws = 100) - expect_identical(draws_gamma$.draw, 1:100) - pred_gamma <- draws_gamma$.prediction - expect_gte(min(pred_gamma), 0) - expect_true( - all(abs(pred_gamma - round(pred_gamma)) > .Machine$double.eps^0.5) - ) + # Helper function to test uncensored predictions + test_uncensored <- function(fit, family) { + predict_fn <- epidist_gen_posterior_predict(family) + draws <- data.frame( + relative_obs_time = Inf, pwindow = 0, swindow = 0, delay_upr = NA + ) |> + tidybayes::add_predicted_draws(fit, ndraws = 100) + expect_identical(draws$.draw, 1:100) + pred <- draws$.prediction + expect_gte(min(pred), 0) + expect_true(all(abs(pred - round(pred)) > .Machine$double.eps^0.5)) + } + + # Test lognormal - latent and marginal + test_uncensored(fit, lognormal()) + test_uncensored(fit_marginal, lognormal()) + + # Test gamma - latent and marginal + test_uncensored(fit_gamma, Gamma()) + test_uncensored(fit_marginal_gamma, Gamma()) }) test_that("epidist_gen_posterior_predict returns a function that predicts delays in the 95% credible interval", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - prep <- brms::prepare_predictions(fit) - prep$ndraws <- 1000 # Down from the 4000 for time saving - predict_fn <- epidist_gen_posterior_predict(lognormal()) - q <- purrr::map_vec(seq_along(prep$data$Y), function(i) { - y <- predict_fn(i, prep) - ecdf <- ecdf(y) - q <- ecdf(prep$data$Y[i]) - return(q) - }) - expect_lt(quantile(q, 0.1), 0.3) - expect_gt(quantile(q, 0.9), 0.7) - expect_lt(min(q), 0.1) - expect_gt(max(q), 0.9) - expect_lt(mean(q), 0.65) - expect_gt(mean(q), 0.35) - # Test gamma - prep_gamma <- brms::prepare_predictions(fit_gamma) - prep_gamma$ndraws <- 1000 - predict_fn_gamma <- epidist_gen_posterior_predict(Gamma()) - q_gamma <- purrr::map_vec(seq_along(prep_gamma$data$Y), function(i) { - y <- predict_fn_gamma(i, prep_gamma) - ecdf <- ecdf(y) - q <- ecdf(prep_gamma$data$Y[i]) - return(q) - }) - expect_lt(quantile(q_gamma, 0.1), 0.3) - expect_gt(quantile(q_gamma, 0.9), 0.7) - expect_lt(min(q_gamma), 0.1) - expect_gt(max(q_gamma), 0.9) - expect_lt(mean(q_gamma), 0.65) - expect_gt(mean(q_gamma), 0.35) + # Helper function to test credible intervals + test_credible_intervals <- function(fit, family) { + prep <- brms::prepare_predictions(fit) + prep$ndraws <- 1000 # Down from the 4000 for time saving + predict_fn <- epidist_gen_posterior_predict(family) + q <- purrr::map_vec(seq_along(prep$data$Y), function(i) { + y <- predict_fn(i, prep) + ecdf <- ecdf(y) + q <- ecdf(prep$data$Y[i]) + return(q) + }) + expect_lt(quantile(q, 0.1), 0.3) + expect_gt(quantile(q, 0.9), 0.7) + expect_lt(min(q), 0.1) + expect_gt(max(q), 0.9) + expect_lt(mean(q), 0.65) + expect_gt(mean(q), 0.35) + } + + # Test lognormal - latent and marginal + test_credible_intervals(fit, lognormal()) + test_credible_intervals(fit_marginal, lognormal()) + + # Test gamma - latent and marginal + test_credible_intervals(fit_gamma, Gamma()) + test_credible_intervals(fit_marginal_gamma, Gamma()) }) test_that("epidist_gen_posterior_epred returns a function that creates arrays with correct dimensions", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - epred <- prep_obs |> - tidybayes::add_epred_draws(fit) - expect_equal(mean(epred$.epred), 5.97, tolerance = 0.1) - expect_gte(min(epred$.epred), 0) - # Test gamma - epred_gamma <- prep_obs |> - tidybayes::add_epred_draws(fit_gamma) - expect_equal(mean(epred_gamma$.epred), 6.56, tolerance = 0.1) - expect_gte(min(epred_gamma$.epred), 0) + # Helper function to test epred + test_epred <- function(fit, expected_mean) { + epred <- prep_obs |> + mutate(delay_upr = NA) |> + tidybayes::add_epred_draws(fit) + expect_equal(mean(epred$.epred), expected_mean, tolerance = 0.1) + expect_gte(min(epred$.epred), 0) + } + + # Test lognormal - latent and marginal + test_epred(fit, 5.97) + test_epred(fit_marginal, 5.97) + + # Test gamma - latent and marginal + test_epred(fit_gamma, 6.56) + test_epred(fit_marginal_gamma, 6.56) }) test_that("epidist_gen_log_lik returns a function that produces valid log likelihoods", { # nolint: line_length_linter. diff --git a/tests/testthat/test-int-marginal_model.R b/tests/testthat/test-int-marginal_model.R new file mode 100644 index 000000000..d44131b90 --- /dev/null +++ b/tests/testthat/test-int-marginal_model.R @@ -0,0 +1,82 @@ +# Note: some tests in this script are stochastic. As such, test failure may be +# bad luck rather than indicate an issue with the code. However, as these tests +# are reproducible, the distribution of test failures may be investigated by +# varying the input seed. Test failure at an unusually high rate does suggest +# a potential code issue. + +test_that("epidist.epidist_marginal_model Stan code has no syntax errors in the default case", { # nolint: line_length_linter. + skip_on_cran() + stancode <- suppressMessages(epidist( + data = prep_marginal_obs, + fn = brms::make_stancode + )) + mod <- cmdstanr::cmdstan_model( + stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE + ) + expect_true(mod$check_syntax()) +}) + +test_that("epidist.epidist_marginal_model fits and the MCMC converges in the default case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + expect_s3_class(fit_marginal, "brmsfit") + expect_s3_class(fit_marginal, "epidist_fit") + expect_convergence(fit_marginal) +}) + +test_that("epidist.epidist_marginal_model recovers the simulation settings for the delay distribution in the default case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + pred <- predict_delay_parameters(fit_marginal) + expect_equal(mean(pred$mu), meanlog, tolerance = 0.1) + expect_equal(mean(pred$sigma), sdlog, tolerance = 0.1) +}) + +test_that("epidist.epidist_marginal_model fits and the MCMC converges in the gamma delay case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + expect_s3_class(fit_marginal_gamma, "brmsfit") + expect_s3_class(fit_marginal_gamma, "epidist_fit") + expect_convergence(fit_marginal_gamma) +}) + +test_that("epidist.epidist_marginal_model recovers the simulation settings for the delay distribution in the gamma delay case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + draws_gamma <- posterior::as_draws_df(fit_marginal_gamma$fit) + draws_gamma_mu <- exp(draws_gamma$Intercept) + draws_gamma_shape <- exp(draws_gamma$Intercept_shape) + draws_gamma_mu_ecdf <- ecdf(draws_gamma_mu) + draws_gamma_shape_ecdf <- ecdf(draws_gamma_shape) + quantile_mu <- draws_gamma_mu_ecdf(mu) + quantile_shape <- draws_gamma_shape_ecdf(shape) + expect_gte(quantile_mu, 0.025) + expect_lte(quantile_mu, 0.975) + expect_gte(quantile_shape, 0.025) + expect_lte(quantile_shape, 0.975) +}) + +test_that("epidist.epidist_marginal_model fits and recovers a sex effect", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + expect_s3_class(fit_marginal_sex, "brmsfit") + expect_s3_class(fit_marginal_sex, "epidist_fit") + expect_convergence(fit_marginal_sex) + + draws <- posterior::as_draws_df(fit_marginal_sex$fit) + expect_equal(mean(draws$b_Intercept), meanlog_m, tolerance = 0.3) + expect_equal( + mean(draws$b_Intercept + draws$b_sex), meanlog_f, + tolerance = 0.3 + ) + expect_equal(mean(exp(draws$b_sigma_Intercept)), sdlog_m, tolerance = 0.3) + expect_equal( + mean(exp(draws$b_sigma_Intercept + draws$b_sigma_sex)), + sdlog_f, + tolerance = 0.3 + ) +}) diff --git a/tests/testthat/test-marginal_model.R b/tests/testthat/test-marginal_model.R new file mode 100644 index 000000000..aeaf43c30 --- /dev/null +++ b/tests/testthat/test-marginal_model.R @@ -0,0 +1,107 @@ +test_that("as_epidist_marginal_model.epidist_linelist_data with default settings an object with the correct classes", { # nolint: line_length_linter. + prep_marginal_obs <- as_epidist_marginal_model(sim_obs) + expect_s3_class(prep_marginal_obs, "data.frame") + expect_s3_class(prep_marginal_obs, "epidist_marginal_model") +}) + +test_that("as_epidist_marginal_model.epidist_linelist_data errors when passed incorrect inputs", { # nolint: line_length_linter. + expect_error(as_epidist_marginal_model(list())) + expect_error(as_epidist_marginal_model(sim_obs[, 1])) +}) + +# Make this data available for other tests +family_lognormal <- epidist_family(prep_marginal_obs, family = lognormal()) + +test_that("is_epidist_marginal_model returns TRUE for correct input", { # nolint: line_length_linter. + expect_true(is_epidist_marginal_model(prep_marginal_obs)) + expect_true({ + x <- list() + class(x) <- "epidist_marginal_model" + is_epidist_marginal_model(x) + }) +}) + +test_that("is_epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter. + expect_false(is_epidist_marginal_model(list())) + expect_false({ + x <- list() + class(x) <- "epidist_marginal_model_extension" + is_epidist_marginal_model(x) + }) +}) + +test_that("assert_epidist.epidist_marginal_model doesn't produce an error for correct input", { # nolint: line_length_linter. + expect_no_error(assert_epidist(prep_marginal_obs)) +}) + +test_that("assert_epidist.epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter. + expect_error(assert_epidist(list())) + expect_error(assert_epidist(prep_marginal_obs[, 1])) + expect_error({ + x <- list() + class(x) <- "epidist_marginal_model" + assert_epidist(x) + }) +}) + +test_that("epidist_stancode.epidist_marginal_model produces valid stanvars", { # nolint: line_length_linter. + epidist_family <- epidist_family(prep_marginal_obs) + epidist_formula <- epidist_formula( + prep_marginal_obs, epidist_family, + formula = bf(mu ~ 1) + ) + stancode <- epidist_stancode( + prep_marginal_obs, + family = epidist_family, formula = epidist_formula + ) + expect_s3_class(stancode, "stanvars") +}) + +test_that("epidist_transform_data_model.epidist_marginal_model correctly transforms data and messages", { # nolint: line_length_linter. + family <- epidist_family(prep_marginal_obs, family = lognormal()) + formula <- epidist_formula( + prep_marginal_obs, + formula = bf(mu ~ 1), + family = family + ) + expect_no_message( + expect_message( + expect_message( + expect_message( + epidist_transform_data_model( + prep_marginal_obs, + family = family, + formula = formula + ), + "Reduced from 500 to 144 rows." + ), + "Data summarised by unique combinations of:" + ), + "Model variables" + ) + ) + + family <- epidist_family(prep_marginal_obs, family = lognormal()) + formula <- epidist_formula( + prep_marginal_obs, + formula = bf(mu ~ 1 + ptime_lwr), + family = family + ) + expect_message( + expect_message( + expect_message( + expect_message( + epidist_transform_data_model( + prep_marginal_obs, + family = family, + formula = formula + ), + "Reduced from 500 to 144 rows." + ), + "Data summarised by unique combinations of:" + ), + "Model variables" + ), + "ptime_lwr" + ) +}) diff --git a/tests/testthat/test-postprocess.R b/tests/testthat/test-postprocess.R index 798d591a2..3ccfe1d19 100644 --- a/tests/testthat/test-postprocess.R +++ b/tests/testthat/test-postprocess.R @@ -1,52 +1,73 @@ -test_that("predict_delay_parameters works with NULL newdata and the latent lognormal model", { # nolint: line_length_linter. - skip_on_cran() - set.seed(1) - pred <- predict_delay_parameters(fit) - expect_s3_class(pred, "lognormal_samples") - expect_s3_class(pred, "data.frame") - expect_named(pred, c("draw", "index", "mu", "sigma", "mean", "sd")) - expect_true(all(pred$mean > 0)) - expect_true(all(pred$sd > 0)) - expect_length(unique(pred$index), nrow(prep_obs)) - expect_length(unique(pred$draw), summary(fit)$total_ndraws) -}) +test_that( + "predict_delay_parameters works with NULL newdata and the latent and marginal lognormal model", # nolint: line_length_linter. + { + skip_on_cran() + + # Helper function to test predictions + test_predictions <- function(fit, expected_rows = nrow(prep_obs)) { + set.seed(1) + pred <- predict_delay_parameters(fit) + expect_s3_class(pred, "lognormal_samples") + expect_s3_class(pred, "data.frame") + expect_named(pred, c("draw", "index", "mu", "sigma", "mean", "sd")) + expect_true(all(pred$mean > 0)) + expect_true(all(pred$sd > 0)) + expect_length(unique(pred$index), expected_rows) + expect_length(unique(pred$draw), summary(fit)$total_ndraws) + } + + # Test latent and marginal models + test_predictions(fit) + test_predictions(fit_marginal, expected_rows = 144) + } +) test_that("predict_delay_parameters accepts newdata arguments and prediction by sex recovers underlying parameters", { # nolint: line_length_linter. skip_on_cran() - set.seed(1) - pred_sex <- predict_delay_parameters(fit_sex, prep_obs_sex) - expect_s3_class(pred_sex, "lognormal_samples") - expect_s3_class(pred_sex, "data.frame") - expect_named(pred_sex, c("draw", "index", "mu", "sigma", "mean", "sd")) - expect_true(all(pred_sex$mean > 0)) - expect_true(all(pred_sex$sd > 0)) - expect_length(unique(pred_sex$index), nrow(prep_obs_sex)) - expect_length(unique(pred_sex$draw), summary(fit_sex)$total_ndraws) - pred_sex_summary <- pred_sex |> - dplyr::left_join( - dplyr::select(prep_obs_sex, index = .row_id, sex), - by = "index" - ) |> - dplyr::group_by(sex) |> - dplyr::summarise( - mu = mean(mu), - sigma = mean(sigma) + # Helper function to test sex predictions + test_sex_predictions <- function(fit, prep = prep_obs_sex) { + set.seed(1) + prep <- prep |> + dplyr::mutate(.row_id = dplyr::row_number()) + pred_sex <- predict_delay_parameters(fit, prep) + expect_s3_class(pred_sex, "lognormal_samples") + expect_s3_class(pred_sex, "data.frame") + expect_named(pred_sex, c("draw", "index", "mu", "sigma", "mean", "sd")) + expect_true(all(pred_sex$mean > 0)) + expect_true(all(pred_sex$sd > 0)) + expect_length(unique(pred_sex$index), nrow(prep)) + expect_length(unique(pred_sex$draw), summary(fit)$total_ndraws) + + pred_sex_summary <- pred_sex |> + dplyr::left_join( + dplyr::select(prep, index = .row_id, sex), + by = "index" + ) |> + dplyr::group_by(sex) |> + dplyr::summarise( + mu = mean(mu), + sigma = mean(sigma) + ) + + # Correct predictions of M + expect_equal( + as.numeric(pred_sex_summary[1, c("mu", "sigma")]), + c(meanlog_m, sdlog_m), + tolerance = 0.1 ) - # Correct predictions of M - expect_equal( - as.numeric(pred_sex_summary[1, c("mu", "sigma")]), - c(meanlog_m, sdlog_m), - tolerance = 0.1 - ) + # Correction predictions of F + expect_equal( + as.numeric(pred_sex_summary[2, c("mu", "sigma")]), + c(meanlog_f, sdlog_f), + tolerance = 0.1 + ) + } - # Correction predictions of F - expect_equal( - as.numeric(pred_sex_summary[2, c("mu", "sigma")]), - c(meanlog_f, sdlog_f), - tolerance = 0.1 - ) + # Test latent and marginal models + test_sex_predictions(fit_sex) + test_sex_predictions(fit_marginal_sex, prep_marginal_obs_sex) }) test_that("add_mean_sd.lognormal_samples works with simulated lognormal distribution parameter data", { # nolint: line_length_linter. diff --git a/tests/testthat/test-transform_data.R b/tests/testthat/test-transform_data.R new file mode 100644 index 000000000..a970d8652 --- /dev/null +++ b/tests/testthat/test-transform_data.R @@ -0,0 +1,39 @@ +test_that( + "epidist_transform_data with default settings returns data unchanged", + { + family <- epidist_family(prep_obs, family = lognormal()) + formula <- epidist_formula(prep_obs, family = family, formula = bf(mu ~ 1)) + + transformed <- epidist_transform_data(prep_obs, family, formula) + expect_identical(transformed, prep_obs) + } +) + +test_that("epidist_transform_data errors when passed incorrect inputs", { + family <- epidist_family(prep_obs, family = lognormal()) + formula <- epidist_formula(prep_obs, family = family, formula = bf(mu ~ 1)) + + expect_error(epidist_transform_data(list(), family, formula)) +}) + +test_that("epidist_transform_data_model.default returns data unchanged", { + family <- epidist_family(prep_obs, family = lognormal()) + formula <- epidist_formula(prep_obs, family = family, formula = bf(mu ~ 1)) + + transformed <- epidist_transform_data_model(prep_obs, family, formula) + expect_identical(transformed, prep_obs) +}) + +test_that("epidist_transform_data works with different model types", { + family <- epidist_family(prep_obs, family = lognormal()) + formula <- epidist_formula(prep_obs, family = family, formula = bf(mu ~ 1)) + + expect_identical( + epidist_transform_data(prep_naive_obs, family, formula), + prep_naive_obs + ) + expect_identical( + epidist_transform_data(prep_obs_gamma, family, formula), + prep_obs_gamma + ) +}) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 1fb3add13..f31b452fa 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -98,3 +98,58 @@ test_that(".make_intercepts_explicit does not add an intercept if the distributi expect_identical(formula$pforms$mu, formula_updated$pforms$mu) expect_identical(formula$pforms$sigma, formula_updated$pforms$sigma) }) + +test_that( + ".summarise_n_by_formula correctly summarizes counts by grouping variables", + { + df <- tibble::tibble( + x = c(1, 1, 2, 2), + y = c("a", "b", "a", "b"), + n = c(2, 3, 4, 1) + ) + + # Test grouping by single variable + result <- .summarise_n_by_formula(df, by = "x") + expect_identical(result$x, c(1, 2)) + expect_identical(result$n, c(5, 5)) + + # Test grouping by multiple variable + result <- .summarise_n_by_formula(df, by = c("x", "y")) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + + # Test with formula + formula <- bf(mu ~ x + y) + result <- .summarise_n_by_formula(df, formula = formula) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + + # Test with both by and formula + formula <- bf(mu ~ y) + result <- .summarise_n_by_formula(df, by = "x", formula = formula) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + } +) + +test_that( + ".summarise_n_by_formula handles missing grouping variables appropriately", + { + df <- data.frame(x = 1:2, n = c(1, 2)) + expect_error( + .summarise_n_by_formula(df, by = "missing"), + "Can't subset elements that don't exist" + ) + } +) + +test_that(".summarise_n_by_formula requires n column in data", { + df <- data.frame(x = 1:2) + expect_error( + .summarise_n_by_formula(df, by = "x"), + "Column `n` not found in `.data`." + ) +}) diff --git a/vignettes/approx-inference.Rmd b/vignettes/approx-inference.Rmd index 819992611..526ca64d2 100644 --- a/vignettes/approx-inference.Rmd +++ b/vignettes/approx-inference.Rmd @@ -135,7 +135,7 @@ obs_cens_trunc_samp <- simulate_gillespie(seed = 101) |> slice_sample(n = sample_size, replace = FALSE) ``` -We now prepare the data for fitting with the latent individual model, and perform inference with HMC: +We now prepare the data for fitting with the marginal model, and perform inference with HMC: ```{r results='hide'} linelist_data <- as_epidist_linelist_data( @@ -146,7 +146,7 @@ linelist_data <- as_epidist_linelist_data( obs_time = obs_cens_trunc_samp$obs_time ) -data <- as_epidist_latent_model(linelist_data) +data <- as_epidist_marginal_model(linelist_data) t <- proc.time() fit_hmc <- epidist(data = data, algorithm = "sampling", backend = "cmdstanr") @@ -155,7 +155,7 @@ time_hmc <- proc.time() - t Note that for clarity above we specify `algorithm = "sampling"`, but if you were to call `epidist(data = data)` the result would be the same since `"sampling"` (i.e. HMC) is the default value for the `algorithm` argument. -Now, we fit^[Note that in this section, and above for the MCMC, the output of the call is hidden, but if you were to call these functions yourself they would display information about the fitting procedure as it occurs] the same latent individual model using each method in Section \@ref(other). +Now, we fit^[Note that in this section, and above for the MCMC, the output of the call is hidden, but if you were to call these functions yourself they would display information about the fitting procedure as it occurs] the same marginal model using each method in Section \@ref(other). To match the four Markov chains of length 1000 in HMC above, we then draw 4000 samples from each approximate posterior. ```{r results='hide'} diff --git a/vignettes/ebola.Rmd b/vignettes/ebola.Rmd index bd1c199fe..9e1b176c5 100644 --- a/vignettes/ebola.Rmd +++ b/vignettes/ebola.Rmd @@ -216,14 +216,14 @@ Second, because we also did not supply an observation time column (`obs_date`), ## Model fitting -To prepare the data for use with the latent individual model, we define the data as being a `epidist_latent_model` model object: +To prepare the data for use with the marginal model, we define the data as being a `epidist_marginal_model` model object: ```{r} -obs_prep <- as_epidist_latent_model(linelist_data) +obs_prep <- as_epidist_marginal_model(linelist_data) head(obs_prep) ``` -Now we are ready to fit the latent individual model. +Now we are ready to fit the marginal model. ### Intercept-only model @@ -241,7 +241,7 @@ fit <- epidist( algorithm = "sampling", chains = 2, cores = 2, - refresh = as.integer(interactive()), + refresh = ifelse(interactive(), 250, 0), seed = 1, backend = "cmdstanr" ) @@ -267,7 +267,7 @@ fit_sex <- epidist( algorithm = "sampling", chains = 2, cores = 2, - refresh = as.integer(interactive()), + refresh = ifelse(interactive(), 250, 0), seed = 1, backend = "cmdstanr" ) @@ -298,7 +298,7 @@ fit_sex_district <- epidist( algorithm = "sampling", chains = 2, cores = 2, - refresh = as.integer(interactive()), + refresh = ifelse(interactive(), 250, 0), seed = 1, backend = "cmdstanr" ) @@ -321,9 +321,15 @@ In Figure \@ref(fig:epred) we show the posterior expectation of the delay distri Figure \@ref(fig:epred)B illustrates the higher mean of men as compared with women. ```{r} +# add dummmy variables +add_marginal_dummy_vars <- function(data) { + data |> + mutate(relative_obs_time = NA, pwindow = NA, delay_upr = NA, swindow = NA) +} + epred_draws <- obs_prep |> data_grid(NA) |> - mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> + add_marginal_dummy_vars() |> add_epred_draws(fit, dpar = TRUE) epred_base_figure <- epred_draws |> @@ -334,7 +340,7 @@ epred_base_figure <- epred_draws |> epred_draws_sex <- obs_prep |> data_grid(sex) |> - mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> + add_marginal_dummy_vars() |> add_epred_draws(fit_sex, dpar = TRUE) epred_sex_figure <- epred_draws_sex |> @@ -345,7 +351,7 @@ epred_sex_figure <- epred_draws_sex |> epred_draws_sex_district <- obs_prep |> data_grid(sex, district) |> - mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> + add_marginal_dummy_vars() |> add_epred_draws(fit_sex_district, dpar = TRUE) epred_sex_district_figure <- epred_draws_sex_district |> @@ -376,7 +382,7 @@ For example, for the `mu` parameter in the sex-district stratified model (Figure linpred_draws_sex_district <- obs_prep |> as.data.frame() |> data_grid(sex, district) |> - mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> + add_marginal_dummy_vars() |> add_linpred_draws(fit_sex_district, dpar = TRUE) ``` @@ -400,13 +406,19 @@ In this section, we demonstrate how to produce either a discrete probability mas ### Discrete probability mass function To generate a discrete probability mass function (PMF) we predict the delay distribution that would be observed with daily censoring and no right truncation. -To do this, we set each of `pwindow` and `swindow` to 1 for daily censoring, and `relative_obs_time` to 1000 for no censoring. +To do this, we set each of `pwindow` and `swindow` to 1 for daily censoring, and `relative_obs_time` to `Inf` for no censoring. Figure \@ref(fig:pmf) shows the result, where the few delays greater than 30 are omitted from the figure. ```{r} +add_marginal_pmf_vars <- function(data) { + data |> + mutate( + relative_obs_time = Inf, pwindow = 1, swindow = 1, delay_upr = NA + ) +} + draws_pmf <- obs_prep |> - as.data.frame() |> - mutate(relative_obs_time = 1000, pwindow = 1, swindow = 1) |> + add_marginal_pmf_vars() |> add_predicted_draws(fit, ndraws = 1000) pmf_base_figure <- ggplot(draws_pmf, aes(x = .prediction)) + @@ -416,9 +428,8 @@ pmf_base_figure <- ggplot(draws_pmf, aes(x = .prediction)) + theme_minimal() draws_sex_pmf <- obs_prep |> - as.data.frame() |> data_grid(sex) |> - mutate(relative_obs_time = 1000, pwindow = 1, swindow = 1) |> + add_marginal_pmf_vars() |> add_predicted_draws(fit_sex, ndraws = 1000) pmf_sex_figure <- draws_sex_pmf |> @@ -430,9 +441,8 @@ pmf_sex_figure <- draws_sex_pmf |> theme_minimal() draws_sex_district_pmf <- obs_prep |> - as.data.frame() |> data_grid(sex, district) |> - mutate(relative_obs_time = 1000, pwindow = 1, swindow = 1) |> + add_marginal_pmf_vars() |> add_predicted_draws(fit_sex_district, ndraws = 1000) pmf_sex_district_figure <- draws_sex_district_pmf |> @@ -468,9 +478,15 @@ The posterior predictive distribution under no truncation and no censoring. That is to produce continuous delay times (Figure \@ref(fig:pdf)): ```{r} +add_marginal_pdf_vars <- function(data) { + data |> + mutate( + relative_obs_time = Inf, pwindow = 0, swindow = 0, delay_upr = NA + ) +} + draws_pdf <- obs_prep |> - as.data.frame() |> - mutate(relative_obs_time = 1000, pwindow = 0, swindow = 0) |> + add_marginal_pdf_vars() |> add_predicted_draws(fit, ndraws = 1000) pdf_base_figure <- ggplot(draws_pdf, aes(x = .prediction)) + @@ -480,9 +496,8 @@ pdf_base_figure <- ggplot(draws_pdf, aes(x = .prediction)) + theme_minimal() draws_sex_pdf <- obs_prep |> - as.data.frame() |> data_grid(sex) |> - mutate(relative_obs_time = 1000, pwindow = 0, swindow = 0) |> + add_marginal_pdf_vars() |> add_predicted_draws(fit_sex, ndraws = 1000) pdf_sex_figure <- draws_sex_pdf |> @@ -494,9 +509,8 @@ pdf_sex_figure <- draws_sex_pdf |> theme_minimal() draws_sex_district_pdf <- obs_prep |> - as.data.frame() |> data_grid(sex, district) |> - mutate(relative_obs_time = 1000, pwindow = 0, swindow = 0) |> + add_marginal_pdf_vars() |> add_predicted_draws(fit_sex_district, ndraws = 1000) pdf_sex_district_figure <- draws_sex_district_pdf |> diff --git a/vignettes/epidist.Rmd b/vignettes/epidist.Rmd index e7c8e807c..4198b4618 100644 --- a/vignettes/epidist.Rmd +++ b/vignettes/epidist.Rmd @@ -274,7 +274,7 @@ linelist_data <- as_epidist_linelist_data( obs_time = obs_cens_trunc_samp$obs_time ) -data <- as_epidist_latent_model(linelist_data) +data <- as_epidist_marginal_model(linelist_data) class(data) ``` @@ -285,7 +285,7 @@ In particular, we use the the No-U-Turn Sampler (NUTS) Markov chain Monte Carlo ```{r} fit <- epidist( - data = data, chains = 2, cores = 2, refresh = as.integer(interactive()) + data = data, chains = 2, cores = 2, refresh = ifelse(interactive(), 250, 0) ) ``` diff --git a/vignettes/faq.Rmd b/vignettes/faq.Rmd index 3343eaf24..a24aef918 100644 --- a/vignettes/faq.Rmd +++ b/vignettes/faq.Rmd @@ -61,12 +61,16 @@ linelist_data <- as_epidist_linelist_data( obs_cens_trunc_samp$stime_upr, obs_time = obs_cens_trunc_samp$obs_time ) -data <- as_epidist_latent_model(linelist_data) +data <- as_epidist_marginal_model(linelist_data) fit <- epidist( data, formula = mu ~ 1, - seed = 1 + seed = 1, + chains = 2, + cores = 2, + refresh = ifelse(interactive(), 250, 0), + backend = "cmdstanr" ) ``` @@ -151,7 +155,8 @@ fit_ppc <- epidist( formula = mu ~ 1, family = lognormal(), sample_prior = "only", - seed = 1 + seed = 1, + backend = "cmdstanr" ) ``` @@ -217,8 +222,10 @@ To see these functions demonstrated in a vignette, see ["Advanced features with As a short example, to generate 4000 predictions (equal to the number of draws) of the delay that would be observed with a double censored observation process (in which the primary and secondary censoring windows are both one) then: ```{r} -draws_pmf <- data.frame(relative_obs_time = 1000, pwindow = 1, swindow = 1) |> - add_predicted_draws(fit, ndraws = 4000) +draws_pmf <- tibble::tibble( + relative_obs_time = Inf, pwindow = 1, swindow = 1, delay_upr = NA +) |> + add_predicted_draws(fit, ndraws = 2000) ggplot(draws_pmf, aes(x = .prediction)) + geom_bar(aes(y = after_stat(count / sum(count)))) + @@ -227,10 +234,7 @@ ggplot(draws_pmf, aes(x = .prediction)) + theme_minimal() ``` -Importantly, this functionality is only available for `epidist` models using custom `brms` families that have `posterior_predict` and `posterior_epred` methods implemented. -For example, for the `epidist_latent_model` model, currently methods are implemented for the [lognormal](https://github.com/epinowcast/epidist/blob/main/R/latent_lognormal.R) and [gamma](https://github.com/epinowcast/epidist/blob/main/R/latent_gamma.R) families. -If you are using another family, consider [submitting a pull request](https://github.com/epinowcast/epidist/pulls) to implement these methods! -In doing so, you may find it useful to use the [`primarycensored`](https://primarycensored.epinowcast.org/) package. +Importantly, this functionality is only available for `epidist` models using `brms` families that have a `log_lik_censor` method implemented internally in `brms`. If you are using another family, consider [submitting a pull request](https://github.com/epinowcast/epidist/pulls) to implement these methods! # How can I use the `cmdstanr` backend?