Skip to content

Commit eb526fa

Browse files
authored
clarify case weight support in show_model_info() (#1102)
1 parent e5c7f92 commit eb526fa

File tree

3 files changed

+113
-20
lines changed

3 files changed

+113
-20
lines changed

R/aaa_models.R

+10-4
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ show_model_info <- function(model) {
991991
) %>%
992992
dplyr::select(engine, mode, has_wts)
993993

994-
engines %>%
994+
engine_weight_info <- engines %>%
995995
dplyr::left_join(weight_info, by = c("engine", "mode")) %>%
996996
dplyr::mutate(
997997
engine = paste0(engine, has_wts),
@@ -1005,9 +1005,15 @@ show_model_info <- function(model) {
10051005
lab = paste0(" ", mode, engine, "\n")
10061006
) %>%
10071007
dplyr::ungroup() %>%
1008-
dplyr::pull(lab) %>%
1009-
cat(sep = "")
1010-
cat("\n", cli::symbol$sup_1, "The model can use case weights.\n\n", sep = "")
1008+
dplyr::pull(lab)
1009+
1010+
cat(engine_weight_info, sep = "")
1011+
1012+
if (!all(weight_info$has_wts == "")) {
1013+
cat("\n", cli::symbol$sup_1, "The model can use case weights.", sep = "")
1014+
}
1015+
1016+
cat("\n\n")
10111017
} else {
10121018
cat(" no registered engines.\n\n")
10131019
}

tests/testthat/_snaps/registration.md

+98
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,101 @@
66
Error in `check_mode_for_new_engine()`:
77
! "regression" is not a known mode for model `sponge()`.
88

9+
# showing model info
10+
11+
Code
12+
show_model_info("rand_forest")
13+
Output
14+
Information for `rand_forest`
15+
modes: unknown, classification, regression, censored regression
16+
17+
engines:
18+
classification: randomForest, ranger1, spark
19+
regression: randomForest, ranger1, spark
20+
21+
1The model can use case weights.
22+
23+
arguments:
24+
ranger:
25+
mtry --> mtry
26+
trees --> num.trees
27+
min_n --> min.node.size
28+
randomForest:
29+
mtry --> mtry
30+
trees --> ntree
31+
min_n --> nodesize
32+
spark:
33+
mtry --> feature_subset_strategy
34+
trees --> num_trees
35+
min_n --> min_instances_per_node
36+
37+
fit modules:
38+
engine mode
39+
ranger classification
40+
ranger regression
41+
randomForest classification
42+
randomForest regression
43+
spark classification
44+
spark regression
45+
46+
prediction modules:
47+
mode engine methods
48+
classification randomForest class, prob, raw
49+
classification ranger class, conf_int, prob, raw
50+
classification spark class, prob
51+
regression randomForest numeric, raw
52+
regression ranger conf_int, numeric, raw
53+
regression spark numeric
54+
55+
56+
---
57+
58+
Code
59+
show_model_info("mlp")
60+
Output
61+
Information for `mlp`
62+
modes: unknown, classification, regression
63+
64+
engines:
65+
classification: brulee, keras, nnet
66+
regression: brulee, keras, nnet
67+
68+
69+
arguments:
70+
keras:
71+
hidden_units --> hidden_units
72+
penalty --> penalty
73+
dropout --> dropout
74+
epochs --> epochs
75+
activation --> activation
76+
nnet:
77+
hidden_units --> size
78+
penalty --> decay
79+
epochs --> maxit
80+
brulee:
81+
hidden_units --> hidden_units
82+
penalty --> penalty
83+
epochs --> epochs
84+
dropout --> dropout
85+
learn_rate --> learn_rate
86+
activation --> activation
87+
88+
fit modules:
89+
engine mode
90+
keras regression
91+
keras classification
92+
nnet regression
93+
nnet classification
94+
brulee regression
95+
brulee classification
96+
97+
prediction modules:
98+
mode engine methods
99+
classification brulee class, prob
100+
classification keras class, prob, raw
101+
classification nnet class, prob, raw
102+
regression brulee numeric
103+
regression keras numeric, raw
104+
regression nnet numeric, raw
105+
106+

tests/testthat/test_registration.R

+5-16
Original file line numberDiff line numberDiff line change
@@ -496,21 +496,10 @@ test_that('adding a new predict method', {
496496

497497

498498
test_that('showing model info', {
499-
expect_output(
500-
show_model_info("rand_forest"),
501-
"Information for `rand_forest`"
502-
)
503-
expect_output(
504-
show_model_info("rand_forest"),
505-
"trees --> ntree"
506-
)
507-
expect_output(
508-
show_model_info("rand_forest"),
509-
"fit modules:"
510-
)
511-
expect_output(
512-
show_model_info("rand_forest"),
513-
"prediction modules:"
514-
)
499+
expect_snapshot(show_model_info("rand_forest"))
500+
501+
# ensure that we don't mention case weight support when the
502+
# notation would be ambiguous (#1000)
503+
expect_snapshot(show_model_info("mlp"))
515504
})
516505

0 commit comments

Comments
 (0)