Skip to content

Commit

Permalink
Revert changes to vec_group_id and move implementation to a new `ve…
Browse files Browse the repository at this point in the history
…c_group_id_and_loc`. Added tests for `vec_deduplicate`
  • Loading branch information
orgadish committed Oct 9, 2023
1 parent 4aec39d commit 660522f
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 25 deletions.
17 changes: 13 additions & 4 deletions R/deduplicate.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#' Modify a function to act on a deduplicated vector input
#'
#' @description
#'
#' `r lifecycle::badge("experimental")`
#'
#' The deduplicated function acts on the unique values in the first input `x`
#' and expands the output back to return. The return value is equivalent to `f(x)`
#' but is significantly faster for inputs with significant duplication.
Expand All @@ -14,7 +17,7 @@
#' x <- sample(LETTERS, 10)
#' x
#'
#' large_x <- sample(rep(x, 100))
#' large_x <- sample(rep(x, 10))
#' length(large_x)
#'
#' long_func <- function(x) for(i in x) {Sys.sleep(0.001)}
Expand All @@ -24,8 +27,14 @@
#' all(y == y2)
vec_deduplicate <- function(f) {
function(x, ...) {
x_gi <- vec_group_id(x)
x_unq <- vec_slice(x, attr(x_gi, "unique_loc"))
f(x_unq, ...)[x_gi]
res <- vec_group_id_and_loc(x)
group_id <- unclass(res)
unique_loc <- attr(res, "unique_loc")
unique_x <- vec_slice(x, unique_loc)
f(unique_x, ...)[group_id]
}
}

vec_group_id_and_loc <- function(x) {
.Call(vctrs_group_id_and_loc, x)
}
4 changes: 1 addition & 3 deletions R/group.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
#'
#' * `vec_group_id()` returns an identifier for the group that each element of
#' `x` falls in, constructed in the order that they appear. The number of
#' groups is also returned as an attribute, `n`. The locations of unique values
#' (as would be returned by `vec_unique_loc`) is also returned as an attribute,
#' `unique_loc`.
#' groups is also returned as an attribute, `n`.
#'
#' * `vec_group_loc()` returns a data frame containing a `key` column with the
#' unique groups, and a `loc` column with the locations of each group in `x`.
Expand Down
4 changes: 3 additions & 1 deletion man/vec_deduplicate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions man/vec_group.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 54 additions & 14 deletions src/group.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,16 @@ SEXP vctrs_group_id(SEXP x) {
SEXP out = PROTECT_N(Rf_allocVector(INTSXP, n), &nprot);
int* p_out = INTEGER(out);

R_len_t g_id = 1;

struct growable g_unq = new_growable(INTSXP, 256);
PROTECT_GROWABLE(&g_unq, &nprot);
R_len_t g = 1;

for (int i = 0; i < n; ++i) {
uint32_t hash = dict_hash_scalar(d, i);
R_len_t key = d->key[hash];

if (key == DICT_EMPTY) {
dict_put(d, hash, i);
// Record group id
p_out[i] = g_id;
++g_id;

// Record unique value
growable_push_int(&g_unq, i + 1);

p_out[i] = g;
++g;
} else {
p_out[i] = p_out[key];
}
Expand All @@ -42,9 +34,6 @@ SEXP vctrs_group_id(SEXP x) {
SEXP n_groups = PROTECT_N(Rf_ScalarInteger(d->used), &nprot);
Rf_setAttrib(out, syms_n, n_groups);

SEXP unq_vals = growable_values(&g_unq);
Rf_setAttrib(out, Rf_install("unique_loc"), unq_vals);

UNPROTECT(nprot);
return out;
}
Expand Down Expand Up @@ -244,3 +233,54 @@ SEXP vec_group_loc(SEXP x) {
UNPROTECT(nprot);
return out;
}


// -----------------------------------------------------------------------------

// [[ register() ]]
SEXP vctrs_group_id_and_loc(SEXP x) {
int nprot = 0;

R_len_t n = vec_size(x);

x = PROTECT_N(vec_proxy_equal(x), &nprot);
x = PROTECT_N(vec_normalize_encoding(x), &nprot);

struct dictionary* d = new_dictionary(x);
PROTECT_DICT(d, &nprot);

SEXP out = PROTECT_N(Rf_allocVector(INTSXP, n), &nprot);
int* p_out = INTEGER(out);

R_len_t g_id = 1;

struct growable g_unq = new_growable(INTSXP, 256);
PROTECT_GROWABLE(&g_unq, &nprot);

for (int i = 0; i < n; ++i) {
uint32_t hash = dict_hash_scalar(d, i);
R_len_t key = d->key[hash];

if (key == DICT_EMPTY) {
dict_put(d, hash, i);
// Record group id
p_out[i] = g_id;
++g_id;

// Record unique locs
growable_push_int(&g_unq, i + 1);

} else {
p_out[i] = p_out[key];
}
}

SEXP n_groups = PROTECT_N(Rf_ScalarInteger(d->used), &nprot);
Rf_setAttrib(out, syms_n, n_groups);

SEXP unq_vals = growable_values(&g_unq);
Rf_setAttrib(out, Rf_install("unique_loc"), unq_vals);

UNPROTECT(nprot);
return out;
}
2 changes: 2 additions & 0 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ extern SEXP vec_split(SEXP, SEXP);
extern SEXP vctrs_group_id(SEXP);
extern SEXP vctrs_group_rle(SEXP);
extern SEXP vec_group_loc(SEXP);
extern SEXP vctrs_group_id_and_loc(SEXP);
extern SEXP vctrs_equal(SEXP, SEXP, SEXP);
extern r_obj* ffi_vec_detect_missing(r_obj*);
extern r_obj* ffi_vec_any_missing(r_obj* x);
Expand Down Expand Up @@ -211,6 +212,7 @@ static const R_CallMethodDef CallEntries[] = {
{"vctrs_group_id", (DL_FUNC) &vctrs_group_id, 1},
{"vctrs_group_rle", (DL_FUNC) &vctrs_group_rle, 1},
{"vctrs_group_loc", (DL_FUNC) &vec_group_loc, 1},
{"vctrs_group_id_and_loc", (DL_FUNC) &vctrs_group_id_and_loc, 1},
{"ffi_size", (DL_FUNC) &ffi_size, 2},
{"ffi_list_sizes", (DL_FUNC) &ffi_list_sizes, 2},
{"vctrs_dim", (DL_FUNC) &vctrs_dim, 1},
Expand Down
73 changes: 73 additions & 0 deletions tests/testthat/test-deduplicate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# group_id_and_loc ----------------------------------------------------------------

expect_matches_separate_calls <- function(x) {
expect_equal(
as.numeric(vec_group_id_and_loc(x)),
as.numeric(vec_group_id(x))
)
expect_equal(
attr(vec_group_id_and_loc(x), "unique_loc"),
vec_unique_loc(x)
)
}

test_that("vec_group_id_and_loc matches vec_group_id and vec_unique_loc", {
x <- c(2, 4, 2, 1, 4)
expect_matches_separate_calls(x)
})

test_that("vec_group_id_and_loc works for size 0 input", {
expect <- structure(integer(), n = 0L, unique_loc=integer())
expect_equal(vec_group_id_and_loc(NULL), expect)
expect_equal(vec_group_id_and_loc(numeric()), expect)
})

test_that("vec_group_id_and_loc works on base S3 objects", {
x <- factor(c("x", "y", "x"))
expect_matches_separate_calls(x)

x <- new_date(c(0, 1, 0))
expect_matches_separate_calls(x)
})

test_that("vec_group_id_and_loc works on data frames", {
df <- data.frame(x = c(1, 2, 1, 1), y = c(2, 3, 2, 3))
expect_matches_separate_calls(df)
})

test_that("vec_group_id_and_loc works on arrays", {
x <- array(c(1, 1, 1, 2, 4, 2), c(3, 2))
expect_matches_separate_calls(x)
})

test_that("vec_group_id takes the equality proxy", {
local_comparable_tuple()
x <- tuple(c(1, 2, 1, 1), c(1, 1, 1, 2))
# Compares on only the first field
expect_matches_separate_calls(x)
})

test_that("vec_group_id takes the equality proxy recursively", {
local_comparable_tuple()

x <- tuple(c(1, 2, 1, 1), 1:4)
df <- data_frame(x = x)
expect_matches_separate_calls(df)
})


# vec_deduplicate ---------------------------------------------------------

test_that("vec_deduplicate(f) runs only on deduplicated values", {
ncalls <<- 0
f <- function(ii) for(i in ii) ncalls <<- ncalls + 1

x <- c(1, 1, 1, 2, 3)
vec_deduplicate(f)(x)
expect_equal(ncalls, 3)

ncalls <<- 0
x <- 1:5
vec_deduplicate(f)(x)
expect_equal(ncalls, 5)
})

0 comments on commit 660522f

Please sign in to comment.