Skip to content

Commit

Permalink
Issue 442: Generalise to all brms distributions (#459)
Browse files Browse the repository at this point in the history
* update gitignore to ignore local build paths for testing

* breaking rename

* start work on primarycensored port in and generalised predict and epred

* check new implementatons against ebola vignette

* refine posterior prediction and make vignettes faster

* rename and move around

* refactor tests

* check tests

* refine tess

* add note about direct usage failing but tidybayes working

Former-commit-id: 95b48ca
Former-commit-id: cf73454e616950a2f6debdc2d039d443fe5537ab
Former-commit-id: efd8befdd0b9b9e8f3f7abaefe9f81590bc3ad35 [formerly f61ea80]
Former-commit-id: aff38d5ec25ac9be287314158163bfafd6d8fa3a
  • Loading branch information
seabbs authored Nov 21, 2024
1 parent 926f178 commit dc5779a
Show file tree
Hide file tree
Showing 32 changed files with 509 additions and 437 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ data/models/*reparam
docs
/doc/
/Meta/
vignettes/**_cache/
*.pdf
.vscode/
vignettes/figures/
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Imports:
rstan (>= 2.26.0),
dplyr,
tibble,
lubridate
lubridate,
primarycensored
Suggests:
bookdown,
testthat (>= 3.0.0),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ export(epidist_family_prior)
export(epidist_family_reparam)
export(epidist_formula)
export(epidist_formula_model)
export(epidist_gen_posterior_epred)
export(epidist_gen_posterior_predict)
export(epidist_model_prior)
export(epidist_prior)
export(epidist_stancode)
Expand Down
73 changes: 73 additions & 0 deletions R/gen.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#' Create a function to draw from the posterior predictive distribution for a
#' latent model
#'
#' This function creates a function that draws from the posterior predictive
#' distribution for a latent model using [primarycensored::rpcens()] to handle
#' censoring and truncation. The returned function takes a `prep` argument from
#' `brms` and returns posterior predictions. This is used internally by
#' [brms::posterior_predict()] to generate predictions for latent models.
#'
#' @inheritParams epidist_family
#'
#' @return A function that takes a `prep` argument from brms and returns a
#' matrix of posterior predictions, with one row per posterior draw and one
#' column per observation. The `prep` object must have the following variables:
#' * `vreal1`: relative observation time
#' * `vreal2`: primary event window
#' * `vreal3`: secondary event window
#'
#' @seealso [brms::posterior_predict()] for details on how this is used within
#' `brms`, [primarycensored::rpcens()] for details on the censoring approach
#' @autoglobal
#' @family gen
#' @export
epidist_gen_posterior_predict <- function(family) {
dist_fn <- .get_brms_fn("posterior_predict", family)

rdist <- function(n, i, prep, ...) {
prep$ndraws <- n
do.call(dist_fn, list(i = i, prep = prep))
}

.predict <- function(i, prep, ...) {
relative_obs_time <- prep$data$vreal1[i]
pwindow <- prep$data$vreal2[i]
swindow <- prep$data$vreal3[i]

as.matrix(primarycensored::rpcens(
n = prep$ndraws,
rdist = rdist,
rprimary = stats::runif,
pwindow = prep$data$vreal2[i],
swindow = prep$data$vreal3[i],
D = prep$data$vreal1[i],
i = i,
prep = prep
))
}
return(.predict)
}

#' Create a function to draw from the expected value of the posterior predictive
#' distribution for a latent model
#'
#' This function creates a function that calculates the expected value of the
#' posterior predictive distribution for a latent model. The returned function
#' takes a `prep` argument (from brms) and returns posterior expected values.
#' This is used internally by [brms::posterior_epred()] to calculate expected
#' values for latent models.
#'
#' @inheritParams epidist_family
#'
#' @return A function that takes a prep argument from brms and returns a matrix
#' of posterior expected values, with one row per posterior draw and one column
#' per observation.
#'
#' @seealso [brms::posterior_epred()] for details on how this is used within
#' `brms`.
#' @autoglobal
#' @family gen
#' @export
epidist_gen_posterior_epred <- function(family) {
.get_brms_fn("posterior_epred", family)
}
7 changes: 0 additions & 7 deletions R/globals.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
# Generated by roxyglobals: do not edit by hand

utils::globalVariables(c(
".data", # <epidist_diagnostics>
"samples", # <epidist_diagnostics>
".data", # <as_epidist_latent_model.epidist_linelist_data>
"woverlap", # <epidist_stancode.epidist_latent_model>
".data", # <as_epidist_naive_model.epidist_linelist_data>
".data", # <add_mean_sd.lognormal_samples>
".data", # <add_mean_sd.gamma_samples>
"rlnorm", # <simulate_secondary>
".data", # <simulate_secondary>
".data", # <.replace_prior>
"prior_new", # <.replace_prior>
"source_new", # <.replace_prior>
NULL
Expand Down
84 changes: 0 additions & 84 deletions R/latent_gamma.R

This file was deleted.

87 changes: 0 additions & 87 deletions R/latent_lognormal.R

This file was deleted.

66 changes: 65 additions & 1 deletion R/latent_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,76 @@ epidist_family_model.epidist_latent_model <- function(
ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))),
type = family$type,
vars = c("pwindow", "swindow", "vreal1"),
loop = FALSE
loop = FALSE,
log_lik = epidist_gen_log_lik_latent(family),
posterior_predict = epidist_gen_posterior_predict(family),
posterior_epred = epidist_gen_posterior_epred(family)
)
custom_family$reparm <- family$reparm
return(custom_family)
}

#' Create a function to calculate the pointwise log likelihood of the latent
#' model
#'
#' This function creates a log likelihood function that accounts for the latent
#' variables in the model, including primary and secondary event windows and
#' their overlap. The returned function calculates the log likelihood for a
#' single observation by augmenting the data with the latent variables and
#' using the underlying brms log likelihood function.
#'
#' @seealso [brms::log_lik()] for details on the brms log likelihood interface.
#'
#' @inheritParams epidist_family
#'
#' @return A function that calculates the log likelihood for a single
#' observation. The prep object must have the following variables:
#' * `vreal1`: relative observation time
#' * `vreal2`: primary event window
#' * `vreal3`: secondary event window
#'
#' @family latent_model
#' @autoglobal
epidist_gen_log_lik_latent <- function(family) {
# Get internal brms log_lik function
log_lik_brms <- .get_brms_fn("log_lik", family)

.log_lik <- function(i, prep) {
y <- prep$data$Y[i]
relative_obs_time <- prep$data$vreal1[i]
pwindow <- prep$data$vreal2[i]
swindow <- prep$data$vreal3[i]

# Generates values of the swindow_raw and pwindow_raw, but really these
# should be extracted from prep or the fitted raws somehow. See:
# https://github.com/epinowcast/epidist/issues/267
swindow_raw <- stats::runif(prep$ndraws)
pwindow_raw <- stats::runif(prep$ndraws)

swindow <- swindow_raw * swindow

# For no overlap calculate as usual, for overlap ensure pwindow < swindow
if (i %in% prep$data$noverlap) {
pwindow <- pwindow_raw * pwindow
} else {
pwindow <- pwindow_raw * swindow
}

d <- y - pwindow + swindow
obs_time <- relative_obs_time - pwindow
# Create brms truncation upper bound
prep$data$ub <- rep(obs_time, length(prep$data$Y))
# Update augmented data
prep$data$Y <- rep(d, length(prep$data$Y))

# Call internal brms log_lik function with augmented data
lpdf <- log_lik_brms(i, prep)
return(lpdf)
}

return(.log_lik)
}

#' Define the model-specific component of an `epidist` custom formula
#'
#' @inheritParams epidist_formula_model
Expand Down
19 changes: 19 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,22 @@

return(df)
}

#' Get a brms function by prefix and family
#'
#' Helper function to get internal brms functions by constructing their name
#' from a prefix and family. Used to get functions like `log_lik_*`,
#' `posterior_predict_*` etc.
#'
#' @param prefix Character string prefix of the brms function to get (e.g.
#' "log_lik")
#'
#' @inheritParams epidist_family
#' @return The requested brms function
#' @keywords internal
.get_brms_fn <- function(prefix, family) {
get(
paste0(prefix, "_", tolower(family$family)),
asNamespace("brms")
)
}
Loading

0 comments on commit dc5779a

Please sign in to comment.