diff --git a/crates/fhe/src/bfv/ops/dot_product.rs b/crates/fhe/src/bfv/ops/dot_product.rs index f51aed13..63c2f8a3 100644 --- a/crates/fhe/src/bfv/ops/dot_product.rs +++ b/crates/fhe/src/bfv/ops/dot_product.rs @@ -58,25 +58,24 @@ where I: Iterator + Clone, J: Iterator + Clone, { - let count = min(ct.clone().count(), pt.clone().count()); - if count == 0 { + // Collect the zipped iterators to avoid multiple traversals and to stop at the + // length of the shorter iterator, preventing O(N) memory usage if one iterator is huge. + let pairs: Vec<(&Ciphertext, &Plaintext)> = izip!(ct, pt).collect(); + + if pairs.is_empty() { return Err(Error::DefaultError( "At least one iterator is empty".to_string(), )); } - let ct_first = ct.clone().next().unwrap(); + let count = pairs.len(); + let ct_first = pairs[0].0; let ctx = ct_first[0].ctx(); - if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| { + if pairs.iter().any(|(cti, pti)| { cti.par != ct_first.par || pti.par != ct_first.par || cti.len() != ct_first.len() }) { return Err(Error::DefaultError("Mismatched parameters".to_string())); } - if ct.clone().any(|cti| cti.len() != ct_first.len()) { - return Err(Error::DefaultError( - "Mismatched number of parts in the ciphertexts".to_string(), - )); - } let max_acc = ctx .moduli() @@ -91,8 +90,8 @@ where let c = (0..ct_first.len()) .map(|i| { poly_dot_product( - ct.clone().map(|cti| unsafe { cti.get_unchecked(i) }), - pt.clone().map(|pti| &pti.poly_ntt), + pairs.iter().map(|(cti, _)| unsafe { cti.get_unchecked(i) }), + pairs.iter().map(|(_, pti)| &pti.poly_ntt), ) .map_err(Error::MathError) }) @@ -106,7 +105,7 @@ where }) } else { let mut acc = Array::zeros((ct_first.len(), ctx.moduli().len(), ct_first.par.degree())); - for (ciphertext, plaintext) in izip!(ct, pt) { + for (ciphertext, plaintext) in pairs { let pt_coefficients = plaintext.poly_ntt.coefficients(); for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.iter()) { let ci_coefficients = ci.coefficients();