Skip to content

Commit 6e158cd

Browse files
committed
chore(gpu): use template for first/last iter in split classical PBS
1 parent cdcf00a commit 6e158cd

File tree

2 files changed

+181
-80
lines changed

2 files changed

+181
-80
lines changed

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_classic.cuh

+89-48
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,17 @@
1717
#include "polynomial/polynomial_math.cuh"
1818
#include "types/complex/operations.cuh"
1919

20-
template <typename Torus, class params, sharedMemDegree SMD>
20+
template <typename Torus, class params, sharedMemDegree SMD, bool first_iter>
2121
__global__ void __launch_bounds__(params::degree / params::opt)
2222
device_programmable_bootstrap_step_one(
2323
const Torus *__restrict__ lut_vector,
2424
const Torus *__restrict__ lut_vector_indexes,
2525
const Torus *__restrict__ lwe_array_in,
26-
const Torus *__restrict__ lwe_input_indexes,
27-
const double2 *__restrict__ bootstrapping_key,
28-
Torus *global_accumulator, double2 *global_join_buffer,
29-
uint32_t lwe_iteration, uint32_t lwe_dimension,
30-
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
31-
int8_t *device_mem, uint64_t device_memory_size_per_block) {
26+
const Torus *__restrict__ lwe_input_indexes, Torus *global_accumulator,
27+
double2 *global_join_buffer, uint32_t lwe_iteration,
28+
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
29+
uint32_t level_count, int8_t *device_mem,
30+
uint64_t device_memory_size_per_block) {
3231

3332
// We use shared memory for the polynomials that are used often during the
3433
// bootstrap, since shared memory is kept in L1 cache and accessing it is
@@ -71,7 +70,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
7170
blockIdx.x * level_count * (glwe_dimension + 1)) *
7271
(polynomial_size / 2);
7372

74-
if (lwe_iteration == 0) {
73+
if constexpr (first_iter) {
7574
// First iteration
7675
// Put "b" in [0, 2N[
7776
Torus b_hat = 0;
@@ -131,7 +130,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
131130
}
132131
}
133132

134-
template <typename Torus, class params, sharedMemDegree SMD>
133+
template <typename Torus, class params, sharedMemDegree SMD, bool last_iter>
135134
__global__ void __launch_bounds__(params::degree / params::opt)
136135
device_programmable_bootstrap_step_two(
137136
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
@@ -205,7 +204,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
205204
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
206205
add_to_torus<Torus, params>(accumulator_fft, accumulator);
207206

208-
if (lwe_iteration + 1 == lwe_dimension) {
207+
if constexpr (last_iter) {
209208
// Last iteration
210209
auto block_lwe_array_out =
211210
&lwe_array_out[lwe_output_indexes[blockIdx.x] *
@@ -321,37 +320,61 @@ __host__ void scratch_programmable_bootstrap(
321320
// Configure step one
322321
if (max_shared_memory >= partial_sm && max_shared_memory < full_sm_step_one) {
323322
check_cuda_error(cudaFuncSetAttribute(
324-
device_programmable_bootstrap_step_one<Torus, params, PARTIALSM>,
323+
device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, true>,
324+
cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm));
325+
cudaFuncSetCacheConfig(
326+
device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, true>,
327+
cudaFuncCachePreferShared);
328+
check_cuda_error(cudaFuncSetAttribute(
329+
device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, false>,
325330
cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm));
326331
cudaFuncSetCacheConfig(
327-
device_programmable_bootstrap_step_one<Torus, params, PARTIALSM>,
332+
device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, false>,
328333
cudaFuncCachePreferShared);
329334
check_cuda_error(cudaGetLastError());
330335
} else if (max_shared_memory >= partial_sm) {
331336
check_cuda_error(cudaFuncSetAttribute(
332-
device_programmable_bootstrap_step_one<Torus, params, FULLSM>,
337+
device_programmable_bootstrap_step_one<Torus, params, FULLSM, true>,
333338
cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_step_one));
334339
cudaFuncSetCacheConfig(
335-
device_programmable_bootstrap_step_one<Torus, params, FULLSM>,
340+
device_programmable_bootstrap_step_one<Torus, params, FULLSM, true>,
341+
cudaFuncCachePreferShared);
342+
check_cuda_error(cudaFuncSetAttribute(
343+
device_programmable_bootstrap_step_one<Torus, params, FULLSM, false>,
344+
cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_step_one));
345+
cudaFuncSetCacheConfig(
346+
device_programmable_bootstrap_step_one<Torus, params, FULLSM, false>,
336347
cudaFuncCachePreferShared);
337348
check_cuda_error(cudaGetLastError());
338349
}
339350

340351
// Configure step two
341352
if (max_shared_memory >= partial_sm && max_shared_memory < full_sm_step_two) {
342353
check_cuda_error(cudaFuncSetAttribute(
343-
device_programmable_bootstrap_step_two<Torus, params, PARTIALSM>,
354+
device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, true>,
355+
cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm));
356+
cudaFuncSetCacheConfig(
357+
device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, true>,
358+
cudaFuncCachePreferShared);
359+
check_cuda_error(cudaFuncSetAttribute(
360+
device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, false>,
344361
cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm));
345362
cudaFuncSetCacheConfig(
346-
device_programmable_bootstrap_step_two<Torus, params, PARTIALSM>,
363+
device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, false>,
347364
cudaFuncCachePreferShared);
348365
check_cuda_error(cudaGetLastError());
349366
} else if (max_shared_memory >= partial_sm) {
350367
check_cuda_error(cudaFuncSetAttribute(
351-
device_programmable_bootstrap_step_two<Torus, params, FULLSM>,
368+
device_programmable_bootstrap_step_two<Torus, params, FULLSM, true>,
369+
cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_step_two));
370+
cudaFuncSetCacheConfig(
371+
device_programmable_bootstrap_step_two<Torus, params, FULLSM, true>,
372+
cudaFuncCachePreferShared);
373+
check_cuda_error(cudaFuncSetAttribute(
374+
device_programmable_bootstrap_step_two<Torus, params, FULLSM, false>,
352375
cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_step_two));
353376
cudaFuncSetCacheConfig(
354-
device_programmable_bootstrap_step_two<Torus, params, FULLSM>,
377+
device_programmable_bootstrap_step_two<Torus, params, FULLSM, false>,
355378
cudaFuncCachePreferShared);
356379
check_cuda_error(cudaGetLastError());
357380
}
@@ -361,7 +384,7 @@ __host__ void scratch_programmable_bootstrap(
361384
input_lwe_ciphertext_count, PBS_VARIANT::DEFAULT, allocate_gpu_memory);
362385
}
363386

364-
template <typename Torus, class params>
387+
template <typename Torus, class params, bool first_iter>
365388
__host__ void execute_step_one(
366389
cudaStream_t stream, uint32_t gpu_index, Torus const *lut_vector,
367390
Torus const *lut_vector_indexes, Torus const *lwe_array_in,
@@ -378,31 +401,30 @@ __host__ void execute_step_one(
378401
dim3 grid(input_lwe_ciphertext_count, glwe_dimension + 1, level_count);
379402

380403
if (max_shared_memory < partial_sm) {
381-
device_programmable_bootstrap_step_one<Torus, params, NOSM>
404+
device_programmable_bootstrap_step_one<Torus, params, NOSM, first_iter>
382405
<<<grid, thds, 0, stream>>>(
383406
lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes,
384-
bootstrapping_key, global_accumulator, global_join_buffer,
385-
lwe_iteration, lwe_dimension, polynomial_size, base_log,
386-
level_count, d_mem, full_dm);
407+
global_accumulator, global_join_buffer, lwe_iteration,
408+
lwe_dimension, polynomial_size, base_log, level_count, d_mem,
409+
full_dm);
387410
} else if (max_shared_memory < full_sm) {
388-
device_programmable_bootstrap_step_one<Torus, params, PARTIALSM>
411+
device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, first_iter>
389412
<<<grid, thds, partial_sm, stream>>>(
390413
lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes,
391-
bootstrapping_key, global_accumulator, global_join_buffer,
392-
lwe_iteration, lwe_dimension, polynomial_size, base_log,
393-
level_count, d_mem, partial_dm);
414+
global_accumulator, global_join_buffer, lwe_iteration,
415+
lwe_dimension, polynomial_size, base_log, level_count, d_mem,
416+
partial_dm);
394417
} else {
395-
device_programmable_bootstrap_step_one<Torus, params, FULLSM>
418+
device_programmable_bootstrap_step_one<Torus, params, FULLSM, first_iter>
396419
<<<grid, thds, full_sm, stream>>>(
397420
lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes,
398-
bootstrapping_key, global_accumulator, global_join_buffer,
399-
lwe_iteration, lwe_dimension, polynomial_size, base_log,
400-
level_count, d_mem, 0);
421+
global_accumulator, global_join_buffer, lwe_iteration,
422+
lwe_dimension, polynomial_size, base_log, level_count, d_mem, 0);
401423
}
402424
check_cuda_error(cudaGetLastError());
403425
}
404426

405-
template <typename Torus, class params>
427+
template <typename Torus, class params, bool last_iter>
406428
__host__ void execute_step_two(
407429
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
408430
Torus const *lwe_output_indexes, Torus const *lut_vector,
@@ -420,21 +442,21 @@ __host__ void execute_step_two(
420442
dim3 grid(input_lwe_ciphertext_count, glwe_dimension + 1);
421443

422444
if (max_shared_memory < partial_sm) {
423-
device_programmable_bootstrap_step_two<Torus, params, NOSM>
445+
device_programmable_bootstrap_step_two<Torus, params, NOSM, last_iter>
424446
<<<grid, thds, 0, stream>>>(
425447
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
426448
bootstrapping_key, global_accumulator, global_join_buffer,
427449
lwe_iteration, lwe_dimension, polynomial_size, base_log,
428450
level_count, d_mem, full_dm, num_many_lut, lut_stride);
429451
} else if (max_shared_memory < full_sm) {
430-
device_programmable_bootstrap_step_two<Torus, params, PARTIALSM>
452+
device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, last_iter>
431453
<<<grid, thds, partial_sm, stream>>>(
432454
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
433455
bootstrapping_key, global_accumulator, global_join_buffer,
434456
lwe_iteration, lwe_dimension, polynomial_size, base_log,
435457
level_count, d_mem, partial_dm, num_many_lut, lut_stride);
436458
} else {
437-
device_programmable_bootstrap_step_two<Torus, params, FULLSM>
459+
device_programmable_bootstrap_step_two<Torus, params, FULLSM, last_iter>
438460
<<<grid, thds, full_sm, stream>>>(
439461
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
440462
bootstrapping_key, global_accumulator, global_join_buffer,
@@ -480,19 +502,38 @@ __host__ void host_programmable_bootstrap(
480502
int8_t *d_mem = pbs_buffer->d_mem;
481503

482504
for (int i = 0; i < lwe_dimension; i++) {
483-
execute_step_one<Torus, params>(
484-
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
485-
lwe_input_indexes, bootstrapping_key, global_accumulator,
486-
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
487-
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
488-
partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one);
489-
execute_step_two<Torus, params>(
490-
stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector,
491-
lut_vector_indexes, bootstrapping_key, global_accumulator,
492-
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
493-
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
494-
partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two,
495-
num_many_lut, lut_stride);
505+
if (i == 0) {
506+
execute_step_one<Torus, params, true>(
507+
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
508+
lwe_input_indexes, bootstrapping_key, global_accumulator,
509+
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
510+
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
511+
partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one);
512+
} else {
513+
execute_step_one<Torus, params, false>(
514+
stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
515+
lwe_input_indexes, bootstrapping_key, global_accumulator,
516+
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
517+
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
518+
partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one);
519+
}
520+
if (i == lwe_dimension - 1) {
521+
execute_step_two<Torus, params, true>(
522+
stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector,
523+
lut_vector_indexes, bootstrapping_key, global_accumulator,
524+
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
525+
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
526+
partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two,
527+
num_many_lut, lut_stride);
528+
} else {
529+
execute_step_two<Torus, params, false>(
530+
stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector,
531+
lut_vector_indexes, bootstrapping_key, global_accumulator,
532+
global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
533+
glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
534+
partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two,
535+
num_many_lut, lut_stride);
536+
}
496537
}
497538
}
498539

0 commit comments

Comments
 (0)