-
Notifications
You must be signed in to change notification settings - Fork 988
Add approx_distinct_count #20735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
PointKernel
wants to merge
2
commits into
rapidsai:main
Choose a base branch
from
PointKernel:approx-distinct-count
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+467
−1
Draft
Add approx_distinct_count #20735
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #include "stream_compaction_common.cuh" | ||
|
|
||
| #include <cudf/column/column_device_view.cuh> | ||
| #include <cudf/column/column_view.hpp> | ||
| #include <cudf/detail/iterator.cuh> | ||
| #include <cudf/detail/null_mask.hpp> | ||
| #include <cudf/detail/nvtx/ranges.hpp> | ||
| #include <cudf/detail/row_operator/hashing.cuh> | ||
| #include <cudf/detail/stream_compaction.hpp> | ||
| #include <cudf/stream_compaction.hpp> | ||
| #include <cudf/table/table_view.hpp> | ||
| #include <cudf/utilities/default_stream.hpp> | ||
| #include <cudf/utilities/type_checks.hpp> | ||
|
|
||
| #include <rmm/cuda_stream_view.hpp> | ||
| #include <rmm/exec_policy.hpp> | ||
| #include <rmm/mr/polymorphic_allocator.hpp> | ||
|
|
||
| #include <cuco/hyperloglog.cuh> | ||
| #include <thrust/copy.h> | ||
| #include <thrust/execution_policy.h> | ||
| #include <thrust/iterator/counting_iterator.h> | ||
| #include <thrust/transform.h> | ||
|
|
||
| #include <algorithm> | ||
|
|
||
| namespace cudf { | ||
| namespace detail { | ||
|
|
||
| // Internal implementation function | ||
| cudf::size_type approx_distinct_count_impl(table_view const& input, | ||
| int precision, | ||
| null_policy null_handling, | ||
| nan_policy nan_handling, | ||
| rmm::cuda_stream_view stream) | ||
| { | ||
| auto const num_rows = input.num_rows(); | ||
| if (num_rows == 0) { return 0; } | ||
|
|
||
| // Clamp precision to valid range for HyperLogLog | ||
| precision = std::max(4, std::min(18, precision)); | ||
|
|
||
| auto const has_nulls = nullate::DYNAMIC{cudf::has_nested_nulls(input)}; | ||
| auto const preprocessed_input = | ||
| cudf::detail::row::hash::preprocessed_table::create(input, stream); | ||
| auto const row_hasher = cudf::detail::row::hash::row_hasher(preprocessed_input); | ||
| auto const hash_key = row_hasher.device_hasher(has_nulls); | ||
|
|
||
| auto hll = cuco::hyperloglog<cudf::hash_value_type, | ||
| cuda::thread_scope_device, | ||
| cuco::xxhash_64<cudf::hash_value_type>, | ||
| rmm::mr::polymorphic_allocator<cuda::std::byte>>{ | ||
| cuco::sketch_size_kb{static_cast<double>(4 * (1ull << precision) / 1024.0)}, | ||
| cuco::xxhash_64<cudf::hash_value_type>{}, | ||
| rmm::mr::polymorphic_allocator<cuda::std::byte>{}, | ||
| cuda::stream_ref{stream.value()}}; | ||
|
|
||
| auto const iter = thrust::counting_iterator<cudf::size_type>(0); | ||
|
|
||
| rmm::device_uvector<cudf::hash_value_type> hash_values(num_rows, stream); | ||
| thrust::transform( | ||
| rmm::exec_policy_nosync(stream), iter, iter + num_rows, hash_values.begin(), hash_key); | ||
|
|
||
| // Create a temporary table for distinct processing if needed | ||
| if (nan_handling == nan_policy::NAN_IS_NULL || null_handling == null_policy::EXCLUDE) { | ||
| if (num_rows < 10000) { | ||
| if (input.num_columns() == 1) { | ||
| return cudf::distinct_count(input.column(0), null_handling, nan_handling); | ||
| } else { | ||
| return cudf::distinct_count(input, cudf::null_equality::EQUAL); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (null_handling == null_policy::EXCLUDE && has_nulls) { | ||
| auto const [row_bitmask, null_count] = | ||
| cudf::detail::bitmask_or(input, stream, cudf::get_current_device_resource_ref()); | ||
|
|
||
| if (null_count > 0) { | ||
| row_validity pred{static_cast<bitmask_type const*>(row_bitmask.data())}; | ||
| auto counting_iter = thrust::counting_iterator<size_type>(0); | ||
|
|
||
| rmm::device_uvector<cudf::hash_value_type> filtered_hashes(num_rows - null_count, stream); | ||
| auto end_iter = thrust::copy_if(rmm::exec_policy(stream), | ||
| hash_values.begin(), | ||
| hash_values.end(), | ||
| counting_iter, | ||
| filtered_hashes.begin(), | ||
| pred); | ||
|
|
||
| auto actual_count = std::distance(filtered_hashes.begin(), end_iter); | ||
| if (actual_count > 0) { | ||
| hll.add(filtered_hashes.begin(), | ||
| filtered_hashes.begin() + actual_count, | ||
| cuda::stream_ref{stream.value()}); | ||
| } | ||
| return static_cast<cudf::size_type>(hll.estimate(cuda::stream_ref{stream.value()})); | ||
| } | ||
| } | ||
|
|
||
| hll.add(hash_values.begin(), hash_values.end(), cuda::stream_ref{stream.value()}); | ||
| return static_cast<cudf::size_type>(hll.estimate(cuda::stream_ref{stream.value()})); | ||
| } | ||
|
|
||
| cudf::size_type approx_distinct_count(table_view const& input, | ||
| int precision, | ||
| null_policy null_handling, | ||
| nan_policy nan_handling, | ||
| rmm::cuda_stream_view stream) | ||
| { | ||
| return approx_distinct_count_impl(input, precision, null_handling, nan_handling, stream); | ||
| } | ||
|
|
||
| cudf::size_type approx_distinct_count(column_view const& input, | ||
| int precision, | ||
| null_policy null_handling, | ||
| nan_policy nan_handling, | ||
| rmm::cuda_stream_view stream) | ||
| { | ||
| // Convert column to single-column table and use unified implementation | ||
| cudf::table_view single_col_table({input}); | ||
| return approx_distinct_count_impl( | ||
| single_col_table, precision, null_handling, nan_handling, stream); | ||
| } | ||
|
|
||
| } // namespace detail | ||
|
|
||
| // Public API implementations | ||
| cudf::size_type approx_distinct_count(column_view const& input, | ||
| int precision, | ||
| null_policy null_handling, | ||
| nan_policy nan_handling, | ||
| rmm::cuda_stream_view stream) | ||
| { | ||
| CUDF_FUNC_RANGE(); | ||
| return detail::approx_distinct_count(input, precision, null_handling, nan_handling, stream); | ||
| } | ||
|
|
||
| cudf::size_type approx_distinct_count(table_view const& input, | ||
| int precision, | ||
| null_policy null_handling, | ||
| nan_policy nan_handling, | ||
| rmm::cuda_stream_view stream) | ||
| { | ||
| CUDF_FUNC_RANGE(); | ||
| return detail::approx_distinct_count(input, precision, null_handling, nan_handling, stream); | ||
| } | ||
|
|
||
| } // namespace cudf | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about multi-gpu approx distinct count, I believe that two sketches can be combined by some binary operator, and that commutes through the
estimatefunction. i.e.(hll(A) + hll(B)).estimate() == hll(A + B).estimate().To produce a global approx distinct count from the GPU-local ones, I need to do this merge.
Can you provide an interface to return the
hll.sketch_bytes()as an object that I can then combine with another sketch that was constructed using the same hashing scheme and approximation size?Perhaps, spitballing:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do have
mergeAPIs in HLL that address this need: https://github.com/NVIDIA/cuCollections/blob/d36905c69ce02d74abdd31dc864ce3e1ffc5a7db/include/cuco/hyperloglog.cuh#L159-L221. The question is really about how to surface this capability in libcudf. One idea I had is to expose a class likeapprox_estimatorin libcudf so users can perform custom operations such as merge. However, that class would essentially just wrapcuco::hyperloglog, meaning that for multi-GPU scenarios users could simply usecuco::hyperloglogdirectly without needing any cudf abstraction. Does that sound reasonable, or is there something I’m overlooking?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, exposing an object-oriented estimator instead of the current free function is likely the better approach. It offers significantly more flexibility, and given the complexity involved with row operations and null/nan handling, relying on users to manage those aspects themselves would be fairly complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think nan/null handling should be provided by us, rather than the end user. I've not yet looked as well at all the row_hasher apis, do we expose those in the public interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very good point. The row operators reside in the
detailnamespace.