Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accumulate for observed reports #643

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
out$reported_cases <- extract_parameter(
"imputed_reports",
samples,
reported_dates
reported_dates[data$cases_time[-1]]
)
if ("estimate_r" %in% names(data)) {
if (data$estimate_r == 1) {
Expand Down
47 changes: 26 additions & 21 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ transformed parameters {
vector<lower = 0>[estimate_r > 0 ? ot_h : 0] R; // reproduction number
vector[t] infections; // latent infections
vector[ot_h] reports; // estimated reported cases
vector[ot] obs_reports; // observed estimated reported cases
vector[lt - accumulate] obs_reports; // observed estimated reported cases
vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf;
// GP in noise - spectral densities
profile("update gp") {
Expand Down Expand Up @@ -131,22 +131,28 @@ transformed parameters {
);
}
}
// truncate near time cases to observed reports
if (trunc_id) {
vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf;
profile("truncation") {
trunc_rev_cmf = get_delay_rev_pmf(
trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 1
);
{
vector[ot] truncated_reports;
// truncate near time cases to observed reports
if (trunc_id) {
vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf;
profile("truncation") {
trunc_rev_cmf = get_delay_rev_pmf(
trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 1
);
}
profile("truncate") {
truncated_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
}
} else {
truncated_reports = reports[1:ot];
}
profile("assign") {
obs_reports = assign_reports(cases_time, truncated_reports, accumulate);
}
profile("truncate") {
obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
}
} else {
obs_reports = reports[1:ot];
}
}

Expand Down Expand Up @@ -185,15 +191,14 @@ model {
if (likelihood) {
profile("report lp") {
report_lp(
cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type,
obs_weight, accumulate
cases, obs_reports, rep_phi, phi_mean, phi_sd, model_type, obs_weight
);
}
}
}

generated quantities {
array[ot_h] int imputed_reports;
array[lt - accumulate] int imputed_reports;
vector[estimate_r > 0 ? 0: ot_h] gen_R;
vector[ot_h - 1] r;
vector[return_likelihood ? ot : 0] log_lik;
Expand All @@ -217,11 +222,11 @@ generated quantities {
// estimate growth from infections
r = calculate_growth(infections, seeding_time + 1);
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, model_type);
imputed_reports = report_rng(obs_reports, rep_phi, model_type);
// log likelihood of model
if (return_likelihood) {
log_lik = report_log_lik(
cases, obs_reports[cases_time], rep_phi, model_type, obs_weight
cases, obs_reports, rep_phi, model_type, obs_weight
);
}
}
Expand Down
12 changes: 8 additions & 4 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ functions {
data {
int t; // time of observations
int lt; // time of observations
array[t] int<lower = 0> obs; // observed secondary data
array[t] int<lower = 0> obs; // observed secondary data
array[lt] int obs_time; // observed secondary data
vector[t] primary; // observed primary data
int burn_in; // time period to not use for fitting
Expand All @@ -35,6 +35,7 @@ parameters{

transformed parameters {
vector<lower=0>[t] secondary;
vector<lower=0>[lt - accumulate] obs_secondary;
// calculate secondary reports from primary

{
Expand Down Expand Up @@ -79,7 +80,10 @@ transformed parameters {
0, 1, 1
);
secondary = truncate(secondary, trunc_rev_cmf, 0);
}
}
obs_secondary = assign_reports(
obs_time, secondary[(burn_in + 1):t], accumulate
);
}

model {
Expand All @@ -96,8 +100,8 @@ model {
// observed secondary reports from mean of secondary reports (update likelihood)
if (likelihood) {
report_lp(
obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1, accumulate
obs[(burn_in + 1):t][obs_time], obs_secondary, rep_phi, phi_mean, phi_sd,
model_type, 1
);
}
}
Expand Down
55 changes: 33 additions & 22 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,18 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
}
}
}

// update log density for reported cases
void report_lp(array[] int cases, array[] int cases_time, vector reports,
void report_lp(array[] int cases, vector obs_reports,
array[] real rep_phi, real phi_mean, real phi_sd,
int model_type, real weight, int accumulate) {
int n = num_elements(cases_time) - accumulate; // number of observations
vector[n] obs_reports; // reports at observation time
array[n] int obs_cases; // observed cases at observation time
if (accumulate) {
int t = num_elements(reports);
int i = 0;
int current_obs = 0;
obs_reports = rep_vector(0, n);
while (i <= t && current_obs <= n) {
if (current_obs > 0) { // first observation gets ignored when accumulating
obs_reports[current_obs] += reports[i];
}
if (i == cases_time[current_obs + 1]) {
current_obs += 1;
}
i += 1;
}
obs_cases = cases[2:(n + 1)];
} else {
obs_reports = reports[cases_time];
int model_type, real weight) {
int n = num_elements(obs_reports);
int n_cases = num_elements(cases);
// if accumulating shift cases
array[n] int obs_cases;
if (n_cases > n) {
obs_cases = cases[(n_cases - n + 1):n_cases];
} else {
obs_cases = cases;
}
if (model_type) {
Expand Down Expand Up @@ -138,3 +126,26 @@ array[] int report_rng(vector reports, array[] real rep_phi, int model_type) {
}
return(sampled_reports);
}

vector assign_reports(array[] int cases_time, vector reports, int accumulate) {
int n = num_elements(cases_time) - accumulate; // number of observations
vector[n] obs_reports; // reports at observation time
if (accumulate) {
int t = num_elements(reports);
int i = 0;
int current_obs = 0;
obs_reports = rep_vector(0, n);
while (i <= t && current_obs <= n) {
if (current_obs > 0) { // first observation gets ignored when accumulating
obs_reports[current_obs] += reports[i];
}
if (i == cases_time[current_obs + 1]) {
current_obs += 1;
}
i += 1;
}
} else {
obs_reports = reports[cases_time];
}
return(obs_reports);
}
Loading