Skip to content

Commit a9889f0

Browse files
authored
predict ordinal factors from ordinal regression models (#1217)
1 parent a4f9811 commit a9889f0

File tree

7 files changed

+60
-6
lines changed

7 files changed

+60
-6
lines changed

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
## Bug Fixes
2727

28+
* Make sure that parsnip does not convert ordered factor predictions to be unordered.
29+
2830
* Ensure that `knit_engine_docs()` has the required packages installed (#1156).
2931

3032
* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).

R/fit.R

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
#' \itemize{
8888
#' \item \code{lvl}: If the outcome is a factor, this contains
8989
#' the factor levels at the time of model fitting.
90+
#' \item \code{ordered}: If the outcome is a factor, was it an ordered factor?
9091
#' \item \code{spec}: The model specification object
9192
#' (\code{object} in the call to \code{fit})
9293
#' \item \code{fit}: when the model is executed without error,

R/fit_helpers.R

+3-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ form_form <-
4040
fit_call <- make_form_call(object, env = env)
4141

4242
res <- list(
43-
lvl = y_levels,
43+
lvl = y_levels$lvl,
44+
ordered = y_levels$ordered,
4445
spec = object
4546
)
4647

@@ -98,7 +99,7 @@ xy_xy <- function(object,
9899

99100
fit_call <- make_xy_call(object, target, env, call)
100101

101-
res <- list(lvl = levels(env$y), spec = object)
102+
res <- list(lvl = levels(env$y), ordered = is.ordered(env$y), spec = object)
102103

103104
time <- proc.time()
104105
res$fit <- eval_mod(

R/misc.R

+5-2
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,12 @@ convert_arg <- function(x) {
260260

261261
levels_from_formula <- function(f, dat) {
262262
if (inherits(dat, "tbl_spark")) {
263-
res <- NULL
263+
res <- list(lvls = NULL, ordered = FALSE)
264264
} else {
265-
res <- levels(eval_tidy(rlang::f_lhs(f), dat))
265+
res <- list()
266+
y_data <- eval_tidy(rlang::f_lhs(f), dat)
267+
res$lvls <- levels(y_data)
268+
res$ordered <- is.ordered(y_data)
266269
}
267270
res
268271
}

R/predict_class.R

+4-2
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,16 @@ predict_class.model_fit <- function(object, new_data, ...) {
4141

4242
# coerce levels to those in `object`
4343
if (is.vector(res) || is.factor(res)) {
44-
res <- factor(as.character(res), levels = object$lvl)
44+
res <- factor(as.character(res), levels = object$lvl, ordered = object$ordered)
4545
} else {
4646
if (!inherits(res, "tbl_spark")) {
4747
# Now case where a parsnip model generated `res`
4848
if (is.data.frame(res) && ncol(res) == 1 && is.factor(res[[1]])) {
4949
res <- res[[1]]
5050
} else {
51-
res$values <- factor(as.character(res$values), levels = object$lvl)
51+
res$values <- factor(as.character(res$values),
52+
levels = object$lvl,
53+
ordered = object$ordered)
5254
}
5355
}
5456
}

man/fit.Rd

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-predict_formats.R

+44
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,50 @@ test_that('classification predictions', {
4343
c(".pred_high", ".pred_low"))
4444
})
4545

46+
47+
test_that('ordinal classification predictions', {
48+
skip_if_not_installed("modeldata")
49+
skip_if_not_installed("rpart")
50+
51+
set.seed(382)
52+
dat_tr <-
53+
modeldata::sim_multinomial(
54+
200,
55+
~ -0.5 + 0.6 * abs(A),
56+
~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
57+
~ -0.6 * A + 0.50 * B - A * B) %>%
58+
dplyr::mutate(class = as.ordered(class))
59+
dat_te <-
60+
modeldata::sim_multinomial(
61+
5,
62+
~ -0.5 + 0.6 * abs(A),
63+
~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
64+
~ -0.6 * A + 0.50 * B - A * B) %>%
65+
dplyr::mutate(class = as.ordered(class))
66+
67+
###
68+
69+
mod_f_fit <-
70+
decision_tree() %>%
71+
set_mode("classification") %>%
72+
fit(class ~ ., data = dat_tr)
73+
expect_true("ordered" %in% names(mod_f_fit))
74+
mod_f_pred <- predict(mod_f_fit, dat_te)
75+
expect_true(is.ordered(mod_f_pred$.pred_class))
76+
77+
###
78+
79+
mod_xy_fit <-
80+
decision_tree() %>%
81+
set_mode("classification") %>%
82+
fit_xy(x = dat_tr %>% dplyr::select(-class), dat_tr$class)
83+
84+
expect_true("ordered" %in% names(mod_xy_fit))
85+
mod_xy_pred <- predict(mod_xy_fit, dat_te)
86+
expect_true(is.ordered(mod_f_pred$.pred_class))
87+
})
88+
89+
4690
test_that('non-standard levels', {
4791
expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1])))
4892
expect_true(is.factor(parsnip:::predict_class.model_fit(lr_fit, new_data = class_dat[1:5,-1])))

0 commit comments

Comments
 (0)