diff --git a/R/add_forecast_results.R b/R/add_forecast_results.R index b397610..3b3e48f 100644 --- a/R/add_forecast_results.R +++ b/R/add_forecast_results.R @@ -57,7 +57,6 @@ add_forecast_results <- function(predictions, is_correct <- function(x, y) { - if (x$predicted_class == x$class) { x <- x |> dplyr::mutate(correct = TRUE) @@ -66,6 +65,17 @@ add_forecast_results <- function(predictions, dplyr::mutate(correct=FALSE) } } + + + is_cl_correct <- function(x,y) { + if ((x$predicted_class == 3 & x$class == 3) || (x$predicted_class != 3 & x$class != 3)) { + x <- x |> + dplyr::mutate(cl_correct = TRUE) + } else { + x <- x |> + dplyr::mutate(cl_correct=FALSE) + } + } results <- predictions |> @@ -78,6 +88,9 @@ add_forecast_results <- function(predictions, tidyr::drop_na("class") |> dplyr::rowwise() |> dplyr::group_map(is_correct, .keep=TRUE) |> + dplyr::bind_rows() |> + dplyr::rowwise() |> + dplyr::group_map(is_cl_correct, .keep=TRUE) |> dplyr::bind_rows() return(forecast_w_results) diff --git a/inst/forecastdb/seasonal_results/psp_forecast_results_2021.csv.gz b/inst/forecastdb/seasonal_results/psp_forecast_results_2021.csv.gz index 13fc789..e66d70c 100644 Binary files a/inst/forecastdb/seasonal_results/psp_forecast_results_2021.csv.gz and b/inst/forecastdb/seasonal_results/psp_forecast_results_2021.csv.gz differ diff --git a/inst/forecastdb/seasonal_results/psp_forecast_results_2022.csv.gz b/inst/forecastdb/seasonal_results/psp_forecast_results_2022.csv.gz index c20d817..dffbd48 100644 Binary files a/inst/forecastdb/seasonal_results/psp_forecast_results_2022.csv.gz and b/inst/forecastdb/seasonal_results/psp_forecast_results_2022.csv.gz differ diff --git a/inst/forecastdb/seasonal_results/psp_forecast_results_2023.csv.gz b/inst/forecastdb/seasonal_results/psp_forecast_results_2023.csv.gz index c215029..a503490 100644 Binary files a/inst/forecastdb/seasonal_results/psp_forecast_results_2023.csv.gz and b/inst/forecastdb/seasonal_results/psp_forecast_results_2023.csv.gz differ diff --git a/inst/forecastdb/seasonal_results/psp_forecast_results_2024.csv.gz b/inst/forecastdb/seasonal_results/psp_forecast_results_2024.csv.gz index f64b4ae..7e6ccfb 100644 Binary files a/inst/forecastdb/seasonal_results/psp_forecast_results_2024.csv.gz and b/inst/forecastdb/seasonal_results/psp_forecast_results_2024.csv.gz differ diff --git a/inst/manuscript/confusion_matrix_allyears.R b/inst/manuscript/confusion_matrix_allyears.R index 4e177a4..935b975 100644 --- a/inst/manuscript/confusion_matrix_allyears.R +++ b/inst/manuscript/confusion_matrix_allyears.R @@ -33,6 +33,10 @@ plot1 <- ggplot2::ggplot(data = cm, ggplot2::aes(x=.data$predicted, y=.data$actu ggplot2::geom_rect(aes(xmin=0.5, xmax=3.5, ymin=0.5, ymax=3.5), alpha=0) + ggplot2::geom_rect(aes(xmin=3.5, xmax=4.5, ymin=3.5, ymax=4.5), alpha=0) +plot1 + +# Save plot + ggsave(filename = "inst/manuscript/cm_allyears.jpeg", plot=plot1, width=12, height=8) diff --git a/inst/manuscript/scatter_allyears.R b/inst/manuscript/scatter_allyears.R index 37623cf..54658c6 100644 --- a/inst/manuscript/scatter_allyears.R +++ b/inst/manuscript/scatter_allyears.R @@ -5,15 +5,24 @@ library(ggplot2) pred_w_results <- read_all_results() +ggplot2::ggplot(data = pred_w_results, ggplot2::aes(x=.data$p_3, y=.data$toxicity, colour = cl_correct)) + + ggplot2::geom_point(alpha=0.7, size=3) + + ggplot2::facet_grid(cols=vars(.data$year)) + + ggplot2::labs(x = "Predicted Probability of Closure-level Toxicity (%)", + y = "Measured Toxicity (μg STX eq/ 100 g shellfish)") + + ggplot2::geom_hline(yintercept=80, linetype="dashed") + + ggplot2::theme_bw() plot2 <- ggplot2::ggplot(data = pred_w_results, ggplot2::aes(x=.data$p_3, y=.data$toxicity, colour = correct)) + ggplot2::geom_point(alpha=0.7, size=3) + ggplot2::facet_grid(cols=vars(.data$year)) + - ggplot2::labs(x = "Predicted Probability of Closure-level Toxicity", - y = "Measured Toxicity") + + ggplot2::labs(x = "Predicted Probability of Closure-level Toxicity (%)", + y = "Measured Toxicity (μg STX eq/ 100 g shellfish)") + ggplot2::geom_hline(yintercept=80, linetype="dashed") + ggplot2::theme_bw() plot2 -ggsave(filename = "inst/manuscript/scatter_allyears.jpeg", plot=plot2, width=12, height=9) +# Save plot + +ggsave(filename = "inst/manuscript/scatter_allyears.jpeg", plot=plot2, width=12, height=8) diff --git a/inst/manuscript/station_metrics.R b/inst/manuscript/station_metrics.R index 7e6ad39..a44a7c5 100644 --- a/inst/manuscript/station_metrics.R +++ b/inst/manuscript/station_metrics.R @@ -8,7 +8,7 @@ find_station_metrics <- function(results = read_all_results()) { dplyr::tibble(location = key$location[1], lat = tbl$lat[1], lon = tbl$lon[1], - accuracy = accuracy_vec(truth = factor(tbl$class, levels = c(0,1,2,3)), estimate=factor(tbl$predicted_class, , levels = c(0,1,2,3))), + accuracy = yardstick::accuracy_vec(truth = factor(tbl$class, levels = c(0,1,2,3)), estimate=factor(tbl$predicted_class, , levels = c(0,1,2,3))), predictions = nrow(tbl)) } @@ -41,7 +41,8 @@ plot_station_metrics <- function(st_metrics) { ggplot2::geom_point(data = st_metrics, ggplot2::aes(x = .data$lon, y = .data$lat, colour=.data$accuracy), size=1) + - ggplot2::scale_color_gradient(low="black", high="red") + #ggplot2::scale_color_gradient(low="black", high="red") + + ggplot2::scale_color_viridis_b() p } @@ -49,4 +50,9 @@ plot_station_metrics <- function(st_metrics) { st_metrics <- find_station_metrics() -plot_station_metrics(st_metrics) \ No newline at end of file +plot3 <- plot_station_metrics(st_metrics) + + +# Save plot + +ggsave(filename = "inst/manuscript/station_metrics_allyears.jpeg", plot=plot3, width=6, height=4) diff --git a/inst/scripts/get_results.R b/inst/scripts/get_results.R index cfc067f..ae9ab04 100644 --- a/inst/scripts/get_results.R +++ b/inst/scripts/get_results.R @@ -4,6 +4,7 @@ library(pspforecast) library(pspdata) library(readr) +library(dplyr) psp <- read_psp_data() |> @@ -15,13 +16,16 @@ psp <- read_psp_data() |> ## 2021 Season -predictions21 <- read_forecast(year = "2021") +predictions21 <- read_forecast(year = "2021") |> + rename(p_0=prob_0, + p_1=prob_1, + p_2=prob_2, + p_3=prob_3) x <- add_forecast_results(predictions21, toxin_measurements = psp) summary(x) -x |> - write_csv("inst/forecastdb/seasonal_results/psp_forecast_results_2021.csv.gz") +write_csv(x, "inst/forecastdb/seasonal_results/psp_forecast_results_2021.csv.gz") ## 2022 Season @@ -30,8 +34,7 @@ xx <- add_forecast_results(predictions22, toxin_measurements = psp) summary(xx) -xx |> - write_csv("inst/forecastdb/seasonal_results/psp_forecast_results_2022.csv.gz") +write_csv(xx, "inst/forecastdb/seasonal_results/psp_forecast_results_2022.csv.gz") ## 2023 @@ -40,8 +43,7 @@ xx <- add_forecast_results(predictions23, toxin_measurements = psp) summary(xx) -xx |> - write_csv("inst/forecastdb/seasonal_results/psp_forecast_results_2023.csv.gz") +write_csv(xx, "inst/forecastdb/seasonal_results/psp_forecast_results_2023.csv.gz") ## 2024 @@ -50,6 +52,5 @@ xx <- add_forecast_results(predictions24, toxin_measurements = psp) summary(xx) -xx |> - write_csv("inst/forecastdb/seasonal_results/psp_forecast_results_2024.csv.gz") +write_csv(xx, "inst/forecastdb/seasonal_results/psp_forecast_results_2024.csv.gz")