diff --git a/R/create.R b/R/create.R index 5eac4a5dc..07dda74eb 100644 --- a/R/create.R +++ b/R/create.R @@ -411,35 +411,6 @@ create_obs_model <- function(obs = obs_opts(), dates) { return(data) } -##' Create forecast settings -##' -##' @param forecast A list of options as generated by [forecast_opts()] defining -##' the forecast opitions. Defaults to [forecast_opts()]. If NULL then no -##' forecasting will be done. -##' @inheritParams create_stan_data -##' @return A list of settings ready to be passed to stan defining -##' the Observation Model -##' @keywords internal -create_forecast_data <- function(forecast = forecast_opts(), data) { - if (forecast$infer_accumulate && any(data$accumulate)) { - accumulation_times <- which(!data$accumulate) - gaps <- unique(diff(accumulation_times)) - if (length(gaps) == 1 && gaps > 1) { ## all gaps are the same - forecast$accumulate <- gaps - cli_inform(c( - "i" = "Forecasts accumulated every {gaps} days, same as accumulation - used in the likelihood. To change this behaviour or silence this - message set {.var accumulate} explicitly in {.fn forecast_opts}." - )) - } - } - data <- list( - horizon = forecast$horizon, - future_accumulate = forecast$accumulate - ) - return(data) -} - #' Create Stan Data Required for estimate_infections #' #' @description`r lifecycle::badge("stable")` @@ -458,7 +429,6 @@ create_forecast_data <- function(forecast = forecast_opts(), data) { #' @inheritParams create_obs_model #' @inheritParams create_rt_data #' @inheritParams create_backcalc_data -#' @inheritParams create_forecast_data #' @importFrom stats lm #' @importFrom purrr safely #' @return A list of stan data @@ -490,12 +460,8 @@ create_stan_data <- function(data, seeding_time, rt, gp, obs, backcalc, shifted_cases = shifted_cases, t = length(data$date), burn_in = 0, - seeding_time = seeding_time - ) - # add forecast data - stan_data <- c( - stan_data, - create_forecast_data(forecast, cases) + seeding_time = seeding_time, + horizon = forecast$horizon ) # add Rt data stan_data <- c( diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 95c875f3c..f3e7b7af6 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -43,6 +43,10 @@ #' used as the `truncation` argument here, thereby propagating the uncertainty #' in the estimate. #' +#' @param forecast A list of options as generated by [forecast_opts()] defining +#' the forecast opitions. Defaults to [forecast_opts()]. If NULL then no +#' forecasting will be done. +#' #' @param horizon Deprecated; use `forecast` instead to specify the predictive #' horizon #' @@ -65,7 +69,6 @@ #' [estimate_truncation()] #' @inheritParams create_stan_args #' @inheritParams create_stan_data -#' @inheritParams create_forecast_data #' @inheritParams create_gp_data #' @inheritParams fit_model_with_nuts #' @inheritParams create_clean_reported_cases @@ -225,9 +228,14 @@ estimate_infections <- function(data, ## add forecast horizon if forecasting is required if (forecast$horizon > 0) { - reported_cases <- add_horizon( - reported_cases, forecast$horizon, forecast$accumulate + horizon_args <- list( + data = reported_cases, + horizon = forecast$horizon ) + if (!is.null(forecast$accumulate)) { + horizon_args$accumulate <- forecast$accumulate + } + reported_cases <- do.call(add_horizon, horizon_args) } # Create clean and complete cases diff --git a/R/opts.R b/R/opts.R index 95e5c96fa..45cca597b 100644 --- a/R/opts.R +++ b/R/opts.R @@ -1113,13 +1113,11 @@ stan_opts <- function(object = NULL, #' forecast_opts(horizon = 28, accumulate = 7) forecast_opts <- function(horizon = 7, accumulate) { opts <- list( - horizon = horizon, - infer_accumulate = missing(accumulate) + horizon = horizon ) - if (missing(accumulate)) { - accumulate <- 1 + if (!missing(accumulate)) { + opts$accumulate <- accumulate } - opts$accumulate <- accumulate attr(opts, "class") <- c("forecast_opts", class(opts)) return(opts) } diff --git a/R/preprocessing.R b/R/preprocessing.R index a7aa7ac71..6056f8022 100644 --- a/R/preprocessing.R +++ b/R/preprocessing.R @@ -185,7 +185,8 @@ default_fill_missing_obs <- function(data, obs, obs_column) { ##' [estimate_secondary()]. See the documentation there for the expected ##' format. ##' @param accumulate The number of days to accumulate when generating posterior -##' prediction, e.g. 7 for weekly accumulated forecasts. +##' prediction, e.g. 7 for weekly accumulated forecasts. If this is not set an +##' attempt will be made to detect the accumulation frequency in the data. ##' @inheritParams fill_missing ##' @inheritParams estimate_infections ##' @importFrom data.table copy merge.data.table setDT @@ -210,8 +211,21 @@ add_horizon <- function(data, horizon, accumulate = 1L, .(date = seq(max(date) + 1, max(date) + horizon, by = "days")), by = by ] - ## if we accumulate add the column + ## detect accumulation + if (missing(accumulate) && "accumulate" %in% colnames(data)) { + accumulation_times <- which(!data$accumulate) + gaps <- unique(diff(accumulation_times)) + if (length(gaps) == 1 && gaps > 1) { ## all gaps are the same + accumulate <- gaps + cli_inform(c( + "i" = "Forecasts accumulated every {gaps} days, same as accumulation + used in the likelihood. To change this behaviour or silence this + message set {.var accumulate} explicitly in {.fn forecast_opts}." + )) + } + } if (accumulate > 1 || "accumulate" %in% colnames(data)) { + ## if we accumulate add the column initial_future_accumulate <- sum(cumsum(rev(!data$accumulate)) == 0) reported_cases_future[, accumulate := TRUE] ## set accumulation to FALSE where appropriate diff --git a/man/add_horizon.Rd b/man/add_horizon.Rd index 72b5ad18d..f3590cdef 100644 --- a/man/add_horizon.Rd +++ b/man/add_horizon.Rd @@ -16,7 +16,8 @@ format.} horizon} \item{accumulate}{The number of days to accumulate when generating posterior -prediction, e.g. 7 for weekly accumulated forecasts.} +prediction, e.g. 7 for weekly accumulated forecasts. If this is not set an +attempt will be made to detect the accumulation frequency in the data.} \item{obs_column}{Character (default: "confirm"). If given, only the column specified here will be used for checking missingness. This is useful if diff --git a/man/create_forecast_data.Rd b/man/create_forecast_data.Rd deleted file mode 100644 index 2e2ae5bb0..000000000 --- a/man/create_forecast_data.Rd +++ /dev/null @@ -1,30 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/create.R -\name{create_forecast_data} -\alias{create_forecast_data} -\title{Create forecast settings} -\usage{ -create_forecast_data(forecast = forecast_opts(), data) -} -\arguments{ -\item{forecast}{A list of options as generated by \code{\link[=forecast_opts]{forecast_opts()}} defining -the forecast opitions. Defaults to \code{\link[=forecast_opts]{forecast_opts()}}. If NULL then no -forecasting will be done.} - -\item{data}{A \verb{} of disease reports (confirm) by date (date). -\code{confirm} must be numeric and \code{date} must be in date format. Optionally -this can also have a logical \code{accumulate} column which indicates whether -data should be added to the next data point. This is useful when modelling -e.g. weekly incidence data. See also the \code{\link[=fill_missing]{fill_missing()}} function which -helps add the \code{accumulate} column with the desired properties when dealing -with non-daily data. If any accumulation is done this happens after -truncation as specified by the \code{truncation} argument.} -} -\value{ -A list of settings ready to be passed to stan defining -the Observation Model -} -\description{ -Create forecast settings -} -\keyword{internal} diff --git a/tests/testthat/test-create_forecast_data.R b/tests/testthat/test-create_forecast_data.R deleted file mode 100644 index 31acda247..000000000 --- a/tests/testthat/test-create_forecast_data.R +++ /dev/null @@ -1,27 +0,0 @@ -test_that("create_forecast_data returns expected default values", { - result <- create_forecast_data(data = example_confirmed) - - expect_type(result, "list") - expect_equal(result$horizon, 7) - expect_equal(result$future_accumulate, 1) -}) - -test_that("create_rt_data identifies gaps correctly", { - reported_weekly <- suppressWarnings(fill_missing( - example_confirmed[seq(1, 60, by = 7)], - missing_dates = "accumulate" - )) - expect_message( - result <- create_forecast_data(data = reported_weekly), - "same as accumulation used in the likelihood" - ) - expect_equal(result$future_accumulate, 7) -}) - -test_that("create_rt_data doesn't try to identify non-equally spaced gaps", { - reported_irregular <- example_confirmed[c(seq(1, 43, by = 7), 45)] - expect_no_message( - result <- create_forecast_data(data = reported_irregular) - ) - expect_equal(result$future_accumulate, 1) -}) diff --git a/tests/testthat/test-forecast-opts.R b/tests/testthat/test-forecast-opts.R index 6e5ac5ac5..005c3cb70 100644 --- a/tests/testthat/test-forecast-opts.R +++ b/tests/testthat/test-forecast-opts.R @@ -1,11 +1,10 @@ test_that("forecast_opts returns correct default values", { forecast <- forecast_opts() expect_equal(forecast$horizon, 7) - expect_equal(forecast$accumulate, 1) - expect_equal(forecast$infer_accumulate, TRUE) + expect_null(forecast$accumulate) }) test_that("forecast_opts sets infer_accumulate to FALSE if accumulate is given", { forecast <- forecast_opts(accumulate = 7) - expect_equal(forecast$infer_accumulate, FALSE) + expect_equal(forecast$accumulate, 7) }) diff --git a/tests/testthat/test-preprocessing.R b/tests/testthat/test-preprocessing.R index 02b1621c9..667ae5f6f 100644 --- a/tests/testthat/test-preprocessing.R +++ b/tests/testthat/test-preprocessing.R @@ -40,3 +40,23 @@ test_that("add_horizon works", { "Initial data point not marked as accumulated" ) }) + +test_that("add_horizon identifies gaps correctly", { + filled <- fill_missing(cases, missing_dates = "accumulate", initial_accumulate = 7) + expect_message( + result <- add_horizon(filled, horizon = 7), + "Forecasts accumulated every 7 days" + ) + result <- add_horizon(filled, horizon = 7, accumulate = 7) + expect_true(all(result[seq(.N - 6, .N - 1), accumulate])) + expect_false(result[.N, accumulate]) +}) + +test_that("add_horizon doesn't try to identify non-equally spaced gaps", { + reported_irregular <- example_confirmed[c(seq(1, 43, by = 7), 45)] + filled <- suppressWarnings( + fill_missing(reported_irregular, missing_dates = "accumulate") + ) + result <- add_horizon(filled, horizon = 7) + expect_false(any(result[seq(.N - 6, .N), accumulate])) +})