Skip to content

Commit

Permalink
Issue #388: Refactor preprocessing functionality (#390)
Browse files Browse the repository at this point in the history
* Sketching out refactor

* Some restructure for clarity

* Refactor along new date approach

* Clear tests

* Refactor lines to save col_names

* Refactor validation functionality

* Redocument

* Remove epidist_validate and move to _model and _data approach plus some linting

* Add documentation of as_epidist_linelist arguments

* Move assert_class into imports and use in place of "check" class

* Documentation for epidist_validate_data.epidist_linelist

* Clear up the direct model file a bit

* Add creating the row_id back in to as_latent_individual

* Passing test-direct_model

* Start working to make data use dates

* Add start of unit tests and bug fix for datetime class check

* Use .row_id rather than row_id

* Use as_epidist_linelist_time function so that tests work with time data

* Fixes to tests

* Group into preprocessing functions

* Update FAQ vignette to run

* Update get started vignette to run

* Update ebola vignette to run

* Update approximate inference vignette to run

* Add documentation

* Methods consistency

* Document ...

* Again on ...

* Remove comment moved to issue

* Include as_epidist_linelist_time ad-hoc

* Add test for datetime column

* Update text in vignettes and add note about the ad-hoc function being included in package soon

* Refactor .rename_columns

Former-commit-id: c573ba836b76170c03d5c493cbb378781db5fa23 [formerly bac50e38d758dfe0fdcfd98722dc50a5a98c0357]
Former-commit-id: 4e9d3ee55e7e0c90bd35366990f38aa4894b3439
Former-commit-id: 84a5299
athowes authored Nov 13, 2024

Verified

This commit was signed with the committer’s verified signature.
alexsnaps Alex Snaps
1 parent 1c490c4 commit 31fe7ae
Showing 42 changed files with 446 additions and 398 deletions.
17 changes: 11 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ S3method(add_mean_sd,default)
S3method(add_mean_sd,gamma_samples)
S3method(add_mean_sd,lognormal_samples)
S3method(as_direct_model,data.frame)
S3method(as_latent_individual,data.frame)
S3method(as_latent_individual,epidist_linelist)
S3method(epidist,default)
S3method(epidist_family_model,default)
S3method(epidist_family_model,epidist_latent_individual)
@@ -17,12 +17,14 @@ S3method(epidist_formula_model,epidist_latent_individual)
S3method(epidist_model_prior,default)
S3method(epidist_stancode,default)
S3method(epidist_stancode,epidist_latent_individual)
S3method(epidist_validate,default)
S3method(epidist_validate,epidist_direct_model)
S3method(epidist_validate,epidist_latent_individual)
export(add_event_vars)
S3method(epidist_validate_data,default)
S3method(epidist_validate_data,epidist_linelist)
S3method(epidist_validate_model,default)
S3method(epidist_validate_model,epidist_direct_model)
S3method(epidist_validate_model,epidist_latent_individual)
export(add_mean_sd)
export(as_direct_model)
export(as_epidist_linelist)
export(as_latent_individual)
export(epidist)
export(epidist_diagnostics)
@@ -35,10 +37,12 @@ export(epidist_formula_model)
export(epidist_model_prior)
export(epidist_prior)
export(epidist_stancode)
export(epidist_validate)
export(epidist_validate_data)
export(epidist_validate_model)
export(filter_obs_by_obs_time)
export(filter_obs_by_ptime)
export(is_direct_model)
export(is_epidist_linelist)
export(is_latent_individual)
export(observe_process)
export(predict_delay_parameters)
@@ -50,6 +54,7 @@ export(simulate_uniform_cases)
import(ggplot2)
importFrom(brms,bf)
importFrom(brms,prior)
importFrom(checkmate,assert_class)
importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_factor)
importFrom(checkmate,assert_integer)
26 changes: 3 additions & 23 deletions R/direct_model.R
Original file line number Diff line number Diff line change
@@ -15,17 +15,6 @@ assert_direct_model_input <- function(data) {
assert_numeric(data$stime, lower = 0)
}

#' Prepare latent individual model
#'
#' This function prepares data for use with the direct model. It does this by
#' adding columns used in the model to the `data` object provided. To do this,
#' the `data` must already have columns for the case number (integer),
#' (positive, numeric) times for the primary and secondary event times. The
#' output of this function is a `epidist_direct_model` class object, which may
#' be passed to [epidist()] to perform inference for the model.
#'
#' @param data A `data.frame` containing line list data
#' @rdname as_direct_model
#' @method as_direct_model data.frame
#' @family direct_model
#' @autoglobal
@@ -35,23 +24,14 @@ as_direct_model.data.frame <- function(data) {
class(data) <- c("epidist_direct_model", class(data))
data <- data |>
mutate(delay = .data$stime - .data$ptime)
epidist_validate(data)
epidist_validate_model(data)
return(data)
}

#' Validate direct model data
#'
#' This function checks whether the provided `data` object is suitable for
#' running the direct model. As well as making sure that
#' `is_direct_model()` is true, it also checks that `data` is a `data.frame`
#' with the correct columns.
#'
#' @param data A `data.frame` containing line list data
#' @param ... ...
#' @method epidist_validate epidist_direct_model
#' @method epidist_validate_model epidist_direct_model
#' @family direct_model
#' @export
epidist_validate.epidist_direct_model <- function(data, ...) {
epidist_validate_model.epidist_direct_model <- function(data, ...) {
assert_true(is_direct_model(data))
assert_direct_model_input(data)
assert_names(names(data), must.include = c("case", "ptime", "stime", "delay"))
2 changes: 1 addition & 1 deletion R/epidist-package.R
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
#' @importFrom dplyr filter select
#' @importFrom brms bf prior
#' @importFrom checkmate assert_data_frame assert_names assert_integer
#' assert_true assert_factor assert_numeric
#' assert_true assert_factor assert_numeric assert_class
#' @importFrom cli cli_abort cli_inform cli_abort cli_warn
#' @importFrom stats as.formula
## usethis namespace: end
2 changes: 1 addition & 1 deletion R/family.R
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
#' @family family
#' @export
epidist_family <- function(data, family = "lognormal", ...) {
epidist_validate(data)
epidist_validate_model(data)
family <- brms:::validate_family(family)
class(family) <- c(family$family, class(family))
family <- .add_dpar_info(family)
2 changes: 1 addition & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ epidist <- function(data, formula, family, prior, backend, fn, ...) {
epidist.default <- function(data, formula = mu ~ 1,
family = "lognormal", prior = NULL,
backend = "cmdstanr", fn = brms::brm, ...) {
epidist_validate(data)
epidist_validate_model(data)
epidist_family <- epidist_family(data, family)
epidist_formula <- epidist_formula(
data = data, family = epidist_family, formula = formula
2 changes: 1 addition & 1 deletion R/formula.R
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
#' @family formula
#' @export
epidist_formula <- function(data, family, formula, ...) {
epidist_validate(data)
epidist_validate_model(data)
formula <- brms:::validate_formula(formula, data = data, family = family)
formula <- .make_intercepts_explicit(formula)
formula <- epidist_formula_model(data, formula)
3 changes: 0 additions & 3 deletions R/globals.R
Original file line number Diff line number Diff line change
@@ -2,9 +2,6 @@

utils::globalVariables(c(
"samples", # <epidist_diagnostics>
"stime_lwr", # <as_latent_individual.data.frame>
"stime_upr", # <as_latent_individual.data.frame>
"ptime_upr", # <as_latent_individual.data.frame>
"woverlap", # <epidist_stancode.epidist_latent_individual>
":=", # <filter_obs_by_ptime>
"rlnorm", # <simulate_secondary>
83 changes: 18 additions & 65 deletions R/latent_individual.R
Original file line number Diff line number Diff line change
@@ -7,89 +7,42 @@ as_latent_individual <- function(data) {
UseMethod("as_latent_individual")
}

assert_latent_individual_input <- function(data) {
assert_data_frame(data)
assert_names(
names(data),
must.include = c("case", "ptime_lwr", "ptime_upr",
"stime_lwr", "stime_upr", "obs_time")
)
assert_integer(data$case, lower = 0)
assert_numeric(data$ptime_lwr, lower = 0)
assert_numeric(data$ptime_upr, lower = 0)
assert_true(all(data$ptime_upr - data$ptime_lwr > 0))
assert_numeric(data$stime_lwr, lower = 0)
assert_numeric(data$stime_upr, lower = 0)
assert_true(all(data$stime_upr - data$stime_lwr > 0))
assert_numeric(data$obs_time, lower = 0)
}

#' Prepare latent individual model
#'
#' This function prepares data for use with the latent individual model. It does
#' this by adding columns used in the model to the `data` object provided. To do
#' this, the `data` must already have columns for the case number (integer),
#' (positive, numeric) upper and lower bounds for the primary and secondary
#' event times, as well as a (positive, numeric) time that observation takes
#' place. The output of this function is a `epidist_latent_individual` class
#' object, which may be passed to [epidist()] to perform inference for the
#' model.
#'
#' @param data A `data.frame` containing line list data
#' @rdname as_latent_individual
#' @method as_latent_individual data.frame
#' @method as_latent_individual epidist_linelist
#' @family latent_individual
#' @autoglobal
#' @export
as_latent_individual.data.frame <- function(data) {
assert_latent_individual_input(data)
as_latent_individual.epidist_linelist <- function(data) {
epidist_validate_data(data)
class(data) <- c("epidist_latent_individual", class(data))
data <- data |>
mutate(
relative_obs_time = .data$obs_time - .data$ptime_lwr,
pwindow = ifelse(
stime_lwr < .data$ptime_upr,
stime_upr - .data$ptime_lwr,
ptime_upr - .data$ptime_lwr
.data$stime_lwr < .data$ptime_upr,
.data$stime_upr - .data$ptime_lwr,
.data$ptime_upr - .data$ptime_lwr
),
woverlap = as.numeric(.data$stime_lwr < .data$ptime_upr),
swindow = .data$stime_upr - .data$stime_lwr,
delay = .data$stime_lwr - .data$ptime_lwr,
row_id = dplyr::row_number()
.row_id = dplyr::row_number()
)
if (nrow(data) > 1) {
data <- mutate(data, row_id = factor(.data$row_id))
}
epidist_validate(data)
epidist_validate_model(data)
return(data)
}

#' Validate latent individual model data
#'
#' This function checks whether the provided `data` object is suitable for
#' running the latent individual model. As well as making sure that
#' `is_latent_individual()` is true, it also checks that `data` is a
#' `data.frame` with the correct columns.
#'
#' @param data A `data.frame` containing line list data
#' @param ... ...
#' @method epidist_validate epidist_latent_individual
#' @method epidist_validate_model epidist_latent_individual
#' @family latent_individual
#' @export
epidist_validate.epidist_latent_individual <- function(data, ...) {
epidist_validate_model.epidist_latent_individual <- function(data, ...) {
assert_true(is_latent_individual(data))
assert_latent_individual_input(data)
assert_names(
names(data),
must.include = c("case", "ptime_lwr", "ptime_upr",
"stime_lwr", "stime_upr", "obs_time",
"relative_obs_time", "pwindow", "woverlap",
"swindow", "delay", "row_id")
col_names <- c(
"ptime_lwr", "ptime_upr", "stime_lwr", "stime_upr", "obs_time",
"relative_obs_time", "pwindow", "woverlap", "swindow", "delay", ".row_id"
)
if (nrow(data) > 1) {
assert_factor(data$row_id)
}
assert_names(names(data), must.include = col_names)
assert_numeric(data$relative_obs_time, lower = 0)
# pwindow as f(p) and swindow as f(s) checks here?
assert_numeric(data$pwindow, lower = 0)
assert_numeric(data$woverlap, lower = 0)
assert_numeric(data$swindow, lower = 0)
@@ -159,7 +112,7 @@ epidist_stancode.epidist_latent_individual <- function(data,
epidist_formula(data),
...) {

epidist_validate(data)
epidist_validate_model(data)

stanvars_version <- .version_stanvar()

@@ -202,13 +155,13 @@ epidist_stancode.epidist_latent_individual <- function(data,
brms::stanvar(
block = "data",
scode = "array[N - wN] int noverlap;",
x = filter(data, woverlap == 0)$row_id,
x = filter(data, woverlap == 0)$.row_id,
name = "noverlap"
) +
brms::stanvar(
block = "data",
scode = "array[wN] int woverlap;",
x = filter(data, woverlap > 0)$row_id,
x = filter(data, woverlap > 0)$.row_id,
name = "woverlap"
)

118 changes: 63 additions & 55 deletions R/preprocess.R
Original file line number Diff line number Diff line change
@@ -1,69 +1,77 @@
#' Add columns for interval censoring of primary and secondary events
#' Prepare date data in the `epidist_linelist` format
#'
#' @param linelist ...
#' @param ptime_lwr ...
#' @param ptime_upr ...
#' @param pwindow ...
#' @param stime_lwr ...
#' @param stime_upr ...
#' @param swindow ...
#' @param data A `data.frame` containing line list data
#' @param pdate_lwr,pdate_upr,sdate_lwr,sdate_upr Strings giving the column of
#' `data` containing the primary and secondary event upper and lower bounds.
#' These columns of `data` must be as datetime.
#' @param obs_date A string giving the column of `data` containing the
#' observation time as a datetime.
#' @family preprocess
#' @autoglobal
#' @export
add_event_vars <- function(
linelist, ptime_lwr = NULL, ptime_upr = NULL, pwindow = NULL,
stime_lwr = NULL, stime_upr = NULL, swindow = NULL
as_epidist_linelist <- function(
data, pdate_lwr = NULL, pdate_upr = NULL, sdate_lwr = NULL, sdate_upr = NULL,
obs_date = NULL
) {
linelist <- .rename_column(linelist, "ptime_lwr", ptime_lwr)
linelist <- .rename_column(linelist, "ptime_upr", ptime_upr)
linelist <- .rename_column(linelist, "stime_lwr", stime_lwr)
linelist <- .rename_column(linelist, "stime_upr", stime_upr)
linelist <- .rename_column(linelist, "pwindow", pwindow)
linelist <- .rename_column(linelist, "swindow", swindow)
class(data) <- c("epidist_linelist", class(data))

if (is.numeric(pwindow)) {
cli::cli_warn("Overwriting using numeric value(s) of pwindow provided!")
linelist$pwindow <- pwindow
}

if (is.numeric(swindow)) {
cli::cli_warn("Overwriting using numeric value(s) of swindow provided!")
linelist$swindow <- swindow
}

if (is.null(stime_upr)) {
linelist <- mutate(linelist, stime_upr = stime_lwr + swindow)
}

if (is.null(ptime_upr)) {
linelist <- mutate(linelist, ptime_upr = ptime_lwr + pwindow)
}
data <- .rename_columns(data,
new_names = c(
"pdate_lwr", "pdate_upr", "sdate_lwr", "sdate_upr", "obs_date"
),
old_names = c(pdate_lwr, pdate_upr, sdate_lwr, sdate_upr, obs_date)
)

if (is.null(swindow)) {
linelist <- mutate(linelist, pwindow = stime_upr - stime_lwr)
}
# Check for being a datetime
assert_true(any(inherits(data$pdate_lwr, c("POSIXct", "POSIXlt"))))
assert_true(any(inherits(data$pdate_upr, c("POSIXct", "POSIXlt"))))
assert_true(any(inherits(data$sdate_lwr, c("POSIXct", "POSIXlt"))))
assert_true(any(inherits(data$sdate_upr, c("POSIXct", "POSIXlt"))))
assert_true(any(inherits(data$obs_date, c("POSIXct", "POSIXlt"))))

if (is.null(pwindow)) {
linelist <- mutate(linelist, swindow = ptime_upr - ptime_lwr)
}
# Convert datetime to time
min_date <- min(data$pdate_lwr)

assert_numeric(linelist$ptime_lwr)
assert_numeric(linelist$ptime_upr)
assert_numeric(linelist$pwindow, lower = 0)
assert_true(
all(linelist$ptime_lwr + linelist$pwindow - linelist$ptime_upr < 1e-6)
data <- mutate(data,
ptime_lwr = as.numeric(.data$pdate_lwr - min_date),
ptime_upr = as.numeric(.data$pdate_upr - min_date),
stime_lwr = as.numeric(.data$sdate_lwr - min_date),
stime_upr = as.numeric(.data$sdate_upr - min_date),
obs_time = as.numeric(.data$obs_date - min_date)
)

assert_numeric(linelist$stime_lwr)
assert_numeric(linelist$stime_upr)
assert_numeric(linelist$swindow, lower = 0)
assert_true(
all(linelist$stime_lwr + linelist$swindow - linelist$stime_upr < 1e-6)
)
epidist_validate_data(data)

return(data)
}

linelist <- dplyr::relocate(
linelist, ptime_lwr, ptime_upr, pwindow, stime_lwr, stime_upr, swindow
#' Validation for the `epidist_linelist` class
#'
#' @inheritParams as_epidist_linelist
#' @param ... Additional arguments
#' @family preprocess
#' @export
epidist_validate_data.epidist_linelist <- function(data, ...) {
assert_true(is_epidist_linelist(data))
assert_data_frame(data)
col_names <- c(
"case", "ptime_lwr", "ptime_upr", "stime_lwr", "stime_upr", "obs_time"
)
assert_names(names(data), must.include = col_names)
assert_numeric(data$ptime_lwr, lower = 0)
assert_numeric(data$ptime_upr, lower = 0)
assert_true(all(data$ptime_upr - data$ptime_lwr > 0))
assert_numeric(data$stime_lwr, lower = 0)
assert_numeric(data$stime_upr, lower = 0)
assert_true(all(data$stime_upr - data$stime_lwr > 0))
assert_numeric(data$obs_time, lower = 0)
}

return(linelist)
#' Check if data has the `epidist_linelist` class
#'
#' @inheritParams as_epidist_linelist
#' @param ... Additional arguments
#' @family preprocess
#' @export
is_epidist_linelist <- function(data, ...) {
inherits(data, "epidist_linelist")
}
2 changes: 1 addition & 1 deletion R/prior.R
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
#' @family prior
#' @export
epidist_prior <- function(data, family, formula, prior) {
epidist_validate(data)
epidist_validate_model(data)
default <- brms::default_prior(formula, data = data)
model <- epidist_model_prior(data, formula)
family <- epidist_family_prior(family, formula)
17 changes: 13 additions & 4 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -125,10 +125,19 @@
return(formula)
}

.rename_column <- function(df, new, old) {
are_char <- is.character(new) & is.character(old)
if (are_char) {
df <- dplyr::rename(df, !!new := !!old)
#' Rename the columns of a `data.frame`
#'
#' @param df ...
#' @param new_names ...
#' @param old_names ...
#' @keywords internal
.rename_columns <- function(df, new_names, old_names) {
are_char <- is.character(new_names) & is.character(old_names)
valid_new_names <- new_names[are_char]
valid_old_names <- old_names[are_char]
if (length(are_char) > 0) {
rename_map <- setNames(valid_old_names, valid_new_names)
df <- dplyr::rename(df, !!!rename_map)
}
return(df)
}
41 changes: 31 additions & 10 deletions R/validate.R
Original file line number Diff line number Diff line change
@@ -1,24 +1,45 @@
#' Validate a data object for use with [epidist()]
#' Validate data class
#'
#' This function validates that the provided `data` is suitable to run a
#' particular `epidist` model. This may include checking the class of `data`,
#' and that it contains suitable columns.
#' @inheritParams epidist
#' @param ... Additional arguments
#' @family validate
#' @export
epidist_validate_data <- function(data, ...) {
UseMethod("epidist_validate_data")
}

#' Default method for validate data class
#'
#' @inheritParams epidist
#' @param ... Additional arguments
#' @family validate
#' @export
epidist_validate_data.default <- function(data, ...) {
cli_abort(
"No epidist_validate_data method implemented for the class ", class(data),
"\n", "See methods(epidist_validate_data) for available methods"
)
}

#' Validate model class
#'
#' @inheritParams epidist
#' @param ... Additional arguments
#' @family validate
#' @export
epidist_validate <- function(data, ...) {
UseMethod("epidist_validate")
epidist_validate_model <- function(data, ...) {
UseMethod("epidist_validate_model")
}

#' Default method for data validation
#' Default method for validate model class
#'
#' @inheritParams epidist
#' @param ... Additional arguments
#' @family validate
#' @export
epidist_validate.default <- function(data, ...) {
epidist_validate_model.default <- function(data, ...) {
cli_abort(
"No epidist_validate method implemented for the class ", class(data), "\n",
"See methods(epidist_validate) for available methods"
"No epidist_validate_model method implemented for the class ", class(data),
"\n", "See methods(epidist_validate_model) for available methods"
)
}
35 changes: 0 additions & 35 deletions man/add_event_vars.Rd

This file was deleted.

17 changes: 1 addition & 16 deletions man/as_direct_model.Rd
34 changes: 34 additions & 0 deletions man/as_epidist_linelist.Rd
20 changes: 1 addition & 19 deletions man/as_latent_individual.Rd
19 changes: 19 additions & 0 deletions man/dot-rename_columns.Rd
2 changes: 1 addition & 1 deletion man/epidist-package.Rd
1 change: 0 additions & 1 deletion man/epidist_family_model.epidist_latent_individual.Rd
1 change: 0 additions & 1 deletion man/epidist_formula_model.epidist_latent_individual.Rd
23 changes: 0 additions & 23 deletions man/epidist_validate.Rd

This file was deleted.

21 changes: 0 additions & 21 deletions man/epidist_validate.default.Rd

This file was deleted.

24 changes: 0 additions & 24 deletions man/epidist_validate.epidist_direct_model.Rd

This file was deleted.

28 changes: 0 additions & 28 deletions man/epidist_validate.epidist_latent_individual.Rd

This file was deleted.

23 changes: 23 additions & 0 deletions man/epidist_validate_data.Rd
23 changes: 23 additions & 0 deletions man/epidist_validate_data.default.Rd
22 changes: 22 additions & 0 deletions man/epidist_validate_data.epidist_linelist.Rd
23 changes: 23 additions & 0 deletions man/epidist_validate_model.Rd
23 changes: 23 additions & 0 deletions man/epidist_validate_model.default.Rd
1 change: 0 additions & 1 deletion man/is_direct_model.Rd
22 changes: 22 additions & 0 deletions man/is_epidist_linelist.Rd
1 change: 0 additions & 1 deletion man/is_latent_individual.Rd
9 changes: 5 additions & 4 deletions man/predict_delay_parameters.Rd
15 changes: 15 additions & 0 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
set.seed(101)

as_epidist_linelist_time <- function(data) {
class(data) <- c("epidist_linelist", class(data))
epidist_validate_data(data)
return(data)
}

obs_time <- 25
sample_size <- 500

@@ -18,6 +24,9 @@ sim_obs <- simulate_gillespie() |>
filter_obs_by_obs_time(obs_time = obs_time) |>
dplyr::slice_sample(n = sample_size, replace = FALSE)

# Temporary solution for classing time data
sim_obs <- as_epidist_linelist_time(sim_obs)

set.seed(101)

shape <- 2
@@ -36,6 +45,9 @@ sim_obs_gamma <- simulate_gillespie() |>
filter_obs_by_obs_time(obs_time = obs_time) |>
dplyr::slice_sample(n = sample_size, replace = FALSE)

# Temporary solution for classing time data
sim_obs_gamma <- as_epidist_linelist_time(sim_obs_gamma)

# Data with a sex difference

meanlog_m <- 2.0
@@ -67,3 +79,6 @@ sim_obs_sex <- dplyr::bind_rows(sim_obs_sex_m, sim_obs_sex_f) |>
observe_process() |>
filter_obs_by_obs_time(obs_time = obs_time) |>
dplyr::slice_sample(n = sample_size, replace = FALSE)

# Temporary solution for classing time data
sim_obs_sex <- as_epidist_linelist_time(sim_obs_sex)
12 changes: 6 additions & 6 deletions tests/testthat/test-direct_model.R
Original file line number Diff line number Diff line change
@@ -35,16 +35,16 @@ test_that("is_direct_model returns FALSE for incorrect input", { # nolint: line_
})
})

test_that("epidist_validate.epidist_direct_model doesn't produce an error for correct input", { # nolint: line_length_linter.
expect_no_error(epidist_validate(prep_obs))
test_that("epidist_validate_model.epidist_direct_model doesn't produce an error for correct input", { # nolint: line_length_linter.
expect_no_error(epidist_validate_model(prep_obs))
})

test_that("epidist_validate.epidist_direct_model returns FALSE for incorrect input", { # nolint: line_length_linter.
expect_error(epidist_validate(list()))
expect_error(epidist_validate(prep_obs[, 1]))
test_that("epidist_validate_model.epidist_direct_model returns FALSE for incorrect input", { # nolint: line_length_linter.
expect_error(epidist_validate_model(list()))
expect_error(epidist_validate_model(prep_obs[, 1]))
expect_error({
x <- list()
class(x) <- "epidist_direct_model"
epidist_validate(x)
epidist_validate_model(x)
})
})
12 changes: 4 additions & 8 deletions tests/testthat/test-latent_individual.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
test_that("as_latent_individual.data.frame with default settings an object with the correct classes", { # nolint: line_length_linter.
test_that("as_latent_individual.epidist_linelist with default settings an object with the correct classes", { # nolint: line_length_linter.
prep_obs <- as_latent_individual(sim_obs)
expect_s3_class(prep_obs, "data.frame")
expect_s3_class(prep_obs, "epidist_latent_individual")
})

test_that("as_latent_individual.data.frame errors when passed incorrect inputs", { # nolint: line_length_linter.
test_that("as_latent_individual.epidist_linelist errors when passed incorrect inputs", { # nolint: line_length_linter.
expect_error(as_latent_individual(list()))
expect_error(as_latent_individual(sim_obs[, 1]))
expect_error({
sim_obs$case <- paste("case_", seq_len(nrow(sim_obs)))
as_latent_individual(sim_obs)
})
})

# Make this data available for other tests
@@ -35,8 +31,8 @@ test_that("is_latent_individual returns FALSE for incorrect input", { # nolint:
})
})

test_that("epidist_validate.epidist_latent_individual doesn't produce an error for correct input", { # nolint: line_length_linter.
expect_no_error(epidist_validate(prep_obs))
test_that("epidist_validate_model.epidist_latent_individual doesn't produce an error for correct input", { # nolint: line_length_linter.
expect_no_error(epidist_validate_model(prep_obs))
})

test_that("epidist_validate.epidist_latent_individual returns FALSE for incorrect input", { # nolint: line_length_linter.
3 changes: 1 addition & 2 deletions tests/testthat/test-postprocess.R
Original file line number Diff line number Diff line change
@@ -39,9 +39,8 @@ test_that("predict_delay_parameters accepts newdata arguments and prediction by
expect_equal(length(unique(pred_sex$draw)), summary(fit_sex)$total_ndraws)

pred_sex_summary <- pred_sex |>
dplyr::mutate(index = as.factor(index)) |>
dplyr::left_join(
dplyr::select(prep_obs_sex, index = row_id, sex),
dplyr::select(prep_obs_sex, index = .row_id, sex),
by = "index"
) |>
dplyr::group_by(sex) |>
72 changes: 41 additions & 31 deletions tests/testthat/test-preprocess.R
Original file line number Diff line number Diff line change
@@ -1,36 +1,46 @@
test_that("add_event_vars produces equivalent linelists in different ways", { # nolint: line_length_linter.
linelist <- tibble::tibble(
"a" = runif(100),
"b" = 1,
"c" = a + b,
"d" = runif(100, 2, 3),
"e" = 1,
"f" = d + e
test_that("as_epidist_linelist assigns epidist_linelist class to data", {
data <- data.frame(
case = 1,
pdate_lwr = as.POSIXct("2023-01-01 00:00:00"),
pdate_upr = as.POSIXct("2023-01-02 00:00:00"),
sdate_lwr = as.POSIXct("2023-01-03 00:00:00"),
sdate_upr = as.POSIXct("2023-01-04 00:00:00"),
obs_date = as.POSIXct("2023-01-05 00:00:00")
)
linelist <- as_epidist_linelist(
data, "pdate_lwr", "pdate_upr", "sdate_lwr", "sdate_upr", "obs_date"
)
expect_s3_class(linelist, "epidist_linelist")
})

ll <- linelist |>
add_event_vars(
ptime_lwr = "a", pwindow = "b", ptime_upr = "c",
stime_lwr = "d", swindow = "e", stime_upr = "f"
)

ll2 <- select(linelist, a, c, d, f) |>
add_event_vars(
ptime_lwr = "a", pwindow = 1, ptime_upr = "c",
stime_lwr = "d", swindow = 1, stime_upr = "f"
)

ll3 <- select(linelist, a, b, d, e) |>
add_event_vars(
ptime_lwr = "a", pwindow = "b", stime_lwr = "d", swindow = "e",
)
test_that("as_epidist_linelist correctly renames columns", {
data <- data.frame(
case = 1,
p_lower = as.POSIXct("2023-01-01"),
p_upper = as.POSIXct("2023-01-02"),
s_lower = as.POSIXct("2023-01-03"),
s_upper = as.POSIXct("2023-01-04"),
observation = as.POSIXct("2023-01-05")
)
linelist <- as_epidist_linelist(
data, "p_lower", "p_upper", "s_lower", "s_upper", "observation"
)
col_names <- c("pdate_lwr", "pdate_upr", "sdate_lwr", "sdate_upr", "obs_date")
expect_true(all(col_names %in% names(linelist)))
})

ll4 <- select(linelist, a, c, d, f) |>
add_event_vars(
ptime_lwr = "a", ptime_upr = "c", stime_lwr = "d", stime_upr = "f",
test_that("as_epidist_linelist gives error if columns are not datetime", {
data <- data.frame(
case = 1,
pdate_lwr = as.Date("2023-01-01"),
pdate_upr = as.Date("2023-01-02"),
sdate_lwr = as.Date("2023-01-03"),
sdate_upr = as.Date("2023-01-04"),
obs_date = as.Date("2023-01-05")
)
expect_error(
as_epidist_linelist(
data, "pdate_lwr", "pdate_upr", "sdate_lwr", "sdate_upr", "obs_date"
)

expect_equal(ll, ll2)
expect_equal(ll, ll3)
expect_equal(ll, ll4)
)
})
10 changes: 9 additions & 1 deletion vignettes/approx-inference.Rmd
Original file line number Diff line number Diff line change
@@ -123,7 +123,15 @@ obs_cens_trunc_samp <- simulate_gillespie(seed = 101) |>
We now prepare the data for fitting with the latent individual model, and perform inference with HMC:

```{r results='hide'}
data <- as_latent_individual(obs_cens_trunc_samp)
# Note: this functionality will be integrated into the package shortly
as_epidist_linelist_time <- function(data) {
class(data) <- c("epidist_linelist", class(data))
epidist_validate_data(data)
return(data)
}
linelist <- as_epidist_linelist_time(obs_cens_trunc_samp)
data <- as_latent_individual(linelist)
t <- proc.time()
fit_hmc <- epidist(data = data, algorithm = "sampling")
12 changes: 10 additions & 2 deletions vignettes/ebola.Rmd
Original file line number Diff line number Diff line change
@@ -199,10 +199,18 @@ obs_cens <- obs_cens |>

## Model fitting

To prepare the data for use with the latent individual model, we use the function `as_latent_individual()`:
To prepare the data for use with the latent individual model, we set `obs_cens` to be an `epidist_linelist` object, then use the function `as_latent_individual()`:

```{r}
obs_prep <- as_latent_individual(obs_cens)
# Note: this functionality will be integrated into the package shortly
as_epidist_linelist_time <- function(data) {
class(data) <- c("epidist_linelist", class(data))
epidist_validate_data(data)
return(data)
}
linelist <- as_epidist_linelist_time(obs_cens)
obs_prep <- as_latent_individual(linelist)
head(obs_prep)
```

10 changes: 9 additions & 1 deletion vignettes/epidist.Rmd
Original file line number Diff line number Diff line change
@@ -251,7 +251,15 @@ We will fit the model `"latent_individual"` which uses latent variables for the
To do so, we first prepare the `data` using `as_latent_individual()`:

```{r}
data <- as_latent_individual(obs_cens_trunc_samp)
# Note: this functionality will be integrated into the package shortly
as_epidist_linelist_time <- function(data) {
class(data) <- c("epidist_linelist", class(data))
epidist_validate_data(data)
return(data)
}
linelist <- as_epidist_linelist_time(obs_cens_trunc_samp)
data <- as_latent_individual(linelist)
class(data)
```

11 changes: 10 additions & 1 deletion vignettes/faq.Rmd
Original file line number Diff line number Diff line change
@@ -47,7 +47,16 @@ obs_cens_trunc_samp <- simulate_gillespie(seed = 101) |>
filter_obs_by_obs_time(obs_time = obs_time) |>
slice_sample(n = sample_size, replace = FALSE)
data <- as_latent_individual(obs_cens_trunc_samp)
# Note: this functionality will be integrated into the package shortly
as_epidist_linelist_time <- function(data) {
class(data) <- c("epidist_linelist", class(data))
epidist_validate_data(data)
return(data)
}
linelist <- as_epidist_linelist_time(obs_cens_trunc_samp)
data <- as_latent_individual(linelist)
fit <- epidist(
data,
formula = mu ~ 1,

0 comments on commit 31fe7ae

Please sign in to comment.