Skip to content

Commit

Permalink
Merge pull request #591 from taikiy/query_config
Browse files Browse the repository at this point in the history
Reduce the number of parameters to the ipa protocol
  • Loading branch information
akoshelev authored Apr 12, 2023
2 parents 672f6cd + 1744cda commit cc48eba
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 118 deletions.
13 changes: 9 additions & 4 deletions benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use clap::Parser;
use ipa::{
error::Error,
ff::Fp32BitPrime,
helpers::GatewayConfig,
helpers::{query::IpaQueryConfig, GatewayConfig},
test_fixture::{
ipa::{
generate_random_user_records_in_reverse_chronological_order, test_ipa,
Expand Down Expand Up @@ -59,6 +59,8 @@ impl Args {
async fn main() -> Result<(), Error> {
type BenchField = Fp32BitPrime;

const NUM_MULTI_BITS: u32 = 3;

let args = Args::parse();

let prep_time = Instant::now();
Expand Down Expand Up @@ -106,9 +108,12 @@ async fn main() -> Result<(), Error> {
&world,
&raw_data,
&expected_results,
args.per_user_cap,
args.breakdown_keys,
args.attribution_window,
IpaQueryConfig::new(
args.per_user_cap,
args.breakdown_keys,
args.attribution_window,
NUM_MULTI_BITS,
),
IpaSecurityModel::Malicious,
)
.await;
Expand Down
19 changes: 18 additions & 1 deletion src/helpers/transport/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,31 @@ pub struct IpaQueryConfig {
impl Default for IpaQueryConfig {
fn default() -> Self {
Self {
per_user_credit_cap: 1,
per_user_credit_cap: 3,
max_breakdown_key: 64,
attribution_window_seconds: 0,
num_multi_bits: 3,
}
}
}

impl IpaQueryConfig {
#[must_use]
pub fn new(
per_user_credit_cap: u32,
max_breakdown_key: u32,
attribution_window_seconds: u32,
num_multi_bits: u32,
) -> Self {
Self {
per_user_credit_cap,
max_breakdown_key,
attribution_window_seconds,
num_multi_bits,
}
}
}

impl From<IpaQueryConfig> for QueryType {
fn from(value: IpaQueryConfig) -> Self {
QueryType::Ipa(value)
Expand Down
15 changes: 6 additions & 9 deletions src/protocol/attribution/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{
use crate::{
error::Error,
ff::{GaloisField, Gf2, PrimeField, Serializable},
helpers::query::IpaQueryConfig,
protocol::{
context::{Context, SemiHonestContext},
ipa::IPAModulusConvertedInputRow,
Expand All @@ -26,17 +27,13 @@ use std::iter::zip;
///
/// # Errors
/// propagates errors from multiplications
#[allow(clippy::too_many_arguments)]
pub async fn secure_attribution<'a, F, BK>(
sh_ctx: SemiHonestContext<'a>,
malicious_validator: MaliciousValidator<'a, F>,
binary_malicious_validator: MaliciousValidator<'a, Gf2>,
sorted_match_keys: Vec<Vec<AdditiveShare<Gf2>>>,
sorted_rows: Vec<IPAModulusConvertedInputRow<F, AdditiveShare<F>>>,
per_user_credit_cap: u32,
max_breakdown_key: u32,
_attribution_window_seconds: u32, // TODO(taikiy): compute the output with the attribution window
num_multi_bits: u32,
config: IpaQueryConfig,
) -> Result<Vec<MCAggregateCreditOutputRow<F, SemiHonestAdditiveShare<F>, BK>>, Error>
where
F: PrimeField + ExtendableField,
Expand Down Expand Up @@ -69,23 +66,23 @@ where
let accumulated_credits = accumulate_credit(
m_ctx.narrow(&Step::AccumulateCredit),
&attribution_input_rows,
per_user_credit_cap,
config.per_user_credit_cap,
)
.await?;

let user_capped_credits = credit_capping(
m_ctx.narrow(&Step::PerformUserCapping),
&accumulated_credits,
per_user_credit_cap,
config.per_user_credit_cap,
)
.await?;

let (malicious_validator, output) = malicious_aggregate_credit::<F, BK>(
malicious_validator,
sh_ctx,
user_capped_credits.into_iter(),
max_breakdown_key,
num_multi_bits,
config.max_breakdown_key,
config.num_multi_bits,
)
.await?;

Expand Down
14 changes: 6 additions & 8 deletions src/protocol/attribution/semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{
use crate::{
error::Error,
ff::{GaloisField, Gf2, PrimeField, Serializable},
helpers::query::IpaQueryConfig,
protocol::{
context::{Context, SemiHonestContext},
ipa::IPAModulusConvertedInputRow,
Expand All @@ -26,10 +27,7 @@ pub async fn secure_attribution<F, BK>(
ctx: SemiHonestContext<'_>,
sorted_match_keys: Vec<Vec<AdditiveShare<Gf2>>>,
sorted_rows: Vec<IPAModulusConvertedInputRow<F, AdditiveShare<F>>>,
per_user_credit_cap: u32,
max_breakdown_key: u32,
_attribution_window_seconds: u32, // TODO(taikiy): compute the output with the attribution window
num_multi_bits: u32,
config: IpaQueryConfig,
) -> Result<Vec<MCAggregateCreditOutputRow<F, AdditiveShare<F>, BK>>, Error>
where
F: PrimeField,
Expand All @@ -56,22 +54,22 @@ where
let accumulated_credits = accumulate_credit(
ctx.narrow(&Step::AccumulateCredit),
&attribution_input_rows,
per_user_credit_cap,
config.per_user_credit_cap,
)
.await?;

let user_capped_credits = credit_capping(
ctx.narrow(&Step::PerformUserCapping),
&accumulated_credits,
per_user_credit_cap,
config.per_user_credit_cap,
)
.await?;

aggregate_credit::<F, BK>(
ctx.narrow(&Step::AggregateCredit),
user_capped_credits.into_iter(),
max_breakdown_key,
num_multi_bits,
config.max_breakdown_key,
config.num_multi_bits,
)
.await
}
Expand Down
104 changes: 51 additions & 53 deletions src/protocol/ipa/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
error::Error,
ff::{Field, GaloisField, Gf2, PrimeField, Serializable},
helpers::Role,
helpers::{query::IpaQueryConfig, Role},
protocol::{
attribution::{input::MCAggregateCreditOutputRow, malicious, semi_honest},
basics::Reshare,
Expand Down Expand Up @@ -292,10 +292,7 @@ where
pub async fn ipa<F, MK, BK>(
ctx: SemiHonestContext<'_>,
input_rows: &[IPAInputRow<F, MK, BK>],
per_user_credit_cap: u32,
max_breakdown_key: u32,
attribution_window_seconds: u32,
num_multi_bits: u32,
config: IpaQueryConfig,
) -> Result<Vec<MCAggregateCreditOutputRow<F, Replicated<F>, BK>>, Error>
where
F: PrimeField,
Expand Down Expand Up @@ -327,7 +324,7 @@ where
&ctx.narrow(&Step::ModulusConversionForMatchKeys),
&convert_all_bits_local::<F, MK>(ctx.role(), mk_shares.into_iter()),
MK::BITS,
num_multi_bits,
config.num_multi_bits,
)
.await
.unwrap();
Expand Down Expand Up @@ -370,16 +367,7 @@ where
.await
.unwrap();

semi_honest::secure_attribution(
ctx,
sorted_match_keys,
sorted_rows,
per_user_credit_cap,
max_breakdown_key,
attribution_window_seconds,
num_multi_bits,
)
.await
semi_honest::secure_attribution(ctx, sorted_match_keys, sorted_rows, config).await
}

/// Malicious IPA
Expand All @@ -392,10 +380,7 @@ where
pub async fn ipa_malicious<'a, F, MK, BK>(
sh_ctx: SemiHonestContext<'a>,
input_rows: &[IPAInputRow<F, MK, BK>],
per_user_credit_cap: u32,
max_breakdown_key: u32,
attribution_window_seconds: u32,
num_multi_bits: u32,
config: IpaQueryConfig,
) -> Result<Vec<MCAggregateCreditOutputRow<F, Replicated<F>, BK>>, Error>
where
F: PrimeField + ExtendableField,
Expand All @@ -419,7 +404,7 @@ where
.upgrade(convert_all_bits_local(m_ctx.role(), mk_shares.into_iter()))
.await?,
MK::BITS,
num_multi_bits,
config.num_multi_bits,
)
.await
.unwrap();
Expand Down Expand Up @@ -509,10 +494,7 @@ where
binary_validator,
sorted_match_keys,
sorted_rows,
per_user_credit_cap,
max_breakdown_key,
attribution_window_seconds,
num_multi_bits,
config,
)
.await
}
Expand Down Expand Up @@ -545,7 +527,7 @@ pub mod tests {
use super::{ipa, ipa_malicious, IPAInputRow};
use crate::{
ff::{Field, Fp31, Fp32BitPrime, GaloisField, Serializable},
helpers::GatewayConfig,
helpers::{query::IpaQueryConfig, GatewayConfig},
ipa_test_input,
protocol::{BreakdownKey, MatchKey},
secret_sharing::IntoShares,
Expand Down Expand Up @@ -608,10 +590,12 @@ pub mod tests {
ipa::<Fp31, MatchKey, BreakdownKey>(
ctx,
&input_rows,
PER_USER_CAP,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
IpaQueryConfig::new(
PER_USER_CAP,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
),
)
.await
.unwrap()
Expand Down Expand Up @@ -659,10 +643,12 @@ pub mod tests {
ipa_malicious::<_, MatchKey, BreakdownKey>(
ctx,
&input_rows,
PER_USER_CAP,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
IpaQueryConfig::new(
PER_USER_CAP,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
),
)
.await
.unwrap()
Expand Down Expand Up @@ -720,10 +706,12 @@ pub mod tests {
ipa::<Fp31, MatchKey, BreakdownKey>(
ctx,
&input_rows,
PER_USER_CAP,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
IpaQueryConfig::new(
PER_USER_CAP,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
),
)
.await
.unwrap()
Expand All @@ -748,10 +736,12 @@ pub mod tests {
ipa_malicious::<Fp31, MatchKey, BreakdownKey>(
ctx,
&input_rows,
PER_USER_CAP,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
IpaQueryConfig::new(
PER_USER_CAP,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
),
)
.await
.unwrap()
Expand Down Expand Up @@ -825,9 +815,12 @@ pub mod tests {
&world,
&raw_data,
&expected_results,
per_user_cap,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
IpaQueryConfig::new(
per_user_cap,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
),
IpaSecurityModel::SemiHonest,
)
.await;
Expand Down Expand Up @@ -882,6 +875,7 @@ pub mod tests {
/// It is possible to increase the number too if there is a good reason for it. This is a
/// "catch all" type of test to make sure we don't miss an accidental regression.
#[tokio::test]
#[allow(clippy::too_many_lines)]
pub async fn communication_baseline() {
const MAX_BREAKDOWN_KEY: u32 = 3;
const ATTRIBUTION_WINDOW_SECONDS: u32 = 0;
Expand Down Expand Up @@ -927,10 +921,12 @@ pub mod tests {
ipa::<Fp32BitPrime, MatchKey, BreakdownKey>(
ctx,
&input_rows,
per_user_cap,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
IpaQueryConfig::new(
per_user_cap,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
),
)
.await
.unwrap()
Expand Down Expand Up @@ -966,10 +962,12 @@ pub mod tests {
ipa_malicious::<Fp32BitPrime, MatchKey, BreakdownKey>(
ctx,
&input_rows,
per_user_cap,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
IpaQueryConfig::new(
per_user_cap,
MAX_BREAKDOWN_KEY,
ATTRIBUTION_WINDOW_SECONDS,
NUM_MULTI_BITS,
),
)
.await
.unwrap()
Expand Down
Loading

0 comments on commit cc48eba

Please sign in to comment.