Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
supported_dplyr_methods <- list(
select = NULL,
filter = NULL,
filter_out = NULL,
collect = NULL,
summarise = c(
"window functions not currently supported;",
Expand Down
121 changes: 100 additions & 21 deletions r/R/dplyr-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,61 @@

# The following S3 methods are registered on load if dplyr is present

filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) {
try_arrow_dplyr({
# TODO something with the .preserve argument
out <- as_adq(.data)
apply_filter_impl <- function(
.data,
...,
.by = NULL,
.preserve = FALSE,
negate = FALSE
) {
# TODO something with the .preserve argument
out <- as_adq(.data)

by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data")
by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data")

if (by$from_by) {
out$group_by_vars <- by$names
}
if (by$from_by) {
out$group_by_vars <- by$names
}

expanded_filters <- expand_across(out, quos(...))
if (length(expanded_filters) == 0) {
# Nothing to do
return(as_adq(.data))
}

# tidy-eval the filter expressions inside an Arrow data_mask
mask <- arrow_mask(out)

if (isTRUE(negate)) {
# filter_out(): combine all predicates with &, then negate
combined <- NULL

for (expr in expanded_filters) {
filt <- arrow_eval(expr, mask)

expanded_filters <- expand_across(out, quos(...))
if (length(expanded_filters) == 0) {
# Nothing to do
return(as_adq(.data))
if (length(mask$.aggregations)) {
# dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it.
# But we could, the same way it works in mutate() via join, if someone asks.
# Until then, just error.
arrow_not_supported(
.actual_msg = "Expression not supported in filter_out() in Arrow",
call = expr
)
}

if (is_list_of(filt, "Expression")) {
filt <- Reduce("&", filt)
}

combined <- if (is.null(combined)) filt else (combined & filt)
}

# tidy-eval the filter expressions inside an Arrow data_mask
mask <- arrow_mask(out)
out <- set_filters(out, combined, negate = TRUE)
} else {
# filter(): apply each predicate sequentially
for (expr in expanded_filters) {
filt <- arrow_eval(expr, mask)

if (length(mask$.aggregations)) {
# dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it.
# But we could, the same way it works in mutate() via join, if someone asks.
Expand All @@ -47,27 +81,72 @@ filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE)
call = expr
)
}
out <- set_filters(out, filt)
}

if (by$from_by) {
out$group_by_vars <- character()
out <- set_filters(out, filt, negate = FALSE)
}
}

if (by$from_by) {
out$group_by_vars <- character()
}

out
out
}

filter.arrow_dplyr_query <- function(
.data,
...,
.by = NULL,
.preserve = FALSE
) {
try_arrow_dplyr({
apply_filter_impl(
.data,
...,
.by = {{ .by }},
.preserve = .preserve,
negate = FALSE
)
})
}
filter.Dataset <- filter.ArrowTabular <- filter.RecordBatchReader <- filter.arrow_dplyr_query

set_filters <- function(.data, expressions) {
filter_out.arrow_dplyr_query <- function(
.data,
...,
.by = NULL,
.preserve = FALSE
) {
try_arrow_dplyr({
apply_filter_impl(
.data,
...,
.by = {{ .by }},
.preserve = .preserve,
negate = TRUE
)
})
}
filter_out.Dataset <- filter_out.ArrowTabular <- filter_out.RecordBatchReader <- filter_out.arrow_dplyr_query

set_filters <- function(.data, expressions, negate = FALSE) {
if (length(expressions)) {
if (is_list_of(expressions, "Expression")) {
# expressions is a list of Expressions. AND them together and set them on .data
new_filter <- Reduce("&", expressions)
} else if (inherits(expressions, "Expression")) {
new_filter <- expressions
} else {
stop("filter expressions must be either an expression or a list of expressions", call. = FALSE)
stop(
"filter expressions must be either an expression or a list of expressions",
call. = FALSE
)
}

if (isTRUE(negate)) {
# dplyr::filter_out() semantics: drop rows where predicate is TRUE;
# keep rows where predicate is FALSE or NA.
new_filter <- (!new_filter) | is.na(new_filter)
}

if (isTRUE(.data$filtered_rows)) {
Expand Down
3 changes: 2 additions & 1 deletion r/R/dplyr-funcs-doc.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#' Functions available in Arrow dplyr queries
#'
#' The `arrow` package contains methods for 37 `dplyr` table functions, many of
#' The `arrow` package contains methods for 38 `dplyr` table functions, many of
#' which are "verbs" that do transformations to one or more tables.
#' The package also has mappings of 224 R functions to the corresponding
#' functions in the Arrow compute library. These allow you to write code inside
Expand All @@ -45,6 +45,7 @@
#' * [`distinct()`][dplyr::distinct()]: `.keep_all = TRUE` returns a non-missing value if present, only returning missing values if all are missing.
#' * [`explain()`][dplyr::explain()]
#' * [`filter()`][dplyr::filter()]
#' * [`filter_out()`][dplyr::filter_out()]
#' * [`full_join()`][dplyr::full_join()]: the `copy` argument is ignored
#' * [`glimpse()`][dplyr::glimpse()]
#' * [`group_by()`][dplyr::group_by()]
Expand Down
5 changes: 3 additions & 2 deletions r/man/acero.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions r/tests/testthat/test-dplyr-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,51 @@ test_that("filter() with aggregation expressions errors", {
"not supported in filter"
)
})

test_that("filter_out() basic", {
compare_dplyr_binding(
.input |>
filter_out(chr == "b") |>
select(chr, int, lgl) |>
collect(),
tbl
)
})

test_that("filter_out() keeps NA values in predicate result", {
compare_dplyr_binding(
.input |>
filter_out(lgl) |>
select(chr, int, lgl) |>
collect(),
tbl
)
})

test_that("filter_out() with multiple conditions", {
compare_dplyr_binding(
.input |>
filter_out(dbl > 2, chr %in% c("d", "f")) |>
collect(),
tbl
)
})

test_that("More complex select/filter_out", {
compare_dplyr_binding(
.input |>
filter_out(dbl > 2, chr == "d" | chr == "f") |>
select(chr, int, lgl) |>
filter(int < 5) |>
select(int, chr) |>
collect(),
tbl
)

compare_dplyr_binding(
.input |>
filter_out(!is.na(int)) |>
collect(),
tbl
)
})
Loading