Skip to content

Commit

Permalink
Added metric function to package for analysis of experimental predict…
Browse files Browse the repository at this point in the history
…ions; updated README metrics
  • Loading branch information
jevanilla committed Oct 2, 2024
1 parent 870e2b4 commit 2b262a0
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 315 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

export(add_forecast_results)
export(find_closures)
export(forecast_metrics)
export(format_probs)
export(format_webpage_table)
export(get_recent_year)
Expand Down
46 changes: 46 additions & 0 deletions R/forecast_metrics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#' Finds metrics from a table of forecast predictions
#' @param fc tibble of predictions
#' @param predicted_col character string of predicted classification column name
#' @param measured_col character string of measured classification column
#' @export
forecast_metrics <- function(fc,
predicted_col = "predicted_class",
measured_col = "actual_class") {
correct <- fc |>
dplyr::filter(.data[[predicted_col]] == .data[[measured_col]]) |>
nrow()
tn <- fc |>
dplyr::filter(.data[[predicted_col]] != 3 & .data[[measured_col]] != 3) |>
nrow()
tp <- fc |>
dplyr::filter(.data[[predicted_col]] == 3 & .data[[measured_col]] == 3) |>
nrow()
fp <- fc |>
dplyr::filter(.data[[predicted_col]] == 3 & .data[[measured_col]] != 3) |>
nrow()
fn <- fc |>
dplyr::filter(.data[[predicted_col]] != 3 & .data[[measured_col]] == 3) |>
nrow()

precision <- tp/(tp+fp)
recall <- tp/(tp+fn)
sensitivity <- tp/(tp+fn)
specificity <- tn/(tn+fp)

f_1 <- (2)*(precision*recall)/(precision+recall)
cl_accuracy <- (tn+tp)/nrow(fc)
accuracy <- correct/nrow(fc)

metrics_c3 <- dplyr::tibble(tp = tp,
fp = fp,
tn = tn,
fn = fn,
accuracy = accuracy,
cl_accuracy = cl_accuracy,
f_1=f_1,
precision = precision,
sensitivity = sensitivity,
specificity = specificity)

return(metrics_c3)
}
223 changes: 7 additions & 216 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -137,70 +137,20 @@ p3_v_tox_24

### Metrics

#### Season Accuracy:


```{r echo=FALSE}
correct <- pred_w_results |>
filter(predicted_class == class) |>
nrow()
metrics <- tibble(accuracy = correct/nrow(pred_w_results))
metrics
```

#### Closure-level (Class 3) Predictions

+ tp - The model predicted class 3 and the following week's measurement was class 3
+ fp - The model predicted class 3 and the following week's measurement was not class 3
+ tn - The model predicted class 0,1,2 and the following week's measurement was in class 0,1,2
+ fn - The model predicted class 0,1,2 and the following week's measurement was class 3
+ accuracy - Measure of how many correct classifications were predicted
+ cl_accuracy - Considering predictions are those that correctly predicted toxicity above or below the closure limit or not
+ precision - TP/(TP+FP)
+ sensitivity - TP/(TP+FN)
+ specificity - TN/(TN+FP)
+ f_1


```{r echo=FALSE}
tn <- pred_w_results |>
filter(predicted_class != 3 & class != 3) |>
nrow()
tp <- pred_w_results |>
dplyr::filter(.data$predicted_class == 3 & .data$class == 3) |>
nrow()
fp <- pred_w_results |>
dplyr::filter(.data$predicted_class == 3 & .data$class != 3) |>
nrow()
precision = tp/(tp+fp)
fn = pred_w_results |>
dplyr::filter(.data$predicted_class != 3 & .data$class == 3) |>
nrow()
recall = tp/(tp+fn)
f_1 = (2)*(precision*recall)/(precision+recall)
sensitivity = tp/(tp+fn)
specificity = tn/(tn+fp)
cl_accuracy = (tn+tp)/nrow(pred_w_results)
metrics_c3 <- tibble(tp = tp,
fp = fp,
tn = tn,
fn = fn,
cl_accuracy = cl_accuracy,
precision = precision,
#recall = recall,
sensitivity = sensitivity,
specificity = specificity)
metrics_c3
forecast_metrics(pred_w_results, measured_col = "class")
```


Expand Down Expand Up @@ -265,61 +215,8 @@ p3_v_tox_23

### Metrics

#### Season Accuracy:


```{r echo=FALSE}
correct <- pred_w_results |>
filter(predicted_class == class) |>
nrow()
metrics <- tibble(accuracy = correct/nrow(pred_w_results))
metrics
```

#### Closure-level (Class 3) Predictions

```{r echo=FALSE}
tn <- pred_w_results |>
filter(predicted_class != 3 & class != 3) |>
nrow()
tp <- pred_w_results |>
dplyr::filter(.data$predicted_class == 3 & .data$class == 3) |>
nrow()
fp <- pred_w_results |>
dplyr::filter(.data$predicted_class == 3 & .data$class != 3) |>
nrow()
precision = tp/(tp+fp)
fn = pred_w_results |>
dplyr::filter(.data$predicted_class != 3 & .data$class == 3) |>
nrow()
recall = tp/(tp+fn)
f_1 = (2)*(precision*recall)/(precision+recall)
sensitivity = tp/(tp+fn)
specificity = tn/(tn+fp)
cl_accuracy = (tn+tp)/nrow(pred_w_results)
metrics_c3 <- tibble(tp = tp,
fp = fp,
tn = tn,
fn = fn,
cl_accuracy = cl_accuracy,
precision = precision,
#recall = recall,
sensitivity = sensitivity,
specificity = specificity)
metrics_c3
forecast_metrics(pred_w_results, measured_col = "class")
```


Expand Down Expand Up @@ -382,61 +279,8 @@ p3_v_tox_22

### Metrics

#### Season Accuracy:

```{r acc22, echo=FALSE}
correct <- pred_w_results |>
filter(predicted_class == class) |>
nrow()
metrics <- tibble(accuracy = correct/nrow(pred_w_results))
metrics
```

#### Closure-level (Class 3) Predictions

```{r metrics22, echo=FALSE}
tn <- pred_w_results |>
filter(predicted_class != 3 & class != 3) |>
nrow()
tp <- pred_w_results |>
dplyr::filter(.data$predicted_class == 3 & .data$class == 3) |>
nrow()
fp <- pred_w_results |>
dplyr::filter(.data$predicted_class == 3 & .data$class != 3) |>
nrow()
precision = tp/(tp+fp)
fn = pred_w_results |>
dplyr::filter(.data$predicted_class != 3 & .data$class == 3) |>
nrow()
recall = tp/(tp+fn)
f_1 = (2)*(precision*recall)/(precision+recall)
sensitivity = tp/(tp+fn)
specificity = tn/(tn+fp)
cl_accuracy = (tn+tp)/nrow(pred_w_results)
metrics_c3 <- tibble(tp = tp,
fp = fp,
tn = tn,
fn = fn,
cl_accuracy = cl_accuracy,
precision = precision,
#recall = recall,
sensitivity = sensitivity,
specificity = specificity)
metrics_c3
forecast_metrics(pred_w_results, measured_col = "class")
```


Expand Down Expand Up @@ -515,61 +359,8 @@ p3_v_tox_21

### Metrics

#### Season Accuracy:

```{r acc21, echo=FALSE}
correct <- pred_w_results |>
filter(predicted_class == class) |>
nrow()
metrics <- tibble(accuracy = correct/nrow(pred_w_results))
metrics
```

#### Closure-level (Class 3) Predictions

```{r metrics21, echo=FALSE}
tn <- pred_w_results |>
filter(predicted_class != 3 & class != 3) |>
nrow()
tp <- pred_w_results |>
dplyr::filter(.data$predicted_class == 3 & .data$class == 3) |>
nrow()
fp <- pred_w_results |>
dplyr::filter(.data$predicted_class == 3 & .data$class != 3) |>
nrow()
precision = tp/(tp+fp)
fn = pred_w_results |>
dplyr::filter(.data$predicted_class != 3 & .data$class == 3) |>
nrow()
recall = tp/(tp+fn)
f_1 = (2)*(precision*recall)/(precision+recall)
sensitivity = tp/(tp+fn)
specificity = tn/(tn+fp)
cl_accuracy = (tn+tp)/nrow(pred_w_results)
metrics_c3 <- tibble(tp = tp,
fp = fp,
tn = tn,
fn = fn,
cl_accuracy = cl_accuracy,
precision = precision,
#recall = recall,
sensitivity = sensitivity,
specificity = specificity)
metrics_c3
forecast_metrics(pred_w_results, measured_col = "class")
```


Expand Down
Loading

0 comments on commit 2b262a0

Please sign in to comment.