-
Notifications
You must be signed in to change notification settings - Fork 11
Grf qr engine #360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Grf qr engine #360
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
055637d
runs
dajmcdon e62b203
don't execute examples conditionally, move engine pkgs to Suggests
dajmcdon 585456b
checks pass
dajmcdon 3028bb9
draft tests
dajmcdon 59c7764
found the bug in pivot_quantiles_wider
dajmcdon 1c9c6a7
Merge branch '356-pivot-quantiles-bug' into grf-qr-engine
dajmcdon 615e112
working, documented implementation
dajmcdon 0ba22b8
update documentation
dajmcdon 68baf07
slightly adjust test
dajmcdon e51d470
bump version
dajmcdon 7eef2f4
styler
dajmcdon 2abbe35
no tibble in examples
dajmcdon a6f9cee
no tibble in examples
dajmcdon 6837da2
Merge branch 'grf-qr-engine' of https://github.com/cmu-delphi/epipred…
dajmcdon 8c3cf4d
doc: add a link and revise quantile_reg
dajmcdon 47cdfd6
merge dev
dajmcdon b4ae464
missing exports
dajmcdon 4dda39a
Merge branch 'dev' into grf-qr-engine
dajmcdon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: epipredict | ||
Title: Basic epidemiology forecasting methods | ||
Version: 0.0.18 | ||
Version: 0.0.19 | ||
Authors@R: c( | ||
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")), | ||
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"), | ||
|
@@ -36,10 +36,8 @@ Imports: | |
glue, | ||
hardhat (>= 1.3.0), | ||
magrittr, | ||
quantreg, | ||
recipes (>= 1.0.4), | ||
rlang (>= 1.0.0), | ||
smoothqr, | ||
stats, | ||
tibble, | ||
tidyr, | ||
|
@@ -52,13 +50,16 @@ Suggests: | |
data.table, | ||
epidatr (>= 1.0.0), | ||
fs, | ||
grf, | ||
knitr, | ||
lubridate, | ||
poissonreg, | ||
purrr, | ||
quantreg, | ||
ranger, | ||
RcppRoll, | ||
rmarkdown, | ||
smoothqr, | ||
testthat (>= 3.0.0), | ||
usethis, | ||
xgboost | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
#' Random quantile forests via grf | ||
#' | ||
#' [grf::quantile_forest()] fits random forests in a way that makes it easy | ||
#' to calculate _quantile_ forests. Currently, this is the only engine | ||
#' provided here, since quantile regression is the typical use-case. | ||
#' | ||
#' @section Tuning Parameters: | ||
#' | ||
#' This model has 3 tuning parameters: | ||
#' | ||
#' - `mtry`: # Randomly Selected Predictors (type: integer, default: see below) | ||
#' - `trees`: # Trees (type: integer, default: 2000L) | ||
dajmcdon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#' - `min_n`: Minimal Node Size (type: integer, default: 5) | ||
#' | ||
#' `mtry` depends on the number of columns in the design matrix. | ||
#' The default in [grf::quantile_forest()] is `min(ceiling(sqrt(ncol(X)) + 20), ncol(X))`. | ||
#' | ||
#' For categorical predictors, a one-hot encoding is always used. This makes | ||
#' splitting efficient, but has implications for the `mtry` choice. A factor | ||
#' with many levels will become a large number of columns in the design matrix | ||
#' which means that some of these may be selected frequently for potential splits. | ||
#' This is different than in other implementations of random forest. For more | ||
#' details, see [the `grf` discussion](https://grf-labs.github.io/grf/articles/categorical_inputs.html). | ||
#' | ||
#' @section Translation from parsnip to the original package: | ||
#' | ||
#' ```{r, translate-engine} | ||
#' rand_forest( | ||
#' mode = "regression", # you must specify the `mode = regression` | ||
#' mtry = integer(1), | ||
#' trees = integer(1), | ||
#' min_n = integer(1) | ||
#' ) %>% | ||
#' set_engine("grf_quantiles") %>% | ||
#' translate() | ||
#' ``` | ||
#' | ||
#' @section Case weights: | ||
#' | ||
#' Case weights are not supported. | ||
#' | ||
#' @examples | ||
#' library(grf) | ||
#' tib <- data.frame( | ||
#' y = rnorm(100), x = rnorm(100), z = rnorm(100), | ||
#' f = factor(sample(letters[1:3], 100, replace = TRUE)) | ||
#' ) | ||
#' spec <- rand_forest(engine = "grf_quantiles", mode = "regression") | ||
#' out <- fit(spec, formula = y ~ x + z, data = tib) | ||
#' predict(out, new_data = tib[1:5, ]) %>% | ||
#' pivot_quantiles_wider(.pred) | ||
#' | ||
#' # -- adjusting the desired quantiles | ||
#' | ||
#' spec <- rand_forest(mode = "regression") %>% | ||
#' set_engine(engine = "grf_quantiles", quantiles = c(1:9 / 10)) | ||
#' out <- fit(spec, formula = y ~ x + z, data = tib) | ||
#' predict(out, new_data = tib[1:5, ]) %>% | ||
#' pivot_quantiles_wider(.pred) | ||
#' | ||
#' # -- a more complicated task | ||
#' | ||
#' library(dplyr) | ||
#' dat <- case_death_rate_subset %>% | ||
#' filter(time_value > as.Date("2021-10-01")) | ||
#' rec <- epi_recipe(dat) %>% | ||
#' step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) %>% | ||
#' step_epi_ahead(death_rate, ahead = 7) %>% | ||
#' step_epi_naomit() | ||
#' frost <- frosting() %>% | ||
#' layer_predict() %>% | ||
#' layer_threshold(.pred) | ||
#' spec <- rand_forest(mode = "regression") %>% | ||
#' set_engine(engine = "grf_quantiles", quantiles = c(.25, .5, .75)) | ||
#' | ||
#' ewf <- epi_workflow(rec, spec, frost) %>% | ||
#' fit(dat) %>% | ||
#' forecast() | ||
#' ewf %>% | ||
#' rename(forecast_date = time_value) %>% | ||
#' mutate(target_date = forecast_date + 7L) %>% | ||
#' pivot_quantiles_wider(.pred) | ||
#' | ||
#' @name grf_quantiles | ||
NULL | ||
dsweber2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
|
||
make_grf_quantiles <- function() { | ||
parsnip::set_model_engine( | ||
model = "rand_forest", mode = "regression", eng = "grf_quantiles" | ||
) | ||
parsnip::set_dependency( | ||
model = "rand_forest", eng = "grf_quantiles", pkg = "grf", | ||
mode = "regression" | ||
) | ||
|
||
|
||
# These are the arguments to the parsnip::rand_forest() that must be | ||
# translated from grf::quantile_forest | ||
parsnip::set_model_arg( | ||
model = "rand_forest", | ||
eng = "grf_quantiles", | ||
parsnip = "mtry", | ||
original = "mtry", | ||
func = list(pkg = "dials", fun = "mtry"), | ||
has_submodel = FALSE | ||
) | ||
parsnip::set_model_arg( | ||
model = "rand_forest", | ||
eng = "grf_quantiles", | ||
parsnip = "trees", | ||
original = "num.trees", | ||
func = list(pkg = "dials", fun = "trees"), | ||
has_submodel = FALSE | ||
) | ||
parsnip::set_model_arg( | ||
model = "rand_forest", | ||
eng = "grf_quantiles", | ||
parsnip = "min_n", | ||
original = "min.node.size", | ||
func = list(pkg = "dials", fun = "min_n"), | ||
has_submodel = FALSE | ||
) | ||
|
||
# the `value` list describes how grf::quantile_forest expects to receive | ||
# arguments. In particular, it needs X and Y to be passed in as a matrices. | ||
# But the matrix interface in parsnip calls these x and y. So the data | ||
# slot translates them | ||
# | ||
# protect - prevents the user from passing X and Y arguments themselves | ||
# defaults - engine specific arguments (not model specific) that we allow | ||
# the user to change | ||
parsnip::set_fit( | ||
model = "rand_forest", | ||
eng = "grf_quantiles", | ||
mode = "regression", | ||
value = list( | ||
interface = "matrix", | ||
protect = c("X", "Y"), | ||
data = c(x = "X", y = "Y"), | ||
func = c(pkg = "grf", fun = "quantile_forest"), | ||
defaults = list( | ||
quantiles = c(0.1, 0.5, 0.9), | ||
num.threads = 1L, | ||
seed = rlang::expr(stats::runif(1, 0, .Machine$integer.max)) | ||
) | ||
) | ||
) | ||
|
||
parsnip::set_encoding( | ||
model = "rand_forest", | ||
eng = "grf_quantiles", | ||
mode = "regression", | ||
options = list( | ||
# one hot is the closest to typical factor handling in randomForest | ||
# (1 vs all splitting), though since we aren't bagging, | ||
# factors with many levels could be visited frequently | ||
predictor_indicators = "one_hot", | ||
compute_intercept = FALSE, | ||
remove_intercept = FALSE, | ||
allow_sparse_x = FALSE | ||
) | ||
) | ||
|
||
# turn the predictions into a tibble with a dist_quantiles column | ||
process_qrf_preds <- function(x, object) { | ||
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig | ||
x <- x$predictions | ||
out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x))) | ||
out <- dist_quantiles(out, list(quantile_levels)) | ||
return(dplyr::tibble(.pred = out)) | ||
} | ||
|
||
parsnip::set_pred( | ||
model = "rand_forest", | ||
eng = "grf_quantiles", | ||
mode = "regression", | ||
type = "numeric", | ||
value = list( | ||
pre = NULL, | ||
post = process_qrf_preds, | ||
func = c(fun = "predict"), | ||
# map between parsnip::predict args and grf::quantile_forest args | ||
args = list( | ||
object = quote(object$fit), | ||
newdata = quote(new_data), | ||
seed = rlang::expr(sample.int(10^5, 1)), | ||
verbose = FALSE | ||
) | ||
) | ||
) | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,5 @@ | |
make_flatline_reg() | ||
make_quantile_reg() | ||
make_smooth_quantile_reg() | ||
make_grf_quantiles() | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.