Skip to content

Commit b6e5848

Browse files
topepo‘topepo’hfrick
authored
Brulee neural networks with two hidden layers (#1187)
* engine work for two-layer mlp models * temp change for tunable * enable parsnip to work with functions wit parameterized labels * Revert "enable parsnip to work with functions wit parameterized labels" This reverts commit bec0423. * update tunable method for brulee mlps * update for other brulee engines * add rate_schedule to tunables * doc updates * typo fix * update test * fix two bugs in brulee tunable methods * Update man/rmd/mlp_brulee_two_layer.Rmd Co-authored-by: Hannah Frick <[email protected]> * remove () from argument names * add mode to set_dependency() * update tunables and tests * redoc --------- Co-authored-by: ‘topepo’ <‘[email protected]’> Co-authored-by: Hannah Frick <[email protected]>
1 parent 1880a48 commit b6e5848

21 files changed

+1103
-109
lines changed

Diff for: NAMESPACE

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ S3method(tunable,logistic_reg)
132132
S3method(tunable,mars)
133133
S3method(tunable,mlp)
134134
S3method(tunable,model_spec)
135-
S3method(tunable,multinomial_reg)
135+
S3method(tunable,multinom_reg)
136136
S3method(tunable,rand_forest)
137137
S3method(tunable,survival_reg)
138138
S3method(tunable,svm_poly)

Diff for: R/mlp_brulee_two_layer.R

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#' Multilayer perceptron via brulee with two hidden layers
2+
#'
3+
#' [brulee::brulee_mlp_two_layer()] fits a neural network (with version 0.3.0.9000 or higher of brulee)
4+
#'
5+
#' @includeRmd man/rmd/mlp_brulee_two_layer.md details
6+
#'
7+
#' @name details_mlp_brulee_two_layer
8+
#' @keywords internal
9+
NULL
10+
11+
# See inst/README-DOCS.md for a description of how these files are processed

Diff for: R/mlp_data.R

+165-1
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ set_pred(
368368

369369
set_model_engine("mlp", "classification", "brulee")
370370
set_model_engine("mlp", "regression", "brulee")
371-
set_dependency("mlp", "brulee", "brulee")
371+
set_dependency("mlp", "brulee", "brulee", mode = "classification")
372+
set_dependency("mlp", "brulee", "brulee", mode = "regression")
372373

373374
set_model_arg(
374375
model = "mlp",
@@ -527,3 +528,166 @@ set_pred(
527528
)
528529
)
529530

531+
532+
set_model_engine("mlp", "classification", "brulee_two_layer")
533+
set_model_engine("mlp", "regression", "brulee_two_layer")
534+
set_dependency("mlp", "brulee_two_layer", "brulee", mode = "classification")
535+
set_dependency("mlp", "brulee_two_layer", "brulee", mode = "regression")
536+
537+
set_model_arg(
538+
model = "mlp",
539+
eng = "brulee_two_layer",
540+
parsnip = "hidden_units",
541+
original = "hidden_units",
542+
func = list(pkg = "dials", fun = "hidden_units"),
543+
has_submodel = FALSE
544+
)
545+
546+
set_model_arg(
547+
model = "mlp",
548+
eng = "brulee_two_layer",
549+
parsnip = "penalty",
550+
original = "penalty",
551+
func = list(pkg = "dials", fun = "penalty"),
552+
has_submodel = FALSE
553+
)
554+
555+
set_model_arg(
556+
model = "mlp",
557+
eng = "brulee_two_layer",
558+
parsnip = "epochs",
559+
original = "epochs",
560+
func = list(pkg = "dials", fun = "epochs"),
561+
has_submodel = FALSE
562+
)
563+
564+
set_model_arg(
565+
model = "mlp",
566+
eng = "brulee_two_layer",
567+
parsnip = "dropout",
568+
original = "dropout",
569+
func = list(pkg = "dials", fun = "dropout"),
570+
has_submodel = FALSE
571+
)
572+
573+
set_model_arg(
574+
model = "mlp",
575+
eng = "brulee_two_layer",
576+
parsnip = "learn_rate",
577+
original = "learn_rate",
578+
func = list(pkg = "dials", fun = "learn_rate", range = c(-2.5, -0.5)),
579+
has_submodel = FALSE
580+
)
581+
582+
set_model_arg(
583+
model = "mlp",
584+
eng = "brulee_two_layer",
585+
parsnip = "activation",
586+
original = "activation",
587+
func = list(pkg = "dials", fun = "activation", values = c('relu', 'elu', 'tanh')),
588+
has_submodel = FALSE
589+
)
590+
591+
592+
set_fit(
593+
model = "mlp",
594+
eng = "brulee_two_layer",
595+
mode = "regression",
596+
value = list(
597+
interface = "data.frame",
598+
protect = c("x", "y"),
599+
func = c(pkg = "brulee", fun = "brulee_mlp_two_layer"),
600+
defaults = list()
601+
)
602+
)
603+
604+
set_encoding(
605+
model = "mlp",
606+
eng = "brulee_two_layer",
607+
mode = "regression",
608+
options = list(
609+
predictor_indicators = "none",
610+
compute_intercept = FALSE,
611+
remove_intercept = FALSE,
612+
allow_sparse_x = FALSE
613+
)
614+
)
615+
616+
set_fit(
617+
model = "mlp",
618+
eng = "brulee_two_layer",
619+
mode = "classification",
620+
value = list(
621+
interface = "data.frame",
622+
protect = c("x", "y"),
623+
func = c(pkg = "brulee", fun = "brulee_mlp_two_layer"),
624+
defaults = list()
625+
)
626+
)
627+
628+
set_encoding(
629+
model = "mlp",
630+
eng = "brulee_two_layer",
631+
mode = "classification",
632+
options = list(
633+
predictor_indicators = "none",
634+
compute_intercept = FALSE,
635+
remove_intercept = FALSE,
636+
allow_sparse_x = FALSE
637+
)
638+
)
639+
640+
set_pred(
641+
model = "mlp",
642+
eng = "brulee_two_layer",
643+
mode = "regression",
644+
type = "numeric",
645+
value = list(
646+
pre = NULL,
647+
post = reformat_torch_num,
648+
func = c(fun = "predict"),
649+
args =
650+
list(
651+
object = quote(object$fit),
652+
new_data = quote(new_data),
653+
type = "numeric"
654+
)
655+
)
656+
)
657+
658+
set_pred(
659+
model = "mlp",
660+
eng = "brulee_two_layer",
661+
mode = "classification",
662+
type = "class",
663+
value = list(
664+
pre = NULL,
665+
post = NULL,
666+
func = c(fun = "predict"),
667+
args =
668+
list(
669+
object = quote(object$fit),
670+
new_data = quote(new_data),
671+
type = "class"
672+
)
673+
)
674+
)
675+
676+
set_pred(
677+
model = "mlp",
678+
eng = "brulee_two_layer",
679+
mode = "classification",
680+
type = "prob",
681+
value = list(
682+
pre = NULL,
683+
post = NULL,
684+
func = c(fun = "predict"),
685+
args =
686+
list(
687+
object = quote(object$fit),
688+
new_data = quote(new_data),
689+
type = "prob"
690+
)
691+
)
692+
)
693+

Diff for: R/tunable.R

+76-52
Original file line numberDiff line numberDiff line change
@@ -194,37 +194,6 @@ earth_engine_args <-
194194
component_id = "engine"
195195
)
196196

197-
brulee_mlp_engine_args <-
198-
tibble::tribble(
199-
~name, ~call_info,
200-
"momentum", list(pkg = "dials", fun = "momentum", range = c(0.5, 0.95)),
201-
"batch_size", list(pkg = "dials", fun = "batch_size", range = c(3, 10)),
202-
"stop_iter", list(pkg = "dials", fun = "stop_iter"),
203-
"class_weights", list(pkg = "dials", fun = "class_weights"),
204-
"decay", list(pkg = "dials", fun = "rate_decay"),
205-
"initial", list(pkg = "dials", fun = "rate_initial"),
206-
"largest", list(pkg = "dials", fun = "rate_largest"),
207-
"rate_schedule", list(pkg = "dials", fun = "rate_schedule"),
208-
"step_size", list(pkg = "dials", fun = "rate_step_size"),
209-
"mixture", list(pkg = "dials", fun = "mixture")
210-
) %>%
211-
dplyr::mutate(source = "model_spec",
212-
component = "mlp",
213-
component_id = "engine"
214-
)
215-
216-
brulee_linear_engine_args <-
217-
brulee_mlp_engine_args %>%
218-
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter"))
219-
220-
brulee_logistic_engine_args <-
221-
brulee_mlp_engine_args %>%
222-
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))
223-
224-
brulee_multinomial_engine_args <-
225-
brulee_mlp_engine_args %>%
226-
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))
227-
228197
flexsurvspline_engine_args <-
229198
tibble::tibble(
230199
name = c("k"),
@@ -236,6 +205,42 @@ flexsurvspline_engine_args <-
236205
component_id = "engine"
237206
)
238207

208+
# ------------------------------------------------------------------------------
209+
# used for brulee engines:
210+
211+
tune_activations <- c("relu", "tanh", "elu", "log_sigmoid", "tanhshrink")
212+
tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
213+
214+
brulee_mlp_args <-
215+
tibble::tibble(
216+
name = c('epochs', 'hidden_units', 'hidden_units_2', 'activation', 'activation_2',
217+
'penalty', 'mixture', 'dropout', 'learn_rate', 'momentum', 'batch_size',
218+
'class_weights', 'stop_iter', 'rate_schedule'),
219+
call_info = list(
220+
list(pkg = "dials", fun = "epochs", range = c(5L, 500L)),
221+
list(pkg = "dials", fun = "hidden_units", range = c(2L, 50L)),
222+
list(pkg = "dials", fun = "hidden_units_2", range = c(2L, 50L)),
223+
list(pkg = "dials", fun = "activation", values = tune_activations),
224+
list(pkg = "dials", fun = "activation_2", values = tune_activations),
225+
list(pkg = "dials", fun = "penalty"),
226+
list(pkg = "dials", fun = "mixture"),
227+
list(pkg = "dials", fun = "dropout"),
228+
list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/5)),
229+
list(pkg = "dials", fun = "momentum", range = c(0.50, 0.95)),
230+
list(pkg = "dials", fun = "batch_size"),
231+
list(pkg = "dials", fun = "stop_iter"),
232+
list(pkg = "dials", fun = "class_weights"),
233+
list(pkg = "dials", fun = "rate_schedule", values = tune_sched)
234+
)
235+
) %>%
236+
dplyr::mutate(source = "model_spec")
237+
238+
brulee_mlp_only_args <-
239+
tibble::tibble(
240+
name =
241+
c('hidden_units', 'hidden_units_2', 'activation', 'activation_2', 'dropout')
242+
)
243+
239244
# ------------------------------------------------------------------------------
240245

241246
#' @export
@@ -245,31 +250,55 @@ tunable.linear_reg <- function(x, ...) {
245250
res$call_info[res$name == "mixture"] <-
246251
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
247252
} else if (x$engine == "brulee") {
248-
res <- add_engine_parameters(res, brulee_linear_engine_args)
253+
res <-
254+
brulee_mlp_args %>%
255+
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
256+
dplyr::filter(name != "class_weights") %>%
257+
dplyr::mutate(
258+
component = "linear_reg",
259+
component_id = ifelse(name %in% names(formals("linear_reg")), "main", "engine")
260+
) %>%
261+
dplyr::select(name, call_info, source, component, component_id)
249262
}
250263
res
251264
}
252265

266+
#' @export
267+
253268
#' @export
254269
tunable.logistic_reg <- function(x, ...) {
255270
res <- NextMethod()
256271
if (x$engine == "glmnet") {
257272
res$call_info[res$name == "mixture"] <-
258273
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
259274
} else if (x$engine == "brulee") {
260-
res <- add_engine_parameters(res, brulee_logistic_engine_args)
275+
res <-
276+
brulee_mlp_args %>%
277+
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
278+
dplyr::mutate(
279+
component = "logistic_reg",
280+
component_id = ifelse(name %in% names(formals("logistic_reg")), "main", "engine")
281+
) %>%
282+
dplyr::select(name, call_info, source, component, component_id)
261283
}
262284
res
263285
}
264286

265287
#' @export
266-
tunable.multinomial_reg <- function(x, ...) {
288+
tunable.multinom_reg <- function(x, ...) {
267289
res <- NextMethod()
268290
if (x$engine == "glmnet") {
269291
res$call_info[res$name == "mixture"] <-
270292
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
271293
} else if (x$engine == "brulee") {
272-
res <- add_engine_parameters(res, brulee_multinomial_engine_args)
294+
res <-
295+
brulee_mlp_args %>%
296+
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
297+
dplyr::mutate(
298+
component = "multinom_reg",
299+
component_id = ifelse(name %in% names(formals("multinom_reg")), "main", "engine")
300+
) %>%
301+
dplyr::select(name, call_info, source, component, component_id)
273302
}
274303
res
275304
}
@@ -345,28 +374,23 @@ tunable.svm_poly <- function(x, ...) {
345374
res
346375
}
347376

348-
349377
#' @export
350378
tunable.mlp <- function(x, ...) {
351379
res <- NextMethod()
352-
if (x$engine == "brulee") {
353-
res <- add_engine_parameters(res, brulee_mlp_engine_args)
354-
res$call_info[res$name == "learn_rate"] <-
355-
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))
356-
res$call_info[res$name == "epochs"] <-
357-
list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L)))
358-
activation_values <- rlang::eval_tidy(
359-
rlang::call2("brulee_activations", .ns = "brulee")
360-
)
361-
res$call_info[res$name == "activation"] <-
362-
list(list(pkg = "dials", fun = "activation", values = activation_values))
363-
} else if (x$engine == "keras") {
364-
activation_values <- parsnip::keras_activations()
365-
res$call_info[res$name == "activation"] <-
366-
list(list(pkg = "dials", fun = "activation", values = activation_values))
380+
if (grepl("brulee", x$engine)) {
381+
res <-
382+
brulee_mlp_args %>%
383+
dplyr::mutate(
384+
component = "mlp",
385+
component_id = ifelse(name %in% names(formals("mlp")), "main", "engine")
386+
) %>%
387+
dplyr::select(name, call_info, source, component, component_id)
388+
if (x$engine == "brulee") {
389+
res <- res[!grepl("_2", res$name),]
390+
}
367391
}
368392
res
369-
}
393+
}
370394

371395
#' @export
372396
tunable.survival_reg <- function(x, ...) {

0 commit comments

Comments
 (0)