Skip to content

Grf qr engine #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.18
Version: 0.0.19
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down Expand Up @@ -36,10 +36,8 @@ Imports:
glue,
hardhat (>= 1.3.0),
magrittr,
quantreg,
recipes (>= 1.0.4),
rlang (>= 1.0.0),
smoothqr,
stats,
tibble,
tidyr,
Expand All @@ -52,13 +50,16 @@ Suggests:
data.table,
epidatr (>= 1.0.0),
fs,
grf,
knitr,
lubridate,
poissonreg,
purrr,
quantreg,
ranger,
RcppRoll,
rmarkdown,
smoothqr,
testthat (>= 3.0.0),
usethis,
xgboost
Expand Down
3 changes: 1 addition & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ importFrom(ggplot2,autoplot)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
importFrom(quantreg,rq)
importFrom(recipes,bake)
importFrom(recipes,prep)
importFrom(rlang,"!!!")
Expand All @@ -253,13 +252,13 @@ importFrom(rlang,as_function)
importFrom(rlang,caller_env)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
importFrom(rlang,expr)
importFrom(rlang,global_env)
importFrom(rlang,inject)
importFrom(rlang,is_logical)
importFrom(rlang,is_null)
importFrom(rlang,is_true)
importFrom(rlang,set_names)
importFrom(smoothqr,smooth_qr)
importFrom(stats,as.formula)
importFrom(stats,family)
importFrom(stats,lm)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
`...` args intended for `predict.model_fit()`
- `bake.epi_recipe()` will now re-infer the geo and time type in case baking the
steps has changed the appropriate values
- Add `step_epi_slide` to produce generic sliding computations over an `epi_df`
- Add `step_epi_slide` to produce generic sliding computations over an `epi_df`
- Add quantile random forests (via `{grf}`) as a parsnip engine
2 changes: 1 addition & 1 deletion R/epipredict-package.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## usethis namespace: start
#' @importFrom tibble tibble
#' @importFrom rlang := !! %||% as_function global_env set_names !!!
#' @importFrom rlang is_logical is_true inject enquo enquos
#' @importFrom rlang is_logical is_true inject enquo enquos expr
#' @importFrom stats poly predict lm residuals quantile
#' @importFrom cli cli_abort
#' @importFrom checkmate assert assert_character assert_int assert_scalar
Expand Down
11 changes: 8 additions & 3 deletions R/layer_quantile_distn.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#' Returns predictive quantiles
#'
#' This function calculates quantiles when the prediction was _distributional_.
#' Currently, the only distributional engine is `quantile_reg()`.
#' If this engine is used, then this layer will grab out estimated (or extrapolated)
#' quantiles at the requested quantile values.
#'
#' Currently, the only distributional modes/engines are
#' * `quantile_reg()`
#' * `smooth_quantile_reg()`
#' * `rand_forest(mode = "regression") %>% set_engine("grf_quantiles")`
#'
#' If these engines were used, then this layer will grab out estimated
#' (or extrapolated) quantiles at the requested quantile values.
#'
#' @param frosting a `frosting` postprocessor
#' @param ... Unused, include for consistency with other layers.
Expand Down
193 changes: 193 additions & 0 deletions R/make_grf_quantiles.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#' Random quantile forests via grf
#'
#' [grf::quantile_forest()] fits random forests in a way that makes it easy
#' to calculate _quantile_ forests. Currently, this is the only engine
#' provided here, since quantile regression is the typical use-case.
#'
#' @section Tuning Parameters:
#'
#' This model has 3 tuning parameters:
#'
#' - `mtry`: # Randomly Selected Predictors (type: integer, default: see below)
#' - `trees`: # Trees (type: integer, default: 2000L)
#' - `min_n`: Minimal Node Size (type: integer, default: 5)
#'
#' `mtry` depends on the number of columns in the design matrix.
#' The default in [grf::quantile_forest()] is `min(ceiling(sqrt(ncol(X)) + 20), ncol(X))`.
#'
#' For categorical predictors, a one-hot encoding is always used. This makes
#' splitting efficient, but has implications for the `mtry` choice. A factor
#' with many levels will become a large number of columns in the design matrix
#' which means that some of these may be selected frequently for potential splits.
#' This is different than in other implementations of random forest. For more
#' details, see [the `grf` discussion](https://grf-labs.github.io/grf/articles/categorical_inputs.html).
#'
#' @section Translation from parsnip to the original package:
#'
#' ```{r, translate-engine}
#' rand_forest(
#' mode = "regression", # you must specify the `mode = regression`
#' mtry = integer(1),
#' trees = integer(1),
#' min_n = integer(1)
#' ) %>%
#' set_engine("grf_quantiles") %>%
#' translate()
#' ```
#'
#' @section Case weights:
#'
#' Case weights are not supported.
#'
#' @examples
#' library(grf)
#' tib <- data.frame(
#' y = rnorm(100), x = rnorm(100), z = rnorm(100),
#' f = factor(sample(letters[1:3], 100, replace = TRUE))
#' )
#' spec <- rand_forest(engine = "grf_quantiles", mode = "regression")
#' out <- fit(spec, formula = y ~ x + z, data = tib)
#' predict(out, new_data = tib[1:5, ]) %>%
#' pivot_quantiles_wider(.pred)
#'
#' # -- adjusting the desired quantiles
#'
#' spec <- rand_forest(mode = "regression") %>%
#' set_engine(engine = "grf_quantiles", quantiles = c(1:9 / 10))
#' out <- fit(spec, formula = y ~ x + z, data = tib)
#' predict(out, new_data = tib[1:5, ]) %>%
#' pivot_quantiles_wider(.pred)
#'
#' # -- a more complicated task
#'
#' library(dplyr)
#' dat <- case_death_rate_subset %>%
#' filter(time_value > as.Date("2021-10-01"))
#' rec <- epi_recipe(dat) %>%
#' step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7) %>%
#' step_epi_naomit()
#' frost <- frosting() %>%
#' layer_predict() %>%
#' layer_threshold(.pred)
#' spec <- rand_forest(mode = "regression") %>%
#' set_engine(engine = "grf_quantiles", quantiles = c(.25, .5, .75))
#'
#' ewf <- epi_workflow(rec, spec, frost) %>%
#' fit(dat) %>%
#' forecast()
#' ewf %>%
#' rename(forecast_date = time_value) %>%
#' mutate(target_date = forecast_date + 7L) %>%
#' pivot_quantiles_wider(.pred)
#'
#' @name grf_quantiles
NULL



make_grf_quantiles <- function() {
parsnip::set_model_engine(
model = "rand_forest", mode = "regression", eng = "grf_quantiles"
)
parsnip::set_dependency(
model = "rand_forest", eng = "grf_quantiles", pkg = "grf",
mode = "regression"
)


# These are the arguments to the parsnip::rand_forest() that must be
# translated from grf::quantile_forest
parsnip::set_model_arg(
model = "rand_forest",
eng = "grf_quantiles",
parsnip = "mtry",
original = "mtry",
func = list(pkg = "dials", fun = "mtry"),
has_submodel = FALSE
)
parsnip::set_model_arg(
model = "rand_forest",
eng = "grf_quantiles",
parsnip = "trees",
original = "num.trees",
func = list(pkg = "dials", fun = "trees"),
has_submodel = FALSE
)
parsnip::set_model_arg(
model = "rand_forest",
eng = "grf_quantiles",
parsnip = "min_n",
original = "min.node.size",
func = list(pkg = "dials", fun = "min_n"),
has_submodel = FALSE
)

# the `value` list describes how grf::quantile_forest expects to receive
# arguments. In particular, it needs X and Y to be passed in as a matrices.
# But the matrix interface in parsnip calls these x and y. So the data
# slot translates them
#
# protect - prevents the user from passing X and Y arguments themselves
# defaults - engine specific arguments (not model specific) that we allow
# the user to change
parsnip::set_fit(
model = "rand_forest",
eng = "grf_quantiles",
mode = "regression",
value = list(
interface = "matrix",
protect = c("X", "Y"),
data = c(x = "X", y = "Y"),
func = c(pkg = "grf", fun = "quantile_forest"),
defaults = list(
quantiles = c(0.1, 0.5, 0.9),
num.threads = 1L,
seed = rlang::expr(stats::runif(1, 0, .Machine$integer.max))
)
)
)

parsnip::set_encoding(
model = "rand_forest",
eng = "grf_quantiles",
mode = "regression",
options = list(
# one hot is the closest to typical factor handling in randomForest
# (1 vs all splitting), though since we aren't bagging,
# factors with many levels could be visited frequently
predictor_indicators = "one_hot",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)

# turn the predictions into a tibble with a dist_quantiles column
process_qrf_preds <- function(x, object) {
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig
x <- x$predictions
out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
out <- dist_quantiles(out, list(quantile_levels))
return(dplyr::tibble(.pred = out))
}

parsnip::set_pred(
model = "rand_forest",
eng = "grf_quantiles",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = process_qrf_preds,
func = c(fun = "predict"),
# map between parsnip::predict args and grf::quantile_forest args
args = list(
object = quote(object$fit),
newdata = quote(new_data),
seed = rlang::expr(sample.int(10^5, 1)),
verbose = FALSE
)
)
)
}
13 changes: 8 additions & 5 deletions R/make_quantile_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
#' @description
#' `quantile_reg()` generates a quantile regression model _specification_ for
#' the [tidymodels](https://www.tidymodels.org/) framework. Currently, the
#' only supported engine is "rq" which uses [quantreg::rq()].
#' only supported engines are "rq", which uses [quantreg::rq()].
#' Quantile regression is also possible by combining [parsnip::rand_forest()]
#' with the `grf` engine. See [grf_quantiles].
#'
#' @param mode A single character string for the type of model.
#' The only possible value for this model is "regression".
#' @param engine Character string naming the fitting function. Currently, only
#' "rq" is supported.
#' "rq" and "grf" are supported.
#' @param quantile_levels A scalar or vector of values in (0, 1) to determine which
#' quantiles to estimate (default is 0.5).
#'
#' @export
#'
#' @seealso [fit.model_spec()], [set_engine()]
#'
#' @importFrom quantreg rq
#'
#' @examples
#' library(quantreg)
#' tib <- data.frame(y = rnorm(100), x1 = rnorm(100), x2 = rnorm(100))
#' rq_spec <- quantile_reg(quantile_levels = c(.2, .8)) %>% set_engine("rq")
#' ff <- rq_spec %>% fit(y ~ ., data = tib)
Expand Down Expand Up @@ -106,15 +109,15 @@ make_quantile_reg <- function() {
out <- switch(type,
rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile
rqs = {
x <- lapply(unname(split(x, seq(nrow(x)))), function(q) sort(q))
x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
dist_quantiles(x, list(object$tau))
},
cli_abort(c(
"Prediction is not implemented for this `rq` type.",
i = "See {.fun quantreg::rq}."
))
)
return(data.frame(.pred = out))
return(dplyr::tibble(.pred = out))
}


Expand Down
23 changes: 11 additions & 12 deletions R/make_smooth_quantile_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#'
#' @seealso [fit.model_spec()], [set_engine()]
#'
#' @importFrom smoothqr smooth_qr
#' @examples
#' library(smoothqr)
#' tib <- data.frame(
#' y1 = rnorm(100), y2 = rnorm(100), y3 = rnorm(100),
#' y4 = rnorm(100), y5 = rnorm(100), y6 = rnorm(100),
Expand Down Expand Up @@ -62,17 +62,16 @@
#' lines(pl$x, pl$`0.8`, col = "blue")
#' lines(pl$x, pl$`0.5`, col = "red")
#'
#' if (require("ggplot2")) {
#' ggplot(data.frame(x = x, y = y), aes(x)) +
#' geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") +
#' geom_point(aes(y = y), colour = "grey") + # observed data
#' geom_function(fun = sin, colour = "black") + # truth
#' geom_vline(xintercept = fd, linetype = "dashed") + # end of training data
#' geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction
#' theme_bw() +
#' coord_cartesian(xlim = c(0, NA)) +
#' ylab("y")
#' }
#' library(ggplot2)
#' ggplot(data.frame(x = x, y = y), aes(x)) +
#' geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") +
#' geom_point(aes(y = y), colour = "grey") + # observed data
#' geom_function(fun = sin, colour = "black") + # truth
#' geom_vline(xintercept = fd, linetype = "dashed") + # end of training data
#' geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction
#' theme_bw() +
#' coord_cartesian(xlim = c(0, NA)) +
#' ylab("y")
smooth_quantile_reg <- function(
mode = "regression",
engine = "smoothqr",
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
make_flatline_reg()
make_quantile_reg()
make_smooth_quantile_reg()
make_grf_quantiles()
}
Loading
Loading