@@ -43,6 +43,50 @@ test_that('classification predictions', {
43
43
c(" .pred_high" , " .pred_low" ))
44
44
})
45
45
46
+
47
+ test_that(' ordinal classification predictions' , {
48
+ skip_if_not_installed(" modeldata" )
49
+ skip_if_not_installed(" rpart" )
50
+
51
+ set.seed(382 )
52
+ dat_tr <-
53
+ modeldata :: sim_multinomial(
54
+ 200 ,
55
+ ~ - 0.5 + 0.6 * abs(A ),
56
+ ~ ifelse(A > 0 & B > 0 , 1.0 + 0.2 * A / B , - 2 ),
57
+ ~ - 0.6 * A + 0.50 * B - A * B ) %> %
58
+ dplyr :: mutate(class = as.ordered(class ))
59
+ dat_te <-
60
+ modeldata :: sim_multinomial(
61
+ 5 ,
62
+ ~ - 0.5 + 0.6 * abs(A ),
63
+ ~ ifelse(A > 0 & B > 0 , 1.0 + 0.2 * A / B , - 2 ),
64
+ ~ - 0.6 * A + 0.50 * B - A * B ) %> %
65
+ dplyr :: mutate(class = as.ordered(class ))
66
+
67
+ # ##
68
+
69
+ mod_f_fit <-
70
+ decision_tree() %> %
71
+ set_mode(" classification" ) %> %
72
+ fit(class ~ . , data = dat_tr )
73
+ expect_true(" ordered" %in% names(mod_f_fit ))
74
+ mod_f_pred <- predict(mod_f_fit , dat_te )
75
+ expect_true(is.ordered(mod_f_pred $ .pred_class ))
76
+
77
+ # ##
78
+
79
+ mod_xy_fit <-
80
+ decision_tree() %> %
81
+ set_mode(" classification" ) %> %
82
+ fit_xy(x = dat_tr %> % dplyr :: select(- class ), dat_tr $ class )
83
+
84
+ expect_true(" ordered" %in% names(mod_xy_fit ))
85
+ mod_xy_pred <- predict(mod_xy_fit , dat_te )
86
+ expect_true(is.ordered(mod_f_pred $ .pred_class ))
87
+ })
88
+
89
+
46
90
test_that(' non-standard levels' , {
47
91
expect_true(is_tibble(predict(lr_fit , new_data = class_dat [1 : 5 ,- 1 ])))
48
92
expect_true(is.factor(parsnip ::: predict_class.model_fit(lr_fit , new_data = class_dat [1 : 5 ,- 1 ])))
0 commit comments