From 660522f042ba91e73e0080351606ce165d2ba203 Mon Sep 17 00:00:00 2001 From: orgadish <48453207+orgadish@users.noreply.github.com> Date: Mon, 9 Oct 2023 00:13:16 -0700 Subject: [PATCH] Revert changes to `vec_group_id` and move implementation to a new `vec_group_id_and_loc`. Added tests for `vec_deduplicate` --- R/deduplicate.R | 17 +++++-- R/group.R | 4 +- man/vec_deduplicate.Rd | 4 +- man/vec_group.Rd | 4 +- src/group.c | 68 ++++++++++++++++++++++------ src/init.c | 2 + tests/testthat/test-deduplicate.R | 73 +++++++++++++++++++++++++++++++ 7 files changed, 147 insertions(+), 25 deletions(-) create mode 100644 tests/testthat/test-deduplicate.R diff --git a/R/deduplicate.R b/R/deduplicate.R index 77a7cca71..0cd61e1da 100644 --- a/R/deduplicate.R +++ b/R/deduplicate.R @@ -1,6 +1,9 @@ #' Modify a function to act on a deduplicated vector input #' #' @description +#' +#' `r lifecycle::badge("experimental")` +#' #' The deduplicated function acts on the unique values in the first input `x` #' and expands the output back to return. The return value is equivalent to `f(x)` #' but is significantly faster for inputs with significant duplication. @@ -14,7 +17,7 @@ #' x <- sample(LETTERS, 10) #' x #' -#' large_x <- sample(rep(x, 100)) +#' large_x <- sample(rep(x, 10)) #' length(large_x) #' #' long_func <- function(x) for(i in x) {Sys.sleep(0.001)} @@ -24,8 +27,14 @@ #' all(y == y2) vec_deduplicate <- function(f) { function(x, ...) { - x_gi <- vec_group_id(x) - x_unq <- vec_slice(x, attr(x_gi, "unique_loc")) - f(x_unq, ...)[x_gi] + res <- vec_group_id_and_loc(x) + group_id <- unclass(res) + unique_loc <- attr(res, "unique_loc") + unique_x <- vec_slice(x, unique_loc) + f(unique_x, ...)[group_id] } } + +vec_group_id_and_loc <- function(x) { + .Call(vctrs_group_id_and_loc, x) +} diff --git a/R/group.R b/R/group.R index 856250941..579cc20d0 100644 --- a/R/group.R +++ b/R/group.R @@ -6,9 +6,7 @@ #' #' * `vec_group_id()` returns an identifier for the group that each element of #' `x` falls in, constructed in the order that they appear. The number of -#' groups is also returned as an attribute, `n`. The locations of unique values -#' (as would be returned by `vec_unique_loc`) is also returned as an attribute, -#' `unique_loc`. +#' groups is also returned as an attribute, `n`. #' #' * `vec_group_loc()` returns a data frame containing a `key` column with the #' unique groups, and a `loc` column with the locations of each group in `x`. diff --git a/man/vec_deduplicate.Rd b/man/vec_deduplicate.Rd index a7df58778..d0cc83cdc 100644 --- a/man/vec_deduplicate.Rd +++ b/man/vec_deduplicate.Rd @@ -13,6 +13,8 @@ vec_deduplicate(f) A deduplicated function } \description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} + The deduplicated function acts on the unique values in the first input \code{x} and expands the output back to return. The return value is equivalent to \code{f(x)} but is significantly faster for inputs with significant duplication. @@ -21,7 +23,7 @@ but is significantly faster for inputs with significant duplication. x <- sample(LETTERS, 10) x -large_x <- sample(rep(x, 100)) +large_x <- sample(rep(x, 10)) length(large_x) long_func <- function(x) for(i in x) {Sys.sleep(0.001)} diff --git a/man/vec_group.Rd b/man/vec_group.Rd index 51e4bd004..26067271e 100644 --- a/man/vec_group.Rd +++ b/man/vec_group.Rd @@ -38,9 +38,7 @@ into a tibble to better understand the output. \itemize{ \item \code{vec_group_id()} returns an identifier for the group that each element of \code{x} falls in, constructed in the order that they appear. The number of -groups is also returned as an attribute, \code{n}. The locations of unique values -(as would be returned by \code{vec_unique_loc}) is also returned as an attribute, -\code{unique_loc}. +groups is also returned as an attribute, \code{n}. \item \code{vec_group_loc()} returns a data frame containing a \code{key} column with the unique groups, and a \code{loc} column with the locations of each group in \code{x}. \item \code{vec_group_rle()} locates groups in \code{x} and returns them run length diff --git a/src/group.c b/src/group.c index c9f179048..ebc09f7cf 100644 --- a/src/group.c +++ b/src/group.c @@ -16,10 +16,7 @@ SEXP vctrs_group_id(SEXP x) { SEXP out = PROTECT_N(Rf_allocVector(INTSXP, n), &nprot); int* p_out = INTEGER(out); - R_len_t g_id = 1; - - struct growable g_unq = new_growable(INTSXP, 256); - PROTECT_GROWABLE(&g_unq, &nprot); + R_len_t g = 1; for (int i = 0; i < n; ++i) { uint32_t hash = dict_hash_scalar(d, i); @@ -27,13 +24,8 @@ SEXP vctrs_group_id(SEXP x) { if (key == DICT_EMPTY) { dict_put(d, hash, i); - // Record group id - p_out[i] = g_id; - ++g_id; - - // Record unique value - growable_push_int(&g_unq, i + 1); - + p_out[i] = g; + ++g; } else { p_out[i] = p_out[key]; } @@ -42,9 +34,6 @@ SEXP vctrs_group_id(SEXP x) { SEXP n_groups = PROTECT_N(Rf_ScalarInteger(d->used), &nprot); Rf_setAttrib(out, syms_n, n_groups); - SEXP unq_vals = growable_values(&g_unq); - Rf_setAttrib(out, Rf_install("unique_loc"), unq_vals); - UNPROTECT(nprot); return out; } @@ -244,3 +233,54 @@ SEXP vec_group_loc(SEXP x) { UNPROTECT(nprot); return out; } + + +// ----------------------------------------------------------------------------- + +// [[ register() ]] +SEXP vctrs_group_id_and_loc(SEXP x) { + int nprot = 0; + + R_len_t n = vec_size(x); + + x = PROTECT_N(vec_proxy_equal(x), &nprot); + x = PROTECT_N(vec_normalize_encoding(x), &nprot); + + struct dictionary* d = new_dictionary(x); + PROTECT_DICT(d, &nprot); + + SEXP out = PROTECT_N(Rf_allocVector(INTSXP, n), &nprot); + int* p_out = INTEGER(out); + + R_len_t g_id = 1; + + struct growable g_unq = new_growable(INTSXP, 256); + PROTECT_GROWABLE(&g_unq, &nprot); + + for (int i = 0; i < n; ++i) { + uint32_t hash = dict_hash_scalar(d, i); + R_len_t key = d->key[hash]; + + if (key == DICT_EMPTY) { + dict_put(d, hash, i); + // Record group id + p_out[i] = g_id; + ++g_id; + + // Record unique locs + growable_push_int(&g_unq, i + 1); + + } else { + p_out[i] = p_out[key]; + } + } + + SEXP n_groups = PROTECT_N(Rf_ScalarInteger(d->used), &nprot); + Rf_setAttrib(out, syms_n, n_groups); + + SEXP unq_vals = growable_values(&g_unq); + Rf_setAttrib(out, Rf_install("unique_loc"), unq_vals); + + UNPROTECT(nprot); + return out; +} diff --git a/src/init.c b/src/init.c index 9b338e3df..90495be30 100644 --- a/src/init.c +++ b/src/init.c @@ -27,6 +27,7 @@ extern SEXP vec_split(SEXP, SEXP); extern SEXP vctrs_group_id(SEXP); extern SEXP vctrs_group_rle(SEXP); extern SEXP vec_group_loc(SEXP); +extern SEXP vctrs_group_id_and_loc(SEXP); extern SEXP vctrs_equal(SEXP, SEXP, SEXP); extern r_obj* ffi_vec_detect_missing(r_obj*); extern r_obj* ffi_vec_any_missing(r_obj* x); @@ -211,6 +212,7 @@ static const R_CallMethodDef CallEntries[] = { {"vctrs_group_id", (DL_FUNC) &vctrs_group_id, 1}, {"vctrs_group_rle", (DL_FUNC) &vctrs_group_rle, 1}, {"vctrs_group_loc", (DL_FUNC) &vec_group_loc, 1}, + {"vctrs_group_id_and_loc", (DL_FUNC) &vctrs_group_id_and_loc, 1}, {"ffi_size", (DL_FUNC) &ffi_size, 2}, {"ffi_list_sizes", (DL_FUNC) &ffi_list_sizes, 2}, {"vctrs_dim", (DL_FUNC) &vctrs_dim, 1}, diff --git a/tests/testthat/test-deduplicate.R b/tests/testthat/test-deduplicate.R new file mode 100644 index 000000000..74dcbde13 --- /dev/null +++ b/tests/testthat/test-deduplicate.R @@ -0,0 +1,73 @@ +# group_id_and_loc ---------------------------------------------------------------- + +expect_matches_separate_calls <- function(x) { + expect_equal( + as.numeric(vec_group_id_and_loc(x)), + as.numeric(vec_group_id(x)) + ) + expect_equal( + attr(vec_group_id_and_loc(x), "unique_loc"), + vec_unique_loc(x) + ) +} + +test_that("vec_group_id_and_loc matches vec_group_id and vec_unique_loc", { + x <- c(2, 4, 2, 1, 4) + expect_matches_separate_calls(x) +}) + +test_that("vec_group_id_and_loc works for size 0 input", { + expect <- structure(integer(), n = 0L, unique_loc=integer()) + expect_equal(vec_group_id_and_loc(NULL), expect) + expect_equal(vec_group_id_and_loc(numeric()), expect) +}) + +test_that("vec_group_id_and_loc works on base S3 objects", { + x <- factor(c("x", "y", "x")) + expect_matches_separate_calls(x) + + x <- new_date(c(0, 1, 0)) + expect_matches_separate_calls(x) +}) + +test_that("vec_group_id_and_loc works on data frames", { + df <- data.frame(x = c(1, 2, 1, 1), y = c(2, 3, 2, 3)) + expect_matches_separate_calls(df) +}) + +test_that("vec_group_id_and_loc works on arrays", { + x <- array(c(1, 1, 1, 2, 4, 2), c(3, 2)) + expect_matches_separate_calls(x) +}) + +test_that("vec_group_id takes the equality proxy", { + local_comparable_tuple() + x <- tuple(c(1, 2, 1, 1), c(1, 1, 1, 2)) + # Compares on only the first field + expect_matches_separate_calls(x) +}) + +test_that("vec_group_id takes the equality proxy recursively", { + local_comparable_tuple() + + x <- tuple(c(1, 2, 1, 1), 1:4) + df <- data_frame(x = x) + expect_matches_separate_calls(df) +}) + + +# vec_deduplicate --------------------------------------------------------- + +test_that("vec_deduplicate(f) runs only on deduplicated values", { + ncalls <<- 0 + f <- function(ii) for(i in ii) ncalls <<- ncalls + 1 + + x <- c(1, 1, 1, 2, 3) + vec_deduplicate(f)(x) + expect_equal(ncalls, 3) + + ncalls <<- 0 + x <- 1:5 + vec_deduplicate(f)(x) + expect_equal(ncalls, 5) +})