@@ -194,37 +194,6 @@ earth_engine_args <-
194
194
component_id = " engine"
195
195
)
196
196
197
- brulee_mlp_engine_args <-
198
- tibble :: tribble(
199
- ~ name , ~ call_info ,
200
- " momentum" , list (pkg = " dials" , fun = " momentum" , range = c(0.5 , 0.95 )),
201
- " batch_size" , list (pkg = " dials" , fun = " batch_size" , range = c(3 , 10 )),
202
- " stop_iter" , list (pkg = " dials" , fun = " stop_iter" ),
203
- " class_weights" , list (pkg = " dials" , fun = " class_weights" ),
204
- " decay" , list (pkg = " dials" , fun = " rate_decay" ),
205
- " initial" , list (pkg = " dials" , fun = " rate_initial" ),
206
- " largest" , list (pkg = " dials" , fun = " rate_largest" ),
207
- " rate_schedule" , list (pkg = " dials" , fun = " rate_schedule" ),
208
- " step_size" , list (pkg = " dials" , fun = " rate_step_size" ),
209
- " mixture" , list (pkg = " dials" , fun = " mixture" )
210
- ) %> %
211
- dplyr :: mutate(source = " model_spec" ,
212
- component = " mlp" ,
213
- component_id = " engine"
214
- )
215
-
216
- brulee_linear_engine_args <-
217
- brulee_mlp_engine_args %> %
218
- dplyr :: filter(name %in% c(" momentum" , " batch_size" , " stop_iter" ))
219
-
220
- brulee_logistic_engine_args <-
221
- brulee_mlp_engine_args %> %
222
- dplyr :: filter(name %in% c(" momentum" , " batch_size" , " stop_iter" , " class_weights" ))
223
-
224
- brulee_multinomial_engine_args <-
225
- brulee_mlp_engine_args %> %
226
- dplyr :: filter(name %in% c(" momentum" , " batch_size" , " stop_iter" , " class_weights" ))
227
-
228
197
flexsurvspline_engine_args <-
229
198
tibble :: tibble(
230
199
name = c(" k" ),
@@ -236,6 +205,42 @@ flexsurvspline_engine_args <-
236
205
component_id = " engine"
237
206
)
238
207
208
+ # ------------------------------------------------------------------------------
209
+ # used for brulee engines:
210
+
211
+ tune_activations <- c(" relu" , " tanh" , " elu" , " log_sigmoid" , " tanhshrink" )
212
+ tune_sched <- c(" none" , " decay_time" , " decay_expo" , " cyclic" , " step" )
213
+
214
+ brulee_mlp_args <-
215
+ tibble :: tibble(
216
+ name = c(' epochs' , ' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' ,
217
+ ' penalty' , ' mixture' , ' dropout' , ' learn_rate' , ' momentum' , ' batch_size' ,
218
+ ' class_weights' , ' stop_iter' , ' rate_schedule' ),
219
+ call_info = list (
220
+ list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )),
221
+ list (pkg = " dials" , fun = " hidden_units" , range = c(2L , 50L )),
222
+ list (pkg = " dials" , fun = " hidden_units_2" , range = c(2L , 50L )),
223
+ list (pkg = " dials" , fun = " activation" , values = tune_activations ),
224
+ list (pkg = " dials" , fun = " activation_2" , values = tune_activations ),
225
+ list (pkg = " dials" , fun = " penalty" ),
226
+ list (pkg = " dials" , fun = " mixture" ),
227
+ list (pkg = " dials" , fun = " dropout" ),
228
+ list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 5 )),
229
+ list (pkg = " dials" , fun = " momentum" , range = c(0.50 , 0.95 )),
230
+ list (pkg = " dials" , fun = " batch_size" ),
231
+ list (pkg = " dials" , fun = " stop_iter" ),
232
+ list (pkg = " dials" , fun = " class_weights" ),
233
+ list (pkg = " dials" , fun = " rate_schedule" , values = tune_sched )
234
+ )
235
+ ) %> %
236
+ dplyr :: mutate(source = " model_spec" )
237
+
238
+ brulee_mlp_only_args <-
239
+ tibble :: tibble(
240
+ name =
241
+ c(' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' , ' dropout' )
242
+ )
243
+
239
244
# ------------------------------------------------------------------------------
240
245
241
246
# ' @export
@@ -245,31 +250,55 @@ tunable.linear_reg <- function(x, ...) {
245
250
res $ call_info [res $ name == " mixture" ] <-
246
251
list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
247
252
} else if (x $ engine == " brulee" ) {
248
- res <- add_engine_parameters(res , brulee_linear_engine_args )
253
+ res <-
254
+ brulee_mlp_args %> %
255
+ dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) %> %
256
+ dplyr :: filter(name != " class_weights" ) %> %
257
+ dplyr :: mutate(
258
+ component = " linear_reg" ,
259
+ component_id = ifelse(name %in% names(formals(" linear_reg" )), " main" , " engine" )
260
+ ) %> %
261
+ dplyr :: select(name , call_info , source , component , component_id )
249
262
}
250
263
res
251
264
}
252
265
266
+ # ' @export
267
+
253
268
# ' @export
254
269
tunable.logistic_reg <- function (x , ... ) {
255
270
res <- NextMethod()
256
271
if (x $ engine == " glmnet" ) {
257
272
res $ call_info [res $ name == " mixture" ] <-
258
273
list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
259
274
} else if (x $ engine == " brulee" ) {
260
- res <- add_engine_parameters(res , brulee_logistic_engine_args )
275
+ res <-
276
+ brulee_mlp_args %> %
277
+ dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) %> %
278
+ dplyr :: mutate(
279
+ component = " logistic_reg" ,
280
+ component_id = ifelse(name %in% names(formals(" logistic_reg" )), " main" , " engine" )
281
+ ) %> %
282
+ dplyr :: select(name , call_info , source , component , component_id )
261
283
}
262
284
res
263
285
}
264
286
265
287
# ' @export
266
- tunable.multinomial_reg <- function (x , ... ) {
288
+ tunable.multinom_reg <- function (x , ... ) {
267
289
res <- NextMethod()
268
290
if (x $ engine == " glmnet" ) {
269
291
res $ call_info [res $ name == " mixture" ] <-
270
292
list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
271
293
} else if (x $ engine == " brulee" ) {
272
- res <- add_engine_parameters(res , brulee_multinomial_engine_args )
294
+ res <-
295
+ brulee_mlp_args %> %
296
+ dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) %> %
297
+ dplyr :: mutate(
298
+ component = " multinom_reg" ,
299
+ component_id = ifelse(name %in% names(formals(" multinom_reg" )), " main" , " engine" )
300
+ ) %> %
301
+ dplyr :: select(name , call_info , source , component , component_id )
273
302
}
274
303
res
275
304
}
@@ -345,28 +374,23 @@ tunable.svm_poly <- function(x, ...) {
345
374
res
346
375
}
347
376
348
-
349
377
# ' @export
350
378
tunable.mlp <- function (x , ... ) {
351
379
res <- NextMethod()
352
- if (x $ engine == " brulee" ) {
353
- res <- add_engine_parameters(res , brulee_mlp_engine_args )
354
- res $ call_info [res $ name == " learn_rate" ] <-
355
- list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
356
- res $ call_info [res $ name == " epochs" ] <-
357
- list (list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )))
358
- activation_values <- rlang :: eval_tidy(
359
- rlang :: call2(" brulee_activations" , .ns = " brulee" )
360
- )
361
- res $ call_info [res $ name == " activation" ] <-
362
- list (list (pkg = " dials" , fun = " activation" , values = activation_values ))
363
- } else if (x $ engine == " keras" ) {
364
- activation_values <- parsnip :: keras_activations()
365
- res $ call_info [res $ name == " activation" ] <-
366
- list (list (pkg = " dials" , fun = " activation" , values = activation_values ))
380
+ if (grepl(" brulee" , x $ engine )) {
381
+ res <-
382
+ brulee_mlp_args %> %
383
+ dplyr :: mutate(
384
+ component = " mlp" ,
385
+ component_id = ifelse(name %in% names(formals(" mlp" )), " main" , " engine" )
386
+ ) %> %
387
+ dplyr :: select(name , call_info , source , component , component_id )
388
+ if (x $ engine == " brulee" ) {
389
+ res <- res [! grepl(" _2" , res $ name ),]
390
+ }
367
391
}
368
392
res
369
- }
393
+ }
370
394
371
395
# ' @export
372
396
tunable.survival_reg <- function (x , ... ) {
0 commit comments