Skip to content

Commit

Permalink
Parallelized map using mirai (#1163)
Browse files Browse the repository at this point in the history
  • Loading branch information
shikokuchuo authored Feb 6, 2025
1 parent baa4cd4 commit 60045a5
Show file tree
Hide file tree
Showing 16 changed files with 1,039 additions and 70 deletions.
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ Imports:
magrittr (>= 1.5.0),
rlang (>= 1.1.1),
vctrs (>= 0.6.3)
Suggests:
Suggests:
carrier (>= 0.1.1),
covr,
dplyr (>= 0.7.8),
httr,
knitr,
lubridate,
mirai (>= 2.0.1.9005),
rmarkdown,
testthat (>= 3.0.0),
tibble,
Expand All @@ -43,3 +45,4 @@ Config/testthat/parallel: TRUE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Remotes: shikokuchuo/mirai
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# purrr (development version)

* purrr gains the capacity for parallel and distributed map, powered by the
mirai package. The argument `.parallel` has been added to `map()`, `map2()`,
`pmap()` and variants to enable this. See `?parallelization` for more details
(@shikokuchuo, #1163).

# purrr 1.0.4

# purrr 1.0.3
Expand Down
4 changes: 2 additions & 2 deletions R/map-if-at.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ map_if <- function(.x, .p, .f, ..., .else = NULL) {
#' installed, you can use `vars()` and tidyselect helpers to select
#' elements.
#' @export
map_at <- function(.x, .at, .f, ..., .progress = FALSE) {
map_at <- function(.x, .at, .f, ..., .parallel = FALSE, .progress = FALSE) {
where <- where_at(.x, .at, user_env = caller_env())

out <- vector("list", length(.x))
out[where] <- map(.x[where], .f, ..., .progress = .progress)
out[where] <- map(.x[where], .f, ..., .parallel = .parallel, .progress = .progress)
out[!where] <- .x[!where]

set_names(out, names(.x))
Expand Down
129 changes: 112 additions & 17 deletions R/map.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@
#' This makes it easier to understand which arguments belong to which
#' function and will tend to yield better error messages.
#'
#' @param .parallel `r lifecycle::badge("experimental")` Whether to map in
#' parallel. Use `TRUE` to parallelize using the \CRANpkg{mirai} package.
#' * Set up parallelization in your session beforehand using
#' [mirai::daemons()].
#' * Non-package functions are auto-crated for sharing with parallel
#' processes. You may [carrier::crate()] your function explicitly if you need
#' to supply additional objects along with your function.
#' * Use of `...` is not permitted in this context, [carrier::crate()] an
#' anonymous function instead.
#'
#' See [parallelization] for more details.
#' @param .progress Whether to show a progress bar. Use `TRUE` to turn on
#' a basic progress bar, use a string to give it a name, or see
#' [progress_bars] for more details.
Expand Down Expand Up @@ -125,50 +136,69 @@
#' map(\(df) lm(mpg ~ wt, data = df)) |>
#' map(summary) |>
#' map_dbl("r.squared")
map <- function(.x, .f, ..., .progress = FALSE) {
map_("list", .x, .f, ..., .progress = .progress)
#'
#' @examplesIf interactive() && requireNamespace("mirai", quietly = TRUE) && requireNamespace("carrier", quietly = TRUE)
#' # Run in interactive sessions only as spawns additional processes
#'
#' # To use parallelized map, set daemons (number of parallel processes) first:
#' mirai::daemons(2)
#'
#' mtcars |> map_dbl(sum, .parallel = TRUE)
#'
#' 1:10 |>
#' map(function(x) stats::rnorm(10, mean = x), .parallel = TRUE) |>
#' map_dbl(mean, .parallel = TRUE)
#'
#' mirai::daemons(0)
#'
map <- function(.x, .f, ..., .parallel = FALSE, .progress = FALSE) {
map_("list", .x, .f, ..., .parallel = .parallel, .progress = .progress)
}

#' @rdname map
#' @export
map_lgl <- function(.x, .f, ..., .progress = FALSE) {
map_("logical", .x, .f, ..., .progress = .progress)
map_lgl <- function(.x, .f, ..., .parallel = FALSE, .progress = FALSE) {
map_("logical", .x, .f, ..., .parallel = .parallel, .progress = .progress)
}

#' @rdname map
#' @export
map_int <- function(.x, .f, ..., .progress = FALSE) {
map_("integer", .x, .f, ..., .progress = .progress)
map_int <- function(.x, .f, ..., .parallel = FALSE, .progress = FALSE) {
map_("integer", .x, .f, ..., .parallel = .parallel, .progress = .progress)
}

#' @rdname map
#' @export
map_dbl <- function(.x, .f, ..., .progress = FALSE) {
map_("double", .x, .f, ..., .progress = .progress)
map_dbl <- function(.x, .f, ..., .parallel = FALSE, .progress = FALSE) {
map_("double", .x, .f, ..., .parallel = .parallel, .progress = .progress)
}

#' @rdname map
#' @export
map_chr <- function(.x, .f, ..., .progress = FALSE) {
map_chr <- function(.x, .f, ..., .parallel = FALSE, .progress = FALSE) {
local_deprecation_user_env()
map_("character", .x, .f, ..., .progress = .progress)
map_("character", .x, .f, ..., .parallel = .parallel, .progress = .progress)
}

map_ <- function(.type,
.x,
.f,
...,
.parallel = FALSE,
.progress = FALSE,
.purrr_user_env = caller_env(2),
.purrr_error_call = caller_env()) {
.x <- vctrs_vec_compat(.x, .purrr_user_env)
vec_assert(.x, arg = ".x", call = .purrr_error_call)

n <- vec_size(.x)
.f <- as_mapper(.f, ...)

names <- vec_names(.x)
if (isTRUE(.parallel)) {
return(mmap_(.x, .f, .progress, .type, .purrr_error_call, ...))
}

.f <- as_mapper(.f, ...)
n <- vec_size(.x)
names <- vec_names(.x)

i <- 0L
with_indexed_errors(
Expand All @@ -179,21 +209,62 @@ map_ <- function(.type,
)
}

mmap_ <- function(.x, .f, .progress, .type, error_call, ...) {

if (is.null(the$packages_installed)) {
rlang::check_installed(c("mirai", "carrier"), reason = "for parallel map.")
the$packages_installed <- TRUE
}

if (is.null(mirai::nextget("n"))) {
cli::cli_abort(
"No daemons set - use e.g. {.run mirai::daemons(6)} to set 6 local daemons.",
call = error_call
)
}
if (...length()) {
cli::cli_abort(
"Don't use `...` with `.parallel = TRUE`.",
call = error_call
)
}

if (!isNamespace(topenv(environment(.f))) && !carrier::is_crate(.f)) {
.f <- carrier::crate(rlang::set_env(.f))
cli::cli_inform(c(
v = "Automatically crated `.f`: {format(lobstr::obj_size(.f))}"
))
}

m <- mirai::mirai_map(.x, .f)

options <- c(".stop", if (isTRUE(.progress)) ".progress")
x <- with_parallel_indexed_errors(
mirai::collect_mirai(m, options = options),
interrupt_expr = mirai::stop_mirai(m),
error_call = error_call
)
if (.type != "list") {
x <- simplify_impl(x, ptype = vector(mode = .type), error_call = error_call)
}
x

}

#' @rdname map
#' @param .ptype If `NULL`, the default, the output type is the common type
#' of the elements of the result. Otherwise, supply a "prototype" giving
#' the desired type of output.
#' @export
map_vec <- function(.x, .f, ..., .ptype = NULL, .progress = FALSE) {
out <- map(.x, .f, ..., .progress = .progress)
map_vec <- function(.x, .f, ..., .ptype = NULL, .parallel = FALSE, .progress = FALSE) {
out <- map(.x, .f, ..., .parallel = .parallel, .progress = .progress)
simplify_impl(out, ptype = .ptype)
}

#' @rdname map
#' @export
walk <- function(.x, .f, ..., .progress = FALSE) {
map(.x, .f, ..., .progress = .progress)
walk <- function(.x, .f, ..., .parallel = FALSE, .progress = FALSE) {
map(.x, .f, ..., .parallel = .parallel, .progress = .progress)
invisible(.x)
}

Expand Down Expand Up @@ -225,6 +296,30 @@ with_indexed_errors <- function(expr, i, names = NULL, error_call = caller_env()
)
}

with_parallel_indexed_errors <- function(expr, interrupt_expr = NULL, error_call = caller_env()) {
withCallingHandlers(
expr,
error = function(cnd) {
location <- cnd$location
iname <- cnd$name
cli::cli_abort(
c(
i = "In index: {location}.",
i = if (length(iname) && nzchar(iname)) "With name: {iname}."
),
location = location,
name = iname,
parent = cnd$parent,
call = error_call,
class = "purrr_error_indexed"
)
},
interrupt = function(cnd) {
interrupt_expr
}
)
}

#' Indexed errors (`purrr_error_indexed`)
#'
#' @description
Expand Down
37 changes: 23 additions & 14 deletions R/map2.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,36 @@
#' by_cyl <- mtcars |> split(mtcars$cyl)
#' mods <- by_cyl |> map(\(df) lm(mpg ~ wt, data = df))
#' map2(mods, by_cyl, predict)
map2 <- function(.x, .y, .f, ..., .progress = FALSE) {
map2_("list", .x, .y, .f, ..., .progress = .progress)
map2 <- function(.x, .y, .f, ..., .parallel = FALSE, .progress = FALSE) {
map2_("list", .x, .y, .f, ..., .parallel = .parallel, .progress = .progress)
}
#' @export
#' @rdname map2
map2_lgl <- function(.x, .y, .f, ..., .progress = FALSE) {
map2_("logical", .x, .y, .f, ..., .progress = .progress)
map2_lgl <- function(.x, .y, .f, ..., .parallel = FALSE, .progress = FALSE) {
map2_("logical", .x, .y, .f, ..., .parallel = .parallel, .progress = .progress)
}
#' @export
#' @rdname map2
map2_int <- function(.x, .y, .f, ..., .progress = FALSE) {
map2_("integer", .x, .y, .f, ..., .progress = .progress)
map2_int <- function(.x, .y, .f, ..., .parallel = FALSE, .progress = FALSE) {
map2_("integer", .x, .y, .f, ..., .parallel = .parallel, .progress = .progress)
}
#' @export
#' @rdname map2
map2_dbl <- function(.x, .y, .f, ..., .progress = FALSE) {
map2_("double", .x, .y, .f, ..., .progress = .progress)
map2_dbl <- function(.x, .y, .f, ..., .parallel = FALSE, .progress = FALSE) {
map2_("double", .x, .y, .f, ..., .parallel = .parallel, .progress = .progress)
}
#' @export
#' @rdname map2
map2_chr <- function(.x, .y, .f, ..., .progress = FALSE) {
map2_("character", .x, .y, .f, ..., .progress = .progress)
map2_chr <- function(.x, .y, .f, ..., .parallel = FALSE, .progress = FALSE) {
map2_("character", .x, .y, .f, ..., .parallel = .parallel, .progress = .progress)
}

map2_ <- function(.type,
.x,
.y,
.f,
...,
.parallel = FALSE,
.progress = FALSE,
.purrr_user_env = caller_env(2),
.purrr_error_call = caller_env()) {
Expand All @@ -74,6 +75,14 @@ map2_ <- function(.type,

.f <- as_mapper(.f, ...)

if (isTRUE(.parallel)) {
attributes(args) <- list(
class = "data.frame",
row.names = if (is.null(names)) .set_row_names(n) else names
)
return(mmap_(args, .f, .progress, .type, .purrr_error_call, ...))
}

i <- 0L
with_indexed_errors(
i = i,
Expand All @@ -85,14 +94,14 @@ map2_ <- function(.type,

#' @rdname map2
#' @export
map2_vec <- function(.x, .y, .f, ..., .ptype = NULL, .progress = FALSE) {
out <- map2(.x, .y, .f, ..., .progress = .progress)
map2_vec <- function(.x, .y, .f, ..., .ptype = NULL, .parallel = FALSE, .progress = FALSE) {
out <- map2(.x, .y, .f, ..., .parallel = .parallel, .progress = .progress)
simplify_impl(out, ptype = .ptype)
}

#' @export
#' @rdname map2
walk2 <- function(.x, .y, .f, ..., .progress = FALSE) {
map2(.x, .y, .f, ..., .progress = .progress)
walk2 <- function(.x, .y, .f, ..., .parallel = FALSE, .progress = FALSE) {
map2(.x, .y, .f, ..., .parallel = .parallel, .progress = .progress)
invisible(.x)
}
Loading

0 comments on commit 60045a5

Please sign in to comment.