17
17
#include " polynomial/polynomial_math.cuh"
18
18
#include " types/complex/operations.cuh"
19
19
20
- template <typename Torus, class params , sharedMemDegree SMD>
20
+ template <typename Torus, class params , sharedMemDegree SMD, bool first_iter >
21
21
__global__ void __launch_bounds__ (params::degree / params::opt)
22
22
device_programmable_bootstrap_step_one(
23
23
const Torus *__restrict__ lut_vector,
24
24
const Torus *__restrict__ lut_vector_indexes,
25
25
const Torus *__restrict__ lwe_array_in,
26
- const Torus *__restrict__ lwe_input_indexes,
27
- const double2 *__restrict__ bootstrapping_key,
28
- Torus *global_accumulator, double2 *global_join_buffer,
29
- uint32_t lwe_iteration, uint32_t lwe_dimension,
30
- uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
31
- int8_t *device_mem, uint64_t device_memory_size_per_block) {
26
+ const Torus *__restrict__ lwe_input_indexes, Torus *global_accumulator,
27
+ double2 *global_join_buffer, uint32_t lwe_iteration,
28
+ uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
29
+ uint32_t level_count, int8_t *device_mem,
30
+ uint64_t device_memory_size_per_block) {
32
31
33
32
// We use shared memory for the polynomials that are used often during the
34
33
// bootstrap, since shared memory is kept in L1 cache and accessing it is
@@ -71,7 +70,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
71
70
blockIdx .x * level_count * (glwe_dimension + 1 )) *
72
71
(polynomial_size / 2 );
73
72
74
- if (lwe_iteration == 0 ) {
73
+ if constexpr (first_iter ) {
75
74
// First iteration
76
75
// Put "b" in [0, 2N[
77
76
Torus b_hat = 0 ;
@@ -131,7 +130,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
131
130
}
132
131
}
133
132
134
- template <typename Torus, class params , sharedMemDegree SMD>
133
+ template <typename Torus, class params , sharedMemDegree SMD, bool last_iter >
135
134
__global__ void __launch_bounds__ (params::degree / params::opt)
136
135
device_programmable_bootstrap_step_two(
137
136
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
@@ -205,7 +204,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
205
204
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
206
205
add_to_torus<Torus, params>(accumulator_fft, accumulator);
207
206
208
- if (lwe_iteration + 1 == lwe_dimension ) {
207
+ if constexpr (last_iter ) {
209
208
// Last iteration
210
209
auto block_lwe_array_out =
211
210
&lwe_array_out[lwe_output_indexes[blockIdx .x ] *
@@ -321,37 +320,61 @@ __host__ void scratch_programmable_bootstrap(
321
320
// Configure step one
322
321
if (max_shared_memory >= partial_sm && max_shared_memory < full_sm_step_one) {
323
322
check_cuda_error (cudaFuncSetAttribute (
324
- device_programmable_bootstrap_step_one<Torus, params, PARTIALSM>,
323
+ device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, true >,
324
+ cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm));
325
+ cudaFuncSetCacheConfig (
326
+ device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, true >,
327
+ cudaFuncCachePreferShared);
328
+ check_cuda_error (cudaFuncSetAttribute (
329
+ device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, false >,
325
330
cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm));
326
331
cudaFuncSetCacheConfig (
327
- device_programmable_bootstrap_step_one<Torus, params, PARTIALSM>,
332
+ device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, false >,
328
333
cudaFuncCachePreferShared);
329
334
check_cuda_error (cudaGetLastError ());
330
335
} else if (max_shared_memory >= partial_sm) {
331
336
check_cuda_error (cudaFuncSetAttribute (
332
- device_programmable_bootstrap_step_one<Torus, params, FULLSM>,
337
+ device_programmable_bootstrap_step_one<Torus, params, FULLSM, true >,
333
338
cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_step_one));
334
339
cudaFuncSetCacheConfig (
335
- device_programmable_bootstrap_step_one<Torus, params, FULLSM>,
340
+ device_programmable_bootstrap_step_one<Torus, params, FULLSM, true >,
341
+ cudaFuncCachePreferShared);
342
+ check_cuda_error (cudaFuncSetAttribute (
343
+ device_programmable_bootstrap_step_one<Torus, params, FULLSM, false >,
344
+ cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_step_one));
345
+ cudaFuncSetCacheConfig (
346
+ device_programmable_bootstrap_step_one<Torus, params, FULLSM, false >,
336
347
cudaFuncCachePreferShared);
337
348
check_cuda_error (cudaGetLastError ());
338
349
}
339
350
340
351
// Configure step two
341
352
if (max_shared_memory >= partial_sm && max_shared_memory < full_sm_step_two) {
342
353
check_cuda_error (cudaFuncSetAttribute (
343
- device_programmable_bootstrap_step_two<Torus, params, PARTIALSM>,
354
+ device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, true >,
355
+ cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm));
356
+ cudaFuncSetCacheConfig (
357
+ device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, true >,
358
+ cudaFuncCachePreferShared);
359
+ check_cuda_error (cudaFuncSetAttribute (
360
+ device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, false >,
344
361
cudaFuncAttributeMaxDynamicSharedMemorySize, partial_sm));
345
362
cudaFuncSetCacheConfig (
346
- device_programmable_bootstrap_step_two<Torus, params, PARTIALSM>,
363
+ device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, false >,
347
364
cudaFuncCachePreferShared);
348
365
check_cuda_error (cudaGetLastError ());
349
366
} else if (max_shared_memory >= partial_sm) {
350
367
check_cuda_error (cudaFuncSetAttribute (
351
- device_programmable_bootstrap_step_two<Torus, params, FULLSM>,
368
+ device_programmable_bootstrap_step_two<Torus, params, FULLSM, true >,
369
+ cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_step_two));
370
+ cudaFuncSetCacheConfig (
371
+ device_programmable_bootstrap_step_two<Torus, params, FULLSM, true >,
372
+ cudaFuncCachePreferShared);
373
+ check_cuda_error (cudaFuncSetAttribute (
374
+ device_programmable_bootstrap_step_two<Torus, params, FULLSM, false >,
352
375
cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm_step_two));
353
376
cudaFuncSetCacheConfig (
354
- device_programmable_bootstrap_step_two<Torus, params, FULLSM>,
377
+ device_programmable_bootstrap_step_two<Torus, params, FULLSM, false >,
355
378
cudaFuncCachePreferShared);
356
379
check_cuda_error (cudaGetLastError ());
357
380
}
@@ -361,7 +384,7 @@ __host__ void scratch_programmable_bootstrap(
361
384
input_lwe_ciphertext_count, PBS_VARIANT::DEFAULT, allocate_gpu_memory);
362
385
}
363
386
364
- template <typename Torus, class params >
387
+ template <typename Torus, class params , bool first_iter >
365
388
__host__ void execute_step_one (
366
389
cudaStream_t stream, uint32_t gpu_index, Torus const *lut_vector,
367
390
Torus const *lut_vector_indexes, Torus const *lwe_array_in,
@@ -378,31 +401,30 @@ __host__ void execute_step_one(
378
401
dim3 grid (input_lwe_ciphertext_count, glwe_dimension + 1 , level_count);
379
402
380
403
if (max_shared_memory < partial_sm) {
381
- device_programmable_bootstrap_step_one<Torus, params, NOSM>
404
+ device_programmable_bootstrap_step_one<Torus, params, NOSM, first_iter >
382
405
<<<grid, thds, 0 , stream>>> (
383
406
lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes,
384
- bootstrapping_key, global_accumulator, global_join_buffer,
385
- lwe_iteration, lwe_dimension, polynomial_size, base_log,
386
- level_count, d_mem, full_dm);
407
+ global_accumulator, global_join_buffer, lwe_iteration ,
408
+ lwe_dimension, polynomial_size, base_log, level_count, d_mem ,
409
+ full_dm);
387
410
} else if (max_shared_memory < full_sm) {
388
- device_programmable_bootstrap_step_one<Torus, params, PARTIALSM>
411
+ device_programmable_bootstrap_step_one<Torus, params, PARTIALSM, first_iter >
389
412
<<<grid, thds, partial_sm, stream>>> (
390
413
lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes,
391
- bootstrapping_key, global_accumulator, global_join_buffer,
392
- lwe_iteration, lwe_dimension, polynomial_size, base_log,
393
- level_count, d_mem, partial_dm);
414
+ global_accumulator, global_join_buffer, lwe_iteration ,
415
+ lwe_dimension, polynomial_size, base_log, level_count, d_mem ,
416
+ partial_dm);
394
417
} else {
395
- device_programmable_bootstrap_step_one<Torus, params, FULLSM>
418
+ device_programmable_bootstrap_step_one<Torus, params, FULLSM, first_iter >
396
419
<<<grid, thds, full_sm, stream>>> (
397
420
lut_vector, lut_vector_indexes, lwe_array_in, lwe_input_indexes,
398
- bootstrapping_key, global_accumulator, global_join_buffer,
399
- lwe_iteration, lwe_dimension, polynomial_size, base_log,
400
- level_count, d_mem, 0 );
421
+ global_accumulator, global_join_buffer, lwe_iteration,
422
+ lwe_dimension, polynomial_size, base_log, level_count, d_mem, 0 );
401
423
}
402
424
check_cuda_error (cudaGetLastError ());
403
425
}
404
426
405
- template <typename Torus, class params >
427
+ template <typename Torus, class params , bool last_iter >
406
428
__host__ void execute_step_two (
407
429
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
408
430
Torus const *lwe_output_indexes, Torus const *lut_vector,
@@ -420,21 +442,21 @@ __host__ void execute_step_two(
420
442
dim3 grid (input_lwe_ciphertext_count, glwe_dimension + 1 );
421
443
422
444
if (max_shared_memory < partial_sm) {
423
- device_programmable_bootstrap_step_two<Torus, params, NOSM>
445
+ device_programmable_bootstrap_step_two<Torus, params, NOSM, last_iter >
424
446
<<<grid, thds, 0 , stream>>> (
425
447
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
426
448
bootstrapping_key, global_accumulator, global_join_buffer,
427
449
lwe_iteration, lwe_dimension, polynomial_size, base_log,
428
450
level_count, d_mem, full_dm, num_many_lut, lut_stride);
429
451
} else if (max_shared_memory < full_sm) {
430
- device_programmable_bootstrap_step_two<Torus, params, PARTIALSM>
452
+ device_programmable_bootstrap_step_two<Torus, params, PARTIALSM, last_iter >
431
453
<<<grid, thds, partial_sm, stream>>> (
432
454
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
433
455
bootstrapping_key, global_accumulator, global_join_buffer,
434
456
lwe_iteration, lwe_dimension, polynomial_size, base_log,
435
457
level_count, d_mem, partial_dm, num_many_lut, lut_stride);
436
458
} else {
437
- device_programmable_bootstrap_step_two<Torus, params, FULLSM>
459
+ device_programmable_bootstrap_step_two<Torus, params, FULLSM, last_iter >
438
460
<<<grid, thds, full_sm, stream>>> (
439
461
lwe_array_out, lwe_output_indexes, lut_vector, lut_vector_indexes,
440
462
bootstrapping_key, global_accumulator, global_join_buffer,
@@ -480,19 +502,38 @@ __host__ void host_programmable_bootstrap(
480
502
int8_t *d_mem = pbs_buffer->d_mem ;
481
503
482
504
for (int i = 0 ; i < lwe_dimension; i++) {
483
- execute_step_one<Torus, params>(
484
- stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
485
- lwe_input_indexes, bootstrapping_key, global_accumulator,
486
- global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
487
- glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
488
- partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one);
489
- execute_step_two<Torus, params>(
490
- stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector,
491
- lut_vector_indexes, bootstrapping_key, global_accumulator,
492
- global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
493
- glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
494
- partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two,
495
- num_many_lut, lut_stride);
505
+ if (i == 0 ) {
506
+ execute_step_one<Torus, params, true >(
507
+ stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
508
+ lwe_input_indexes, bootstrapping_key, global_accumulator,
509
+ global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
510
+ glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
511
+ partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one);
512
+ } else {
513
+ execute_step_one<Torus, params, false >(
514
+ stream, gpu_index, lut_vector, lut_vector_indexes, lwe_array_in,
515
+ lwe_input_indexes, bootstrapping_key, global_accumulator,
516
+ global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
517
+ glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
518
+ partial_sm, partial_dm_step_one, full_sm_step_one, full_dm_step_one);
519
+ }
520
+ if (i == lwe_dimension - 1 ) {
521
+ execute_step_two<Torus, params, true >(
522
+ stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector,
523
+ lut_vector_indexes, bootstrapping_key, global_accumulator,
524
+ global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
525
+ glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
526
+ partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two,
527
+ num_many_lut, lut_stride);
528
+ } else {
529
+ execute_step_two<Torus, params, false >(
530
+ stream, gpu_index, lwe_array_out, lwe_output_indexes, lut_vector,
531
+ lut_vector_indexes, bootstrapping_key, global_accumulator,
532
+ global_join_buffer, input_lwe_ciphertext_count, lwe_dimension,
533
+ glwe_dimension, polynomial_size, base_log, level_count, d_mem, i,
534
+ partial_sm, partial_dm_step_two, full_sm_step_two, full_dm_step_two,
535
+ num_many_lut, lut_stride);
536
+ }
496
537
}
497
538
}
498
539
0 commit comments