Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

934: fix detection of gaps for forecasts #935

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 2 additions & 36 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -411,35 +411,6 @@ create_obs_model <- function(obs = obs_opts(), dates) {
return(data)
}

##' Create forecast settings
##'
##' @param forecast A list of options as generated by [forecast_opts()] defining
##' the forecast opitions. Defaults to [forecast_opts()]. If NULL then no
##' forecasting will be done.
##' @inheritParams create_stan_data
##' @return A list of settings ready to be passed to stan defining
##' the Observation Model
##' @keywords internal
create_forecast_data <- function(forecast = forecast_opts(), data) {
if (forecast$infer_accumulate && any(data$accumulate)) {
accumulation_times <- which(!data$accumulate)
gaps <- unique(diff(accumulation_times))
if (length(gaps) == 1 && gaps > 1) { ## all gaps are the same
forecast$accumulate <- gaps
cli_inform(c(
"i" = "Forecasts accumulated every {gaps} days, same as accumulation
used in the likelihood. To change this behaviour or silence this
message set {.var accumulate} explicitly in {.fn forecast_opts}."
))
}
}
data <- list(
horizon = forecast$horizon,
future_accumulate = forecast$accumulate
)
return(data)
}

#' Create Stan Data Required for estimate_infections
#'
#' @description`r lifecycle::badge("stable")`
Expand All @@ -458,7 +429,6 @@ create_forecast_data <- function(forecast = forecast_opts(), data) {
#' @inheritParams create_obs_model
#' @inheritParams create_rt_data
#' @inheritParams create_backcalc_data
#' @inheritParams create_forecast_data
#' @importFrom stats lm
#' @importFrom purrr safely
#' @return A list of stan data
Expand Down Expand Up @@ -490,12 +460,8 @@ create_stan_data <- function(data, seeding_time, rt, gp, obs, backcalc,
shifted_cases = shifted_cases,
t = length(data$date),
burn_in = 0,
seeding_time = seeding_time
)
# add forecast data
stan_data <- c(
stan_data,
create_forecast_data(forecast, cases)
seeding_time = seeding_time,
horizon = forecast$horizon
)
# add Rt data
stan_data <- c(
Expand Down
14 changes: 11 additions & 3 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
#' used as the `truncation` argument here, thereby propagating the uncertainty
#' in the estimate.
#'
#' @param forecast A list of options as generated by [forecast_opts()] defining
#' the forecast opitions. Defaults to [forecast_opts()]. If NULL then no
#' forecasting will be done.
#'
#' @param horizon Deprecated; use `forecast` instead to specify the predictive
#' horizon
#'
Expand All @@ -65,7 +69,6 @@
#' [estimate_truncation()]
#' @inheritParams create_stan_args
#' @inheritParams create_stan_data
#' @inheritParams create_forecast_data
#' @inheritParams create_gp_data
#' @inheritParams fit_model_with_nuts
#' @inheritParams create_clean_reported_cases
Expand Down Expand Up @@ -225,9 +228,14 @@ estimate_infections <- function(data,

## add forecast horizon if forecasting is required
if (forecast$horizon > 0) {
reported_cases <- add_horizon(
reported_cases, forecast$horizon, forecast$accumulate
horizon_args <- list(
data = reported_cases,
horizon = forecast$horizon
)
if (!is.null(forecast$accumulate)) {
horizon_args$accumulate <- forecast$accumulate
}
reported_cases <- do.call(add_horizon, horizon_args)
}

# Create clean and complete cases
Expand Down
8 changes: 3 additions & 5 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -1113,13 +1113,11 @@ stan_opts <- function(object = NULL,
#' forecast_opts(horizon = 28, accumulate = 7)
forecast_opts <- function(horizon = 7, accumulate) {
opts <- list(
horizon = horizon,
infer_accumulate = missing(accumulate)
horizon = horizon
)
if (missing(accumulate)) {
accumulate <- 1
if (!missing(accumulate)) {
opts$accumulate <- accumulate
}
opts$accumulate <- accumulate
attr(opts, "class") <- c("forecast_opts", class(opts))
return(opts)
}
Expand Down
18 changes: 16 additions & 2 deletions R/preprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ default_fill_missing_obs <- function(data, obs, obs_column) {
##' [estimate_secondary()]. See the documentation there for the expected
##' format.
##' @param accumulate The number of days to accumulate when generating posterior
##' prediction, e.g. 7 for weekly accumulated forecasts.
##' prediction, e.g. 7 for weekly accumulated forecasts. If this is not set an
##' attempt will be made to detect the accumulation frequency in the data.
##' @inheritParams fill_missing
##' @inheritParams estimate_infections
##' @importFrom data.table copy merge.data.table setDT
Expand All @@ -210,8 +211,21 @@ add_horizon <- function(data, horizon, accumulate = 1L,
.(date = seq(max(date) + 1, max(date) + horizon, by = "days")),
by = by
]
## if we accumulate add the column
## detect accumulation
if (missing(accumulate) && "accumulate" %in% colnames(data)) {
accumulation_times <- which(!data$accumulate)
gaps <- unique(diff(accumulation_times))
if (length(gaps) == 1 && gaps > 1) { ## all gaps are the same
accumulate <- gaps
cli_inform(c(
"i" = "Forecasts accumulated every {gaps} days, same as accumulation
used in the likelihood. To change this behaviour or silence this
message set {.var accumulate} explicitly in {.fn forecast_opts}."
))
}
}
if (accumulate > 1 || "accumulate" %in% colnames(data)) {
## if we accumulate add the column
initial_future_accumulate <- sum(cumsum(rev(!data$accumulate)) == 0)
reported_cases_future[, accumulate := TRUE]
## set accumulation to FALSE where appropriate
Expand Down
3 changes: 2 additions & 1 deletion man/add_horizon.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 0 additions & 30 deletions man/create_forecast_data.Rd

This file was deleted.

2 changes: 1 addition & 1 deletion man/create_stan_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/epinow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/estimate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/forecast_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/regional_epinow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 0 additions & 27 deletions tests/testthat/test-create_forecast_data.R

This file was deleted.

5 changes: 2 additions & 3 deletions tests/testthat/test-forecast-opts.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
test_that("forecast_opts returns correct default values", {
forecast <- forecast_opts()
expect_equal(forecast$horizon, 7)
expect_equal(forecast$accumulate, 1)
expect_equal(forecast$infer_accumulate, TRUE)
expect_null(forecast$accumulate)
})

test_that("forecast_opts sets infer_accumulate to FALSE if accumulate is given", {
forecast <- forecast_opts(accumulate = 7)
expect_equal(forecast$infer_accumulate, FALSE)
expect_equal(forecast$accumulate, 7)
})
20 changes: 20 additions & 0 deletions tests/testthat/test-preprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,23 @@ test_that("add_horizon works", {
"Initial data point not marked as accumulated"
)
})

test_that("add_horizon identifies gaps correctly", {
filled <- fill_missing(cases, missing_dates = "accumulate", initial_accumulate = 7)
expect_message(
result <- add_horizon(filled, horizon = 7),
"Forecasts accumulated every 7 days"
)
result <- add_horizon(filled, horizon = 7, accumulate = 7)
expect_true(all(result[seq(.N - 6, .N - 1), accumulate]))
expect_false(result[.N, accumulate])
})

test_that("add_horizon doesn't try to identify non-equally spaced gaps", {
reported_irregular <- example_confirmed[c(seq(1, 43, by = 7), 45)]
filled <- suppressWarnings(
fill_missing(reported_irregular, missing_dates = "accumulate")
)
result <- add_horizon(filled, horizon = 7)
expect_false(any(result[seq(.N - 6, .N), accumulate]))
})
Loading