Skip to content

Commit

Permalink
weierstrass, edwards: use global WeakMap to store precomputes. Make p…
Browse files Browse the repository at this point in the history
…oints immutable.
  • Loading branch information
paulmillr committed Aug 4, 2024
1 parent f3ef397 commit 3a2e4a5
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 127 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ import { bls12_381 } from '@noble/curves/bls12-381';
```

See [abstract/bls](#bls-barreto-lynn-scott-curves).
For example usage, check out [the implementation of EVM precompiles](https://github.com/ethereumjs/ethereumjs-monorepo/blob/361f4edbc239e795a411ac2da7e5567298b9e7e5/packages/evm/src/precompiles/bls12_381/noble.ts).

#### bn254 aka alt_bn128

Expand Down
74 changes: 35 additions & 39 deletions src/abstract/bls.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// BLS (Barreto-Lynn-Scott) family of pairing-friendly curves.
// TODO: import { AffinePoint } from './curve.js';
import { IField, getMinHashLength, mapHashToField } from './modular.js';
import { Hex, PrivKey, CHash, ensureBytes } from './utils.js';
import { Hex, PrivKey, CHash, ensureBytes, cached } from './utils.js';
// prettier-ignore
import {
MapToCurve, Opts as HTFOpts, H2CPointConstructor, htfBasicOpts,
Expand Down Expand Up @@ -201,7 +203,7 @@ export function bls(CURVE: CurveType): CurveFn {
})
);
type G1 = typeof G1.ProjectivePoint.BASE;
type G2 = typeof G2.ProjectivePoint.BASE & { _PPRECOMPUTES?: Precompute };
type G2 = typeof G2.ProjectivePoint.BASE;

// Applies sparse multiplication as line function
let lineFunction: (c0: Fp2, c1: Fp2, c2: Fp2, f: Fp12, Px: Fp, Py: Fp) => Fp12;
Expand All @@ -216,7 +218,7 @@ export function bls(CURVE: CurveType): CurveFn {
} else throw new Error('bls: unknown twist type');

const Fp2div2 = Fp2.div(Fp2.ONE, Fp2.mul(Fp2.ONE, _2n));
const pointDouble = (ell: PrecomputeSingle, Rx: Fp2, Ry: Fp2, Rz: Fp2) => {
function pointDouble(ell: PrecomputeSingle, Rx: Fp2, Ry: Fp2, Rz: Fp2) {
const t0 = Fp2.sqr(Ry); // Ry²
const t1 = Fp2.sqr(Rz); // Rz²
const t2 = Fp2.mulByB(Fp2.mul(t1, _3n)); // 3 * T1 * B
Expand All @@ -232,14 +234,14 @@ export function bls(CURVE: CurveType): CurveFn {
Ry = Fp2.sub(Fp2.sqr(Fp2.mul(Fp2.add(t0, t3), Fp2div2)), Fp2.mul(Fp2.sqr(t2), _3n)); // ((T0 + T3) / 2)² - 3 * T2²
Rz = Fp2.mul(t0, t4); // T0 * T4
return { Rx, Ry, Rz };
};
}
function pointAdd(ell: PrecomputeSingle, Rx: Fp2, Ry: Fp2, Rz: Fp2, Qx: Fp2, Qy: Fp2) {
// Addition
const t0 = Fp2.sub(Ry, Fp2.mul(Qy, Rz)); // Ry - Qy * Rz
const t1 = Fp2.sub(Rx, Fp2.mul(Qx, Rz)); // Rx - Qx * Rz
const c0 = Fp2.sub(Fp2.mul(t0, Qx), Fp2.mul(t1, Qy)); // T0 * Qx - T1 * Qy (i)
const c1 = Fp2.neg(t0); // -T0
const c2 = t1;
const c0 = Fp2.sub(Fp2.mul(t0, Qx), Fp2.mul(t1, Qy)); // T0 * Qx - T1 * Qy == Ry * Qx - Rx * Qy
const c1 = Fp2.neg(t0); // -T0 == Qy * Rz - Ry
const c2 = t1; // == Rx - Qx * Rz

ell.push([c0, c1, c2]);

Expand All @@ -258,12 +260,9 @@ export function bls(CURVE: CurveType): CurveFn {
// pointAdd happens only if bit set, so wNAF is reasonable. Unfortunately we cannot combine
// add + double in windowed precomputes here, otherwise it would be single op (since X is static)
const ATE_NAF = NAfDecomposition(CURVE.params.ateLoopSize);
// TODO: Fp.sqr can-be re-used in batch
// TODO: we can combine two lineFunc multipl in one, but applying result can be slower
// For G1 we can only convert it into affine, no other precomputes possible
function calcPairingPrecomputes(point: G2) {

const calcPairingPrecomputes = cached((point: G2) => {
const p = point;
if (p._PPRECOMPUTES) return p._PPRECOMPUTES;
const { x, y } = p.toAffine();
// prettier-ignore
const Qx = x, Qy = y, negQy = Fp2.neg(y);
Expand All @@ -280,9 +279,9 @@ export function bls(CURVE: CurveType): CurveFn {
const last = ell[ell.length - 1];
CURVE.postPrecompute(Rx, Ry, Rz, Qx, Qy, pointAdd.bind(null, last));
}
p._PPRECOMPUTES = ell;
return ell;
}
});

// Main pairing logic is here. Computes product of miller loops + final exponentiate
// Applies calculated precomputes
type MillerInput = [Precompute, Fp, Fp][];
Expand All @@ -301,26 +300,15 @@ export function bls(CURVE: CurveType): CurveFn {
if (BLS_X_IS_NEGATIVE) f12 = Fp12.conjugate(f12);
return withFinalExponent ? Fp12.finalExponentiate(f12) : f12;
}
/*
TODO: revisit using precomputes and maybe remove them.
pairing x 84 ops/sec @ 11ms/op
pairing10 x 18 ops/sec @ 54ms/op (raw expected to be 110ms/op, so ~2x faster?)
pairing10 x 19 ops/sec @ 52ms/op with disabled precomputes
verifyBatch can be faster, but we don't bench it.
assertValidity/toAffine is very slow.
*/
type PairingInput = { g1: G1; g2: G2 };
// Calculates product of multiple pairings
// This up to x2 faster than just `map(({g1, g2})=>pairing({g1,g2}))`
function pairingBatch(pairs: PairingInput[], withFinalExponent: boolean = true) {
const res: MillerInput = [];
// Pairings use affine coordinates inside
const g1Norm = G1.ProjectivePoint.normalizeZ(pairs.map(({ g1 }) => g1));
const g2Norm = G2.ProjectivePoint.normalizeZ(pairs.map(({ g2 }) => g2));
for (let i = 0; i < g1Norm.length; i++) {
const g1 = g1Norm[i];
const g2 = g2Norm[i];
// This cache precomputed toAffine for all points
G1.ProjectivePoint.normalizeZ(pairs.map(({ g1 }) => g1));
G2.ProjectivePoint.normalizeZ(pairs.map(({ g2 }) => g2));
for (const { g1, g2 } of pairs) {
if (g1.equals(G1.ProjectivePoint.ZERO) || g2.equals(G2.ProjectivePoint.ZERO))
throw new Error('pairing is not available for ZERO point');
// This uses toAffine inside
Expand Down Expand Up @@ -494,6 +482,7 @@ export function bls(CURVE: CurveType): CurveFn {
// e(G, S) = e(G, SUM(n)(Si)) = MUL(n)(e(G, Si))
function verifyBatch(
signature: G2Hex,
// TODO: maybe `{message: G2Hex, publicKey: G1Hex}[]` instead?
messages: G2Hex[],
publicKeys: G1Hex[],
htfOpts?: htfBasicOpts
Expand All @@ -504,16 +493,23 @@ export function bls(CURVE: CurveType): CurveFn {
const sig = normP2(signature);
const nMessages = messages.map((i) => normP2Hash(i, htfOpts));
const nPublicKeys = publicKeys.map(normP1);
// NOTE: this works only for exact same object
const messagePubKeyMap = new Map<G2, G1[]>();
for (let i = 0; i < nPublicKeys.length; i++) {
const pub = nPublicKeys[i];
const msg = nMessages[i];
let keys = messagePubKeyMap.get(msg);
if (keys === undefined) {
keys = [];
messagePubKeyMap.set(msg, keys);
}
keys.push(pub);
}
const paired = [];
try {
const paired = [];
for (const message of new Set(nMessages)) {
// TODO: seems broken
// nMessages is set of objects -> same message is different object. '===' won't work here.
const groupPublicKey = nMessages.reduce(
(acc, subMessage, i) => (subMessage === message ? acc.add(nPublicKeys[i]) : acc),
G1.ProjectivePoint.ZERO
);
paired.push({ g1: groupPublicKey, g2: message });
for (const [msg, keys] of messagePubKeyMap) {
const groupPublicKey = keys.reduce((acc, msg) => acc.add(msg));
paired.push({ g1: groupPublicKey, g2: msg });
}
paired.push({ g1: G1.ProjectivePoint.BASE.negate(), g2: sig });
return Fp12.eql(pairingBatch(paired), Fp12.ONE);
Expand Down
30 changes: 23 additions & 7 deletions src/abstract/curve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ export type GroupConstructor<T> = {
};
export type Mapper<T> = (i: T[]) => T[];

// Since points in different groups cannot be equal (different object constructor),
// we can have single place to store precomputes
const pointPrecomputes = new WeakMap<any, any[]>();
const pointWindowSizes = new WeakMap<any, number>(); // This allows use make points immutable (nothing changes inside)

// Elliptic curve multiplication of Point by scalar. Fragile.
// Scalars should always be less than curve order: this should be checked inside of a curve itself.
// Creates precomputation tables for fast multiplication:
Expand All @@ -41,7 +46,12 @@ export function wNAF<T extends Group<T>>(c: GroupConstructor<T>, bits: number) {
const neg = item.negate();
return condition ? neg : item;
};
const validateW = (W: number) => {
if (!Number.isSafeInteger(W) || W <= 0 || W > bits)
throw new Error(`Wrong window size=${W}, should be [1..${bits}]`);
};
const opts = (W: number) => {
validateW(W);
const windows = Math.ceil(bits / W) + 1; // +1, because
const windowSize = 2 ** (W - 1); // -1 because we skip zero
return { windows, windowSize };
Expand Down Expand Up @@ -149,19 +159,25 @@ export function wNAF<T extends Group<T>>(c: GroupConstructor<T>, bits: number) {
return { p, f };
},

wNAFCached(P: T, precomputesMap: Map<T, T[]>, n: bigint, transform: Mapper<T>): { p: T; f: T } {
// @ts-ignore
const W: number = P._WINDOW_SIZE || 1;
wNAFCached(P: T, n: bigint, transform: Mapper<T>): { p: T; f: T } {
const W: number = pointWindowSizes.get(P) || 1;
// Calculate precomputes on a first run, reuse them after
let comp = precomputesMap.get(P);
let comp = pointPrecomputes.get(P);
if (!comp) {
comp = this.precomputeWindow(P, W) as T[];
if (W !== 1) {
precomputesMap.set(P, transform(comp));
}
if (W !== 1) pointPrecomputes.set(P, transform(comp));
}
return this.wNAF(W, comp, n);
},
// We calculate precomputes for elliptic curve point multiplication
// using windowed method. This specifies window size and
// stores precomputed values. Usually only base point would be precomputed.

setWindowSize(P: T, W: number) {
validateW(W);
pointWindowSizes.set(P, W);
pointPrecomputes.delete(P);
},
};
}

Expand Down
85 changes: 46 additions & 39 deletions src/abstract/edwards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import { AffinePoint, BasicCurve, Group, GroupConstructor, validateBasic, wNAF } from './curve.js';
import { mod } from './modular.js';
import * as ut from './utils.js';
import { ensureBytes, FHash, Hex } from './utils.js';
import { ensureBytes, FHash, Hex, cached, abool } from './utils.js';

// Be friendly to bad ECMAScript parsers by not using bigint literals
// prettier-ignore
Expand Down Expand Up @@ -134,6 +134,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const domain =
CURVE.domain ||
((data: Uint8Array, ctx: Uint8Array, phflag: boolean) => {
abool('phflag', phflag);
if (ctx.length || phflag) throw new Error('Contexts/pre-hash are not supported');
return data;
}); // NOOP
Expand All @@ -142,10 +143,43 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
function assertCoordinate(title: string, n: bigint) {
assertInRange('coordinate ' + title, n, _0n, MASK);
}
const pointPrecomputes = new Map<Point, Point[]>();

function assertPoint(other: unknown) {
if (!(other instanceof Point)) throw new Error('ExtendedPoint expected');
}
// Converts Extended point to default (x, y) coordinates.
// Can accept precomputed Z^-1 - for example, from invertBatch.
const cachedAffine = cached((p: Point, iz?: bigint): AffinePoint<bigint> => {
const { ex: x, ey: y, ez: z } = p;
const is0 = p.is0();
if (iz == null) iz = is0 ? _8n : (Fp.inv(z) as bigint); // 8 was chosen arbitrarily
const ax = modP(x * iz);
const ay = modP(y * iz);
const zz = modP(z * iz);
if (is0) return { x: _0n, y: _1n };
if (zz !== _1n) throw new Error('invZ was invalid');
return { x: ax, y: ay };
});
const cachedValidity = cached((p: Point) => {
const { a, d } = CURVE;
if (p.is0()) throw new Error('bad point: ZERO'); // TODO: optimize, with vars below?
// Equation in affine coordinates: ax² + y² = 1 + dx²y²
// Equation in projective coordinates (X/Z, Y/Z, Z): (aX² + Y²)Z² = Z⁴ + dX²Y²
const { ex: X, ey: Y, ez: Z, et: T } = p;
const X2 = modP(X * X); // X²
const Y2 = modP(Y * Y); // Y²
const Z2 = modP(Z * Z); // Z²
const Z4 = modP(Z2 * Z2); // Z⁴
const aX2 = modP(X2 * a); // aX²
const left = modP(Z2 * modP(aX2 + Y2)); // (aX² + Y²)Z²
const right = modP(Z4 + modP(d * modP(X2 * Y2))); // Z⁴ + dX²Y²
if (left !== right) throw new Error('bad point: equation left != right (1)');
// In Extended coordinates we also have T, which is x*y=T/Z: check X*Y == Z*T
const XY = modP(X * Y);
const ZT = modP(Z * T);
if (XY !== ZT) throw new Error('bad point: equation left != right (2)');
return true;
});
// Extended Point works in extended coordinates: (x, y, z, t) ∋ (x=x/z, y=y/z, t=xy).
// https://en.wikipedia.org/wiki/Twisted_Edwards_curve#Extended_coordinates
class Point implements ExtPointType {
Expand All @@ -162,6 +196,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
assertCoordinate('y', ey);
assertCoordinate('z', ez);
assertCoordinate('t', et);
Object.freeze(this);
}

get x(): bigint {
Expand All @@ -183,36 +218,14 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
return points.map((p, i) => p.toAffine(toInv[i])).map(Point.fromAffine);
}

// We calculate precomputes for elliptic curve point multiplication
// using windowed method. This specifies window size and
// stores precomputed values. Usually only base point would be precomputed.
_WINDOW_SIZE?: number;

// "Private method", don't use it directly
_setWindowSize(windowSize: number) {
this._WINDOW_SIZE = windowSize;
pointPrecomputes.delete(this);
wnaf.setWindowSize(this, windowSize);
}
// Not required for fromHex(), which always creates valid points.
// Could be useful for fromAffine().
assertValidity(): void {
const { a, d } = CURVE;
if (this.is0()) throw new Error('bad point: ZERO'); // TODO: optimize, with vars below?
// Equation in affine coordinates: ax² + y² = 1 + dx²y²
// Equation in projective coordinates (X/Z, Y/Z, Z): (aX² + Y²)Z² = Z⁴ + dX²Y²
const { ex: X, ey: Y, ez: Z, et: T } = this;
const X2 = modP(X * X); // X²
const Y2 = modP(Y * Y); // Y²
const Z2 = modP(Z * Z); // Z²
const Z4 = modP(Z2 * Z2); // Z⁴
const aX2 = modP(X2 * a); // aX²
const left = modP(Z2 * modP(aX2 + Y2)); // (aX² + Y²)Z²
const right = modP(Z4 + modP(d * modP(X2 * Y2))); // Z⁴ + dX²Y²
if (left !== right) throw new Error('bad point: equation left != right (1)');
// In Extended coordinates we also have T, which is x*y=T/Z: check X*Y == Z*T
const XY = modP(X * Y);
const ZT = modP(Z * T);
if (XY !== ZT) throw new Error('bad point: equation left != right (2)');
cachedValidity(this);
}

// Compare one point to another.
Expand All @@ -227,7 +240,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
return X1Z2 === X2Z1 && Y1Z2 === Y2Z1;
}

protected is0(): boolean {
is0(): boolean {
return this.equals(Point.ZERO);
}

Expand Down Expand Up @@ -307,12 +320,12 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
}

private wNAF(n: bigint): { p: Point; f: Point } {
return wnaf.wNAFCached(this, pointPrecomputes, n, Point.normalizeZ);
return wnaf.wNAFCached(this, n, Point.normalizeZ);
}

// Constant-time multiplication.
multiply(scalar: bigint): Point {
let n = scalar;
const n = scalar;
assertInRange('scalar', n, _1n, CURVE_ORDER);
const { p, f } = this.wNAF(n);
return Point.normalizeZ([p, f])[0];
Expand All @@ -323,7 +336,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
// an exposed private key e.g. sig verification.
// Does NOT allow scalars higher than CURVE.n.
multiplyUnsafe(scalar: bigint): Point {
let n = scalar;
const n = scalar;
assertInRange('scalar', n, _0n, CURVE_ORDER); // 0 <= scalar < l
if (n === _0n) return I;
if (this.equals(I) || n === _1n) return this;
Expand All @@ -348,15 +361,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
// Converts Extended point to default (x, y) coordinates.
// Can accept precomputed Z^-1 - for example, from invertBatch.
toAffine(iz?: bigint): AffinePoint<bigint> {
const { ex: x, ey: y, ez: z } = this;
const is0 = this.is0();
if (iz == null) iz = is0 ? _8n : (Fp.inv(z) as bigint); // 8 was chosen arbitrarily
const ax = modP(x * iz);
const ay = modP(y * iz);
const zz = modP(z * iz);
if (is0) return { x: _0n, y: _1n };
if (zz !== _1n) throw new Error('invZ was invalid');
return { x: ax, y: ay };
return cachedAffine(this, iz);
}

clearCofactor(): Point {
Expand All @@ -371,6 +376,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const { d, a } = CURVE;
const len = Fp.BYTES;
hex = ensureBytes('pointHex', hex, len); // copy hex to a new array
abool('zip215', zip215);
const normed = hex.slice(); // copy again, we'll manipulate it
const lastByte = hex[len - 1]; // select last byte
normed[len - 1] = lastByte & ~0x80; // clear last bit
Expand Down Expand Up @@ -467,6 +473,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const len = Fp.BYTES; // Verifies EdDSA signature against message and public key. RFC8032 5.1.7.
sig = ensureBytes('signature', sig, 2 * len); // An extended group equation is checked.
msg = ensureBytes('message', msg);
if (zip215 !== undefined) abool('zip215', zip215);
if (prehash) msg = prehash(msg); // for ed25519ph, etc

const s = ut.bytesToNumberLE(sig.slice(len, 2 * len));
Expand Down
3 changes: 3 additions & 0 deletions src/abstract/modular.ts
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ type FpField = IField<bigint> & Required<Pick<IField<bigint>, 'isOdd'>>;
* * a) denormalized operations like mulN instead of mul
* * b) same object shape: never add or remove keys
* * c) Object.freeze
* NOTE: operations don't check 'isValid' for all elements for performance reasons,
* it is caller responsibility to check this.
* This is low-level code, please make sure you know what you doing.
* @param ORDER prime positive bigint
* @param bitLen how many bits the field consumes
* @param isLE (def: false) if encoding / decoding should be in little-endian
Expand Down
Loading

0 comments on commit 3a2e4a5

Please sign in to comment.