-
Notifications
You must be signed in to change notification settings - Fork 8
Simplify C++ inference with auto-instantiated RAFT handle #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/26.06
Are you sure you want to change the base?
Changes from all commits
35287df
551bbd2
d8ad61b
f042cc0
97a8976
ab44b92
b781fda
9000a73
cd9bece
fb4355a
f4c6fee
59cfadc
8efe44e
564e38c
13ddf90
ed7627e
c27b29c
65b83ce
0c314cb
b3fd3be
a1f05d3
5fc3d6d
57d8845
3be33b9
bb6616a
afd5cab
5af8cac
a2111d2
6e2cc42
902c2ea
779f8c2
397835f
5f344d0
121c400
10499f2
0b2ac1a
b8b3ada
b2f1d32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -308,6 +308,57 @@ struct forest_model { | |
| predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size); | ||
| } | ||
|
|
||
| /** | ||
| * Perform inference on given input using an internally managed RAFT handle. | ||
| * This function is blocking and synchronizes the handle before returning. | ||
| * | ||
| * @param[out] output Pointer to the memory location where output should end | ||
| * up | ||
| * @param[in] input Pointer to the input data | ||
| * @param[in] num_rows Number of rows in input | ||
| * @param[in] out_mem_type The memory type (device/host) of the output | ||
| * buffer | ||
| * @param[in] in_mem_type The memory type (device/host) of the input buffer | ||
| * @param[in] predict_type Type of inference to perform. Defaults to summing | ||
| * the outputs of all trees and produce an output per row. If set to | ||
| * "per_tree", we will instead output all outputs of individual trees. | ||
| * If set to "leaf_id", we will output the integer ID of the leaf node | ||
| * for each tree. | ||
| * @param[in] specified_chunk_size: Specifies the mini-batch size for | ||
| * processing. This has different meanings on CPU and GPU, but on GPU it | ||
| * corresponds to the number of rows evaluated per inference iteration | ||
| * on a single block. It can take on any power of 2 from 1 to 32, and | ||
| * runtime performance is quite sensitive to the value chosen. In general, | ||
| * larger batches benefit from higher values, but it is hard to predict the | ||
| * optimal value a priori. If omitted, a heuristic will be used to select a | ||
| * reasonable value. On CPU, this argument can generally just be omitted. | ||
| */ | ||
| template <typename io_t> | ||
| void predict(io_t* output, | ||
| io_t* input, | ||
| std::size_t num_rows, | ||
| raft_proto::device_type out_mem_type, | ||
| raft_proto::device_type in_mem_type, | ||
| infer_kind predict_type = infer_kind::default_kind, | ||
| std::optional<index_type> specified_chunk_size = std::nullopt) | ||
|
hcho3 marked this conversation as resolved.
|
||
| { | ||
| #ifdef NVFOREST_ENABLE_GPU | ||
| auto raft_handle = raft::handle_t{}; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question on raft::handle_t{} here, with no arguments, the default ctor uses rmm::cuda_stream_per_thread and a null stream pool, so handle.get_stream_pool_size() is 0 and get_usable_stream_count() is 1, right? . That means the chunking/partitioning loop in the buffer overload above (around lines ~200‑248 of this same file) collapse to single stream sequential copies whenever out_mem_type != in_mem_type or either differs from memory_type(). In other words, a C++ user who picks the simple path and passes host input + device output (or vice versa) silently gets less parallelism than the same call via the explicit-handle path with a populated pool. The Python API does cpu/cpu or gpu/gpu, but this is now a primary C++ API too, so could we give it a small default pool here so the convenience overload doesn't quietly perform worse on the heterogeneous memory case? And also point this out in the doc?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the stream pool size actually has an effect on the partition size and thus is a runtime optimization parameter, we probably need to run some benchmarks before providing a default choice. |
||
| auto handle = raft_proto::handle_t{raft_handle}; | ||
| #else | ||
| auto handle = raft_proto::handle_t{}; | ||
| #endif | ||
| predict(handle, | ||
| output, | ||
| input, | ||
| num_rows, | ||
| out_mem_type, | ||
| in_mem_type, | ||
| predict_type, | ||
| specified_chunk_size); | ||
| handle.synchronize(); | ||
| } | ||
|
|
||
| private: | ||
| decision_forest_variant decision_forest_; | ||
| }; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.