Skip to content

Commit

Permalink
934: fix detection of gaps for forecasts (#935)
Browse files Browse the repository at this point in the history
* fix detection of gaps for forecasts

* update docs

* update docs
  • Loading branch information
sbfnk authored Jan 30, 2025
1 parent 046857e commit a130dac
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 107 deletions.
38 changes: 2 additions & 36 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -411,35 +411,6 @@ create_obs_model <- function(obs = obs_opts(), dates) {
return(data)
}

##' Create forecast settings
##'
##' @param forecast A list of options as generated by [forecast_opts()] defining
##' the forecast opitions. Defaults to [forecast_opts()]. If NULL then no
##' forecasting will be done.
##' @inheritParams create_stan_data
##' @return A list of settings ready to be passed to stan defining
##' the Observation Model
##' @keywords internal
create_forecast_data <- function(forecast = forecast_opts(), data) {
if (forecast$infer_accumulate && any(data$accumulate)) {
accumulation_times <- which(!data$accumulate)
gaps <- unique(diff(accumulation_times))
if (length(gaps) == 1 && gaps > 1) { ## all gaps are the same
forecast$accumulate <- gaps
cli_inform(c(
"i" = "Forecasts accumulated every {gaps} days, same as accumulation
used in the likelihood. To change this behaviour or silence this
message set {.var accumulate} explicitly in {.fn forecast_opts}."
))
}
}
data <- list(
horizon = forecast$horizon,
future_accumulate = forecast$accumulate
)
return(data)
}

#' Create Stan Data Required for estimate_infections
#'
#' @description`r lifecycle::badge("stable")`
Expand All @@ -458,7 +429,6 @@ create_forecast_data <- function(forecast = forecast_opts(), data) {
#' @inheritParams create_obs_model
#' @inheritParams create_rt_data
#' @inheritParams create_backcalc_data
#' @inheritParams create_forecast_data
#' @importFrom stats lm
#' @importFrom purrr safely
#' @return A list of stan data
Expand Down Expand Up @@ -490,12 +460,8 @@ create_stan_data <- function(data, seeding_time, rt, gp, obs, backcalc,
shifted_cases = shifted_cases,
t = length(data$date),
burn_in = 0,
seeding_time = seeding_time
)
# add forecast data
stan_data <- c(
stan_data,
create_forecast_data(forecast, cases)
seeding_time = seeding_time,
horizon = forecast$horizon
)
# add Rt data
stan_data <- c(
Expand Down
14 changes: 11 additions & 3 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
#' used as the `truncation` argument here, thereby propagating the uncertainty
#' in the estimate.
#'
#' @param forecast A list of options as generated by [forecast_opts()] defining
#' the forecast opitions. Defaults to [forecast_opts()]. If NULL then no
#' forecasting will be done.
#'
#' @param horizon Deprecated; use `forecast` instead to specify the predictive
#' horizon
#'
Expand All @@ -65,7 +69,6 @@
#' [estimate_truncation()]
#' @inheritParams create_stan_args
#' @inheritParams create_stan_data
#' @inheritParams create_forecast_data
#' @inheritParams create_gp_data
#' @inheritParams fit_model_with_nuts
#' @inheritParams create_clean_reported_cases
Expand Down Expand Up @@ -225,9 +228,14 @@ estimate_infections <- function(data,

## add forecast horizon if forecasting is required
if (forecast$horizon > 0) {
reported_cases <- add_horizon(
reported_cases, forecast$horizon, forecast$accumulate
horizon_args <- list(
data = reported_cases,
horizon = forecast$horizon
)
if (!is.null(forecast$accumulate)) {
horizon_args$accumulate <- forecast$accumulate
}
reported_cases <- do.call(add_horizon, horizon_args)
}

# Create clean and complete cases
Expand Down
8 changes: 3 additions & 5 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -1113,13 +1113,11 @@ stan_opts <- function(object = NULL,
#' forecast_opts(horizon = 28, accumulate = 7)
forecast_opts <- function(horizon = 7, accumulate) {
opts <- list(
horizon = horizon,
infer_accumulate = missing(accumulate)
horizon = horizon
)
if (missing(accumulate)) {
accumulate <- 1
if (!missing(accumulate)) {
opts$accumulate <- accumulate
}
opts$accumulate <- accumulate
attr(opts, "class") <- c("forecast_opts", class(opts))
return(opts)
}
Expand Down
18 changes: 16 additions & 2 deletions R/preprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ default_fill_missing_obs <- function(data, obs, obs_column) {
##' [estimate_secondary()]. See the documentation there for the expected
##' format.
##' @param accumulate The number of days to accumulate when generating posterior
##' prediction, e.g. 7 for weekly accumulated forecasts.
##' prediction, e.g. 7 for weekly accumulated forecasts. If this is not set an
##' attempt will be made to detect the accumulation frequency in the data.
##' @inheritParams fill_missing
##' @inheritParams estimate_infections
##' @importFrom data.table copy merge.data.table setDT
Expand All @@ -210,8 +211,21 @@ add_horizon <- function(data, horizon, accumulate = 1L,
.(date = seq(max(date) + 1, max(date) + horizon, by = "days")),
by = by
]
## if we accumulate add the column
## detect accumulation
if (missing(accumulate) && "accumulate" %in% colnames(data)) {
accumulation_times <- which(!data$accumulate)
gaps <- unique(diff(accumulation_times))
if (length(gaps) == 1 && gaps > 1) { ## all gaps are the same
accumulate <- gaps
cli_inform(c(
"i" = "Forecasts accumulated every {gaps} days, same as accumulation
used in the likelihood. To change this behaviour or silence this
message set {.var accumulate} explicitly in {.fn forecast_opts}."
))
}
}
if (accumulate > 1 || "accumulate" %in% colnames(data)) {
## if we accumulate add the column
initial_future_accumulate <- sum(cumsum(rev(!data$accumulate)) == 0)
reported_cases_future[, accumulate := TRUE]
## set accumulation to FALSE where appropriate
Expand Down
3 changes: 2 additions & 1 deletion man/add_horizon.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 0 additions & 30 deletions man/create_forecast_data.Rd

This file was deleted.

27 changes: 0 additions & 27 deletions tests/testthat/test-create_forecast_data.R

This file was deleted.

5 changes: 2 additions & 3 deletions tests/testthat/test-forecast-opts.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
test_that("forecast_opts returns correct default values", {
forecast <- forecast_opts()
expect_equal(forecast$horizon, 7)
expect_equal(forecast$accumulate, 1)
expect_equal(forecast$infer_accumulate, TRUE)
expect_null(forecast$accumulate)
})

test_that("forecast_opts sets infer_accumulate to FALSE if accumulate is given", {
forecast <- forecast_opts(accumulate = 7)
expect_equal(forecast$infer_accumulate, FALSE)
expect_equal(forecast$accumulate, 7)
})
20 changes: 20 additions & 0 deletions tests/testthat/test-preprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,23 @@ test_that("add_horizon works", {
"Initial data point not marked as accumulated"
)
})

test_that("add_horizon identifies gaps correctly", {
filled <- fill_missing(cases, missing_dates = "accumulate", initial_accumulate = 7)
expect_message(
result <- add_horizon(filled, horizon = 7),
"Forecasts accumulated every 7 days"
)
result <- add_horizon(filled, horizon = 7, accumulate = 7)
expect_true(all(result[seq(.N - 6, .N - 1), accumulate]))
expect_false(result[.N, accumulate])
})

test_that("add_horizon doesn't try to identify non-equally spaced gaps", {
reported_irregular <- example_confirmed[c(seq(1, 43, by = 7), 45)]
filled <- suppressWarnings(
fill_missing(reported_irregular, missing_dates = "accumulate")
)
result <- add_horizon(filled, horizon = 7)
expect_false(any(result[seq(.N - 6, .N), accumulate]))
})

0 comments on commit a130dac

Please sign in to comment.