Skip to content

Commit a11584f

Browse files
committed
refactor(integer): provide recompose_unsigned, recompose_signed functions
- to decrypt values and make it flexible wrt input type for the recomposition, e.g. using u128 for noise squashed primitives
1 parent f326aaf commit a11584f

File tree

11 files changed

+151
-127
lines changed

11 files changed

+151
-127
lines changed

tfhe/src/high_level_api/array/cpu/integers.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ use crate::array::traits::{
1212
};
1313
use crate::high_level_api::global_state;
1414
use crate::high_level_api::integers::{FheIntId, FheUintId};
15-
use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
16-
use crate::integer::client_key::RecomposableSignedInteger;
15+
use crate::integer::block_decomposition::{
16+
DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
17+
};
1718
use crate::integer::server_key::radix_parallel::scalar_div_mod::SignedReciprocable;
1819
use crate::integer::server_key::{Reciprocable, ScalarMultiplier};
1920
use crate::integer::{IntegerRadixCiphertext, RadixCiphertext, SignedRadixCiphertext};

tfhe/src/high_level_api/array/dynamic/signed.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ use crate::high_level_api::array::{
1111
};
1212
use crate::high_level_api::global_state;
1313
use crate::high_level_api::integers::FheIntId;
14-
use crate::integer::block_decomposition::DecomposableInto;
15-
use crate::integer::client_key::RecomposableSignedInteger;
14+
use crate::integer::block_decomposition::{DecomposableInto, RecomposableSignedInteger};
1615
use crate::integer::SignedRadixCiphertext;
1716
use crate::prelude::{FheDecrypt, FheTryEncrypt};
1817
use crate::{ClientKey, Device, Error};

tfhe/src/high_level_api/array/gpu/integers.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ use crate::core_crypto::gpu::CudaStreams;
1515
use crate::high_level_api::global_state;
1616
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
1717
use crate::high_level_api::integers::{FheIntId, FheUintId};
18-
use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
19-
use crate::integer::client_key::RecomposableSignedInteger;
18+
use crate::integer::block_decomposition::{
19+
DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
20+
};
2021
use crate::integer::gpu::ciphertext::{
2122
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
2223
};

tfhe/src/high_level_api/integers/signed/base.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::high_level_api::global_state;
77
use crate::high_level_api::integers::{FheUint, FheUintId, IntegerId};
88
use crate::high_level_api::keys::InternalServerKey;
99
use crate::high_level_api::traits::Tagged;
10-
use crate::integer::client_key::RecomposableSignedInteger;
10+
use crate::integer::block_decomposition::RecomposableSignedInteger;
1111
use crate::integer::parameters::RadixCiphertextConformanceParams;
1212
use crate::named::Named;
1313
use crate::prelude::CastFrom;

tfhe/src/high_level_api/integers/signed/encrypt.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ use crate::high_level_api::global_state;
44
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
55
use crate::high_level_api::integers::FheIntId;
66
use crate::high_level_api::keys::InternalServerKey;
7-
use crate::integer::block_decomposition::DecomposableInto;
8-
use crate::integer::client_key::RecomposableSignedInteger;
7+
use crate::integer::block_decomposition::{DecomposableInto, RecomposableSignedInteger};
98
#[cfg(feature = "gpu")]
109
use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext;
1110
use crate::prelude::{FheDecrypt, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt};

tfhe/src/integer/block_decomposition.rs

+94-25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::core_crypto::prelude::{CastFrom, CastInto, Numeric};
1+
use crate::core_crypto::prelude::{CastFrom, CastInto, Numeric, SignedNumeric};
22
use crate::integer::bigint::static_signed::StaticSignedBigInt;
33
use crate::integer::bigint::static_unsigned::StaticUnsignedBigInt;
44
use core::ops::{AddAssign, BitAnd, ShlAssign, ShrAssign};
@@ -88,6 +88,57 @@ impl<const N: usize> RecomposableFrom<u8> for StaticUnsignedBigInt<N> {}
8888
impl<const N: usize> DecomposableInto<u64> for StaticUnsignedBigInt<N> {}
8989
impl<const N: usize> DecomposableInto<u8> for StaticUnsignedBigInt<N> {}
9090

91+
pub trait RecomposableSignedInteger:
92+
RecomposableFrom<u64>
93+
+ std::ops::Neg<Output = Self>
94+
+ std::ops::Shr<u32, Output = Self>
95+
+ std::ops::BitOrAssign<Self>
96+
+ std::ops::BitOr<Self, Output = Self>
97+
+ std::ops::Mul<Self, Output = Self>
98+
+ CastFrom<Self>
99+
+ SignedNumeric
100+
{
101+
}
102+
103+
impl RecomposableSignedInteger for i8 {}
104+
impl RecomposableSignedInteger for i16 {}
105+
impl RecomposableSignedInteger for i32 {}
106+
impl RecomposableSignedInteger for i64 {}
107+
impl RecomposableSignedInteger for i128 {}
108+
109+
impl<const N: usize> RecomposableSignedInteger for StaticSignedBigInt<N> {}
110+
111+
pub trait SignExtendable:
112+
std::ops::Shl<u32, Output = Self> + std::ops::Shr<u32, Output = Self> + SignedNumeric
113+
{
114+
}
115+
116+
impl<T> SignExtendable for T where T: RecomposableSignedInteger {}
117+
118+
/// This function takes a signed integer of type `T` for which `num_bits_set`
119+
/// have been set.
120+
///
121+
/// It will set the most significant bits to the value of the bit
122+
/// at pos `num_bits_set - 1`.
123+
///
124+
/// This is used to correctly decrypt a signed radix ciphertext into a clear type
125+
/// that has more bits than the original ciphertext.
126+
///
127+
/// This is like doing i8 as i16, i16 as i64, i16 as i8, etc
128+
pub(in crate::integer) fn sign_extend_partial_number<T>(unpadded_value: T, num_bits_set: u32) -> T
129+
where
130+
T: SignExtendable,
131+
{
132+
if num_bits_set >= T::BITS as u32 {
133+
return unpadded_value;
134+
}
135+
136+
// Shift to put the last set bit in the position of the sign bit of T
137+
// When right shifting this will do the sign extend automatically
138+
let shift = T::BITS as u32 - num_bits_set;
139+
(unpadded_value << shift) >> shift
140+
}
141+
91142
#[derive(Copy, Clone)]
92143
#[repr(u32)]
93144
pub enum PaddingBitValue {
@@ -256,25 +307,6 @@ pub struct BlockRecomposer<T> {
256307
bit_pos: u32,
257308
}
258309

259-
impl<T> BlockRecomposer<T>
260-
where
261-
T: Recomposable,
262-
{
263-
pub fn value(&self) -> T {
264-
let is_signed = (T::ONE << (T::BITS as u32 - 1)) < T::ZERO;
265-
if self.bit_pos >= (T::BITS as u32 - u32::from(is_signed)) {
266-
self.data
267-
} else {
268-
let valid_mask = (T::ONE << self.bit_pos) - T::ONE;
269-
self.data & valid_mask
270-
}
271-
}
272-
273-
pub fn unmasked_value(&self) -> T {
274-
self.data
275-
}
276-
}
277-
278310
impl<T> BlockRecomposer<T>
279311
where
280312
T: Recomposable,
@@ -292,12 +324,21 @@ where
292324
bit_pos,
293325
}
294326
}
295-
}
296327

297-
impl<T> BlockRecomposer<T>
298-
where
299-
T: Recomposable,
300-
{
328+
pub fn value(&self) -> T {
329+
let is_signed = (T::ONE << (T::BITS as u32 - 1)) < T::ZERO;
330+
if self.bit_pos >= (T::BITS as u32 - u32::from(is_signed)) {
331+
self.data
332+
} else {
333+
let valid_mask = (T::ONE << self.bit_pos) - T::ONE;
334+
self.data & valid_mask
335+
}
336+
}
337+
338+
pub fn unmasked_value(&self) -> T {
339+
self.data
340+
}
341+
301342
pub fn add_unmasked<V>(&mut self, block: V) -> bool
302343
where
303344
T: CastFrom<V>,
@@ -328,6 +369,34 @@ where
328369

329370
true
330371
}
372+
373+
pub(crate) fn recompose_unsigned<U>(input: impl Iterator<Item = U>, bits_in_block: u32) -> T
374+
where
375+
T: RecomposableFrom<U>,
376+
{
377+
let mut recomposer = Self::new(bits_in_block);
378+
for limb in input {
379+
if !recomposer.add_unmasked(limb) {
380+
break;
381+
}
382+
}
383+
384+
recomposer.value()
385+
}
386+
387+
pub(crate) fn recompose_signed<U>(input: impl Iterator<Item = U>, bits_in_block: u32) -> T
388+
where
389+
T: RecomposableFrom<U> + SignExtendable,
390+
{
391+
let mut recomposer = Self::new(bits_in_block);
392+
for limb in input {
393+
if !recomposer.add_unmasked(limb) {
394+
break;
395+
}
396+
}
397+
398+
sign_extend_partial_number(recomposer.value(), recomposer.bit_pos)
399+
}
331400
}
332401

333402
#[cfg(test)]

tfhe/src/integer/ciphertext/base.rs

+27-16
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ use crate::core_crypto::prelude::UnsignedNumeric;
44
use crate::integer::backward_compatibility::ciphertext::{
55
BaseCrtCiphertextVersions, BaseRadixCiphertextVersions, BaseSignedRadixCiphertextVersions,
66
};
7-
use crate::integer::block_decomposition::{BlockRecomposer, RecomposableFrom};
8-
use crate::integer::client_key::{sign_extend_partial_number, RecomposableSignedInteger};
7+
use crate::integer::block_decomposition::{
8+
BlockRecomposer, RecomposableFrom, RecomposableSignedInteger,
9+
};
910
use crate::shortint::ciphertext::NotTrivialCiphertextError;
1011
use crate::shortint::parameters::CiphertextConformanceParams;
1112
use crate::shortint::Ciphertext;
@@ -102,15 +103,21 @@ impl RadixCiphertext {
102103
where
103104
Clear: UnsignedNumeric + RecomposableFrom<u64>,
104105
{
106+
if !self.blocks.iter().all(|b| b.is_trivial()) {
107+
return Err(NotTrivialCiphertextError);
108+
}
109+
105110
let bits_in_block = self.blocks[0].message_modulus.0.ilog2();
106-
let mut recomposer = BlockRecomposer::<Clear>::new(bits_in_block);
107111

108-
for encrypted_block in &self.blocks {
109-
let decrypted_block = encrypted_block.decrypt_trivial_message_and_carry()?;
110-
recomposer.add_unmasked(decrypted_block);
111-
}
112+
let decrypted_block_iter = self
113+
.blocks
114+
.iter()
115+
.map(|block| block.decrypt_trivial_message_and_carry().unwrap());
112116

113-
Ok(recomposer.value())
117+
Ok(BlockRecomposer::recompose_unsigned(
118+
decrypted_block_iter,
119+
bits_in_block,
120+
))
114121
}
115122
}
116123

@@ -204,17 +211,21 @@ impl SignedRadixCiphertext {
204211
where
205212
Clear: RecomposableSignedInteger,
206213
{
214+
if !self.blocks.iter().all(|b| b.is_trivial()) {
215+
return Err(NotTrivialCiphertextError);
216+
}
217+
207218
let bits_in_block = self.blocks[0].message_modulus.0.ilog2();
208-
let mut recomposer = BlockRecomposer::<Clear>::new(bits_in_block);
209219

210-
for encrypted_block in &self.blocks {
211-
let decrypted_block = encrypted_block.decrypt_trivial_message_and_carry()?;
212-
recomposer.add_unmasked(decrypted_block);
213-
}
220+
let decrypted_block_iter = self
221+
.blocks
222+
.iter()
223+
.map(|block| block.decrypt_trivial_message_and_carry().unwrap());
214224

215-
let num_bits_in_ctxt = bits_in_block * self.blocks.len() as u32;
216-
let unpadded_value = recomposer.value();
217-
Ok(sign_extend_partial_number(unpadded_value, num_bits_in_ctxt))
225+
Ok(BlockRecomposer::recompose_signed(
226+
decrypted_block_iter,
227+
bits_in_block,
228+
))
218229
}
219230
}
220231

tfhe/src/integer/client_key/mod.rs

+15-71
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ pub(crate) mod secret_encryption_key;
99
pub(crate) mod utils;
1010

1111
use super::backward_compatibility::client_key::ClientKeyVersions;
12-
use super::block_decomposition::{DecomposableInto, RecomposableFrom};
1312
use super::ciphertext::{
1413
CompressedRadixCiphertext, CompressedSignedRadixCiphertext, RadixCiphertext,
1514
SignedRadixCiphertext,
1615
};
17-
use crate::core_crypto::prelude::{CastFrom, SignedNumeric, UnsignedNumeric};
18-
use crate::integer::bigint::static_signed::StaticSignedBigInt;
19-
use crate::integer::block_decomposition::BlockRecomposer;
16+
use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric};
17+
use crate::integer::block_decomposition::{
18+
BlockRecomposer, DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
19+
};
2020
use crate::integer::ciphertext::boolean_value::BooleanBlock;
2121
use crate::integer::ciphertext::{CompressedCrtCiphertext, CrtCiphertext};
2222
use crate::integer::client_key::utils::i_crt;
@@ -33,53 +33,6 @@ use secret_encryption_key::SecretEncryptionKeyView;
3333
use serde::{Deserialize, Serialize};
3434
use tfhe_versionable::Versionize;
3535

36-
pub trait RecomposableSignedInteger:
37-
RecomposableFrom<u64>
38-
+ std::ops::Neg<Output = Self>
39-
+ std::ops::Shr<u32, Output = Self>
40-
+ std::ops::BitOrAssign<Self>
41-
+ std::ops::BitOr<Self, Output = Self>
42-
+ std::ops::Mul<Self, Output = Self>
43-
+ CastFrom<Self>
44-
+ SignedNumeric
45-
{
46-
}
47-
48-
impl RecomposableSignedInteger for i8 {}
49-
impl RecomposableSignedInteger for i16 {}
50-
impl RecomposableSignedInteger for i32 {}
51-
impl RecomposableSignedInteger for i64 {}
52-
impl RecomposableSignedInteger for i128 {}
53-
54-
impl<const N: usize> RecomposableSignedInteger for StaticSignedBigInt<N> {}
55-
56-
/// This function takes a signed integer of type `T` for which `num_bits_set`
57-
/// have been set.
58-
///
59-
/// It will set the most significant bits to the value of the bit
60-
/// at pos `num_bits_set - 1`.
61-
///
62-
/// This is used to correctly decrypt a signed radix ciphertext into a clear type
63-
/// that has more bits than the original ciphertext.
64-
///
65-
/// This is like doing i8 as i16, i16 as i64, i6 as i8, etc
66-
pub(in crate::integer) fn sign_extend_partial_number<T>(unpadded_value: T, num_bits_set: u32) -> T
67-
where
68-
T: RecomposableSignedInteger,
69-
{
70-
if num_bits_set >= T::BITS as u32 {
71-
return unpadded_value;
72-
}
73-
let sign_bit_pos = num_bits_set - 1;
74-
let sign_bit = (unpadded_value >> sign_bit_pos) & T::ONE;
75-
76-
// Creates a padding mask
77-
// where bits above num_bits_set
78-
// are 1s if sign bit is `1` else `0`
79-
let padding = (T::MAX * sign_bit) << num_bits_set;
80-
padding | unpadded_value
81-
}
82-
8336
/// A structure containing the client key, which must be kept secret.
8437
///
8538
/// This key can be used to encrypt both in Radix and CRT
@@ -381,18 +334,8 @@ impl ClientKey {
381334
}
382335

383336
let bits_in_block = self.key.parameters.message_modulus().0.ilog2();
384-
let mut recomposer = BlockRecomposer::<T>::new(bits_in_block);
385-
386-
for encrypted_block in blocks {
387-
let decrypted_block = decrypt_block(&self.key, encrypted_block);
388-
if !recomposer.add_unmasked(decrypted_block) {
389-
// End of T::BITS reached no need to try more
390-
// recomposition
391-
break;
392-
}
393-
}
394-
395-
recomposer.value()
337+
let decrypted_block_iter = blocks.iter().map(|block| decrypt_block(&self.key, block));
338+
BlockRecomposer::recompose_unsigned(decrypted_block_iter, bits_in_block)
396339
}
397340

398341
pub fn encrypt_signed_radix<T>(&self, message: T, num_blocks: usize) -> SignedRadixCiphertext
@@ -470,15 +413,16 @@ impl ClientKey {
470413
let message_modulus = self.parameters().message_modulus().0;
471414
assert!(message_modulus.is_power_of_two());
472415

473-
// Decrypting a signed value is the same as decrypting an unsigned value
474-
// but, in the signed case,
475-
// we have to take care of the case when the clear type T has more bits
476-
// than what the ciphertext encrypts.
477-
let unpadded_value = self.decrypt_radix_impl(&ctxt.blocks, decrypt_block);
416+
if ctxt.blocks.is_empty() {
417+
return T::ZERO;
418+
}
478419

479-
let num_bits_in_message = message_modulus.ilog2();
480-
let num_bits_in_ctxt = num_bits_in_message * ctxt.blocks.len() as u32;
481-
sign_extend_partial_number(unpadded_value, num_bits_in_ctxt)
420+
let bits_in_block = self.key.parameters.message_modulus().0.ilog2();
421+
let decrypted_block_iter = ctxt
422+
.blocks
423+
.iter()
424+
.map(|block| decrypt_block(&self.key, block));
425+
BlockRecomposer::recompose_signed(decrypted_block_iter, bits_in_block)
482426
}
483427

484428
/// Encrypts one block.

0 commit comments

Comments
 (0)