diff --git a/R/use.R b/R/use.R index 878d0b4..cf2f0c2 100644 --- a/R/use.R +++ b/R/use.R @@ -154,18 +154,21 @@ use_xgboost <- function(formula, data, prefix = "xgboost", verbose = FALSE, prm <- rlang::exprs( trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), - loss_reduction = tune(), sample_size = tune() + loss_reduction = tune(), sample_size = tune(), mtry = tune() ) + mod_syntax <- + paste0(prefix, "_spec") %>% + assign_value(!!rlang::call2("boost_tree", !!!prm)) %>% + pipe_value(set_mode(!!model_mode(rec))) %>% + pipe_value(set_engine("xgboost", count = TRUE)) } else { - prm <- NULL + mod_syntax <- + paste0(prefix, "_spec") %>% + assign_value(!!rlang::call2("boost_tree")) %>% + pipe_value(set_mode(!!model_mode(rec))) %>% + pipe_value(set_engine("xgboost")) } - mod_syntax <- - paste0(prefix, "_spec") %>% - assign_value(!!rlang::call2("boost_tree", !!!prm)) %>% - pipe_value(set_mode(!!model_mode(rec))) %>% - pipe_value(set_engine("xgboost")) - route(rec_syntax, path = pth) route(mod_syntax, path = pth) route(template_workflow(prefix), path = pth) diff --git a/tests/testthat/_snaps/templates.md b/tests/testthat/_snaps/templates.md index abd555d..dc7037a 100644 --- a/tests/testthat/_snaps/templates.md +++ b/tests/testthat/_snaps/templates.md @@ -1511,9 +1511,9 @@ test_config_31_dummies_spec <- boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), - loss_reduction = tune(), sample_size = tune()) %>% + loss_reduction = tune(), sample_size = tune(), mtry = tune()) %>% set_mode("regression") %>% - set_engine("xgboost") + set_engine("xgboost", count = TRUE) test_config_31_dummies_workflow <- workflow() %>% @@ -1545,9 +1545,9 @@ test_config_31_no_dummies_spec <- boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), - loss_reduction = tune(), sample_size = tune()) %>% + loss_reduction = tune(), sample_size = tune(), mtry = tune()) %>% set_mode("classification") %>% - set_engine("xgboost") + set_engine("xgboost", count = TRUE) test_config_31_no_dummies_workflow <- workflow() %>% @@ -2963,9 +2963,9 @@ test_config_63_dummies_spec <- boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), - loss_reduction = tune(), sample_size = tune()) %>% + loss_reduction = tune(), sample_size = tune(), mtry = tune()) %>% set_mode("regression") %>% - set_engine("xgboost") + set_engine("xgboost", count = TRUE) test_config_63_dummies_workflow <- workflow() %>% @@ -2991,9 +2991,9 @@ test_config_63_no_dummies_spec <- boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), - loss_reduction = tune(), sample_size = tune()) %>% + loss_reduction = tune(), sample_size = tune(), mtry = tune()) %>% set_mode("classification") %>% - set_engine("xgboost") + set_engine("xgboost", count = TRUE) test_config_63_no_dummies_workflow <- workflow() %>%