Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions crates/fhe/src/bfv/ops/dot_product.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::cmp::min;

use fhe_math::rq::{Ntt, Poly, dot_product as poly_dot_product, traits::TryConvertFrom};
use itertools::{Itertools, izip};
use ndarray::{Array, Array2};
Expand Down Expand Up @@ -58,25 +56,20 @@ where
I: Iterator<Item = &'a Ciphertext> + Clone,
J: Iterator<Item = &'a Plaintext> + Clone,
{
let count = min(ct.clone().count(), pt.clone().count());
if count == 0 {
let inputs = izip!(ct, pt).collect_vec();
if inputs.is_empty() {
return Err(Error::DefaultError(
"At least one iterator is empty".to_string(),
));
}
let ct_first = ct.clone().next().unwrap();
let (ct_first, _) = inputs[0];
let ctx = ct_first[0].ctx();

if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| {
if inputs.iter().any(|(cti, pti)| {
cti.par != ct_first.par || pti.par != ct_first.par || cti.len() != ct_first.len()
}) {
Comment on lines +68 to 70

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Validate all ciphertexts, not just zipped pairs

This validation now only inspects ciphertexts that have a corresponding plaintext in inputs. If the ciphertext iterator is longer than the plaintext iterator, any extra ciphertexts are ignored by izip! and will no longer be checked for parameter/part-length consistency. That is a regression from the previous behavior (which walked all ciphertexts) and violates the function’s own contract (“ciphertexts have different number of parts”) for those trailing ciphertexts. The bug manifests when callers accidentally pass mismatched iterators (e.g., extra ciphertexts), because the function will silently accept ciphertexts with different part counts/params instead of erroring.

Useful? React with 👍 / 👎.

return Err(Error::DefaultError("Mismatched parameters".to_string()));
}
if ct.clone().any(|cti| cti.len() != ct_first.len()) {
return Err(Error::DefaultError(
"Mismatched number of parts in the ciphertexts".to_string(),
));
}

let max_acc = ctx
.moduli()
Expand All @@ -85,14 +78,16 @@ where
.collect_vec();
let min_of_max = max_acc.iter().min().unwrap();

if count as u128 > *min_of_max {
if inputs.len() as u128 > *min_of_max {
// Too many ciphertexts for the optimized method, instead, we call
// `poly_dot_product`.
let c = (0..ct_first.len())
.map(|i| {
poly_dot_product(
ct.clone().map(|cti| unsafe { cti.get_unchecked(i) }),
pt.clone().map(|pti| &pti.poly_ntt),
inputs
.iter()
.map(|(cti, _)| unsafe { cti.get_unchecked(i) }),
inputs.iter().map(|(_, pti)| &pti.poly_ntt),
)
.map_err(Error::MathError)
})
Expand All @@ -106,7 +101,7 @@ where
})
} else {
let mut acc = Array::zeros((ct_first.len(), ctx.moduli().len(), ct_first.par.degree()));
for (ciphertext, plaintext) in izip!(ct, pt) {
for (ciphertext, plaintext) in inputs {
let pt_coefficients = plaintext.poly_ntt.coefficients();
for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.iter()) {
let ci_coefficients = ci.coefficients();
Expand Down
Loading