diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index a1167433c932..5a596dffe3cd 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -38,6 +38,7 @@ supported_dplyr_methods <- list( select = NULL, filter = NULL, + filter_out = NULL, collect = NULL, summarise = c( "window functions not currently supported;", diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R index 18f5c929affb..26fa1bf7d5f2 100644 --- a/r/R/dplyr-filter.R +++ b/r/R/dplyr-filter.R @@ -17,27 +17,61 @@ # The following S3 methods are registered on load if dplyr is present -filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) { - try_arrow_dplyr({ - # TODO something with the .preserve argument - out <- as_adq(.data) +apply_filter_impl <- function( + .data, + ..., + .by = NULL, + .preserve = FALSE, + negate = FALSE +) { + # TODO something with the .preserve argument + out <- as_adq(.data) - by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") + by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") - if (by$from_by) { - out$group_by_vars <- by$names - } + if (by$from_by) { + out$group_by_vars <- by$names + } + + expanded_filters <- expand_across(out, quos(...)) + if (length(expanded_filters) == 0) { + # Nothing to do + return(as_adq(.data)) + } + + # tidy-eval the filter expressions inside an Arrow data_mask + mask <- arrow_mask(out) + + if (isTRUE(negate)) { + # filter_out(): combine all predicates with &, then negate + combined <- NULL + + for (expr in expanded_filters) { + filt <- arrow_eval(expr, mask) - expanded_filters <- expand_across(out, quos(...)) - if (length(expanded_filters) == 0) { - # Nothing to do - return(as_adq(.data)) + if (length(mask$.aggregations)) { + # dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it. + # But we could, the same way it works in mutate() via join, if someone asks. + # Until then, just error. + arrow_not_supported( + .actual_msg = "Expression not supported in filter_out() in Arrow", + call = expr + ) + } + + if (is_list_of(filt, "Expression")) { + filt <- Reduce("&", filt) + } + + combined <- if (is.null(combined)) filt else (combined & filt) } - # tidy-eval the filter expressions inside an Arrow data_mask - mask <- arrow_mask(out) + out <- set_filters(out, combined, negate = TRUE) + } else { + # filter(): apply each predicate sequentially for (expr in expanded_filters) { filt <- arrow_eval(expr, mask) + if (length(mask$.aggregations)) { # dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it. # But we could, the same way it works in mutate() via join, if someone asks. @@ -47,19 +81,55 @@ filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) call = expr ) } - out <- set_filters(out, filt) - } - if (by$from_by) { - out$group_by_vars <- character() + out <- set_filters(out, filt, negate = FALSE) } + } + + if (by$from_by) { + out$group_by_vars <- character() + } - out + out +} + +filter.arrow_dplyr_query <- function( + .data, + ..., + .by = NULL, + .preserve = FALSE +) { + try_arrow_dplyr({ + apply_filter_impl( + .data, + ..., + .by = {{ .by }}, + .preserve = .preserve, + negate = FALSE + ) }) } filter.Dataset <- filter.ArrowTabular <- filter.RecordBatchReader <- filter.arrow_dplyr_query -set_filters <- function(.data, expressions) { +filter_out.arrow_dplyr_query <- function( + .data, + ..., + .by = NULL, + .preserve = FALSE +) { + try_arrow_dplyr({ + apply_filter_impl( + .data, + ..., + .by = {{ .by }}, + .preserve = .preserve, + negate = TRUE + ) + }) +} +filter_out.Dataset <- filter_out.ArrowTabular <- filter_out.RecordBatchReader <- filter_out.arrow_dplyr_query + +set_filters <- function(.data, expressions, negate = FALSE) { if (length(expressions)) { if (is_list_of(expressions, "Expression")) { # expressions is a list of Expressions. AND them together and set them on .data @@ -67,7 +137,16 @@ set_filters <- function(.data, expressions) { } else if (inherits(expressions, "Expression")) { new_filter <- expressions } else { - stop("filter expressions must be either an expression or a list of expressions", call. = FALSE) + stop( + "filter expressions must be either an expression or a list of expressions", + call. = FALSE + ) + } + + if (isTRUE(negate)) { + # dplyr::filter_out() semantics: drop rows where predicate is TRUE; + # keep rows where predicate is FALSE or NA. + new_filter <- (!new_filter) | is.na(new_filter) } if (isTRUE(.data$filtered_rows)) { diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R index bbd1c91a0213..9293d14c94c0 100644 --- a/r/R/dplyr-funcs-doc.R +++ b/r/R/dplyr-funcs-doc.R @@ -19,7 +19,7 @@ #' Functions available in Arrow dplyr queries #' -#' The `arrow` package contains methods for 37 `dplyr` table functions, many of +#' The `arrow` package contains methods for 38 `dplyr` table functions, many of #' which are "verbs" that do transformations to one or more tables. #' The package also has mappings of 224 R functions to the corresponding #' functions in the Arrow compute library. These allow you to write code inside @@ -45,6 +45,7 @@ #' * [`distinct()`][dplyr::distinct()]: `.keep_all = TRUE` returns a non-missing value if present, only returning missing values if all are missing. #' * [`explain()`][dplyr::explain()] #' * [`filter()`][dplyr::filter()] +#' * [`filter_out()`][dplyr::filter_out()] #' * [`full_join()`][dplyr::full_join()]: the `copy` argument is ignored #' * [`glimpse()`][dplyr::glimpse()] #' * [`group_by()`][dplyr::group_by()] diff --git a/r/man/acero.Rd b/r/man/acero.Rd index dcaca04d2f2c..ee156cc9129b 100644 --- a/r/man/acero.Rd +++ b/r/man/acero.Rd @@ -7,7 +7,7 @@ \alias{arrow-dplyr} \title{Functions available in Arrow dplyr queries} \description{ -The \code{arrow} package contains methods for 37 \code{dplyr} table functions, many of +The \code{arrow} package contains methods for 38 \code{dplyr} table functions, many of which are "verbs" that do transformations to one or more tables. The package also has mappings of 224 R functions to the corresponding functions in the Arrow compute library. These allow you to write code inside @@ -32,6 +32,7 @@ Table into an R \code{tibble}. \item \code{\link[dplyr:distinct]{distinct()}}: \code{.keep_all = TRUE} returns a non-missing value if present, only returning missing values if all are missing. \item \code{\link[dplyr:explain]{explain()}} \item \code{\link[dplyr:filter]{filter()}} +\item \code{\link[dplyr:filter]{filter_out()}} \item \code{\link[dplyr:mutate-joins]{full_join()}}: the \code{copy} argument is ignored \item \code{\link[dplyr:glimpse]{glimpse()}} \item \code{\link[dplyr:group_by]{group_by()}} @@ -198,7 +199,7 @@ Valid values are "s", "ms" (default), "us", "ns". \itemize{ \item \code{\link[dplyr:across]{across()}} \item \code{\link[dplyr:between]{between()}} -\item \code{\link[dplyr:case_when]{case_when()}}: \code{.ptype} and \code{.size} arguments not supported +\item \code{\link[dplyr:case-and-replace-when]{case_when()}}: \code{.ptype} and \code{.size} arguments not supported \item \code{\link[dplyr:coalesce]{coalesce()}} \item \code{\link[dplyr:desc]{desc()}} \item \code{\link[dplyr:across]{if_all()}} diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index d56e25fca329..3912e518ed08 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -498,3 +498,51 @@ test_that("filter() with aggregation expressions errors", { "not supported in filter" ) }) + +test_that("filter_out() basic", { + compare_dplyr_binding( + .input |> + filter_out(chr == "b") |> + select(chr, int, lgl) |> + collect(), + tbl + ) +}) + +test_that("filter_out() keeps NA values in predicate result", { + compare_dplyr_binding( + .input |> + filter_out(lgl) |> + select(chr, int, lgl) |> + collect(), + tbl + ) +}) + +test_that("filter_out() with multiple conditions", { + compare_dplyr_binding( + .input |> + filter_out(dbl > 2, chr %in% c("d", "f")) |> + collect(), + tbl + ) +}) + +test_that("More complex select/filter_out", { + compare_dplyr_binding( + .input |> + filter_out(dbl > 2, chr == "d" | chr == "f") |> + select(chr, int, lgl) |> + filter(int < 5) |> + select(int, chr) |> + collect(), + tbl + ) + + compare_dplyr_binding( + .input |> + filter_out(!is.na(int)) |> + collect(), + tbl + ) +})