From 1dc8c46a402664ea1d3c0faa2da984ce6ea1a097 Mon Sep 17 00:00:00 2001 From: Yi-Hsiu Chen Date: Fri, 20 Feb 2026 09:07:12 -0800 Subject: [PATCH 1/3] build: remove go in .github/workflow --- .github/workflows/ci.yml | 187 ++++++++++++++++++++++++++++------- .github/workflows/codeql.yml | 7 +- 2 files changed, 155 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ea64820e..738acc69 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,63 +14,180 @@ on: - "dev/**" - "docs/**" - ".gitignore" + workflow_dispatch: + inputs: + heavy: + description: "Run heavier checks (demos)" + type: boolean + default: false + schedule: + # Weekly “deep” checks to keep PR CI fast. + - cron: "0 7 * * 1" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + permissions: contents: read - jobs: - tests: + changes: + name: Detect changed paths + runs-on: ubuntu-latest + timeout-minutes: 10 + outputs: + install_smoke: ${{ steps.detect.outputs.install_smoke }} + demos: ${{ steps.detect.outputs.demos }} + steps: + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@0634a2670c59f64b4a01f0f96f84700a4088b9f0 # v2.12.0 + with: + egress-policy: audit + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + # Ensure we can diff reliably for push events. + fetch-depth: 0 + - id: detect + name: Classify changed paths + shell: bash + env: + BEFORE_SHA: ${{ github.event.before }} + run: |- + set -euo pipefail + + install_smoke=false + demos=false + + changed_files="" + if [ "${GITHUB_EVENT_NAME}" = "pull_request" ]; then + # `actions/checkout` checks out a merge commit by default for PRs. + # Its parents are: base (HEAD^1) and head (HEAD^2). + if git rev-parse -q --verify HEAD^2 >/dev/null 2>&1; then + base_sha="$(git rev-parse HEAD^1)" + head_sha="$(git rev-parse HEAD^2)" + changed_files="$(git diff --name-only "${base_sha}" "${head_sha}" || true)" + else + # Fallback: if the checkout isn't a merge commit, diff against the base branch ref. + git fetch --no-tags --prune --depth=1 origin "${GITHUB_BASE_REF}" + changed_files="$(git diff --name-only "origin/${GITHUB_BASE_REF}" HEAD || true)" + fi + elif [ "${GITHUB_EVENT_NAME}" = "push" ]; then + # For pushes, compare the previous and current commit SHAs. + if [ -n "${BEFORE_SHA:-}" ] && [ "${BEFORE_SHA}" != "0000000000000000000000000000000000000000" ]; then + changed_files="$(git diff --name-only "${BEFORE_SHA}" "${GITHUB_SHA}" || true)" + else + # Initial commit / force-push edge case: fall back to listing files in the commit. + changed_files="$(git show --name-only --pretty='' "${GITHUB_SHA}" || true)" + fi + else + # schedule / workflow_dispatch: no diff-based gating (heavier jobs are enabled explicitly). + changed_files="" + fi + + echo "Changed files:" + printf '%s\n' "${changed_files}" + + while IFS= read -r f; do + [ -z "${f}" ] && continue + + # “Installed library smoke” should run when we change the public API surface or the + # build/install plumbing that consumers rely on. + case "${f}" in + include/cbmpc/api/*|src/cbmpc/api/*|include/cbmpc/capi/*|src/cbmpc/capi/*|scripts/install.sh|CMakeLists.txt|cmake/*|Makefile|Dockerfile|scripts/openssl/*) + install_smoke=true + ;; + esac + + # “All demos” should run when demo code or demo plumbing changes. + case "${f}" in + demo-cpp/*|demo-api/*|scripts/run-demos.sh) + demos=true + ;; + esac + done <<< "${changed_files}" + + echo "install_smoke=${install_smoke}" >> "${GITHUB_OUTPUT}" + echo "demos=${demos}" >> "${GITHUB_OUTPUT}" + + core: + name: Core (lint + unit + sanitizers) runs-on: ubuntu-latest + timeout-minutes: 60 steps: - name: Harden the runner (Audit all outbound calls) uses: step-security/harden-runner@0634a2670c59f64b4a01f0f96f84700a4088b9f0 # v2.12.0 with: egress-policy: audit + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Build Docker image + run: |- + set -euo pipefail + make submodules + make image + - name: Lint + unit tests (Debug) + run: |- + set -euo pipefail + docker run --rm -v $(pwd):/code -t cb-mpc bash -c 'make lint && make test BUILD_TYPE=Debug TEST_LABEL=unit' + - name: Sanitizers (ASAN+UBSAN) + run: |- + set -euo pipefail + docker run --rm -v $(pwd):/code -t cb-mpc bash -c 'make sanitize' + # Heavier checks run only on schedule or manual dispatch to keep PR CI fast. + demos: + name: Heavy (all demos) + needs: changes + if: | + github.event_name == 'schedule' || + (github.event_name == 'workflow_dispatch' && inputs.heavy == 'true') || + ((github.event_name == 'pull_request' || github.event_name == 'push') && needs.changes.outputs.demos == 'true') + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@0634a2670c59f64b4a01f0f96f84700a4088b9f0 # v2.12.0 + with: + egress-policy: audit - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - run: |- - echo "Run tests" + set -euo pipefail make submodules make image - docker run --rm -v $(pwd):/code -t cb-mpc bash -c 'make lint && make full-test' - lib-tests: + docker run --rm -v $(pwd):/code -t cb-mpc bash -c ' + set -euo pipefail + make install-all BUILD_TYPE=Release + BUILD_TYPE=Release bash scripts/run-demos.sh --run basic_primitive + BUILD_TYPE=Release bash scripts/run-demos.sh --run zk + BUILD_TYPE=Release bash scripts/run-demos.sh --run parallel_transport + BUILD_TYPE=Release bash scripts/run-demos.sh --run-api pve + BUILD_TYPE=Release bash scripts/run-demos.sh --run-api hd_keyset_ecdsa_2p + BUILD_TYPE=Release bash scripts/run-demos.sh --run-api ecdsa_mp_pve_backup + BUILD_TYPE=Release bash scripts/run-demos.sh --run-api schnorr_2p_pve_batch_backup + ' + + install-demo-smoke: + name: Heavy (install + demo smoke) + needs: changes + if: | + (github.event_name == 'pull_request' || github.event_name == 'push') && + needs.changes.outputs.install_smoke == 'true' && + needs.changes.outputs.demos != 'true' runs-on: ubuntu-latest + timeout-minutes: 60 steps: - name: Harden the runner (Audit all outbound calls) uses: step-security/harden-runner@0634a2670c59f64b4a01f0f96f84700a4088b9f0 # v2.12.0 with: egress-policy: audit - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - run: |- - echo "Test on installed library" + set -euo pipefail make submodules make image - docker run --rm -v "$(pwd)":/code -t cb-mpc bash -lc ' + # Smoke-test “installed library” flow without running every demo. + docker run --rm -v $(pwd):/code -t cb-mpc bash -c ' set -euo pipefail - cd /code - # Never re-touch submodules inside the container; avoids git safe.directory pain. - # If you insist on touching git here, you MUST first: git config --global --add safe.directory /code - if ! command -v go >/dev/null 2>&1; then - apt-get update - apt-get install -y --no-install-recommends ca-certificates curl git - GO_VERSION="${GO_VERSION:-1.24.6}" - GO_FILENAME="go${GO_VERSION}.linux-amd64.tar.gz" - GO_SHA256="bbca37cc395c974ffa4893ee35819ad23ebb27426df87af92e93a9ec66ef8712" - curl -fsSLo "/tmp/${GO_FILENAME}" "https://go.dev/dl/${GO_FILENAME}" - printf '"'"'%s %s\n'"'"' "$GO_SHA256" "/tmp/$GO_FILENAME" | sha256sum --check --status - - tar -C /usr/local -xzf "/tmp/${GO_FILENAME}" - rm -f "/tmp/${GO_FILENAME}" - ln -sf /usr/local/go/bin/* /usr/local/bin/ - fi - export PATH="/usr/local/go/bin:$PATH" - export GOTOOLCHAIN=auto - go version - export CGO_ENABLED=1 - make build - make install - make benchmark-build - # Run demos (Go + C++) with cgo on - BUILD_TYPE=${BUILD_TYPE:-Release} bash scripts/run-demos.sh --run-all - make test-go - make test-go-race - ' \ No newline at end of file + make install-all BUILD_TYPE=Release + BUILD_TYPE=Release bash scripts/run-demos.sh --run basic_primitive + BUILD_TYPE=Release bash scripts/run-demos.sh --run-api pve + ' diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index ae9dfb92..fea229f9 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -12,7 +12,6 @@ on: - '**/*.h' - '**/*.hpp' - '**/*.c' - - '**/*.go' - '**/*.py' permissions: @@ -30,7 +29,7 @@ jobs: strategy: fail-fast: false matrix: - language: [ 'cpp', 'go', 'python' ] + language: [ 'cpp', 'python' ] steps: - name: Harden the runner (Audit all outbound calls) @@ -52,9 +51,9 @@ jobs: - name: Install language dependencies run: | sudo apt-get update - sudo apt-get install -y cmake golang-go + sudo apt-get install -y cmake - - if: matrix.language == 'go' || matrix.language == 'python' + - if: matrix.language == 'python' name: Autobuild uses: github/codeql-action/autobuild@fca7ace96b7d713c7035871441bd52efbe39e27e # v3.28.19 - if: matrix.language == 'cpp' From ae9f39114c17efce6ff5248ba71ef971a5b4954c Mon Sep 17 00:00:00 2001 From: Yi-Hsiu Chen Date: Fri, 20 Feb 2026 09:33:42 -0800 Subject: [PATCH 2/3] feat: introduce C++ and C API layers with header reorganization Add a structured public API surface for the library: - New `cbmpc/api/` C++ API layer with clean, high-level interfaces for ECDSA 2p/mp, EdDSA 2p/mp, Schnorr 2p/mp, PVE (base, batch, AC variants), HD keyset, and TDH2 - New `cbmpc/c_api/` C API layer as a language-agnostic FFI boundary, replacing the previous Go-specific FFI module (`cbmpc/ffi/`) - Reorganize headers: public headers under `include/cbmpc/`, internal headers under `include-internal/cbmpc/internal/` - Remove `demos-go/` (Go bindings now delegate to the C API instead) - Add `demo-api/` and `demo-cpp/` with expanded demos (ECDSA/Schnorr PVE with backup, HD keyset, parallel transport) - Add comprehensive unit tests for both the C++ API and C API layers Co-Authored-By: Claude --- .clang-format | 18 +- .gitignore | 1 + BUG_BOUNTY.md | 26 + CMakeLists.txt | 42 +- Dockerfile | 7 +- Makefile | 123 +- README.md | 79 +- SECURE_USAGE.md | 236 ++++ cmake/compilation_flags.cmake | 28 +- cmake/openssl.cmake | 4 +- demo-api/ecdsa_mp_pve_backup/CMakeLists.txt | 44 + demo-api/ecdsa_mp_pve_backup/main.cpp | 458 +++++++ demo-api/hd_keyset_ecdsa_2p/CMakeLists.txt | 43 + demo-api/hd_keyset_ecdsa_2p/main.cpp | 345 +++++ demo-api/pve/CMakeLists.txt | 45 + demo-api/pve/main.cpp | 747 +++++++++++ .../CMakeLists.txt | 44 + demo-api/schnorr_2p_pve_batch_backup/main.cpp | 266 ++++ demo-cpp/basic_primitive/CMakeLists.txt | 47 + .../basic_primitive/main.cpp | 11 +- .../common}/mpc_job_session.cpp | 6 +- .../common}/mpc_job_session.h | 44 +- demo-cpp/parallel_transport/CMakeLists.txt | 56 + demo-cpp/parallel_transport/main.cpp | 231 ++++ demo-cpp/zk/CMakeLists.txt | 47 + {demos-cpp => demo-cpp}/zk/demo_nizk.h | 8 +- {demos-cpp => demo-cpp}/zk/main.cpp | 5 +- demos-cpp/basic_primitive/CMakeLists.txt | 26 - demos-cpp/zk/CMakeLists.txt | 26 - demos-go/cb-mpc-go/.gitignore | 1 - demos-go/cb-mpc-go/api/curve/curve.go | 232 ---- demos-go/cb-mpc-go/api/curve/curve_test.go | 198 --- demos-go/cb-mpc-go/api/curve/doc.go | 41 - demos-go/cb-mpc-go/api/curve/point.go | 92 -- demos-go/cb-mpc-go/api/curve/scalar.go | 60 - demos-go/cb-mpc-go/api/curve/scalar_test.go | 291 ----- .../cb-mpc-go/api/mpc/access_structure.go | 266 ---- demos-go/cb-mpc-go/api/mpc/agree_random.go | 35 - .../cb-mpc-go/api/mpc/agreer_andom_test.go | 208 --- demos-go/cb-mpc-go/api/mpc/doc.go | 83 -- demos-go/cb-mpc-go/api/mpc/ecdsa_2p.go | 203 --- demos-go/cb-mpc-go/api/mpc/ecdsa_2p_test.go | 508 ------- demos-go/cb-mpc-go/api/mpc/ecdsa_mp.go | 408 ------ demos-go/cb-mpc-go/api/mpc/ecdsa_mp_test.go | 348 ----- .../api/mpc/ecdsa_mp_threshold_test.go | 372 ------ .../api/mpc/ecdsa_mp_validation_test.go | 368 ------ demos-go/cb-mpc-go/api/mpc/eddsa_mp.go | 305 ----- demos-go/cb-mpc-go/api/mpc/eddsa_mp_test.go | 215 --- .../api/mpc/eddsa_mp_threshold_test.go | 355 ----- demos-go/cb-mpc-go/api/mpc/job.go | 57 - demos-go/cb-mpc-go/api/mpc/job_mp.go | 38 - demos-go/cb-mpc-go/api/mpc/pve.go | 176 --- demos-go/cb-mpc-go/api/mpc/pve_ac.go | 257 ---- demos-go/cb-mpc-go/api/mpc/pve_ac_test.go | 220 ---- demos-go/cb-mpc-go/api/mpc/pve_instance.go | 43 - demos-go/cb-mpc-go/api/mpc/pve_test.go | 773 ----------- demos-go/cb-mpc-go/api/mpc/secure.go | 23 - demos-go/cb-mpc-go/api/transport/doc.go | 24 - demos-go/cb-mpc-go/api/transport/messenger.go | 18 - .../cb-mpc-go/api/transport/mocknet/doc.go | 16 - .../cb-mpc-go/api/transport/mocknet/runner.go | 169 --- .../api/transport/mocknet/runner_test.go | 69 - .../api/transport/mocknet/transport.go | 114 -- .../cb-mpc-go/api/transport/mtls/transport.go | 321 ----- demos-go/cb-mpc-go/api/zk/doc.go | 41 - demos-go/cb-mpc-go/api/zk/zkdl.go | 112 -- demos-go/cb-mpc-go/api/zk/zkdl_test.go | 133 -- demos-go/cb-mpc-go/go.mod | 16 - demos-go/cb-mpc-go/go.sum | 12 - demos-go/cb-mpc-go/internal/cgobinding/ac.cpp | 57 - demos-go/cb-mpc-go/internal/cgobinding/ac.go | 48 - demos-go/cb-mpc-go/internal/cgobinding/ac.h | 38 - .../internal/cgobinding/agree_random.cpp | 46 - .../internal/cgobinding/agree_random.h | 18 - .../internal/cgobinding/agreerandom.go | 23 - .../cb-mpc-go/internal/cgobinding/cblib.h | 34 - .../cb-mpc-go/internal/cgobinding/cmem.go | 171 --- .../cb-mpc-go/internal/cgobinding/curve.cpp | 177 --- .../cb-mpc-go/internal/cgobinding/curve.go | 164 --- .../cb-mpc-go/internal/cgobinding/curve.h | 57 - .../cb-mpc-go/internal/cgobinding/ecdsa2p.cpp | 113 -- .../cb-mpc-go/internal/cgobinding/ecdsa2p.go | 88 -- .../cb-mpc-go/internal/cgobinding/ecdsa2p.h | 50 - .../cb-mpc-go/internal/cgobinding/ecdsamp.cpp | 58 - .../cb-mpc-go/internal/cgobinding/ecdsamp.go | 77 -- .../cb-mpc-go/internal/cgobinding/ecdsamp.h | 18 - .../cb-mpc-go/internal/cgobinding/eckeymp.cpp | 198 --- .../cb-mpc-go/internal/cgobinding/eckeymp.go | 245 ---- .../cb-mpc-go/internal/cgobinding/eckeymp.h | 51 - .../cb-mpc-go/internal/cgobinding/eddsamp.cpp | 32 - .../cb-mpc-go/internal/cgobinding/eddsamp.go | 25 - .../cb-mpc-go/internal/cgobinding/eddsamp.h | 15 - demos-go/cb-mpc-go/internal/cgobinding/kem.h | 24 - .../cb-mpc-go/internal/cgobinding/network.cpp | 342 ----- .../cb-mpc-go/internal/cgobinding/network.go | 496 ------- .../cb-mpc-go/internal/cgobinding/network.h | 73 -- .../cb-mpc-go/internal/cgobinding/pve.cpp | 382 ------ demos-go/cb-mpc-go/internal/cgobinding/pve.go | 317 ----- demos-go/cb-mpc-go/internal/cgobinding/pve.h | 52 - demos-go/cb-mpc-go/internal/cgobinding/zk.cpp | 33 - demos-go/cb-mpc-go/internal/cgobinding/zk.go | 44 - demos-go/cb-mpc-go/internal/cgobinding/zk.h | 18 - demos-go/cb-mpc-go/internal/curvemap/nid.go | 29 - .../cb-mpc-go/internal/testutil/testutil.go | 36 - demos-go/cmd/threshold-ecdsa-web/.gitignore | 10 - demos-go/cmd/threshold-ecdsa-web/Makefile | 72 - demos-go/cmd/threshold-ecdsa-web/README.md | 31 - .../certs/party-0/openssl-0.cnf | 23 - .../certs/party-1/openssl-1.cnf | 23 - .../certs/party-2/openssl-2.cnf | 23 - .../certs/party-3/openssl-3.cnf | 23 - .../cmd/threshold-ecdsa-web/config-0.yaml | 18 - .../cmd/threshold-ecdsa-web/config-1.yaml | 18 - .../cmd/threshold-ecdsa-web/config-2.yaml | 18 - .../cmd/threshold-ecdsa-web/config-3.yaml | 18 - demos-go/cmd/threshold-ecdsa-web/go.mod | 36 - demos-go/cmd/threshold-ecdsa-web/go.sum | 71 - demos-go/cmd/threshold-ecdsa-web/handlers.go | 337 ----- demos-go/cmd/threshold-ecdsa-web/main.go | 127 -- .../templates/dkg_base.html | 96 -- .../templates/dkg_connection_success.html | 47 - .../templates/dkg_connection_waiting.html | 149 --- .../templates/dkg_result.html | 14 - .../threshold-ecdsa-web/templates/error.html | 3 - .../templates/signing_immediate_waiting.html | 228 ---- .../templates/signing_leader_interface.html | 237 ---- .../templates/signing_result.html | 33 - demos-go/cmd/threshold-ecdsa-web/web.go | 741 ----------- demos-go/examples/access-structure/go.mod | 14 - demos-go/examples/access-structure/go.sum | 10 - demos-go/examples/access-structure/main.go | 36 - demos-go/examples/agreerandom/go.mod | 14 - demos-go/examples/agreerandom/go.sum | 10 - demos-go/examples/agreerandom/main.go | 92 -- demos-go/examples/ecdsa-2pc/go.mod | 17 - demos-go/examples/ecdsa-2pc/go.sum | 18 - demos-go/examples/ecdsa-2pc/main.go | 295 ----- .../examples/ecdsa-mpc-with-backup/go.mod | 18 - .../examples/ecdsa-mpc-with-backup/go.sum | 17 - .../examples/ecdsa-mpc-with-backup/main.go | 400 ------ .../examples/ecdsa-mpc-with-backup/party.go | 183 --- demos-go/examples/zk/go.mod | 14 - demos-go/examples/zk/go.sum | 10 - demos-go/examples/zk/main.go | 48 - docs/secure-usage.pdf | Bin 131 -> 0 bytes .../cbmpc/internal}/core/convert.h | 48 +- .../cbmpc/internal}/core/extended_uint.h | 0 .../cbmpc/internal}/core/log.h | 0 .../cbmpc/internal}/core/strext.h | 10 +- .../cbmpc/internal}/core/utils.h | 54 +- .../cbmpc/internal}/crypto/base.h | 20 +- .../cbmpc/internal}/crypto/base_bn.h | 41 +- .../cbmpc/internal}/crypto/base_ec_core.h | 2 +- .../cbmpc/internal}/crypto/base_ecc.h | 3 +- .../internal}/crypto/base_ecc_secp256k1.h | 2 +- .../cbmpc/internal}/crypto/base_eddsa.h | 4 +- .../cbmpc/internal}/crypto/base_hash.h | 7 +- .../cbmpc/internal}/crypto/base_mod.h | 12 +- .../cbmpc/internal}/crypto/base_paillier.h | 8 +- .../cbmpc/internal}/crypto/base_pki.h | 101 +- .../cbmpc/internal/crypto/base_rsa.h | 136 ++ .../cbmpc/internal}/crypto/commitment.h | 2 +- .../cbmpc/internal}/crypto/ec25519_core.h | 2 +- .../cbmpc/internal}/crypto/elgamal.h | 2 +- .../cbmpc/internal}/crypto/lagrange.h | 6 +- .../cbmpc/internal}/crypto/ro.h | 2 +- .../cbmpc/internal}/crypto/scope.h | 0 .../cbmpc/internal/crypto/secret_sharing.h | 195 +++ .../cbmpc/internal}/crypto/tdh2.h | 27 +- .../cbmpc/internal}/protocol/agree_random.h | 2 +- .../internal}/protocol/committed_broadcast.h | 4 +- .../cbmpc/internal}/protocol/ec_dkg.h | 35 +- .../cbmpc/internal}/protocol/ecdsa_2p.h | 8 +- .../cbmpc/internal}/protocol/ecdsa_mp.h | 14 +- .../cbmpc/internal}/protocol/eddsa.h | 6 +- .../internal}/protocol/hd_keyset_ecdsa_2p.h | 8 +- .../internal}/protocol/hd_keyset_eddsa_2p.h | 8 +- .../cbmpc/internal}/protocol/hd_tree_bip32.h | 2 +- .../cbmpc/internal}/protocol/int_commitment.h | 6 +- .../cbmpc/internal}/protocol/mpc_job.h | 64 +- .../cbmpc/internal}/protocol/ot.h | 8 +- .../cbmpc/internal/protocol/pve.h | 45 + .../cbmpc/internal/protocol/pve_ac.h | 101 ++ .../cbmpc/internal/protocol/pve_base.h | 258 ++++ .../cbmpc/internal/protocol/pve_batch.h | 80 ++ .../cbmpc/internal}/protocol/schnorr_2p.h | 10 +- .../cbmpc/internal}/protocol/schnorr_mp.h | 14 +- .../cbmpc/internal}/protocol/sid.h | 4 +- .../cbmpc/internal}/protocol/util.h | 4 +- .../cbmpc/internal}/zk/fischlin.h | 9 +- .../cbmpc/internal}/zk/small_primes.h | 6 +- .../cbmpc/internal}/zk/zk_ec.h | 4 +- .../cbmpc/internal}/zk/zk_elgamal_com.h | 8 +- .../cbmpc/internal}/zk/zk_paillier.h | 6 +- .../cbmpc/internal}/zk/zk_pedersen.h | 6 +- .../cbmpc/internal}/zk/zk_unknown_order.h | 2 +- .../cbmpc/internal}/zk/zk_util.h | 2 +- include/cbmpc/api/curve.h | 14 + include/cbmpc/api/ecdsa_2p.h | 80 ++ include/cbmpc/api/ecdsa_mp.h | 128 ++ include/cbmpc/api/eddsa_2p.h | 78 ++ include/cbmpc/api/eddsa_mp.h | 131 ++ include/cbmpc/api/hd_keyset_ecdsa_2p.h | 41 + include/cbmpc/api/hd_keyset_eddsa_2p.h | 46 + include/cbmpc/api/pve_base_pke.h | 164 +++ include/cbmpc/api/pve_batch_ac.h | 117 ++ .../cbmpc/api/pve_batch_single_recipient.h | 72 + include/cbmpc/api/schnorr_2p.h | 92 ++ include/cbmpc/api/schnorr_mp.h | 147 +++ include/cbmpc/api/tdh2.h | 75 ++ include/cbmpc/c_api/access_structure.h | 79 ++ include/cbmpc/c_api/cmem.h | 47 + include/cbmpc/c_api/common.h | 69 + include/cbmpc/c_api/ecdsa_2p.h | 89 ++ include/cbmpc/c_api/ecdsa_mp.h | 163 +++ include/cbmpc/c_api/eddsa_2p.h | 86 ++ include/cbmpc/c_api/eddsa_mp.h | 159 +++ include/cbmpc/c_api/job.h | 75 ++ include/cbmpc/c_api/pve_base_pke.h | 220 ++++ include/cbmpc/c_api/pve_batch_ac.h | 103 ++ .../cbmpc/c_api/pve_batch_single_recipient.h | 84 ++ include/cbmpc/c_api/schnorr_2p.h | 87 ++ include/cbmpc/c_api/schnorr_mp.h | 160 +++ include/cbmpc/c_api/tdh2.h | 91 ++ include/cbmpc/core/access_structure.h | 72 + include/cbmpc/core/bip32_path.h | 16 + {src => include}/cbmpc/core/buf.h | 52 +- {src => include}/cbmpc/core/buf128.h | 3 +- {src => include}/cbmpc/core/buf256.h | 8 +- {src => include}/cbmpc/core/error.h | 4 +- include/cbmpc/core/job.h | 67 + {src => include}/cbmpc/core/macros.h | 4 +- {src => include}/cbmpc/core/precompiled.h | 2 +- scripts/auto_build_cpp.sh | 59 - scripts/go_with_cpp.sh | 40 - scripts/install.sh | 124 +- scripts/make-release.sh | 4 +- scripts/openssl/build-static-openssl-linux.sh | 14 +- .../openssl/build-static-openssl-macos-m1.sh | 33 +- scripts/openssl/build-static-openssl-macos.sh | 33 +- scripts/run-demos.sh | 163 ++- src/cbmpc/api/CMakeLists.txt | 19 + src/cbmpc/api/access_structure_util.h | 227 ++++ src/cbmpc/api/curve_util.h | 40 + src/cbmpc/api/ecdsa2pc.cpp | 302 +++++ src/cbmpc/api/ecdsa_mp.cpp | 552 ++++++++ src/cbmpc/api/eddsa2pc.cpp | 224 ++++ src/cbmpc/api/eddsa_mp.cpp | 531 ++++++++ src/cbmpc/api/hd_keyset_ecdsa_2p.cpp | 226 ++++ src/cbmpc/api/hd_keyset_eddsa_2p.cpp | 201 +++ src/cbmpc/api/hd_keyset_util.h | 32 + src/cbmpc/api/job_util.h | 65 + src/cbmpc/api/mem_util.h | 55 + src/cbmpc/api/pve_base_pke.cpp | 373 ++++++ src/cbmpc/api/pve_batch_ac.cpp | 405 ++++++ src/cbmpc/api/pve_batch_single_recipient.cpp | 288 ++++ src/cbmpc/api/pve_internal.h | 168 +++ src/cbmpc/api/schnorr2pc.cpp | 234 ++++ src/cbmpc/api/schnorr_mp.cpp | 544 ++++++++ src/cbmpc/api/tdh2.cpp | 363 +++++ src/cbmpc/c_api/CMakeLists.txt | 18 + src/cbmpc/c_api/access_structure_adapter.h | 123 ++ src/cbmpc/c_api/common.cpp | 63 + src/cbmpc/c_api/ecdsa2pc.cpp | 262 ++++ src/cbmpc/c_api/ecdsa_mp.cpp | 454 +++++++ src/cbmpc/c_api/eddsa2pc.cpp | 218 +++ src/cbmpc/c_api/eddsa_mp.cpp | 439 +++++++ src/cbmpc/c_api/pve_base_pke.cpp | 565 ++++++++ src/cbmpc/c_api/pve_batch_ac.cpp | 387 ++++++ .../c_api/pve_batch_single_recipient.cpp | 276 ++++ src/cbmpc/c_api/pve_internal.h | 111 ++ src/cbmpc/c_api/schnorr2pc.cpp | 239 ++++ src/cbmpc/c_api/schnorr_mp.cpp | 461 +++++++ src/cbmpc/c_api/tdh2.cpp | 341 +++++ src/cbmpc/c_api/transport_adapter.h | 127 ++ src/cbmpc/c_api/util.h | 206 +++ src/cbmpc/core/CMakeLists.txt | 2 +- src/cbmpc/core/buf.cpp | 127 +- src/cbmpc/core/buf128.cpp | 9 +- src/cbmpc/core/buf256.cpp | 23 +- src/cbmpc/core/cmem.h | 21 - src/cbmpc/core/convert.cpp | 64 +- src/cbmpc/core/error.cpp | 7 +- src/cbmpc/core/extended_uint.cpp | 5 +- src/cbmpc/core/strext.cpp | 13 +- src/cbmpc/core/thread.h | 42 - src/cbmpc/crypto/CMakeLists.txt | 1 - src/cbmpc/crypto/base.cpp | 18 +- src/cbmpc/crypto/base_bn.cpp | 200 ++- src/cbmpc/crypto/base_ec_core.cpp | 2 +- src/cbmpc/crypto/base_ecc.cpp | 19 +- src/cbmpc/crypto/base_ecc_secp256k1.cpp | 4 +- src/cbmpc/crypto/base_eddsa.cpp | 6 +- src/cbmpc/crypto/base_hash.cpp | 64 +- src/cbmpc/crypto/base_mod.cpp | 43 +- src/cbmpc/crypto/base_paillier.cpp | 67 +- src/cbmpc/crypto/base_pki.cpp | 76 -- src/cbmpc/crypto/base_rsa.cpp | 233 +++- src/cbmpc/crypto/base_rsa.h | 139 -- src/cbmpc/crypto/base_rsa_oaep.cpp | 32 +- src/cbmpc/crypto/drbg.cpp | 4 +- src/cbmpc/crypto/ec25519_core.cpp | 11 +- src/cbmpc/crypto/elgamal.cpp | 7 +- src/cbmpc/crypto/lagrange.cpp | 2 +- src/cbmpc/crypto/pki_ffi.h | 35 - src/cbmpc/crypto/ro.cpp | 5 +- src/cbmpc/crypto/secret_sharing.cpp | 159 ++- src/cbmpc/crypto/secret_sharing.h | 176 --- src/cbmpc/crypto/tdh2.cpp | 23 +- src/cbmpc/ffi/CMakeLists.txt | 11 - src/cbmpc/ffi/cmem_adapter.cpp | 83 -- src/cbmpc/ffi/cmem_adapter.h | 35 - src/cbmpc/ffi/pki.cpp | 15 - src/cbmpc/ffi/pki.h | 114 -- src/cbmpc/protocol/CMakeLists.txt | 3 +- src/cbmpc/protocol/agree_random.cpp | 7 +- src/cbmpc/protocol/data_transport.h | 18 - src/cbmpc/protocol/ec_dkg.cpp | 109 +- src/cbmpc/protocol/ecdsa_2p.cpp | 22 +- src/cbmpc/protocol/ecdsa_mp.cpp | 28 +- src/cbmpc/protocol/eddsa.cpp | 2 +- src/cbmpc/protocol/hd_keyset_ecdsa_2p.cpp | 17 +- src/cbmpc/protocol/hd_keyset_eddsa_2p.cpp | 19 +- src/cbmpc/protocol/hd_tree_bip32.cpp | 2 +- src/cbmpc/protocol/int_commitment.cpp | 9 +- src/cbmpc/protocol/mpc_job.cpp | 6 +- src/cbmpc/protocol/ot.cpp | 8 +- src/cbmpc/protocol/pve.cpp | 83 +- src/cbmpc/protocol/pve.h | 90 -- src/cbmpc/protocol/pve_ac.cpp | 83 +- src/cbmpc/protocol/pve_ac.h | 101 -- src/cbmpc/protocol/pve_base.cpp | 17 +- src/cbmpc/protocol/pve_base.h | 79 -- src/cbmpc/protocol/pve_batch.cpp | 70 +- src/cbmpc/protocol/pve_batch.h | 108 -- src/cbmpc/protocol/schnorr_2p.cpp | 16 +- src/cbmpc/protocol/schnorr_mp.cpp | 28 +- src/cbmpc/zk/fischlin.cpp | 5 +- src/cbmpc/zk/small_primes.cpp | 6 +- src/cbmpc/zk/zk_ec.cpp | 13 +- src/cbmpc/zk/zk_elgamal_com.cpp | 10 +- src/cbmpc/zk/zk_paillier.cpp | 12 +- src/cbmpc/zk/zk_pedersen.cpp | 9 +- src/cbmpc/zk/zk_unknown_order.cpp | 7 +- tests/dudect/dudect_util/dudect.h | 74 +- .../dudect_util/dudect_implementation.h | 2 +- tests/public_headers_smoke.cpp | 44 + tests/unit/api/test_ecdsa2pc.cpp | 827 ++++++++++++ tests/unit/api/test_ecdsa_mp.cpp | 673 ++++++++++ tests/unit/api/test_ecdsa_mp_ac.cpp | 1163 +++++++++++++++++ tests/unit/api/test_eddsa2pc.cpp | 207 +++ tests/unit/api/test_eddsa_mp.cpp | 619 +++++++++ tests/unit/api/test_eddsa_mp_ac.cpp | 869 ++++++++++++ tests/unit/api/test_hd_keyset_ecdsa_2p.cpp | 582 +++++++++ tests/unit/api/test_hd_keyset_eddsa_2p.cpp | 458 +++++++ tests/unit/api/test_pve.cpp | 741 +++++++++++ tests/unit/api/test_pve_ac.cpp | 897 +++++++++++++ tests/unit/api/test_pve_batch.cpp | 426 ++++++ tests/unit/api/test_schnorr2pc.cpp | 523 ++++++++ tests/unit/api/test_schnorr_mp.cpp | 337 +++++ tests/unit/api/test_schnorr_mp_ac.cpp | 358 +++++ tests/unit/api/test_tdh2.cpp | 570 ++++++++ tests/unit/api/test_transport_harness.h | 100 ++ tests/unit/c_api/test_curve_validation.cpp | 69 + tests/unit/c_api/test_ecdsa2pc.cpp | 861 ++++++++++++ tests/unit/c_api/test_ecdsa_mp.cpp | 996 ++++++++++++++ tests/unit/c_api/test_ecdsa_mp_ac.cpp | 896 +++++++++++++ tests/unit/c_api/test_eddsa2pc.cpp | 144 ++ tests/unit/c_api/test_eddsa_mp.cpp | 933 +++++++++++++ tests/unit/c_api/test_eddsa_mp_ac.cpp | 777 +++++++++++ tests/unit/c_api/test_eddsa_mp_threshold.cpp | 230 ++++ tests/unit/c_api/test_pve.cpp | 721 ++++++++++ tests/unit/c_api/test_pve_ac.cpp | 739 +++++++++++ tests/unit/c_api/test_pve_batch.cpp | 293 +++++ tests/unit/c_api/test_schnorr2pc.cpp | 363 +++++ tests/unit/c_api/test_schnorr_mp.cpp | 445 +++++++ tests/unit/c_api/test_schnorr_mp_ac.cpp | 493 +++++++ .../unit/c_api/test_schnorr_mp_threshold.cpp | 231 ++++ tests/unit/c_api/test_tdh2.cpp | 657 ++++++++++ tests/unit/c_api/test_transport_harness.h | 179 +++ tests/unit/core/test_buf.cpp | 32 +- tests/unit/core/test_buf128.cpp | 7 + tests/unit/core/test_buf256.cpp | 8 + tests/unit/core/test_convert.cpp | 39 +- tests/unit/core/test_error.cpp | 2 + tests/unit/core/test_util.cpp | 6 +- tests/unit/crypto/test_base.cpp | 14 +- tests/unit/crypto/test_base_bn.cpp | 120 +- tests/unit/crypto/test_base_ecc.cpp | 11 +- tests/unit/crypto/test_base_hash.cpp | 4 +- tests/unit/crypto/test_base_mod.cpp | 10 +- tests/unit/crypto/test_base_pki.cpp | 114 +- tests/unit/crypto/test_base_rsa.cpp | 7 +- tests/unit/crypto/test_commitment.cpp | 9 +- tests/unit/crypto/test_ecc.cpp | 4 +- tests/unit/crypto/test_eddsa.cpp | 10 +- tests/unit/crypto/test_elgamal.cpp | 2 +- tests/unit/crypto/test_hkdf_rfc5869.cpp | 4 +- tests/unit/crypto/test_hpke_rfc9180_json.cpp | 4 +- tests/unit/crypto/test_lagrange.cpp | 2 +- tests/unit/crypto/test_ro.cpp | 5 +- tests/unit/crypto/test_secret_sharing.cpp | 83 +- tests/unit/crypto/test_tdh2.cpp | 25 +- tests/unit/protocol/test_agree_random.cpp | 4 +- tests/unit/protocol/test_broadcast.cpp | 4 +- tests/unit/protocol/test_ec_dkg.cpp | 54 +- tests/unit/protocol/test_ecdsa_2p.cpp | 50 +- tests/unit/protocol/test_ecdsa_mp.cpp | 67 +- tests/unit/protocol/test_hdmpc_ecdsa_2p.cpp | 69 +- tests/unit/protocol/test_hdmpc_eddsa_2p.cpp | 68 +- tests/unit/protocol/test_int_commitment.cpp | 6 +- tests/unit/protocol/test_mpc_network.cpp | 150 +-- tests/unit/protocol/test_ot.cpp | 6 +- .../protocol/test_parallel_transport_oob.cpp | 96 -- tests/unit/protocol/test_pve.cpp | 357 +++-- tests/unit/protocol/test_pve_ac.cpp | 104 +- tests/unit/protocol/test_schnorr_2p.cpp | 51 +- tests/unit/protocol/test_schnorr_mp.cpp | 44 +- tests/unit/protocol/test_util.cpp | 2 +- tests/unit/zk/test_zk.cpp | 2 +- tests/utils/crypto/nizk.h | 2 +- tests/utils/data/ac.h | 2 +- tests/utils/data/mpc_data_generator.h | 4 +- tests/utils/data/sampler/base.h | 6 +- tests/utils/data/sampler/bn.h | 8 +- tests/utils/data/sampler/buf.cpp | 11 +- tests/utils/data/sampler/buf.h | 4 +- tests/utils/data/sampler/ecp.h | 6 +- tests/utils/data/sampler/elgamal.h | 4 +- tests/utils/data/sampler/paillier.h | 4 +- tests/utils/data/tdh2.h | 17 +- tests/utils/data/test_data_factory.h | 12 +- tests/utils/data/test_node.h | 2 +- tests/utils/data/zk_completeness.h | 12 +- tests/utils/data/zk_data_generator.h | 4 +- tests/utils/local_network/channel.h | 2 +- tests/utils/local_network/mpc_runner.cpp | 43 +- tests/utils/local_network/mpc_runner.h | 10 +- tests/utils/local_network/network_context.h | 4 +- tools/benchmark/CMakeLists.txt | 38 +- tools/benchmark/benchmark.makefile | 6 +- tools/benchmark/bm_agree_random.cpp | 4 +- tools/benchmark/bm_commitment.cpp | 6 +- tools/benchmark/bm_core_bn.cpp | 2 +- tools/benchmark/bm_drbg.cpp | 5 +- tools/benchmark/bm_ecdsa.cpp | 4 +- tools/benchmark/bm_eddsa.cpp | 2 +- tools/benchmark/bm_elgamal.cpp | 2 +- tools/benchmark/bm_elliptic_curve.cpp | 4 +- tools/benchmark/bm_hash.cpp | 5 +- tools/benchmark/bm_ot.cpp | 4 +- tools/benchmark/bm_paillier.cpp | 2 +- tools/benchmark/bm_pve.cpp | 311 +++-- tools/benchmark/bm_share.cpp | 4 +- tools/benchmark/bm_sid.cpp | 4 +- tools/benchmark/bm_tdh2.cpp | 2 +- tools/benchmark/bm_test.cpp | 4 +- tools/benchmark/bm_zk.cpp | 4 +- tools/benchmark/mpc_util.h | 6 +- tools/benchmark/util.h | 2 +- 460 files changed, 36889 insertions(+), 18193 deletions(-) create mode 100644 BUG_BOUNTY.md create mode 100644 SECURE_USAGE.md create mode 100644 demo-api/ecdsa_mp_pve_backup/CMakeLists.txt create mode 100644 demo-api/ecdsa_mp_pve_backup/main.cpp create mode 100644 demo-api/hd_keyset_ecdsa_2p/CMakeLists.txt create mode 100644 demo-api/hd_keyset_ecdsa_2p/main.cpp create mode 100644 demo-api/pve/CMakeLists.txt create mode 100644 demo-api/pve/main.cpp create mode 100644 demo-api/schnorr_2p_pve_batch_backup/CMakeLists.txt create mode 100644 demo-api/schnorr_2p_pve_batch_backup/main.cpp create mode 100755 demo-cpp/basic_primitive/CMakeLists.txt rename {demos-cpp => demo-cpp}/basic_primitive/main.cpp (84%) rename {src/cbmpc/protocol => demo-cpp/common}/mpc_job_session.cpp (99%) rename {src/cbmpc/protocol => demo-cpp/common}/mpc_job_session.h (74%) create mode 100644 demo-cpp/parallel_transport/CMakeLists.txt create mode 100644 demo-cpp/parallel_transport/main.cpp create mode 100755 demo-cpp/zk/CMakeLists.txt rename {demos-cpp => demo-cpp}/zk/demo_nizk.h (91%) rename {demos-cpp => demo-cpp}/zk/main.cpp (87%) delete mode 100755 demos-cpp/basic_primitive/CMakeLists.txt delete mode 100755 demos-cpp/zk/CMakeLists.txt delete mode 100644 demos-go/cb-mpc-go/.gitignore delete mode 100644 demos-go/cb-mpc-go/api/curve/curve.go delete mode 100644 demos-go/cb-mpc-go/api/curve/curve_test.go delete mode 100644 demos-go/cb-mpc-go/api/curve/doc.go delete mode 100644 demos-go/cb-mpc-go/api/curve/point.go delete mode 100644 demos-go/cb-mpc-go/api/curve/scalar.go delete mode 100644 demos-go/cb-mpc-go/api/curve/scalar_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/access_structure.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/agree_random.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/agreer_andom_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/doc.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/ecdsa_2p.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/ecdsa_2p_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/ecdsa_mp.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/ecdsa_mp_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/ecdsa_mp_threshold_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/ecdsa_mp_validation_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/eddsa_mp.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/eddsa_mp_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/eddsa_mp_threshold_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/job.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/job_mp.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/pve.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/pve_ac.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/pve_ac_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/pve_instance.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/pve_test.go delete mode 100644 demos-go/cb-mpc-go/api/mpc/secure.go delete mode 100644 demos-go/cb-mpc-go/api/transport/doc.go delete mode 100644 demos-go/cb-mpc-go/api/transport/messenger.go delete mode 100644 demos-go/cb-mpc-go/api/transport/mocknet/doc.go delete mode 100644 demos-go/cb-mpc-go/api/transport/mocknet/runner.go delete mode 100644 demos-go/cb-mpc-go/api/transport/mocknet/runner_test.go delete mode 100644 demos-go/cb-mpc-go/api/transport/mocknet/transport.go delete mode 100644 demos-go/cb-mpc-go/api/transport/mtls/transport.go delete mode 100644 demos-go/cb-mpc-go/api/zk/doc.go delete mode 100644 demos-go/cb-mpc-go/api/zk/zkdl.go delete mode 100644 demos-go/cb-mpc-go/api/zk/zkdl_test.go delete mode 100644 demos-go/cb-mpc-go/go.mod delete mode 100644 demos-go/cb-mpc-go/go.sum delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ac.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ac.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ac.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/agree_random.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/agree_random.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/agreerandom.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/cblib.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/cmem.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/curve.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/curve.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/curve.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/eckeymp.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/eckeymp.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/eckeymp.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/eddsamp.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/eddsamp.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/eddsamp.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/kem.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/network.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/network.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/network.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/pve.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/pve.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/pve.h delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/zk.cpp delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/zk.go delete mode 100644 demos-go/cb-mpc-go/internal/cgobinding/zk.h delete mode 100644 demos-go/cb-mpc-go/internal/curvemap/nid.go delete mode 100644 demos-go/cb-mpc-go/internal/testutil/testutil.go delete mode 100644 demos-go/cmd/threshold-ecdsa-web/.gitignore delete mode 100644 demos-go/cmd/threshold-ecdsa-web/Makefile delete mode 100644 demos-go/cmd/threshold-ecdsa-web/README.md delete mode 100644 demos-go/cmd/threshold-ecdsa-web/certs/party-0/openssl-0.cnf delete mode 100644 demos-go/cmd/threshold-ecdsa-web/certs/party-1/openssl-1.cnf delete mode 100644 demos-go/cmd/threshold-ecdsa-web/certs/party-2/openssl-2.cnf delete mode 100644 demos-go/cmd/threshold-ecdsa-web/certs/party-3/openssl-3.cnf delete mode 100644 demos-go/cmd/threshold-ecdsa-web/config-0.yaml delete mode 100644 demos-go/cmd/threshold-ecdsa-web/config-1.yaml delete mode 100644 demos-go/cmd/threshold-ecdsa-web/config-2.yaml delete mode 100644 demos-go/cmd/threshold-ecdsa-web/config-3.yaml delete mode 100644 demos-go/cmd/threshold-ecdsa-web/go.mod delete mode 100644 demos-go/cmd/threshold-ecdsa-web/go.sum delete mode 100644 demos-go/cmd/threshold-ecdsa-web/handlers.go delete mode 100644 demos-go/cmd/threshold-ecdsa-web/main.go delete mode 100644 demos-go/cmd/threshold-ecdsa-web/templates/dkg_base.html delete mode 100644 demos-go/cmd/threshold-ecdsa-web/templates/dkg_connection_success.html delete mode 100644 demos-go/cmd/threshold-ecdsa-web/templates/dkg_connection_waiting.html delete mode 100644 demos-go/cmd/threshold-ecdsa-web/templates/dkg_result.html delete mode 100644 demos-go/cmd/threshold-ecdsa-web/templates/error.html delete mode 100644 demos-go/cmd/threshold-ecdsa-web/templates/signing_immediate_waiting.html delete mode 100644 demos-go/cmd/threshold-ecdsa-web/templates/signing_leader_interface.html delete mode 100644 demos-go/cmd/threshold-ecdsa-web/templates/signing_result.html delete mode 100644 demos-go/cmd/threshold-ecdsa-web/web.go delete mode 100644 demos-go/examples/access-structure/go.mod delete mode 100644 demos-go/examples/access-structure/go.sum delete mode 100644 demos-go/examples/access-structure/main.go delete mode 100644 demos-go/examples/agreerandom/go.mod delete mode 100644 demos-go/examples/agreerandom/go.sum delete mode 100644 demos-go/examples/agreerandom/main.go delete mode 100644 demos-go/examples/ecdsa-2pc/go.mod delete mode 100644 demos-go/examples/ecdsa-2pc/go.sum delete mode 100644 demos-go/examples/ecdsa-2pc/main.go delete mode 100644 demos-go/examples/ecdsa-mpc-with-backup/go.mod delete mode 100644 demos-go/examples/ecdsa-mpc-with-backup/go.sum delete mode 100644 demos-go/examples/ecdsa-mpc-with-backup/main.go delete mode 100644 demos-go/examples/ecdsa-mpc-with-backup/party.go delete mode 100644 demos-go/examples/zk/go.mod delete mode 100644 demos-go/examples/zk/go.sum delete mode 100644 demos-go/examples/zk/main.go delete mode 100644 docs/secure-usage.pdf rename {src/cbmpc => include-internal/cbmpc/internal}/core/convert.h (87%) rename {src/cbmpc => include-internal/cbmpc/internal}/core/extended_uint.h (100%) rename {src/cbmpc => include-internal/cbmpc/internal}/core/log.h (100%) rename {src/cbmpc => include-internal/cbmpc/internal}/core/strext.h (90%) rename {src/cbmpc => include-internal/cbmpc/internal}/core/utils.h (65%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base.h (93%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_bn.h (76%) mode change 100755 => 100644 rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_ec_core.h (99%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_ecc.h (99%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_ecc_secp256k1.h (98%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_eddsa.h (96%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_hash.h (98%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_mod.h (93%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_paillier.h (94%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/base_pki.h (74%) create mode 100644 include-internal/cbmpc/internal/crypto/base_rsa.h rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/commitment.h (99%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/ec25519_core.h (97%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/elgamal.h (98%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/lagrange.h (94%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/ro.h (98%) rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/scope.h (100%) create mode 100644 include-internal/cbmpc/internal/crypto/secret_sharing.h rename {src/cbmpc => include-internal/cbmpc/internal}/crypto/tdh2.h (82%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/agree_random.h (96%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/committed_broadcast.h (96%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/ec_dkg.h (73%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/ecdsa_2p.h (96%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/ecdsa_mp.h (89%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/eddsa.h (82%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/hd_keyset_ecdsa_2p.h (87%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/hd_keyset_eddsa_2p.h (86%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/hd_tree_bip32.h (97%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/int_commitment.h (86%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/mpc_job.h (88%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/ot.h (97%) create mode 100644 include-internal/cbmpc/internal/protocol/pve.h create mode 100644 include-internal/cbmpc/internal/protocol/pve_ac.h create mode 100644 include-internal/cbmpc/internal/protocol/pve_base.h create mode 100644 include-internal/cbmpc/internal/protocol/pve_batch.h rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/schnorr_2p.h (71%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/schnorr_mp.h (68%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/sid.h (94%) rename {src/cbmpc => include-internal/cbmpc/internal}/protocol/util.h (94%) rename {src/cbmpc => include-internal/cbmpc/internal}/zk/fischlin.h (88%) rename {src/cbmpc => include-internal/cbmpc/internal}/zk/small_primes.h (81%) rename {src/cbmpc => include-internal/cbmpc/internal}/zk/zk_ec.h (96%) rename {src/cbmpc => include-internal/cbmpc/internal}/zk/zk_elgamal_com.h (94%) rename {src/cbmpc => include-internal/cbmpc/internal}/zk/zk_paillier.h (98%) rename {src/cbmpc => include-internal/cbmpc/internal}/zk/zk_pedersen.h (96%) rename {src/cbmpc => include-internal/cbmpc/internal}/zk/zk_unknown_order.h (93%) rename {src/cbmpc => include-internal/cbmpc/internal}/zk/zk_util.h (97%) create mode 100644 include/cbmpc/api/curve.h create mode 100644 include/cbmpc/api/ecdsa_2p.h create mode 100644 include/cbmpc/api/ecdsa_mp.h create mode 100644 include/cbmpc/api/eddsa_2p.h create mode 100644 include/cbmpc/api/eddsa_mp.h create mode 100644 include/cbmpc/api/hd_keyset_ecdsa_2p.h create mode 100644 include/cbmpc/api/hd_keyset_eddsa_2p.h create mode 100644 include/cbmpc/api/pve_base_pke.h create mode 100644 include/cbmpc/api/pve_batch_ac.h create mode 100644 include/cbmpc/api/pve_batch_single_recipient.h create mode 100644 include/cbmpc/api/schnorr_2p.h create mode 100644 include/cbmpc/api/schnorr_mp.h create mode 100644 include/cbmpc/api/tdh2.h create mode 100644 include/cbmpc/c_api/access_structure.h create mode 100644 include/cbmpc/c_api/cmem.h create mode 100644 include/cbmpc/c_api/common.h create mode 100644 include/cbmpc/c_api/ecdsa_2p.h create mode 100644 include/cbmpc/c_api/ecdsa_mp.h create mode 100644 include/cbmpc/c_api/eddsa_2p.h create mode 100644 include/cbmpc/c_api/eddsa_mp.h create mode 100644 include/cbmpc/c_api/job.h create mode 100644 include/cbmpc/c_api/pve_base_pke.h create mode 100644 include/cbmpc/c_api/pve_batch_ac.h create mode 100644 include/cbmpc/c_api/pve_batch_single_recipient.h create mode 100644 include/cbmpc/c_api/schnorr_2p.h create mode 100644 include/cbmpc/c_api/schnorr_mp.h create mode 100644 include/cbmpc/c_api/tdh2.h create mode 100644 include/cbmpc/core/access_structure.h create mode 100644 include/cbmpc/core/bip32_path.h rename {src => include}/cbmpc/core/buf.h (80%) rename {src => include}/cbmpc/core/buf128.h (97%) rename {src => include}/cbmpc/core/buf256.h (94%) rename {src => include}/cbmpc/core/error.h (97%) create mode 100644 include/cbmpc/core/job.h rename {src => include}/cbmpc/core/macros.h (94%) rename {src => include}/cbmpc/core/precompiled.h (98%) delete mode 100644 scripts/auto_build_cpp.sh delete mode 100644 scripts/go_with_cpp.sh create mode 100644 src/cbmpc/api/CMakeLists.txt create mode 100644 src/cbmpc/api/access_structure_util.h create mode 100644 src/cbmpc/api/curve_util.h create mode 100644 src/cbmpc/api/ecdsa2pc.cpp create mode 100644 src/cbmpc/api/ecdsa_mp.cpp create mode 100644 src/cbmpc/api/eddsa2pc.cpp create mode 100644 src/cbmpc/api/eddsa_mp.cpp create mode 100644 src/cbmpc/api/hd_keyset_ecdsa_2p.cpp create mode 100644 src/cbmpc/api/hd_keyset_eddsa_2p.cpp create mode 100644 src/cbmpc/api/hd_keyset_util.h create mode 100644 src/cbmpc/api/job_util.h create mode 100644 src/cbmpc/api/mem_util.h create mode 100644 src/cbmpc/api/pve_base_pke.cpp create mode 100644 src/cbmpc/api/pve_batch_ac.cpp create mode 100644 src/cbmpc/api/pve_batch_single_recipient.cpp create mode 100644 src/cbmpc/api/pve_internal.h create mode 100644 src/cbmpc/api/schnorr2pc.cpp create mode 100644 src/cbmpc/api/schnorr_mp.cpp create mode 100644 src/cbmpc/api/tdh2.cpp create mode 100644 src/cbmpc/c_api/CMakeLists.txt create mode 100644 src/cbmpc/c_api/access_structure_adapter.h create mode 100644 src/cbmpc/c_api/common.cpp create mode 100644 src/cbmpc/c_api/ecdsa2pc.cpp create mode 100644 src/cbmpc/c_api/ecdsa_mp.cpp create mode 100644 src/cbmpc/c_api/eddsa2pc.cpp create mode 100644 src/cbmpc/c_api/eddsa_mp.cpp create mode 100644 src/cbmpc/c_api/pve_base_pke.cpp create mode 100644 src/cbmpc/c_api/pve_batch_ac.cpp create mode 100644 src/cbmpc/c_api/pve_batch_single_recipient.cpp create mode 100644 src/cbmpc/c_api/pve_internal.h create mode 100644 src/cbmpc/c_api/schnorr2pc.cpp create mode 100644 src/cbmpc/c_api/schnorr_mp.cpp create mode 100644 src/cbmpc/c_api/tdh2.cpp create mode 100644 src/cbmpc/c_api/transport_adapter.h create mode 100644 src/cbmpc/c_api/util.h mode change 100755 => 100644 src/cbmpc/core/buf256.cpp delete mode 100644 src/cbmpc/core/cmem.h mode change 100755 => 100644 src/cbmpc/core/error.cpp delete mode 100755 src/cbmpc/core/thread.h delete mode 100644 src/cbmpc/crypto/base_pki.cpp delete mode 100644 src/cbmpc/crypto/base_rsa.h delete mode 100644 src/cbmpc/crypto/pki_ffi.h delete mode 100644 src/cbmpc/crypto/secret_sharing.h delete mode 100644 src/cbmpc/ffi/CMakeLists.txt delete mode 100644 src/cbmpc/ffi/cmem_adapter.cpp delete mode 100644 src/cbmpc/ffi/cmem_adapter.h delete mode 100644 src/cbmpc/ffi/pki.cpp delete mode 100644 src/cbmpc/ffi/pki.h delete mode 100644 src/cbmpc/protocol/data_transport.h delete mode 100644 src/cbmpc/protocol/pve.h delete mode 100644 src/cbmpc/protocol/pve_ac.h delete mode 100644 src/cbmpc/protocol/pve_base.h delete mode 100644 src/cbmpc/protocol/pve_batch.h create mode 100644 tests/public_headers_smoke.cpp create mode 100644 tests/unit/api/test_ecdsa2pc.cpp create mode 100644 tests/unit/api/test_ecdsa_mp.cpp create mode 100644 tests/unit/api/test_ecdsa_mp_ac.cpp create mode 100644 tests/unit/api/test_eddsa2pc.cpp create mode 100644 tests/unit/api/test_eddsa_mp.cpp create mode 100644 tests/unit/api/test_eddsa_mp_ac.cpp create mode 100644 tests/unit/api/test_hd_keyset_ecdsa_2p.cpp create mode 100644 tests/unit/api/test_hd_keyset_eddsa_2p.cpp create mode 100644 tests/unit/api/test_pve.cpp create mode 100644 tests/unit/api/test_pve_ac.cpp create mode 100644 tests/unit/api/test_pve_batch.cpp create mode 100644 tests/unit/api/test_schnorr2pc.cpp create mode 100644 tests/unit/api/test_schnorr_mp.cpp create mode 100644 tests/unit/api/test_schnorr_mp_ac.cpp create mode 100644 tests/unit/api/test_tdh2.cpp create mode 100644 tests/unit/api/test_transport_harness.h create mode 100644 tests/unit/c_api/test_curve_validation.cpp create mode 100644 tests/unit/c_api/test_ecdsa2pc.cpp create mode 100644 tests/unit/c_api/test_ecdsa_mp.cpp create mode 100644 tests/unit/c_api/test_ecdsa_mp_ac.cpp create mode 100644 tests/unit/c_api/test_eddsa2pc.cpp create mode 100644 tests/unit/c_api/test_eddsa_mp.cpp create mode 100644 tests/unit/c_api/test_eddsa_mp_ac.cpp create mode 100644 tests/unit/c_api/test_eddsa_mp_threshold.cpp create mode 100644 tests/unit/c_api/test_pve.cpp create mode 100644 tests/unit/c_api/test_pve_ac.cpp create mode 100644 tests/unit/c_api/test_pve_batch.cpp create mode 100644 tests/unit/c_api/test_schnorr2pc.cpp create mode 100644 tests/unit/c_api/test_schnorr_mp.cpp create mode 100644 tests/unit/c_api/test_schnorr_mp_ac.cpp create mode 100644 tests/unit/c_api/test_schnorr_mp_threshold.cpp create mode 100644 tests/unit/c_api/test_tdh2.cpp create mode 100644 tests/unit/c_api/test_transport_harness.h delete mode 100644 tests/unit/protocol/test_parallel_transport_oob.cpp diff --git a/.clang-format b/.clang-format index e483f28d..655767e1 100644 --- a/.clang-format +++ b/.clang-format @@ -5,18 +5,20 @@ Standard: c++17 ColumnLimit: 120 IncludeCategories: - - Regex: "" - Priority: -2 - SortPriority: -2 - - Regex: "" - Priority: -1 - SortPriority: -1 + # Prefer a simple, stable include ordering: + # 1) System/third-party headers: <...> + # 2) Project headers: + # 3) Local/relative headers: "..." + # + # Public headers are required to be self-contained and must not rely on + # include ordering (enforced by build/tests), so we avoid special-casing + # specific headers here. - Regex: "^$" Priority: 2 SortPriority: 2 - Regex: '^".*"$' - Priority: 6 - SortPriority: 6 + Priority: 3 + SortPriority: 3 - Regex: "^<.*>$" Priority: 1 SortPriority: 1 diff --git a/.gitignore b/.gitignore index 4550c774..e4578c96 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ a.out # CMake generated files **build/ +cmake_test_discovery_*.json # ctest files **Testing/ diff --git a/BUG_BOUNTY.md b/BUG_BOUNTY.md new file mode 100644 index 00000000..d8396e9f --- /dev/null +++ b/BUG_BOUNTY.md @@ -0,0 +1,26 @@ +# CB-MPC (Coinbase Multi-Party Computation) Open Source Release + +Coinbase is proud to announce the open-sourcing of our MPC cryptography library! You can access it here: https://github.com/coinbase/cb-mpc. This significant milestone underscores our commitment to transparency, security, and promoting innovation within the cryptographic community. + +With this release, we aim to: + +* Enhance the security of the field by enabling developers to quickly deploy threshold signing/MPC for protecting cryptoassets in their applications. +* Increase transparency regarding Coinbase’s use of MPC, and encourage collaboration within the developer community. + +Note that while the code is based on Coinbase's production environment, it is not exactly the same, and it has been modified to make it useful as a general-purpose library. + +The primary focus of our bug bounty program will include identifying and addressing potential vulnerabilities in our open-source MPC implementation. Given the sensitive nature of these cryptographic protocols, it's imperative to safeguard against any exploits that could compromise cryptoassets. Responsible disclosure via the Bug Bounty Program or directly is encouraged (for direct disclosure see https://github.com/coinbase/cb-mpc/blob/master/SECURITY.md). + +Through community collaboration and vigilant security reviews, we aspire to provide an easy to use and highly secure MPC library to help developers secure cryptoassets across the entire cryptocurrency and blockchain ecosystem. + +To keep this bounty focused on issues that affect real integrations, eligible reports should target vulnerabilities reachable through the library's supported public APIs. High-level protocol entry points are exposed via the public C++ headers under `include/cbmpc/api/` (e.g., signing, DKG, TDH2). + +For **Medium** and above, submissions must include a proof-of-concept that triggers the issue through those public APIs. Reports may reference or require fixes in `include-internal/` for root cause and impact analysis, but the PoC must not use `include-internal/` as the entry point. Demo applications and sample code under `demo-*`, and the C API headers under `include/cbmpc/c_api/*`, are not in scope for this bug bounty program. + +| Vulnerability Tier | Description | Reward | +|:-------------------|:--------------------------------|:------------------------------------------| +| **Extreme** | Open Source Bugs (cb-mpc): not applicable | Up to $1,000,000 | +| **Critical** | Open Source Bugs (cb-mpc): High-severity vulnerabilities in supported high-level protocols from the public API (e.g., Signing, DKG, TDH2) that are easily exploitable and can lead to key compromise. Examples: significant disclosure of sensitive data (key material), remote code execution. Private by default; triggers new releases for all supported versions. | $50,000 | +| **High** | Open Source Bugs (cb-mpc): High-severity vulnerabilities in supported high-level protocols from the public API that are less easily exploitable. Private by default; triggers new release for all supported versions within a reasonable timeframe. | $15,000 | +| **Medium** | Open Source Bugs (cb-mpc): Vulnerabilities that are hard to exploit, limited in impact, or present in less commonly used scenarios, but are still demonstrable via the supported public APIs. Bugs in lower-level cryptographic primitives (e.g., ZKPs, commitments) are eligible when reachable from those protocols. Private until next release; released with subsequent updates. | $2,000 | +| **Low** | Open Source Bugs (cb-mpc): Non-cryptographic issues including low-level non-cryptographic APIs, crashes, or deprecated cryptographic code. Any vulnerability in code that is released under “beta” is always low. Fixed immediately in latest development versions; may be backported to older versions. | $200 | diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a574cce..9ab949f9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,21 +84,57 @@ else() set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${BIN_DIR}/${CMAKE_BUILD_TYPE}) endif() +include_directories(${ROOT_DIR}/include) +include_directories(${ROOT_DIR}/include-internal) include_directories(${SRC_DIR}) add_subdirectory(src/cbmpc/core) add_subdirectory(src/cbmpc/crypto) add_subdirectory(src/cbmpc/zk) add_subdirectory(src/cbmpc/protocol) -add_subdirectory(src/cbmpc/ffi) +add_subdirectory(src/cbmpc/api) +add_subdirectory(src/cbmpc/c_api) add_library( cbmpc STATIC $ $ $ $ - $) + $ + $) link_openssl(cbmpc) +# ------------- Public headers smoke check --------------------------- +# +# Compile a TU that includes all public headers with include paths limited to +# `include/` (plus OpenSSL headers). This ensures public headers do not +# accidentally depend on `include-internal/`. +# +if(NOT DEFINED CBMPC_OPENSSL_ROOT) + if(DEFINED ENV{CBMPC_OPENSSL_ROOT}) + set(CBMPC_OPENSSL_ROOT $ENV{CBMPC_OPENSSL_ROOT}) + else() + set(CBMPC_OPENSSL_ROOT "/usr/local/opt/openssl@3.6.1") + endif() +endif() + +set(_cbmpc_public_headers_smoke_src "${ROOT_DIR}/tests/public_headers_smoke.cpp") +set(_cbmpc_public_headers_smoke_obj "${CMAKE_BINARY_DIR}/cbmpc_public_headers_smoke.o") + +add_custom_command( + OUTPUT "${_cbmpc_public_headers_smoke_obj}" + COMMAND "${CMAKE_CXX_COMPILER}" + -std=c++17 + "-I${ROOT_DIR}/include" + "-I${CBMPC_OPENSSL_ROOT}/include" + -c "${_cbmpc_public_headers_smoke_src}" + -o "${_cbmpc_public_headers_smoke_obj}" + DEPENDS "${_cbmpc_public_headers_smoke_src}" + COMMENT "Compiling public headers smoke TU" + VERBATIM +) + +add_custom_target(public-only-check DEPENDS "${_cbmpc_public_headers_smoke_obj}") + # ------------- Tests --------------------------- if(NOT IS_MACOS) @@ -107,5 +143,7 @@ endif() if(BUILD_TESTS) enable_testing() + add_test(NAME PublicHeadersSmoke COMMAND "${CMAKE_COMMAND}" --build "${CMAKE_BINARY_DIR}" --target public-only-check) + set_tests_properties(PublicHeadersSmoke PROPERTIES RUN_SERIAL TRUE LABELS "unit") add_subdirectory(tests) endif() diff --git a/Dockerfile b/Dockerfile index 1043808d..6fb094ae 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,6 +25,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ clang-20 \ clang-format-20 \ lld-20 \ + llvm-20 \ libfuzzer-20-dev \ libclang-rt-20-dev \ && rm -rf /var/lib/apt/lists/* \ @@ -36,9 +37,9 @@ WORKDIR /build COPY scripts/openssl/build-static-openssl-linux.sh . RUN sh build-static-openssl-linux.sh \ && mkdir -p /usr/local/lib64 /usr/local/lib /usr/local/include \ - && ln -sf /usr/local/opt/openssl@3.2.0/lib64/libcrypto.a /usr/local/lib64/libcrypto.a \ - && ln -sf /usr/local/opt/openssl@3.2.0/lib64/libcrypto.a /usr/local/lib/libcrypto.a \ - && ln -sf /usr/local/opt/openssl@3.2.0/include/openssl /usr/local/include/openssl \ + && ln -sf /usr/local/opt/openssl@3.6.1/lib64/libcrypto.a /usr/local/lib64/libcrypto.a \ + && ln -sf /usr/local/opt/openssl@3.6.1/lib64/libcrypto.a /usr/local/lib/libcrypto.a \ + && ln -sf /usr/local/opt/openssl@3.6.1/include/openssl /usr/local/include/openssl \ && rm -rf /build WORKDIR /code diff --git a/Makefile b/Makefile index bbe5421d..2e30e584 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,12 @@ DOCKER_IMAGE_NAME := cb-mpc RUN_CMD := bash -c .DEFAULT_GOAL := ghas +BUILD_TYPE ?= Release + +# Local install layout (no sudo by default) +CBMPC_INSTALL_ROOT ?= $(CURDIR)/build/install +CBMPC_PREFIX_PUBLIC ?= $(CBMPC_INSTALL_ROOT)/public +CBMPC_PREFIX_FULL ?= $(CBMPC_INSTALL_ROOT)/full CMAKE_NCORES := $(shell \ if [ -n "$${ARG_CMAKE_NCORES}" ]; then \ echo "$${ARG_CMAKE_NCORES}"; \ @@ -20,6 +26,35 @@ TEST_NCORES := $(shell \ fi) TEST_REPEAT ?= 1 +# clang-format targets +CLANG_FORMAT_PATHS := src tests include include-internal +CLANG_FORMAT_FILE_FILTER := \( -name '*.cpp' -o -name '*.h' \) +CLANG_FORMAT_FIND := find $(CLANG_FORMAT_PATHS) -type f $(CLANG_FORMAT_FILE_FILTER) + +# Sanitizer settings (mirrors the CI "sanitizers" workflow job). +# IMPORTANT: CMake caches the *absolute* source dir in CMakeCache.txt. +# If you reuse the same build dir from both host and Docker, you'll hit: +# "CMakeCache.txt directory ... is different than the directory ... where it was created" +# because the repo path differs (e.g. /Users/... vs /code). +# +# Default to a docker-specific build dir when running inside Docker so the two +# environments can coexist without `make clean`. +_CBMPC_IN_DOCKER := $(shell if [ -f /.dockerenv ]; then echo 1; else echo 0; fi) +SANITIZE_BUILD_DIR ?= $(if $(filter 1,$(_CBMPC_IN_DOCKER)),build/sanitize-docker,build/sanitize) +SANITIZE_BUILD_TYPE ?= Debug +# -fno-sanitize=enum: allows tests that intentionally pass invalid enum values via the C API. +SANITIZE_CFLAGS ?= -O1 -g -fno-omit-frame-pointer -fsanitize=address,undefined -fno-sanitize=enum +SANITIZE_CXXFLAGS ?= -O1 -g -fno-omit-frame-pointer -fsanitize=address,undefined -fno-sanitize=enum +SANITIZE_LDFLAGS ?= -fsanitize=address,undefined +# Avoid cross-build-dir clobbering: the project archives into `${ROOT_DIR}/lib/${CMAKE_BUILD_TYPE}` by default, +# which is shared between host and Docker (and between different Debug variants). For sanitizer builds, enable +# platform-dependent library output dirs so Linux vs macOS (and arm64 vs x86_64) can coexist. +SANITIZE_CMAKE_ARGS ?= -DCBMPC_PLATFORM_DEP_OUTPUT_DIR=ON +SANITIZE_TEST_LABEL ?= unit +SANITIZE_TEST_NCORES ?= 1 +SANITIZE_ASAN_OPTIONS ?= detect_leaks=0 +SANITIZE_UBSAN_OPTIONS ?= print_stacktrace=1:halt_on_error=1 + .PHONY: ghas ghas: submodules openssl-linux build @@ -84,13 +119,15 @@ image: .PHONY: lint-fix lint-fix: - find src/ -name '*.cpp' -o -name '*.h' | xargs clang-format -i - find tests/ -name '*.cpp' -o -name '*.h' | xargs clang-format -i + $(CLANG_FORMAT_FIND) -exec clang-format -i {} + .PHONY: lint lint: - find src/ -name '*.cpp' -o -name '*.h' | xargs clang-format -n 2>&1 | grep -q "^" && exit 1 || exit 0 - find tests/ -name '*.cpp' -o -name '*.h' | xargs clang-format -n 2>&1 | grep -q "^" && exit 1 || exit 0 + @output="$$( $(CLANG_FORMAT_FIND) -exec clang-format -n {} + 2>&1 )"; \ + if [ -n "$$output" ]; then \ + echo "$$output"; \ + exit 1; \ + fi .PHONY: build build: BUILD_TYPE = Release# (Release/Debug/RelWithDebInfo) @@ -138,6 +175,33 @@ dudect: -E DUDECT_VT \ $(if $(filter),-R $(filter))' +.PHONY: sanitize +sanitize: + ${RUN_CMD} \ + 'set -e; \ + src_dir="$$(pwd -P)"; \ + cache_path="$(SANITIZE_BUILD_DIR)/CMakeCache.txt"; \ + if [ -f "$$cache_path" ] && ! grep -Fq "CMAKE_HOME_DIRECTORY:INTERNAL=$$src_dir" "$$cache_path"; then \ + echo "sanitize: removing stale build dir '$(SANITIZE_BUILD_DIR)' (source dir mismatch)"; \ + rm -rf "$(SANITIZE_BUILD_DIR)"; \ + fi; \ + cmake -B "$(SANITIZE_BUILD_DIR)" -DCMAKE_BUILD_TYPE="$(SANITIZE_BUILD_TYPE)" -DBUILD_TESTS=ON -DBUILD_DUDECT=OFF \ + $(SANITIZE_CMAKE_ARGS) \ + -DCMAKE_C_FLAGS="$(SANITIZE_CFLAGS)" -DCMAKE_CXX_FLAGS="$(SANITIZE_CXXFLAGS)" \ + -DCMAKE_EXE_LINKER_FLAGS="$(SANITIZE_LDFLAGS)" -DCMAKE_SHARED_LINKER_FLAGS="$(SANITIZE_LDFLAGS)" && \ + cmake --build "$(SANITIZE_BUILD_DIR)" -- -j$(CMAKE_NCORES)' + ${RUN_CMD} \ + 'ASAN_OPTIONS="$(SANITIZE_ASAN_OPTIONS)" UBSAN_OPTIONS="$(SANITIZE_UBSAN_OPTIONS)" \ + ctest --output-on-failure --repeat until-fail:$(TEST_REPEAT) -j$(SANITIZE_TEST_NCORES) --test-dir "$(SANITIZE_BUILD_DIR)" \ + -L "$(SANITIZE_TEST_LABEL)" \ + $(if $(filter),-R $(filter))' + +.PHONY: sanitize-docker +sanitize-docker: + $(MAKE) submodules + $(MAKE) image + docker run --rm -v $(shell pwd):/code -t ${DOCKER_IMAGE_NAME} bash -c 'make sanitize' + .PHONY: clean clean: ${RUN_CMD} 'rm -rf build' @@ -146,17 +210,31 @@ clean: ### Install the C++ library to local (this is necessary for demo and benchmark) .PHONY: install install: - ${RUN_CMD} 'scripts/install.sh' - ${RUN_CMD} 'ln -sf /usr/local/lib /usr/local/lib64' + $(MAKE) build-no-test BUILD_TYPE=$(BUILD_TYPE) + ${RUN_CMD} 'scripts/install.sh --mode public --build-type $(BUILD_TYPE) --prefix "$(CBMPC_PREFIX_PUBLIC)"' + +.PHONY: install-full +install-full: + $(MAKE) build-no-test BUILD_TYPE=$(BUILD_TYPE) + ${RUN_CMD} 'scripts/install.sh --mode full --build-type $(BUILD_TYPE) --prefix "$(CBMPC_PREFIX_FULL)"' + +.PHONY: install-all +install-all: + $(MAKE) build-no-test BUILD_TYPE=$(BUILD_TYPE) + ${RUN_CMD} 'scripts/install.sh --mode public --build-type $(BUILD_TYPE) --prefix "$(CBMPC_PREFIX_PUBLIC)"' + ${RUN_CMD} 'scripts/install.sh --mode full --build-type $(BUILD_TYPE) --prefix "$(CBMPC_PREFIX_FULL)"' .PHONY: uninstall uninstall: - ${RUN_CMD} 'rm -rf /usr/local/opt/cbmpc' + ${RUN_CMD} 'rm -rf "$(CBMPC_INSTALL_ROOT)"' ### Demos +.PHONY: demo +demo: install-all + ${RUN_CMD} 'CBMPC_PREFIX_PUBLIC="$(CBMPC_PREFIX_PUBLIC)" CBMPC_PREFIX_FULL="$(CBMPC_PREFIX_FULL)" BUILD_TYPE="$(BUILD_TYPE)" bash scripts/run-demos.sh --run-all' + .PHONY: demos -demos: - ${RUN_CMD} 'BUILD_TYPE=${BUILD_TYPE:-Release} bash scripts/run-demos.sh --run-all' +demos: demo .PHONY: clean-demos clean-demos: @@ -165,6 +243,9 @@ clean-demos: ### Benchmarks include tools/benchmark/benchmark.makefile +.PHONY: benchmark-build +benchmark-build: install-full + .PHONY: bench bench: $(MAKE) benchmark-build @@ -178,32 +259,10 @@ clean-bench: sanity-check: set -e $(MAKE) build - sudo $(MAKE) install + sudo $(MAKE) install-full docker run -it --rm -v $(shell pwd):/code -t ${DOCKER_IMAGE_NAME} bash -c 'make lint' $(MAKE) demos $(MAKE) test $(MAKE) benchmark-build $(MAKE) dudect filter=NON_EXISTING_TEST -### For Go wrappers -.PHONY: test-go -test-go: - @echo "Running Go tests$(if $(filter), (filter=$(filter)),)..." - @${RUN_CMD} 'BUILD_TYPE=${BUILD_TYPE:-Release} bash scripts/go_with_cpp.sh bash -lc "\ - if [ -n \"$(filter)\" ]; then \ - go test -v -run \"$(filter)\" ./...; \ - else \ - go test -v ./...; \ - fi"' - -.PHONY: test-go-short -test-go-short: - ${RUN_CMD} 'BUILD_TYPE=${BUILD_TYPE:-Release} bash scripts/go_with_cpp.sh bash -lc "go test -short ./..."' - -.PHONY: test-go-race -test-go-race: - ${RUN_CMD} 'BUILD_TYPE=${BUILD_TYPE:-Release} bash scripts/go_with_cpp.sh bash -lc "go test -race -v ./..."' - -.PHONY: godoc -godoc: - ${RUN_CMD} 'cd demos-go/cb-mpc-go && godoc -http=:6060' diff --git a/README.md b/README.md index 37267bb7..282cac79 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ - [Internal Header Files](#internal-header-files) - [RSA OAEP Padding Modification](#rsa-oaep-padding-modification) - [Bitcoin Secp256k1 Curve implementation](#bitcoin-secp256k1-curve-implementation) +- [Go Wrappers](#go-wrappers) # Introduction @@ -44,8 +45,7 @@ Although this library is designed for general use, we have included examples sho 1. **HD-MPC**: This is the MPC version of an HD-Wallet where the keys are derived according to an HD tree. The library contains the source code for how to generate keys and also to derive keys for the tree (see [src/cbmpc/protocol/hd_keyset_ecdsa_2p.cpp](src/cbmpc/protocol/hd_keyset_ecdsa_2p.cpp)). This can be used to perform a batch ECDSA signature or sequential signatures as shown in the test file, [tests/unit/protocol/test_hdmpc_ecdsa_2p.cpp](tests/unit/protocol/test_hdmpc_ecdsa_2p.cpp). We stress that this is not BIP32-compliant, but is indistinguishable from it; more details can be found in [docs/theory/mpc-friendly-derivation-theory.pdf](docs/theory/mpc-friendly-derivation-theory.pdf). 2. **ECDSA-MPC with Threshold EC-DKG**: This example showcases how a threshold of parties (or more generally any quorum of parties according to a given access structure) can perform ECDSA-MPC. The code can be found in [src/cbmpc/protocol/ec_dkg.cpp](src/cbmpc/protocol/ec_dkg.cpp) and its usage can be found in [tests/unit/protocol/test_ecdsa_mp.cpp](tests/unit/protocol/test_ecdsa_mp.cpp). -3. **ECDSA-MPC with Threshold Backup**: This example showcases various things. First, the code is in Go, [demos-go/examples/ecdsa-mpc-with-backup/main.go](demos-go/examples/ecdsa-mpc-with-backup/main.go) and therefore showcases how the C++ core library can be used in a Go project. Second, it showcases how different protocols can be combined to create a full solution. In this case, we use PVE (publicly-verifiable encryption) as a way of creating verifiable backup of keyshares according to an access structure (e.g., a threshold of `t` out of `n` parties). The code shows how the backup can be created and restored. It also shows how the backup can be used to generate a signature. Note that the key generation can be done using the threshold EC-DKG protocol, which is showcased in the previous example. However, for simplicity a normal additive DKG is used in this example. -4. **Various other uses cases, including ZKPs**: The demo code under [demos-cpp](demos-cpp) and [demos-go](demos-go), and the tests under [tests](tests), contain various examples of how the different protocols can be used. Specifically, for the case of ZKPs, the tests can be found under [tests/unit/zk/test_zk.cpp](tests/unit/zk/test_zk.cpp). +3. **Various other uses cases, including ZKPs**: The demo code under [demo-cpp](demo-cpp) and the tests under [tests](tests) contain various examples of how the different protocols can be used. Specifically, for the case of ZKPs, the tests can be found under [tests/unit/zk/test_zk.cpp](tests/unit/zk/test_zk.cpp). The library comes with various tests and checks to increase the confidence in the code including: @@ -54,16 +54,19 @@ The library comes with various tests and checks to increase the confidence in th - Benchmarks: See `make bench` - Linting: See `make lint` +# High-level Public API vs the Full API + +The cb-mpc library contains two levels of APIs, a public one that contains the API for calling high-level MPC protocols like DKG, threshold signing and publicly-verifiable encryption for backup. These APIs are simple to use, and are recommended for those wishing to use the cb-mpc protocols as is. The second level contains the full cb-mpc API and includes all mid and low-level functions as well. These APIs are intended for users wishing to modify protocols or implement other protocols using the cb-mpc infrastructure. We stress that all functions in the public API are safe, including validation of all externally received inputs. In contrast, not all lower-level APIs include all such validations. For example, an elliptic curve point received from another party in a protocol needs to be validated once and not every time it is used. In addition, some zero-knowledge proofs may depend on other zero-knowledge proofs already being validated. This is taken care of in the public API, but the user is responsible for ensuring all of this when using the lower-level full API. Note that the Coinbase HackerOne bug bounty requires finding a bug in the public API for it to be consider Medium or above. See [BUG_BOUNTY.md](BUG_BOUNTY.md) for details. + # Directory Structure - `docs`: the pdf files that define the detailed cryptographic specification and theoretical documentation (you need to enable git-lfs to get them) -- `src`: contains the cpp library and its unit tests -- `cb-mpc-go`: contains an example of how a go wrapper for the cpp library can be written -- `demos-cpp`: a collection of examples of common use cases in c++ -- `demos-go`: examples of how the c++ library can be used in Golang - - `demos/cb-mpc-go`: Go wrapper of the cb-mpc - - `demos/mocknet`: an example of how a network infra can be implemented (for demo purposes) - - `demos/examples`: examples of some multiparty computation tasks in Golang +- `include`: public headers (installed in both `public` and `full` install modes) + - Public C++ API wrappers: `include/cbmpc/api/` (`coinbase::api`) + - Public C API (stable ABI) for wrappers written in other languages such as Go, Rust, etc.: `include/cbmpc/c_api/` (`cbmpc_*`) +- `include-internal`: internal headers (installed only in `full` mode), included as `` +- `src`: C++ implementation sources +- `demo-cpp`: a collection of examples of common use cases in c++ - `scripts`: a collection of scripts used by the Makefile - `tools/benchmark`: a collection of benchmarks for the library - `tests/{dudect,integration,unit}`: a collection of tests for the library @@ -78,7 +81,7 @@ git submodule update --init --recursive Furthermore, to obtain the documentations (in pdf form), you need to enable [git-lfs](https://git-lfs.com/) -# Building the code +# Building the Code ## Build Modes @@ -92,7 +95,7 @@ There are three build modes available: ### OpenSSL -The library depends on a **custom build of OpenSSL 3.2.0** with specific modifications (see [External Dependencies](#external-dependencies)). You must build this custom version before compiling the library. +The library depends on a **custom build of OpenSSL 3.6.1** with specific modifications (see [External Dependencies](#external-dependencies)). You must build this custom version before compiling the library. **Quick Start:** ```bash @@ -111,7 +114,7 @@ scripts/openssl/build-static-openssl-macos.sh # for x86_64 scripts/openssl/build-static-openssl-macos-m1.sh # for ARM64 ``` -**Note:** These scripts install OpenSSL to `/usr/local/opt/openssl@3.2.0` and may require `sudo` permission. +**Note:** These scripts install OpenSSL to `/usr/local/opt/openssl@3.6.1` and may require `sudo` permission. **Custom Install Location:** If you prefer a different installation path, you can set the `CBMPC_OPENSSL_ROOT` variable: @@ -148,19 +151,43 @@ To test the library, run Running demos and benchmarks: -- Go wrapper and Go demos do not require installation. They compile against the local build output under `/lib` via `scripts/go_with_cpp.sh`. - - Run Go tests: `make test-go` (or `make test-go-short`, `make test-go-race`) -- C++ demos and benchmarks still expect the library to be installed under `/usr/local/opt/cbmpc`. - - Install for C++ usage: `sudo make install` - - Run all demos (C++ + Go): `make demos` - - Run benchmarks: `make bench` +- By default, demos/benchmarks use a **repo-local install prefix** under `build/install/` (no `sudo`). + - Public install (only `include/`): `make install` (installs to `build/install/public`) + - Full install (also installs `include-internal/`): `make install-full` (installs to `build/install/full`) + - Run all demos (C++ + API + Go): `make demo` (or `make demos`) + - Run benchmarks: `make bench` (auto-runs the full install) Notes: -- If you have not run `make install`, the C++ portion of `make demos` may fail, but the Go demos will still use the local build. -- To run a single Go demo without install, for example `ecdsa-2pc`: - ```bash - BUILD_TYPE=Release bash scripts/go_with_cpp.sh --no-cd bash -lc "cd demos-go/examples/ecdsa-2pc && go run main.go" - ``` +- `make demo` takes care of building + installing the right variants (public vs full) before running each demo. + +### Install Public API vs full API + +By default the cb-mpc library only installs the public API. To install the full APIs run the following command: + +```bash +scripts/install.sh --mode full +``` + +### Install prefix (optional) + +You can install to a custom prefix: + +```bash +scripts/install.sh --mode public --prefix /path/to/prefix +# or: +CBMPC_PREFIX=/path/to/prefix scripts/install.sh --mode full +``` + +For the Makefile helpers, you can override the default repo-local layout: + +```bash +# Install under a custom root (creates /{public,full}) +make install-all CBMPC_INSTALL_ROOT=/path/to/prefix + +# Or override each prefix independently +make install CBMPC_PREFIX_PUBLIC=/path/to/public/prefix +make install-full CBMPC_PREFIX_FULL=/path/to/full/prefix +``` Our benchmark results can be found at @@ -297,8 +324,12 @@ Our implementation modifies OpenSSL's OAEP padding algorithm to support determin The security properties of OAEP remain intact as long as the provided seed maintains appropriate randomness and uniqueness requirements. For standard encryption operations, we recommend using the non-deterministic version that generates random seeds internally. -## Bitcoin Secp256k1 Curve implementation +## Bitcoin Secp256k1 Curve Implementation We used a modified version of the secp256k1 curve implementation from [coinbase/secp256k1](https://github.com/coinbase/secp256k1) which is forked from [bitcoin-core/secp256k1](https://github.com/bitcoin-core/secp256k1). The change made is to allow calling the curve operations from within our C++ codebase. Note that as indicated in their repository, the curve addition operations of `secp256k1` are not constant time. To work around this, we have devised a custom point addition operation that is constant time. Please refer to our [documentation](/docs/constant-time.pdf) for more details. + +# Go Wrappers + +There are extensive Go wrappers for this C++ library that enable the use of library natively from Go. You can find them in [coinbase/cb-mpc-go](https://github.com/coinbase/cb-mpc-go/). The repository also includes demos on how to use the library in Go. \ No newline at end of file diff --git a/SECURE_USAGE.md b/SECURE_USAGE.md new file mode 100644 index 00000000..d7bc2cd2 --- /dev/null +++ b/SECURE_USAGE.md @@ -0,0 +1,236 @@ +# The CB-MPC Open Source Library: Secure Usage Instructions + +Much of this document applies to the internal API. + +## Description + +Cryptography is hard to use correctly, and cryptographic libraries are easily misused. To use a well-known example, if a block cipher mode of operation requires a random IV and a unique nonce is used instead, then this can break security (let alone if a fixed nonce is used). Continuing with this example, it is possible to make the library very high level and not allow the user to provide an IV at all, but this then makes the library less flexible and so not usable in all applications. Finding the right balance between an API that is hard to misuse and one that is flexible is challenging in any cryptographic library, and in this one as well. + +We have designed this library with defensive measures in mind wherever possible, in order to minimize misuse. However, doing this comprehensively would result in significant complications and sacrifices in efficiency (which sometimes have been made, as described below). Although it is impossible to describe every potential misuse, in this document, we provide usage instructions for primitives and protocols that are "error prone" in that there are subtleties when it comes to implementation. Familiarity with these instructions is important for ensuring secure usage. + +Ideally, this library should be usable to a software engineer without significant cryptographic background. While we believe that this is true of the high-level API (e.g., calling multiparty sign) as long as this usage document is used, we stress that more familiarity with cryptography is needed when directly calling lower-level primitives and subprotocols. Additional restrictions and requirements appear in some of the specifications, and so these should be checked before using any primitive or protocol. + +## Usage Instructions + +### Session identifiers + +Session identifiers are used in MPC protocols in order to determine which protocol execution a message belongs to. This is especially critical in concurrent settings (e.g., in order to achieve UC security), where the use of a unique session ID (denoted `sid` throughout our specifications) prevents an adversary from attacks like taking a zero-knowledge proof generated by an honest party in one execution and sending it in another execution. When minimizing the number of rounds is not an issue, the session identifier can be generated simply by having each party sending a local random `sid_i` (of length 128 bits), and then hashing the concatenation of them all. This adds a single round of communication, which is often insignificant. + +In interactive protocols (like multiparty signing), we work according to the following two cases: + +1. If a `sid` is needed in the first round of the protocol, then the API can accept a `sid` as optional input, and works as follows: + 1. If a `sid` is passed as input, then it is used where needed and assumed to be globally unique. + 2. If a `sid` is not passed as input, then the protocol generates it internally, as described above. This adds an additional round to the protocol in order to generate the `sid`. +2. If a `sid` is needed only in the second round or later (or not at all), then the API does not accept a `sid` as input at all; if a `sid` is needed, it will be generated internally. Note that this has no efficiency impact, since the `sid` is generated in parallel to the first round of the protocol. + +In non-interactive APIs, like (most) zero-knowledge proofs, a globally unique `sid` must be passed as input (the only exception to this rule is for commitments, which are discussed next). This is the responsibility of the calling application, since it is not possible to generate a `sid` internally in a non-interactive primitive. We stress that the `sid` *cannot* be chosen singlehandedly by the prover but rather must be jointly generated, since a malicious prover can copy a `sid` from another execution, enabling it to copy a proof. We remark that when such non-interactive APIs are called by our higher-level interactive protocol, they use the `sid` that they generated in the calls to the non-interactive APIs. When developing new protocols that directly call the non-interactive APIs, it is the responsibility of the protocol to generate and use a globally unique `sid` appropriately. + +In particular, when multiple different subprotocols are called, each subprotocol must be given a unique session identifier as input. The best way to achieve this is to concatenate a different identifier to the `sid` for each subprotocol and to use the resulting string as the session identifier of the subprotocol (e.g., use `sid || 1`, `sid || 2` and so on). + +### Zero-knowledge auxiliary input + +In the zero-knowledge API, in addition to the `sid` there is an `aux` input which is used to differentiate different calls (i.e., the same `sid` is used for all zero-knowledge calls within a given protocol, but with a different `aux` each time). Therefore, there is no need to modify the `sid` for zero-knowledge calls; rather, the `sid` of the calling protocol can be used while using a different value for `aux`. + +### Session identifiers in commitments + +Our commitment scheme can work in one of two ways: either using a globally unique `sid` (as above, this cannot be generated by the committer, but must be jointly generated by the committer and receiver), or using the committer's party identifier `pid` and a locally generated `sid_i` (chosen randomly by the committer).[^1] The advantage of this latter option is that in a protocol that begins with commitments in the first round, there is no need to add an additional round in order to generate a globally unique `sid`. + +We stress, however, that in this case, the `pid` must be a true party identifier and not just a "role". In order to clarify this, consider a setting where two-party signing is used, and the first party needs to send a commitment to the second party. Then, the roles of the parties are just "1" and "2", but these do not suffice as `pid`s, since there can be many different parties playing role "1" and role "2". Therefore, in this context, the `pid` should be taken as something truly particular to a party. It could be their username in an application with unique usernames, or it could be the hash of their public key that is used to connect to the system. This is discussed in more detail in the [basic primitives theory document](docs/theory/basic-primitives-theory.pdf). + +By default, our commitment API is designed to work with a `pid` and local `sid_i`. If such a `pid` is not available, then it is possible to securely use the API by generating a globally unique `sid` and using that as the `pid`. As above, if different parties use the same `sid` (or the same party commits to multiple commitments with the same `sid`), then an index needs to be concatenated to the `sid` first. + +As will be explained below, our job/networking layer expects a globally unique `pid` for each party (as defined above), and derives these identifiers from the party names used to initialize the job/transport layer so that they are accessible by all two party and multiparty protocols. + +### Counting rounds + +In our documentation we call a single message from one party to another a "round". Thus, if party 1 sends a message to party 2 who then responds, this is called "two rounds". In the multiparty setting, a single round typically consists of all parties sending messages to all others, but it can also involve a subset of those messages. Note that the messages of a single round are typically sent with a single "send instruction" (`job.p1_to_p2`, `job.p2_to_p1` `job.mpc_broadcast`, `job.mpc_message_all_to_one`, etc.), but this is not necessarily the case. + +For example, if two subprotocols are run sequentially, where the first protocol ends with a message from `P1` to `P2`, and the second protocol begins with a message from `P1` to `P2`, then the last message of the first protocol with the first message of the second protocol is considered a *single round*. This is because `P1` sends them both one after another and does not have to wait for a reply from `P2` (thereby saving the ping time). This is especially common when generating session identifiers in protocols where a `sid` is needed in the first round. In this case, as described above, if a `sid` is not provided as input then the protocol begins by generating a `sid`. This generation requires two rounds. However, by ordering the messages so that the second message of the `sid` generation is sent by the same party sending the first message of the protocol, this adds only a *single* round to the protocol (and not two). + +### Networking + +To make this library extendible we do not provide a default networking implementation. Rather, we define a simple transport interface comprising of the following functions (see `include/cbmpc/core/job.h`): + +```cpp +error_t send(party_idx_t receiver, mem_t msg) = 0; +error_t receive(party_idx_t sender, buf_t& msg) = 0; +error_t receive_all(const std::vector& senders, std::vector& msgs) = 0; +``` + +**Design note:** `data_transport_i` is intentionally minimal. It does *not* implement security properties (authentication, confidentiality, integrity, anti-replay, etc.) for you; those must be provided by the calling application/transport (e.g., mutual-authenticated TLS). + +**Important:** The security guarantees of the MPC protocols in this library assume a secure transport. If you run protocols over an unauthenticated channel, an active attacker can tamper with protocol messages and may be able to cause key compromise or incorrect outputs. Do not treat transport security as optional. + +These low-level functions can be implemented in a variety of ways. For example, we have a C++ implementation of them in our test suite and a Go implementation of them in our Go demos; these implementations are demos only, communicating between different threads on a single machine. Any production level implementation of these interfaces must have the following properties to achieve the security needed by the library: + +- **Authenticated**: the parties must mutually authenticate each other and reject forged/tampered messages. A simple way to achieve this is to use TLS with mutual authentication and either certificate pinning or a PKI. +- **Encrypted**: the messages must be encrypted. In some of the protocols, secret data is sent over the network and we rely on encryption to keep them secret. As before, TLS provides these guarantees. +- **Anti-replay / channel binding**: the transport should prevent replay of old messages and should bind the channel to the intended peer identity (so an attacker cannot replay messages between two different sessions/peers). +- **Blocking**: the receive functions must be blocking. Our code uses the receive functions sequentially in a loop to receive messages from all parties. Our protocols assume by default that all messages in a round are received, before a party sends their next-round message. As such, the receive functions must be blocking in order to fulfil this property. +- **Globally unique `pid`s**: It is the responsibility of the caller to ensure that each party has a globally unique identifier (denoted `pid` throughout our specifications) and that all honest parties agree on these identifiers. In the library, `pid`s are derived from the party names used to initialize the job/transport layer; it therefore does not suffice to have parties merely send their identifiers to each other. If certificates are used for authentication and TLS, then a natural choice for a party's identifier would be the hash of their public key. + +*Limitations*: at the moment the library does not support more than 64 parties. This can be easily lifted with some code changes in the networking layer. + +### Transport lifetime and cancellation + +The `job_2p_t` / `job_mp_t` objects store a non-owning reference/pointer to the transport implementation provided by the caller. As a result: + +- The transport object must outlive any protocol call that uses the job (including any background threads in the caller that are still servicing receives). +- The character data backing the `std::string_view` party names must outlive the protocol call (do not construct a job from temporary strings). +- Do not pass a temporary transport object when constructing a job. +- If you add timeouts/cancellation to your transport, ensure that blocked `receive(...)` / `receive_all(...)` calls can be unblocked during shutdown/abort; otherwise, protocol threads may deadlock on errors. + +### Uniform vs non-uniform messages + +To send a message over the network, we have created some helper types, namely `uniform_msg_t` and `non_uniform_msg_t`. The first one is used for sending the "same" message to all parties, and the second one is used for sending "different" messages to different parties. It is imperative that these are not mixed up and used incorrectly or a message containing a secret that only one party is allowed to receive may end up being received by all parties. + +Specifically, when using `committed_pairwise_broadcast`, the `non_uniform_msg_t` must be used. Conversely, `committed_group_broadcast` should use `uniform_msg_t`. + +### Opaque blobs and secret material + +Many public APIs return versioned, opaque `buf_t` blobs (e.g., `key_blob`, `keyset_blob`, TDH2 `private_share`, PVE base-PKE `dk` blobs). These blobs are designed for portability across process restarts, but they often contain *secret key material* (private key shares, Paillier secret keys, etc.). + +Important implications: + +- Treat these blobs as you would treat a raw private key: do not log them, do not send them over the network, and avoid writing them to disk unless necessary. +- In particular, do not send a party's `key_blob` / `keyset_blob` to the other MPC party: these blobs are private key *shares* and often include auxiliary secrets (e.g., Paillier secret keys). Sharing them breaks the trust model. +- The library does not encrypt or authenticate these blobs for you. If you persist them, protect them with an application-managed AEAD (e.g., XChaCha20-Poly1305 or AES-256-GCM) and bind associated data such as: protocol name, curve id, blob version, and the expected party identity (role / party name). +- Prefer encrypting these blobs at rest via envelope encryption: keep the wrapping/encryption key in an HSM/secure enclave or managed KMS (outside the host) and store only AEAD-encrypted blobs on disk. Ensure crash dumps / core dumps and crash reporting cannot exfiltrate plaintext secret material. +- When providing 'mem_t' inputs to APIs, you need to validate the inputs and ensure validity: `size` must be non-negative, and if `size > 0` then `data` must be non-null. Public (high-level) API wrappers validate these invariants and return `E_BADARG` on violations; lower-level internal mem_t code does not validate values; it is the responsibility of high-level APIs to use these safely. + +### Paillier + +When Paillier encryption is used for homomorphic operations, it is critical that the ciphertext be rerandomized after the homomorphic operations are carried out. This ensures that the private-key owner cannot learn anything about the operations that we carried out, and can only learn the final result. By default, this rerandomization is carried out internally after each homomorphic operation. However, since rerandomization is costly, if multiple homomorphic operations are to be carried out, it is best to rerandomize only once at the end. + +This can be achieved by calling: + +```cpp +crypto::paillier_t::rerand_scope_t paillier_rerand(crypto::paillier_t::rerand_e::off) +``` + +and then calling `rerand()` afterwards. Alternatively, if part of the homomorphic operations include adding in a freshly encrypted value, then `rerand` is not needed at all. + +### ElGamal commitments + +Similar to Paillier ciphertexts, ElGamal commitments need to be rerandomized after the homomorphic operations. However, in the current version of the library, this is not done by default, and the protocol is responsible to ensure rerandomization where needed. + +### Zero-knowledge flags + +In some of our zero-knowledge proofs, the soundness relies on the fact that the verifier has already ascertained that the public input fulfills certain properties. For example, when proving that two Paillier ciphertexts under different keys encrypt the same value, we need to assume that both Paillier keys are valid, and that the range of the plaintext in one of the ciphertexts has previously been verified to be small. These specific properties (validity of a Paillier key and range of a plaintext) cannot be efficiently verified without additional information, and separate zero-knowledge proofs are needed for them. + +In order to ensure that these have already been verified, the zero-knowledge objects include flags for relevant properties that are set to `false` by default, and the calling application must set the flags to `true` before beginning attesting to the fact that these properties have been verified via a previous ZK proof, or verification will fail. + +There are also some cases where a property is efficiently verifiable, like a Paillier public key having no factors smaller than `2^13` or a Paillier ciphertext being valid, and yet we still include flags for them. This is done in cases where the check is theoretically efficient but still rather costly. For example, checking validity of a Paillier ciphertext requires computing `gcd` which is equivalent to the cost of an exponentiation. This is not necessarily a problem, but is wasteful if already checked. + +Our mechanism for flags is described in more detail in the [zero-knowledge specification](docs/spec/zk-proofs-spec.pdf). + +We stress that the zero-knowledge flags are declarative only, and are there to avoid mistakenly calling a zero-knowledge proof that relies on another proof that has not been run. That is, nothing prevents code from setting a flag without having ever verified the property; rather the need to set the flag serves as a notification to the developer that the property must be proven first. + +Note that within a single protocol flow, flags can be passed from one proof to another to ensure correct usage. However, if a proof is run in a different (e.g., later) protocol execution, then the flags must be manually set based on knowledge that the appropriate property has been verified in the past. + +### Ciphertext verification vs. decryption (PVE) + +The PVE APIs provide explicit verification functions (e.g., `verify`, `verify_batch`, `verify_ac`) and decryption / reconstruction functions (e.g., `decrypt`, `decrypt_batch`, `combine_ac`). + +The decryption / reconstruction functions intentionally do **not** verify ciphertexts internally. Invalid ciphertexts may cause decryption / reconstruction to fail, but are designed to not leak secret information. + +If your application needs ciphertext validation on untrusted inputs, call the appropriate `verify*` function before decrypting / reconstructing. + +### PVE callback contracts (custom base PKE / HSM) + +The PVE APIs support pluggable encryption backends (custom `base_pke_i`, custom KEM, and HSM callbacks). These interfaces are intentionally flexible, but misuse can break correctness and security. + +- If you implement a custom PVE `base_pke_i` (or custom KEM callbacks), the encryption/encapsulation step **must be deterministic given the provided `rho` seed**. Do not draw fresh randomness from an RNG in these callbacks. +- For the built-in HSM callbacks: ensure you are performing RSA-OAEP decryption with the expected parameters (SHA-256, empty OAEP label) and returning the exact shared secret that encryption produced; for ECIES(P-256), the callback must return the 32-byte affine-X coordinate of the ECDH output. +- Treat `dk_handle` as an application-controlled opaque identifier (e.g., key id/label). Do not accept it from an untrusted counterparty. + +### HD-MPC + +Hierarchical-deterministic (HD) wallets are very prevalent, since they enable parties to backup a single master key once, and to use that master key to derive many keys for different accounts, blockchains, and so on. The standard for HD wallets today is [BIP-0032](https://github.com/bitcoin/bips/blob/master/bip-0032.mediawiki). + +It is possible to compute the standard BIP-0032 key derivation in MPC, but it is relatively expensive. In addition, most wallets that implement BIP-0032 also implement [BIP-039](https://github.com/bitcoin/bips/blob/master/bip-0039.mediawiki) which uses 2048 iterations of HMAC-SHA512 in order to derive a key from a given mnemonic. This derivation is certainly not feasible in MPC, and so the advantage of being BIP-0032 compatible -- which is to enable interoperability with other wallets -- is greatly diminished. + +We therefore take a different approach, and provide an HD method that is highly efficient in MPC (i.e., it is "MPC-friendly") and indistinguishable from a standard BIP-0032 HD wallet. In particular, normal derivation is carried out in exactly the same way as in BIP-0032, and hardened derivation uses a different pseudorandom function that is algebraic and so amenable to MPC. We note that, by definition, all pseudorandom functions are indistinguishable from each other, and so an HD wallet using HD-MPC in this library looks exactly like a standard BIP-0032 HD wallet. + +For more information about the method and its security, please see the [mpc-friendly-derivation theory document](docs/theory/mpc-friendly-derivation-theory.pdf). + +### Signing protocol input + +Beyond the shares of the key, each party in an MPC signing protocol also receives the message to be signed upon. We stress that in our *ECDSA APIs* the input is an already hashed message. Thus, we do *not* hash the input again, and it is *not secure* to use the message directly. + +The reason for this is that different blockchains use different hash functions (SHA256, double SHA256, Keccak256, etc.), and therefore the protocol just works on the hashed message, enabling the application to use the appropriate hash for the appropriate blockchain. + +Similarly, our BIP340 Schnorr APIs take a 32-byte message digest (as defined in BIP340). The caller is responsible for hashing and domain separation; passing a raw, variable-length message is not the intended usage. + +In contrast, since EdDSA is a fully standardized algorithm and it requires the message to be hashed with the nonce and public key, the EdDSA API *does* receive the original message and not its hash. + +### EdDSA + +EdDSA is a signing scheme that is a variant of Schnorr over the specific Ed25519 curve. It differs from standard Schnorr since in that it essentially uses a pseudorandom function to derive the signature nonce from the message being signed, making it deterministic. This design was intended to prevent breaks that are due to the use of low-quality randomness in signing. + +Nevertheless, our implementation of EdDSA is actually just Schnorr over Ed25519 and is probabilistic. When the same message is never signed twice, it is indistinguishable from standard EdDSA, and in particular, standard EdDSA verification works for our implementation. This makes it suitable for blockchains that use EdDSA, especially if parties ensure that they never sign on the same message twice.[^2] Whether or not any such enforcement is needed depends on the specific application. + +### Multiparty ECDSA signing + +The protocol that we use for multiparty ECDSA signing uses oblivious transfer for private multiplication. In order to make this efficient, we use OT-extension, meaning that we only need to do 256 base oblivious transfers that require elliptic-curve operations, and the rest use symmetric operations only. + +However, in order to keep things simple in terms of state that needs to be stored, we run the base oblivious transfers from scratch *every time we sign*. This increases both rounds of communication and computational cost (each base oblivious transfer requires 3 elliptic curve multiplications by the receiver and 8 elliptic curve multiplications by the sender). + +An alternative, and much more efficient approach, is to run the base OTs once during key generation. Then, upon signing, assuming that `n` oblivious transfers are needed for the signing operation, the parties run OT-extension in order to obtain `n + 256` random oblivious transfers, and store the results of 256 transfers to use as "base OTs" in the next signing operation (to be used as the effective base OTs for the OT extension). The disadvantage of working in this way is that each party needs to hold these OT results as state, and this state needs to be updated in every signing operation. + +In addition to the above, the protocol that we use utilizes threshold ElGamal commitments, and we generate a new ElGamal key pair per signing operation, rather than doing this at key generation. This adds some additional cost as well, but has the advantage of keeping the state to be the key shares only, which is the minimum required. + +### Two-party ECDSA signing + +The public two-party protocol for ECDSA signing exposes a single signing API: `sign()`. This protocol is fully secure and behaves like a standard MPC protocol (no special operational caveats beyond the general guidance in this document). + +We intentionally do not expose a more efficient "global-abort" variant of two-party signing in the public API, because using it safely requires additional cryptographic and operational expertise as described below. The `sign_with_global_abort()` is secure as long as if a certain type of cheating is detected, all executions with that key are halted. This is because such a cheat can be used to learn a bit of the private key. This is insignificant for a small number of bits (as they can be guessed anyway) but can leak the entire private key over time if the attack is allowed to be carried out multiple times over many signing attempts. This also means that it isn't secure to open hundreds of signing sessions in parallel, if it isn't possible to abort them all in case cheating is detected in any one of them. Despite the above, the reason to use `sign_with_global_abort()` is that it is more efficient, and is often sufficient. For example, if the application using this protocol is a wallet between a mobile and a server, and it is possible to enforce sequentiality (or only minimal parallelism), then aborting other sessions can easily be achieved. Having said this, one still needs to determine what to do after such cheating has been discovered. Ideally, the malware is found and cleaned out, and the issue is over. However, this is not so easily achieved in practice. Therefore, our recommendation is that if `sign_with_global_abort()` is used, then after a critical cheat event is discovered, the parties will continue to use `sign()` only. This ensures that the leakage is insignificant (again, assuming sequentiality), and that the parties can continue signing securely afterwards, albeit a bit less efficiently. (Note that the verifier party returns a special error code called `E_ECDSA_2P_BIT_LEAK` in this case (the usual error that requires aborting but no other special treatment is `E_CRYPTO`).) +We stress that this is the only protocol in the library with this property. We also stress that the *application* using the library is responsible for ensuring that appropriate action is taken (locking the key, moving to `sign()`, etc.) if the `E_ECDSA_2P_BIT_LEAK` error is received. *This is **not** taken care of by the low-level library*. + +### Unknown-order Pedersen and two-party ECDSA signing + +Our two-party `sign()` includes a zero-knowledge proof (in the last round of the protocol) that one of the parties computed its message correctly. This proof uses Pedersen commitments over an unknown-order modulus; these parameters are an RSA modulus `N` (generated with safe primes) and two elements `g, h`. + +The parameters for this proof can be generated locally by the verifier, as long as it proves that it generated `g` and `h` so that `g = h^α mod N`, for some `α`. We stress that the parameters cannot be generated by the prover, since if the prover knows the factorization of `N` then it can cheat in the proof. + +One option for using this proof is to have the party who will play the verifier generate these parameters, and prove that `g, h` are generated correctly using the provided zero-knowledge proof for unknown-order discrete log. See [`zk::unknown_order_dl_t`](https://github.com/coinbase/cb-mpc/blob/master/src/cbmpc/zk/zk_unknown_order.h) for this. + +Alternatively, these parameters can be generated once and for all in a secure environment -- requiring users to trust that they were indeed correctly generated. + +An in-between option which is sometimes applicable is as follows. Consider a case where the verifying party is the organization offering a wallet service to its users, and the users are always the parties playing the prover (in our [two-party ECDSA specification](docs/spec/ecdsa-2pc-spec.pdf), the prover is `P2` and the verifier is `P1`, where `P1` is the Paillier-key owner). Then, the organization can generate the parameters and a proof that `g = h^α mod N`, and can hardwire the parameters and proof into the code (while also erasing the factorization of `N` and the value `α`). + +Then, any user can verify the legitimacy of the hardwired parameters once and for all, and no trust is needed. The parameters that are hardwired in the existing open source library were generated by Coinbase in this way, and so are suitable without any trust when Coinbase plays the verifier (`P1` in the two-party signing protocol). See [`zk::unknown_order_dl_t`](https://github.com/coinbase/cb-mpc/blob/master/src/cbmpc/protocol/int_commitment.cpp) for this. + +Anyone using this library can feel free to use these parameters (trusting that Coinbase has erased any secrets, and won't cheat as a malicious prover), can re-generate the parameters themselves (so that no trust in Coinbase is needed), or can generate them separately for each wallet during key generation as discussed above. + +More information about these parameters and how they need to be set can be found in the [basic-primitives specification](docs/spec/basic-primitives-spec.pdf) under Integer Commitments, in the [integer-commitments theory document](docs/theory/basic-primitives-integer-commitments-theory.pdf) and in the [ecdsa-2pc specification](docs/spec/ecdsa-2pc-spec.pdf) and [ecdsa-2pc theory documents](docs/theory/ecdsa-2pc-theory.pdf). + +### Checking group membership + +In many protocols, it is necessary to check that an element received from another party (who may be adversarial) is "valid". In the elliptic curve setting, this involves verifying that the value received is a point on the appropriate elliptic curve. Furthermore, in the case that the curve co-factor is greater than 1, it involves checking that the element is in the appropriate subgroup. + +The question that arises in these cases is how to determine the "real" curve that the value needs to be verified against. In signing algorithms, our approach is to use the public key for which we are signing. In zero-knowledge proofs, we take one of the elements in the input statement (with the assumption that the statement is known to both parties, and the curve for the input has already been validated). + +In general, and this needs to be considered since the library is general purpose, the application developer needs to keep this in mind and make sure that the checks are against the appropriate curve. + +### Constant-time + +By default, all computations on private information is supposed to be constant time,[^3] including low-level operations. We remark that *elliptic point addition* is *not* constant time. This is because it typically is not required to be constant time, and is therefore not supported in lower level libraries that we use. + +We remark that `muladd(a, b, H) = a·G + b·H` *is* constant time, and this is indeed needed for Pedersen and ElGamal commitments. + +We stress that constant-time operations are often slower than their analogous variable-time functions. The library therefore enables the developer to declare that a certain piece of code does not need to be constant time by calling `crypto::vartime_scope_t vartime_scope`. This is used, for example, in zero-knowledge verification which is always on public values. + +We stress that we only insist on constant time for private values. Thus, for example, when computing the Lagrange basis polynomial (see the [basic primitives specification](docs/spec/basic-primitives-spec.pdf)) the denominator is just the product of differences between party identifiers, which is public. As a result, the computation of the modular inverse of this value does not need to run in constant time. + +See [constant-time.pdf](docs/constant-time.pdf) for more discussion. + +### A final word on efficiency + +The implementations in this library are highly efficient, within the limits of not cutting any corners. However, as we have seen above, there are places where we have not opted for optimal efficiency (e.g., for OT in multiparty ECDSA, and by making the default `sign()` protocol for two-party ECDSA the less efficient one). + +In some use cases, this is a problem, and developers can opt for more efficient versions. However, in most blockchain use cases, the differences in efficiency here are insignificant. We therefore recommend to optimize only where needed, and to opt for simplicity and defense in depth in all other cases. + +[^1]: The locally generated `sid_i` needs only be unique if the committer is honest, and therefore the committer can choose it themselves. +[^2]: This can be enforced quite easily in the blockchain space where there are always timestamps. In particular, it suffices to store previously signed transactions for a certain short time period for which a timestamp is still valid, and to ensure that a new transaction to be signed isn't currently stored. This suffices since older transactions can never be repeated since their timestamp will not be valid. Furthermore, the storage overhead is small since only very recent transactions need to be stored. +[^3]: This statement is qualified by "supposed to be" since the operating system and hardware can influence whether or not it is constant time. diff --git a/cmake/compilation_flags.cmake b/cmake/compilation_flags.cmake index c70a62cf..94e265e1 100644 --- a/cmake/compilation_flags.cmake +++ b/cmake/compilation_flags.cmake @@ -5,10 +5,13 @@ endif() set_cxx_flags("-std=c++17") set_cxx_flags("-fPIC") +if(NOT IS_WASM) + set_cxx_flags("-fstack-protector-strong") +endif() set_cxx_flags("-fvisibility=hidden") set_cxx_flags("-fno-operator-names") set_cxx_flags("-Wno-attributes") -set_cxx_flags("-Wno-null-dereference") +set_cxx_flags("-Wnull-dereference") set_cxx_flags("-Wno-parentheses") set_cxx_flags("-Wno-reorder") set_cxx_flags("-Wno-missing-braces") @@ -17,10 +20,14 @@ set_cxx_flags("-Wno-switch-enum") set_cxx_flags("-Wno-sign-compare") set_cxx_flags("-Wno-strict-overflow") set_cxx_flags("-Wno-unused") -set_cxx_flags("-Wno-parentheses") set_cxx_flags("-Werror") +set_cxx_flags("-Wno-error=null-dereference") set_cxx_flags("-Wno-shorten-64-to-32") set_cxx_flags("-DNO_DEPRECATED_OPENSSL") +if(IS_LINUX AND ENABLE_O3) + # Enable fortified libc wrappers where supported (glibc, optimized builds). + set_cxx_flags("-D_FORTIFY_SOURCE=2") +endif() if(IS_CLANG) set_cxx_flags("-Wno-tautological-undefined-compare") @@ -28,7 +35,8 @@ if(IS_CLANG) set_cxx_flags("-Wno-vla-extension") set_cxx_flags("-Wno-error=deprecated-declarations") else() - set_cxx_flags("-Wno-maybe-uninitialized") + set_cxx_flags("-Wmaybe-uninitialized") + set_cxx_flags("-Wno-error=maybe-uninitialized") endif() if(IS_ARM64) @@ -41,8 +49,20 @@ if(IS_X86_64) set_cxx_flags("-mpclmul -maes -msse4.1") endif() +# Some linker hardening flags are incompatible with sanitizer runtimes. +# In particular, `-Wl,--exclude-libs,ALL` breaks ASAN+UBSAN builds (gtest crashes +# during test registration with a bad-free). Skip it for sanitizer builds so we +# can run sanitizer lanes in CI and locally. +set(_cbmpc_is_sanitized false) +if(CMAKE_CXX_FLAGS MATCHES "-fsanitize=" OR CMAKE_C_FLAGS MATCHES "-fsanitize=" OR + CMAKE_EXE_LINKER_FLAGS MATCHES "-fsanitize=" OR CMAKE_SHARED_LINKER_FLAGS MATCHES "-fsanitize=") + set(_cbmpc_is_sanitized true) +endif() + if(IS_LINUX) - set_link_flags("-Wl,--exclude-libs,ALL") + if(NOT _cbmpc_is_sanitized) + set_link_flags("-Wl,--exclude-libs,ALL") + endif() set_link_flags("-Wl,-z,defs") set_link_flags("-z noexecstack -z nodelete") link_libraries(pthread dl rt) diff --git a/cmake/openssl.cmake b/cmake/openssl.cmake index 8341c307..cef29e61 100644 --- a/cmake/openssl.cmake +++ b/cmake/openssl.cmake @@ -4,7 +4,7 @@ # The OpenSSL path can be customized via: # 1. CMake variable: -DCBMPC_OPENSSL_ROOT=/path/to/openssl # 2. Environment variable: export CBMPC_OPENSSL_ROOT=/path/to/openssl -# 3. Default: /usr/local/opt/openssl@3.2.0 +# 3. Default: /usr/local/opt/openssl@3.6.1 # # To build the custom OpenSSL, run the appropriate script: # - macOS (x86_64): scripts/openssl/build-static-openssl-macos.sh @@ -16,7 +16,7 @@ macro(link_openssl TARGET_NAME) if(DEFINED ENV{CBMPC_OPENSSL_ROOT}) set(CBMPC_OPENSSL_ROOT $ENV{CBMPC_OPENSSL_ROOT}) else() - set(CBMPC_OPENSSL_ROOT "/usr/local/opt/openssl@3.2.0") + set(CBMPC_OPENSSL_ROOT "/usr/local/opt/openssl@3.6.1") endif() endif() diff --git a/demo-api/ecdsa_mp_pve_backup/CMakeLists.txt b/demo-api/ecdsa_mp_pve_backup/CMakeLists.txt new file mode 100644 index 00000000..d2602700 --- /dev/null +++ b/demo-api/ecdsa_mp_pve_backup/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.16) + +project(mpc-demo-api-ecdsa_mp_pve_backup LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) + +include(${REPO_CMAKE_DIR}/macros.cmake) +include(${REPO_CMAKE_DIR}/arch.cmake) +include(${REPO_CMAKE_DIR}/openssl.cmake) +include(${REPO_CMAKE_DIR}/compilation_flags.cmake) + +if(NOT DEFINED CBMPC_SOURCE_DIR) + if(DEFINED ENV{CBMPC_PREFIX}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX}") + else() + get_filename_component(_cbmpc_repo_root "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE) + if(EXISTS "${_cbmpc_repo_root}/build/install/public") + set(CBMPC_SOURCE_DIR "${_cbmpc_repo_root}/build/install/public") + else() + set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) + endif() + endif() +endif() + +set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib") +if(EXISTS "${CBMPC_SOURCE_DIR}/lib/Release/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Release") +elseif(EXISTS "${CBMPC_SOURCE_DIR}/lib/Debug/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Debug") +endif() + +add_executable(mpc-demo-api-ecdsa_mp_pve_backup main.cpp) + +target_include_directories(mpc-demo-api-ecdsa_mp_pve_backup PRIVATE ${CBMPC_SOURCE_DIR}/include) +target_link_directories(mpc-demo-api-ecdsa_mp_pve_backup PRIVATE ${CBMPC_LIB_DIR}) +target_link_libraries(mpc-demo-api-ecdsa_mp_pve_backup PRIVATE cbmpc) + +# Important for static linking on Linux: ensure libcbmpc.a appears before +# libcrypto.a on the final link line so libcrypto symbols resolve correctly. +link_openssl(mpc-demo-api-ecdsa_mp_pve_backup) + diff --git a/demo-api/ecdsa_mp_pve_backup/main.cpp b/demo-api/ecdsa_mp_pve_backup/main.cpp new file mode 100644 index 00000000..d3d48b38 --- /dev/null +++ b/demo-api/ecdsa_mp_pve_backup/main.cpp @@ -0,0 +1,458 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +using namespace coinbase; + +[[noreturn]] void die(const std::string& msg) { + std::cerr << "ecdsa_mp_pve_backup demo failure: " << msg << "\n"; + std::exit(1); +} + +void require(bool ok, const std::string& msg) { + if (!ok) die(msg); +} + +void require_rv(error_t got, error_t want, const std::string& msg) { + if (got != want) die(msg + " (got=0x" + std::to_string(uint32_t(got)) + ")"); +} + +// ------------------------- +// Minimal in-memory MP transport +// ------------------------- + +struct channel_t { + std::mutex m; + std::condition_variable cv; + std::deque q; +}; + +struct in_memory_network_t { + const int n; + std::vector>> ch; + std::atomic aborted{false}; + + explicit in_memory_network_t(int n_) : n(n_), ch(static_cast(n_), std::vector>(static_cast(n_))) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + if (i == j) continue; + ch[static_cast(i)][static_cast(j)] = std::make_shared(); + } + } + } + + void abort() { + aborted.store(true); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + auto& c = ch[static_cast(i)][static_cast(j)]; + if (c) c->cv.notify_all(); + } + } + } +}; + +class in_memory_transport_t final : public coinbase::api::data_transport_i { + public: + in_memory_transport_t(int self, std::shared_ptr net) : self_(self), net_(std::move(net)) {} + + error_t send(coinbase::api::party_idx_t receiver, mem_t msg) override { + if (!net_) return E_GENERAL; + if (net_->aborted.load()) return E_NET_GENERAL; + if (receiver < 0 || receiver >= net_->n || receiver == self_) return E_BADARG; + auto c = net_->ch[static_cast(self_)][static_cast(receiver)]; + if (!c) return E_GENERAL; + { + std::lock_guard lk(c->m); + c->q.emplace_back(msg); + } + c->cv.notify_one(); + return SUCCESS; + } + + error_t receive(coinbase::api::party_idx_t sender, buf_t& msg) override { + if (!net_ || sender < 0 || sender >= net_->n || sender == self_) return E_BADARG; + auto c = net_->ch[static_cast(sender)][static_cast(self_)]; + if (!c) return E_GENERAL; + std::unique_lock lk(c->m); + c->cv.wait(lk, [&] { return net_->aborted.load() || !c->q.empty(); }); + if (net_->aborted.load() && c->q.empty()) return E_NET_GENERAL; + msg = std::move(c->q.front()); + c->q.pop_front(); + return SUCCESS; + } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + msgs.clear(); + msgs.resize(senders.size()); + for (size_t i = 0; i < senders.size(); i++) { + const error_t rv = receive(senders[i], msgs[i]); + if (rv) return rv; + } + return SUCCESS; + } + + private: + const int self_; + std::shared_ptr net_; +}; + +template +static void run_mp(const std::shared_ptr& net, int n, F&& f, std::vector& out_rv) { + out_rv.assign(static_cast(n), UNINITIALIZED_ERROR); + std::atomic aborted{false}; + std::vector threads; + threads.reserve(static_cast(n)); + + for (int i = 0; i < n; i++) { + threads.emplace_back([&, i] { + out_rv[static_cast(i)] = f(i); + if (out_rv[static_cast(i)] && !aborted.exchange(true)) { + net->abort(); + } + }); + } + for (auto& t : threads) t.join(); +} + +// ------------------------- +// OpenSSL verify helper (DER ECDSA, secp256k1/p256) +// ------------------------- + +struct ec_group_deleter_t { + void operator()(EC_GROUP* g) const { EC_GROUP_free(g); } +}; +struct ec_point_deleter_t { + void operator()(EC_POINT* p) const { EC_POINT_free(p); } +}; +struct ossl_param_bld_deleter_t { + void operator()(OSSL_PARAM_BLD* bld) const { OSSL_PARAM_BLD_free(bld); } +}; +struct ossl_param_deleter_t { + void operator()(OSSL_PARAM* p) const { OSSL_PARAM_free(p); } +}; +struct evp_pkey_ctx_deleter_t { + void operator()(EVP_PKEY_CTX* ctx) const { EVP_PKEY_CTX_free(ctx); } +}; +struct evp_pkey_deleter_t { + void operator()(EVP_PKEY* pkey) const { EVP_PKEY_free(pkey); } +}; + +static int curve_to_nid(coinbase::api::curve_id curve) { + switch (curve) { + case coinbase::api::curve_id::secp256k1: + return NID_secp256k1; + case coinbase::api::curve_id::p256: + return NID_X9_62_prime256v1; + case coinbase::api::curve_id::ed25519: + return NID_undef; + } + return NID_undef; +} + +static const char* nid_to_group_name(int nid) { + switch (nid) { + case NID_secp256k1: + return SN_secp256k1; + case NID_X9_62_prime256v1: + return SN_X9_62_prime256v1; + } + return nullptr; +} + +static bool verify_ecdsa_sig_der(coinbase::api::curve_id curve, mem_t pub_key_compressed, mem_t msg_hash, mem_t sig_der) { + const int nid = curve_to_nid(curve); + if (nid == NID_undef) return false; + const char* group_name = nid_to_group_name(nid); + if (!group_name) return false; + + std::unique_ptr group(EC_GROUP_new_by_curve_name(nid)); + if (!group) return false; + std::unique_ptr Q(EC_POINT_new(group.get())); + if (!Q) return false; + + if (EC_POINT_oct2point(group.get(), Q.get(), pub_key_compressed.data, static_cast(pub_key_compressed.size), + /*ctx=*/nullptr) != 1) { + return false; + } + + const size_t oct_len = EC_POINT_point2oct(group.get(), Q.get(), POINT_CONVERSION_UNCOMPRESSED, /*buf=*/nullptr, + /*len=*/0, /*ctx=*/nullptr); + if (oct_len == 0) return false; + std::vector pub_oct(oct_len); + if (EC_POINT_point2oct(group.get(), Q.get(), POINT_CONVERSION_UNCOMPRESSED, pub_oct.data(), pub_oct.size(), + /*ctx=*/nullptr) != oct_len) { + return false; + } + + std::unique_ptr bld(OSSL_PARAM_BLD_new()); + if (!bld) return false; + if (OSSL_PARAM_BLD_push_utf8_string(bld.get(), "group", group_name, 0) <= 0) return false; + if (OSSL_PARAM_BLD_push_octet_string(bld.get(), "pub", pub_oct.data(), pub_oct.size()) <= 0) return false; + std::unique_ptr params(OSSL_PARAM_BLD_to_param(bld.get())); + if (!params) return false; + + std::unique_ptr fromdata_ctx(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr)); + if (!fromdata_ctx) return false; + if (EVP_PKEY_fromdata_init(fromdata_ctx.get()) <= 0) return false; + EVP_PKEY* pkey_raw = nullptr; + if (EVP_PKEY_fromdata(fromdata_ctx.get(), &pkey_raw, EVP_PKEY_PUBLIC_KEY, params.get()) <= 0) return false; + std::unique_ptr pkey(pkey_raw); + + std::unique_ptr verify_ctx(EVP_PKEY_CTX_new(pkey.get(), nullptr)); + if (!verify_ctx) return false; + if (EVP_PKEY_verify_init(verify_ctx.get()) <= 0) return false; + const int v = EVP_PKEY_verify(verify_ctx.get(), sig_der.data, static_cast(sig_der.size), msg_hash.data, + static_cast(msg_hash.size)); + return v == 1; +} + +static buf_t make_msg_hash32(uint8_t seed) { + buf_t msg_hash(32); + for (int i = 0; i < msg_hash.size(); i++) msg_hash[i] = static_cast(seed + i); + return msg_hash; +} + +} // namespace + +int main(int /*argc*/, const char* /*argv*/[]) { + std::cout << std::boolalpha; + std::cout << "============= ECDSA-MP + PVE backup demo (api-only) =============\n"; + + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + + // Parties: p0..p4 + const int n = 5; + std::vector names = {"p0", "p1", "p2", "p3", "p4"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // THRESHOLD[3](p0, p1, p2, p3, p4) + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 3, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + coinbase::api::access_structure_t::leaf(names[4]), + }); + + // Only 3 parties actively contribute to DKG/refresh (all 5 must be online to run). + const std::vector quorum_party_names = {names[0], names[1], names[2]}; + + // ------------------------- + // Step 1: DKG (access-structure) + // ------------------------- + std::cout << "\n[1] DKG(ac): 5 parties, threshold 3-of-5, quorum contributors: {p0,p1,p2}\n"; + + auto net_dkg = std::make_shared(n); + std::vector> transports; + transports.reserve(static_cast(n)); + for (int i = 0; i < n; i++) transports.emplace_back(std::make_unique(i, net_dkg)); + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp(net_dkg, n, [&](int i) { + coinbase::api::job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_ac(job, curve, sids[static_cast(i)], ac, quorum_party_names, + key_blobs[static_cast(i)]); + }, rvs); + for (int i = 0; i < n; i++) require_rv(rvs[static_cast(i)], SUCCESS, "dkg_ac p" + std::to_string(i)); + for (int i = 1; i < n; i++) require(sids[0] == sids[static_cast(i)], "DKG sid mismatch"); + + buf_t pub; + require_rv(coinbase::api::ecdsa_mp::get_public_key_compressed(key_blobs[0], pub), SUCCESS, "get public key"); + for (int i = 1; i < n; i++) { + buf_t pub_i; + require_rv(coinbase::api::ecdsa_mp::get_public_key_compressed(key_blobs[static_cast(i)], pub_i), SUCCESS, + "get public key"); + require(pub_i == pub, "public key mismatch across parties"); + } + std::cout << "public key bytes: " << pub.size() << "\n"; + + // ------------------------- + // Step 2: Sign, then refresh + // ------------------------- + std::cout << "\n[2] Sign(ac) with quorum {p0,p1,p2}\n"; + { + const int qn = static_cast(quorum_party_names.size()); + auto net_sign = std::make_shared(qn); + std::vector> sign_transports; + sign_transports.reserve(static_cast(qn)); + for (int i = 0; i < qn; i++) sign_transports.emplace_back(std::make_unique(i, net_sign)); + + const buf_t msg_hash = make_msg_hash32(/*seed=*/0x11); + std::vector sigs(static_cast(qn)); + + run_mp(net_sign, qn, [&](int i) { + coinbase::api::job_mp_t job{static_cast(i), quorum_party_names, + *sign_transports[static_cast(i)]}; + // Map quorum index -> original party index (p0,p1,p2 == 0,1,2) + return coinbase::api::ecdsa_mp::sign_ac(job, key_blobs[static_cast(i)], ac, msg_hash, + /*sig_receiver=*/0, sigs[static_cast(i)]); + }, rvs); + for (int i = 0; i < qn; i++) require_rv(rvs[static_cast(i)], SUCCESS, "sign_ac"); + + require(!sigs[0].empty(), "signature should be returned on receiver (p0)"); + require(verify_ecdsa_sig_der(curve, mem_t(pub.data(), pub.size()), mem_t(msg_hash.data(), msg_hash.size()), + mem_t(sigs[0].data(), sigs[0].size())), + "signature verify failed"); + std::cout << "sign/verify ok\n"; + } + + std::cout << "\n[2] Refresh(ac): 5 parties, quorum contributors: {p0,p1,p2}\n"; + std::vector refreshed(n); + std::vector refresh_sids(n); + { + auto net_refresh = std::make_shared(n); + // Reuse transport objects but point them at a fresh network by rebuilding. + transports.clear(); + for (int i = 0; i < n; i++) transports.emplace_back(std::make_unique(i, net_refresh)); + + run_mp(net_refresh, n, [&](int i) { + coinbase::api::job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::refresh_ac(job, refresh_sids[static_cast(i)], key_blobs[static_cast(i)], + ac, quorum_party_names, refreshed[static_cast(i)]); + }, rvs); + for (int i = 0; i < n; i++) require_rv(rvs[static_cast(i)], SUCCESS, "refresh_ac p" + std::to_string(i)); + for (int i = 1; i < n; i++) require(refresh_sids[0] == refresh_sids[static_cast(i)], "refresh sid mismatch"); + + for (int i = 0; i < n; i++) { + buf_t pub_i; + require_rv(coinbase::api::ecdsa_mp::get_public_key_compressed(refreshed[static_cast(i)], pub_i), SUCCESS, + "get public key refreshed"); + require(pub_i == pub, "public key changed after refresh"); + } + std::cout << "refresh ok (public key stable)\n"; + } + + // ------------------------- + // Step 3: PVE backup of private scalar shares (verifiable via Qi_self) + // ------------------------- + std::cout << "\n[3] Backup each party's private scalar x_share using PVE (verifiable)\n"; + + std::vector redacted(n); + std::vector ct(n); + std::vector ek(n); + std::vector dk(n); + std::vector Qi_selfs(n); + std::vector labels(n); + + for (int i = 0; i < n; i++) { + require_rv(coinbase::api::ecdsa_mp::get_public_share_compressed(refreshed[static_cast(i)], + Qi_selfs[static_cast(i)]), + SUCCESS, "get_public_share_compressed"); + + buf_t x_fixed; + require_rv(coinbase::api::ecdsa_mp::detach_private_scalar(refreshed[static_cast(i)], + redacted[static_cast(i)], x_fixed), + SUCCESS, "detach_private_scalar"); + + require_rv(coinbase::api::pve::generate_base_pke_rsa_keypair(ek[static_cast(i)], dk[static_cast(i)]), + SUCCESS, "generate_base_pke_rsa_keypair"); + + labels[static_cast(i)] = std::string("ecdsa-mp-demo:pve-backup:") + names[static_cast(i)]; + const mem_t label_mem(labels[static_cast(i)]); + + require_rv(coinbase::api::pve::encrypt(curve, mem_t(ek[static_cast(i)].data(), ek[static_cast(i)].size()), + label_mem, mem_t(x_fixed.data(), x_fixed.size()), + ct[static_cast(i)]), + SUCCESS, "pve::encrypt"); + + require_rv(coinbase::api::pve::verify(curve, mem_t(ek[static_cast(i)].data(), ek[static_cast(i)].size()), + mem_t(ct[static_cast(i)].data(), ct[static_cast(i)].size()), + mem_t(Qi_selfs[static_cast(i)].data(), + Qi_selfs[static_cast(i)].size()), + label_mem), + SUCCESS, "pve::verify"); + } + std::cout << "PVE backup + verification ok\n"; + + // ------------------------- + // Step 4: Simulate party p1 losing the private scalar share (restore from PVE) + // Step 5: Sign again with quorum {p0,p1,p2} + // ------------------------- + std::cout << "\n[4] Party p1 loses private scalar; restore from PVE and attach into public blob\n"; + + auto restore_party = [&](int party_idx, buf_t& out_full_key_blob) { + const mem_t label_mem(labels[static_cast(party_idx)]); + buf_t x_out; + require_rv(coinbase::api::pve::decrypt(curve, mem_t(dk[static_cast(party_idx)].data(), dk[static_cast(party_idx)].size()), + mem_t(ek[static_cast(party_idx)].data(), ek[static_cast(party_idx)].size()), + mem_t(ct[static_cast(party_idx)].data(), ct[static_cast(party_idx)].size()), + label_mem, x_out), + SUCCESS, "pve::decrypt"); + require_rv(coinbase::api::ecdsa_mp::attach_private_scalar(redacted[static_cast(party_idx)], + mem_t(x_out.data(), x_out.size()), + mem_t(Qi_selfs[static_cast(party_idx)].data(), + Qi_selfs[static_cast(party_idx)].size()), + out_full_key_blob), + SUCCESS, "attach_private_scalar"); + }; + + // Restore quorum parties (p0, p1, p2). Only p1 is "lost", but restoring all three + // keeps the demo uniform after redaction. + std::vector restored_quorum(3); + restore_party(/*p0=*/0, restored_quorum[0]); + restore_party(/*p1=*/1, restored_quorum[1]); + restore_party(/*p2=*/2, restored_quorum[2]); + + std::cout << "\n[5] Sign(ac) again with quorum {p0,p1,p2} using restored key blobs\n"; + { + const int qn = 3; + auto net_sign2 = std::make_shared(qn); + std::vector> sign_transports; + sign_transports.reserve(static_cast(qn)); + for (int i = 0; i < qn; i++) sign_transports.emplace_back(std::make_unique(i, net_sign2)); + + const buf_t msg_hash = make_msg_hash32(/*seed=*/0x33); + std::vector sigs(static_cast(qn)); + + run_mp(net_sign2, qn, [&](int i) { + coinbase::api::job_mp_t job{static_cast(i), quorum_party_names, + *sign_transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::sign_ac(job, restored_quorum[static_cast(i)], ac, msg_hash, + /*sig_receiver=*/0, sigs[static_cast(i)]); + }, rvs); + for (int i = 0; i < qn; i++) require_rv(rvs[static_cast(i)], SUCCESS, "sign_ac (restored)"); + + require(!sigs[0].empty(), "signature should be returned on receiver (p0)"); + require(verify_ecdsa_sig_der(curve, mem_t(pub.data(), pub.size()), mem_t(msg_hash.data(), msg_hash.size()), + mem_t(sigs[0].data(), sigs[0].size())), + "restored signature verify failed"); + std::cout << "sign/verify after restore ok\n"; + } + + std::cout << "\nDone.\n"; + return 0; +} + diff --git a/demo-api/hd_keyset_ecdsa_2p/CMakeLists.txt b/demo-api/hd_keyset_ecdsa_2p/CMakeLists.txt new file mode 100644 index 00000000..04c73697 --- /dev/null +++ b/demo-api/hd_keyset_ecdsa_2p/CMakeLists.txt @@ -0,0 +1,43 @@ +cmake_minimum_required(VERSION 3.16) + +project(mpc-demo-api-hd_keyset_ecdsa_2p LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) + +include(${REPO_CMAKE_DIR}/macros.cmake) +include(${REPO_CMAKE_DIR}/arch.cmake) +include(${REPO_CMAKE_DIR}/openssl.cmake) +include(${REPO_CMAKE_DIR}/compilation_flags.cmake) + +if(NOT DEFINED CBMPC_SOURCE_DIR) + if(DEFINED ENV{CBMPC_PREFIX}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX}") + else() + get_filename_component(_cbmpc_repo_root "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE) + if(EXISTS "${_cbmpc_repo_root}/build/install/public") + set(CBMPC_SOURCE_DIR "${_cbmpc_repo_root}/build/install/public") + else() + set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) + endif() + endif() +endif() + +set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib") +if(EXISTS "${CBMPC_SOURCE_DIR}/lib/Release/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Release") +elseif(EXISTS "${CBMPC_SOURCE_DIR}/lib/Debug/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Debug") +endif() + +add_executable(mpc-demo-api-hd_keyset_ecdsa_2p main.cpp) + +target_include_directories(mpc-demo-api-hd_keyset_ecdsa_2p PRIVATE ${CBMPC_SOURCE_DIR}/include) +target_link_directories(mpc-demo-api-hd_keyset_ecdsa_2p PRIVATE ${CBMPC_LIB_DIR}) +target_link_libraries(mpc-demo-api-hd_keyset_ecdsa_2p PRIVATE cbmpc) + +# Important for static linking on Linux: ensure libcbmpc.a appears before +# libcrypto.a on the final link line so libcrypto symbols resolve correctly. +link_openssl(mpc-demo-api-hd_keyset_ecdsa_2p) diff --git a/demo-api/hd_keyset_ecdsa_2p/main.cpp b/demo-api/hd_keyset_ecdsa_2p/main.cpp new file mode 100644 index 00000000..535567bb --- /dev/null +++ b/demo-api/hd_keyset_ecdsa_2p/main.cpp @@ -0,0 +1,345 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +using namespace coinbase; + +[[noreturn]] void die(const std::string& msg) { + std::cerr << "hd_keyset_ecdsa_2p demo failure: " << msg << "\n"; + std::exit(1); +} + +void require(bool ok, const std::string& msg) { + if (!ok) die(msg); +} + +void require_rv(error_t got, error_t want, const std::string& msg) { + if (got != want) die(msg + " (got=0x" + std::to_string(uint32_t(got)) + ")"); +} + +// Minimal in-memory 2-party transport. +struct channel_t { + std::mutex m; + std::condition_variable cv; + std::deque q; +}; + +struct in_memory_network_t { + std::shared_ptr ch[2][2]; + std::atomic aborted{false}; + in_memory_network_t() { + ch[0][1] = std::make_shared(); + ch[1][0] = std::make_shared(); + } + + void abort() { + aborted.store(true); + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 2; j++) { + if (ch[i][j]) ch[i][j]->cv.notify_all(); + } + } + } +}; + +class in_memory_transport_t final : public coinbase::api::data_transport_i { + public: + in_memory_transport_t(int self, std::shared_ptr net) : self_(self), net_(std::move(net)) {} + + error_t send(coinbase::api::party_idx_t receiver, mem_t msg) override { + if (!net_) return E_GENERAL; + if (net_->aborted.load()) return E_NET_GENERAL; + if (receiver < 0 || receiver > 1 || receiver == self_) return E_BADARG; + auto c = net_->ch[self_][receiver]; + if (!c) return E_GENERAL; + { + std::lock_guard lk(c->m); + c->q.emplace_back(msg); + } + c->cv.notify_one(); + return SUCCESS; + } + + error_t receive(coinbase::api::party_idx_t sender, buf_t& msg) override { + if (!net_ || sender < 0 || sender > 1 || sender == self_) return E_BADARG; + auto c = net_->ch[sender][self_]; + if (!c) return E_GENERAL; + std::unique_lock lk(c->m); + c->cv.wait(lk, [&] { return net_->aborted.load() || !c->q.empty(); }); + if (net_->aborted.load() && c->q.empty()) return E_NET_GENERAL; + msg = std::move(c->q.front()); + c->q.pop_front(); + return SUCCESS; + } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + msgs.clear(); + msgs.resize(senders.size()); + for (size_t i = 0; i < senders.size(); i++) { + error_t rv = receive(senders[i], msgs[i]); + if (rv) return rv; + } + return SUCCESS; + } + + private: + const int self_; + std::shared_ptr net_; +}; + +template +void run_2pc(in_memory_network_t* net, F1&& f1, F2&& f2, error_t& out_rv1, error_t& out_rv2) { + std::thread t1([&] { + out_rv1 = f1(); + if (out_rv1 && net) net->abort(); + }); + std::thread t2([&] { + out_rv2 = f2(); + if (out_rv2 && net) net->abort(); + }); + t1.join(); + t2.join(); +} + +static int curve_to_nid(coinbase::api::curve_id curve) { + switch (curve) { + case coinbase::api::curve_id::secp256k1: + return NID_secp256k1; + case coinbase::api::curve_id::p256: + return NID_X9_62_prime256v1; + case coinbase::api::curve_id::ed25519: + return NID_undef; + } + return NID_undef; +} + +static const char* nid_to_group_name(int nid) { + switch (nid) { + case NID_secp256k1: + return SN_secp256k1; + case NID_X9_62_prime256v1: + return SN_X9_62_prime256v1; + } + return nullptr; +} + +struct ec_group_deleter_t { + void operator()(EC_GROUP* g) const { EC_GROUP_free(g); } +}; +struct ec_point_deleter_t { + void operator()(EC_POINT* p) const { EC_POINT_free(p); } +}; +struct ossl_param_bld_deleter_t { + void operator()(OSSL_PARAM_BLD* bld) const { OSSL_PARAM_BLD_free(bld); } +}; +struct ossl_param_deleter_t { + void operator()(OSSL_PARAM* p) const { OSSL_PARAM_free(p); } +}; +struct evp_pkey_ctx_deleter_t { + void operator()(EVP_PKEY_CTX* ctx) const { EVP_PKEY_CTX_free(ctx); } +}; +struct evp_pkey_deleter_t { + void operator()(EVP_PKEY* pkey) const { EVP_PKEY_free(pkey); } +}; + +static bool verify_ecdsa_sig_der(coinbase::api::curve_id curve, mem_t pub_key_compressed, mem_t msg_hash, mem_t sig_der) { + const int nid = curve_to_nid(curve); + if (nid == NID_undef) return false; + const char* group_name = nid_to_group_name(nid); + if (!group_name) return false; + + // Decode the compressed public key into an EC_POINT, then re-encode it as an + // uncompressed point (the provider-based EVP_PKEY import expects an encoded + // point blob). + std::unique_ptr group(EC_GROUP_new_by_curve_name(nid)); + if (!group) return false; + std::unique_ptr Q(EC_POINT_new(group.get())); + if (!Q) return false; + + if (EC_POINT_oct2point(group.get(), Q.get(), pub_key_compressed.data, static_cast(pub_key_compressed.size), + /*ctx=*/nullptr) != 1) { + return false; + } + + const size_t oct_len = + EC_POINT_point2oct(group.get(), Q.get(), POINT_CONVERSION_UNCOMPRESSED, /*buf=*/nullptr, /*len=*/0, /*ctx=*/nullptr); + if (oct_len == 0) return false; + std::vector pub_oct(oct_len); + if (EC_POINT_point2oct(group.get(), Q.get(), POINT_CONVERSION_UNCOMPRESSED, pub_oct.data(), pub_oct.size(), + /*ctx=*/nullptr) != oct_len) { + return false; + } + + std::unique_ptr bld(OSSL_PARAM_BLD_new()); + if (!bld) return false; + if (OSSL_PARAM_BLD_push_utf8_string(bld.get(), "group", group_name, 0) <= 0) return false; + if (OSSL_PARAM_BLD_push_octet_string(bld.get(), "pub", pub_oct.data(), pub_oct.size()) <= 0) return false; + std::unique_ptr params(OSSL_PARAM_BLD_to_param(bld.get())); + if (!params) return false; + + std::unique_ptr fromdata_ctx(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr)); + if (!fromdata_ctx) return false; + if (EVP_PKEY_fromdata_init(fromdata_ctx.get()) <= 0) return false; + EVP_PKEY* pkey_raw = nullptr; + if (EVP_PKEY_fromdata(fromdata_ctx.get(), &pkey_raw, EVP_PKEY_PUBLIC_KEY, params.get()) <= 0) return false; + std::unique_ptr pkey(pkey_raw); + + std::unique_ptr verify_ctx(EVP_PKEY_CTX_new(pkey.get(), nullptr)); + if (!verify_ctx) return false; + if (EVP_PKEY_verify_init(verify_ctx.get()) <= 0) return false; + const int v = EVP_PKEY_verify(verify_ctx.get(), sig_der.data, static_cast(sig_der.size), msg_hash.data, + static_cast(msg_hash.size)); + return v == 1; +} + +void demo_curve(coinbase::api::curve_id curve) { + std::cout << "\n=== HD keyset ECDSA-2P (api) curve=" << (curve == coinbase::api::curve_id::secp256k1 ? "secp256k1" : "p256") + << " ===\n"; + + auto net = std::make_shared(); + in_memory_transport_t t1(/*self=*/0, net); + in_memory_transport_t t2(/*self=*/1, net); + + const coinbase::api::job_2p_t job1{coinbase::api::party_2p_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{coinbase::api::party_2p_t::p2, "p1", "p2", t2}; + + buf_t keyset1; + buf_t keyset2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + run_2pc(net.get(), [&] { return coinbase::api::hd_keyset_ecdsa_2p::dkg(job1, curve, keyset1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::dkg(job2, curve, keyset2); }, rv1, rv2); + require_rv(rv1, SUCCESS, "dkg p1"); + require_rv(rv2, SUCCESS, "dkg p2"); + + buf_t root_pub1; + buf_t root_pub2; + require_rv(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(keyset1, root_pub1), SUCCESS, + "extract root pub p1"); + require_rv(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(keyset2, root_pub2), SUCCESS, + "extract root pub p2"); + require(root_pub1 == root_pub2, "root pub keys must match"); + std::cout << "root public key bytes: " << root_pub1.size() << "\n"; + + // Derivation paths. + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 1}}); + + std::vector derived1; + std::vector derived2; + buf_t sid1; + buf_t sid2; + run_2pc( + net.get(), + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job1, keyset1, hard, non_hard, sid1, derived1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job2, keyset2, hard, non_hard, sid2, derived2); }, + rv1, rv2); + require_rv(rv1, SUCCESS, "derive p1"); + require_rv(rv2, SUCCESS, "derive p2"); + require(sid1 == sid2, "sid must match"); + std::cout << "derived keys: " << derived1.size() << "\n"; + + for (size_t i = 0; i < derived1.size(); i++) { + buf_t pub_a; + buf_t pub_b; + require_rv(coinbase::api::ecdsa_2p::get_public_key_compressed(derived1[i], pub_a), SUCCESS, + "extract derived pub p1"); + require_rv(coinbase::api::ecdsa_2p::get_public_key_compressed(derived2[i], pub_b), SUCCESS, + "extract derived pub p2"); + require(pub_a == pub_b, "derived pubkeys must match across parties"); + } + + // Sign with derived key #0. + buf_t msg_hash(32); + for (int i = 0; i < msg_hash.size(); i++) msg_hash[i] = static_cast(0x42 + i); + + buf_t sig1; + buf_t sig2; + buf_t sid3; + buf_t sid4; + run_2pc(net.get(), [&] { return coinbase::api::ecdsa_2p::sign(job1, derived1[0], msg_hash, sid3, sig1); }, + [&] { return coinbase::api::ecdsa_2p::sign(job2, derived2[0], msg_hash, sid4, sig2); }, rv1, rv2); + require_rv(rv1, SUCCESS, "sign p1"); + require_rv(rv2, SUCCESS, "sign p2"); + require(sid3 == sid4, "sign sid mismatch"); + require(!sig1.empty(), "signature should be returned on p1"); + + buf_t derived_pub; + require_rv(coinbase::api::ecdsa_2p::get_public_key_compressed(derived1[0], derived_pub), SUCCESS, + "extract pub for verify"); + require(verify_ecdsa_sig_der(curve, derived_pub, msg_hash, sig1), "signature verification failed"); + std::cout << "sign/verify ok\n"; + + // Refresh keyset shares and derive again. + buf_t keyset1_ref; + buf_t keyset2_ref; + run_2pc(net.get(), [&] { return coinbase::api::hd_keyset_ecdsa_2p::refresh(job1, keyset1, keyset1_ref); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::refresh(job2, keyset2, keyset2_ref); }, rv1, rv2); + require_rv(rv1, SUCCESS, "refresh p1"); + require_rv(rv2, SUCCESS, "refresh p2"); + + std::vector derived1_ref; + std::vector derived2_ref; + buf_t sid5; + buf_t sid6; + run_2pc( + net.get(), + [&] { + return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job1, keyset1_ref, hard, non_hard, sid5, + derived1_ref); + }, + [&] { + return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job2, keyset2_ref, hard, non_hard, sid6, + derived2_ref); + }, + rv1, rv2); + require_rv(rv1, SUCCESS, "derive after refresh p1"); + require_rv(rv2, SUCCESS, "derive after refresh p2"); + + for (size_t i = 0; i < derived1.size(); i++) { + buf_t pub_old; + buf_t pub_new; + require_rv(coinbase::api::ecdsa_2p::get_public_key_compressed(derived1[i], pub_old), SUCCESS, "extract pub old"); + require_rv(coinbase::api::ecdsa_2p::get_public_key_compressed(derived1_ref[i], pub_new), SUCCESS, + "extract pub new"); + require(pub_old == pub_new, "derived pubkey changed across refresh (unexpected)"); + } + std::cout << "refresh/derive stable ok\n"; +} + +} // namespace + +int main(int /*argc*/, const char* /*argv*/[]) { + std::cout << "============= HD Keyset ECDSA 2P Demo (api-only) =============\n"; + demo_curve(coinbase::api::curve_id::secp256k1); + demo_curve(coinbase::api::curve_id::p256); + return 0; +} diff --git a/demo-api/pve/CMakeLists.txt b/demo-api/pve/CMakeLists.txt new file mode 100644 index 00000000..1ff0ac28 --- /dev/null +++ b/demo-api/pve/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.16) + +project(mpc-demo-api-pve LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) + +include(${REPO_CMAKE_DIR}/macros.cmake) +include(${REPO_CMAKE_DIR}/arch.cmake) +include(${REPO_CMAKE_DIR}/openssl.cmake) +include(${REPO_CMAKE_DIR}/compilation_flags.cmake) + +if(NOT DEFINED CBMPC_SOURCE_DIR) + if(DEFINED ENV{CBMPC_PREFIX}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX}") + else() + get_filename_component(_cbmpc_repo_root "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE) + if(EXISTS "${_cbmpc_repo_root}/build/install/public") + set(CBMPC_SOURCE_DIR "${_cbmpc_repo_root}/build/install/public") + else() + set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) + endif() + endif() +endif() + +set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib") +if(EXISTS "${CBMPC_SOURCE_DIR}/lib/Release/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Release") +elseif(EXISTS "${CBMPC_SOURCE_DIR}/lib/Debug/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Debug") +endif() + +add_executable(mpc-demo-api-pve main.cpp) + +link_openssl(mpc-demo-api-pve) +target_include_directories(mpc-demo-api-pve PRIVATE ${CBMPC_SOURCE_DIR}/include) +target_link_directories(mpc-demo-api-pve PRIVATE ${CBMPC_LIB_DIR}) +target_link_libraries(mpc-demo-api-pve PRIVATE cbmpc) + +if(IS_LINUX) + link_openssl(mpc-demo-api-pve) +endif() + diff --git a/demo-api/pve/main.cpp b/demo-api/pve/main.cpp new file mode 100644 index 00000000..42f739a8 --- /dev/null +++ b/demo-api/pve/main.cpp @@ -0,0 +1,747 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using namespace coinbase; + +namespace { + +// Minimal, demo-only "base PKE" that satisfies the PVE interface contract: +// deterministic encryption given `rho`, reversible decryption. +// +// - Key format: `ek` and `dk` are the same 32-byte key. +// - Ciphertext format: `ct = rho32 || (plain XOR SHA256(key || label || rho32 || ctr)...)`. +class toy_base_pke_t : public coinbase::api::pve::base_pke_i { + public: + error_t encrypt(mem_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { + if (ek.size != 32) return coinbase::error(E_BADARG, "toy_base_pke: expected 32-byte key"); + if (rho.size != 32) return coinbase::error(E_BADARG, "toy_base_pke: expected 32-byte rho"); + + out_ct = buf_t(rho.size + plain.size); + std::memmove(out_ct.data(), rho.data, static_cast(rho.size)); + + xor_keystream(ek, label, rho, /*in_out=*/mem_t(out_ct.data() + rho.size, plain.size), plain); + return SUCCESS; + } + + error_t decrypt(mem_t dk, mem_t label, mem_t ct, buf_t& out_plain) const override { + if (dk.size != 32) return coinbase::error(E_BADARG, "toy_base_pke: expected 32-byte key"); + if (ct.size < 32) return coinbase::error(E_FORMAT, "toy_base_pke: ciphertext too small"); + + const mem_t rho(ct.data, 32); + const mem_t cipher(ct.data + 32, ct.size - 32); + + out_plain = buf_t(cipher.size); + std::memmove(out_plain.data(), cipher.data, static_cast(cipher.size)); + xor_keystream(dk, label, rho, /*in_out=*/mem_t(out_plain.data(), out_plain.size()), /*plain=*/mem_t()); + return SUCCESS; + } + + private: + static void xor_keystream(mem_t key, mem_t label, mem_t rho, mem_t in_out, mem_t plain) { + // If `plain` is provided, XOR it into `in_out` while generating keystream. + // Otherwise, `in_out` is already the ciphertext and we XOR keystream in-place to decrypt. + cb_assert(key.size == 32); + cb_assert(rho.size == 32); + cb_assert(in_out.size >= 0); + cb_assert(plain.size == 0 || plain.size == in_out.size); + + uint8_t digest[32]; + EVP_MD_CTX* md = EVP_MD_CTX_new(); + cb_assert(md); + int out_off = 0; + uint32_t ctr = 0; + while (out_off < in_out.size) { + cb_assert(EVP_DigestInit_ex(md, EVP_sha256(), nullptr) == 1); + cb_assert(EVP_DigestUpdate(md, key.data, static_cast(key.size)) == 1); + cb_assert(EVP_DigestUpdate(md, label.data, static_cast(label.size)) == 1); + cb_assert(EVP_DigestUpdate(md, rho.data, static_cast(rho.size)) == 1); + cb_assert(EVP_DigestUpdate(md, &ctr, sizeof(ctr)) == 1); + unsigned int digest_len = 0; + cb_assert(EVP_DigestFinal_ex(md, digest, &digest_len) == 1); + cb_assert(digest_len == sizeof(digest)); + + const int n = std::min(static_cast(sizeof(digest)), in_out.size - out_off); + for (int i = 0; i < n; i++) { + const uint8_t ks = digest[i]; + const uint8_t src = (plain.size == 0) ? in_out[out_off + i] : plain[out_off + i]; + in_out[out_off + i] = static_cast(src ^ ks); + } + out_off += n; + ctr++; + } + EVP_MD_CTX_free(md); + secure_bzero(digest, static_cast(sizeof(digest))); + } +}; + +void demo_default_base_pke_rsa() { + std::cout << "\n=== PVE (api) + built-in RSA key blob ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-demo-label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0xC0 + i); + const mem_t x(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ek_blob; + buf_t dk_blob; + cb_assert(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob) == SUCCESS); + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt(curve, ek_blob, label, x, ct) == SUCCESS); + + buf_t Q; + cb_assert(coinbase::api::pve::get_public_key_compressed(ct, Q) == SUCCESS); + + buf_t label_extracted; + cb_assert(coinbase::api::pve::get_Label(ct, label_extracted) == SUCCESS); + std::cout << "label extracted matches? " << (label_extracted == buf_t(label)) << "\n"; + + cb_assert(coinbase::api::pve::verify(curve, ek_blob, ct, Q, label) == SUCCESS); + + buf_t x_out; + cb_assert(coinbase::api::pve::decrypt(curve, dk_blob, ek_blob, ct, label, x_out) == SUCCESS); + std::cout << "decrypt ok? " << (x_out == buf_t(x)) << "\n"; +} + +void demo_default_base_pke_ecies() { + std::cout << "\n=== PVE (api) + built-in ECIES(P-256) key blob ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-demo-label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x44 + i); + const mem_t x(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ek_blob; + buf_t dk_blob; + cb_assert(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(ek_blob, dk_blob) == SUCCESS); + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt(curve, ek_blob, label, x, ct) == SUCCESS); + + buf_t Q; + cb_assert(coinbase::api::pve::get_public_key_compressed(ct, Q) == SUCCESS); + cb_assert(coinbase::api::pve::verify(curve, ek_blob, ct, Q, label) == SUCCESS); + + buf_t x_out; + cb_assert(coinbase::api::pve::decrypt(curve, dk_blob, ek_blob, ct, label, x_out) == SUCCESS); + std::cout << "decrypt ok? " << (x_out == buf_t(x)) << "\n"; +} + +struct ec_group_deleter_t { + void operator()(EC_GROUP* g) const { EC_GROUP_free(g); } +}; +struct ec_point_deleter_t { + void operator()(EC_POINT* p) const { EC_POINT_free(p); } +}; +struct ossl_param_bld_deleter_t { + void operator()(OSSL_PARAM_BLD* bld) const { OSSL_PARAM_BLD_free(bld); } +}; +struct ossl_param_deleter_t { + void operator()(OSSL_PARAM* p) const { OSSL_PARAM_free(p); } +}; +struct evp_pkey_ctx_deleter_t { + void operator()(EVP_PKEY_CTX* ctx) const { EVP_PKEY_CTX_free(ctx); } +}; +struct evp_pkey_deleter_t { + void operator()(EVP_PKEY* pkey) const { EVP_PKEY_free(pkey); } +}; + +using evp_pkey_ctx_ptr_t = std::unique_ptr; +using evp_pkey_ptr_t = std::unique_ptr; + +static error_t ensure_p256_pubkey_oct_uncompressed(mem_t pub_oct_any, buf_t& out_pub_oct_uncompressed) { + if (pub_oct_any.size == 65 && pub_oct_any.data && pub_oct_any.data[0] == 0x04) { + out_pub_oct_uncompressed = buf_t(pub_oct_any.data, pub_oct_any.size); + return SUCCESS; + } + + std::unique_ptr group(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)); + if (!group) return coinbase::error(E_INSUFFICIENT, "EC_GROUP_new_by_curve_name failed"); + + std::unique_ptr pt(EC_POINT_new(group.get())); + if (!pt) return coinbase::error(E_INSUFFICIENT, "EC_POINT_new failed"); + + if (EC_POINT_oct2point(group.get(), pt.get(), pub_oct_any.data, static_cast(pub_oct_any.size), /*ctx=*/nullptr) != 1) { + return coinbase::error(E_FORMAT, "invalid EC(P-256) public key encoding"); + } + + uint8_t oct[65]; + const size_t written = + EC_POINT_point2oct(group.get(), pt.get(), POINT_CONVERSION_UNCOMPRESSED, oct, sizeof(oct), /*ctx=*/nullptr); + if (written != sizeof(oct)) return coinbase::error(E_GENERAL, "unexpected P-256 public key size"); + + out_pub_oct_uncompressed = buf_t(oct, static_cast(written)); + return SUCCESS; +} + +// Demo-only "HSM" that holds EC(P-256) private keys and can perform ECDH. +class fake_hsm_ecies_p256_t { + public: + error_t generate_key(std::string handle, buf_t& out_pub_key_oct_uncompressed) { + if (keys_.find(handle) != keys_.end()) return coinbase::error(E_BADARG, "duplicate HSM key handle"); + + evp_pkey_ctx_ptr_t keygen_ctx(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr)); + if (!keygen_ctx) return coinbase::error(E_INSUFFICIENT, "EVP_PKEY_CTX_new_from_name(EC) failed"); + if (EVP_PKEY_keygen_init(keygen_ctx.get()) <= 0) return coinbase::error(E_GENERAL, "EVP_PKEY_keygen_init failed"); + + std::unique_ptr bld(OSSL_PARAM_BLD_new()); + if (!bld) return coinbase::error(E_INSUFFICIENT, "OSSL_PARAM_BLD_new failed"); + if (OSSL_PARAM_BLD_push_utf8_string(bld.get(), "group", SN_X9_62_prime256v1, 0) <= 0) { + return coinbase::error(E_GENERAL, "OSSL_PARAM_BLD_push_utf8_string(group) failed"); + } + std::unique_ptr params(OSSL_PARAM_BLD_to_param(bld.get())); + if (!params) return coinbase::error(E_INSUFFICIENT, "OSSL_PARAM_BLD_to_param failed"); + if (EVP_PKEY_CTX_set_params(keygen_ctx.get(), params.get()) <= 0) { + return coinbase::error(E_GENERAL, "EVP_PKEY_CTX_set_params(group) failed"); + } + + EVP_PKEY* key_raw = nullptr; + if (EVP_PKEY_keygen(keygen_ctx.get(), &key_raw) <= 0 || !key_raw) { + return coinbase::error(E_GENERAL, "EVP_PKEY_keygen failed"); + } + evp_pkey_ptr_t key(key_raw); + + // Extract public key octets from the provider key and ensure uncompressed + // format (65 bytes, 0x04 || X || Y). + size_t pub_len = 0; + if (EVP_PKEY_get_octet_string_param(key.get(), "pub", /*out=*/nullptr, /*outlen=*/0, &pub_len) <= 0 || pub_len == 0) { + return coinbase::error(E_GENERAL, "EVP_PKEY_get_octet_string_param(pub) failed (size)"); + } + std::vector pub(pub_len); + if (EVP_PKEY_get_octet_string_param(key.get(), "pub", pub.data(), pub.size(), &pub_len) <= 0) { + return coinbase::error(E_GENERAL, "EVP_PKEY_get_octet_string_param(pub) failed"); + } + pub.resize(pub_len); + error_t rv = ensure_p256_pubkey_oct_uncompressed(mem_t(pub.data(), static_cast(pub.size())), out_pub_key_oct_uncompressed); + if (rv) return rv; + + keys_.emplace(std::move(handle), std::move(key)); + return SUCCESS; + } + + error_t ecdh_x32(mem_t handle, mem_t peer_pub_oct_uncompressed, buf_t& out_dh_x32) const { + const std::string h(reinterpret_cast(handle.data), static_cast(handle.size)); + const auto it = keys_.find(h); + if (it == keys_.end()) return coinbase::error(E_BADARG, "unknown HSM key handle"); + + if (peer_pub_oct_uncompressed.size != 65 || !peer_pub_oct_uncompressed.data || peer_pub_oct_uncompressed.data[0] != 0x04) { + return coinbase::error(E_FORMAT, "peer public key must be uncompressed P-256 octets"); + } + + // Import peer public key as EVP_PKEY (provider-based). + std::unique_ptr peer_bld(OSSL_PARAM_BLD_new()); + if (!peer_bld) return coinbase::error(E_INSUFFICIENT, "OSSL_PARAM_BLD_new failed"); + if (OSSL_PARAM_BLD_push_utf8_string(peer_bld.get(), "group", SN_X9_62_prime256v1, 0) <= 0) { + return coinbase::error(E_GENERAL, "OSSL_PARAM_BLD_push_utf8_string(group) failed"); + } + if (OSSL_PARAM_BLD_push_octet_string(peer_bld.get(), "pub", peer_pub_oct_uncompressed.data, + static_cast(peer_pub_oct_uncompressed.size)) <= 0) { + return coinbase::error(E_GENERAL, "OSSL_PARAM_BLD_push_octet_string(pub) failed"); + } + std::unique_ptr peer_params(OSSL_PARAM_BLD_to_param(peer_bld.get())); + if (!peer_params) return coinbase::error(E_INSUFFICIENT, "OSSL_PARAM_BLD_to_param failed"); + + evp_pkey_ctx_ptr_t peer_fromdata_ctx(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr)); + if (!peer_fromdata_ctx) return coinbase::error(E_INSUFFICIENT, "EVP_PKEY_CTX_new_from_name(EC) failed"); + if (EVP_PKEY_fromdata_init(peer_fromdata_ctx.get()) <= 0) return coinbase::error(E_GENERAL, "EVP_PKEY_fromdata_init failed"); + EVP_PKEY* peer_key_raw = nullptr; + if (EVP_PKEY_fromdata(peer_fromdata_ctx.get(), &peer_key_raw, EVP_PKEY_PUBLIC_KEY, peer_params.get()) <= 0 || !peer_key_raw) { + return coinbase::error(E_FORMAT, "EVP_PKEY_fromdata(peer pub) failed"); + } + evp_pkey_ptr_t peer_key(peer_key_raw); + + EVP_PKEY* key = it->second.get(); + evp_pkey_ctx_ptr_t derive_ctx(EVP_PKEY_CTX_new(key, nullptr)); + if (!derive_ctx) return coinbase::error(E_INSUFFICIENT, "EVP_PKEY_CTX_new failed"); + if (EVP_PKEY_derive_init(derive_ctx.get()) <= 0) return coinbase::error(E_GENERAL, "EVP_PKEY_derive_init failed"); + if (EVP_PKEY_derive_set_peer(derive_ctx.get(), peer_key.get()) <= 0) { + return coinbase::error(E_GENERAL, "EVP_PKEY_derive_set_peer failed"); + } + + size_t dh_len = 0; + if (EVP_PKEY_derive(derive_ctx.get(), /*key=*/nullptr, &dh_len) <= 0 || dh_len == 0) { + return coinbase::error(E_GENERAL, "EVP_PKEY_derive(size) failed"); + } + std::vector dh(dh_len); + if (EVP_PKEY_derive(derive_ctx.get(), dh.data(), &dh_len) <= 0) { + return coinbase::error(E_GENERAL, "EVP_PKEY_derive failed"); + } + dh.resize(dh_len); + if (dh.size() > 32) return coinbase::error(E_GENERAL, "unexpected ECDH output length"); + + std::array x32{}; + std::memmove(x32.data() + (32 - dh.size()), dh.data(), dh.size()); + out_dh_x32 = buf_t(x32.data(), static_cast(x32.size())); + return SUCCESS; + } + + private: + std::map keys_; +}; + +static error_t ecies_p256_hsm_ecdh_cb(void* ctx, mem_t dk_handle, mem_t kem_ct, buf_t& out_dh_x32) { + if (!ctx) return coinbase::error(E_BADARG, "missing HSM ctx"); + return static_cast(ctx)->ecdh_x32(dk_handle, kem_ct, out_dh_x32); +} + +void demo_hsm_ecies_p256() { + std::cout << "\n=== PVE (api) + ECIES(P-256) simulated HSM (ECDH callback only) ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-demo-label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x88 + i); + const mem_t x(x_bytes.data(), static_cast(x_bytes.size())); + + // 1) Simulate an HSM that generates/stores the private key and exports only the public key. + fake_hsm_ecies_p256_t hsm; + const std::string hsm_handle = "hsm-ecies-p256-key-1"; + buf_t pub_key_oct_uncompressed; + cb_assert(hsm.generate_key(hsm_handle, pub_key_oct_uncompressed) == SUCCESS); + + // 2) Wrap the exported public key into cbmpc's opaque base-PKE ek blob. + buf_t ek_blob; + cb_assert(coinbase::api::pve::base_pke_ecies_p256_ek_from_oct(pub_key_oct_uncompressed, ek_blob) == SUCCESS); + + // 3) Encrypt and verify using the software public key blob. + buf_t ct; + cb_assert(coinbase::api::pve::encrypt(curve, ek_blob, label, x, ct) == SUCCESS); + + buf_t Q; + cb_assert(coinbase::api::pve::get_public_key_compressed(ct, Q) == SUCCESS); + cb_assert(coinbase::api::pve::verify(curve, ek_blob, ct, Q, label) == SUCCESS); + + // 4) Decrypt using the HSM callback for *only* the ECDH step. + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb; + cb.ctx = &hsm; + cb.ecdh = ecies_p256_hsm_ecdh_cb; + + const mem_t dk_handle(hsm_handle); + buf_t x_out; + cb_assert(coinbase::api::pve::decrypt_ecies_p256_hsm(curve, dk_handle, ek_blob, ct, label, cb, x_out) == SUCCESS); + std::cout << "decrypt ok? " << (x_out == buf_t(x)) << "\n"; +} + +void demo_custom_base_pke() { + std::cout << "\n=== PVE (api) + custom base PKE ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-demo-label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x66 + i); + const mem_t x(x_bytes.data(), static_cast(x_bytes.size())); + + // Symmetric "key" used as both ek and dk in this toy base PKE. + std::array key{}; + for (int i = 0; i < 32; i++) key[static_cast(i)] = static_cast(i); + const mem_t ek(key.data(), 32); + const mem_t dk(key.data(), 32); + + toy_base_pke_t toy; + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt(toy, curve, ek, label, x, ct) == SUCCESS); + + buf_t Q; + cb_assert(coinbase::api::pve::get_public_key_compressed(ct, Q) == SUCCESS); + cb_assert(coinbase::api::pve::verify(toy, curve, ek, ct, Q, label) == SUCCESS); + + buf_t x_out; + cb_assert(coinbase::api::pve::decrypt(toy, curve, dk, ek, ct, label, x_out) == SUCCESS); + std::cout << "decrypt ok? " << (x_out == buf_t(x)) << "\n"; +} + +void demo_ac_default_base_pke_rsa() { + std::cout << "\n=== PVE-AC (api) + built-in RSA key blobs (stepwise decrypt) ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-ac-demo-label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + constexpr int n = 8; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(0x10 + i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + std::array eks{}; + std::array dks{}; + cb_assert(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[0], dks[0]) == SUCCESS); + cb_assert(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[1], dks[1]) == SUCCESS); + cb_assert(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[2], dks[2]) == SUCCESS); + + coinbase::api::pve::leaf_keys_t ac_pks; + cb_assert(ac_pks.emplace("p1", mem_t(eks[0].data(), eks[0].size())).second); + cb_assert(ac_pks.emplace("p2", mem_t(eks[1].data(), eks[1].size())).second); + cb_assert(ac_pks.emplace("p3", mem_t(eks[2].data(), eks[2].size())).second); + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt_ac(curve, ac, ac_pks, label, xs, ct) == SUCCESS); + + int batch_count = 0; + cb_assert(coinbase::api::pve::get_ac_batch_count(ct, batch_count) == SUCCESS); + std::cout << "batch_count: " << batch_count << "\n"; + + std::vector Qs; + cb_assert(coinbase::api::pve::get_public_keys_compressed_ac(ct, Qs) == SUCCESS); + std::vector Qs_mem; + Qs_mem.reserve(Qs.size()); + for (const auto& q : Qs) Qs_mem.emplace_back(q.data(), q.size()); + cb_assert(coinbase::api::pve::verify_ac(curve, ac, ac_pks, ct, Qs_mem, label) == SUCCESS); + + const int attempt_index = 0; + buf_t share_p1; + buf_t share_p2; + cb_assert(coinbase::api::pve::partial_decrypt_ac_attempt(curve, ac, ct, attempt_index, "p1", + mem_t(dks[0].data(), dks[0].size()), label, share_p1) == SUCCESS); + cb_assert(coinbase::api::pve::partial_decrypt_ac_attempt(curve, ac, ct, attempt_index, "p2", + mem_t(dks[1].data(), dks[1].size()), label, share_p2) == SUCCESS); + + coinbase::api::pve::leaf_shares_t quorum; + cb_assert(quorum.emplace("p1", mem_t(share_p1.data(), share_p1.size())).second); + cb_assert(quorum.emplace("p2", mem_t(share_p2.data(), share_p2.size())).second); + + std::vector xs_out; + cb_assert(coinbase::api::pve::combine_ac(curve, ac, ct, attempt_index, label, quorum, xs_out) == SUCCESS); + + bool ok = (xs_out.size() == xs.size()); + for (int i = 0; ok && i < n; i++) ok = (xs_out[static_cast(i)] == buf_t(xs[static_cast(i)])); + std::cout << "recover ok? " << ok << "\n"; +} + +void demo_ac_default_base_pke_ecies() { + std::cout << "\n=== PVE-AC (api) + built-in ECIES(P-256) key blobs (stepwise decrypt) ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-ac-demo-label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + constexpr int n = 8; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(0x40 + i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + std::array eks{}; + std::array dks{}; + cb_assert(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(eks[0], dks[0]) == SUCCESS); + cb_assert(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(eks[1], dks[1]) == SUCCESS); + cb_assert(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(eks[2], dks[2]) == SUCCESS); + + coinbase::api::pve::leaf_keys_t ac_pks; + cb_assert(ac_pks.emplace("p1", mem_t(eks[0].data(), eks[0].size())).second); + cb_assert(ac_pks.emplace("p2", mem_t(eks[1].data(), eks[1].size())).second); + cb_assert(ac_pks.emplace("p3", mem_t(eks[2].data(), eks[2].size())).second); + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt_ac(curve, ac, ac_pks, label, xs, ct) == SUCCESS); + + std::vector Qs; + cb_assert(coinbase::api::pve::get_public_keys_compressed_ac(ct, Qs) == SUCCESS); + std::vector Qs_mem; + Qs_mem.reserve(Qs.size()); + for (const auto& q : Qs) Qs_mem.emplace_back(q.data(), q.size()); + cb_assert(coinbase::api::pve::verify_ac(curve, ac, ac_pks, ct, Qs_mem, label) == SUCCESS); + + const int attempt_index = 0; + buf_t share_p2; + buf_t share_p3; + cb_assert(coinbase::api::pve::partial_decrypt_ac_attempt(curve, ac, ct, attempt_index, "p2", + mem_t(dks[1].data(), dks[1].size()), label, share_p2) == SUCCESS); + cb_assert(coinbase::api::pve::partial_decrypt_ac_attempt(curve, ac, ct, attempt_index, "p3", + mem_t(dks[2].data(), dks[2].size()), label, share_p3) == SUCCESS); + + coinbase::api::pve::leaf_shares_t quorum; + cb_assert(quorum.emplace("p2", mem_t(share_p2.data(), share_p2.size())).second); + cb_assert(quorum.emplace("p3", mem_t(share_p3.data(), share_p3.size())).second); + + std::vector xs_out; + cb_assert(coinbase::api::pve::combine_ac(curve, ac, ct, attempt_index, label, quorum, xs_out) == SUCCESS); + + bool ok = (xs_out.size() == xs.size()); + for (int i = 0; ok && i < n; i++) ok = (xs_out[static_cast(i)] == buf_t(xs[static_cast(i)])); + std::cout << "recover ok? " << ok << "\n"; +} + +void demo_ac_custom_base_pke() { + std::cout << "\n=== PVE-AC (api) + custom base PKE (stepwise decrypt) ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-ac-demo-label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + constexpr int n = 8; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(0x70 + i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + // Per-leaf toy keys (ek == dk). + std::array, 3> keys{}; + for (int p = 0; p < 3; p++) { + for (int i = 0; i < 32; i++) keys[static_cast(p)][static_cast(i)] = static_cast(p * 0x11 + i); + } + + coinbase::api::pve::leaf_keys_t ac_pks; + cb_assert(ac_pks.emplace("p1", mem_t(keys[0].data(), 32)).second); + cb_assert(ac_pks.emplace("p2", mem_t(keys[1].data(), 32)).second); + cb_assert(ac_pks.emplace("p3", mem_t(keys[2].data(), 32)).second); + + toy_base_pke_t toy; + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt_ac(toy, curve, ac, ac_pks, label, xs, ct) == SUCCESS); + + std::vector Qs; + cb_assert(coinbase::api::pve::get_public_keys_compressed_ac(ct, Qs) == SUCCESS); + std::vector Qs_mem; + Qs_mem.reserve(Qs.size()); + for (const auto& q : Qs) Qs_mem.emplace_back(q.data(), q.size()); + cb_assert(coinbase::api::pve::verify_ac(toy, curve, ac, ac_pks, ct, Qs_mem, label) == SUCCESS); + + const int attempt_index = 0; + buf_t share_p1; + buf_t share_p3; + cb_assert(coinbase::api::pve::partial_decrypt_ac_attempt(toy, curve, ac, ct, attempt_index, "p1", + mem_t(keys[0].data(), 32), label, share_p1) == SUCCESS); + cb_assert(coinbase::api::pve::partial_decrypt_ac_attempt(toy, curve, ac, ct, attempt_index, "p3", + mem_t(keys[2].data(), 32), label, share_p3) == SUCCESS); + + coinbase::api::pve::leaf_shares_t quorum; + cb_assert(quorum.emplace("p1", mem_t(share_p1.data(), share_p1.size())).second); + cb_assert(quorum.emplace("p3", mem_t(share_p3.data(), share_p3.size())).second); + + std::vector xs_out; + cb_assert(coinbase::api::pve::combine_ac(toy, curve, ac, ct, attempt_index, label, quorum, xs_out) == SUCCESS); + + bool ok = (xs_out.size() == xs.size()); + for (int i = 0; ok && i < n; i++) ok = (xs_out[static_cast(i)] == buf_t(xs[static_cast(i)])); + std::cout << "recover ok? " << ok << "\n"; +} + +void demo_batch_default_base_pke_rsa() { + std::cout << "\n=== PVE Batch (api) + built-in RSA key blob ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-demo-label"); + + constexpr int n = 8; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + buf_t ek_blob; + buf_t dk_blob; + cb_assert(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob) == SUCCESS); + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt_batch(curve, ek_blob, label, xs, ct) == SUCCESS); + + int batch_count = 0; + cb_assert(coinbase::api::pve::get_batch_count(ct, batch_count) == SUCCESS); + std::cout << "batch_count: " << batch_count << "\n"; + + std::vector Qs; + cb_assert(coinbase::api::pve::get_public_keys_compressed_batch(ct, Qs) == SUCCESS); + std::vector Qs_mem; + Qs_mem.reserve(Qs.size()); + for (const auto& q : Qs) Qs_mem.emplace_back(q.data(), q.size()); + + buf_t label_extracted; + cb_assert(coinbase::api::pve::get_Label_batch(ct, label_extracted) == SUCCESS); + std::cout << "label extracted matches? " << (label_extracted == buf_t(label)) << "\n"; + + cb_assert(coinbase::api::pve::verify_batch(curve, ek_blob, ct, Qs_mem, label) == SUCCESS); + + std::vector xs_out; + cb_assert(coinbase::api::pve::decrypt_batch(curve, dk_blob, ek_blob, ct, label, xs_out) == SUCCESS); + + bool ok = (xs_out.size() == xs.size()); + for (int i = 0; ok && i < n; i++) ok = (xs_out[static_cast(i)] == buf_t(xs[static_cast(i)])); + std::cout << "decrypt ok? " << ok << "\n"; +} + +void demo_batch_default_base_pke_ecies() { + std::cout << "\n=== PVE Batch (api) + built-in ECIES(P-256) key blob ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-demo-label"); + + constexpr int n = 8; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(0x55 + i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + buf_t ek_blob; + buf_t dk_blob; + cb_assert(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(ek_blob, dk_blob) == SUCCESS); + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt_batch(curve, ek_blob, label, xs, ct) == SUCCESS); + + std::vector Qs; + cb_assert(coinbase::api::pve::get_public_keys_compressed_batch(ct, Qs) == SUCCESS); + std::vector Qs_mem; + Qs_mem.reserve(Qs.size()); + for (const auto& q : Qs) Qs_mem.emplace_back(q.data(), q.size()); + + cb_assert(coinbase::api::pve::verify_batch(curve, ek_blob, ct, Qs_mem, label) == SUCCESS); + + std::vector xs_out; + cb_assert(coinbase::api::pve::decrypt_batch(curve, dk_blob, ek_blob, ct, label, xs_out) == SUCCESS); + + bool ok = (xs_out.size() == xs.size()); + for (int i = 0; ok && i < n; i++) ok = (xs_out[static_cast(i)] == buf_t(xs[static_cast(i)])); + std::cout << "decrypt ok? " << ok << "\n"; +} + +void demo_batch_hsm_ecies_p256() { + std::cout << "\n=== PVE Batch (api) + ECIES(P-256) simulated HSM (ECDH callback only) ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-demo-label"); + + constexpr int n = 8; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(0x88 + i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + fake_hsm_ecies_p256_t hsm; + const std::string hsm_handle = "hsm-ecies-p256-key-batch-1"; + buf_t pub_key_oct_uncompressed; + cb_assert(hsm.generate_key(hsm_handle, pub_key_oct_uncompressed) == SUCCESS); + + buf_t ek_blob; + cb_assert(coinbase::api::pve::base_pke_ecies_p256_ek_from_oct(pub_key_oct_uncompressed, ek_blob) == SUCCESS); + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt_batch(curve, ek_blob, label, xs, ct) == SUCCESS); + + std::vector Qs; + cb_assert(coinbase::api::pve::get_public_keys_compressed_batch(ct, Qs) == SUCCESS); + std::vector Qs_mem; + Qs_mem.reserve(Qs.size()); + for (const auto& q : Qs) Qs_mem.emplace_back(q.data(), q.size()); + + cb_assert(coinbase::api::pve::verify_batch(curve, ek_blob, ct, Qs_mem, label) == SUCCESS); + + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb; + cb.ctx = &hsm; + cb.ecdh = ecies_p256_hsm_ecdh_cb; + + const mem_t dk_handle(hsm_handle); + std::vector xs_out; + cb_assert(coinbase::api::pve::decrypt_batch_ecies_p256_hsm(curve, dk_handle, ek_blob, ct, label, cb, xs_out) == SUCCESS); + + bool ok = (xs_out.size() == xs.size()); + for (int i = 0; ok && i < n; i++) ok = (xs_out[static_cast(i)] == buf_t(xs[static_cast(i)])); + std::cout << "decrypt ok? " << ok << "\n"; +} + +void demo_batch_custom_base_pke() { + std::cout << "\n=== PVE Batch (api) + custom base PKE ===\n"; + const coinbase::api::curve_id curve = coinbase::api::curve_id::secp256k1; + const mem_t label("pve-demo-label"); + + constexpr int n = 8; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(0x22 + i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + std::array key{}; + for (int i = 0; i < 32; i++) key[static_cast(i)] = static_cast(i); + const mem_t ek(key.data(), 32); + const mem_t dk(key.data(), 32); + + toy_base_pke_t toy; + + buf_t ct; + cb_assert(coinbase::api::pve::encrypt_batch(toy, curve, ek, label, xs, ct) == SUCCESS); + + std::vector Qs; + cb_assert(coinbase::api::pve::get_public_keys_compressed_batch(ct, Qs) == SUCCESS); + std::vector Qs_mem; + Qs_mem.reserve(Qs.size()); + for (const auto& q : Qs) Qs_mem.emplace_back(q.data(), q.size()); + + cb_assert(coinbase::api::pve::verify_batch(toy, curve, ek, ct, Qs_mem, label) == SUCCESS); + + std::vector xs_out; + cb_assert(coinbase::api::pve::decrypt_batch(toy, curve, dk, ek, ct, label, xs_out) == SUCCESS); + + bool ok = (xs_out.size() == xs.size()); + for (int i = 0; ok && i < n; i++) ok = (xs_out[static_cast(i)] == buf_t(xs[static_cast(i)])); + std::cout << "decrypt ok? " << ok << "\n"; +} + +} // namespace + +int main(int /*argc*/, const char* /*argv*/[]) { + std::cout << std::boolalpha; + std::cout << "================ PVE Demo (api-only) ================\n"; + + demo_default_base_pke_rsa(); + demo_default_base_pke_ecies(); + demo_hsm_ecies_p256(); + demo_custom_base_pke(); + + demo_ac_default_base_pke_rsa(); + demo_ac_default_base_pke_ecies(); + demo_ac_custom_base_pke(); + + demo_batch_default_base_pke_rsa(); + demo_batch_default_base_pke_ecies(); + demo_batch_hsm_ecies_p256(); + demo_batch_custom_base_pke(); + + return 0; +} + diff --git a/demo-api/schnorr_2p_pve_batch_backup/CMakeLists.txt b/demo-api/schnorr_2p_pve_batch_backup/CMakeLists.txt new file mode 100644 index 00000000..93ee52fc --- /dev/null +++ b/demo-api/schnorr_2p_pve_batch_backup/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.16) + +project(mpc-demo-api-schnorr_2p_pve_batch_backup LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) + +include(${REPO_CMAKE_DIR}/macros.cmake) +include(${REPO_CMAKE_DIR}/arch.cmake) +include(${REPO_CMAKE_DIR}/openssl.cmake) +include(${REPO_CMAKE_DIR}/compilation_flags.cmake) + +if(NOT DEFINED CBMPC_SOURCE_DIR) + if(DEFINED ENV{CBMPC_PREFIX}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX}") + else() + get_filename_component(_cbmpc_repo_root "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE) + if(EXISTS "${_cbmpc_repo_root}/build/install/public") + set(CBMPC_SOURCE_DIR "${_cbmpc_repo_root}/build/install/public") + else() + set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) + endif() + endif() +endif() + +set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib") +if(EXISTS "${CBMPC_SOURCE_DIR}/lib/Release/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Release") +elseif(EXISTS "${CBMPC_SOURCE_DIR}/lib/Debug/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Debug") +endif() + +add_executable(mpc-demo-api-schnorr_2p_pve_batch_backup main.cpp) + +target_include_directories(mpc-demo-api-schnorr_2p_pve_batch_backup PRIVATE ${CBMPC_SOURCE_DIR}/include) +target_link_directories(mpc-demo-api-schnorr_2p_pve_batch_backup PRIVATE ${CBMPC_LIB_DIR}) +target_link_libraries(mpc-demo-api-schnorr_2p_pve_batch_backup PRIVATE cbmpc) + +# Important for static linking on Linux: ensure libcbmpc.a appears before +# libcrypto.a on the final link line so libcrypto symbols resolve correctly. +link_openssl(mpc-demo-api-schnorr_2p_pve_batch_backup) + diff --git a/demo-api/schnorr_2p_pve_batch_backup/main.cpp b/demo-api/schnorr_2p_pve_batch_backup/main.cpp new file mode 100644 index 00000000..7ca06c08 --- /dev/null +++ b/demo-api/schnorr_2p_pve_batch_backup/main.cpp @@ -0,0 +1,266 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +using namespace coinbase; + +[[noreturn]] void die(const std::string& msg) { + std::cerr << "schnorr_2p_pve_batch_backup demo failure: " << msg << "\n"; + std::exit(1); +} + +void require(bool ok, const std::string& msg) { + if (!ok) die(msg); +} + +void require_rv(error_t got, error_t want, const std::string& msg) { + if (got != want) die(msg + " (got=0x" + std::to_string(uint32_t(got)) + ")"); +} + +// Minimal in-memory 2-party transport. +struct channel_t { + std::mutex m; + std::condition_variable cv; + std::deque q; +}; + +struct in_memory_network_t { + std::shared_ptr ch[2][2]; + std::atomic aborted{false}; + in_memory_network_t() { + ch[0][1] = std::make_shared(); + ch[1][0] = std::make_shared(); + } + + void abort() { + aborted.store(true); + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 2; j++) { + if (ch[i][j]) ch[i][j]->cv.notify_all(); + } + } + } +}; + +class in_memory_transport_t final : public coinbase::api::data_transport_i { + public: + in_memory_transport_t(int self, std::shared_ptr net) : self_(self), net_(std::move(net)) {} + + error_t send(coinbase::api::party_idx_t receiver, mem_t msg) override { + if (!net_) return E_GENERAL; + if (net_->aborted.load()) return E_NET_GENERAL; + if (receiver < 0 || receiver > 1 || receiver == self_) return E_BADARG; + auto c = net_->ch[self_][receiver]; + if (!c) return E_GENERAL; + { + std::lock_guard lk(c->m); + c->q.emplace_back(msg); + } + c->cv.notify_one(); + return SUCCESS; + } + + error_t receive(coinbase::api::party_idx_t sender, buf_t& msg) override { + if (!net_ || sender < 0 || sender > 1 || sender == self_) return E_BADARG; + auto c = net_->ch[sender][self_]; + if (!c) return E_GENERAL; + std::unique_lock lk(c->m); + c->cv.wait(lk, [&] { return net_->aborted.load() || !c->q.empty(); }); + if (net_->aborted.load() && c->q.empty()) return E_NET_GENERAL; + msg = std::move(c->q.front()); + c->q.pop_front(); + return SUCCESS; + } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + msgs.clear(); + msgs.resize(senders.size()); + for (size_t i = 0; i < senders.size(); i++) { + error_t rv = receive(senders[i], msgs[i]); + if (rv) return rv; + } + return SUCCESS; + } + + private: + const int self_; + std::shared_ptr net_; +}; + +template +void run_2pc(in_memory_network_t* net, F1&& f1, F2&& f2, error_t& out_rv1, error_t& out_rv2) { + std::thread t1([&] { + out_rv1 = f1(); + if (out_rv1 && net) net->abort(); + }); + std::thread t2([&] { + out_rv2 = f2(); + if (out_rv2 && net) net->abort(); + }); + t1.join(); + t2.join(); +} + +static buf_t make_msg32(uint8_t seed) { + std::array msg{}; + for (size_t i = 0; i < msg.size(); i++) msg[i] = static_cast(seed + i); + return buf_t(msg.data(), static_cast(msg.size())); +} + +} // namespace + +int main() { + using coinbase::api::curve_id; + using coinbase::api::schnorr_2p::party_t; + + std::cout << "=== Schnorr-2P (api) + PVE batch backup (5x DKG) ===\n"; + + constexpr int batch_count = 5; + const curve_id curve = curve_id::secp256k1; + const mem_t label("schnorr-2p-demo:pve-batch-backup"); + + // Base-PKE keypair (used to encrypt and decrypt the batch ciphertext). + buf_t ek; + buf_t dk; + require_rv(coinbase::api::pve::generate_base_pke_rsa_keypair(ek, dk), SUCCESS, "generate_base_pke_rsa_keypair"); + + std::vector key_p2(batch_count); + std::vector public_p1(batch_count); + std::vector x_fixed(batch_count); + std::vector Qi_p1(batch_count); + std::vector pubkeys(batch_count); + + // Run DKG 5 times and detach p1's scalar share each time. + for (int k = 0; k < batch_count; k++) { + auto net = std::make_shared(); + in_memory_transport_t t1(/*self=*/0, net); + in_memory_transport_t t2(/*self=*/1, net); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + buf_t key1; + buf_t key2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + run_2pc(net.get(), [&] { return coinbase::api::schnorr_2p::dkg(job1, curve, key1); }, + [&] { return coinbase::api::schnorr_2p::dkg(job2, curve, key2); }, rv1, rv2); + require_rv(rv1, SUCCESS, "dkg p1 (k=" + std::to_string(k) + ")"); + require_rv(rv2, SUCCESS, "dkg p2 (k=" + std::to_string(k) + ")"); + + // Extract and sanity-check global public key. + buf_t pub1; + buf_t pub2; + require_rv(coinbase::api::schnorr_2p::get_public_key_compressed(key1, pub1), SUCCESS, "get_public_key_compressed p1"); + require_rv(coinbase::api::schnorr_2p::get_public_key_compressed(key2, pub2), SUCCESS, "get_public_key_compressed p2"); + require(pub1 == pub2, "public key mismatch (k=" + std::to_string(k) + ")"); + pubkeys[static_cast(k)] = pub1; + + // Capture p1's share public point before detaching (public blobs do not carry it). + require_rv(coinbase::api::schnorr_2p::get_public_share_compressed(key1, Qi_p1[static_cast(k)]), SUCCESS, + "get_public_share_compressed p1"); + + // Detach p1 scalar share into (public blob, scalar). + require_rv(coinbase::api::schnorr_2p::detach_private_scalar(key1, public_p1[static_cast(k)], + x_fixed[static_cast(k)]), + SUCCESS, "detach_private_scalar p1"); + require(x_fixed[static_cast(k)].size() == 32, "unexpected scalar size"); + + // Keep p2's full key blob for later signing. + key_p2[static_cast(k)] = key2; + } + + // Prepare batch scalars and corresponding share points. + std::vector xs; + xs.reserve(batch_count); + std::vector Qs; + Qs.reserve(batch_count); + for (int k = 0; k < batch_count; k++) { + xs.emplace_back(x_fixed[static_cast(k)].data(), x_fixed[static_cast(k)].size()); + Qs.emplace_back(Qi_p1[static_cast(k)].data(), Qi_p1[static_cast(k)].size()); + } + + // Encrypt + verify batch ciphertext. + buf_t ct; + require_rv(coinbase::api::pve::encrypt_batch(curve, mem_t(ek.data(), ek.size()), label, xs, ct), SUCCESS, + "pve::encrypt_batch"); + + int got_count = 0; + require_rv(coinbase::api::pve::get_batch_count(mem_t(ct.data(), ct.size()), got_count), SUCCESS, "pve::get_batch_count"); + require(got_count == batch_count, "unexpected batch_count"); + + buf_t label_out; + require_rv(coinbase::api::pve::get_Label_batch(mem_t(ct.data(), ct.size()), label_out), SUCCESS, "pve::get_Label_batch"); + require(label_out == buf_t(label), "label mismatch"); + + require_rv(coinbase::api::pve::verify_batch(curve, mem_t(ek.data(), ek.size()), mem_t(ct.data(), ct.size()), Qs, label), + SUCCESS, "pve::verify_batch"); + + // Decrypt batch ciphertext (simulate restoring from backup service). + std::vector xs_out; + require_rv(coinbase::api::pve::decrypt_batch(curve, mem_t(dk.data(), dk.size()), mem_t(ek.data(), ek.size()), + mem_t(ct.data(), ct.size()), label, xs_out), + SUCCESS, "pve::decrypt_batch"); + require(static_cast(xs_out.size()) == batch_count, "decrypt_batch returned wrong count"); + + // Restore p1 key blobs and sign once per key. + for (int k = 0; k < batch_count; k++) { + buf_t restored_p1; + require_rv(coinbase::api::schnorr_2p::attach_private_scalar(public_p1[static_cast(k)], + mem_t(xs_out[static_cast(k)].data(), + xs_out[static_cast(k)].size()), + mem_t(Qi_p1[static_cast(k)].data(), + Qi_p1[static_cast(k)].size()), + restored_p1), + SUCCESS, "attach_private_scalar p1"); + + // Sign with restored p1 + original p2. + auto net = std::make_shared(); + in_memory_transport_t t1(/*self=*/0, net); + in_memory_transport_t t2(/*self=*/1, net); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + const buf_t msg = make_msg32(static_cast(0x11 + k)); + buf_t sig1; + buf_t sig2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + run_2pc(net.get(), [&] { return coinbase::api::schnorr_2p::sign(job1, restored_p1, msg, sig1); }, + [&] { return coinbase::api::schnorr_2p::sign(job2, key_p2[static_cast(k)], msg, sig2); }, rv1, rv2); + require_rv(rv1, SUCCESS, "sign p1 (k=" + std::to_string(k) + ")"); + require_rv(rv2, SUCCESS, "sign p2 (k=" + std::to_string(k) + ")"); + + require(sig1.size() == 64, "unexpected signature size"); + require(sig2.empty(), "signature should be returned only on p1"); + + // Public key should remain stable after restore. + buf_t pub_restored; + require_rv(coinbase::api::schnorr_2p::get_public_key_compressed(restored_p1, pub_restored), SUCCESS, + "get_public_key_compressed restored p1"); + require(pub_restored == pubkeys[static_cast(k)], "restored public key mismatch"); + } + + std::cout << "Done.\n"; + return 0; +} + diff --git a/demo-cpp/basic_primitive/CMakeLists.txt b/demo-cpp/basic_primitive/CMakeLists.txt new file mode 100755 index 00000000..5ab7f2e5 --- /dev/null +++ b/demo-cpp/basic_primitive/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.16) + +project(mpc-demo-basic_primitive LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) + +include(${REPO_CMAKE_DIR}/macros.cmake) +include(${REPO_CMAKE_DIR}/arch.cmake) +include(${REPO_CMAKE_DIR}/openssl.cmake) +include(${REPO_CMAKE_DIR}/compilation_flags.cmake) + +if(NOT DEFINED CBMPC_SOURCE_DIR) + if(DEFINED ENV{CBMPC_PREFIX}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX}") + else() + get_filename_component(_cbmpc_repo_root "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE) + if(EXISTS "${_cbmpc_repo_root}/build/install/full") + set(CBMPC_SOURCE_DIR "${_cbmpc_repo_root}/build/install/full") + else() + set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) + endif() + endif() +endif() + +set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib") +if(EXISTS "${CBMPC_SOURCE_DIR}/lib/Release/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Release") +elseif(EXISTS "${CBMPC_SOURCE_DIR}/lib/Debug/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Debug") +endif() + +add_executable(mpc-demo-basic_primitive main.cpp) +link_openssl(mpc-demo-basic_primitive) + +target_include_directories(mpc-demo-basic_primitive PUBLIC ${CBMPC_SOURCE_DIR}/include) +if(EXISTS "${CBMPC_SOURCE_DIR}/include-internal") + target_include_directories(mpc-demo-basic_primitive PUBLIC ${CBMPC_SOURCE_DIR}/include-internal) +endif() +target_link_directories(mpc-demo-basic_primitive PUBLIC ${CBMPC_LIB_DIR}) +target_link_libraries(mpc-demo-basic_primitive PRIVATE cbmpc) + +if(IS_LINUX) + link_openssl(mpc-demo-basic_primitive) +endif() diff --git a/demos-cpp/basic_primitive/main.cpp b/demo-cpp/basic_primitive/main.cpp similarity index 84% rename from demos-cpp/basic_primitive/main.cpp rename to demo-cpp/basic_primitive/main.cpp index 002fe3f5..7433748e 100644 --- a/demos-cpp/basic_primitive/main.cpp +++ b/demo-cpp/basic_primitive/main.cpp @@ -1,9 +1,12 @@ #include #include -#include -#include -#include +#include +#include +#include + +using namespace coinbase; +using namespace coinbase::crypto; bn_t hash_number() { @@ -41,7 +44,7 @@ error_t com() ecc_point_t G = c.generator(); buf_t sid = coinbase::crypto::gen_random(16); - pid_t pid = coinbase::crypto::pid_from_name("test"); + coinbase::crypto::mpc_pid_t pid = coinbase::crypto::pid_from_name("test"); coinbase::crypto::commitment_t com(sid, pid); com.gen(G); std::cout << bn_t(com.msg).to_string() << std::endl; diff --git a/src/cbmpc/protocol/mpc_job_session.cpp b/demo-cpp/common/mpc_job_session.cpp similarity index 99% rename from src/cbmpc/protocol/mpc_job_session.cpp rename to demo-cpp/common/mpc_job_session.cpp index 2fd8dd21..aacff16c 100644 --- a/src/cbmpc/protocol/mpc_job_session.cpp +++ b/demo-cpp/common/mpc_job_session.cpp @@ -1,7 +1,7 @@ -#include "mpc_job_session.h" - #include +#include "mpc_job_session.h" + namespace coinbase::mpc { error_t parallel_data_transport_t::send(const party_idx_t receiver, const parallel_id_t parallel_id, const mem_t msg) { @@ -54,6 +54,7 @@ error_t parallel_data_transport_t::send(const party_idx_t receiver, const parall if (is_send_active == 0) send_active_cv.notify_all(); } + (void)rv; return SUCCESS; } @@ -267,3 +268,4 @@ void parallel_data_transport_t::set_parallel(int _parallel_count) { } } // namespace coinbase::mpc + diff --git a/src/cbmpc/protocol/mpc_job_session.h b/demo-cpp/common/mpc_job_session.h similarity index 74% rename from src/cbmpc/protocol/mpc_job_session.h rename to demo-cpp/common/mpc_job_session.h index e246f06d..7e9fa733 100644 --- a/src/cbmpc/protocol/mpc_job_session.h +++ b/demo-cpp/common/mpc_job_session.h @@ -1,17 +1,30 @@ #pragma once -#include +#include +#include +#include +#include +#include +#include -#include "data_transport.h" +#include namespace coinbase::mpc { using parallel_id_t = int32_t; -class parallel_data_transport_t : public data_transport_interface_t { +// parallel_data_transport_t demonstrates that the core `job_*` types do not depend +// on any particular transport implementation. +// +// It also demonstrates one way to implement a "parallel transport": multiplexing +// several logical messages (distinguished by `parallel_id`) into a single +// underlying transport message, allowing multiple protocol sessions to share a +// single transport. +class parallel_data_transport_t : public coinbase::api::data_transport_i { public: - parallel_data_transport_t(std::shared_ptr _data_transport_ptr, int _parallel_count = 1) - : data_transport_ptr(_data_transport_ptr) { + parallel_data_transport_t(std::shared_ptr _data_transport_ptr, + int _parallel_count = 1) + : data_transport_ptr(std::move(_data_transport_ptr)) { set_parallel(_parallel_count); } @@ -19,7 +32,7 @@ class parallel_data_transport_t : public data_transport_interface_t { error_t receive(const party_idx_t sender, const parallel_id_t parallel_id, buf_t& msg); error_t receive_all(const std::vector& senders, parallel_id_t parallel_id, std::vector& msgs); - // data_transport_interface_t overrides using jsid 0 + // data_transport_i overrides using parallel_id 0 error_t send(const party_idx_t receiver, mem_t msg) override { return send(receiver, 0, msg); } error_t receive(const party_idx_t sender, buf_t& msg) override { return receive(sender, 0, msg); } error_t receive_all(const std::vector& senders, std::vector& msgs) override { @@ -29,7 +42,7 @@ class parallel_data_transport_t : public data_transport_interface_t { void set_parallel(int _parallel_count); private: - std::shared_ptr data_transport_ptr; + std::shared_ptr data_transport_ptr; int parallel_count; // For parallel send @@ -97,15 +110,14 @@ class job_parallel_mp_t : public job_mp_t { public: job_parallel_mp_t(party_idx_t index, std::vector pnames, std::shared_ptr _network_ptr, parallel_id_t _parallel_id) - : job_mp_t(index, pnames, _network_ptr), parallel_id(_parallel_id) {} + : job_mp_t(index, std::move(pnames), std::move(_network_ptr)), parallel_id(_parallel_id) {} void set_network(party_idx_t party_idx, std::shared_ptr ptr) { - set_transport(party_idx, ptr); + set_transport(party_idx, std::move(ptr)); } - job_parallel_mp_t get_parallel_job(int parallel_count, parallel_id_t id) { - return job_parallel_mp_t(party_index, names, std::static_pointer_cast(transport_ptr), - id); + job_parallel_mp_t get_parallel_job(int /*parallel_count*/, parallel_id_t id) { + return job_parallel_mp_t(party_index, names, std::static_pointer_cast(transport_ptr), id); } protected: @@ -116,13 +128,13 @@ class job_parallel_2p_t : public job_2p_t { public: job_parallel_2p_t(party_t party, crypto::pname_t pname1, crypto::pname_t pname2, std::shared_ptr ptr, parallel_id_t id = 0) - : job_2p_t(party, pname1, pname2, ptr), parallel_id(id) {}; + : job_2p_t(party, std::move(pname1), std::move(pname2), std::move(ptr)), parallel_id(id) {} void set_network(party_t party, std::shared_ptr ptr) { - set_transport(party_idx_t(party), ptr); + set_transport(party_idx_t(party), std::move(ptr)); } - job_parallel_2p_t get_parallel_job(int parallel_count, parallel_id_t id) { + job_parallel_2p_t get_parallel_job(int /*parallel_count*/, parallel_id_t id) { return job_parallel_2p_t(party_t(party_index), names[0], names[1], std::static_pointer_cast(transport_ptr), id); } @@ -132,6 +144,7 @@ class job_parallel_2p_t : public job_2p_t { if (network) network->set_parallel(parallel_count); } + protected: error_t send_impl(party_idx_t to, mem_t msg) override { auto network = std::static_pointer_cast(transport_ptr); if (!network) return E_NET_GENERAL; @@ -148,3 +161,4 @@ class job_parallel_2p_t : public job_2p_t { }; } // namespace coinbase::mpc + diff --git a/demo-cpp/parallel_transport/CMakeLists.txt b/demo-cpp/parallel_transport/CMakeLists.txt new file mode 100644 index 00000000..bad4c186 --- /dev/null +++ b/demo-cpp/parallel_transport/CMakeLists.txt @@ -0,0 +1,56 @@ +cmake_minimum_required(VERSION 3.16) + +project(mpc-demo-parallel_transport LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) + +include(${REPO_CMAKE_DIR}/macros.cmake) +include(${REPO_CMAKE_DIR}/arch.cmake) +include(${REPO_CMAKE_DIR}/openssl.cmake) +include(${REPO_CMAKE_DIR}/compilation_flags.cmake) + +if(NOT DEFINED CBMPC_SOURCE_DIR) + if(DEFINED ENV{CBMPC_PREFIX}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX}") + else() + get_filename_component(_cbmpc_repo_root "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE) + if(EXISTS "${_cbmpc_repo_root}/build/install/full") + set(CBMPC_SOURCE_DIR "${_cbmpc_repo_root}/build/install/full") + else() + set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) + endif() + endif() +endif() + +set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib") +if(EXISTS "${CBMPC_SOURCE_DIR}/lib/Release/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Release") +elseif(EXISTS "${CBMPC_SOURCE_DIR}/lib/Debug/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Debug") +endif() + +add_executable( + mpc-demo-parallel_transport + main.cpp + ../common/mpc_job_session.cpp +) + +# The demo helper header lives in demo-cpp/common. +target_include_directories(mpc-demo-parallel_transport PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../common) + +target_include_directories(mpc-demo-parallel_transport PRIVATE ${CBMPC_SOURCE_DIR}/include) +if(EXISTS "${CBMPC_SOURCE_DIR}/include-internal") + target_include_directories(mpc-demo-parallel_transport PRIVATE ${CBMPC_SOURCE_DIR}/include-internal) +endif() + +target_link_directories(mpc-demo-parallel_transport PRIVATE ${CBMPC_LIB_DIR}) +target_link_libraries(mpc-demo-parallel_transport PRIVATE cbmpc) +link_openssl(mpc-demo-parallel_transport) + +if(IS_LINUX) + link_openssl(mpc-demo-parallel_transport) +endif() + diff --git a/demo-cpp/parallel_transport/main.cpp b/demo-cpp/parallel_transport/main.cpp new file mode 100644 index 00000000..67dadc73 --- /dev/null +++ b/demo-cpp/parallel_transport/main.cpp @@ -0,0 +1,231 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "mpc_job_session.h" + +namespace { + +using namespace coinbase; +using namespace coinbase::mpc; + +[[noreturn]] void die(const std::string& msg) { + std::cerr << "parallel_transport demo failure: " << msg << std::endl; + std::exit(1); +} + +void require(bool ok, const std::string& msg) { + if (!ok) die(msg); +} + +void require_rv(error_t got, error_t want, const std::string& msg) { + if (got != want) { + die(msg + " (got=" + std::to_string(int(got)) + ", want=" + std::to_string(int(want)) + ")"); + } +} + +// A minimal in-memory transport implementation that demonstrates how users can +// bring their own transport without touching protocol/job definitions. +struct channel_t { + std::mutex m; + std::condition_variable cv; + std::deque q; +}; + +struct in_memory_network_t { + explicit in_memory_network_t(int n) : n(n), ch(n, std::vector>(n)) { + for (int from = 0; from < n; ++from) { + for (int to = 0; to < n; ++to) { + if (from == to) continue; + ch[from][to] = std::make_shared(); + } + } + } + + const int n; + // ch[from][to] is the message queue from `from` to `to`. + std::vector>> ch; +}; + +class in_memory_transport_t final : public coinbase::api::data_transport_i { + public: + in_memory_transport_t(int self, std::shared_ptr net) : self_(self), net_(std::move(net)) {} + + error_t send(party_idx_t receiver, mem_t msg) override { + if (!net_ || receiver < 0 || receiver >= net_->n || receiver == self_) return E_BADARG; + auto c = net_->ch[self_][receiver]; + if (!c) return E_GENERAL; + { + std::lock_guard lk(c->m); + c->q.emplace_back(msg); + } + c->cv.notify_one(); + return SUCCESS; + } + + error_t receive(party_idx_t sender, buf_t& msg) override { + if (!net_ || sender < 0 || sender >= net_->n || sender == self_) return E_BADARG; + auto c = net_->ch[sender][self_]; + if (!c) return E_GENERAL; + std::unique_lock lk(c->m); + c->cv.wait(lk, [&] { return !c->q.empty(); }); + msg = std::move(c->q.front()); + c->q.pop_front(); + return SUCCESS; + } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + msgs.clear(); + msgs.resize(senders.size()); + for (size_t i = 0; i < senders.size(); ++i) { + error_t rv = receive(senders[i], msgs[i]); + if (rv) return rv; + } + return SUCCESS; + } + + private: + const int self_; + std::shared_ptr net_; +}; + +class fixed_buf_transport_t final : public coinbase::api::data_transport_i { + public: + explicit fixed_buf_transport_t(buf_t malicious) : malicious_buf_(std::move(malicious)) {} + + error_t send(party_idx_t /*receiver*/, mem_t /*msg*/) override { return SUCCESS; } + + error_t receive(party_idx_t /*sender*/, buf_t& msg) override { + msg = malicious_buf_; + return SUCCESS; + } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + msgs.assign(senders.size(), malicious_buf_); + return SUCCESS; + } + + private: + buf_t malicious_buf_; +}; + +struct scoped_log_sink_t { + scoped_log_sink_t() : prev_(coinbase::out_log_fun) { coinbase::out_log_fun = &scoped_log_sink_t::discard; } + ~scoped_log_sink_t() { coinbase::out_log_fun = prev_; } + scoped_log_sink_t(const scoped_log_sink_t&) = delete; + scoped_log_sink_t& operator=(const scoped_log_sink_t&) = delete; + + private: + static void discard(int /*mode*/, const char* /*str*/) {} + coinbase::out_log_str_f prev_; +}; + +void demo_parallel_transport_oob_receive() { + scoped_log_sink_t logs; + + // A single byte `0x00` decodes to vector length = 0 (via convert_len). + buf_t malicious(1); + malicious[0] = 0x00; + + auto transport = std::make_shared(malicious); + parallel_data_transport_t network(transport, /*_parallel_count=*/2); + + error_t rv0 = UNINITIALIZED_ERROR; + error_t rv1 = UNINITIALIZED_ERROR; + buf_t out0, out1; + + std::thread t0([&] { rv0 = network.receive(/*sender=*/0, /*parallel_id=*/0, out0); }); + std::thread t1([&] { rv1 = network.receive(/*sender=*/0, /*parallel_id=*/1, out1); }); + t0.join(); + t1.join(); + + require_rv(rv0, E_FORMAT, "receive: expected E_FORMAT for parallel_id=0"); + require_rv(rv1, E_FORMAT, "receive: expected E_FORMAT for parallel_id=1"); + require(out0.empty(), "receive: expected empty out0 on error"); + require(out1.empty(), "receive: expected empty out1 on error"); +} + +void demo_parallel_transport_oob_receive_all() { + scoped_log_sink_t logs; + + buf_t malicious(1); + malicious[0] = 0x00; + + auto transport = std::make_shared(malicious); + parallel_data_transport_t network(transport, /*_parallel_count=*/2); + + const std::vector senders = {0, 1, 2}; + + error_t rv0 = UNINITIALIZED_ERROR; + error_t rv1 = UNINITIALIZED_ERROR; + std::vector outs0(senders.size()); + std::vector outs1(senders.size()); + + std::thread t0([&] { rv0 = network.receive_all(senders, /*parallel_id=*/0, outs0); }); + std::thread t1([&] { rv1 = network.receive_all(senders, /*parallel_id=*/1, outs1); }); + t0.join(); + t1.join(); + + require_rv(rv0, E_FORMAT, "receive_all: expected E_FORMAT for parallel_id=0"); + require_rv(rv1, E_FORMAT, "receive_all: expected E_FORMAT for parallel_id=1"); + for (const auto& m : outs0) require(m.empty(), "receive_all: expected empty output on error (outs0)"); + for (const auto& m : outs1) require(m.empty(), "receive_all: expected empty output on error (outs1)"); +} + +void demo_parallel_2pc_messaging() { + const int parallel_count = 16; + + auto net = std::make_shared(/*n=*/2); + auto t_p1 = std::make_shared(/*self=*/0, net); + auto t_p2 = std::make_shared(/*self=*/1, net); + + auto n_p1 = std::make_shared(t_p1, parallel_count); + auto n_p2 = std::make_shared(t_p2, parallel_count); + + std::atomic finished{0}; + std::vector threads; + threads.reserve(size_t(parallel_count) * 2); + + for (int th_i = 0; th_i < parallel_count; ++th_i) { + threads.emplace_back([&, th_i] { + job_parallel_2p_t job(party_t::p1, /*pname1=*/"p1", /*pname2=*/"p2", n_p1, parallel_id_t(th_i)); + buf_t data("msg:" + std::to_string(th_i)); + require_rv(job.p1_to_p2(data), SUCCESS, "2pc parallel: p1_to_p2 failed (p1)"); + finished++; + }); + } + + for (int th_i = 0; th_i < parallel_count; ++th_i) { + threads.emplace_back([&, th_i] { + job_parallel_2p_t job(party_t::p2, /*pname1=*/"p1", /*pname2=*/"p2", n_p2, parallel_id_t(th_i)); + buf_t data; + require_rv(job.p1_to_p2(data), SUCCESS, "2pc parallel: p1_to_p2 failed (p2)"); + require(data == buf_t("msg:" + std::to_string(th_i)), "2pc parallel: received message mismatch"); + finished++; + }); + } + + for (auto& t : threads) t.join(); + require(finished == parallel_count * 2, "2pc parallel: not all threads finished"); +} + +} // namespace + +int main() { + demo_parallel_transport_oob_receive(); + demo_parallel_transport_oob_receive_all(); + demo_parallel_2pc_messaging(); + + std::cout << "parallel_transport demo: OK" << std::endl; + return 0; +} + diff --git a/demo-cpp/zk/CMakeLists.txt b/demo-cpp/zk/CMakeLists.txt new file mode 100755 index 00000000..fbe6c7de --- /dev/null +++ b/demo-cpp/zk/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.16) + +project(mpc-demo-zk LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) + +include(${REPO_CMAKE_DIR}/macros.cmake) +include(${REPO_CMAKE_DIR}/arch.cmake) +include(${REPO_CMAKE_DIR}/openssl.cmake) +include(${REPO_CMAKE_DIR}/compilation_flags.cmake) + +if(NOT DEFINED CBMPC_SOURCE_DIR) + if(DEFINED ENV{CBMPC_PREFIX}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX}") + else() + get_filename_component(_cbmpc_repo_root "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE) + if(EXISTS "${_cbmpc_repo_root}/build/install/full") + set(CBMPC_SOURCE_DIR "${_cbmpc_repo_root}/build/install/full") + else() + set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) + endif() + endif() +endif() + +set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib") +if(EXISTS "${CBMPC_SOURCE_DIR}/lib/Release/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Release") +elseif(EXISTS "${CBMPC_SOURCE_DIR}/lib/Debug/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Debug") +endif() + +add_executable(mpc-demo-zk main.cpp) + +link_openssl(mpc-demo-zk) +target_include_directories(mpc-demo-zk PRIVATE ${CBMPC_SOURCE_DIR}/include) +if(EXISTS "${CBMPC_SOURCE_DIR}/include-internal") + target_include_directories(mpc-demo-zk PRIVATE ${CBMPC_SOURCE_DIR}/include-internal) +endif() +target_link_directories(mpc-demo-zk PRIVATE ${CBMPC_LIB_DIR}) +target_link_libraries(mpc-demo-zk PRIVATE cbmpc) + +if(IS_LINUX) + link_openssl(mpc-demo-zk) +endif() diff --git a/demos-cpp/zk/demo_nizk.h b/demo-cpp/zk/demo_nizk.h similarity index 91% rename from demos-cpp/zk/demo_nizk.h rename to demo-cpp/zk/demo_nizk.h index 3c4c6937..429e9296 100644 --- a/demos-cpp/zk/demo_nizk.h +++ b/demo-cpp/zk/demo_nizk.h @@ -1,7 +1,7 @@ -#include -#include -#include -#include +#include +#include +#include +#include struct demo_nizk_t { diff --git a/demos-cpp/zk/main.cpp b/demo-cpp/zk/main.cpp similarity index 87% rename from demos-cpp/zk/main.cpp rename to demo-cpp/zk/main.cpp index 8cdbd950..d9ef19f6 100644 --- a/demos-cpp/zk/main.cpp +++ b/demo-cpp/zk/main.cpp @@ -1,7 +1,10 @@ #include #include -#include +#include + +using namespace coinbase; +using namespace coinbase::crypto; #include "demo_nizk.h" diff --git a/demos-cpp/basic_primitive/CMakeLists.txt b/demos-cpp/basic_primitive/CMakeLists.txt deleted file mode 100755 index f34fe56b..00000000 --- a/demos-cpp/basic_primitive/CMakeLists.txt +++ /dev/null @@ -1,26 +0,0 @@ -cmake_minimum_required(VERSION 3.16) - -project(mpc-demo-basic_primitive LANGUAGES CXX) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) - -include(${REPO_CMAKE_DIR}/macros.cmake) -include(${REPO_CMAKE_DIR}/arch.cmake) -include(${REPO_CMAKE_DIR}/openssl.cmake) -include(${REPO_CMAKE_DIR}/compilation_flags.cmake) - -set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) - -add_executable(mpc-demo-basic_primitive main.cpp) -link_openssl(mpc-demo-basic_primitive) - -target_include_directories(mpc-demo-basic_primitive PUBLIC ${CBMPC_SOURCE_DIR}/include) -target_link_directories(mpc-demo-basic_primitive PUBLIC ${CBMPC_SOURCE_DIR}/lib) -target_link_libraries(mpc-demo-basic_primitive PRIVATE cbmpc) - -if(IS_LINUX) - link_openssl(mpc-demo-basic_primitive) -endif() diff --git a/demos-cpp/zk/CMakeLists.txt b/demos-cpp/zk/CMakeLists.txt deleted file mode 100755 index e8ab1358..00000000 --- a/demos-cpp/zk/CMakeLists.txt +++ /dev/null @@ -1,26 +0,0 @@ -cmake_minimum_required(VERSION 3.16) - -project(mpc-demo-zk LANGUAGES CXX) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -set(REPO_CMAKE_DIR ${CMAKE_CURRENT_LIST_DIR}/../../cmake) - -include(${REPO_CMAKE_DIR}/macros.cmake) -include(${REPO_CMAKE_DIR}/arch.cmake) -include(${REPO_CMAKE_DIR}/openssl.cmake) -include(${REPO_CMAKE_DIR}/compilation_flags.cmake) - -set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) - -add_executable(mpc-demo-zk main.cpp) - -link_openssl(mpc-demo-zk) -target_include_directories(mpc-demo-zk PRIVATE ${CBMPC_SOURCE_DIR}/include) -target_link_directories(mpc-demo-zk PRIVATE ${CBMPC_SOURCE_DIR}/lib) -target_link_libraries(mpc-demo-zk PRIVATE cbmpc) - -if(IS_LINUX) - link_openssl(mpc-demo-zk) -endif() diff --git a/demos-go/cb-mpc-go/.gitignore b/demos-go/cb-mpc-go/.gitignore deleted file mode 100644 index 9ed3b07c..00000000 --- a/demos-go/cb-mpc-go/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.test diff --git a/demos-go/cb-mpc-go/api/curve/curve.go b/demos-go/cb-mpc-go/api/curve/curve.go deleted file mode 100644 index 2441cfa4..00000000 --- a/demos-go/cb-mpc-go/api/curve/curve.go +++ /dev/null @@ -1,232 +0,0 @@ -package curve - -import ( - "fmt" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// Bridging conventions between Go and native (cgo boundary): -// -// - Curves are passed as integer codes (OpenSSL NIDs). The bindings resolve -// native curve handles internally via ECurveFind(code) when needed. -// - Points, scalars and other crypto artefacts are passed as canonical []byte -// encodings. The bindings reconstruct native objects internally as needed. -// -// Rationale: keep the public Go API free of native pointers/handles and any -// linkname plumbing, while incurring only negligible overhead to reconstruct -// native structures inside the bindings. -// -// The helpers Code(c) and NewFromCode(code) are the only public surface needed -// by higher layers to follow this convention consistently. - -// Curve is the public interface that represents an elliptic curve supported by the cb-mpc library. -// -// Concrete curves – secp256k1, P-256 and Ed25519 – implement this interface. -// Users should obtain a curve via the constructor helpers NewSecp256k1, NewP256 or NewEd25519 -// rather than dealing with numeric curve codes directly. -// -// All implementations wrap native (C++) resources. Therefore each Curve must be released -// with a call to Free once it is no longer needed to avoid memory leaks. -// Alternatively, the caller can rely on the Go GC finalizer to invoke Free automatically -// (not implemented here to avoid hidden costs). -// -// The methods mirror the functionality that was previously available on the ECurve struct. -// They remain unchanged so existing call-sites require only minimal migration. -type Curve interface { - // Generator returns the generator point of the curve. - Generator() *Point - // Order returns the (big-endian) order of the curve group. - Order() []byte - // Free releases the native resources associated with the curve. - Free() - // RandomScalar returns a uniformly random non-zero scalar in the interval - // [1, Order()-1]. The random sampling is delegated to the native C++ layer. - RandomScalar() (*Scalar, error) - // MultiplyGenerator multiplies the curve generator by the given scalar - // and returns the resulting point (k * G). - MultiplyGenerator(k *Scalar) (*Point, error) - // RandomKeyPair returns a uniformly random non-zero scalar in the interval - // [1, Order()-1] and the corresponding point (k * G). - RandomKeyPair() (*Scalar, *Point, error) - // Add returns (a + b) mod Order() as a new Scalar. - Add(a, b *Scalar) (*Scalar, error) - // String returns a human friendly identifier (implements fmt.Stringer). - fmt.Stringer -} - -// Internal numeric identifiers – matching OpenSSL NIDs – used by the native library. -const ( - secp256k1Code = 714 // OpenSSL NID_secp256k1 - p256Code = 415 // OpenSSL NID_X9_62_prime256v1 - ed25519Code = 1087 // OpenSSL NID_ED25519 -) - -// ========================= common implementation ========================= - -type baseCurve struct { - cCurve cgobinding.ECurveRef -} - -func newBaseCurve(code int) (*baseCurve, error) { - cCurve, err := cgobinding.ECurveFind(code) - if err != nil { - return nil, err - } - return &baseCurve{cCurve: cCurve}, nil -} - -func (b *baseCurve) Generator() *Point { - cPoint := cgobinding.ECurveGenerator(b.cCurve) - return &Point{cPoint: cPoint} -} - -func (b *baseCurve) Order() []byte { - return cgobinding.ECurveOrderToMem(b.cCurve) -} - -func (b *baseCurve) Free() { - b.cCurve.Free() -} - -func (b *baseCurve) RandomScalar() (*Scalar, error) { - // Delegate sampling to the native library so we stay consistent with the - // core C++ implementation. - kBytes := cgobinding.ECurveRandomScalarToMem(b.cCurve) - if len(kBytes) == 0 { - return nil, fmt.Errorf("failed to generate random scalar") - } - return &Scalar{Bytes: kBytes}, nil -} - -func (b *baseCurve) MultiplyGenerator(k *Scalar) (*Point, error) { - if k == nil { - return nil, fmt.Errorf("scalar is nil") - } - gen := b.Generator() - defer gen.Free() - return gen.Multiply(k) -} - -func (b *baseCurve) RandomKeyPair() (*Scalar, *Point, error) { - x, err := b.RandomScalar() - if err != nil { - return nil, nil, err - } - X, err := b.MultiplyGenerator(x) - if err != nil { - return nil, nil, err - } - return x, X, nil -} - -func (b *baseCurve) Add(a, c *Scalar) (*Scalar, error) { - if a == nil || c == nil { - return nil, fmt.Errorf("nil scalar operand") - } - res := cgobinding.ScalarAddModOrder(b.cCurve, a.Bytes, c.Bytes) - if len(res) == 0 { - return nil, fmt.Errorf("scalar modular addition failed") - } - return &Scalar{Bytes: res}, nil -} - -func (b *baseCurve) String() string { - switch cgobinding.ECurveGetCurveCode(b.cCurve) { - case secp256k1Code: - return "secp256k1" - case p256Code: - return "P-256" - case ed25519Code: - return "Ed25519" - default: - return fmt.Sprintf("unknown curve (%d)", cgobinding.ECurveGetCurveCode(b.cCurve)) - } -} - -// Code returns the internal numeric identifier for the curve (OpenSSL NID). -// This is the standard medium for crossing the Go↔native boundary for curves. -// Bindings will resolve a native handle on-demand using this code. -func Code(c Curve) int { return cgobinding.ECurveGetCurveCode(nativeRef(c)) } - -// NewFromCode constructs a Curve instance for the provided internal code. -// It resolves a native curve under the hood and wraps it in the appropriate -// concrete curve type when recognised. -func NewFromCode(code int) (Curve, error) { - bc, err := newBaseCurve(code) - if err != nil { - return nil, err - } - // Pick concrete wrapper for friendly String() output - switch code { - case secp256k1Code: - return &secp256k1Curve{baseCurve: bc}, nil - case p256Code: - return &p256Curve{baseCurve: bc}, nil - case ed25519Code: - return &ed25519Curve{baseCurve: bc}, nil - default: - return bc, nil - } -} - -// nativeRef exposes the underlying native curve handle for a given Curve. -// -// It is unexported so it remains invisible to application code. Internal -// packages access it via go:linkname (see api/internal/curveref) to bridge -// between the high-level Go types and the low-level C++ pointers. -func nativeRef(c Curve) cgobinding.ECurveRef { - switch v := c.(type) { - case *secp256k1Curve: - return v.cCurve - case *p256Curve: - return v.cCurve - case *ed25519Curve: - return v.cCurve - default: - panic(fmt.Sprintf("unsupported curve type %T", c)) - } -} - -// ========================= concrete curve types ========================= - -// secp256k1Curve implements Curve for the secp256k1 group used by Bitcoin. -type secp256k1Curve struct{ *baseCurve } - -// NewSecp256k1 returns a new instance of the secp256k1 curve. -func NewSecp256k1() (Curve, error) { - bc, err := newBaseCurve(secp256k1Code) - if err != nil { - return nil, err - } - return &secp256k1Curve{baseCurve: bc}, nil -} - -// p256Curve implements Curve for the NIST P-256 curve. -type p256Curve struct{ *baseCurve } - -// NewP256 returns a new instance of the P-256 curve. -func NewP256() (Curve, error) { - bc, err := newBaseCurve(p256Code) - if err != nil { - return nil, err - } - return &p256Curve{baseCurve: bc}, nil -} - -// ed25519Curve implements Curve for the Ed25519 Edwards curve. -type ed25519Curve struct{ *baseCurve } - -// NewEd25519 returns a new instance of the Ed25519 curve. -func NewEd25519() (Curve, error) { - bc, err := newBaseCurve(ed25519Code) - if err != nil { - return nil, err - } - return &ed25519Curve{baseCurve: bc}, nil -} - -// Compile-time guarantees that each concrete type satisfies the interface. -var _ Curve = (*secp256k1Curve)(nil) -var _ Curve = (*p256Curve)(nil) -var _ Curve = (*ed25519Curve)(nil) diff --git a/demos-go/cb-mpc-go/api/curve/curve_test.go b/demos-go/cb-mpc-go/api/curve/curve_test.go deleted file mode 100644 index 65d789d3..00000000 --- a/demos-go/cb-mpc-go/api/curve/curve_test.go +++ /dev/null @@ -1,198 +0,0 @@ -package curve - -import ( - "fmt" - "math/big" - "testing" -) - -// TestSupportedCurves instantiates each supported curve and performs -// a few basic sanity-checks to demonstrate the public API. -func TestSupportedCurves(t *testing.T) { - cases := []struct { - name string - newFn func() (Curve, error) - expect string - }{ - {"secp256k1", NewSecp256k1, "secp256k1"}, - {"P-256", NewP256, "P-256"}, - {"Ed25519", NewEd25519, "Ed25519"}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - curve, err := tc.newFn() - if err != nil { - t.Fatalf("failed to create %s: %v", tc.name, err) - } - defer curve.Free() - - if got := curve.String(); got != tc.expect { - t.Fatalf("String() = %q, want %q", got, tc.expect) - } - - order := curve.Order() - if len(order) == 0 { - t.Fatalf("order returned empty slice for %s", tc.name) - } - - gen := curve.Generator() - defer gen.Free() - - if gen.IsZero() { - t.Fatalf("generator reported as zero for %s", tc.name) - } - }) - } -} - -func TestRandomScalar(t *testing.T) { - curve, err := NewSecp256k1() - if err != nil { - t.Fatalf("init curve: %v", err) - } - defer curve.Free() - - scalar, err := curve.RandomScalar() - if err != nil { - t.Fatalf("RandomScalar failed: %v", err) - } - - if len(scalar.Bytes) != len(curve.Order()) { - t.Fatalf("scalar byte length mismatch: got %d want %d", len(scalar.Bytes), len(curve.Order())) - } - - order := new(big.Int).SetBytes(curve.Order()) - val := new(big.Int).SetBytes(scalar.Bytes) - if val.Sign() == 0 || val.Cmp(order) >= 0 { - t.Fatalf("scalar out of valid range") - } - - // Test MultiplyGenerator - point, err := curve.MultiplyGenerator(scalar) - if err != nil { - t.Fatalf("MultiplyGenerator failed: %v", err) - } - defer point.Free() - - gen := curve.Generator() - defer gen.Free() - expected, err := gen.Multiply(scalar) - if err != nil { - t.Fatalf("Multiply for expected failed: %v", err) - } - defer expected.Free() - - if !point.Equals(expected) { - t.Fatalf("MultiplyGenerator result mismatch") - } - - // Modular addition via Curve.Add - scalar2, err := curve.RandomScalar() - if err != nil { - t.Fatalf("RandomScalar (second) failed: %v", err) - } - - sum, err := curve.Add(scalar, scalar2) - if err != nil { - t.Fatalf("Curve.Add failed: %v", err) - } - if len(sum.Bytes) == 0 { - t.Fatalf("Curve.Add returned empty result") - } - - sum2, err := curve.Add(scalar2, scalar) - if err != nil { - t.Fatalf("Curve.Add failed: %v", err) - } - if len(sum2.Bytes) == 0 { - t.Fatalf("Curve.Add returned empty result") - } - // if !sum.Equals(sum2) { - // t.Fatalf("Curve.Add returned different result") - // } - - // Check (scalar + scalar2) mod order matches big.Int computation - // v1 := new(big.Int).SetBytes(scalar.Bytes) - // v2 := new(big.Int).SetBytes(scalar2.Bytes) - // order = new(big.Int).SetBytes(curve.Order()) - // expectedSum := new(big.Int).Add(v1, v2) - // expectedSum.Mod(expectedSum, order) - // gotSum := new(big.Int).SetBytes(sum.Bytes) - // if expectedSum.Cmp(gotSum) != 0 { - // t.Fatalf("Curve.Add incorrect modulo addition") - // } -} - -// TestCodeNewFromCodeRoundTrip verifies c -> Code(c) -> NewFromCode -> c' works for all supported curves. -func TestCodeNewFromCodeRoundTrip(t *testing.T) { - cases := []struct { - name string - newFn func() (Curve, error) - }{ - {"secp256k1", NewSecp256k1}, - {"P-256", NewP256}, - {"Ed25519", NewEd25519}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - c, err := tc.newFn() - if err != nil { - t.Fatalf("failed to construct curve %s: %v", tc.name, err) - } - defer c.Free() - - code := Code(c) - - c2, err := NewFromCode(code) - if err != nil { - t.Fatalf("NewFromCode failed for %s (code=%d): %v", tc.name, code, err) - } - defer c2.Free() - - // Both should identify as the same curve via String - if c.String() != c2.String() { - t.Fatalf("round-trip String mismatch: %q vs %q", c.String(), c2.String()) - } - - // Orders should match byte-for-byte - o1 := c.Order() - o2 := c2.Order() - if len(o1) == 0 || len(o2) == 0 || len(o1) != len(o2) { - t.Fatalf("order length mismatch or empty: len1=%d len2=%d", len(o1), len(o2)) - } - for i := range o1 { - if o1[i] != o2[i] { - t.Fatalf("order bytes differ at %d", i) - } - } - - // Generator equality check - g1 := c.Generator() - defer g1.Free() - g2 := c2.Generator() - defer g2.Free() - if !g1.Equals(g2) { - t.Fatalf("generators not equal after round-trip for %s", tc.name) - } - }) - } -} - -// TestNewFromCodeInvalid ensures invalid/unsupported code is handled. -// Depending on native behavior, NewFromCode may return an error, or a curve -// instance that reports as "unknown curve (code)". We accept either. -func TestNewFromCodeInvalid(t *testing.T) { - invalidCodes := []int{-1, 0, 999999, 123456} - for _, code := range invalidCodes { - t.Run(fmt.Sprintf("code_%d", code), func(t *testing.T) { - c, err := NewFromCode(code) - if err != nil { - return // acceptable: invalid code produced an error - } - // Otherwise, ensure we can at least free the handle without crashing. - c.Free() - }) - } -} diff --git a/demos-go/cb-mpc-go/api/curve/doc.go b/demos-go/cb-mpc-go/api/curve/doc.go deleted file mode 100644 index f156abb3..00000000 --- a/demos-go/cb-mpc-go/api/curve/doc.go +++ /dev/null @@ -1,41 +0,0 @@ -// Package curve provides idiomatic Go bindings for the elliptic-curve primitives -// implemented in the C++ `cbmpc` library. -// -// The package wraps two native C++ handle types: -// 1. `ecurve_t` – an elliptic-curve definition (currently only secp256k1) -// 2. `ecc_point_t` – a point that lives on a particular curve -// -// Because the underlying objects are allocated on the C++ heap, every value that -// is created from this package owns native resources. ALWAYS call the `Free` or -// `Close` method when you are done with a value, or use `defer` immediately after -// creation. Failing to do so will leak memory. -// -// # Quick start -// -// The snippet below shows the most common workflow – creating a curve, deriving -// the generator point, performing a scalar multiplication and reading back the -// affine coordinates: -// -// cur, err := curve.NewSecp256k1() -// if err != nil { -// log.Fatalf("initialising curve: %v", err) -// } -// defer cur.Free() -// -// G := cur.Generator() -// defer gen.Free() -// -// // Multiply the generator by a 32-byte scalar. -// scalar := curve.RandomScalar(cur) -// p := G.Mul(scalar) -// -// Features -// -// - Creation of named curves (secp256k1 for now) -// - Arithmetic on immutable `Point` values: Add, Sub, Neg, Mul -// - Constant-time, allocation-free serialization (compressed & uncompressed) -// - Helper utilities for random scalar / point generation (in tests) -// -// All heavy arithmetic is executed in constant time inside C++, guaranteeing that -// the Go bindings themselves never become a side-channel. -package curve diff --git a/demos-go/cb-mpc-go/api/curve/point.go b/demos-go/cb-mpc-go/api/curve/point.go deleted file mode 100644 index 752829f6..00000000 --- a/demos-go/cb-mpc-go/api/curve/point.go +++ /dev/null @@ -1,92 +0,0 @@ -package curve - -import ( - "fmt" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// Point represents a point on an elliptic curve -type Point struct { - cPoint cgobinding.ECCPointRef -} - -// NewPointFromBytes creates a new point from serialized bytes -func NewPointFromBytes(pointBytes []byte) (*Point, error) { - if len(pointBytes) == 0 { - return nil, fmt.Errorf("empty point bytes") - } - cPoint, err := cgobinding.ECCPointFromBytes(pointBytes) - if err != nil { - return nil, err - } - return &Point{cPoint: cPoint}, nil -} - -// Free releases the memory associated with the point -func (p *Point) Free() { - p.cPoint.Free() -} - -// Multiply multiplies the point by a scalar -func (p *Point) Multiply(scalar *Scalar) (*Point, error) { - if scalar.Bytes == nil { - return nil, fmt.Errorf("nil scalar") - } - cPoint, err := cgobinding.ECCPointMultiply(p.cPoint, scalar.Bytes) - if err != nil { - return nil, err - } - return &Point{cPoint: cPoint}, nil -} - -// Add adds two points together -func (p *Point) Add(other *Point) *Point { - cPoint := cgobinding.ECCPointAdd(p.cPoint, other.cPoint) - return &Point{cPoint: cPoint} -} - -// Subtract subtracts one point from another -func (p *Point) Subtract(other *Point) *Point { - cPoint := cgobinding.ECCPointSubtract(p.cPoint, other.cPoint) - return &Point{cPoint: cPoint} -} - -// GetX returns the x coordinate of the point as bytes -func (p *Point) GetX() []byte { - return cgobinding.ECCPointGetX(p.cPoint) -} - -// GetY returns the y coordinate of the point as bytes -func (p *Point) GetY() []byte { - return cgobinding.ECCPointGetY(p.cPoint) -} - -// IsZero checks if the point is the point at infinity (zero point) -func (p *Point) IsZero() bool { - return cgobinding.ECCPointIsZero(p.cPoint) -} - -// Equals checks if two points are equal -func (p *Point) Equals(other *Point) bool { - return cgobinding.ECCPointEquals(p.cPoint, other.cPoint) -} - -// String returns a string representation of the point -func (p *Point) String() string { - if p.IsZero() { - return "Point(∞)" - } - x := p.GetX() - y := p.GetY() - return fmt.Sprintf("Point(x: %x, y: %x)", x, y) -} - -// Bytes returns the canonical serialization of the point as produced by the -// underlying native library (SEC1 uncompressed format). -func (p *Point) Bytes() []byte { - if p == nil { - return nil - } - return cgobinding.ECCPointToBytes(p.cPoint) -} diff --git a/demos-go/cb-mpc-go/api/curve/scalar.go b/demos-go/cb-mpc-go/api/curve/scalar.go deleted file mode 100644 index 3caec5c3..00000000 --- a/demos-go/cb-mpc-go/api/curve/scalar.go +++ /dev/null @@ -1,60 +0,0 @@ -package curve - -import ( - "bytes" - "fmt" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// Scalar represents a field element (mod the curve order). -// -// Bytes is a fixed-length, big-endian encoding with the same byte length as -// the curve order. -// -// For now the type only supports random generation, but it lays the -// foundation for future arithmetic helpers. -type Scalar struct { - // Bytes holds the big-endian representation of the scalar. The length of - // the slice matches the order of the underlying curve (but the scalar - // itself no longer embeds that information). - Bytes []byte -} - -// NewScalarFromInt64 creates a new Scalar from an int64 value. -// The int64 value is converted to a big number using the native C++ layer's -// set_int64 function to ensure consistent representation. -func NewScalarFromInt64(value int64) *Scalar { - bytes := cgobinding.ScalarFromInt64(value) - return &Scalar{Bytes: bytes} -} - -// Add returns s + other as a new Scalar. The addition is performed by the -// native C++ layer (bn_t addition) to leverage its constant-time -// implementation and to stay consistent with the rest of the library. -func (s *Scalar) Add(other *Scalar) (*Scalar, error) { - if s == nil || other == nil { - return nil, fmt.Errorf("nil scalar operand") - } - res := cgobinding.ScalarAdd(s.Bytes, other.Bytes) - if len(res) == 0 { - return nil, fmt.Errorf("scalar addition failed") - } - return &Scalar{Bytes: res}, nil -} - -// Equal returns true if s and other represent the same scalar value. -// Returns false if either scalar is nil. -func (s *Scalar) Equal(other *Scalar) bool { - if s == nil || other == nil { - return false - } - return bytes.Equal(s.Bytes, other.Bytes) -} - -func (s *Scalar) String() string { - if s == nil { - return "" - } - return fmt.Sprintf("Scalar(%x)", s.Bytes) -} diff --git a/demos-go/cb-mpc-go/api/curve/scalar_test.go b/demos-go/cb-mpc-go/api/curve/scalar_test.go deleted file mode 100644 index 6ca6aeb4..00000000 --- a/demos-go/cb-mpc-go/api/curve/scalar_test.go +++ /dev/null @@ -1,291 +0,0 @@ -package curve - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestScalarEqual(t *testing.T) { - curve, err := NewSecp256k1() - require.NoError(t, err) - defer curve.Free() - - t.Run("equal_scalars", func(t *testing.T) { - scalar1, err := curve.RandomScalar() - require.NoError(t, err) - - // Create a copy with the same bytes - scalar2 := &Scalar{Bytes: make([]byte, len(scalar1.Bytes))} - copy(scalar2.Bytes, scalar1.Bytes) - - assert.True(t, scalar1.Equal(scalar2), "scalars with same bytes should be equal") - assert.True(t, scalar2.Equal(scalar1), "equality should be symmetric") - }) - - t.Run("different_scalars", func(t *testing.T) { - scalar1, err := curve.RandomScalar() - require.NoError(t, err) - - scalar2, err := curve.RandomScalar() - require.NoError(t, err) - - // Very unlikely that two random scalars are equal - assert.False(t, scalar1.Equal(scalar2), "different random scalars should not be equal") - }) - - t.Run("nil_scalars", func(t *testing.T) { - scalar, err := curve.RandomScalar() - require.NoError(t, err) - - var nilScalar *Scalar - - assert.False(t, scalar.Equal(nilScalar), "scalar should not equal nil") - assert.False(t, nilScalar.Equal(scalar), "nil should not equal scalar") - assert.False(t, nilScalar.Equal(nilScalar), "nil should not equal nil") - }) - - t.Run("self_equality", func(t *testing.T) { - scalar, err := curve.RandomScalar() - require.NoError(t, err) - - assert.True(t, scalar.Equal(scalar), "scalar should equal itself") - }) -} - -func TestScalarAddCommutativity(t *testing.T) { - curve, err := NewSecp256k1() - require.NoError(t, err) - defer curve.Free() - - t.Run("a_plus_b_equals_b_plus_a", func(t *testing.T) { - a, err := curve.RandomScalar() - require.NoError(t, err) - - b, err := curve.RandomScalar() - require.NoError(t, err) - - // Compute a + b - aPlusB, err := a.Add(b) - require.NoError(t, err) - - // Compute b + a - bPlusA, err := b.Add(a) - require.NoError(t, err) - - // They should be equal - assert.True(t, aPlusB.Equal(bPlusA), "a+b should equal b+a (commutativity)") - }) - - t.Run("multiple_random_pairs", func(t *testing.T) { - // Test commutativity with multiple random pairs to increase confidence - for i := 0; i < 10; i++ { - a, err := curve.RandomScalar() - require.NoError(t, err) - - b, err := curve.RandomScalar() - require.NoError(t, err) - - aPlusB, err := a.Add(b) - require.NoError(t, err) - - bPlusA, err := b.Add(a) - require.NoError(t, err) - - assert.True(t, aPlusB.Equal(bPlusA), "commutativity failed for pair %d", i) - } - }) -} - -func TestScalarAddErrorHandling(t *testing.T) { - curve, err := NewSecp256k1() - require.NoError(t, err) - defer curve.Free() - - t.Run("nil_operands", func(t *testing.T) { - scalar, err := curve.RandomScalar() - require.NoError(t, err) - - var nilScalar *Scalar - - // Test scalar + nil - result, err := scalar.Add(nilScalar) - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "nil scalar operand") - - // Test nil + scalar - result, err = nilScalar.Add(scalar) - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "nil scalar operand") - - // Test nil + nil - result, err = nilScalar.Add(nilScalar) - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "nil scalar operand") - }) -} - -func TestScalarString(t *testing.T) { - curve, err := NewSecp256k1() - require.NoError(t, err) - defer curve.Free() - - t.Run("valid_scalar", func(t *testing.T) { - scalar, err := curve.RandomScalar() - require.NoError(t, err) - - str := scalar.String() - assert.Contains(t, str, "Scalar(") - assert.Contains(t, str, ")") - // Should contain hex representation of the bytes - assert.Greater(t, len(str), len("Scalar()")) - }) - - t.Run("nil_scalar", func(t *testing.T) { - var nilScalar *Scalar - str := nilScalar.String() - assert.Equal(t, "", str) - }) -} - -func TestScalarAddAssociativity(t *testing.T) { - curve, err := NewSecp256k1() - require.NoError(t, err) - defer curve.Free() - - t.Run("a_plus_b_plus_c_associativity", func(t *testing.T) { - a, err := curve.RandomScalar() - require.NoError(t, err) - - b, err := curve.RandomScalar() - require.NoError(t, err) - - c, err := curve.RandomScalar() - require.NoError(t, err) - - // Compute (a + b) + c - aPlusB, err := a.Add(b) - require.NoError(t, err) - - aPlusBPlusC, err := aPlusB.Add(c) - require.NoError(t, err) - - // Compute a + (b + c) - bPlusC, err := b.Add(c) - require.NoError(t, err) - - aPlusBPlusC2, err := a.Add(bPlusC) - require.NoError(t, err) - - // They should be equal - assert.True(t, aPlusBPlusC.Equal(aPlusBPlusC2), "(a+b)+c should equal a+(b+c) (associativity)") - }) -} - -func TestScalarBytesConsistency(t *testing.T) { - curve, err := NewSecp256k1() - require.NoError(t, err) - defer curve.Free() - - t.Run("bytes_length_consistency", func(t *testing.T) { - orderBytes := curve.Order() - - for i := 0; i < 5; i++ { - scalar, err := curve.RandomScalar() - require.NoError(t, err) - - assert.Equal(t, len(orderBytes), len(scalar.Bytes), - "scalar byte length should match curve order length") - } - }) - - t.Run("equal_scalars_have_equal_bytes", func(t *testing.T) { - scalar1, err := curve.RandomScalar() - require.NoError(t, err) - - // Create scalar2 with same bytes - scalar2 := &Scalar{Bytes: make([]byte, len(scalar1.Bytes))} - copy(scalar2.Bytes, scalar1.Bytes) - - assert.True(t, scalar1.Equal(scalar2)) - assert.True(t, bytes.Equal(scalar1.Bytes, scalar2.Bytes)) - }) -} - -func TestNewScalarFromInt64(t *testing.T) { - t.Run("positive_values", func(t *testing.T) { - testCases := []int64{1, 42, 100, 1000, 65536} - - for _, value := range testCases { - scalar := NewScalarFromInt64(value) - assert.NotNil(t, scalar, "scalar should not be nil for value %d", value) - assert.NotNil(t, scalar.Bytes, "scalar bytes should not be nil for value %d", value) - assert.Greater(t, len(scalar.Bytes), 0, "scalar bytes should not be empty for value %d", value) - } - }) - - t.Run("negative_values", func(t *testing.T) { - testCases := []int64{-1, -42, -100, -1000} - - for _, value := range testCases { - scalar := NewScalarFromInt64(value) - assert.NotNil(t, scalar, "scalar should not be nil for value %d", value) - assert.NotNil(t, scalar.Bytes, "scalar bytes should not be nil for value %d", value) - assert.Greater(t, len(scalar.Bytes), 0, "scalar bytes should not be empty for value %d", value) - } - }) - - t.Run("zero_value", func(t *testing.T) { - scalar := NewScalarFromInt64(0) - assert.NotNil(t, scalar) - assert.Equal(t, 0, len(scalar.Bytes)) - }) - - t.Run("equality_consistency", func(t *testing.T) { - // Same values should produce equal scalars - scalar1 := NewScalarFromInt64(42) - scalar2 := NewScalarFromInt64(42) - - assert.True(t, scalar1.Equal(scalar2), "scalars from same int64 value should be equal") - assert.True(t, bytes.Equal(scalar1.Bytes, scalar2.Bytes), "bytes should be identical for same int64 value") - }) - - t.Run("different_values_not_equal", func(t *testing.T) { - scalar1 := NewScalarFromInt64(42) - scalar2 := NewScalarFromInt64(43) - - assert.False(t, scalar1.Equal(scalar2), "scalars from different int64 values should not be equal") - }) - - t.Run("addition_with_int64_scalars", func(t *testing.T) { - // Test that scalars created from int64 work with addition - scalar1 := NewScalarFromInt64(10) - scalar2 := NewScalarFromInt64(5) - - sum, err := scalar1.Add(scalar2) - require.NoError(t, err) - assert.NotNil(t, sum) - - // The sum should be different from both operands - assert.False(t, sum.Equal(scalar1)) - assert.False(t, sum.Equal(scalar2)) - - scalar3 := NewScalarFromInt64(15) - assert.True(t, scalar3.Equal(sum)) - }) - - t.Run("string_representation", func(t *testing.T) { - scalar := NewScalarFromInt64(123) - str := scalar.String() - - assert.Contains(t, str, "Scalar(") - assert.Contains(t, str, ")") - assert.Contains(t, str, "7b") - assert.Greater(t, len(str), len("Scalar()")) - }) -} diff --git a/demos-go/cb-mpc-go/api/mpc/access_structure.go b/demos-go/cb-mpc-go/api/mpc/access_structure.go deleted file mode 100644 index 5d03e1f7..00000000 --- a/demos-go/cb-mpc-go/api/mpc/access_structure.go +++ /dev/null @@ -1,266 +0,0 @@ -// Package mpc provides data structures used by multiple MPC protocols. -// This file defines AccessNode – a minimal, logic-free representation -// of attribute-based access structures that will later be processed by -// native C++ code. -package mpc - -import ( - "fmt" - "strings" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// NodeKind tells which logical operator a node represents. -// -// The zero value corresponds to KindLeaf so that freshly allocated structs -// default to the most restrictive (leaf) kind. -type NodeKind uint8 - -const ( - // KindLeaf marks a terminal node with no children. - KindLeaf NodeKind = iota - // KindAnd represents logical conjunction – all children must be satisfied. - KindAnd - // KindOr represents logical disjunction – any child suffices. - KindOr - // KindThreshold represents an out-of-n condition where at least K of - // the children must be satisfied. - KindThreshold -) - -// AccessNode is a purely-data representation of a node in an -// attribute-based access structure. -// -// Invariants (not enforced by code): -// - The unique root has Name == "" and Parent == nil. -// - Non-leaf nodes have at least one child. -// - KindThreshold nodes use field K (0 < K ≤ len(Children)). -// - KindLeaf nodes ignore field K and must have len(Children) == 0. -// -// All fields are exported so that other packages (or the C++ bridge) can walk -// and mutate the tree freely. -// -// NOTE: This struct purposefully embeds _no_ validation logic. Consumers are -// expected to perform their own checks or rely on downstream C++ code. -type AccessNode struct { - Name string // human-readable identifier (root == "") - Kind NodeKind // AND / OR / THRESHOLD / LEAF - Parent *AccessNode // nil for root - Children []*AccessNode // nil or empty slice for leaf - K int // threshold value if Kind == KindThreshold -} - -// ================= AccessStructure wrapper ========================= - -// AccessStructure bundles an access-tree (Root) together with the elliptic -// curve on which the underlying cryptographic secret-sharing operates. -// -// The type is a thin container with helper utilities to bridge to the native -// C++ implementation. -type AccessStructure struct { - Root *AccessNode // Root of the access structure tree (must not be nil) - Curve curve.Curve // Elliptic curve used for commitments / shares -} - -// String returns a multi-line representation that starts with the curve name -// and then embeds the pretty-printed access-tree. -func (as *AccessStructure) String() string { - if as == nil { - return "" - } - var sb strings.Builder - if as.Curve != nil { - sb.WriteString(fmt.Sprintf("Curve: %s\n", as.Curve)) - } else { - sb.WriteString("Curve: \n") - } - if as.Root != nil { - sb.WriteString(as.Root.String()) - } else { - sb.WriteString("\n") - } - return sb.String() -} - -// toCryptoAC converts the AccessStructure into the native secret-sharing -// representation expected by the MPC engine and returns an opaque handle. -// -// The method panics if the AccessStructure is malformed (nil fields, unknown -// node kinds, …). Such errors typically indicate a misuse by calling code. -func (as *AccessStructure) toCryptoAC() cgobinding.C_AcPtr { - if as == nil { - panic("AccessStructure.toCryptoAC: receiver is nil") - } - if as.Root == nil { - panic("AccessStructure.toCryptoAC: Root is nil") - } - if as.Curve == nil { - panic("AccessStructure.toCryptoAC: Curve is nil") - } - - // Local helper mapping Go enum to C enum (identical to the previous - // implementation that lived on AccessNode). - kindToC := func(k NodeKind) cgobinding.NodeType { - switch k { - case KindLeaf: - return cgobinding.NodeType_LEAF - case KindAnd: - return cgobinding.NodeType_AND - case KindOr: - return cgobinding.NodeType_OR - case KindThreshold: - return cgobinding.NodeType_THRESHOLD - default: - panic(fmt.Sprintf("AccessStructure.toCryptoAC: unknown NodeKind %d", k)) - } - } - - // Recursively clone the Go tree into the C representation. - var build func(n *AccessNode) cgobinding.C_NodePtr - build = func(n *AccessNode) cgobinding.C_NodePtr { - cNode := cgobinding.NewNode(kindToC(n.Kind), n.Name, n.K) - for _, child := range n.Children { - if child == nil { - continue - } - childPtr := build(child) - cgobinding.AddChild(cNode, childPtr) - } - return cNode - } - - rootPtr := build(as.Root) - - code := curve.Code(as.Curve) - curveRef, err := cgobinding.ECurveFind(code) - if err != nil { - panic(fmt.Sprintf("AccessStructure.toCryptoAC: unsupported curve code %d", code)) - } - - ac := cgobinding.NewAccessStructure(rootPtr, curveRef) - return ac -} - -// Leaf returns a pointer to a leaf node. -func Leaf(name string) *AccessNode { - return &AccessNode{Name: name, Kind: KindLeaf} -} - -// And creates a logical AND node and wires the Parent pointers of its children. -func And(name string, kids ...*AccessNode) *AccessNode { - n := &AccessNode{Name: name, Kind: KindAnd, Children: kids} - for _, c := range kids { - if c != nil { - c.Parent = n - } - } - return n -} - -// Or creates a logical OR node and wires the Parent pointers of its children. -func Or(name string, kids ...*AccessNode) *AccessNode { - n := &AccessNode{Name: name, Kind: KindOr, Children: kids} - for _, c := range kids { - if c != nil { - c.Parent = n - } - } - return n -} - -// Threshold creates a threshold node (at least k of n) and wires the Parent pointers. -// It performs no validation of k against the number of children – that is deferred -// to the consuming C++ logic. -func Threshold(name string, k int, kids ...*AccessNode) *AccessNode { - n := &AccessNode{Name: name, Kind: KindThreshold, K: k, Children: kids} - for _, c := range kids { - if c != nil { - c.Parent = n - } - } - return n -} - -/* -Example construction (root name must be ""): - - root := And("", - Or("role", - Leaf("role:Admin"), - Leaf("dept:HR"), - ), - Threshold("sig", 2, - Leaf("sig:A"), - Leaf("sig:B"), - Leaf("sig:C"), - ), - ) - -`root` is now ready to be translated to C++. -*/ - -// String returns a human-readable, multi-line representation of the subtree -// rooted at the receiver. It recursively walks the tree and formats each node -// with two-space indentation per level. -// -// Example (matching the construction snippet below): -// -// AND -// OR role -// LEAF role:Admin -// LEAF dept:HR -// THRESHOLD sig (2/3) -// LEAF sig:A -// LEAF sig:B -// LEAF sig:C -// -// The function never returns an error – malformed trees are printed as-is. -func (n *AccessNode) String() string { - if n == nil { - return "" - } - var sb strings.Builder - n.format(&sb, 0) - return sb.String() -} - -// format writes a single node (with indentation) followed by all children. -func (n *AccessNode) format(sb *strings.Builder, level int) { - indent := strings.Repeat(" ", level) - sb.WriteString(indent) - sb.WriteString(n.Kind.String()) - if n.Name != "" { - sb.WriteByte(' ') - sb.WriteString(n.Name) - } - if n.Kind == KindThreshold { - sb.WriteString(fmt.Sprintf(" (%d/%d)", n.K, len(n.Children))) - } - sb.WriteByte('\n') - for _, child := range n.Children { - if child != nil { - child.format(sb, level+1) - } else { - sb.WriteString(strings.Repeat(" ", level+1)) - sb.WriteString("\n") - } - } -} - -// String returns the symbolic name of the NodeKind (LEAF, AND, OR, THRESHOLD). -func (k NodeKind) String() string { - switch k { - case KindLeaf: - return "LEAF" - case KindAnd: - return "AND" - case KindOr: - return "OR" - case KindThreshold: - return "THRESHOLD" - default: - return fmt.Sprintf("NodeKind(%d)", k) - } -} diff --git a/demos-go/cb-mpc-go/api/mpc/agree_random.go b/demos-go/cb-mpc-go/api/mpc/agree_random.go deleted file mode 100644 index a197cba3..00000000 --- a/demos-go/cb-mpc-go/api/mpc/agree_random.go +++ /dev/null @@ -1,35 +0,0 @@ -package mpc - -import ( - "fmt" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// AgreeRandomRequest represents the input parameters for agree random protocol -type AgreeRandomRequest struct { - BitLen int // Number of bits for the random value -} - -// AgreeRandomResponse represents the output of agree random protocol -type AgreeRandomResponse struct { - RandomValue []byte // The agreed-upon random value -} - -// AgreeRandom executes the agree random protocol between two parties. -// Both parties will agree on the same random value of the specified bit length. -func AgreeRandom(job2p *Job2P, req *AgreeRandomRequest) (*AgreeRandomResponse, error) { - if req.BitLen <= 0 { - return nil, fmt.Errorf("bit length must be positive, got %d", req.BitLen) - } - - // Execute the agree random protocol using the provided Job2P - randomValue, err := cgobinding.AgreeRandom(job2p.cgo(), req.BitLen) - if err != nil { - return nil, fmt.Errorf("agree random protocol failed: %v", err) - } - - return &AgreeRandomResponse{ - RandomValue: randomValue, - }, nil -} diff --git a/demos-go/cb-mpc-go/api/mpc/agreer_andom_test.go b/demos-go/cb-mpc-go/api/mpc/agreer_andom_test.go deleted file mode 100644 index d8c73d4d..00000000 --- a/demos-go/cb-mpc-go/api/mpc/agreer_andom_test.go +++ /dev/null @@ -1,208 +0,0 @@ -package mpc - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" -) - -// AgreeRandomWithMockNet is a test-only convenience wrapper that spins up two parties -// connected via the in-memory mock network and runs the AgreeRandom protocol. -// This helper lives in *_test.go files so it is not exposed to library users. -func AgreeRandomWithMockNet(nParties int, bitLen int) ([]*AgreeRandomResponse, error) { - if nParties != 2 { - return nil, fmt.Errorf("agree random currently only supports 2 parties, got %d", nParties) - } - - if bitLen <= 0 { - return nil, fmt.Errorf("bit length must be positive, got %d", bitLen) - } - - // Create mock network messengers - messengers := mocknet.NewMockNetwork(nParties) - - partyNames := []string{"party_0", "party_1"} - - responses := make([]*AgreeRandomResponse, nParties) - errChan := make(chan error, nParties) - respChan := make(chan struct { - index int - resp *AgreeRandomResponse - }, nParties) - - for i := 0; i < nParties; i++ { - go func(partyIndex int) { - j, err := NewJob2P(messengers[partyIndex], partyIndex, partyNames) - if err != nil { - errChan <- fmt.Errorf("party %d failed to create Job2P: %v", partyIndex, err) - return - } - defer j.Free() - - req := &AgreeRandomRequest{BitLen: bitLen} - resp, err := AgreeRandom(j, req) - if err != nil { - errChan <- fmt.Errorf("party %d failed: %v", partyIndex, err) - return - } - - respChan <- struct { - index int - resp *AgreeRandomResponse - }{partyIndex, resp} - }(i) - } - - for i := 0; i < nParties; i++ { - select { - case err := <-errChan: - return nil, err - case result := <-respChan: - responses[result.index] = result.resp - } - } - - return responses, nil -} - -func TestAgreeRandomWithMockNet(t *testing.T) { - tests := []struct { - name string - nParties int - bitLen int - wantErr bool - }{ - { - name: "valid 2-party 128-bit", - nParties: 2, - bitLen: 128, - wantErr: false, - }, - { - name: "valid 2-party 10-bit", - nParties: 2, - bitLen: 10, - wantErr: false, - }, - { - name: "valid 2-party 256-bit", - nParties: 2, - bitLen: 256, - wantErr: false, - }, - { - name: "invalid party count", - nParties: 3, - bitLen: 128, - wantErr: true, - }, - { - name: "invalid bit length zero", - nParties: 2, - bitLen: 0, - wantErr: true, - }, - { - name: "invalid bit length negative", - nParties: 2, - bitLen: -10, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - responses, err := AgreeRandomWithMockNet(tt.nParties, tt.bitLen) - - if tt.wantErr { - assert.Error(t, err) - assert.Nil(t, responses) - return - } - - require.NoError(t, err) - require.Len(t, responses, tt.nParties) - - // Verify all parties got the same random value - firstValue := responses[0].RandomValue - for i, resp := range responses { - assert.NotNil(t, resp, "response %d should not be nil", i) - assert.NotNil(t, resp.RandomValue, "random value %d should not be nil", i) - assert.Equal(t, firstValue, resp.RandomValue, - "party %d should have same random value as party 0", i) - } - - // Verify the random value has the expected length - expectedBytes := (tt.bitLen + 7) / 8 // Round up to nearest byte - assert.Len(t, firstValue, expectedBytes, - "random value should have %d bytes for %d bits", expectedBytes, tt.bitLen) - }) - } -} - -func TestAgreeRandomRequest_Validation(t *testing.T) { - tests := []struct { - name string - bitLen int - wantErr bool - }{ - {"valid small", 1, false}, - {"valid medium", 128, false}, - {"valid large", 2048, false}, - {"invalid zero", 0, true}, - {"invalid negative", -1, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := &AgreeRandomRequest{BitLen: tt.bitLen} - - // We can't easily test AgreeRandom without a real messenger, - // so we test through AgreeRandomWithMockNet which does validation - _, err := AgreeRandomWithMockNet(2, req.BitLen) - - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestAgreeRandomResponse_Structure(t *testing.T) { - // Test that the response structure is as expected - responses, err := AgreeRandomWithMockNet(2, 64) - require.NoError(t, err) - require.Len(t, responses, 2) - - for i, resp := range responses { - assert.NotNil(t, resp, "response %d should not be nil", i) - assert.NotNil(t, resp.RandomValue, "random value %d should not be nil", i) - assert.Len(t, resp.RandomValue, 8, "64-bit value should be 8 bytes") - } -} - -func TestAgreeRandom_DeterministicAgreement(t *testing.T) { - // Run the same configuration multiple times to ensure consistency - bitLen := 32 - - for i := 0; i < 5; i++ { - responses, err := AgreeRandomWithMockNet(2, bitLen) - require.NoError(t, err, "iteration %d should succeed", i) - require.Len(t, responses, 2) - - // Both parties should agree each time - assert.Equal(t, responses[0].RandomValue, responses[1].RandomValue, - "iteration %d: parties should agree", i) - - // Values should be the correct length - expectedBytes := (bitLen + 7) / 8 - assert.Len(t, responses[0].RandomValue, expectedBytes, - "iteration %d: wrong length", i) - } -} diff --git a/demos-go/cb-mpc-go/api/mpc/doc.go b/demos-go/cb-mpc-go/api/mpc/doc.go deleted file mode 100644 index 79eaa348..00000000 --- a/demos-go/cb-mpc-go/api/mpc/doc.go +++ /dev/null @@ -1,83 +0,0 @@ -// Package mpc exposes high-level, ergonomic APIs for the multi-party -// computation (MPC) protocols implemented in the CB-MPC library. -// -// Instead of dealing with round messages, state-machines and network plumbing -// you interact with simple, synchronous request/response helpers. Under the -// hood each helper drives the native C++ engine and uses a `transport.Messenger` -// implementation to move data between parties. -// -// Highlights -// -// - Uniform Go API for 2–N-party ECDSA/EdDSA key generation, key refresh, -// signing and more. -// - Pluggable transport layer – run the same code against an in-process -// `mocknet` during unit tests and switch to a production‐grade mTLS -// transport with no changes. -// - First-class test-utilities that spin up realistic local networks in a -// single process. -// -// Quick example (random agreement between two parties): -// -// import "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/mpc" -// -// // Agree on 128 bits of randomness between two parties. -// out, err := mpc.AgreeRandomWithMockNet(2 /* parties */, 128 /* bits */) -// if err != nil { -// log.Fatalf("mpc: %v", err) -// } -// fmt.Printf("Shared random value: %x\n", out[0].Random) -// -// For production you would create a `transport.Messenger` (for example via the -// `mtls` sub-package) and then build a `Job*` value: -// -// messenger, _ := mtls.NewMTLSMessenger(cfg) -// job, _ := mpc.NewJob2P(messenger, selfIndex, []string{"alice", "bob"}) -// resp, err := mpc.AgreeRandom(job, &mpc.AgreeRandomRequest{BitLen: 256}) -// -// Every exported helper returns rich, declarative request and response structs -// making it straightforward to marshal results into JSON or protobuf. -// -// Package mpc exposes a thin, Go-idiomatic wrapper around the core C++ -// Publicly-Verifiable-Encryption (PVE) primitives found in cb-mpc. The wrapper -// is intentionally small – it forwards heavy cryptographic operations to the -// native library while letting Go take care of configuration, concurrency and -// pluggable encryption back-ends. -// -// Architecture -// -// Go (mpc.PVE) ──▶ Cgo shim (internal/cgobinding) ──▶ C++ core (src/cbmpc) -// ▲ ▲ │ -// │ │ │ -// │ per-backend ctx │ stub functions │ -// │ registry │ │ -// │ │ ▼ -// Backend impls ←── thread-local ctx ←──────── ffi_pke_t -// -// A caller supplies a custom encapsulation implementation that satisfies the -// cgobinding.KEM interface (aliased as mpc.KEM). Each implementation is -// registered once and receives an opaque *context pointer*. That pointer is -// shipped down to the C++ code on every call so the correct backend can be -// picked without any global state. -// -// The wrapper therefore supports multiple, independent encryption schemes –­ -// including non-ECIES KEM hybrids – running side-by-side inside the same Go -// process. -// -// # Concurrency Model -// -// The context pointer is stored in a thread-local variable inside the shim -// (`thread_local void* g_ctx`). Every user-visible helper first calls the -// unexported activateCtx() method which sets the variable, thereby guaranteeing -// that concurrent goroutines operating on different PVE handles never clash. -// -// Adding a New Backend -// -// 1. Implement the KEM methods. -// 2. Pass an instance via `NewPVE(Config{KEM: yourImpl})`. -// 3. Use the returned *PVE handle* for Encrypt / Verify / Decrypt. -// -// The backend is registered automatically and its required `rho` size is cached -// once at start-up for zero-alloc fast paths inside the native code. -// -// See the unit tests in pve_test.go for some example backends. -package mpc diff --git a/demos-go/cb-mpc-go/api/mpc/ecdsa_2p.go b/demos-go/cb-mpc-go/api/mpc/ecdsa_2p.go deleted file mode 100644 index 7e476220..00000000 --- a/demos-go/cb-mpc-go/api/mpc/ecdsa_2p.go +++ /dev/null @@ -1,203 +0,0 @@ -package mpc - -import ( - "fmt" - - "crypto/sha256" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/curvemap" -) - -// ECDSA2PCKey is an opaque handle to a 2-party ECDSA key share. -// -// It intentionally does **not** expose the underlying cgobinding type so that -// callers of the API do not need to import the low-level binding package. The -// only supported operation right now is an internal conversion back to the -// cgobinding representation so that the implementation can keep using the -// existing MPC primitives. Additional helper functions (e.g. serialization, -// freeing resources) can be added later. -// -// NOTE: the zero value of ECDSA2PCKey is considered invalid and can be used in -// tests to assert a key share was returned. -type ECDSA2PCKey cgobinding.Mpc_ecdsa2pc_key_ref - -// cgobindingRef converts the wrapper back to the underlying cgobinding type. -// It is unexported because callers outside this package should never rely on -// the cgobinding representation. -func (k ECDSA2PCKey) cgobindingRef() cgobinding.Mpc_ecdsa2pc_key_ref { - return cgobinding.Mpc_ecdsa2pc_key_ref(k) -} - -// RoleIndex returns which party (e.g., 0 or 1) owns this key share. -// It delegates to the underlying cgobinding implementation. -func (k ECDSA2PCKey) RoleIndex() (int, error) { - return cgobinding.KeyRoleIndex(k.cgobindingRef()) -} - -// Q returns the public key point associated with the distributed key. The -// returned Point must be freed by the caller once no longer needed. -func (k ECDSA2PCKey) Q() (*curve.Point, error) { - cPointRef, err := cgobinding.KeyQ(k.cgobindingRef()) - if err != nil { - return nil, err - } - bytes := cgobinding.ECCPointToBytes(cPointRef) - return curve.NewPointFromBytes(bytes) -} - -// Curve returns the elliptic curve associated with this key. -// The caller is responsible for freeing the returned Curve when done. -func (k ECDSA2PCKey) Curve() (curve.Curve, error) { - code, err := cgobinding.KeyCurveCode(k.cgobindingRef()) - if err != nil { - return nil, err - } - - return curvemap.CurveForCode(code) -} - -// XShare returns the scalar share x_i held by this party. -func (k ECDSA2PCKey) XShare() (*curve.Scalar, error) { - bytes, err := cgobinding.KeyXShare(k.cgobindingRef()) - if err != nil { - return nil, err - } - return &curve.Scalar{Bytes: bytes}, nil -} - -// ECDSA2PCKeyGenRequest represents the input parameters for ECDSA 2PC key generation -type ECDSA2PCKeyGenRequest struct { - Curve curve.Curve // Curve to use for key generation -} - -// ECDSA2PCKeyGenResponse represents the output of ECDSA 2PC key generation -type ECDSA2PCKeyGenResponse struct { - KeyShare ECDSA2PCKey // The party's share of the key -} - -// ECDSA2PCKeyGen executes the distributed key generation protocol between two parties. -// Both parties will generate complementary key shares that can be used together for signing. -func ECDSA2PCKeyGen(job2p *Job2P, req *ECDSA2PCKeyGenRequest) (*ECDSA2PCKeyGenResponse, error) { - if req == nil || req.Curve == nil { - return nil, fmt.Errorf("curve must be provided") - } - - // Execute the distributed key generation using the provided Job2P - keyShareRef, err := cgobinding.DistributedKeyGen(job2p.cgo(), curve.Code(req.Curve)) - if err != nil { - return nil, fmt.Errorf("ECDSA 2PC key generation failed: %v", err) - } - - return &ECDSA2PCKeyGenResponse{KeyShare: ECDSA2PCKey(keyShareRef)}, nil -} - -// ECDSA2PCSignRequest represents the input parameters for ECDSA 2PC signing -type ECDSA2PCSignRequest struct { - SessionID []byte // Session identifier for the signing operation - KeyShare ECDSA2PCKey // The party's share of the key - Message []byte // The message to sign -} - -// ECDSA2PCSignResponse represents the output of ECDSA 2PC signing -type ECDSA2PCSignResponse struct { - Signature []byte // The ECDSA signature -} - -// Verify verifies the DER-encoded signature against Q and 32-byte digest using the native crypto backend. -func (r *ECDSA2PCSignResponse) Verify(Q *curve.Point, digest []byte, c curve.Curve) error { - if len(r.Signature) == 0 { - return fmt.Errorf("empty signature") - } - if len(digest) != 32 { - return fmt.Errorf("digest must be 32 bytes, got %d", len(digest)) - } - // Build SEC1 uncompressed encoding: 0x04 || X || Y, with 32-byte padded coordinates - pad32 := func(b []byte) []byte { - if len(b) >= 32 { - if len(b) == 32 { - return b - } - // Trim if somehow longer - return b[len(b)-32:] - } - p := make([]byte, 32) - copy(p[32-len(b):], b) - return p - } - x := pad32(Q.GetX()) - y := pad32(Q.GetY()) - pubOct := make([]byte, 1+32+32) - pubOct[0] = 0x04 - copy(pubOct[1:1+32], x) - copy(pubOct[1+32:], y) - return cgobinding.ECCVerifyDER(curve.Code(c), pubOct, digest, r.Signature) -} - -// ECDSA2PCSign executes the collaborative signing protocol between two parties. -// Both parties use their key shares to jointly create a signature for the given message. -func ECDSA2PCSign(job2p *Job2P, req *ECDSA2PCSignRequest) (*ECDSA2PCSignResponse, error) { - if len(req.Message) == 0 { - return nil, fmt.Errorf("message cannot be empty") - } - - // Prepare 32-byte digest - msg := req.Message - if len(msg) != 32 { - d := sha256.Sum256(msg) - m := make([]byte, 32) - copy(m, d[:]) - msg = m - } - - // Execute the collaborative signing using batch API with a single message - sigs, err := cgobinding.Sign(job2p.cgo(), req.SessionID, req.KeyShare.cgobindingRef(), [][]byte{msg}) - if err != nil { - return nil, fmt.Errorf("ECDSA 2PC signing failed: %v", err) - } - if len(sigs) != 1 { - return nil, fmt.Errorf("unexpected batch sign result") - } - - return &ECDSA2PCSignResponse{Signature: sigs[0]}, nil -} - -// ECDSA2PCRefreshRequest represents the parameters required to refresh (re-share) -// an existing 2-party ECDSA key. -// -// The protocol produces a fresh set of secret shares (x₁′, x₂′) that satisfy -// x₁′ + x₂′ = x₁ + x₂ mod n, i.e. the joint secret – and therefore the public -// key Q – remains unchanged while the individual shares are replaced with new -// uniformly-random values. Refreshing is useful to proactively rid the system -// of potentially compromised partial secrets. -// -// Only the existing key share is required as input because the curve is -// implicitly encoded in the key itself. -type ECDSA2PCRefreshRequest struct { - KeyShare ECDSA2PCKey // The party's current key share to be refreshed -} - -// ECDSA2PCRefreshResponse encapsulates the newly generated key share that -// replaces the caller's previous share. -type ECDSA2PCRefreshResponse struct { - NewKeyShare ECDSA2PCKey // The refreshed key share for this party -} - -// ECDSA2PCRefresh executes the key-refresh (re-share) protocol for an existing -// 2-party ECDSA key. Both parties must invoke this function concurrently with -// their respective messengers and key shares. On completion each party obtains -// a new, independent share such that the public key and the combined secret -// remain unchanged. -func ECDSA2PCRefresh(job2p *Job2P, req *ECDSA2PCRefreshRequest) (*ECDSA2PCRefreshResponse, error) { - if req == nil { - return nil, fmt.Errorf("request must be provided") - } - - newKeyRef, err := cgobinding.Refresh(job2p.cgo(), req.KeyShare.cgobindingRef()) - if err != nil { - return nil, fmt.Errorf("ECDSA 2PC refresh failed: %v", err) - } - - return &ECDSA2PCRefreshResponse{NewKeyShare: ECDSA2PCKey(newKeyRef)}, nil -} diff --git a/demos-go/cb-mpc-go/api/mpc/ecdsa_2p_test.go b/demos-go/cb-mpc-go/api/mpc/ecdsa_2p_test.go deleted file mode 100644 index f2cee5a6..00000000 --- a/demos-go/cb-mpc-go/api/mpc/ecdsa_2p_test.go +++ /dev/null @@ -1,508 +0,0 @@ -package mpc - -import ( - "bytes" - "fmt" - "sync" - "testing" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestECDSA2PCKeyGenWithMockNet(t *testing.T) { - secp, _ := curve.NewSecp256k1() - - // Valid case - responses, err := ECDSA2PCKeyGenWithMockNet(secp) - require.NoError(t, err) - require.Len(t, responses, 2) - - for i, resp := range responses { - assert.NotNil(t, resp, "response %d should not be nil", i) - assert.NotEqual(t, resp.KeyShare, 0, "key share %d should not be zero", i) - } - - // Invalid case: nil curve - responsesNil, errNil := ECDSA2PCKeyGenWithMockNet(nil) - assert.Error(t, errNil) - assert.Nil(t, responsesNil) -} - -func TestECDSA2PCFullProtocolWithMockNet(t *testing.T) { - tests := []struct { - name string - sessionID []byte - message []byte - wantErr bool - }{ - { - name: "valid full protocol", - sessionID: []byte("test-session"), - message: []byte("Hello, world!"), - wantErr: false, - }, - { - name: "valid with empty session ID", - sessionID: []byte{}, - message: []byte("Test message"), - wantErr: false, - }, - { - name: "invalid empty message", - sessionID: []byte("test-session"), - message: []byte{}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - secp, _ := curve.NewSecp256k1() - result, err := ECDSA2PCFullProtocolWithMockNet(secp, tt.sessionID, tt.message) - - if tt.wantErr { - assert.Error(t, err) - assert.Nil(t, result) - return - } - - require.NoError(t, err) - require.NotNil(t, result) - - // Verify key generation results - require.Len(t, result.KeyGenResponses, 2) - for i, resp := range result.KeyGenResponses { - assert.NotNil(t, resp, "key gen response %d should not be nil", i) - assert.NotEqual(t, resp.KeyShare, 0, "key share %d should not be zero", i) - } - - // Verify signing results - require.Len(t, result.SignResponses, 2) - - // In ECDSA 2PC, only Party 0 gets the final signature - assert.NotNil(t, result.SignResponses[0], "sign response 0 should not be nil") - assert.NotNil(t, result.SignResponses[0].Signature, "signature 0 should not be nil") - assert.NotEmpty(t, result.SignResponses[0].Signature, "signature 0 should not be empty") - - // Party 1 contributes to signing but doesn't receive the final signature - assert.NotNil(t, result.SignResponses[1], "sign response 1 should not be nil") - assert.NotNil(t, result.SignResponses[1].Signature, "signature 1 should not be nil") - assert.Empty(t, result.SignResponses[1].Signature, "signature 1 should be empty (expected behavior)") - }) - } -} - -func TestECDSA2PCKeyGenRequest_Validation(t *testing.T) { - secp, _ := curve.NewSecp256k1() - _, err := ECDSA2PCKeyGenWithMockNet(secp) - assert.NoError(t, err) - - // Nil curve should error - _, errNil := ECDSA2PCKeyGenWithMockNet(nil) - assert.Error(t, errNil) -} - -func TestECDSA2PCSignRequest_Validation(t *testing.T) { - // First generate valid key shares for testing - secp, _ := curve.NewSecp256k1() - keyGenResponses, err := ECDSA2PCKeyGenWithMockNet(secp) - require.NoError(t, err) - require.Len(t, keyGenResponses, 2) - - tests := []struct { - name string - message []byte - wantErr bool - }{ - {"valid message", []byte("Hello, world!"), false}, - {"invalid empty message", []byte{}, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sessionID := []byte("test-session") - _, err := ECDSA2PCFullProtocolWithMockNet(secp, sessionID, tt.message) - - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestECDSA2PC_DeterministicSignatures(t *testing.T) { - // Test that Party 0 consistently gets signatures and Party 1 gets empty results - secp, _ := curve.NewSecp256k1() - sessionID := []byte("deterministic-test") - message := []byte("Consistent message") - - // Run the protocol multiple times - for i := 0; i < 3; i++ { - result, err := ECDSA2PCFullProtocolWithMockNet(secp, sessionID, message) - require.NoError(t, err, "iteration %d should succeed", i) - require.NotNil(t, result) - - // Party 0 should always get a signature - assert.NotEmpty(t, result.SignResponses[0].Signature, "iteration %d: Party 0 should get signature", i) - - // Party 1 should always get an empty signature (this is expected behavior) - assert.Empty(t, result.SignResponses[1].Signature, "iteration %d: Party 1 should get empty signature", i) - } -} - -func TestECDSA2PC_DifferentMessages(t *testing.T) { - // Test that different messages produce different signatures (from Party 0) - secp, _ := curve.NewSecp256k1() - sessionID := []byte("different-messages-test") - message1 := []byte("First message") - message2 := []byte("Second message") - - // Sign first message - result1, err := ECDSA2PCFullProtocolWithMockNet(secp, sessionID, message1) - require.NoError(t, err) - require.NotNil(t, result1) - require.NotEmpty(t, result1.SignResponses[0].Signature, "first signature should not be empty") - - // Sign second message - result2, err := ECDSA2PCFullProtocolWithMockNet(secp, sessionID, message2) - require.NoError(t, err) - require.NotNil(t, result2) - require.NotEmpty(t, result2.SignResponses[0].Signature, "second signature should not be empty") - - // Different messages should produce different signatures (from Party 0) - assert.False(t, bytes.Equal(result1.SignResponses[0].Signature, result2.SignResponses[0].Signature), - "different messages should produce different signatures") -} - -func TestECDSA2PC_StructureValidation(t *testing.T) { - // Test that all response structures are properly populated - secp, _ := curve.NewSecp256k1() - result, err := ECDSA2PCFullProtocolWithMockNet(secp, []byte("structure-test"), []byte("Test message")) - require.NoError(t, err) - require.NotNil(t, result) - - // Verify ECDSA2PCResult structure - assert.NotNil(t, result.KeyGenResponses, "KeyGenResponses should not be nil") - assert.NotNil(t, result.SignResponses, "SignResponses should not be nil") - assert.Len(t, result.KeyGenResponses, 2, "should have 2 key gen responses") - assert.Len(t, result.SignResponses, 2, "should have 2 sign responses") - - // Verify KeyGenResponse structure - for i, resp := range result.KeyGenResponses { - assert.NotNil(t, resp, "key gen response %d should not be nil", i) - // KeyShare is an opaque pointer, can't easily verify its contents - } - - // Verify SignResponse structure - // Party 0 should have a signature - assert.NotNil(t, result.SignResponses[0], "sign response 0 should not be nil") - assert.NotNil(t, result.SignResponses[0].Signature, "signature 0 should not be nil") - assert.Greater(t, len(result.SignResponses[0].Signature), 0, "signature 0 should have positive length") - - // Party 1 should have empty signature (expected behavior) - assert.NotNil(t, result.SignResponses[1], "sign response 1 should not be nil") - assert.NotNil(t, result.SignResponses[1].Signature, "signature 1 should not be nil") - assert.Equal(t, 0, len(result.SignResponses[1].Signature), "signature 1 should be empty (expected)") -} - -func TestECDSA2PCKey_RoleIndex(t *testing.T) { - // Generate key shares for two parties - secp, _ := curve.NewSecp256k1() - keyGenResponses, err := ECDSA2PCKeyGenWithMockNet(secp) - require.NoError(t, err) - require.Len(t, keyGenResponses, 2) - - // Party 0 - idx0, err := keyGenResponses[0].KeyShare.RoleIndex() - require.NoError(t, err, "party 0 RoleIndex should not error") - assert.Equal(t, 0, idx0, "party 0 should have role index 0") - - // Party 1 - idx1, err := keyGenResponses[1].KeyShare.RoleIndex() - require.NoError(t, err, "party 1 RoleIndex should not error") - assert.Equal(t, 1, idx1, "party 1 should have role index 1") -} - -func TestECDSA2PCKey_QAndXShare(t *testing.T) { - // Generate key shares - secp, _ := curve.NewSecp256k1() - keyGenResponses, err := ECDSA2PCKeyGenWithMockNet(secp) - require.NoError(t, err) - require.Len(t, keyGenResponses, 2) - - // Extract curve - curveObj, err := curve.NewSecp256k1() - require.NoError(t, err) - defer curveObj.Free() - - // Q from both parties should be the same - Q0, err := keyGenResponses[0].KeyShare.Q() - require.NoError(t, err) - defer Q0.Free() - - Q1, err := keyGenResponses[1].KeyShare.Q() - require.NoError(t, err) - defer Q1.Free() - - assert.True(t, Q0.Equals(Q1), "public key points should match across parties") - - // x shares - x0, err := keyGenResponses[0].KeyShare.XShare() - require.NoError(t, err) - x1, err := keyGenResponses[1].KeyShare.XShare() - require.NoError(t, err) - - // x_sum = x0 + x1 mod order - xSum, err := curveObj.Add(x0, x1) - require.NoError(t, err) - - // G * x_sum should equal Q - GxSum, err := curveObj.MultiplyGenerator(xSum) - require.NoError(t, err) - defer GxSum.Free() - - assert.True(t, GxSum.Equals(Q0), "G * (x0 + x1) should equal Q") -} - -// ECDSA2PCResult represents the complete result of key generation and signing -type ECDSA2PCResult struct { - KeyGenResponses []*ECDSA2PCKeyGenResponse - SignResponses []*ECDSA2PCSignResponse -} - -// ECDSA2PCFullProtocolWithMockNet runs the complete ECDSA 2PC protocol (key generation + signing) -// using the mock network. This is a test-only helper that provides a convenient way to exercise -// the full protocol flow. -func ECDSA2PCFullProtocolWithMockNet(curveObj curve.Curve, sessionID []byte, message []byte) (*ECDSA2PCResult, error) { - if curveObj == nil { - return nil, fmt.Errorf("curve must be provided") - } - if len(message) == 0 { - return nil, fmt.Errorf("message cannot be empty") - } - - // Use MPCRunner for proper coordination between parties - runner := mocknet.NewMPCRunner(mocknet.GeneratePartyNames(2)...) - - // Step 1: Distributed Key Generation - keyGenOutputs, err := runner.MPCRun2P(func(job cgobinding.Job2P, input *mocknet.MPCIO) (*mocknet.MPCIO, error) { - cv := input.Opaque.(curve.Curve) - keyShareRef, err := cgobinding.DistributedKeyGen(job, curve.Code(cv)) - if err != nil { - return nil, fmt.Errorf("key generation failed: %v", err) - } - return &mocknet.MPCIO{Opaque: keyShareRef}, nil - }, []*mocknet.MPCIO{ - {Opaque: curveObj}, - {Opaque: curveObj}, - }) - if err != nil { - return nil, fmt.Errorf("key generation failed: %v", err) - } - - // Extract key shares from outputs - keyShare0 := ECDSA2PCKey(keyGenOutputs[0].Opaque.(cgobinding.Mpc_ecdsa2pc_key_ref)) - keyShare1 := ECDSA2PCKey(keyGenOutputs[1].Opaque.(cgobinding.Mpc_ecdsa2pc_key_ref)) - - // Step 2: Collaborative Signing - type signInput struct { - SessionID []byte - KeyShare ECDSA2PCKey - Message []byte - } - - signOutputs, err := runner.MPCRun2P(func(job cgobinding.Job2P, input *mocknet.MPCIO) (*mocknet.MPCIO, error) { - signInput := input.Opaque.(signInput) - messages := [][]byte{signInput.Message} - signatures, err := cgobinding.Sign(job, signInput.SessionID, signInput.KeyShare.cgobindingRef(), messages) - if err != nil { - return nil, fmt.Errorf("signing failed: %v", err) - } - if len(signatures) == 0 { - return nil, fmt.Errorf("no signature returned") - } - return &mocknet.MPCIO{Opaque: signatures[0]}, nil - }, []*mocknet.MPCIO{ - {Opaque: signInput{ - SessionID: sessionID, - KeyShare: keyShare0, - Message: message, - }}, - {Opaque: signInput{ - SessionID: sessionID, - KeyShare: keyShare1, - Message: message, - }}, - }) - if err != nil { - return nil, fmt.Errorf("signing failed: %v", err) - } - - // Build result - result := &ECDSA2PCResult{ - KeyGenResponses: []*ECDSA2PCKeyGenResponse{ - {KeyShare: keyShare0}, - {KeyShare: keyShare1}, - }, - SignResponses: []*ECDSA2PCSignResponse{ - {Signature: signOutputs[0].Opaque.([]byte)}, - {Signature: signOutputs[1].Opaque.([]byte)}, - }, - } - - return result, nil -} - -// ECDSA2PCKeyGenWithMockNet is a test-only helper that runs the distributed key -// generation protocol locally using the in-memory mock network. This mirrors -// the original implementation that lived in the production API but has been -// moved here to avoid exposing testing utilities to API consumers. -func ECDSA2PCKeyGenWithMockNet(curveObj curve.Curve) ([]*ECDSA2PCKeyGenResponse, error) { - if curveObj == nil { - return nil, fmt.Errorf("curve must be provided") - } - - // Coordinate two virtual parties using the mock network helper. - runner := mocknet.NewMPCRunner(mocknet.GeneratePartyNames(2)...) - - outputs, err := runner.MPCRun2P(func(job cgobinding.Job2P, input *mocknet.MPCIO) (*mocknet.MPCIO, error) { - cv := input.Opaque.(curve.Curve) - keyShareRef, err := cgobinding.DistributedKeyGen(job, curve.Code(cv)) - if err != nil { - return nil, fmt.Errorf("key generation failed: %v", err) - } - return &mocknet.MPCIO{Opaque: keyShareRef}, nil - }, []*mocknet.MPCIO{ - {Opaque: curveObj}, - {Opaque: curveObj}, - }) - if err != nil { - return nil, fmt.Errorf("key generation failed: %v", err) - } - - // Convert outputs into the public response structure expected by callers. - responses := []*ECDSA2PCKeyGenResponse{ - {KeyShare: ECDSA2PCKey(outputs[0].Opaque.(cgobinding.Mpc_ecdsa2pc_key_ref))}, - {KeyShare: ECDSA2PCKey(outputs[1].Opaque.(cgobinding.Mpc_ecdsa2pc_key_ref))}, - } - - return responses, nil -} - -func TestECDSA2PCKeyGen_CurveIntegrity(t *testing.T) { - // Ensure that the curve associated with each generated key share matches - // the curve requested during key generation. - secp, _ := curve.NewSecp256k1() - - keyGenResponses, err := ECDSA2PCKeyGenWithMockNet(secp) - require.NoError(t, err) - require.Len(t, keyGenResponses, 2) - - expectedCode := curve.Code(secp) - - for i, resp := range keyGenResponses { - c, err := resp.KeyShare.Curve() - require.NoError(t, err, "party %d Curve() should not error", i) - assert.NotNil(t, c, "party %d curve should not be nil", i) - assert.Equal(t, expectedCode, curve.Code(c), "party %d curve code should match", i) - c.Free() - } -} - -func TestECDSA2PC_Refresh(t *testing.T) { - // Step 0: initialise curve - curveObj, _ := curve.NewSecp256k1() - - // Step 1: Generate initial key shares - keyGenResponses, err := ECDSA2PCKeyGenWithMockNet(curveObj) - require.NoError(t, err) - require.Len(t, keyGenResponses, 2) - - // Capture original public key Q - origQ, err := keyGenResponses[0].KeyShare.Q() - require.NoError(t, err) - defer origQ.Free() - - // Create a fresh mock network for the refresh round - const nParties = 2 - messengers := mocknet.NewMockNetwork(nParties) - partyNames := []string{"party_0", "party_1"} - - type refreshResult struct { - resp *ECDSA2PCRefreshResponse - err error - } - - results := make([]refreshResult, nParties) - - var wg sync.WaitGroup - wg.Add(nParties) - for i := 0; i < nParties; i++ { - go func(i int) { - defer wg.Done() - jp, err := NewJob2P(messengers[i], i, partyNames) - if err != nil { - results[i] = refreshResult{resp: nil, err: err} - return - } - defer jp.Free() - - r, e := ECDSA2PCRefresh(jp, &ECDSA2PCRefreshRequest{KeyShare: keyGenResponses[i].KeyShare}) - results[i] = refreshResult{resp: r, err: e} - }(i) - } - wg.Wait() - - for i := 0; i < nParties; i++ { - require.NoError(t, results[i].err, "party %d refresh should succeed", i) - require.NotNil(t, results[i].resp, "party %d response should not be nil", i) - } - - newShare0 := results[0].resp.NewKeyShare - newShare1 := results[1].resp.NewKeyShare - - // ===== Curve unchanged ===== - expectedCode := curve.Code(curveObj) - - c0, err := newShare0.Curve() - require.NoError(t, err) - assert.Equal(t, expectedCode, curve.Code(c0), "party 0 curve code should remain unchanged") - c0.Free() - - c1, err := newShare1.Curve() - require.NoError(t, err) - assert.Equal(t, expectedCode, curve.Code(c1), "party 1 curve code should remain unchanged") - c1.Free() - - // ===== Public key Q unchanged ===== - newQ0, err := newShare0.Q() - require.NoError(t, err) - defer newQ0.Free() - newQ1, err := newShare1.Q() - require.NoError(t, err) - defer newQ1.Free() - - assert.True(t, origQ.Equals(newQ0), "public key should remain unchanged after refresh (party 0)") - assert.True(t, origQ.Equals(newQ1), "public key should remain unchanged after refresh (party 1)") - - // ===== Key share sum unchanged ===== - x0New, err := newShare0.XShare() - require.NoError(t, err) - x1New, err := newShare1.XShare() - require.NoError(t, err) - - sumNew, err := curveObj.Add(x0New, x1New) - require.NoError(t, err) - - GsumNew, err := curveObj.MultiplyGenerator(sumNew) - require.NoError(t, err) - defer GsumNew.Free() - - assert.True(t, GsumNew.Equals(origQ), "G * (x0' + x1') should equal original Q after refresh") -} diff --git a/demos-go/cb-mpc-go/api/mpc/ecdsa_mp.go b/demos-go/cb-mpc-go/api/mpc/ecdsa_mp.go deleted file mode 100644 index 9a1e9ee1..00000000 --- a/demos-go/cb-mpc-go/api/mpc/ecdsa_mp.go +++ /dev/null @@ -1,408 +0,0 @@ -package mpc - -import ( - "bytes" - "encoding" - "encoding/gob" - "fmt" - "runtime" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// Compile-time assertions to ensure ECDSAMPCKey implements the binary -// marshaling interfaces. These will fail to compile if the method set ever -// becomes incomplete. -var ( - _ encoding.BinaryMarshaler = (*ECDSAMPCKey)(nil) - _ encoding.BinaryUnmarshaler = (*ECDSAMPCKey)(nil) -) - -// ECDSA N-Party (Multi-Party) API -// This API provides N-party ECDSA key generation and signing operations -// for scenarios requiring more than 2 parties (e.g., 3, 4, 5+ parties). - -// ============================================================================ -// Request/Response Types -// ============================================================================ - -// ECDSAMPCKey is an opaque handle to an N-party ECDSA key share. -// -// It mirrors the design of ECDSA2PCKey and intentionally hides the -// underlying cgobinding representation from API consumers. -// -// NOTE: the zero value is considered invalid. -type ECDSAMPCKey cgobinding.Mpc_eckey_mp_ref - -// newECDSAMPCKey wraps the given cgobinding key reference in an ECDSAMPCKey -// value and installs a finalizer to automatically release native resources -// when the Go value becomes unreachable. -func newECDSAMPCKey(ref cgobinding.Mpc_eckey_mp_ref) ECDSAMPCKey { - // We purposefully avoid installing a finalizer here. Although automatic - // cleanup is convenient, the ECDSAMPCKey value is small (it merely wraps - // an unsafe pointer) and is frequently copied. Go's finalizer semantics - // make it hard to guarantee that *all* copies of the value refer to the - // same underlying storage, which can easily lead to double-free bugs. We - // therefore rely on callers to invoke (*ECDSAMPCKey).Free() explicitly - // when they are done with the key share. - return ECDSAMPCKey(ref) -} - -// Free releases the underlying native key-share object. Callers who wish to -// deterministically free resources can invoke this method. If not called -// explicitly, the finalizer registered by newECDSAMPCKey will eventually take -// care of cleanup. -func (k *ECDSAMPCKey) Free() { - if k == nil { - return - } - ref := cgobinding.Mpc_eckey_mp_ref(*k) - (&ref).Free() - // Prevent double free by clearing the opaque pointer and disabling the - // finalizer. - *k = ECDSAMPCKey(cgobinding.Mpc_eckey_mp_ref{}) - runtime.SetFinalizer(k, nil) -} - -// MarshalBinary serializes the receiver into a byte slice that implements -// encoding.BinaryMarshaler. The output is produced by the underlying native -// helper and subsequently gob-encoded to ensure a portable wire format. The -// representation is intended for short-term transport or caching and should -// not be relied upon for long-term persistence across cb-mpc versions. -func (k ECDSAMPCKey) MarshalBinary() ([]byte, error) { - parts, err := cgobinding.SerializeECDSAShare(k.cgobindingRef()) - if err != nil { - return nil, err - } - var buf bytes.Buffer - if err := gob.NewEncoder(&buf).Encode(parts); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -// UnmarshalBinary restores the key share from the byte slice generated by -// MarshalBinary. The receiver is overwritten with the newly decoded -// ECDSAMPCKey and will point to a fresh native object. Callers that wish to -// reclaim any resources referenced by the previous value should invoke -// (*ECDSAMPCKey).Free before calling this method. -func (k *ECDSAMPCKey) UnmarshalBinary(data []byte) error { - var parts [][]byte - if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&parts); err != nil { - return err - } - keyRef, err := cgobinding.DeserializeECDSAShare(parts) - if err != nil { - return err - } - *k = newECDSAMPCKey(keyRef) - return nil -} - -// PartyName returns the string identifier of the party that owns this key -// share. It is fetched from the underlying C++ `key_share_mp_t::party_name` -// field via the cgobinding helper. -func (k ECDSAMPCKey) PartyName() (string, error) { - return cgobinding.MPC_mpc_eckey_mp_get_party_name(cgobinding.Mpc_eckey_mp_ref(k)) -} - -// XShare returns the scalar secret share x_i held by this party. -func (k ECDSAMPCKey) XShare() (*curve.Scalar, error) { - bytes, err := cgobinding.MPC_mpc_eckey_mp_get_x_share(k.cgobindingRef()) - if err != nil { - return nil, err - } - return &curve.Scalar{Bytes: bytes}, nil -} - -// Q returns the aggregated public key point associated with the distributed key. -// The returned Point must be freed by the caller. -func (k ECDSAMPCKey) Q() (*curve.Point, error) { - bytes, err := cgobinding.KeyShareQBytes(k.cgobindingRef()) - if err != nil { - return nil, err - } - return curve.NewPointFromBytes(bytes) -} - -// Curve returns the elliptic curve associated with this key. -// The caller is responsible for freeing the returned Curve when done. -func (k ECDSAMPCKey) Curve() (curve.Curve, error) { - code, err := cgobinding.KeyShareCurveCode(k.cgobindingRef()) - if err != nil { - return nil, err - } - return curve.NewFromCode(code) -} - -// Qis returns the per-party public key shares Qi for all peers. -// The caller is responsible for freeing the individual Point values when no -// longer needed. -func (k ECDSAMPCKey) Qis() (map[string]*curve.Point, error) { - names, points, err := cgobinding.MPC_mpc_eckey_mp_Qis(k.cgobindingRef()) - if err != nil { - return nil, err - } - if len(names) != len(points) { - return nil, fmt.Errorf("inconsistent Qis arrays: %d names vs %d points", len(names), len(points)) - } - out := make(map[string]*curve.Point, len(names)) - for i, nameBytes := range names { - pt, err := curve.NewPointFromBytes(points[i]) - if err != nil { - return nil, fmt.Errorf("failed to decode Qi for party %s: %v", string(nameBytes), err) - } - out[string(nameBytes)] = pt - } - return out, nil -} - -// cgobindingRef unwraps the internal cgobinding key reference. It is kept -// unexported to discourage direct use outside of this package. -func (k ECDSAMPCKey) cgobindingRef() cgobinding.Mpc_eckey_mp_ref { - return cgobinding.Mpc_eckey_mp_ref(k) -} - -// ECDSAMPCKeyGenRequest represents a request for N-party ECDSA key generation. -// The caller specifies the Curve instance instead of a raw numeric identifier -// to align the API with other MPC primitives (e.g. ECDSA2PC). -type ECDSAMPCKeyGenRequest struct { - Curve curve.Curve // Elliptic curve to use (e.g., secp256k1) -} - -// ECDSAMPCKeyGenResponse represents the response from N-party ECDSA key generation -type ECDSAMPCKeyGenResponse struct { - KeyShare ECDSAMPCKey // The distributed key share for this party -} - -// ============================================================================ -// Core API Functions -// ============================================================================ - -// ECDSAMPCKeyGen performs N-party ECDSA distributed key generation -// All parties must call this function simultaneously with the same parameters -func ECDSAMPCKeyGen(jobmp *JobMP, req *ECDSAMPCKeyGenRequest) (*ECDSAMPCKeyGenResponse, error) { - if jobmp == nil { - return nil, fmt.Errorf("job must be provided") - } - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if req.Curve == nil { - return nil, fmt.Errorf("curve must be provided") - } - if jobmp.NParties() < 3 { - return nil, fmt.Errorf("n-party ECDSA requires at least 3 parties (use ECDSA2PC for 2-party)") - } - - // Perform distributed key generation using the provided JobMP and curve - keyShare, err := cgobinding.KeyShareDKGCode(jobmp.cgo(), curve.Code(req.Curve)) - if err != nil { - return nil, fmt.Errorf("ECDSA N-party key generation failed: %v", err) - } - - return &ECDSAMPCKeyGenResponse{KeyShare: newECDSAMPCKey(keyShare)}, nil -} - -// ECDSAMPCSignRequest represents a request for N-party ECDSA signing -type ECDSAMPCSignRequest struct { - KeyShare ECDSAMPCKey // The key share from key generation - Message []byte // The message to sign - SignatureReceiver int // Which party should receive the final signature (typically 0) -} - -// ECDSAMPCSignResponse represents the response from N-party ECDSA signing -type ECDSAMPCSignResponse struct { - Signature []byte // The ECDSA signature (only populated for the designated receiver) -} - -// ECDSAMPCSign performs N-party ECDSA signing -// All parties must call this function simultaneously with their respective key shares -func ECDSAMPCSign(jobmp *JobMP, req *ECDSAMPCSignRequest) (*ECDSAMPCSignResponse, error) { - if jobmp == nil { - return nil, fmt.Errorf("job must be provided") - } - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if jobmp.NParties() < 3 { - return nil, fmt.Errorf("n-party signing requires at least 3 parties") - } - if len(req.Message) == 0 { - return nil, fmt.Errorf("message cannot be empty") - } - - // Perform distributed signing using the provided JobMP - signature, err := cgobinding.MPC_ecdsampc_sign(jobmp.cgo(), req.KeyShare.cgobindingRef(), req.Message, req.SignatureReceiver) - if err != nil { - return nil, fmt.Errorf("ECDSA N-party signing failed: %v", err) - } - - // Determine current party index - roleIndex := jobmp.GetPartyIndex() - - // Only the designated receiver gets the signature - var sigBytes []byte - if roleIndex == req.SignatureReceiver { - sigBytes = signature - } - - return &ECDSAMPCSignResponse{Signature: sigBytes}, nil -} - -// ECDSAMPCRefreshRequest represents the parameters required to refresh (re-share) -// an existing N-party ECDSA key. -// -// The protocol produces a fresh set of secret shares such that the combined -// public key remains unchanged while each party obtains a new independent -// share. -// -// A unique SessionID must be supplied by the caller. It should be identical -// for all parties participating in the refresh procedure. If SessionID is -// nil or empty, a random identifier will be generated internally. -// -// NOTE: The refresh protocol is an N-party operation – all parties that -// originally participated in key generation must invoke it concurrently. -type ECDSAMPCRefreshRequest struct { - KeyShare ECDSAMPCKey // Existing key share to be refreshed - SessionID []byte // Caller-provided session identifier (optional) -} - -// ECDSAMPCRefreshResponse contains the newly generated key share that -// replaces the caller's previous share. -type ECDSAMPCRefreshResponse struct { - NewKeyShare ECDSAMPCKey // The refreshed key share for this party -} - -// ECDSAMPCRefresh executes the key-refresh protocol for an existing N-party -// ECDSA key. All parties must invoke this function concurrently with their -// current key shares and an identical SessionID. -func ECDSAMPCRefresh(jobmp *JobMP, req *ECDSAMPCRefreshRequest) (*ECDSAMPCRefreshResponse, error) { - if jobmp == nil { - return nil, fmt.Errorf("job must be provided") - } - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if jobmp.NParties() < 3 { - return nil, fmt.Errorf("n-party refresh requires at least 3 parties") - } - // Ensure a session ID is always provided to the native layer. If the caller - // did not supply one, fall back to an empty slice (the binding will handle - // conversion to an empty cmem_t). - sid := req.SessionID - - newKey, err := cgobinding.KeyShareRefresh(jobmp.cgo(), sid, req.KeyShare.cgobindingRef()) - if err != nil { - return nil, fmt.Errorf("ECDSA N-party refresh failed: %v", err) - } - - return &ECDSAMPCRefreshResponse{NewKeyShare: newECDSAMPCKey(newKey)}, nil -} - -// ECDSAMPCThresholdDKGRequest represents the input parameters for running the -// threshold Distributed Key Generation (DKG) protocol. The protocol generates -// an N-party ECDSA key that can later be used by any quorum that satisfies the -// provided access-structure policy. -// -// All parties that participate in the initial key generation (typically all -// online parties) must invoke ECDSAMPCThresholdDKG concurrently using the same -// parameters. -// -// AccessStructure describes the quorum policy that determines which subsets of -// parties are authorized to later perform signature operations. Callers should -// construct it using the high-level helpers provided in this package (see -// Leaf/And/Or/Threshold plus the AccessStructure wrapper) rather than dealing -// with low-level cgobinding functions. -type ECDSAMPCThresholdDKGRequest struct { - Curve curve.Curve // Elliptic curve to use (e.g. secp256k1) - SessionID []byte // Optional caller-supplied session identifier - AccessStructure *AccessStructure // Quorum access-structure description - QuorumRIDs []int // (Optional) Indices of parties that will form the quorum; defaults to all parties if nil/empty -} - -// ECDSAMPCThresholdDKGResponse contains the newly generated key share owned by -// the calling party. -type ECDSAMPCThresholdDKGResponse struct { - KeyShare ECDSAMPCKey -} - -// ECDSAMPCThresholdDKG executes the threshold DKG protocol and returns the -// caller's key share. Internally it delegates to cgobinding.ThresholdDKGCurve. -// -// Usage notes: -// - jobmp must represent *all* parties involved in the DKG. -// - The AccessStructure parameter must describe a quorum that can be satisfied -// by (a subset of) the parties represented by jobmp. -func ECDSAMPCThresholdDKG(jobmp *JobMP, req *ECDSAMPCThresholdDKGRequest) (*ECDSAMPCThresholdDKGResponse, error) { - if jobmp == nil { - return nil, fmt.Errorf("job must be provided") - } - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if req.Curve == nil { - return nil, fmt.Errorf("curve must be provided") - } - - // Ensure a session ID is always passed to the binding. An empty slice is - // interpreted as "let the native implementation pick a random SID". - sid := req.SessionID - - // Ensure we have a valid access-structure description. - if req.AccessStructure == nil { - return nil, fmt.Errorf("access structure must be provided") - } - - // Translate the high-level Go representation into the native C handle. - acPtr := req.AccessStructure.toCryptoAC() - - // Determine which party indices will participate in DKG. - roleIndices := req.QuorumRIDs - if len(roleIndices) == 0 { - // Fallback to all parties if caller did not supply a custom set. - roleIndices = make([]int, jobmp.NParties()) - for i := 0; i < jobmp.NParties(); i++ { - roleIndices[i] = i - } - } - - // Run the native threshold DKG using the curve reference directly to - // avoid leaking numeric NIDs into the API layer. - keyShareRef, err := cgobinding.ThresholdDKGCode(jobmp.cgo(), curve.Code(req.Curve), sid, acPtr, roleIndices) - if err != nil { - return nil, fmt.Errorf("ECDSA threshold DKG failed: %v", err) - } - - return &ECDSAMPCThresholdDKGResponse{KeyShare: newECDSAMPCKey(keyShareRef)}, nil -} - -// ToAdditiveShare converts the multiplicative secret-share representation -// embodied by the ECDSAMPCKey into an additive share that satisfies the -// provided access-structure. The resulting key share can be used by the -// threshold signing routines that expect additive shares. The aggregated -// public key is preserved by the transformation. -func (k ECDSAMPCKey) ToAdditiveShare(ac *AccessStructure, quorumPartyNames []string) (ECDSAMPCKey, error) { - // Validate inputs - if ac == nil { - return ECDSAMPCKey{}, fmt.Errorf("access structure must be provided") - } - if len(quorumPartyNames) == 0 { - return ECDSAMPCKey{}, fmt.Errorf("quorumPartyNames cannot be empty") - } - - // Translate the high-level AccessStructure into the native representation. - // The native object returned by toCryptoAC carries a finalizer that will - // release its resources once it becomes unreachable, so we do not free - // it explicitly here to avoid double-free errors. - acPtr := ac.toCryptoAC() - - // Forward to the low-level helper using the underlying cgobinding key ref. - keyRef := cgobinding.Mpc_eckey_mp_ref(k) - additiveRef, err := (&keyRef).ToAdditiveShare(acPtr, quorumPartyNames) - if err != nil { - return ECDSAMPCKey{}, err - } - return newECDSAMPCKey(additiveRef), nil -} diff --git a/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_test.go b/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_test.go deleted file mode 100644 index 74b43de3..00000000 --- a/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_test.go +++ /dev/null @@ -1,348 +0,0 @@ -package mpc - -import ( - "fmt" - "testing" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// SignMPInput represents the input for multi-party signing operations -type SignMPInput struct { - Key ECDSAMPCKey - Msg []byte -} - -// ECDSAMPCWithMockNet performs complete N-party ECDSA workflow using MockNet. -// It is used in unit tests to exercise the full protocol stack without external -// networking. -func ECDSAMPCWithMockNet(nParties int, c curve.Curve, message []byte) ([]*ECDSAMPCKeyGenResponse, []*ECDSAMPCSignResponse, error) { - if nParties < 3 { - return nil, nil, fmt.Errorf("n-party ECDSA requires at least 3 parties") - } - if len(message) == 0 { - return nil, nil, fmt.Errorf("message cannot be empty") - } - - // Create MockNet runner - runner := mocknet.NewMPCRunner(mocknet.GeneratePartyNames(nParties)...) - - // --------------------------------------------------------------------- - // Step 1: Distributed Key Generation - // --------------------------------------------------------------------- - keyGenInputs := make([]*mocknet.MPCIO, nParties) - for i := 0; i < nParties; i++ { - keyGenInputs[i] = &mocknet.MPCIO{Opaque: c} - } - - keyGenOutputs, err := runner.MPCRunMP(func(job cgobinding.JobMP, input *mocknet.MPCIO) (*mocknet.MPCIO, error) { - cv := input.Opaque.(curve.Curve) - keyShare, err := cgobinding.KeyShareDKGCode(job, curve.Code(cv)) - if err != nil { - return nil, fmt.Errorf("n-party key generation failed: %v", err) - } - return &mocknet.MPCIO{Opaque: keyShare}, nil - }, keyGenInputs) - if err != nil { - return nil, nil, fmt.Errorf("key generation failed: %v", err) - } - - keyGenResponses := make([]*ECDSAMPCKeyGenResponse, nParties) - for i := 0; i < nParties; i++ { - keyGenResponses[i] = &ECDSAMPCKeyGenResponse{ - KeyShare: ECDSAMPCKey(keyGenOutputs[i].Opaque.(cgobinding.Mpc_eckey_mp_ref)), - } - } - - // --------------------------------------------------------------------- - // Step 2: Distributed Signing (no explicit public key extraction needed) - // --------------------------------------------------------------------- - signatureReceiver := 0 // Party 0 receives the signature - signInputs := make([]*mocknet.MPCIO, nParties) - for i := 0; i < nParties; i++ { - signInputs[i] = &mocknet.MPCIO{Opaque: SignMPInput{Key: keyGenResponses[i].KeyShare, Msg: message}} - } - - signOutputs, err := runner.MPCRunMP(func(job cgobinding.JobMP, input *mocknet.MPCIO) (*mocknet.MPCIO, error) { - signInput := input.Opaque.(SignMPInput) - sig, err := cgobinding.MPC_ecdsampc_sign(job, signInput.Key.cgobindingRef(), signInput.Msg, signatureReceiver) - if err != nil { - return nil, fmt.Errorf("n-party signing failed: %v", err) - } - return &mocknet.MPCIO{Opaque: sig}, nil - }, signInputs) - if err != nil { - return nil, nil, fmt.Errorf("signing failed: %v", err) - } - - signResponses := make([]*ECDSAMPCSignResponse, nParties) - for i := 0; i < nParties; i++ { - var sigBytes []byte - if i == signatureReceiver { - sigBytes = signOutputs[i].Opaque.([]byte) - } - signResponses[i] = &ECDSAMPCSignResponse{Signature: sigBytes} - } - - return keyGenResponses, signResponses, nil -} - -func TestECDSAMPCWithMockNet(t *testing.T) { - tests := []struct { - name string - nParties int - message []byte - wantErr bool - }{ - { - name: "valid_3_party_ecdsa", - nParties: 3, - message: []byte("test message for 3-party ECDSA"), - wantErr: false, - }, - { - name: "valid_4_party_ecdsa", - nParties: 4, - message: []byte("test message for 4-party ECDSA"), - wantErr: false, - }, - { - name: "valid_5_party_ecdsa", - nParties: 5, - message: []byte("test message for 5-party ECDSA"), - wantErr: false, - }, - { - name: "invalid_too_few_parties", - nParties: 2, - message: []byte("test message"), - wantErr: true, - }, - { - name: "invalid_empty_message", - nParties: 3, - message: []byte{}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Use secp256k1 for all tests (matches previous NID 714) - secp, errCurve := curve.NewSecp256k1() - if errCurve != nil { - t.Fatalf("failed to create curve: %v", errCurve) - } - - keyGenResponses, signResponses, err := ECDSAMPCWithMockNet(tt.nParties, secp, tt.message) - - if tt.wantErr { - if err == nil { - t.Errorf("ECDSAMPCWithMockNet() expected error but got none") - } - return - } - - if err != nil { - t.Errorf("ECDSAMPCWithMockNet() error = %v, wantErr %v", err, tt.wantErr) - return - } - - // Verify key generation responses - if len(keyGenResponses) != tt.nParties { - t.Errorf("Expected %d key generation responses, got %d", tt.nParties, len(keyGenResponses)) - } - - // Verify signing responses - if len(signResponses) != tt.nParties { - t.Errorf("Expected %d signing responses, got %d", tt.nParties, len(signResponses)) - } - - // Verify that only party 0 receives the signature - signatureReceiver := 0 - for i, resp := range signResponses { - if i == signatureReceiver { - if len(resp.Signature) == 0 { - t.Errorf("Party %d should receive signature but got empty", i) - } - } else { - if len(resp.Signature) != 0 { - t.Errorf("Party %d should not receive signature but got %d bytes", i, len(resp.Signature)) - } - } - } - - // Verify public key is populated using get_Q - pt, err := keyGenResponses[0].KeyShare.Q() - if err != nil { - t.Errorf("Failed to retrieve public key: %v", err) - } else { - if len(pt.GetX()) == 0 { - t.Error("Public key X coordinate is empty") - } - if len(pt.GetY()) == 0 { - t.Error("Public key Y coordinate is empty") - } - pt.Free() - } - }) - } -} - -func TestECDSAMPCConsistency(t *testing.T) { - // Test that multiple runs produce different signatures but same public key structure - nParties := 3 - secp, errCurve := curve.NewSecp256k1() - if errCurve != nil { - t.Fatalf("failed to create curve: %v", errCurve) - } - - // Run the protocol twice with different messages - message1 := []byte("first test message") - message2 := []byte("second test message") - - keyGenResponses1, signResponses1, err := ECDSAMPCWithMockNet(nParties, secp, message1) - if err != nil { - t.Fatalf("First run failed: %v", err) - } - - keyGenResponses2, signResponses2, err := ECDSAMPCWithMockNet(nParties, secp, message2) - if err != nil { - t.Fatalf("Second run failed: %v", err) - } - - // Public keys should have the same structure (both should be non-empty) using get_Q - pt1, err := keyGenResponses1[0].KeyShare.Q() - if err != nil { - t.Fatalf("Failed to retrieve public key from first run: %v", err) - } - pt2, err := keyGenResponses2[0].KeyShare.Q() - if err != nil { - t.Fatalf("Failed to retrieve public key from second run: %v", err) - } - if len(pt1.GetX()) == 0 || len(pt1.GetY()) == 0 { - t.Error("First public key has empty coordinates") - } - if len(pt2.GetX()) == 0 || len(pt2.GetY()) == 0 { - t.Error("Second public key has empty coordinates") - } - pt1.Free() - pt2.Free() - - // Key generation responses should have the same structure - if len(keyGenResponses1) != len(keyGenResponses2) { - t.Error("Key generation response counts differ") - } - - // Signature responses should have the same structure - if len(signResponses1) != len(signResponses2) { - t.Error("Signature response counts differ") - } - - // Signatures should be different (different messages) - sig1 := signResponses1[0].Signature - sig2 := signResponses2[0].Signature - - if len(sig1) == 0 || len(sig2) == 0 { - t.Error("Signatures should not be empty") - } - - // Compare signatures byte by byte - they should be different - if len(sig1) == len(sig2) { - allSame := true - for i := 0; i < len(sig1); i++ { - if sig1[i] != sig2[i] { - allSame = false - break - } - } - if allSame { - t.Error("Signatures should be different for different messages") - } - } -} - -func TestECDSAMPCScalability(t *testing.T) { - // Test with different party counts to ensure scalability - partyCounts := []int{3, 4, 5, 6} - secp, errCurve := curve.NewSecp256k1() - if errCurve != nil { - t.Fatalf("failed to create curve: %v", errCurve) - } - message := []byte("scalability test message") - - for _, nParties := range partyCounts { - t.Run(fmt.Sprintf("parties_%d", nParties), func(t *testing.T) { - keyGenResponses, signResponses, err := ECDSAMPCWithMockNet(nParties, secp, message) - if err != nil { - t.Errorf("Failed with %d parties: %v", nParties, err) - return - } - - if len(keyGenResponses) != nParties { - t.Errorf("Expected %d key shares, got %d", nParties, len(keyGenResponses)) - } - - if len(signResponses) != nParties { - t.Errorf("Expected %d sign responses, got %d", nParties, len(signResponses)) - } - - // Verify public key is retrievable using get_Q - pt, err := keyGenResponses[0].KeyShare.Q() - if err != nil { - t.Errorf("Failed to retrieve public key: %v", err) - } else { - pt.Free() - } - - // Verify signature was generated - if len(signResponses[0].Signature) == 0 { - t.Error("Signature was not generated") - } - }) - } -} - -func TestECDSAMPCKeyAccessors_Q_and_Curve(t *testing.T) { - // Use secp256k1 curve and 3-party setup - secp, errCurve := curve.NewSecp256k1() - if errCurve != nil { - t.Fatalf("failed to create curve: %v", errCurve) - } - defer secp.Free() - - msg := []byte("accessors test message") - keyGenResponses, _, err := ECDSAMPCWithMockNet(3, secp, msg) - if err != nil { - t.Fatalf("ECDSAMPCWithMockNet failed: %v", err) - } - if len(keyGenResponses) == 0 { - t.Fatalf("no key generation responses returned") - } - - // Verify Q() returns a valid, non-empty public key point - pt, err := keyGenResponses[0].KeyShare.Q() - if err != nil { - t.Fatalf("KeyShare.Q failed: %v", err) - } - if len(pt.GetX()) == 0 || len(pt.GetY()) == 0 { - t.Fatalf("public key coordinates should be non-empty") - } - pt.Free() - - // Verify Curve() returns the same curve that was used for key generation - kc, err := keyGenResponses[0].KeyShare.Curve() - if err != nil { - t.Fatalf("KeyShare.Curve failed: %v", err) - } - defer kc.Free() - - if curve.Code(kc) != curve.Code(secp) { - t.Fatalf("curve code mismatch: got %d want %d", curve.Code(kc), curve.Code(secp)) - } - if kc.String() != secp.String() { - t.Fatalf("curve string mismatch: got %q want %q", kc.String(), secp.String()) - } -} diff --git a/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_threshold_test.go b/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_threshold_test.go deleted file mode 100644 index 21fc888a..00000000 --- a/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_threshold_test.go +++ /dev/null @@ -1,372 +0,0 @@ -// Replace placeholder with test implementations -package mpc - -import ( - "crypto/sha256" - "testing" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// createThresholdAccessStructure builds an in-memory AccessStructure tree -// representing a simple "threshold-of-n" policy and returns a high-level -// Go wrapper that can be passed to the MPC APIs. -func createThresholdAccessStructure(pnames []string, threshold int, cv curve.Curve) *AccessStructure { - // Build leaf nodes for each party. - kids := make([]*AccessNode, len(pnames)) - for i, n := range pnames { - kids[i] = Leaf(n) - } - - // Root is a THRESHOLD node with K=threshold. - root := Threshold("", threshold, kids...) - - return &AccessStructure{Root: root, Curve: cv} -} - -// TestECDSAMPCThresholdDKGWithMockNet exercises the high-level -// ECDSAMPCThresholdDKG wrapper across multiple parties using the in-memory mock -// network. It validates that each participant receives a non-nil key share and -// that basic invariants (party name, curve code) hold. -func TestECDSAMPCThresholdDKGWithMockNet(t *testing.T) { - const ( - nParties = 5 - threshold = 3 - ) - - // Prepare curve instance. - cv, err := curve.NewSecp256k1() - require.NoError(t, err) - defer cv.Free() - - // Prepare mock network primitives. - pnames := mocknet.GeneratePartyNames(nParties) - messengers := mocknet.NewMockNetwork(nParties) - - // Channel to gather per-party results. - type result struct { - idx int - resp *ECDSAMPCThresholdDKGResponse - err error - } - resCh := make(chan result, nParties) - - // Launch one goroutine per party. - for i := 0; i < nParties; i++ { - go func(idx int) { - // Build JobMP wrapper for this party. - job, err := NewJobMP(messengers[idx], nParties, idx, pnames) - if err != nil { - resCh <- result{idx: idx, resp: nil, err: err} - return - } - defer job.Free() - - // Each party creates its own access-structure object. - ac := createThresholdAccessStructure(pnames, threshold, cv) - - req := &ECDSAMPCThresholdDKGRequest{ - Curve: cv, - SessionID: nil, // let native generate SID - AccessStructure: ac, - } - - r, e := ECDSAMPCThresholdDKG(job, req) - resCh <- result{idx: idx, resp: r, err: e} - }(i) - } - - // Collect results. - resp := make([]*ECDSAMPCThresholdDKGResponse, nParties) - for i := 0; i < nParties; i++ { - out := <-resCh - require.NoError(t, out.err, "party %d threshold DKG should succeed", out.idx) - require.NotNil(t, out.resp, "party %d response must not be nil", out.idx) - resp[out.idx] = out.resp - } - - // Basic validations. - expectedCurveCode := curve.Code(cv) - - for i, r := range resp { - // Key share must be non-zero. - assert.NotEqual(t, 0, r.KeyShare, "party %d key share should not be zero", i) - - // Party name matches. - pname, err := r.KeyShare.PartyName() - require.NoError(t, err) - assert.Equal(t, pnames[i], pname, "party %d pname mismatch", i) - - // Curve matches. - c, err := r.KeyShare.Curve() - require.NoError(t, err) - actual := curve.Code(c) - assert.Equal(t, expectedCurveCode, actual) - c.Free() - } - - // Convert a quorum of parties to additive shares under the same threshold policy - root := Threshold("", threshold, func() []*AccessNode { - kids := make([]*AccessNode, len(pnames)) - for i, n := range pnames { - kids[i] = Leaf(n) - } - return kids - }()...) - asQ := &AccessStructure{Root: root, Curve: cv} - quorumNames := pnames[:threshold] - - additive := make([]ECDSAMPCKey, threshold) - for i := 0; i < threshold; i++ { - as, err := resp[i].KeyShare.ToAdditiveShare(asQ, quorumNames) - require.NoError(t, err, "party %d additive share conversion failed", i) - additive[i] = as - } - - // Run an ECDSA MPC signing with only the quorum parties, then verify the DER signature - message := []byte("ecdsa threshold dkg signing") - digest := sha256.Sum256(message) - sigReceiver := 0 - - signMessengers := mocknet.NewMockNetwork(threshold) - type signResult struct { - idx int - sig []byte - err error - } - signCh := make(chan signResult, threshold) - - for i := 0; i < threshold; i++ { - go func(idx int) { - job, err := NewJobMP(signMessengers[idx], threshold, idx, quorumNames) - if err != nil { - signCh <- signResult{idx: idx, err: err} - return - } - defer job.Free() - - req := &ECDSAMPCSignRequest{KeyShare: additive[idx], Message: digest[:], SignatureReceiver: sigReceiver} - r, e := ECDSAMPCSign(job, req) - if e != nil { - signCh <- signResult{idx: idx, err: e} - return - } - signCh <- signResult{idx: idx, sig: r.Signature, err: nil} - }(i) - } - - sigs := make([][]byte, threshold) - for i := 0; i < threshold; i++ { - out := <-signCh - require.NoError(t, out.err, "party %d signing should succeed", out.idx) - sigs[out.idx] = out.sig - } - - // Only the receiver should have the signature - require.NotEmpty(t, sigs[sigReceiver]) - for i := 0; i < threshold; i++ { - if i != sigReceiver { - assert.Empty(t, sigs[i]) - } - } - - // Verify signature against Q - Q, err := resp[0].KeyShare.Q() - require.NoError(t, err) - // Build SEC1 uncompressed pubkey - pad32 := func(b []byte) []byte { - p := make([]byte, 32) - if len(b) >= 32 { - copy(p, b[len(b)-32:]) - return p - } - copy(p[32-len(b):], b) - return p - } - x := pad32(Q.GetX()) - y := pad32(Q.GetY()) - pubOct := make([]byte, 1+32+32) - pubOct[0] = 0x04 - copy(pubOct[1:33], x) - copy(pubOct[33:], y) - Q.Free() - require.NoError(t, cgobinding.ECCVerifyDER(curve.Code(cv), pubOct, digest[:], sigs[sigReceiver])) -} - -// TestECDSAMPC_ToAdditiveShare verifies that a subset of parties satisfying the -// quorum threshold can convert their threshold-DKG key share into an additive -// secret share without error. -func TestECDSAMPC_ToAdditiveShare(t *testing.T) { - const ( - nParties = 4 - threshold = 2 - ) - - cv, err := curve.NewSecp256k1() - require.NoError(t, err) - defer cv.Free() - - pnames := mocknet.GeneratePartyNames(nParties) - messengers := mocknet.NewMockNetwork(nParties) - - // First run threshold DKG to obtain key shares. - type dkgResult struct { - idx int - share ECDSAMPCKey - err error - } - dkgCh := make(chan dkgResult, nParties) - - for i := 0; i < nParties; i++ { - go func(idx int) { - job, err := NewJobMP(messengers[idx], nParties, idx, pnames) - if err != nil { - dkgCh <- dkgResult{idx: idx, err: err} - return - } - defer job.Free() - - ac := createThresholdAccessStructure(pnames, threshold, cv) - - req := &ECDSAMPCThresholdDKGRequest{Curve: cv, AccessStructure: ac} - resp, err := ECDSAMPCThresholdDKG(job, req) - if err != nil { - dkgCh <- dkgResult{idx: idx, err: err} - return - } - dkgCh <- dkgResult{idx: idx, share: resp.KeyShare, err: nil} - }(i) - } - - shares := make([]ECDSAMPCKey, nParties) - for i := 0; i < nParties; i++ { - out := <-dkgCh - require.NoError(t, out.err) - shares[out.idx] = out.share - } - - // Prepare quorum party names – pick the first `threshold` parties. - quorumPNames := pnames[:threshold] - - // Build an AccessStructure representing the same threshold policy. - root := Threshold("", threshold, func() []*AccessNode { - kids := make([]*AccessNode, len(pnames)) - for i, n := range pnames { - kids[i] = Leaf(n) - } - return kids - }()...) - - asQ := &AccessStructure{Root: root, Curve: cv} - - // Convert shares for the quorum parties and ensure success. - for i := 0; i < threshold; i++ { - additive, err := shares[i].ToAdditiveShare(asQ, quorumPNames) - require.NoError(t, err, "party %d additive share conversion failed", i) - assert.NotEqual(t, 0, additive, "party %d additive share should not be zero", i) - // Clean up native resources to avoid leaks. - ref := additive.cgobindingRef() - (&ref).Free() - } - - // Non-quorum parties can also convert to additive shares; ensure no error - for i := threshold; i < nParties; i++ { - additive, err := shares[i].ToAdditiveShare(asQ, quorumPNames) - require.NoError(t, err, "non-quorum party %d additive conversion failed", i) - assert.NotEqual(t, 0, additive, "non-quorum party %d additive share should not be zero", i) - ref := additive.cgobindingRef() - (&ref).Free() - } -} - -// TestECDSAMPCThresholdDKG_SigningFailsWithTooFewParties ensures that attempting -// to sign with fewer than 3 parties (and fewer than the 3-of-5 threshold) fails. -func TestECDSAMPCThresholdDKG_SigningFailsWithTooFewParties(t *testing.T) { - const ( - nParties = 5 - threshold = 3 - ) - - cv, err := curve.NewSecp256k1() - require.NoError(t, err) - defer cv.Free() - - pnames := mocknet.GeneratePartyNames(nParties) - messengers := mocknet.NewMockNetwork(nParties) - - // Run threshold DKG across all parties - type dkgRes struct { - idx int - resp *ECDSAMPCThresholdDKGResponse - err error - } - dkgCh := make(chan dkgRes, nParties) - for i := 0; i < nParties; i++ { - go func(idx int) { - job, err := NewJobMP(messengers[idx], nParties, idx, pnames) - if err != nil { - dkgCh <- dkgRes{idx: idx, err: err} - return - } - defer job.Free() - ac := createThresholdAccessStructure(pnames, threshold, cv) - r, e := ECDSAMPCThresholdDKG(job, &ECDSAMPCThresholdDKGRequest{Curve: cv, AccessStructure: ac}) - dkgCh <- dkgRes{idx: idx, resp: r, err: e} - }(i) - } - resp := make([]*ECDSAMPCThresholdDKGResponse, nParties) - for i := 0; i < nParties; i++ { - out := <-dkgCh - require.NoError(t, out.err) - resp[out.idx] = out.resp - } - - // Convert to additive shares for a valid 3-of-5 quorum - root := Threshold("", threshold, func() []*AccessNode { - kids := make([]*AccessNode, len(pnames)) - for i, n := range pnames { - kids[i] = Leaf(n) - } - return kids - }()...) - asQ := &AccessStructure{Root: root, Curve: cv} - quorumNames := pnames[:threshold] - additive := make([]ECDSAMPCKey, threshold) - for i := 0; i < threshold; i++ { - as, err := resp[i].KeyShare.ToAdditiveShare(asQ, quorumNames) - require.NoError(t, err) - additive[i] = as - } - - // Attempt to sign with only two parties -> should fail - signMessengers := mocknet.NewMockNetwork(2) - signPNames := quorumNames[:2] - type signResult struct { - idx int - err error - } - signCh := make(chan signResult, 2) - digest := sha256.Sum256([]byte("ecdsa threshold negative test")) - sigReceiver := 0 - for i := 0; i < 2; i++ { - go func(idx int) { - job, err := NewJobMP(signMessengers[idx], 2, idx, signPNames) - if err != nil { - signCh <- signResult{idx: idx, err: err} - return - } - defer job.Free() - _, e := ECDSAMPCSign(job, &ECDSAMPCSignRequest{KeyShare: additive[idx], Message: digest[:], SignatureReceiver: sigReceiver}) - signCh <- signResult{idx: idx, err: e} - }(i) - } - for i := 0; i < 2; i++ { - out := <-signCh - require.Error(t, out.err, "party %d signing should fail with too few parties", out.idx) - } -} diff --git a/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_validation_test.go b/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_validation_test.go deleted file mode 100644 index 6efb01ab..00000000 --- a/demos-go/cb-mpc-go/api/mpc/ecdsa_mp_validation_test.go +++ /dev/null @@ -1,368 +0,0 @@ -package mpc - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" -) - -// ------------------------------------------------------------ -// Helper types & functions -// ------------------------------------------------------------ - -type partyResult[T any] struct { - idx int - val T - err error -} - -// keyGenWithMockNet spins up `n` in-memory parties and runs ECDSAMPCKeyGen -// via the public API, returning the per-party responses. -func keyGenWithMockNet(n int, cv curve.Curve) ([]*ECDSAMPCKeyGenResponse, error) { - pnames := mocknet.GeneratePartyNames(n) - messengers := mocknet.NewMockNetwork(n) - - respCh := make(chan partyResult[*ECDSAMPCKeyGenResponse], n) - - for i := 0; i < n; i++ { - go func(idx int) { - j, err := NewJobMP(messengers[idx], n, idx, pnames) - if err != nil { - respCh <- partyResult[*ECDSAMPCKeyGenResponse]{idx: idx, val: nil, err: err} - return - } - defer j.Free() - - r, e := ECDSAMPCKeyGen(j, &ECDSAMPCKeyGenRequest{Curve: cv}) - respCh <- partyResult[*ECDSAMPCKeyGenResponse]{idx: idx, val: r, err: e} - }(i) - } - - res := make([]*ECDSAMPCKeyGenResponse, n) - for i := 0; i < n; i++ { - out := <-respCh - if out.err != nil { - return nil, fmt.Errorf("party %d keygen failed: %v", out.idx, out.err) - } - res[out.idx] = out.val - } - return res, nil -} - -// refreshWithMockNet performs the refresh protocol on the provided key shares. -func refreshWithMockNet(orig []*ECDSAMPCKeyGenResponse, sessionID []byte) ([]*ECDSAMPCRefreshResponse, error) { - n := len(orig) - pnames := mocknet.GeneratePartyNames(n) - messengers := mocknet.NewMockNetwork(n) - - respCh := make(chan partyResult[*ECDSAMPCRefreshResponse], n) - - for i := 0; i < n; i++ { - go func(idx int) { - j, err := NewJobMP(messengers[idx], n, idx, pnames) - if err != nil { - respCh <- partyResult[*ECDSAMPCRefreshResponse]{idx: idx, val: nil, err: err} - return - } - defer j.Free() - - req := &ECDSAMPCRefreshRequest{KeyShare: orig[idx].KeyShare, SessionID: sessionID} - r, e := ECDSAMPCRefresh(j, req) - respCh <- partyResult[*ECDSAMPCRefreshResponse]{idx: idx, val: r, err: e} - }(i) - } - - res := make([]*ECDSAMPCRefreshResponse, n) - for i := 0; i < n; i++ { - out := <-respCh - if out.err != nil { - return nil, fmt.Errorf("party %d refresh failed: %v", out.idx, out.err) - } - res[out.idx] = out.val - } - return res, nil -} - -// signWithMockNet executes a signing round over the provided key shares. -// Only party `receiver` will obtain the resulting signature. -func signWithMockNet(keyShares []ECDSAMPCKey, msg []byte, receiver int) ([]*ECDSAMPCSignResponse, error) { - n := len(keyShares) - pnames := mocknet.GeneratePartyNames(n) - messengers := mocknet.NewMockNetwork(n) - - respCh := make(chan partyResult[*ECDSAMPCSignResponse], n) - - for i := 0; i < n; i++ { - go func(idx int) { - j, err := NewJobMP(messengers[idx], n, idx, pnames) - if err != nil { - respCh <- partyResult[*ECDSAMPCSignResponse]{idx: idx, val: nil, err: err} - return - } - defer j.Free() - - req := &ECDSAMPCSignRequest{ - KeyShare: keyShares[idx], - Message: msg, - SignatureReceiver: receiver, - } - r, e := ECDSAMPCSign(j, req) - respCh <- partyResult[*ECDSAMPCSignResponse]{idx: idx, val: r, err: e} - }(i) - } - - res := make([]*ECDSAMPCSignResponse, n) - for i := 0; i < n; i++ { - out := <-respCh - if out.err != nil { - return nil, fmt.Errorf("party %d sign failed: %v", out.idx, out.err) - } - res[out.idx] = out.val - } - return res, nil -} - -// ------------------------------------------------------------ -// Tests -// ------------------------------------------------------------ - -func TestECDSAMPC_DKG_Validation(t *testing.T) { - partyCounts := []int{3, 4, 5} - - for _, n := range partyCounts { - t.Run(fmt.Sprintf("dkg_%d_parties", n), func(t *testing.T) { - cv, err := curve.NewSecp256k1() - require.NoError(t, err) - defer cv.Free() - - keyGenRes, err := keyGenWithMockNet(n, cv) - require.NoError(t, err) - require.Len(t, keyGenRes, n) - - // --- collect shared data from first party --- - firstKey := keyGenRes[0].KeyShare - - Qglobal, err := firstKey.Q() - require.NoError(t, err) - defer Qglobal.Free() - - QiMap, err := firstKey.Qis() - require.NoError(t, err) - - expectedCode := curve.Code(cv) - - // --- per-party validations --- - pnames := mocknet.GeneratePartyNames(n) - for i, resp := range keyGenRes { - ks := resp.KeyShare - - // party name correct - pname, err := ks.PartyName() - require.NoError(t, err) - assert.Equal(t, pnames[i], pname, "party %d name mismatch", i) - - // curve matches - c, err := ks.Curve() - require.NoError(t, err) - actualCode := curve.Code(c) - assert.Equal(t, expectedCode, actualCode) - c.Free() - - // Q identical - Q, err := ks.Q() - require.NoError(t, err) - assert.True(t, Q.Equals(Qglobal), "party %d Q differs", i) - Q.Free() - - // Qi maps identical - QiMapOther, err := ks.Qis() - require.NoError(t, err) - require.Equal(t, len(QiMap), len(QiMapOther)) - for k, v := range QiMap { - other := QiMapOther[k] - assert.True(t, v.Equals(other), "party %d Qi for %s differs", i, k) - other.Free() - } - - // x_share consistency - x, err := ks.XShare() - require.NoError(t, err) - QiParty := QiMap[pname] - expectedQi, err := cv.MultiplyGenerator(x) - require.NoError(t, err) - assert.True(t, expectedQi.Equals(QiParty), "party %d Qi != x_i*G", i) - expectedQi.Free() - } - - // --- Sum(Qi) == Q --- - var sumPt *curve.Point - // We need a stable iteration order; use pnames slice - for i, name := range pnames { - pt := QiMap[name] - if i == 0 { - sumPt = pt // borrow reference; do not free here - continue - } - tmp := sumPt.Add(pt) - if i != 0 { - // Free previous accumulator if it was not one of the Qi map entries - if sumPt != pt { // avoid double-free - sumPt.Free() - } - } - sumPt = tmp - } - assert.True(t, sumPt.Equals(Qglobal), "sum(Qi) != Q") - - // Free accumulator if it is not one of original Qi values - accIsQi := false - for _, pt := range QiMap { - if pt == sumPt { - accIsQi = true - break - } - } - if !accIsQi { - sumPt.Free() - } - }) - } -} - -func TestECDSAMPC_Refresh(t *testing.T) { - const nParties = 3 - cv, err := curve.NewSecp256k1() - require.NoError(t, err) - defer cv.Free() - - // --- initial keygen --- - keyGenRes, err := keyGenWithMockNet(nParties, cv) - require.NoError(t, err) - - // capture original x_shares & Q - origX := make([]*curve.Scalar, nParties) - for i, ks := range keyGenRes { - x, err := ks.KeyShare.XShare() - require.NoError(t, err) - origX[i] = x - } - Qorig, err := keyGenRes[0].KeyShare.Q() - require.NoError(t, err) - defer Qorig.Free() - - // --- refresh --- - refreshRes, err := refreshWithMockNet(keyGenRes, nil) - require.NoError(t, err) - - // validations - for i := 0; i < nParties; i++ { - newShare := refreshRes[i].NewKeyShare - - // x_share changed - newX, err := newShare.XShare() - require.NoError(t, err) - assert.False(t, newX.Equal(origX[i]), "party %d x_share should change after refresh", i) - - // Q unchanged - Qnew, err := newShare.Q() - require.NoError(t, err) - assert.True(t, Qnew.Equals(Qorig), "party %d Q changed after refresh", i) - Qnew.Free() - } -} - -func TestECDSAMPC_Sign_Refresh_Sign(t *testing.T) { - const nParties = 3 - cv, err := curve.NewSecp256k1() - require.NoError(t, err) - defer cv.Free() - - // Key generation - keyGenRes, err := keyGenWithMockNet(nParties, cv) - require.NoError(t, err) - - keyShares := make([]ECDSAMPCKey, nParties) - for i, r := range keyGenRes { - keyShares[i] = r.KeyShare - } - - msg1 := []byte("first message") - - // Sign before refresh - sigRes1, err := signWithMockNet(keyShares, msg1, 0) - require.NoError(t, err) - - // Only receiver (0) gets signature - assert.Greater(t, len(sigRes1[0].Signature), 0) - for i := 1; i < nParties; i++ { - assert.Equal(t, 0, len(sigRes1[i].Signature)) - } - - // Refresh - refreshRes, err := refreshWithMockNet(keyGenRes, nil) - require.NoError(t, err) - - newShares := make([]ECDSAMPCKey, nParties) - for i, r := range refreshRes { - newShares[i] = r.NewKeyShare - } - - // Sign after refresh (same message for simplicity) - sigRes2, err := signWithMockNet(newShares, msg1, 0) - require.NoError(t, err) - assert.Greater(t, len(sigRes2[0].Signature), 0) - - // Signatures should differ - assert.NotEqual(t, sigRes1[0].Signature, sigRes2[0].Signature) -} - -func TestECDSAMPC_SerializeDeserialize(t *testing.T) { - const nParties = 3 - - cv, err := curve.NewSecp256k1() - require.NoError(t, err) - defer cv.Free() - - keyGenRes, err := keyGenWithMockNet(nParties, cv) - require.NoError(t, err) - - deserShares := make([]ECDSAMPCKey, nParties) - for i, res := range keyGenRes { - ser, err := res.KeyShare.MarshalBinary() - require.NoError(t, err) - assert.Greater(t, len(ser), 0, "serialized data should not be empty") - - var newKey ECDSAMPCKey - err = newKey.UnmarshalBinary(ser) - require.NoError(t, err) - - // Party name should stay identical - origPName, err := res.KeyShare.PartyName() - require.NoError(t, err) - newPName, err := newKey.PartyName() - require.NoError(t, err) - assert.Equal(t, origPName, newPName, "party %d name mismatch after serde", i) - - // Public key Q must match - origQ, err := res.KeyShare.Q() - require.NoError(t, err) - newQ, err := newKey.Q() - require.NoError(t, err) - assert.True(t, origQ.Equals(newQ), "party %d Q mismatch after serde", i) - origQ.Free() - newQ.Free() - - deserShares[i] = newKey - } - - message := []byte("serde-test-message") - sigRes, err := signWithMockNet(deserShares, message, 0) - require.NoError(t, err) - - assert.Greater(t, len(sigRes[0].Signature), 0, "receiver should get a signature") -} diff --git a/demos-go/cb-mpc-go/api/mpc/eddsa_mp.go b/demos-go/cb-mpc-go/api/mpc/eddsa_mp.go deleted file mode 100644 index ec047964..00000000 --- a/demos-go/cb-mpc-go/api/mpc/eddsa_mp.go +++ /dev/null @@ -1,305 +0,0 @@ -package mpc - -import ( - "bytes" - "encoding" - "encoding/gob" - "fmt" - "runtime" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// Compile-time assertions to ensure EDDSAMPCKey implements the binary marshaling -// interfaces. -var _ encoding.BinaryMarshaler = (*EDDSAMPCKey)(nil) -var _ encoding.BinaryUnmarshaler = (*EDDSAMPCKey)(nil) - -// ============================================================================ -// Type definitions -// ============================================================================ - -// EDDSAMPCKey represents an opaque N-party EdDSA key share owned by the current -// party. Internally it wraps cgobinding.Mpc_eckey_mp_ref – the underlying representation -// is shared with other N-party key types. -// -// NOTE: The zero value is invalid. -// -// All methods are intentionally kept analogous to the existing multi-party key APIs to provide a -// consistent developer experience. -// -// ---------------------------------------------------------------------------- - -type EDDSAMPCKey cgobinding.Mpc_eckey_mp_ref - -func newEDDSAMPCKey(ref cgobinding.Mpc_eckey_mp_ref) EDDSAMPCKey { - return EDDSAMPCKey(ref) -} - -// Free releases the underlying native resources. -func (k *EDDSAMPCKey) Free() { - if k == nil { - return - } - ref := cgobinding.Mpc_eckey_mp_ref(*k) - (&ref).Free() - *k = EDDSAMPCKey(cgobinding.Mpc_eckey_mp_ref{}) - runtime.SetFinalizer(k, nil) -} - -func (k EDDSAMPCKey) cgobindingRef() cgobinding.Mpc_eckey_mp_ref { - return cgobinding.Mpc_eckey_mp_ref(k) -} - -// MarshalBinary serialises the key share into a portable wire format. -func (k EDDSAMPCKey) MarshalBinary() ([]byte, error) { - parts, err := cgobinding.SerializeKeyShare(k.cgobindingRef()) - if err != nil { - return nil, err - } - var buf bytes.Buffer - if err := gob.NewEncoder(&buf).Encode(parts); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -// UnmarshalBinary restores a key share previously produced by MarshalBinary. -func (k *EDDSAMPCKey) UnmarshalBinary(data []byte) error { - var parts [][]byte - if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&parts); err != nil { - return err - } - ref, err := cgobinding.DeserializeKeyShare(parts) - if err != nil { - return err - } - *k = newEDDSAMPCKey(ref) - return nil -} - -// Accessors --------------------------------------------------------------------------------- - -func (k EDDSAMPCKey) PartyName() (string, error) { - return cgobinding.MPC_mpc_eckey_mp_get_party_name(k.cgobindingRef()) -} - -func (k EDDSAMPCKey) XShare() (*curve.Scalar, error) { - bytes, err := cgobinding.MPC_mpc_eckey_mp_get_x_share(k.cgobindingRef()) - if err != nil { - return nil, err - } - return &curve.Scalar{Bytes: bytes}, nil -} - -func (k EDDSAMPCKey) Q() (*curve.Point, error) { - bytes, err := cgobinding.KeyShareQBytes(k.cgobindingRef()) - if err != nil { - return nil, err - } - return curve.NewPointFromBytes(bytes) -} - -func (k EDDSAMPCKey) Curve() (curve.Curve, error) { - code, err := cgobinding.KeyShareCurveCode(k.cgobindingRef()) - if err != nil { - return nil, err - } - return curve.NewFromCode(code) -} - -func (k EDDSAMPCKey) Qis() (map[string]*curve.Point, error) { - names, points, err := cgobinding.MPC_mpc_eckey_mp_Qis(k.cgobindingRef()) - if err != nil { - return nil, err - } - if len(names) != len(points) { - return nil, fmt.Errorf("inconsistent Qis arrays: %d names vs %d points", len(names), len(points)) - } - out := make(map[string]*curve.Point, len(names)) - for i, nameBytes := range names { - pt, err := curve.NewPointFromBytes(points[i]) - if err != nil { - return nil, fmt.Errorf("failed to decode Qi for party %s: %v", string(nameBytes), err) - } - out[string(nameBytes)] = pt - } - return out, nil -} - -// ============================================================================ -// Request / Response structs -// ============================================================================ - -type EDDSAMPCKeyGenRequest struct { - Curve curve.Curve -} - -type EDDSAMPCKeyGenResponse struct { - KeyShare EDDSAMPCKey -} - -type EDDSAMPCSignRequest struct { - KeyShare EDDSAMPCKey - Message []byte - SignatureReceiver int -} - -type EDDSAMPCSignResponse struct { - Signature []byte -} - -type EDDSAMPCRefreshRequest struct { - KeyShare EDDSAMPCKey - SessionID []byte -} - -type EDDSAMPCRefreshResponse struct { - NewKeyShare EDDSAMPCKey -} - -// ============================================================================ -// Core API functions -// ============================================================================ - -// EDDSAMPCKeyGen performs algorithm-agnostic distributed key generation. -func EDDSAMPCKeyGen(jobmp *JobMP, req *EDDSAMPCKeyGenRequest) (*EDDSAMPCKeyGenResponse, error) { - if jobmp == nil { - return nil, fmt.Errorf("job must be provided") - } - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if req.Curve == nil { - return nil, fmt.Errorf("curve must be provided") - } - if jobmp.NParties() < 3 { - return nil, fmt.Errorf("n-party EdDSA requires at least 3 parties") - } - - key, err := cgobinding.KeyShareDKGCode(jobmp.cgo(), curve.Code(req.Curve)) - if err != nil { - return nil, fmt.Errorf("EdDSA N-party key generation failed: %v", err) - } - return &EDDSAMPCKeyGenResponse{KeyShare: newEDDSAMPCKey(key)}, nil -} - -// EDDSAMPCSign performs N-party EdDSA signing. -func EDDSAMPCSign(jobmp *JobMP, req *EDDSAMPCSignRequest) (*EDDSAMPCSignResponse, error) { - if jobmp == nil { - return nil, fmt.Errorf("job must be provided") - } - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if jobmp.NParties() < 3 { - return nil, fmt.Errorf("n-party signing requires at least 3 parties") - } - if len(req.Message) == 0 { - return nil, fmt.Errorf("message cannot be empty") - } - - sig, err := cgobinding.MPC_eddsampc_sign(jobmp.cgo(), req.KeyShare.cgobindingRef(), req.Message, req.SignatureReceiver) - if err != nil { - return nil, fmt.Errorf("EdDSA N-party signing failed: %v", err) - } - - roleIdx := jobmp.GetPartyIndex() - var sigBytes []byte - if roleIdx == req.SignatureReceiver { - sigBytes = sig - } - return &EDDSAMPCSignResponse{Signature: sigBytes}, nil -} - -// EDDSAMPCRefresh re-shares secret without changing public key. -func EDDSAMPCRefresh(jobmp *JobMP, req *EDDSAMPCRefreshRequest) (*EDDSAMPCRefreshResponse, error) { - if jobmp == nil { - return nil, fmt.Errorf("job must be provided") - } - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if jobmp.NParties() < 3 { - return nil, fmt.Errorf("n-party refresh requires at least 3 parties") - } - sid := req.SessionID - newKey, err := cgobinding.KeyShareRefresh(jobmp.cgo(), sid, req.KeyShare.cgobindingRef()) - if err != nil { - return nil, fmt.Errorf("EdDSA N-party refresh failed: %v", err) - } - return &EDDSAMPCRefreshResponse{NewKeyShare: newEDDSAMPCKey(newKey)}, nil -} - -// The threshold-DKG and ToAdditiveShare helpers reuse the same low-level bindings -// to avoid code duplication. - -// EDDSAMPCThresholdDKGRequest holds the parameters for running the threshold-DKG -// protocol when creating an N-party EdDSA key. -type EDDSAMPCThresholdDKGRequest struct { - Curve curve.Curve // Elliptic curve to use - SessionID []byte // Optional caller-supplied session identifier - AccessStructure *AccessStructure // Quorum access-structure description - QuorumRIDs []int // (Optional) indices of parties that will form the quorum; defaults to all parties if nil/empty -} - -// EDDSAMPCThresholdDKGResponse contains the key share produced for the calling -// party by the threshold-DKG protocol. -type EDDSAMPCThresholdDKGResponse struct { - KeyShare EDDSAMPCKey -} - -// EDDSAMPCThresholdDKG executes the threshold DKG protocol for EdDSA and -// returns the caller's key share. -func EDDSAMPCThresholdDKG(jobmp *JobMP, req *EDDSAMPCThresholdDKGRequest) (*EDDSAMPCThresholdDKGResponse, error) { - if jobmp == nil { - return nil, fmt.Errorf("job must be provided") - } - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if req.Curve == nil { - return nil, fmt.Errorf("curve must be provided") - } - - sid := req.SessionID - - if req.AccessStructure == nil { - return nil, fmt.Errorf("access structure must be provided") - } - - acPtr := req.AccessStructure.toCryptoAC() - - roleIndices := req.QuorumRIDs - if len(roleIndices) == 0 { - roleIndices = make([]int, jobmp.NParties()) - for i := 0; i < jobmp.NParties(); i++ { - roleIndices[i] = i - } - } - - keyShareRef, err := cgobinding.ThresholdDKGCode(jobmp.cgo(), curve.Code(req.Curve), sid, acPtr, roleIndices) - if err != nil { - return nil, fmt.Errorf("EdDSA threshold DKG failed: %v", err) - } - - return &EDDSAMPCThresholdDKGResponse{KeyShare: newEDDSAMPCKey(keyShareRef)}, nil -} - -func (k EDDSAMPCKey) ToAdditiveShare(ac *AccessStructure, quorumPartyNames []string) (EDDSAMPCKey, error) { - if ac == nil { - return EDDSAMPCKey{}, fmt.Errorf("access structure must be provided") - } - if len(quorumPartyNames) == 0 { - return EDDSAMPCKey{}, fmt.Errorf("quorumPartyNames cannot be empty") - } - - acPtr := ac.toCryptoAC() - keyRef := cgobinding.Mpc_eckey_mp_ref(k) - addRef, err := (&keyRef).ToAdditiveShare(acPtr, quorumPartyNames) - if err != nil { - return EDDSAMPCKey{}, err - } - return newEDDSAMPCKey(addRef), nil -} diff --git a/demos-go/cb-mpc-go/api/mpc/eddsa_mp_test.go b/demos-go/cb-mpc-go/api/mpc/eddsa_mp_test.go deleted file mode 100644 index 02d00a20..00000000 --- a/demos-go/cb-mpc-go/api/mpc/eddsa_mp_test.go +++ /dev/null @@ -1,215 +0,0 @@ -package mpc - -import ( - "bytes" - "crypto/ed25519" - "fmt" - "testing" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// EDDSAMPCWithMockNet executes the full EdDSA N-party workflow using the in-memory -// mock network. It is intentionally lightweight compared to the exhaustive -// ECDSA test-suite – its goal is to ensure the basic API surface compiles and -// the protocol can run end-to-end. -func EDDSAMPCWithMockNet(nParties int, cv curve.Curve, message []byte) ([]*EDDSAMPCKeyGenResponse, []*EDDSAMPCSignResponse, error) { - if nParties < 3 { - return nil, nil, fmt.Errorf("EdDSA N-party requires at least 3 parties") - } - if len(message) == 0 { - return nil, nil, fmt.Errorf("message cannot be empty") - } - - runner := mocknet.NewMPCRunner(mocknet.GeneratePartyNames(nParties)...) - - // ---------------- KeyGen ---------------- - keyGenInputs := make([]*mocknet.MPCIO, nParties) - for i := 0; i < nParties; i++ { - keyGenInputs[i] = &mocknet.MPCIO{Opaque: cv} - } - keyGenOutputs, err := runner.MPCRunMP(func(job cgobinding.JobMP, input *mocknet.MPCIO) (*mocknet.MPCIO, error) { - curveObj := input.Opaque.(curve.Curve) - apiJob := &JobMP{inner: job} - resp, err := EDDSAMPCKeyGen(apiJob, &EDDSAMPCKeyGenRequest{Curve: curveObj}) - if err != nil { - return nil, err - } - return &mocknet.MPCIO{Opaque: resp.KeyShare}, nil - }, keyGenInputs) - if err != nil { - return nil, nil, err - } - - keyShares := make([]EDDSAMPCKey, nParties) - keyGenResponses := make([]*EDDSAMPCKeyGenResponse, nParties) - for i := 0; i < nParties; i++ { - keyShares[i] = keyGenOutputs[i].Opaque.(EDDSAMPCKey) - keyGenResponses[i] = &EDDSAMPCKeyGenResponse{KeyShare: keyShares[i]} - } - - // ---------------- Sign ---------------- - signInputs := make([]*mocknet.MPCIO, nParties) - for i := 0; i < nParties; i++ { - signInputs[i] = &mocknet.MPCIO{Opaque: struct { - Key EDDSAMPCKey - Msg []byte - }{Key: keyShares[i], Msg: message}} - } - - const sigReceiver = 0 - - signOutputs, err := runner.MPCRunMP(func(job cgobinding.JobMP, input *mocknet.MPCIO) (*mocknet.MPCIO, error) { - data := input.Opaque.(struct { - Key EDDSAMPCKey - Msg []byte - }) - apiJob := &JobMP{inner: job} - resp, err := EDDSAMPCSign(apiJob, &EDDSAMPCSignRequest{ - KeyShare: data.Key, - Message: data.Msg, - SignatureReceiver: sigReceiver, - }) - if err != nil { - return nil, err - } - return &mocknet.MPCIO{Opaque: resp.Signature}, nil - }, signInputs) - if err != nil { - return nil, nil, err - } - - signResponses := make([]*EDDSAMPCSignResponse, nParties) - for i := 0; i < nParties; i++ { - var sigBytes []byte - if i == sigReceiver { - sigBytes = signOutputs[i].Opaque.([]byte) - } - signResponses[i] = &EDDSAMPCSignResponse{Signature: sigBytes} - } - - return keyGenResponses, signResponses, nil -} - -func TestEDDSAMPC_EndToEnd(t *testing.T) { - ed, err := curve.NewEd25519() - if err != nil { - t.Fatalf("failed to init curve: %v", err) - } - - const nParties = 3 - message := []byte("hello eddsa") - - keyRes, signRes, err := EDDSAMPCWithMockNet(nParties, ed, message) - if err != nil { - t.Fatalf("protocol failed: %v", err) - } - - if len(keyRes) != nParties || len(signRes) != nParties { - t.Fatalf("unexpected response sizes") - } - - if len(signRes[0].Signature) == 0 { - t.Fatalf("signature receiver did not obtain signature") - } - // Non-receiver parties should have empty signatures - for i := 1; i < nParties; i++ { - if len(signRes[i].Signature) != 0 { - t.Fatalf("party %d unexpectedly received signature bytes", i) - } - } - - // Verify the signature against the aggregated public key Q using Ed25519 - qVerify, err := keyRes[0].KeyShare.Q() - if err != nil { - t.Fatalf("Q() failed for verification: %v", err) - } - pub, err := ed25519PublicKeyFromPoint(qVerify) - qVerify.Free() - if err != nil { - t.Fatalf("failed to derive Ed25519 public key: %v", err) - } - sig := signRes[0].Signature - if len(sig) != ed25519.SignatureSize { - t.Fatalf("unexpected Ed25519 signature length: got %d", len(sig)) - } - if !ed25519.Verify(ed25519.PublicKey(pub), message, sig) { - t.Fatalf("signature verification failed") - } - - // Validate EDDSAMPCKey.Curve() and EDDSAMPCKey.Q() accessors on resulting key shares - expectedCode := curve.Code(ed) - var q0 []byte - for i := 0; i < nParties; i++ { - c, err := keyRes[i].KeyShare.Curve() - if err != nil { - t.Fatalf("Curve() failed for party %d: %v", i, err) - } - if got := curve.Code(c); got != expectedCode { - t.Fatalf("Curve() returned unexpected code for party %d: got %d want %d", i, got, expectedCode) - } - - q, err := keyRes[i].KeyShare.Q() - if err != nil { - t.Fatalf("Q() failed for party %d: %v", i, err) - } - qBytes := q.Bytes() - if len(qBytes) == 0 { - t.Fatalf("Q() returned empty point for party %d", i) - } - if q.IsZero() { - t.Fatalf("Q() returned zero point for party %d", i) - } - if i == 0 { - q0 = qBytes - } else if !bytes.Equal(qBytes, q0) { - t.Fatalf("Q() mismatch across parties: party %d differs", i) - } - q.Free() - } - - // Negative checks: zero-value key should surface errors from Curve() and Q() - var zeroKey EDDSAMPCKey - if _, err := zeroKey.Curve(); err == nil { - t.Fatalf("expected Curve() to fail on zero-value key") - } - if _, err := zeroKey.Q(); err == nil { - t.Fatalf("expected Q() to fail on zero-value key") - } -} - -// ed25519PublicKeyFromPoint converts a curve point on Ed25519 to the 32-byte -// compressed public key as defined by RFC 8032: little-endian encoding of the -// y-coordinate with the most-significant bit set to the sign bit of x. -func ed25519PublicKeyFromPoint(q *curve.Point) ([]byte, error) { - if q == nil { - return nil, fmt.Errorf("nil point") - } - // Some curves may already serialize Ed25519 points in compressed 32-byte form. - if pb := q.Bytes(); len(pb) == ed25519.PublicKeySize { - return pb, nil - } - - x := q.GetX() - y := q.GetY() - if len(y) == 0 { - return nil, fmt.Errorf("empty y coordinate") - } - if len(y) > ed25519.PublicKeySize { - y = y[len(y)-ed25519.PublicKeySize:] - } - pub := make([]byte, ed25519.PublicKeySize) - copy(pub[ed25519.PublicKeySize-len(y):], y) - for i, j := 0, len(pub)-1; i < j; i, j = i+1, j-1 { - pub[i], pub[j] = pub[j], pub[i] - } - var xlsb byte - if len(x) > 0 { - xlsb = x[len(x)-1] & 1 - } - pub[31] &^= 0x80 - pub[31] |= (xlsb << 7) - return pub, nil -} diff --git a/demos-go/cb-mpc-go/api/mpc/eddsa_mp_threshold_test.go b/demos-go/cb-mpc-go/api/mpc/eddsa_mp_threshold_test.go deleted file mode 100644 index f2bc9ddb..00000000 --- a/demos-go/cb-mpc-go/api/mpc/eddsa_mp_threshold_test.go +++ /dev/null @@ -1,355 +0,0 @@ -package mpc - -import ( - "crypto/ed25519" - "testing" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestEDDSAMPCThresholdDKGWithMockNet exercises the high-level -// EDDSAMPCThresholdDKG wrapper across multiple parties using the in-memory mock -// network. It validates that the threshold DKG protocol works and that the -// resulting key shares can be used to sign a message. -func TestEDDSAMPCThresholdDKGWithMockNet(t *testing.T) { - const ( - nParties = 5 - threshold = 3 // 3-of-5 threshold policy - ) - - // Prepare curve instance. - cv, err := curve.NewEd25519() - require.NoError(t, err) - defer cv.Free() - - // Prepare mock network primitives. - pnames := mocknet.GeneratePartyNames(nParties) - messengers := mocknet.NewMockNetwork(nParties) - - // Channel to gather per-party results. - type result struct { - idx int - resp *EDDSAMPCThresholdDKGResponse - err error - } - resCh := make(chan result, nParties) - - // Launch one goroutine per party. - for i := 0; i < nParties; i++ { - go func(idx int) { - // Build JobMP wrapper for this party. - job, err := NewJobMP(messengers[idx], nParties, idx, pnames) - if err != nil { - resCh <- result{idx: idx, resp: nil, err: err} - return - } - defer job.Free() - - // Each party creates its own access-structure object. - ac := createThresholdAccessStructure(pnames, threshold, cv) - - req := &EDDSAMPCThresholdDKGRequest{ - Curve: cv, - SessionID: nil, // let native generate SID - AccessStructure: ac, - } - - r, e := EDDSAMPCThresholdDKG(job, req) - resCh <- result{idx: idx, resp: r, err: e} - }(i) - } - - // Collect results. - resp := make([]*EDDSAMPCThresholdDKGResponse, nParties) - for i := 0; i < nParties; i++ { - out := <-resCh - require.NoError(t, out.err, "party %d threshold DKG should succeed", out.idx) - require.NotNil(t, out.resp, "party %d response must not be nil", out.idx) - resp[out.idx] = out.resp - } - - // Basic validations. - expectedCurveCode := curve.Code(cv) - - for i, r := range resp { - // Key share must be non-zero. - assert.NotEqual(t, 0, r.KeyShare, "party %d key share should not be zero", i) - - // Party name matches. - pname, err := r.KeyShare.PartyName() - require.NoError(t, err) - assert.Equal(t, pnames[i], pname, "party %d pname mismatch", i) - - // Curve matches. - c, err := r.KeyShare.Curve() - require.NoError(t, err) - actual := curve.Code(c) - assert.Equal(t, expectedCurveCode, actual) - c.Free() - - // Note: For threshold-DKG keys, SUM(Qis) may not equal Q until converted to additive shares. - } - - // Convert to additive shares for a quorum of size `threshold` - root := Threshold("", threshold, func() []*AccessNode { - kids := make([]*AccessNode, len(pnames)) - for i, n := range pnames { - kids[i] = Leaf(n) - } - return kids - }()...) - acQ := &AccessStructure{Root: root, Curve: cv} - quorumNames := pnames[:threshold] - additive := make([]EDDSAMPCKey, threshold) - for i := 0; i < threshold; i++ { - as, err := resp[i].KeyShare.ToAdditiveShare(acQ, quorumNames) - require.NoError(t, err, "party %d additive share conversion failed", i) - additive[i] = as - } - - // Run an EdDSA MPC signing round with only the quorum parties using additive shares - message := []byte("eddsa threshold dkg signing") - sigReceiver := 0 - - // Fresh mock network for signing across quorum parties - signMessengers := mocknet.NewMockNetwork(threshold) - signPNames := quorumNames - - type signResult struct { - idx int - sig []byte - err error - } - signCh := make(chan signResult, threshold) - - for i := 0; i < threshold; i++ { - go func(idx int) { - job, err := NewJobMP(signMessengers[idx], threshold, idx, signPNames) - if err != nil { - signCh <- signResult{idx: idx, err: err} - return - } - defer job.Free() - - req := &EDDSAMPCSignRequest{ - KeyShare: additive[idx], - Message: message, - SignatureReceiver: sigReceiver, - } - r, e := EDDSAMPCSign(job, req) - if e != nil { - signCh <- signResult{idx: idx, err: e} - return - } - signCh <- signResult{idx: idx, sig: r.Signature, err: nil} - }(i) - } - - sigs := make([][]byte, threshold) - for i := 0; i < threshold; i++ { - out := <-signCh - require.NoError(t, out.err, "party %d signing should succeed", out.idx) - sigs[out.idx] = out.sig - } - - // Only the receiver should obtain the signature bytes. - require.NotEmpty(t, sigs[sigReceiver], "receiver should have signature bytes") - for i := 0; i < threshold; i++ { - if i == sigReceiver { - continue - } - assert.Empty(t, sigs[i], "non-receiver party %d should not have signature", i) - } - - // Verify the signature against the aggregated public key Q using Ed25519. - Q, err := resp[0].KeyShare.Q() - require.NoError(t, err) - pub, err := ed25519PublicKeyFromPoint(Q) - Q.Free() - require.NoError(t, err) - require.Len(t, sigs[sigReceiver], ed25519.SignatureSize) - valid := ed25519.Verify(ed25519.PublicKey(pub), message, sigs[sigReceiver]) - require.True(t, valid, "signature verification failed") -} - -// TestEDDSAMPC_ToAdditiveShare verifies that a subset of parties satisfying the -// quorum threshold can convert their threshold-DKG key share into an additive -// secret share without error. -func TestEDDSAMPC_ToAdditiveShare(t *testing.T) { - const ( - nParties = 4 - threshold = 2 - ) - - cv, err := curve.NewEd25519() - require.NoError(t, err) - defer cv.Free() - - pnames := mocknet.GeneratePartyNames(nParties) - messengers := mocknet.NewMockNetwork(nParties) - - // First run threshold DKG to obtain key shares. - type dkgResult struct { - idx int - share EDDSAMPCKey - err error - } - dkgCh := make(chan dkgResult, nParties) - - for i := 0; i < nParties; i++ { - go func(idx int) { - job, err := NewJobMP(messengers[idx], nParties, idx, pnames) - if err != nil { - dkgCh <- dkgResult{idx: idx, err: err} - return - } - defer job.Free() - - ac := createThresholdAccessStructure(pnames, threshold, cv) - - req := &EDDSAMPCThresholdDKGRequest{Curve: cv, AccessStructure: ac} - resp, err := EDDSAMPCThresholdDKG(job, req) - if err != nil { - dkgCh <- dkgResult{idx: idx, err: err} - return - } - dkgCh <- dkgResult{idx: idx, share: resp.KeyShare, err: nil} - }(i) - } - - shares := make([]EDDSAMPCKey, nParties) - for i := 0; i < nParties; i++ { - out := <-dkgCh - require.NoError(t, out.err) - shares[out.idx] = out.share - } - - // Prepare quorum party names – pick the first `threshold` parties. - quorumPNames := pnames[:threshold] - - // Build an AccessStructure representing the same threshold policy. - root := Threshold("", threshold, func() []*AccessNode { - kids := make([]*AccessNode, len(pnames)) - for i, n := range pnames { - kids[i] = Leaf(n) - } - return kids - }()...) - - asQ := &AccessStructure{Root: root, Curve: cv} - - // Convert shares for the quorum parties and ensure success. - for i := 0; i < threshold; i++ { - additive, err := shares[i].ToAdditiveShare(asQ, quorumPNames) - require.NoError(t, err, "party %d additive share conversion failed", i) - assert.NotEqual(t, 0, additive, "party %d additive share should not be zero", i) - // Clean up native resources to avoid leaks. - ref := additive.cgobindingRef() - (&ref).Free() - } - - // Non-quorum parties can also convert to additive shares; ensure no error - for i := threshold; i < nParties; i++ { - additive, err := shares[i].ToAdditiveShare(asQ, quorumPNames) - require.NoError(t, err, "non-quorum party %d additive conversion failed", i) - assert.NotEqual(t, 0, additive, "non-quorum party %d additive share should not be zero", i) - ref := additive.cgobindingRef() - (&ref).Free() - } -} - -// TestEDDSAMPCThresholdDKG_SigningFailsWithTooFewParties ensures that attempting -// to sign with fewer than 3 parties (below the protocol minimum and below the -// 3-of-5 threshold) fails as expected. -func TestEDDSAMPCThresholdDKG_SigningFailsWithTooFewParties(t *testing.T) { - const ( - nParties = 5 - threshold = 3 - ) - - cv, err := curve.NewEd25519() - require.NoError(t, err) - defer cv.Free() - - pnames := mocknet.GeneratePartyNames(nParties) - messengers := mocknet.NewMockNetwork(nParties) - - // Run threshold DKG across all parties - type dkgRes struct { - idx int - resp *EDDSAMPCThresholdDKGResponse - err error - } - dkgCh := make(chan dkgRes, nParties) - for i := 0; i < nParties; i++ { - go func(idx int) { - job, err := NewJobMP(messengers[idx], nParties, idx, pnames) - if err != nil { - dkgCh <- dkgRes{idx: idx, err: err} - return - } - defer job.Free() - ac := createThresholdAccessStructure(pnames, threshold, cv) - r, e := EDDSAMPCThresholdDKG(job, &EDDSAMPCThresholdDKGRequest{Curve: cv, AccessStructure: ac}) - dkgCh <- dkgRes{idx: idx, resp: r, err: e} - }(i) - } - resp := make([]*EDDSAMPCThresholdDKGResponse, nParties) - for i := 0; i < nParties; i++ { - out := <-dkgCh - require.NoError(t, out.err) - resp[out.idx] = out.resp - } - - // Convert to additive shares for a valid 3-of-5 quorum, but we will only - // attempt to sign with TWO parties to trigger failure. - root := Threshold("", threshold, func() []*AccessNode { - kids := make([]*AccessNode, len(pnames)) - for i, n := range pnames { - kids[i] = Leaf(n) - } - return kids - }()...) - asQ := &AccessStructure{Root: root, Curve: cv} - quorumNames := pnames[:threshold] - additive := make([]EDDSAMPCKey, threshold) - for i := 0; i < threshold; i++ { - as, err := resp[i].KeyShare.ToAdditiveShare(asQ, quorumNames) - require.NoError(t, err) - additive[i] = as - } - - // Use ONLY two parties for signing – should fail with "n-party signing requires at least 3 parties" - signMessengers := mocknet.NewMockNetwork(2) - signPNames := quorumNames[:2] - type signResult struct { - idx int - err error - } - signCh := make(chan signResult, 2) - message := []byte("eddsa threshold negative test") - sigReceiver := 0 - for i := 0; i < 2; i++ { - go func(idx int) { - job, err := NewJobMP(signMessengers[idx], 2, idx, signPNames) - if err != nil { - signCh <- signResult{idx: idx, err: err} - return - } - defer job.Free() - _, e := EDDSAMPCSign(job, &EDDSAMPCSignRequest{KeyShare: additive[idx], Message: message, SignatureReceiver: sigReceiver}) - signCh <- signResult{idx: idx, err: e} - }(i) - } - for i := 0; i < 2; i++ { - out := <-signCh - require.Error(t, out.err, "party %d signing should fail with too few parties", out.idx) - } -} - -// ed25519PublicKeyFromPoint helper is defined in eddsa_mp_test.go within the -// same package and reused here. diff --git a/demos-go/cb-mpc-go/api/mpc/job.go b/demos-go/cb-mpc-go/api/mpc/job.go deleted file mode 100644 index 2f45942b..00000000 --- a/demos-go/cb-mpc-go/api/mpc/job.go +++ /dev/null @@ -1,57 +0,0 @@ -package mpc - -import ( - "fmt" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// Job2P is an opaque handle for a 2-party MPC job. -// Users create it via NewJob2P and pass it to protocol APIs. -// Always call Free when finished. -type Job2P struct { - inner cgobinding.Job2P -} - -// NewJob2P constructs a two-party job. -// messenger – network layer implementation. -// roleIndex – 0 or 1 for the local party. -// pnames – names of the two parties (len == 2). -func NewJob2P(messenger transport.Messenger, roleIndex int, pnames []string) (*Job2P, error) { - inner, err := cgobinding.NewJob2P(messenger, roleIndex, pnames) - if err != nil { - return nil, err - } - return &Job2P{inner: inner}, nil -} - -// Free releases C-side resources. -func (j *Job2P) Free() { j.inner.Free() } - -// Close satisfies io.Closer by delegating to Free(). -func (j *Job2P) Close() error { - j.Free() - return nil -} - -// BroadcastToOthers sends the provided payload to the other party. -// For Job2P this means exactly one peer (1 - selfIndex). -func (j *Job2P) BroadcastToOthers(payload []byte) error { - sender := j.GetRoleIndex() - if sender < 0 || sender > 1 { - return fmt.Errorf("invalid role index %d", sender) - } - receiver := 1 - sender - _, err := j.inner.Message(sender, receiver, payload) - return err -} - -// IsRoleIndex returns true if the given index matches this party. -func (j *Job2P) IsRoleIndex(idx int) bool { return j.inner.IsRoleIndex(idx) } - -// GetRoleIndex returns the current party index (0 or 1). -func (j *Job2P) GetRoleIndex() int { return j.inner.GetRoleIndex() } - -// cgo exposes the underlying binding (internal). -func (j *Job2P) cgo() cgobinding.Job2P { return j.inner } diff --git a/demos-go/cb-mpc-go/api/mpc/job_mp.go b/demos-go/cb-mpc-go/api/mpc/job_mp.go deleted file mode 100644 index 99511a8e..00000000 --- a/demos-go/cb-mpc-go/api/mpc/job_mp.go +++ /dev/null @@ -1,38 +0,0 @@ -package mpc - -import ( - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// JobMP is an opaque handle for an N-party MPC job (N>2). -type JobMP struct { - inner cgobinding.JobMP -} - -// NewJobMP constructs a multi-party job. -func NewJobMP(messenger transport.Messenger, partyCount, roleIndex int, pnames []string) (*JobMP, error) { - inner, err := cgobinding.NewJobMP(messenger, partyCount, roleIndex, pnames) - if err != nil { - return nil, err - } - return &JobMP{inner: inner}, nil -} - -// Free releases resources. -func (j *JobMP) Free() { j.inner.Free() } - -// Close implements io.Closer. -func (j *JobMP) Close() error { j.Free(); return nil } - -// GetPartyIndex returns this party's index. -func (j *JobMP) GetPartyIndex() int { return j.inner.GetPartyIndex() } - -// IsParty checks if the given index matches this party. -func (j *JobMP) IsParty(idx int) bool { return j.inner.IsParty(idx) } - -// cgo exposes the underlying binding (internal). -func (j *JobMP) cgo() cgobinding.JobMP { return j.inner } - -// NParties returns the total number of parties in this MPC job. -func (j *JobMP) NParties() int { return j.inner.GetNParties() } diff --git a/demos-go/cb-mpc-go/api/mpc/pve.go b/demos-go/cb-mpc-go/api/mpc/pve.go deleted file mode 100644 index 2dc39c0e..00000000 --- a/demos-go/cb-mpc-go/api/mpc/pve.go +++ /dev/null @@ -1,176 +0,0 @@ -package mpc - -import ( - "fmt" - "runtime" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// ===================== Single-party PVE (ec_pve_t) ========================== - -// PVEEncryptRequest represents a request to back-up a single secret scalar `x`. -// -// - PublicKey: the ECIES public key used for encryption (BaseEncPublicKey). -// - PrivateValue: the secret scalar to encrypt (curve.Scalar). -// - Curve: the elliptic curve of the scalar. If nil, defaults to P-256. -// - Label: a human-readable domain-separator bound into the ciphertext. -// -// All fields are mandatory. The response contains the opaque PVECiphertext blob -// that must be persisted together with the public share (x * G) so that future -// verification and decryption operations can be performed. -type PVEEncryptRequest struct { - PublicKey BaseEncPublicKey - PrivateValue *curve.Scalar - Curve curve.Curve - Label string -} - -type PVEEncryptResponse struct { - Ciphertext PVECiphertext -} - -// Encrypt backs up a single secret scalar using the configuration bound to the -// receiving PVE handle. All semantic requirements remain identical to the -// former package-level helper that this method replaces. -func (p *PVE) Encrypt(req *PVEEncryptRequest) (*PVEEncryptResponse, error) { - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if req.PrivateValue == nil { - return nil, fmt.Errorf("private value cannot be nil") - } - if len(req.PublicKey) == 0 { - return nil, fmt.Errorf("public key cannot be empty") - } - if req.Label == "" { - return nil, fmt.Errorf("label cannot be empty") - } - if req.Curve == nil { - return nil, fmt.Errorf("curve cannot be nil") - } - - // Ensure the correct KEM instance is active in the native layer on this OS thread. - runtime.LockOSThread() - defer runtime.UnlockOSThread() - p.activateCtx() - - cipher, err := cgobinding.PVE_encrypt( - []byte(req.PublicKey), - req.PrivateValue.Bytes, - req.Label, - curve.Code(req.Curve), - ) - if err != nil { - return nil, err - } - - return &PVEEncryptResponse{Ciphertext: PVECiphertext(cipher)}, nil -} - -// ====================== Decryption ========================================== - -type PVEDecryptRequest struct { - PrivateKey BaseEncPrivateKey - Ciphertext PVECiphertext - Curve curve.Curve - Label string -} - -type PVEDecryptResponse struct { - PrivateValue *curve.Scalar -} - -// Decrypt recovers the secret scalar from a previously produced ciphertext. -// It mirrors the behaviour of the former PVEDecrypt helper. -func (p *PVE) Decrypt(req *PVEDecryptRequest) (*PVEDecryptResponse, error) { - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if len(req.PrivateKey) == 0 { - return nil, fmt.Errorf("private key cannot be empty") - } - if len(req.Ciphertext) == 0 { - return nil, fmt.Errorf("ciphertext cannot be empty") - } - if req.Label == "" { - return nil, fmt.Errorf("label cannot be empty") - } - if req.Curve == nil { - return nil, fmt.Errorf("curve cannot be nil") - } - - // Ensure the correct KEM instance is active in the native layer on this OS thread. - runtime.LockOSThread() - defer runtime.UnlockOSThread() - p.activateCtx() - - xBytes, err := cgobinding.PVE_decrypt( - []byte(req.PrivateKey), - []byte(req.Ciphertext), - req.Label, - curve.Code(req.Curve), - ) - if err != nil { - return nil, err - } - - // Ensure correct length relative to curve order - orderLen := len(req.Curve.Order()) - if len(xBytes) > orderLen { - xBytes = xBytes[len(xBytes)-orderLen:] - } - - return &PVEDecryptResponse{PrivateValue: &curve.Scalar{Bytes: xBytes}}, nil -} - -// ====================== Verification ======================================== - -type PVEVerifyRequest struct { - PublicKey BaseEncPublicKey - Ciphertext PVECiphertext - PublicShare *curve.Point - Label string -} - -type PVEVerifyResponse struct { - Valid bool -} - -// Verify checks whether the ciphertext is a valid encryption of the provided -// public share under the embedded PKE scheme. The logic is unchanged from the -// old stand-alone function. -func (p *PVE) Verify(req *PVEVerifyRequest) (*PVEVerifyResponse, error) { - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if len(req.PublicKey) == 0 { - return nil, fmt.Errorf("public key cannot be empty") - } - if len(req.Ciphertext) == 0 { - return nil, fmt.Errorf("ciphertext cannot be empty") - } - if req.PublicShare == nil { - return nil, fmt.Errorf("public share cannot be nil") - } - if req.Label == "" { - return nil, fmt.Errorf("label cannot be empty") - } - - // Ensure the correct KEM instance is active in the native layer on this OS thread. - runtime.LockOSThread() - defer runtime.UnlockOSThread() - p.activateCtx() - - err := cgobinding.PVE_verify( - []byte(req.PublicKey), - []byte(req.Ciphertext), - req.PublicShare.Bytes(), - req.Label, - ) - if err != nil { - return &PVEVerifyResponse{Valid: false}, err - } - return &PVEVerifyResponse{Valid: true}, nil -} diff --git a/demos-go/cb-mpc-go/api/mpc/pve_ac.go b/demos-go/cb-mpc-go/api/mpc/pve_ac.go deleted file mode 100644 index 1b679898..00000000 --- a/demos-go/cb-mpc-go/api/mpc/pve_ac.go +++ /dev/null @@ -1,257 +0,0 @@ -package mpc - -import ( - "fmt" - "runtime" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// BaseEncPrivateKey is an opaque byte slice holding a serialized base enc private key (KEM dk) -type BaseEncPrivateKey []byte - -// BaseEncPublicKey is an opaque byte slice holding a serialized base enc public key (KEM ek) -type BaseEncPublicKey []byte - -// PVECiphertext is an opaque byte slice holding a serialized PVE bundle returned by encryption. -type PVECiphertext []byte - -// PVEAcEncryptRequest represents a request for PVE encryption (backup) -type PVEAcEncryptRequest struct { - AccessStructure *AccessStructure // Quorum policy & curve description - PublicKeys map[string]BaseEncPublicKey // Map of leaf name -> public encryption key - Curve curve.Curve // Optional override curve (nil => derive from AccessStructure or default P-256) - PrivateValues []*curve.Scalar // Private data to backup (key shares) - Label string // Human-readable label bound to the backup -} - -type PVEAcEncryptResponse struct{ EncryptedBundle PVECiphertext } - -// AcEncrypt performs publicly verifiable encryption of private shares for backup using the active KEM backend. -func (p *PVE) AcEncrypt(req *PVEAcEncryptRequest) (*PVEAcEncryptResponse, error) { - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if len(req.PrivateValues) == 0 { - return nil, fmt.Errorf("private shares cannot be empty") - } - if len(req.PublicKeys) == 0 { - return nil, fmt.Errorf("public keys map cannot be empty") - } - if req.Label == "" { - return nil, fmt.Errorf("label cannot be empty") - } - if req.AccessStructure == nil { - return nil, fmt.Errorf("access structure cannot be nil") - } - // Determine curve - if req.Curve == nil { - if req.AccessStructure.Curve != nil { - req.Curve = req.AccessStructure.Curve - } else { - p256, err := curve.NewP256() - if err != nil { - return nil, fmt.Errorf("failed to initialise default curve: %v", err) - } - req.Curve = p256 - } - } - // Build inputs - names := make([][]byte, 0, len(req.PublicKeys)) - pubKeys := make([][]byte, 0, len(req.PublicKeys)) - for name, key := range req.PublicKeys { - names = append(names, []byte(name)) - pubKeys = append(pubKeys, []byte(key)) - } - xs := make([][]byte, len(req.PrivateValues)) - for i, s := range req.PrivateValues { - if s == nil { - return nil, fmt.Errorf("private share %d is nil", i) - } - xs[i] = s.Bytes - } - acPtr := req.AccessStructure.toCryptoAC() - defer cgobinding.FreeAccessStructure(acPtr) - // Ensure the correct KEM instance is active in the native layer and run on one OS thread. - runtime.LockOSThread() - defer runtime.UnlockOSThread() - p.activateCtx() - rawBundle, err := cgobinding.PVE_AC_encrypt(acPtr, names, pubKeys, len(pubKeys), xs, len(xs), req.Label, curve.Code(req.Curve)) - if err != nil { - return nil, fmt.Errorf("PVE encryption failed: %v", err) - } - return &PVEAcEncryptResponse{EncryptedBundle: PVECiphertext(rawBundle)}, nil -} - -type PVEAcPartyDecryptRowRequest struct { - AccessStructure *AccessStructure - Path string - PrivateKey BaseEncPrivateKey - EncryptedBundle PVECiphertext - Label string - // RowIndex selects the commitment row to use during decryption. - // Theoretically it can be any value in [0, kappa). In practice, try RowIndex = 0 first; - // decryption should succeed. If RowIndex = 0 fails, it usually indicates a mismatch - // (e.g. label, public keys, or public shares), so halting and inspecting is preferable - // to iterating over other indices. - RowIndex int -} - -type PVEAcPartyDecryptRowResponse struct{ Share []byte } - -func (p *PVE) AcPartyDecryptRow(req *PVEAcPartyDecryptRowRequest) (*PVEAcPartyDecryptRowResponse, error) { - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if req.AccessStructure == nil { - return nil, fmt.Errorf("access structure cannot be nil") - } - if req.Path == "" { - return nil, fmt.Errorf("path cannot be empty") - } - if len(req.PrivateKey) == 0 { - return nil, fmt.Errorf("private key cannot be empty") - } - if req.Label == "" { - return nil, fmt.Errorf("label cannot be empty") - } - acPtr := req.AccessStructure.toCryptoAC() - defer cgobinding.FreeAccessStructure(acPtr) - runtime.LockOSThread() - defer runtime.UnlockOSThread() - p.activateCtx() - share, err := cgobinding.PVE_AC_party_decrypt_row(acPtr, []byte(req.PrivateKey), []byte(req.EncryptedBundle), req.Label, req.Path, req.RowIndex) - if err != nil { - return nil, err - } - return &PVEAcPartyDecryptRowResponse{Share: share}, nil -} - -type PVEAcAggregateToRestoreRowRequest struct { - AccessStructure *AccessStructure - EncryptedBundle PVECiphertext - Label string - // RowIndex must match the row used to collect party shares. - // Any value in [0, kappa) is valid, but the typical (and expected) choice is 0. - // If aggregation at 0 fails, prefer investigating input correctness over trying other indices. - RowIndex int - Shares map[string][]byte // path -> share -} - -type PVEAcAggregateToRestoreRowResponse struct{ PrivateValues []*curve.Scalar } - -func (p *PVE) AcAggregateToRestoreRow(req *PVEAcAggregateToRestoreRowRequest) (*PVEAcAggregateToRestoreRowResponse, error) { - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if req.AccessStructure == nil { - return nil, fmt.Errorf("access structure cannot be nil") - } - if req.Label == "" { - return nil, fmt.Errorf("label cannot be empty") - } - if len(req.Shares) == 0 { - return nil, fmt.Errorf("shares cannot be empty") - } - acPtr := req.AccessStructure.toCryptoAC() - defer cgobinding.FreeAccessStructure(acPtr) - paths := make([][]byte, 0, len(req.Shares)) - shares := make([][]byte, 0, len(req.Shares)) - for path, sh := range req.Shares { - paths = append(paths, []byte(path)) - shares = append(shares, sh) - } - runtime.LockOSThread() - defer runtime.UnlockOSThread() - p.activateCtx() - recovered, err := cgobinding.PVE_AC_aggregate_to_restore_row(acPtr, []byte(req.EncryptedBundle), req.Label, paths, shares, req.RowIndex) - if err != nil { - return nil, err - } - orderLen := len(req.AccessStructure.Curve.Order()) - scalars := make([]*curve.Scalar, len(recovered)) - for i, s := range recovered { - if len(s) > orderLen { - s = s[len(s)-orderLen:] - } - scalars[i] = &curve.Scalar{Bytes: s} - } - return &PVEAcAggregateToRestoreRowResponse{PrivateValues: scalars}, nil -} - -type PVEAcVerifyRequest struct { - AccessStructure *AccessStructure - PublicKeys map[string]BaseEncPublicKey - EncryptedBundle PVECiphertext - PublicShares []*curve.Point - Label string -} - -type PVEAcVerifyResponse struct{ Valid bool } - -// AcVerify checks whether the provided PVE ciphertext is valid with respect to the given public information. -func (p *PVE) AcVerify(req *PVEAcVerifyRequest) (*PVEAcVerifyResponse, error) { - if req == nil { - return nil, fmt.Errorf("request cannot be nil") - } - if req.AccessStructure == nil { - return nil, fmt.Errorf("access structure cannot be nil") - } - if len(req.PublicKeys) == 0 { - return nil, fmt.Errorf("public keys cannot be empty") - } - if len(req.PublicShares) == 0 { - return nil, fmt.Errorf("public shares cannot be empty") - } - if req.Label == "" { - return nil, fmt.Errorf("label cannot be empty") - } - acPtr := req.AccessStructure.toCryptoAC() - defer cgobinding.FreeAccessStructure(acPtr) - leafNames := collectLeafNames(req.AccessStructure.Root) - names := make([][]byte, len(leafNames)) - pubBytes := make([][]byte, len(leafNames)) - for i, name := range leafNames { - pk, ok := req.PublicKeys[name] - if !ok { - return nil, fmt.Errorf("missing public key for leaf %s", name) - } - names[i] = []byte(name) - pubBytes[i] = []byte(pk) - } - xsBytes := make([][]byte, len(req.PublicShares)) - for i, pt := range req.PublicShares { - if pt == nil { - return nil, fmt.Errorf("public share %d is nil", i) - } - xsBytes[i] = pt.Bytes() - } - runtime.LockOSThread() - defer runtime.UnlockOSThread() - p.activateCtx() - if err := cgobinding.PVE_AC_verify(acPtr, names, pubBytes, len(pubBytes), []byte(req.EncryptedBundle), xsBytes, len(xsBytes), req.Label); err != nil { - return &PVEAcVerifyResponse{Valid: false}, err - } - return &PVEAcVerifyResponse{Valid: true}, nil -} - -// collectLeafNames performs a DFS traversal to return leaf names in deterministic order. -func collectLeafNames(root *AccessNode) []string { - var res []string - var walk func(n *AccessNode) - walk = func(n *AccessNode) { - if n == nil { - return - } - if n.Kind == KindLeaf { - res = append(res, n.Name) - return - } - for _, c := range n.Children { - walk(c) - } - } - walk(root) - return res -} diff --git a/demos-go/cb-mpc-go/api/mpc/pve_ac_test.go b/demos-go/cb-mpc-go/api/mpc/pve_ac_test.go deleted file mode 100644 index 58083a42..00000000 --- a/demos-go/cb-mpc-go/api/mpc/pve_ac_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package mpc - -import ( - "testing" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/testutil" - - "github.com/stretchr/testify/require" -) - -// TestPVEAcEncryptDecrypt performs a full encrypt → decrypt round-trip on a simple threshold access structure. -func TestPVEAcEncryptDecrypt(t *testing.T) { - const ( - nParties = 5 - threshold = 3 - ) - - // Prepare curve instance (use P-256 for speed). - cv, err := curve.NewP256() - require.NoError(t, err) - defer cv.Free() - - // Party names and access structure. - pnames := mocknet.GeneratePartyNames(nParties) - ac := createThresholdAccessStructure(pnames, threshold, cv) - - // PVE handle with XOR KEM test backend (defined in pve_test.go) - pve, err := NewPVE(Config{KEM: newTestXorKEM()}) - require.NoError(t, err) - - // Generate base encryption key pairs for every leaf using the test KEM. - pubMap := make(map[string]BaseEncPublicKey, nParties) - prvMap := make(map[string]BaseEncPrivateKey, nParties) - for _, name := range pnames { - dk, ek, err := newTestXorKEM().Generate() - require.NoError(t, err) - pubMap[name] = BaseEncPublicKey(ek) - prvMap[name] = BaseEncPrivateKey(dk) - } - - // Generate random private values to back-up. - privValues := make([]*curve.Scalar, nParties) - for i := 0; i < nParties; i++ { - s, err := cv.RandomScalar() - require.NoError(t, err) - privValues[i] = s - } - - pubShares := make([]*curve.Point, nParties) - for i, s := range privValues { - pt, err := cv.MultiplyGenerator(s) - require.NoError(t, err) - pubShares[i] = pt - } - - // Encrypt - encResp, err := pve.AcEncrypt(&PVEAcEncryptRequest{ - AccessStructure: ac, - PublicKeys: pubMap, - PrivateValues: privValues, - Label: "unit-test-backup", - Curve: cv, - }) - require.NoError(t, err) - require.Greater(t, len(encResp.EncryptedBundle), 0, "ciphertext should not be empty") - - // Verify ciphertext prior to decryption - verifyResp, err := pve.AcVerify(&PVEAcVerifyRequest{ - AccessStructure: ac, - PublicKeys: pubMap, - EncryptedBundle: encResp.EncryptedBundle, - PublicShares: pubShares, - Label: "unit-test-backup", - }) - require.NoError(t, err) - require.NotNil(t, verifyResp) - require.True(t, verifyResp.Valid, "verification should succeed on authentic ciphertext") - - // Tamper with ciphertext - tampered := make([]byte, len(encResp.EncryptedBundle)) - copy(tampered, encResp.EncryptedBundle) - if len(tampered) > 0 { - tampered[0] ^= 0xFF // flip first byte - } - - testutil.TSilence(t, func(t *testing.T) { - verifyResp, err = pve.AcVerify(&PVEAcVerifyRequest{ - AccessStructure: ac, - PublicKeys: pubMap, - EncryptedBundle: PVECiphertext(tampered), - PublicShares: pubShares, - Label: "unit-test-backup", - }) - }) - require.Error(t, err) - require.NotNil(t, verifyResp) - require.False(t, verifyResp.Valid, "verification should fail on tampered ciphertext") - - // Decrypt - shares := make(map[string][]byte) - for _, name := range pnames { - resp, err := pve.AcPartyDecryptRow(&PVEAcPartyDecryptRowRequest{ - AccessStructure: ac, - Path: name, - PrivateKey: prvMap[name], - EncryptedBundle: encResp.EncryptedBundle, - Label: "unit-test-backup", - RowIndex: 0, - }) - require.NoError(t, err) - shares[name] = resp.Share - } - aggResp, err := pve.AcAggregateToRestoreRow(&PVEAcAggregateToRestoreRowRequest{ - AccessStructure: ac, - EncryptedBundle: encResp.EncryptedBundle, - Label: "unit-test-backup", - RowIndex: 0, - Shares: shares, - }) - require.NoError(t, err) - require.Equal(t, len(privValues), len(aggResp.PrivateValues)) - - // Compare recovered values with originals. - for i := 0; i < nParties; i++ { - require.Equal(t, privValues[i].Bytes, aggResp.PrivateValues[i].Bytes) - } -} - -// TestPVEAcWithRSAHSMKEM verifies quorum PVE with an HSM-like RSA KEM backend. -func TestPVEAcWithRSAHSMKEM(t *testing.T) { - const ( - nParties = 4 - threshold = 2 - ) - - cv, err := curve.NewP256() - require.NoError(t, err) - defer cv.Free() - - pnames := mocknet.GeneratePartyNames(nParties) - ac := createThresholdAccessStructure(pnames, threshold, cv) - - // Use a single HSM-like KEM instance so handles resolve correctly. - hsm := newRSAHSMKEM() - pve, err := NewPVE(Config{KEM: hsm}) - require.NoError(t, err) - - pubMap := make(map[string]BaseEncPublicKey, nParties) - prvMap := make(map[string]BaseEncPrivateKey, nParties) - for _, name := range pnames { - dk, ek, err := hsm.Generate() - require.NoError(t, err) - pubMap[name] = BaseEncPublicKey(ek) - prvMap[name] = BaseEncPrivateKey(dk) - } - - privValues := make([]*curve.Scalar, nParties) - for i := 0; i < nParties; i++ { - s, err := cv.RandomScalar() - require.NoError(t, err) - privValues[i] = s - } - - pubShares := make([]*curve.Point, nParties) - for i, s := range privValues { - pt, err := cv.MultiplyGenerator(s) - require.NoError(t, err) - pubShares[i] = pt - } - - encResp, err := pve.AcEncrypt(&PVEAcEncryptRequest{ - AccessStructure: ac, - PublicKeys: pubMap, - PrivateValues: privValues, - Label: "rsa-hsm-quorum", - Curve: cv, - }) - require.NoError(t, err) - require.NotNil(t, encResp) - require.True(t, len(encResp.EncryptedBundle) > 0) - - verResp, err := pve.AcVerify(&PVEAcVerifyRequest{ - AccessStructure: ac, - PublicKeys: pubMap, - EncryptedBundle: encResp.EncryptedBundle, - PublicShares: pubShares, - Label: "rsa-hsm-quorum", - }) - require.NoError(t, err) - require.True(t, verResp.Valid) - - // Interactive decryption for RSA-HSM KEM - shares2 := make(map[string][]byte) - for _, name := range pnames { - resp, err := pve.AcPartyDecryptRow(&PVEAcPartyDecryptRowRequest{ - AccessStructure: ac, - Path: name, - PrivateKey: prvMap[name], - EncryptedBundle: encResp.EncryptedBundle, - Label: "rsa-hsm-quorum", - RowIndex: 0, - }) - require.NoError(t, err) - shares2[name] = resp.Share - } - aggResp2, err := pve.AcAggregateToRestoreRow(&PVEAcAggregateToRestoreRowRequest{ - AccessStructure: ac, - EncryptedBundle: encResp.EncryptedBundle, - Label: "rsa-hsm-quorum", - RowIndex: 0, - Shares: shares2, - }) - require.NoError(t, err) - require.Equal(t, len(privValues), len(aggResp2.PrivateValues)) - for i := 0; i < nParties; i++ { - require.Equal(t, privValues[i].Bytes, aggResp2.PrivateValues[i].Bytes) - } -} diff --git a/demos-go/cb-mpc-go/api/mpc/pve_instance.go b/demos-go/cb-mpc-go/api/mpc/pve_instance.go deleted file mode 100644 index e4d35644..00000000 --- a/demos-go/cb-mpc-go/api/mpc/pve_instance.go +++ /dev/null @@ -1,43 +0,0 @@ -package mpc - -import ( - "fmt" - "unsafe" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// KEM is a pluggable "key encapsulation mechanism" backend. -// We alias the low-level cgobinding.KEM interface so callers only have to -// satisfy a single contract across the whole code base. -type KEM = cgobinding.KEM - -type Config struct { - KEM KEM -} - -func (c *Config) normalise() error { - if c.KEM == nil { - return fmt.Errorf("pve: Config.KEM cannot be nil") - } - return nil -} - -// PVE is an instance-level façade around the PVE helpers. -type PVE struct { - kem KEM - ctx unsafe.Pointer // context handle passed to C -} - -func NewPVE(cfg Config) (*PVE, error) { - if err := cfg.normalise(); err != nil { - return nil, err - } - ctxPtr, err := cgobinding.RegisterKEMInstance(cfg.KEM) - if err != nil { - return nil, err - } - return &PVE{kem: cfg.KEM, ctx: ctxPtr}, nil -} - -func (p *PVE) activateCtx() { cgobinding.ActivateCtx(p.ctx) } diff --git a/demos-go/cb-mpc-go/api/mpc/pve_test.go b/demos-go/cb-mpc-go/api/mpc/pve_test.go deleted file mode 100644 index d42b28e0..00000000 --- a/demos-go/cb-mpc-go/api/mpc/pve_test.go +++ /dev/null @@ -1,773 +0,0 @@ -package mpc - -import ( - "bytes" - "crypto/ecdh" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/x509" - "encoding/binary" - "fmt" - "io" - "sync" - "testing" - "unsafe" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/testutil" - "github.com/stretchr/testify/require" -) - -// Deterministic reader derived from a seed using SHA-256(counter || seed) -type ctrRand struct { - seed [32]byte - counter uint64 - buf []byte - off int -} - -func newCTRRand(seed []byte) *ctrRand { - var s [32]byte - copy(s[:], seed) - return &ctrRand{seed: s} -} - -func (r *ctrRand) refill() { - ctrBytes := make([]byte, 8) - for i := 0; i < 8; i++ { - ctrBytes[7-i] = byte(r.counter >> (8 * i)) - } - h := sha256.Sum256(append(ctrBytes, r.seed[:]...)) - r.buf = h[:] - r.off = 0 - r.counter++ -} - -func (r *ctrRand) Read(p []byte) (int, error) { - n := 0 - for n < len(p) { - if r.off >= len(r.buf) { - r.refill() - } - m := copy(p[n:], r.buf[r.off:]) - r.off += m - n += m - } - return n, nil -} - -// TestPVEEncryptDecryptSingle performs an encrypt → verify → decrypt round-trip -// for the single-party PVE helpers. -func TestPVEEncryptDecryptSingle(t *testing.T) { - // Create a fresh PVE handle bound to the dummy XOR scheme so that tests - // run side-by-side without touching global state. - xorKEM := newTestXorKEM() - pve, err := NewPVE(Config{KEM: xorKEM}) - require.NoError(t, err) - - // Prepare curve (P-256). - cv, err := curve.NewP256() - require.NoError(t, err) - defer cv.Free() - - // Generate base enc key pair. - pub, prv, err := xorKEM.Generate() - require.NoError(t, err) - - // Secret scalar x and its public share Q = x*G. - x, err := cv.RandomScalar() - require.NoError(t, err) - - Q, err := cv.MultiplyGenerator(x) - require.NoError(t, err) - - // Encrypt. - encResp, err := pve.Encrypt(&PVEEncryptRequest{ - PublicKey: pub, - PrivateValue: x, - Curve: cv, - Label: "pve-single-test", - }) - require.NoError(t, err) - require.Greater(t, len(encResp.Ciphertext), 0) - - _, _ = prv, Q - - // Verify authentic ciphertext. - verResp, err := pve.Verify(&PVEVerifyRequest{ - PublicKey: pub, - Ciphertext: encResp.Ciphertext, - PublicShare: Q, - Label: "pve-single-test", - }) - require.NoError(t, err) - require.True(t, verResp.Valid) - - // Tamper with ciphertext and expect failure (silence C/C++ stderr during this expected error). - tampered := make([]byte, len(encResp.Ciphertext)) - copy(tampered, encResp.Ciphertext) - if len(tampered) > 0 { - tampered[len(tampered)-1] ^= 0xFF - } - testutil.TSilence(t, func(t *testing.T) { - _, err = pve.Verify(&PVEVerifyRequest{ - PublicKey: pub, - Ciphertext: PVECiphertext(tampered), - PublicShare: Q, - Label: "pve-single-test", - }) - }) - require.Error(t, err) - - // Derive public key from private key and print it. - _, err = xorKEM.DerivePub(prv) - require.NoError(t, err) - - // Decrypt. - decResp, err := pve.Decrypt(&PVEDecryptRequest{ - PrivateKey: prv, - Ciphertext: encResp.Ciphertext, - Curve: cv, - Label: "pve-single-test", - }) - _ = decResp - require.NoError(t, err) - require.NotNil(t, decResp.PrivateValue) - require.Equal(t, x.Bytes, decResp.PrivateValue.Bytes) - - // Concurrency coverage for thread-safety - t.Run("concurrent encrypt operations", func(t *testing.T) { - const goroutines = 10 - var wg sync.WaitGroup - errCh := make(chan error, goroutines) - wg.Add(goroutines) - for i := 0; i < goroutines; i++ { - go func() { - defer wg.Done() - xi, err := cv.RandomScalar() - if err != nil { - errCh <- err - return - } - _, err = pve.Encrypt(&PVEEncryptRequest{ - PublicKey: pub, - PrivateValue: xi, - Curve: cv, - Label: "concurrent-test", - }) - errCh <- err - }() - } - wg.Wait() - close(errCh) - for e := range errCh { - require.NoError(t, e) - } - }) - - t.Run("concurrent verify operations", func(t *testing.T) { - const goroutines = 10 - var wg sync.WaitGroup - errCh := make(chan error, goroutines) - wg.Add(goroutines) - for i := 0; i < goroutines; i++ { - go func() { - defer wg.Done() - _, err := pve.Verify(&PVEVerifyRequest{ - PublicKey: pub, - Ciphertext: encResp.Ciphertext, - PublicShare: Q, - Label: "pve-single-test", - }) - errCh <- err - }() - } - wg.Wait() - close(errCh) - for e := range errCh { - require.NoError(t, e) - } - }) - - t.Run("concurrent decrypt operations", func(t *testing.T) { - const goroutines = 10 - var wg sync.WaitGroup - errCh := make(chan error, goroutines) - wg.Add(goroutines) - for i := 0; i < goroutines; i++ { - go func() { - defer wg.Done() - dec, err := pve.Decrypt(&PVEDecryptRequest{ - PrivateKey: prv, - Ciphertext: encResp.Ciphertext, - Curve: cv, - Label: "pve-single-test", - }) - if err == nil && dec != nil && dec.PrivateValue != nil && !bytes.Equal(dec.PrivateValue.Bytes, x.Bytes) { - err = fmt.Errorf("decrypted value mismatch") - } - errCh <- err - }() - } - wg.Wait() - close(errCh) - for e := range errCh { - require.NoError(t, e) - } - }) -} - -type testXorKEM struct{} - -func newTestXorKEM() testXorKEM { return testXorKEM{} } - -func (testXorKEM) Generate() ([]byte, []byte, error) { - key := make([]byte, 1) - rand.Read(key) // any single byte != 0 is fine - return key, key, nil -} - -func (testXorKEM) Encapsulate(ek []byte, rho [32]byte) ([]byte, []byte, error) { - // Use rho directly as the shared secret for determinism. - ss := make([]byte, 32) - copy(ss, rho[:]) - ct := make([]byte, len(ss)) - for i := range ss { - ct[i] = ss[i] ^ ek[0] - } - return ct, ss, nil -} - -func (testXorKEM) Decapsulate(skHandle unsafe.Pointer, ct []byte) ([]byte, error) { - var keyByte byte - if skHandle != nil { - // In this test, skHandle may point to a cmem_t with the bytes - type cmem_t struct { - data *byte - size int32 - } - cm := (*cmem_t)(skHandle) - if cm != nil && cm.data != nil && cm.size > 0 { - keyByte = *(*byte)(unsafe.Pointer(cm.data)) - } else { - keyByte = byte(uintptr(skHandle) & 0xFF) - } - } - out := make([]byte, len(ct)) - for i := range ct { - out[i] = ct[i] ^ keyByte - } - return out, nil -} - -func (testXorKEM) DerivePub(dk []byte) ([]byte, error) { - // For this toy scheme, public key equals private key (XOR key byte). - ek := make([]byte, len(dk)) - copy(ek, dk) - return ek, nil -} - -func TestPVERoundTrip(t *testing.T) { - // Create a fresh PVE handle bound to the XOR backend. - pve, err := NewPVE(Config{KEM: newTestXorKEM()}) - require.NoError(t, err) - - // 1) Create demo key-pair using the XOR backend. - dk, ek, err := newTestXorKEM().Generate() - if err != nil { - t.Fatalf("generate: %v", err) - } - - rho := []byte("demo‑rho 32 bytes pad pad pad pad!!")[:32] - var rhoArr [32]byte - copy(rhoArr[:], rho) - - // 2) Run through the bridges without touching any C++ code: - ct, ss, err := pve.kem.Encapsulate(ek, rhoArr) - if err != nil { - t.Fatalf("encapsulate: %v", err) - } - // Wrap dk bytes into a temporary cmem_t and pass its address as the handle - type cmem_t struct { - data *byte - size int32 - } - var cm cmem_t - if len(dk) > 0 { - cm.data = (*byte)(unsafe.Pointer(&dk[0])) - cm.size = int32(len(dk)) - } - ss2, err := pve.kem.Decapsulate(unsafe.Pointer(&cm), ct) - if err != nil { - t.Fatalf("decapsulate: %v", err) - } - if !bytes.Equal(ss2, ss) { - t.Fatal("round‑trip failed") - } -} - -// ============================================================================= -// Additional back-ends used for coexistence / HSM style tests - -type testShiftKEM struct{} - -func (testShiftKEM) Generate() (skRef, ek []byte, err error) { - b := make([]byte, 1) - rand.Read(b) - return b, b, nil // same byte represents both halves -} - -func (testShiftKEM) Encapsulate(ek []byte, rho [32]byte) ([]byte, []byte, error) { - ss := make([]byte, 32) - copy(ss, rho[:]) - shift := ek[0] + 1 - ct := make([]byte, len(ss)) - for i := range ss { - ct[i] = ss[i] ^ shift - } - return ct, ss, nil -} - -func (testShiftKEM) Decapsulate(skHandle unsafe.Pointer, ct []byte) ([]byte, error) { - // Prefer reading from cmem_t (byte-backed dk) - type cmem_t struct { - data *byte - size int32 - } - var keyByte byte - cm := (*cmem_t)(skHandle) - if cm != nil && cm.data != nil && cm.size > 0 { - keyByte = *(*byte)(unsafe.Pointer(cm.data)) - } else { - keyByte = byte(uintptr(skHandle) & 0xFF) - } - shift := keyByte + 1 - out := make([]byte, len(ct)) - for i := range ct { - out[i] = ct[i] ^ shift - } - return out, nil -} - -func (testShiftKEM) DerivePub(skRef []byte) ([]byte, error) { - ek := make([]byte, len(skRef)) - copy(ek, skRef) - return ek, nil -} - -// hsmStubKEM imitates a hardware token: Generate returns a 32-bit handle. - -type Handle = uint32 - -type hsmStubKEM struct{ store map[Handle]byte } - -func newHSMStub() *hsmStubKEM { return &hsmStubKEM{store: make(map[Handle]byte)} } - -func (h *hsmStubKEM) Generate() (skRef, ek []byte, err error) { - // 4-byte little-endian handle - handle := make([]byte, 4) - rand.Read(handle) - key := make([]byte, 1) - rand.Read(key) - h.store[Handle(binary.LittleEndian.Uint32(handle))] = key[0] - return handle, key, nil -} - -func (h *hsmStubKEM) Encapsulate(ek []byte, rho [32]byte) ([]byte, []byte, error) { - ss := make([]byte, 32) - copy(ss, rho[:]) - ct := make([]byte, len(ss)) - for i := range ss { - ct[i] = ss[i] ^ ek[0] - } - return ct, ss, nil -} - -func (h *hsmStubKEM) Decapsulate(skHandle unsafe.Pointer, ct []byte) ([]byte, error) { - // Prefer cmem_t-backed handle; parse first 4 bytes little-endian. - type cmem_t struct { - data *byte - size int32 - } - var handle Handle - cm := (*cmem_t)(skHandle) - if cm != nil && cm.data != nil && cm.size > 0 { - dk := unsafe.Slice((*byte)(unsafe.Pointer(cm.data)), int(cm.size)) - if len(dk) >= 4 { - handle = Handle(binary.LittleEndian.Uint32(dk[:4])) - } else { - handle = Handle(dk[0]) - } - } else { - handle = Handle(uint32(uintptr(skHandle) & 0xffffffff)) - } - keyByte, ok := h.store[handle] - if !ok { - return nil, fmt.Errorf("unknown handle") - } - out := make([]byte, len(ct)) - for i := range ct { - out[i] = ct[i] ^ keyByte - } - return out, nil -} - -func (h *hsmStubKEM) DerivePub(skRef []byte) ([]byte, error) { - var handle Handle - if len(skRef) >= 4 { - handle = Handle(binary.LittleEndian.Uint32(skRef[:4])) - } else if len(skRef) >= 1 { - handle = Handle(skRef[0]) - } else { - return nil, fmt.Errorf("invalid handle ref") - } - key, ok := h.store[handle] - if !ok { - return nil, fmt.Errorf("unknown handle %x", skRef) - } - return []byte{key}, nil -} - -// Tests - -func TestCoexistingBackends(t *testing.T) { - xorPVE, err := NewPVE(Config{KEM: newTestXorKEM()}) - require.NoError(t, err) - - shiftPVE, err := NewPVE(Config{KEM: testShiftKEM{}}) - require.NoError(t, err) - - cv, _ := curve.NewP256() - - // Common inputs - x, _ := cv.RandomScalar() - Q, _ := cv.MultiplyGenerator(x) - - // XOR backend - dkXor, ekXor, _ := newTestXorKEM().Generate() - encXor, _ := xorPVE.Encrypt(&PVEEncryptRequest{PublicKey: ekXor, PrivateValue: x, Curve: cv, Label: "coexist"}) - decXor, _ := xorPVE.Decrypt(&PVEDecryptRequest{PrivateKey: dkXor, Ciphertext: encXor.Ciphertext, Curve: cv, Label: "coexist"}) - require.Equal(t, x.Bytes, decXor.PrivateValue.Bytes) - - // Shift backend - skShift, ekShift, _ := testShiftKEM{}.Generate() - encShift, _ := shiftPVE.Encrypt(&PVEEncryptRequest{PublicKey: ekShift, PrivateValue: x, Curve: cv, Label: "coexist"}) - decShift, _ := shiftPVE.Decrypt(&PVEDecryptRequest{PrivateKey: skShift, Ciphertext: encShift.Ciphertext, Curve: cv, Label: "coexist"}) - require.Equal(t, x.Bytes, decShift.PrivateValue.Bytes) - - // Final sanity: XOR ciphertext should not decrypt under Shift backend and vice-versa. - testutil.TSilence(t, func(t *testing.T) { - _, err = shiftPVE.Decrypt(&PVEDecryptRequest{PrivateKey: skShift, Ciphertext: encXor.Ciphertext, Curve: cv, Label: "coexist"}) - }) - require.Error(t, err) - - testutil.TSilence(t, func(t *testing.T) { - _, err = xorPVE.Decrypt(&PVEDecryptRequest{PrivateKey: dkXor, Ciphertext: encShift.Ciphertext, Curve: cv, Label: "coexist"}) - }) - require.Error(t, err) - _ = Q -} - -func TestHSMStub(t *testing.T) { - hsm := newHSMStub() - pve, err := NewPVE(Config{KEM: hsm}) - require.NoError(t, err) - - cv, _ := curve.NewP256() - x, _ := cv.RandomScalar() - - skRef, ek, err := hsm.Generate() - require.NoError(t, err) - - encResp, err := pve.Encrypt(&PVEEncryptRequest{PublicKey: ek, PrivateValue: x, Curve: cv, Label: "hsm-demo"}) - require.NoError(t, err) - - decResp, err := pve.Decrypt(&PVEDecryptRequest{PrivateKey: skRef, Ciphertext: encResp.Ciphertext, Curve: cv, Label: "hsm-demo"}) - require.NoError(t, err) - require.Equal(t, x.Bytes, decResp.PrivateValue.Bytes) - - SecureWipe(skRef) -} - -// ============================= -// Go-defined RSA KEM (toy) -// ============================= - -type rsaGoKEM struct { - prv *rsa.PrivateKey - pub *rsa.PublicKey -} - -func newRSAGoKEM() (*rsaGoKEM, error) { - k, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - return &rsaGoKEM{prv: k, pub: &k.PublicKey}, nil -} - -func (r *rsaGoKEM) Generate() ([]byte, []byte, error) { - prvBytes := x509.MarshalPKCS1PrivateKey(r.prv) - pubBytes := x509.MarshalPKCS1PublicKey(r.pub) - return prvBytes, pubBytes, nil -} - -func (r *rsaGoKEM) Encapsulate(ek []byte, rho [32]byte) ([]byte, []byte, error) { - pub, err := x509.ParsePKCS1PublicKey(ek) - if err != nil { - return nil, nil, err - } - label := []byte("pve-rsa-go") - ctr := newCTRRand(rho[:]) - ct, err := rsa.EncryptOAEP(sha256.New(), ctr, pub, rho[:], label) - if err != nil { - return nil, nil, err - } - ss := make([]byte, 32) - copy(ss, rho[:]) - return ct, ss, nil -} - -func (r *rsaGoKEM) Decapsulate(skHandle unsafe.Pointer, ct []byte) ([]byte, error) { - // For Go RSA KEM tests we still expect raw bytes; when called directly we pass a cmem_t. - type cmem_t struct { - data *byte - size int32 - } - cm := (*cmem_t)(skHandle) - if cm == nil || cm.data == nil || cm.size <= 0 { - return nil, fmt.Errorf("invalid sk handle") - } - dk := unsafe.Slice((*byte)(unsafe.Pointer(cm.data)), int(cm.size)) - prv, err := x509.ParsePKCS1PrivateKey(dk) - if err != nil { - return nil, err - } - label := []byte("pve-rsa-go") - ss, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, prv, ct, label) - if err != nil { - return nil, err - } - out := make([]byte, 32) - copy(out, ss) - return out, nil -} - -func (r *rsaGoKEM) DerivePub(dk []byte) ([]byte, error) { - prv, err := x509.ParsePKCS1PrivateKey(dk) - if err != nil { - return nil, err - } - return x509.MarshalPKCS1PublicKey(&prv.PublicKey), nil -} - -// ============================= -// Go-defined ECDH KEM (toy, P-256 + HKDF=truncate) -// ============================= - -type ecdhGoKEM struct{} - -func newECDHGoKEM() *ecdhGoKEM { return &ecdhGoKEM{} } - -func (e *ecdhGoKEM) Generate() ([]byte, []byte, error) { - cv := ecdh.P256() - priv, err := cv.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - return priv.Bytes(), priv.PublicKey().Bytes(), nil -} - -func (e *ecdhGoKEM) Encapsulate(ek []byte, rho [32]byte) ([]byte, []byte, error) { - cv := ecdh.P256() - peerPub, err := cv.NewPublicKey(ek) - if err != nil { - return nil, nil, err - } - // ephemeral key - ephPriv, err := cv.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - ss, err := ephPriv.ECDH(peerPub) - if err != nil { - return nil, nil, err - } - ct := ephPriv.PublicKey().Bytes() - return ct, ss, nil -} - -func (e *ecdhGoKEM) Decapsulate(skHandle unsafe.Pointer, ct []byte) ([]byte, error) { - // Expect cmem_t pointing to private key bytes - type cmem_t struct { - data *byte - size int32 - } - cm := (*cmem_t)(skHandle) - if cm == nil || cm.data == nil || cm.size <= 0 { - return nil, fmt.Errorf("invalid sk handle") - } - dk := unsafe.Slice((*byte)(unsafe.Pointer(cm.data)), int(cm.size)) - cv := ecdh.P256() - priv, err := cv.NewPrivateKey(dk) - if err != nil { - return nil, err - } - pub, err := cv.NewPublicKey(ct) - if err != nil { - return nil, err - } - ss, err := priv.ECDH(pub) - if err != nil { - return nil, err - } - return ss, nil -} - -func (e *ecdhGoKEM) DerivePub(dk []byte) ([]byte, error) { - cv := ecdh.P256() - priv, err := cv.NewPrivateKey(dk) - if err != nil { - return nil, err - } - return priv.PublicKey().Bytes(), nil -} - -// ============================= -// RSA KEM with HSM-like handle simulation -// ============================= - -type rsaHSMKEM struct { - store map[Handle]*rsa.PrivateKey -} - -func newRSAHSMKEM() *rsaHSMKEM { return &rsaHSMKEM{store: make(map[Handle]*rsa.PrivateKey)} } - -func (h *rsaHSMKEM) Generate() (skRef, ek []byte, err error) { - k, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - handle := make([]byte, 4) - rand.Read(handle) - h.store[Handle(binary.LittleEndian.Uint32(handle))] = k - pubBytes := x509.MarshalPKCS1PublicKey(&k.PublicKey) - return handle, pubBytes, nil -} - -func (h *rsaHSMKEM) Encapsulate(ek []byte, rho [32]byte) ([]byte, []byte, error) { - pub, err := x509.ParsePKCS1PublicKey(ek) - if err != nil { - return nil, nil, err - } - label := []byte("pve-rsa-hsm") - var rng io.Reader = newCTRRand(rho[:]) - ct, err := rsa.EncryptOAEP(sha256.New(), rng, pub, rho[:], label) - if err != nil { - return nil, nil, err - } - ss := make([]byte, 32) - copy(ss, rho[:]) - return ct, ss, nil -} - -func (h *rsaHSMKEM) Decapsulate(skHandle unsafe.Pointer, ct []byte) ([]byte, error) { - // Prefer cmem_t-backed handle; parse first 4 bytes little-endian. - type cmem_t struct { - data *byte - size int32 - } - var handle Handle - cm := (*cmem_t)(skHandle) - if cm != nil && cm.data != nil && cm.size > 0 { - dk := unsafe.Slice((*byte)(unsafe.Pointer(cm.data)), int(cm.size)) - if len(dk) >= 4 { - handle = Handle(binary.LittleEndian.Uint32(dk[:4])) - } else if len(dk) >= 1 { - handle = Handle(dk[0]) - } else { - return nil, fmt.Errorf("invalid handle") - } - } else { - handle = Handle(uint32(uintptr(skHandle) & 0xffffffff)) - } - k, ok := h.store[handle] - if !ok { - return nil, fmt.Errorf("unknown handle") - } - label := []byte("pve-rsa-hsm") - ss, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, k, ct, label) - if err != nil { - return nil, err - } - out := make([]byte, 32) - copy(out, ss) - return out, nil -} - -func (h *rsaHSMKEM) DerivePub(skRef []byte) ([]byte, error) { - var handle Handle - if len(skRef) >= 4 { - handle = Handle(binary.LittleEndian.Uint32(skRef[:4])) - } else if len(skRef) >= 1 { - handle = Handle(skRef[0]) - } else { - return nil, fmt.Errorf("invalid handle ref") - } - k, ok := h.store[handle] - if !ok { - return nil, fmt.Errorf("unknown handle %x", skRef) - } - return x509.MarshalPKCS1PublicKey(&k.PublicKey), nil -} - -// ============================= -// Tests for the above KEMs -// ============================= - -func TestPVEWithRSAGoKEM(t *testing.T) { - rsaK, err := newRSAGoKEM() - require.NoError(t, err) - pve, err := NewPVE(Config{KEM: rsaK}) - require.NoError(t, err) - cv, _ := curve.NewP256() - x, _ := cv.RandomScalar() - prvBytes, pubBytes, err := rsaK.Generate() - require.NoError(t, err) - enc, err := pve.Encrypt(&PVEEncryptRequest{PublicKey: pubBytes, PrivateValue: x, Curve: cv, Label: "rsa-go"}) - require.NoError(t, err) - dec, err := pve.Decrypt(&PVEDecryptRequest{PrivateKey: prvBytes, Ciphertext: enc.Ciphertext, Curve: cv, Label: "rsa-go"}) - require.NoError(t, err) - require.Equal(t, x.Bytes, dec.PrivateValue.Bytes) -} - -func TestPVEWithECDHGoKEM(t *testing.T) { - ecdhK := newECDHGoKEM() - pve, err := NewPVE(Config{KEM: ecdhK}) - require.NoError(t, err) - cv, _ := curve.NewP256() - x, _ := cv.RandomScalar() - prvBytes, pubBytes, err := ecdhK.Generate() - require.NoError(t, err) - enc, err := pve.Encrypt(&PVEEncryptRequest{PublicKey: pubBytes, PrivateValue: x, Curve: cv, Label: "ecdh-go"}) - require.NoError(t, err) - dec, err := pve.Decrypt(&PVEDecryptRequest{PrivateKey: prvBytes, Ciphertext: enc.Ciphertext, Curve: cv, Label: "ecdh-go"}) - require.NoError(t, err) - require.Equal(t, x.Bytes, dec.PrivateValue.Bytes) -} - -func TestPVEWithRSAHSMKEM(t *testing.T) { - hsm := newRSAHSMKEM() - pve, err := NewPVE(Config{KEM: hsm}) - require.NoError(t, err) - cv, _ := curve.NewP256() - x, _ := cv.RandomScalar() - skRef, ek, err := hsm.Generate() - require.NoError(t, err) - enc, err := pve.Encrypt(&PVEEncryptRequest{PublicKey: ek, PrivateValue: x, Curve: cv, Label: "rsa-hsm"}) - require.NoError(t, err) - dec, err := pve.Decrypt(&PVEDecryptRequest{PrivateKey: skRef, Ciphertext: enc.Ciphertext, Curve: cv, Label: "rsa-hsm"}) - require.NoError(t, err) - require.Equal(t, x.Bytes, dec.PrivateValue.Bytes) -} diff --git a/demos-go/cb-mpc-go/api/mpc/secure.go b/demos-go/cb-mpc-go/api/mpc/secure.go deleted file mode 100644 index c9ceed4a..00000000 --- a/demos-go/cb-mpc-go/api/mpc/secure.go +++ /dev/null @@ -1,23 +0,0 @@ -package mpc - -import ( - "crypto/subtle" - "runtime" -) - -// SecureWipe overwrites the given byte slice with zeros using a constant-time -// copy to minimise the risk of compiler optimisations removing the call. The -// slice length is left unchanged but its contents become all-zero. -// -// This is a best-effort helper – Go's garbage collector may still keep old -// copies alive until the next collection cycle. Use it immediately after you -// no longer need a secret key reference that contains raw key material. -func SecureWipe(buf []byte) { - if len(buf) == 0 { - return - } - zero := make([]byte, len(buf)) - subtle.ConstantTimeCopy(1, buf, zero) - // Keep the backing array alive until after the zeroization. - runtime.KeepAlive(&buf[0]) -} diff --git a/demos-go/cb-mpc-go/api/transport/doc.go b/demos-go/cb-mpc-go/api/transport/doc.go deleted file mode 100644 index 94b7d3f5..00000000 --- a/demos-go/cb-mpc-go/api/transport/doc.go +++ /dev/null @@ -1,24 +0,0 @@ -// Package transport defines the abstractions that glue MPC protocols to the -// underlying network. -// -// The core interface is `Messenger` which provides a minimal set of primitives -// understood by the native C++ engine: -// -// MessageSend(ctx, receiver, data) -// MessageReceive(ctx, sender) -// MessagesReceive(ctx, senders) -// -// A Messenger implementation does *not* need to care about protocol details – -// it simply delivers opaque byte slices between numbered parties. This -// deliberate design choice lets applications swap transport mechanisms without -// touching any of the cryptography. -// -// Out of the box the repository provides two implementations: -// -// - mocknet – an in-process, fully deterministic transport ideal for tests -// - mtls – a production-ready TCP transport that uses mutual-TLS for -// authentication and encryption -// -// You are encouraged to implement your own Messenger for custom deployment -// scenarios (e.g. gRPC, libp2p, message queues, …). -package transport diff --git a/demos-go/cb-mpc-go/api/transport/messenger.go b/demos-go/cb-mpc-go/api/transport/messenger.go deleted file mode 100644 index 9daad8de..00000000 --- a/demos-go/cb-mpc-go/api/transport/messenger.go +++ /dev/null @@ -1,18 +0,0 @@ -package transport - -import "context" - -// Messenger defines the interface for data transport in the CB-MPC system. -// Implementations of this interface handle message passing between MPC parties. -type Messenger interface { - // MessageSend sends a message buffer to the specified receiver party. - MessageSend(ctx context.Context, receiver int, buffer []byte) error - - // MessageReceive receives a message from the specified sender party. - MessageReceive(ctx context.Context, sender int) ([]byte, error) - - // MessagesReceive receives messages from multiple sender parties. It waits - // until all messages are ready and returns them in the same order as the - // provided senders slice. - MessagesReceive(ctx context.Context, senders []int) ([][]byte, error) -} diff --git a/demos-go/cb-mpc-go/api/transport/mocknet/doc.go b/demos-go/cb-mpc-go/api/transport/mocknet/doc.go deleted file mode 100644 index 9719c892..00000000 --- a/demos-go/cb-mpc-go/api/transport/mocknet/doc.go +++ /dev/null @@ -1,16 +0,0 @@ -// Package mocknet implements the `transport.Messenger` interface entirely in -// memory and is intended ONLY for testing or local development. -// -// A mock network is invaluable when writing unit- or integration-tests because -// it: -// - removes all external dependencies (no sockets, no certificates), -// - runs deterministically inside a single OS process, and -// - is orders of magnitude faster than loop-back TCP. -// -// Under the hood each party is backed by a pair of goroutines and a channel per -// direction which faithfully replicate the semantics of a real network while -// still sharing memory. -// -// For production deployments use the `mtls` transport or build your own -// Messenger that satisfies the `transport.Messenger` interface. -package mocknet diff --git a/demos-go/cb-mpc-go/api/transport/mocknet/runner.go b/demos-go/cb-mpc-go/api/transport/mocknet/runner.go deleted file mode 100644 index 7b690316..00000000 --- a/demos-go/cb-mpc-go/api/transport/mocknet/runner.go +++ /dev/null @@ -1,169 +0,0 @@ -package mocknet - -import ( - "fmt" - "sync" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// MPCIO represents input/output data for MPC operations -type MPCIO struct { - Opaque interface{} -} - -// MPCPeer represents a single party in the MPC protocol -type MPCPeer struct { - nParties int - roleIndex int - dataTransport *MockMessenger -} - -// MPCRunner provides utilities for running MPC protocols in a test environment -type MPCRunner struct { - nParties int - pnames []string - peers []*MPCPeer - isAbort bool -} - -// GeneratePartyNames returns the default party name list ("party_0", "party_1", ...) -// for the given number of parties. It is handy for tests and examples that do not -// require custom naming. -func GeneratePartyNames(n int) []string { - names := make([]string, n) - for i := 0; i < n; i++ { - names[i] = fmt.Sprintf("party_%d", i) - } - return names -} - -// NewMPCRunner creates a new MPCRunner with the specified party names. The caller -// should pass one name per party, e.g.: -// -// r := mocknet.NewMPCRunner("alice", "bob") -// -// For convenience, callers can generate the default names via GeneratePartyNames. -func NewMPCRunner(pnames ...string) *MPCRunner { - n := len(pnames) - if n == 0 { - panic("NewMPCRunner requires at least one party name") - } - - runner := &MPCRunner{nParties: n, pnames: pnames} - runner.peers = make([]*MPCPeer, n) - - // Create the mock network - transports := NewMockNetwork(n) - - // Create peers with their respective transports - for i := 0; i < n; i++ { - runner.peers[i] = &MPCPeer{ - nParties: n, - roleIndex: i, - dataTransport: transports[i], - } - } - return runner -} - -// MPCFunction2P represents a function for two-party MPC protocols -type MPCFunction2P func(net cgobinding.Job2P, input *MPCIO) (*MPCIO, error) - -// MPCFunctionMP represents a function for multi-party MPC protocols -type MPCFunctionMP func(net cgobinding.JobMP, input *MPCIO) (*MPCIO, error) - -// Run2P executes a two-party MPC protocol with the given function and inputs -func (runner *MPCRunner) MPCRun2P(f MPCFunction2P, inputs []*MPCIO) ([]*MPCIO, error) { - if runner.nParties != 2 { - return nil, fmt.Errorf("Run2P only supports 2 parties, got %d", runner.nParties) - } - errs := make([]error, runner.nParties) - outs := make([]*MPCIO, runner.nParties) - - runner.isAbort = false - - var wg sync.WaitGroup - wg.Add(runner.nParties) - for i := 0; i < runner.nParties; i++ { - go func(i int) { - defer wg.Done() - pnames := runner.pnames - job, err := cgobinding.NewJob2P(runner.peers[i].dataTransport, i, pnames) - if err != nil { - errs[i] = fmt.Errorf("failed to create Job2P: %w", err) - return - } - defer job.Free() - outs[i], errs[i] = f(job, inputs[i]) - if errs[i] != nil { // abort job - runner.isAbort = true - for j := 0; j < runner.nParties; j++ { - runner.peers[j].dataTransport.cond.Broadcast() - } - } - }(i) - } - wg.Wait() - - // Clean up after job - runner.cleanup() - - for _, err := range errs { - if err != nil { - return nil, err - } - } - return outs, nil -} - -// RunMP executes a multi-party MPC protocol with the given function and inputs -func (runner *MPCRunner) MPCRunMP(f MPCFunctionMP, inputs []*MPCIO) ([]*MPCIO, error) { - errs := make([]error, runner.nParties) - outs := make([]*MPCIO, runner.nParties) - - runner.isAbort = false - - var wg sync.WaitGroup - wg.Add(runner.nParties) - for i := 0; i < runner.nParties; i++ { - go func(i int) { - defer wg.Done() - // Use the configured party names directly - job, err := cgobinding.NewJobMP(runner.peers[i].dataTransport, runner.nParties, i, runner.pnames) - if err != nil { - errs[i] = fmt.Errorf("failed to create JobMP: %w", err) - return - } - defer job.Free() - outs[i], errs[i] = f(job, inputs[i]) - if errs[i] != nil { // abort job - runner.isAbort = true - for j := 0; j < runner.nParties; j++ { - runner.peers[j].dataTransport.cond.Broadcast() - } - } - }(i) - } - wg.Wait() - - // Clean up after job - runner.cleanup() - - for _, err := range errs { - if err != nil { - return nil, err - } - } - return outs, nil -} - -// cleanup resets the runner state and clears message queues -func (runner *MPCRunner) cleanup() { - runner.isAbort = false - for i := 0; i < runner.nParties; i++ { - for j := 0; j < runner.nParties; j++ { - runner.peers[i].dataTransport.queues[j].Init() - } - } -} diff --git a/demos-go/cb-mpc-go/api/transport/mocknet/runner_test.go b/demos-go/cb-mpc-go/api/transport/mocknet/runner_test.go deleted file mode 100644 index 4c7712ca..00000000 --- a/demos-go/cb-mpc-go/api/transport/mocknet/runner_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package mocknet - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewRunner(t *testing.T) { - // Test creating a runner with 2 parties - runner := NewMPCRunner(GeneratePartyNames(2)...) - assert.NotNil(t, runner) - assert.Equal(t, 2, runner.nParties) - assert.Len(t, runner.peers, 2) - - // Verify peers are properly initialized - for i, peer := range runner.peers { - assert.Equal(t, i, peer.roleIndex) - assert.Equal(t, 2, peer.nParties) - assert.NotNil(t, peer.dataTransport) - } -} - -func TestMockMessenger(t *testing.T) { - // Create a mock network with 3 parties - messengers := NewMockNetwork(3) - require.Len(t, messengers, 3) - - // Test message sending and receiving - message := []byte("test message") - - // Party 0 sends to party 1 - err := messengers[0].MessageSend(context.Background(), 1, message) - assert.NoError(t, err) - - // Party 1 receives from party 0 - received, err := messengers[1].MessageReceive(context.Background(), 0) - assert.NoError(t, err) - assert.Equal(t, message, received) -} - -func TestMockMessengerMultipleMessages(t *testing.T) { - // Create a mock network with 2 parties - messengers := NewMockNetwork(2) - - // Test receiving messages one by one to avoid order issues - message1 := []byte("message 1") - message2 := []byte("message 2") - - // Party 1 sends first message to party 0 - err := messengers[1].MessageSend(context.Background(), 0, message1) - assert.NoError(t, err) - - // Party 0 receives first message from party 1 - received1, err := messengers[0].MessageReceive(context.Background(), 1) - assert.NoError(t, err) - assert.Equal(t, message1, received1) - - // Party 1 sends second message to party 0 - err = messengers[1].MessageSend(context.Background(), 0, message2) - assert.NoError(t, err) - - // Party 0 receives second message from party 1 - received2, err := messengers[0].MessageReceive(context.Background(), 1) - assert.NoError(t, err) - assert.Equal(t, message2, received2) -} diff --git a/demos-go/cb-mpc-go/api/transport/mocknet/transport.go b/demos-go/cb-mpc-go/api/transport/mocknet/transport.go deleted file mode 100644 index eedc1c5b..00000000 --- a/demos-go/cb-mpc-go/api/transport/mocknet/transport.go +++ /dev/null @@ -1,114 +0,0 @@ -package mocknet - -import ( - "container/list" - "context" - "errors" - "sync" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport" -) - -// MockMessenger provides a mock implementation of the Messenger interface for testing. -// It uses in-memory message queues to simulate network communication between parties. -type MockMessenger struct { - roleIndex int - outs []*MockMessenger - mutex sync.Mutex - cond *sync.Cond - queues []list.List - isAbort bool -} - -// Ensure MockMessenger implements the Messenger interface -var _ transport.Messenger = (*MockMessenger)(nil) - -// NewMockMessenger creates a new MockMessenger instance for the specified party role -func NewMockMessenger(roleIndex int) *MockMessenger { - ctx := &MockMessenger{roleIndex: roleIndex} - ctx.cond = sync.NewCond(&ctx.mutex) - ctx.isAbort = false - return ctx -} - -// setOuts configures the connections to other mock transport instances. -// This is used internally by NewMockNetwork to wire up all parties. -func (dt *MockMessenger) setOuts(dts []*MockMessenger) { - dt.outs = dts - dt.queues = make([]list.List, len(dts)) -} - -// MessageSend sends a message to the specified receiver party -func (dt *MockMessenger) MessageSend(_ context.Context, receiverIndex int, buffer []byte) error { - if receiverIndex == dt.roleIndex { - return errors.New("cannot send to self") - } - - receiverDT := dt.outs[receiverIndex] - receiverDT.mutex.Lock() - receiverDT.queues[dt.roleIndex].PushBack(buffer) - receiverDT.mutex.Unlock() - receiverDT.cond.Broadcast() - - return nil -} - -// MessageReceive receives a message from the specified sender party -func (dt *MockMessenger) MessageReceive(_ context.Context, senderIndex int) ([]byte, error) { - if senderIndex == dt.roleIndex { - return nil, errors.New("cannot receive from self") - } - - dt.mutex.Lock() - defer dt.mutex.Unlock() - - if dt.isAbort { - return nil, errors.New("aborted") - } - queue := &dt.queues[senderIndex] - for queue.Len() == 0 { - dt.cond.Wait() - if dt.isAbort { - return nil, errors.New("aborted") - } - } - front := queue.Front() - receivedMsg := front.Value.([]byte) - queue.Remove(front) - return receivedMsg, nil -} - -// MessagesReceive receives messages from multiple sender parties concurrently -func (dt *MockMessenger) MessagesReceive(ctx context.Context, senderIndices []int) (receivedMsgs [][]byte, err error) { - n := len(senderIndices) - receivedMsgs = make([][]byte, n) - - var wg sync.WaitGroup - wg.Add(n) - for i, senderIndex := range senderIndices { - go func(i int, senderIndex int) { - var e error - receivedMsgs[i], e = dt.MessageReceive(ctx, senderIndex) - if e != nil { - err = e // Note: this is not thread-safe, but sufficient for testing - } - wg.Done() - }(i, senderIndex) - } - wg.Wait() - - return receivedMsgs, nil -} - -// NewMockNetwork creates a complete mock network with the specified number of parties. -// It returns a slice of MockMessenger instances, one for each party, already wired together. -func NewMockNetwork(nParties int) []*MockMessenger { - messengers := make([]*MockMessenger, nParties) - for i := 0; i < nParties; i++ { - messengers[i] = NewMockMessenger(i) - } - for i := 0; i < nParties; i++ { - messengers[i].setOuts(messengers) - } - return messengers -} diff --git a/demos-go/cb-mpc-go/api/transport/mtls/transport.go b/demos-go/cb-mpc-go/api/transport/mtls/transport.go deleted file mode 100644 index 7a155a6b..00000000 --- a/demos-go/cb-mpc-go/api/transport/mtls/transport.go +++ /dev/null @@ -1,321 +0,0 @@ -// Package mtls provides a production-ready implementation of the Messenger interface using mutual TLS. -// This serves as a reference implementation showing how to build secure, authenticated transport -// for multi-party computation protocols. -package mtls - -import ( - "context" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/binary" - "encoding/hex" - "fmt" - "io" - "net" - "sync" - "time" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport" - "golang.org/x/sync/errgroup" -) - -// MTLSMessenger implements the Messenger interface using mutual TLS authentication. -// It provides secure, authenticated communication between MPC parties. -type MTLSMessenger struct { - connections map[int]*tls.Conn - nameToIndex map[string]int - listener net.Listener - mu sync.RWMutex - timeout time.Duration - selfIndex int -} - -// Ensure MTLSMessenger implements the Messenger interface -var _ transport.Messenger = (*MTLSMessenger)(nil) - -// PartyConfig contains the configuration for a single party -type PartyConfig struct { - // Address should include the IP/hostname and port - Address string - Cert *x509.Certificate -} - -// Config contains the configuration for setting up mutual TLS transport -type Config struct { - // Parties must include the current party as well as all other parties, - // the key is the index of the party among all possible parties - Parties map[int]PartyConfig - CertPool *x509.CertPool - TLSCert tls.Certificate - NameToIndex map[string]int - SelfIndex int -} - -// PartyNameFromCertificate extracts a unique party name from a certificate by hashing its public key -func PartyNameFromCertificate(cert *x509.Certificate) (string, error) { - pubKeyBytes, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - if err != nil { - return "", fmt.Errorf("marshaling public key: %v", err) - } - hash := sha256.Sum256(pubKeyBytes) - pname := hex.EncodeToString(hash[:]) - return pname, nil -} - -// NewMTLSMessenger creates a new MTLSMessenger instance with the given configuration. -// It establishes TLS connections with all other parties according to a deterministic connection pattern. -func NewMTLSMessenger(config Config) (*MTLSMessenger, error) { - tlsConfig := &tls.Config{ - MinVersion: tls.VersionTLS13, - CipherSuites: nil, // use the safe default cipher suites - Certificates: []tls.Certificate{config.TLSCert}, - RootCAs: config.CertPool, - ClientCAs: config.CertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - if len(rawCerts) == 0 { - return fmt.Errorf("no server certificate provided") - } - - serverCert, err := x509.ParseCertificate(rawCerts[0]) - if err != nil { - return fmt.Errorf("parsing server certificate: %v", err) - } - - peerPname, err := PartyNameFromCertificate(serverCert) - if err != nil { - return fmt.Errorf("extracting peer name from server certificate: %v", err) - } - peerIndex, ok := config.NameToIndex[peerPname] - if !ok { - return fmt.Errorf("peer name %s not found in name to index map", peerPname) - } - if !serverCert.Equal(config.Parties[peerIndex].Cert) { - return fmt.Errorf("server certificate does not match the expected certificate: %v", peerIndex) - } - - return nil - }, - } - - expectedIncomingConnectionsCount := 0 - expectedOutgoingConnectionsCount := 0 - for i := range config.Parties { - if i < config.SelfIndex { - expectedOutgoingConnectionsCount++ - } - if i > config.SelfIndex { - expectedIncomingConnectionsCount++ - } - } - - fmt.Printf("Party %d: expected %d incoming connections and %d outgoing connections\n", config.SelfIndex, expectedIncomingConnectionsCount, expectedOutgoingConnectionsCount) - - transport := &MTLSMessenger{ - connections: make(map[int]*tls.Conn), - timeout: time.Hour * 1, - selfIndex: config.SelfIndex, - nameToIndex: config.NameToIndex, - } - - wg := sync.WaitGroup{} - - if expectedIncomingConnectionsCount != 0 { - myAddress := config.Parties[config.SelfIndex].Address - ln, err := tls.Listen("tcp", myAddress, tlsConfig) - if err != nil { - return nil, fmt.Errorf("starting server on %s: %v", myAddress, err) - } - transport.listener = ln - - for i := 0; i < expectedIncomingConnectionsCount; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - conn, err := ln.Accept() - if err != nil { - fmt.Printf("error accepting connection on %s: %v\n", myAddress, err) - return - } - c := conn.(*tls.Conn) - - // Explicitly complete the TLS handshake. This will let us access the peer certificates. - if err := c.Handshake(); err != nil { - fmt.Printf("TLS handshake failed: %v\n", err) - c.Close() - return - } - peerCerts := c.ConnectionState().PeerCertificates - if len(peerCerts) == 0 { - fmt.Printf("No peer certificates found\n") - return - } - peerName, err := PartyNameFromCertificate(peerCerts[0]) - if err != nil { - fmt.Printf("error extracting peer name from certificate: %v\n", err) - return - } - peerIndex, ok := transport.nameToIndex[peerName] - if !ok { - fmt.Printf("peer name %s not found in name to index map\n", peerName) - return - } - fmt.Printf("Party %d: peer %d connected\n", config.SelfIndex, peerIndex) - - transport.mu.Lock() - transport.connections[peerIndex] = c - transport.mu.Unlock() - }(i) - } - } - - for i, party := range config.Parties { - if i < config.SelfIndex { - // Exponential backoff with cap for connection retries - backoff := 1 * time.Second - maxBackoff := 20 * time.Second - attempts := 0 - for { - counterPartyAddress := party.Address - conn, err := tls.Dial("tcp", counterPartyAddress, tlsConfig) - if err != nil { - attempts++ - if attempts > 10 { - return nil, fmt.Errorf("connecting to %s: %v", counterPartyAddress, err) - } - time.Sleep(backoff) - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - continue - } - transport.mu.Lock() - transport.connections[i] = conn - transport.mu.Unlock() - break - } - } - } - - // Wait for all incoming connections to be established - wg.Wait() - - return transport, nil -} - -// MessageSend sends a message to the specified receiver party -func (dt *MTLSMessenger) MessageSend(_ context.Context, receiverIndex int, buffer []byte) error { - conn, ok := dt.connections[receiverIndex] - - if !ok { - return fmt.Errorf("no connection found for receiver index %d", receiverIndex) - } - - // Send message length first (4 bytes, big endian) - messageLength := uint32(len(buffer)) - lengthBytes := make([]byte, 4) - binary.BigEndian.PutUint32(lengthBytes, messageLength) - - if _, err := conn.Write(lengthBytes); err != nil { - return fmt.Errorf("writing message length: %v", err) - } - - // Send the actual message - if _, err := conn.Write(buffer); err != nil { - return fmt.Errorf("writing message data: %v", err) - } - - return nil -} - -// MessageReceive receives a message from the specified sender party -func (dt *MTLSMessenger) MessageReceive(_ context.Context, senderIndex int) ([]byte, error) { - conn, ok := dt.connections[senderIndex] - - if !ok { - return nil, fmt.Errorf("no connection found for sender index %d", senderIndex) - } - - // Read message length first (4 bytes) - lengthBytes := make([]byte, 4) - if _, err := io.ReadFull(conn, lengthBytes); err != nil { - return nil, fmt.Errorf("reading message length: %v", err) - } - - messageLength := binary.BigEndian.Uint32(lengthBytes) - - // Validate message length to prevent excessive memory allocation - if messageLength > 10*1024*1024 { // 10MB limit - return nil, fmt.Errorf("message too large: %d bytes", messageLength) - } - - // Read the exact amount of message data - buffer := make([]byte, messageLength) - if _, err := io.ReadFull(conn, buffer); err != nil { - return nil, fmt.Errorf("reading message data: %v", err) - } - - return buffer, nil -} - -// MessagesReceive receives messages from multiple sender parties concurrently -func (dt *MTLSMessenger) MessagesReceive(ctx context.Context, senderIndices []int) ([][]byte, error) { - receivedMsgs := make([][]byte, len(senderIndices)) - - eg := errgroup.Group{} - wg := sync.WaitGroup{} - wg.Add(len(senderIndices)) - - for i, senderIndex := range senderIndices { - eg.Go(func() error { - defer wg.Done() - msg, err := dt.MessageReceive(ctx, senderIndex) - if err != nil { - return fmt.Errorf("receiving message from %d: %v", senderIndex, err) - } - receivedMsgs[i] = msg - return nil - }) - } - wg.Wait() - - if err := eg.Wait(); err != nil { - return nil, fmt.Errorf("receiving messages: %v", err) - } - return receivedMsgs, nil -} - -// Close closes all connections and cleans up resources -func (dt *MTLSMessenger) Close() error { - dt.mu.Lock() - defer dt.mu.Unlock() - - fmt.Printf("Closing MTLSMessenger for party %d\n", dt.selfIndex) - - // Close all connections - for idx, conn := range dt.connections { - if conn != nil { - fmt.Printf("Closing connection to party %d\n", idx) - conn.Close() - } - } - - // Clear the connections map - dt.connections = make(map[int]*tls.Conn) - - // Close listener if it exists - if dt.listener != nil { - fmt.Printf("Closing listener for party %d\n", dt.selfIndex) - err := dt.listener.Close() - dt.listener = nil - if err != nil { - fmt.Printf("Error closing listener: %v\n", err) - return err - } - } - - fmt.Printf("MTLSMessenger closed successfully for party %d\n", dt.selfIndex) - return nil -} diff --git a/demos-go/cb-mpc-go/api/zk/doc.go b/demos-go/cb-mpc-go/api/zk/doc.go deleted file mode 100644 index 5dedbf1a..00000000 --- a/demos-go/cb-mpc-go/api/zk/doc.go +++ /dev/null @@ -1,41 +0,0 @@ -// Package zk contains zero-knowledge protocols that can be used alongside the -// MPC primitives in CB-MPC. -// -// A zero-knowledge proof lets a prover convince a verifier that a statement is -// true without revealing *why* it is true. The proofs implemented here are -// small, non-interactive and can be transmitted over any `transport.Messenger`. -// -// Currently implemented: -// -// - ZK-DL – Proof of knowledge of a discrete-logarithm relative to a curve -// generator (i.e. possession of an ECDSA private key). -// -// The Go API follows the same request/response design used by the `mpc` -// package which makes it trivial to marshal the data into JSON or protobuf and -// to plug the proofs into higher-level protocols. -// -// Example -// -// // 1. Generate a fresh key pair. -// kp, _ := zk.ZKDLGenerateKeyPair(&zk.ZKDLKeyGenRequest{}) -// -// // 2. Produce a proof that we know the private key. -// proveResp, _ := zk.ZKUCDLProve(&zk.ZKUCDLProveRequest{ -// PublicKey: kp.PublicKey, -// Witness: kp.PrivateKey, -// SessionID: []byte("session-1"), -// Auxiliary: 42, -// }) -// -// // 3. Verify the proof. -// verifyResp, _ := zk.ZKUCDLVerify(&zk.ZKUCDLVerifyRequest{ -// PublicKey: kp.PublicKey, -// Proof: proveResp.Proof, -// SessionID: []byte("session-1"), -// Auxiliary: 42, -// }) -// -// if !verifyResp.Valid { -// log.Fatal("proof rejected") -// } -package zk diff --git a/demos-go/cb-mpc-go/api/zk/zkdl.go b/demos-go/cb-mpc-go/api/zk/zkdl.go deleted file mode 100644 index 9c3e6256..00000000 --- a/demos-go/cb-mpc-go/api/zk/zkdl.go +++ /dev/null @@ -1,112 +0,0 @@ -package zk - -import ( - "fmt" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/cgobinding" -) - -// ZKUCDLProveRequest represents the input parameters for generating a -// zero-knowledge discrete logarithm proof. -// -// - PublicKey is the point Q = w·G. It MUST be non-nil. -// - Witness is the scalar w. It MUST be non-nil. -// - SessionID is an application supplied domain-separator that -// distinguishes distinct proofs created with the same key material. It -// MAY be nil/empty. -// - Auxiliary is user defined additional data that is bound to the proof -// (e.g. a transcript hash). -// -// The request mirrors the native C++ CB-MPC interface but uses the Go -// curve package types so that call-sites cannot accidentally confuse point -// and scalar byte slices. -type ZKUCDLProveRequest struct { - PublicKey *curve.Point - Witness *curve.Scalar - SessionID []byte - Auxiliary uint64 -} - -// ZKUCDLProveResponse is returned by ZKUCDLProve. -// -// Proof holds an opaque, serialised representation of the zero-knowledge -// proof. Until the cgobinding for the real protocol is available the proof is -// a simple, deterministic mock value. -type ZKUCDLProveResponse struct { - Proof []byte -} - -// ZKUCDLProve is the Go wrapper around the native CB-MPC ZK DL prover. -// It delegates the heavy lifting to the C++ implementation exposed through -// the cgobinding package. -func ZKUCDLProve(req *ZKUCDLProveRequest) (*ZKUCDLProveResponse, error) { - if req == nil { - return nil, fmt.Errorf("nil request") - } - if req.PublicKey == nil { - return nil, fmt.Errorf("public key cannot be nil") - } - if req.Witness == nil { - return nil, fmt.Errorf("witness cannot be nil") - } - if req.SessionID == nil { - req.SessionID = []byte{} - } - - // Convert point bytes to native point ref for the binding - pRef, err := cgobinding.ECCPointFromBytes(req.PublicKey.Bytes()) - if err != nil { - return nil, fmt.Errorf("invalid public key: %v", err) - } - defer pRef.Free() - proof, err := cgobinding.ZK_DL_Prove(pRef, req.Witness.Bytes, req.SessionID, req.Auxiliary) - if err != nil { - return nil, err - } - - return &ZKUCDLProveResponse{Proof: proof}, nil -} - -// ZKUCDLVerifyRequest represents the verifier input. -// The PublicKey, Proof, SessionID and Auxiliary fields must match the values -// used at prove time. -type ZKUCDLVerifyRequest struct { - PublicKey *curve.Point - Proof []byte - SessionID []byte - Auxiliary uint64 -} - -// ZKUCDLVerifyResponse indicates whether the proof could be validated. -type ZKUCDLVerifyResponse struct { - Valid bool -} - -// ZKUCDLVerify validates a proof produced by ZKUCDLProve using the native -// implementation. -func ZKUCDLVerify(req *ZKUCDLVerifyRequest) (*ZKUCDLVerifyResponse, error) { - if req == nil { - return nil, fmt.Errorf("nil request") - } - if req.PublicKey == nil { - return nil, fmt.Errorf("public key cannot be nil") - } - if len(req.Proof) == 0 { - return nil, fmt.Errorf("proof cannot be empty") - } - if req.SessionID == nil { - req.SessionID = []byte{} - } - - pRef, err := cgobinding.ECCPointFromBytes(req.PublicKey.Bytes()) - if err != nil { - return nil, fmt.Errorf("invalid public key: %v", err) - } - defer pRef.Free() - valid, err := cgobinding.ZK_DL_Verify(pRef, req.Proof, req.SessionID, req.Auxiliary) - if err != nil { - return nil, err - } - return &ZKUCDLVerifyResponse{Valid: valid}, nil -} diff --git a/demos-go/cb-mpc-go/api/zk/zkdl_test.go b/demos-go/cb-mpc-go/api/zk/zkdl_test.go deleted file mode 100644 index a8e7cd1f..00000000 --- a/demos-go/cb-mpc-go/api/zk/zkdl_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package zk - -import ( - "testing" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/internal/testutil" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// curvesUnderTest returns instances of all curves supported by the Go wrapper. -// The caller must invoke Free on every returned curve. -func curvesUnderTest(t *testing.T) []curve.Curve { - t.Helper() - - secp, err := curve.NewSecp256k1() - require.NoError(t, err) - p256, err := curve.NewP256() - require.NoError(t, err) - ed, err := curve.NewEd25519() - require.NoError(t, err) - - return []curve.Curve{secp, p256, ed} -} - -// TestProveAndVerifySuccess ensures that a proof created with ZKUCDLProve can be -// verified with ZKUCDLVerify for every supported curve. -func TestProveAndVerifySuccess(t *testing.T) { - for _, c := range curvesUnderTest(t) { - c := c // capture for parallel sub-test safety - t.Run(c.String(), func(t *testing.T) { - defer c.Free() - - w, W, err := c.RandomKeyPair() - require.NoError(t, err) - - sessionID := []byte("session-" + c.String()) - auxiliary := uint64(42) - - pr, err := ZKUCDLProve(&ZKUCDLProveRequest{ - PublicKey: W, - Witness: w, - SessionID: sessionID, - Auxiliary: auxiliary, - }) - require.NoError(t, err) - require.NotEmpty(t, pr.Proof) - - vr, err := ZKUCDLVerify(&ZKUCDLVerifyRequest{ - PublicKey: W, - Proof: pr.Proof, - SessionID: sessionID, - Auxiliary: auxiliary, - }) - require.NoError(t, err) - assert.True(t, vr.Valid) - }) - } -} - -// TestVerifyRejectsTamperedProof modifies a valid proof and expects the -// verification to fail. -func TestVerifyRejectsTamperedProof(t *testing.T) { - c, err := curve.NewSecp256k1() - require.NoError(t, err) - defer c.Free() - - w, W, err := c.RandomKeyPair() - require.NoError(t, err) - - sessionID := []byte("tamper") - auxiliary := uint64(1) - - pr, err := ZKUCDLProve(&ZKUCDLProveRequest{PublicKey: W, Witness: w, SessionID: sessionID, Auxiliary: auxiliary}) - require.NoError(t, err) - - // Flip first byte to invalidate the proof. - corrupted := append([]byte{}, pr.Proof...) - if len(corrupted) > 0 { - corrupted[0] ^= 0xFF - } - - var vr *ZKUCDLVerifyResponse - testutil.TSilence(t, func(t *testing.T) { - var err2 error - vr, err2 = ZKUCDLVerify(&ZKUCDLVerifyRequest{PublicKey: W, Proof: corrupted, SessionID: sessionID, Auxiliary: auxiliary}) - require.NoError(t, err2) - }) - assert.False(t, vr.Valid) -} - -// TestProveRejectsMalformedPublicKey ensures that ZKUCDLProve returns an error -// when provided with an invalid public key (nil). -func TestProveRejectsMalformedPublicKey(t *testing.T) { - c, err := curve.NewSecp256k1() - require.NoError(t, err) - defer c.Free() - - // Valid non-nil witness - w := curve.NewScalarFromInt64(1) - - _, err = ZKUCDLProve(&ZKUCDLProveRequest{ - PublicKey: nil, - Witness: w, - SessionID: []byte("badpk"), - Auxiliary: 0, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "public key cannot be nil") -} - -// TestVerifyRejectsMalformedPublicKey ensures that ZKUCDLVerify returns an error -// when provided with an invalid public key (nil). -func TestVerifyRejectsMalformedPublicKey(t *testing.T) { - c, err := curve.NewSecp256k1() - require.NoError(t, err) - defer c.Free() - - // Provide any non-empty proof bytes to bypass the proof length check and - // exercise the public key validation branch first. - bogusProof := []byte{0x01} - - _, err = ZKUCDLVerify(&ZKUCDLVerifyRequest{ - PublicKey: nil, - Proof: bogusProof, - SessionID: []byte("badpk"), - Auxiliary: 0, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "public key cannot be nil") -} diff --git a/demos-go/cb-mpc-go/go.mod b/demos-go/cb-mpc-go/go.mod deleted file mode 100644 index 7475141d..00000000 --- a/demos-go/cb-mpc-go/go.mod +++ /dev/null @@ -1,16 +0,0 @@ -module github.com/coinbase/cb-mpc/demos-go/cb-mpc-go - -go 1.23.0 - -toolchain go1.24.2 - -require ( - github.com/stretchr/testify v1.9.0 - golang.org/x/sync v0.15.0 -) - -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/demos-go/cb-mpc-go/go.sum b/demos-go/cb-mpc-go/go.sum deleted file mode 100644 index 0f67ccbc..00000000 --- a/demos-go/cb-mpc-go/go.sum +++ /dev/null @@ -1,12 +0,0 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ac.cpp b/demos-go/cb-mpc-go/internal/cgobinding/ac.cpp deleted file mode 100644 index 204836f7..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ac.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include "ac.h" - -#include -#include -#include - -using namespace coinbase; -using namespace coinbase::crypto; -using node_t = coinbase::crypto::ss::node_t; -using node_e = coinbase::crypto::ss::node_e; - -#ifdef __cplusplus -extern "C" { -#endif - -// ============ PVE (Access Structure) utilities ================ -crypto_ss_node_ref new_node(int node_type, cmem_t node_name, int threshold) { - std::string name = coinbase::ffi::view(node_name).to_string(); - node_t* node = new node_t(node_e(node_type), name, threshold); - return crypto_ss_node_ref{node}; -} - -void add_child(crypto_ss_node_ref* parent, crypto_ss_node_ref* child) { - node_t* p = static_cast(parent->opaque); - node_t* c = static_cast(child->opaque); - p->add_child_node(c); -} - -crypto_ss_ac_ref new_access_structure(crypto_ss_node_ref* root, ecurve_ref* curve_ref) { - crypto::ss::node_t* root_node = static_cast(root->opaque); - - // Resolve the curve reference passed from Go and obtain its generator. - crypto::ecurve_t* curve = static_cast(curve_ref->opaque); - - crypto::ss::ac_t* ac = new crypto::ss::ac_t(); - if (curve) { - ac->G = curve->generator(); - } - ac->root = root_node; - return crypto_ss_ac_ref{ac}; -} - -// ============ Memory Management ================================ - -// Frees the native access-structure instance allocated by -// new_access_structure. Calling this function more than once on the same -// object or passing an already-freed reference is undefined behaviour. -void free_crypto_ss_ac(crypto_ss_ac_ref ac) { - if (ac.opaque != nullptr) { - crypto::ss::ac_t* ptr = static_cast(ac.opaque); - delete ptr; - } -} - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ac.go b/demos-go/cb-mpc-go/internal/cgobinding/ac.go deleted file mode 100644 index 38e0d761..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ac.go +++ /dev/null @@ -1,48 +0,0 @@ -package cgobinding - -/* -#cgo CFLAGS: -I${SRCDIR} -#cgo CXXFLAGS: -I${SRCDIR} -#include "ac.h" -*/ -import "C" - -type C_NodePtr C.crypto_ss_node_ref -type C_AcPtr C.crypto_ss_ac_ref - -// NodeType represents the type of a node in an access-structure tree. -type NodeType int - -const ( - NodeType_NONE NodeType = iota - NodeType_LEAF - NodeType_AND - NodeType_OR - NodeType_THRESHOLD -) - -// NewNode constructs a new access-structure node by delegating to the native helper. -// The returned C_NodePtr owns the underlying C++ ss::node_t* pointer. -func NewNode(nodeType NodeType, nodeName string, threshold int) C_NodePtr { - node := C.new_node(C.int(nodeType), cmem([]byte(nodeName)), C.int(threshold)) - return C_NodePtr(node) -} - -// AddChild links |child| as a direct child of |parent| in the access structure tree. -func AddChild(parent, child C_NodePtr) { - C.add_child((*C.crypto_ss_node_ref)(&parent), (*C.crypto_ss_node_ref)(&child)) -} - -// NewAccessStructure constructs a new native access-structure object and -// returns an opaque handle managed by the caller. The returned handle must be -// released exactly once via FreeAccessStructure. -func NewAccessStructure(root C_NodePtr, curve ECurveRef) C_AcPtr { - ac := C.new_access_structure((*C.crypto_ss_node_ref)(&root), (*C.ecurve_ref)(&curve)) - return C_AcPtr(ac) -} - -// FreeAccessStructure releases the resources associated with the native -// access-structure object previously created via NewAccessStructure. -func FreeAccessStructure(ac C_AcPtr) { - C.free_crypto_ss_ac(C.crypto_ss_ac_ref(ac)) -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ac.h b/demos-go/cb-mpc-go/internal/cgobinding/ac.h deleted file mode 100644 index 685451f4..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ac.h +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include - -#include "curve.h" - -// Stand-alone opaque pointer wrapper for access-structure trees. -// The concrete type is defined inside the C++ implementation. -typedef struct crypto_ss_ac_ref { - void* opaque; // Opaque pointer to the C++ access-structure instance -} crypto_ss_ac_ref; - -// Opaque pointer wrapper for secret-sharing tree nodes. -typedef struct crypto_ss_node_ref { - void* opaque; // Opaque pointer to the C++ node_t instance -} crypto_ss_node_ref; - -// Function prototypes for secret-sharing access structure nodes. -// Creates a new node of the given type/name/threshold. The caller owns the returned pointer. -crypto_ss_node_ref new_node(int node_type, cmem_t node_name, int threshold); -// Adds |child| as a child of |parent|. Both pointers must reference valid nodes. -void add_child(crypto_ss_node_ref* parent, crypto_ss_node_ref* child); -// Constructs and returns a new access-structure given a root node and the -// curve reference. The caller owns the returned pointer and must release it -// via free_crypto_ss_ac. -crypto_ss_ac_ref new_access_structure(crypto_ss_node_ref* root, ecurve_ref* curve); - -// Releases memory held by a native access-structure. The caller must -// invoke this once the object is no longer needed to avoid memory leaks. -void free_crypto_ss_ac(crypto_ss_ac_ref ac); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/agree_random.cpp b/demos-go/cb-mpc-go/internal/cgobinding/agree_random.cpp deleted file mode 100644 index 461871fe..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/agree_random.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "agree_random.h" - -#include - -#include -#include -#include -#include - -using namespace coinbase; -using namespace coinbase::mpc; - -namespace { -constexpr int SUCCESS_CODE = 0; -constexpr int ERROR_CODE = -1; -constexpr int PARAM_ERROR_CODE = -2; -} // namespace - -#define VALIDATE_JOB_2P(job) \ - do { \ - if (!job || !job->opaque) { \ - return PARAM_ERROR_CODE; \ - } \ - } while (0) - -#define GET_JOB_2P(job) static_cast(job->opaque) - -int mpc_agree_random(job_2p_ref* job, int bit_len, cmem_t* out) { - if (!job || !job->opaque || !out) return PARAM_ERROR_CODE; - if (bit_len <= 0) return PARAM_ERROR_CODE; - - try { - job_2p_t* j = GET_JOB_2P(job); - buf_t out_buf; - error_t err = agree_random(*j, bit_len, out_buf); - - if (err) return static_cast(err); - - *out = coinbase::ffi::copy_to_cmem(out_buf); - return SUCCESS_CODE; - - } catch (const std::exception& e) { - std::cerr << "Error in mpc_agree_random: " << e.what() << std::endl; - return ERROR_CODE; - } -} \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/agree_random.h b/demos-go/cb-mpc-go/internal/cgobinding/agree_random.h deleted file mode 100644 index e1e99bf8..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/agree_random.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include -#include - -#include - -#include "network.h" - -#ifdef __cplusplus -extern "C" { -#endif - -int mpc_agree_random(job_2p_ref* job, int bit_len, cmem_t* out); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/agreerandom.go b/demos-go/cb-mpc-go/internal/cgobinding/agreerandom.go deleted file mode 100644 index 27bcb1c7..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/agreerandom.go +++ /dev/null @@ -1,23 +0,0 @@ -package cgobinding - -import ( - "fmt" -) - -/* -#include -#include -#include "agree_random.h" -#include "cblib.h" -*/ -import "C" - -// AgreeRandom executes the agree random protocol between two parties -func AgreeRandom(job Job2P, bitLen int) ([]byte, error) { - var out CMEM - cErr := C.mpc_agree_random(job.GetCJob(), C.int(bitLen), &out) - if cErr != 0 { - return nil, fmt.Errorf("mpc_agree_random failed, %v", cErr) - } - return CMEMGet(out), nil -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/cblib.h b/demos-go/cb-mpc-go/internal/cgobinding/cblib.h deleted file mode 100644 index dec395b6..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/cblib.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include - -#include - -#include "ac.h" -#include "curve.h" -#include "network.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// ------------------------- Type Wrappers --------------------------- -// Naming convention: -// - drop the initial 'coinbase' namespace, replace :: with _ add ptr instead of _t -// - We do this since direct usage of namespaces in C is not supported -// case case Example: -// coinbase::mpc::ecdsa2pc::key_t is represented here as mpc_ecdsa2pc_key_ptr - -// ------------------------- Function/Method Wrappers ---------------- -// For each function in the library, create a wrapper that uses the following types: -// - primitive types such as int, char, void, ... -// - PTR types defined above -// - PTR types defined in the network directory -// - cmem_t, cmems_t types defined in the library -// -// Conventions: -// - Implementing a method of a class, receives the class pointer as the first argument, called ctx - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/demos-go/cb-mpc-go/internal/cgobinding/cmem.go b/demos-go/cb-mpc-go/internal/cgobinding/cmem.go deleted file mode 100644 index 690ab045..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/cmem.go +++ /dev/null @@ -1,171 +0,0 @@ -package cgobinding - -import ( - "runtime" - "unsafe" -) - -/* -#cgo CXXFLAGS: -std=c++17 -Wno-switch -Wno-parentheses -Wno-attributes -Wno-deprecated-declarations -DNO_DEPRECATED_OPENSSL -#cgo CFLAGS: -Wno-deprecated-declarations -#cgo arm64 CXXFLAGS: -march=armv8-a+crypto -#cgo !linux LDFLAGS: -lcrypto -#cgo android LDFLAGS: -lcrypto -static-libstdc++ -#cgo LDFLAGS: -ldl -// Local headers/libs are provided via CGO_* environment variables. -// See scripts/go_with_cpp.sh for how we set: -// CGO_CFLAGS/CGO_CXXFLAGS to include /src -// CGO_LDFLAGS to include /build//lib and /lib/ -#cgo linux,!android CFLAGS: -I/usr/local/include -#cgo linux,!android CXXFLAGS: -I/usr/local/include -#cgo linux,!android LDFLAGS: /usr/local/lib64/libcrypto.a -#cgo darwin,!iossimulator,!ios CFLAGS: -I/usr/local/opt/openssl@3.2.0/include -#cgo darwin,!iossimulator,!ios CXXFLAGS: -I/usr/local/opt/openssl@3.2.0/include -#cgo darwin,!iossimulator,!ios LDFLAGS: -L/usr/local/opt/openssl@3.2.0/lib - -#cgo CFLAGS: -I${SRCDIR} -#cgo CXXFLAGS: -I${SRCDIR} -#cgo LDFLAGS: -lcbmpc -#cgo linux,!android LDFLAGS: /usr/local/lib64/libcrypto.a - -#include -#include -#include "cblib.h" -*/ -import "C" - -// Memory Management Utilities - -type CMEM = C.cmem_t - -func cmem(in []byte) CMEM { - var mem CMEM - mem.size = C.int(len(in)) - if len(in) > 0 { - mem.data = (*C.uchar)(&in[0]) - } else { - mem.data = nil - } - return mem -} - -func CMEMGet(cmem CMEM) []byte { - if cmem.data == nil { - return nil - } - out := C.GoBytes(unsafe.Pointer(cmem.data), cmem.size) - C.memset(unsafe.Pointer(cmem.data), 0, C.ulong(cmem.size)) - C.free(unsafe.Pointer(cmem.data)) - return out -} - -type CMEMS = C.cmems_t - -func cmems(in [][]byte) CMEMS { - var mems CMEMS - count := len(in) - if count > 0 { - lens := make([]int32, count) - mems.sizes = (*C.int)(&lens[0]) - mems.count = C.int(count) - var n, k int - for i := 0; i < count; i++ { - l := len(in[i]) - lens[i] = int32(l) - n += int(lens[i]) - } - if n > 0 { - data := make([]byte, n) - for i := 0; i < count; i++ { - l := len(in[i]) - if l > 0 { - copy(data[k:k+l], in[i]) - } - k += l - } - mems.data = (*C.uchar)(&data[0]) - } else { - mems.data = nil - } - } else { - mems.sizes = nil - mems.data = nil - mems.count = 0 - } - return mems -} - -func CMEMSGet(cmems CMEMS) [][]byte { - if cmems.data == nil { - return nil - } - count := int(cmems.count) - out := make([][]byte, count) - n := 0 - p := uintptr(unsafe.Pointer(cmems.data)) - for i := 0; i < count; i++ { - // Inline array access to avoid dependency on network.go - sizePtr := (*C.int)(unsafe.Pointer(uintptr(unsafe.Pointer(cmems.sizes)) + uintptr(i*int(unsafe.Sizeof(C.int(0)))))) - l := int(*sizePtr) - out[i] = C.GoBytes(unsafe.Pointer(p), C.int(l)) - p += uintptr(l) - n += l - } - C.memset(unsafe.Pointer(cmems.data), 0, C.ulong(n)) - C.free(unsafe.Pointer(cmems.data)) - C.free(unsafe.Pointer(cmems.sizes)) - return out -} - -// cmemsPin holds Go-owned backing storage for a CMEMS so the Go GC cannot -// reclaim it while a C function is executing. Always call runtime.KeepAlive -// on the returned value after the C call returns. -type cmemsPin struct { - c CMEMS - lens []int32 - data []byte -} - -// makeCmems builds a CMEMS value backed by Go slices that stay reachable via -// the returned cmemsPin. Call runtime.KeepAlive(pin) after the C call. -func makeCmems(in [][]byte) cmemsPin { - var mems CMEMS - count := len(in) - if count > 0 { - lens := make([]int32, count) - mems.sizes = (*C.int)(&lens[0]) - mems.count = C.int(count) - var n, k int - for i := 0; i < count; i++ { - l := len(in[i]) - lens[i] = int32(l) - n += int(lens[i]) - } - var data []byte - if n > 0 { - data = make([]byte, n) - for i := 0; i < count; i++ { - l := len(in[i]) - if l > 0 { - copy(data[k:k+l], in[i]) - } - k += l - } - mems.data = (*C.uchar)(&data[0]) - } else { - mems.data = nil - } - // Ensure the slices are considered live until function return - // (and later via runtime.KeepAlive in callers). - _ = lens - _ = data - return cmemsPin{c: mems, lens: lens, data: data} - } - mems.sizes = nil - mems.data = nil - mems.count = 0 - // KeepAlive on zero-value pin is harmless. - pin := cmemsPin{c: mems} - runtime.KeepAlive(pin) - return pin -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/curve.cpp b/demos-go/cb-mpc-go/internal/cgobinding/curve.cpp deleted file mode 100644 index 0b687ddd..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/curve.cpp +++ /dev/null @@ -1,177 +0,0 @@ -#include "curve.h" - -#include - -#include -#include -#include - -using namespace coinbase; -using namespace coinbase::crypto; - -// ============ Curve Operations ================ - -ecurve_ref new_ecurve(int curve_code) { - ecurve_t* curve = new ecurve_t(ecurve_t::find(curve_code)); - return ecurve_ref{curve}; -} - -void free_ecurve(ecurve_ref ref) { - if (ref.opaque) { - delete static_cast(ref.opaque); - } -} - -void free_ecc_point(ecc_point_ref ref) { - if (ref.opaque) { - delete static_cast(ref.opaque); - } -} - -ecc_point_ref ecurve_generator(ecurve_ref* curve) { - ecurve_t* curve_obj = static_cast(curve->opaque); - // Create a new point and copy the generator into it - ecc_point_t* generator = new ecc_point_t(); - const ecc_generator_point_t& gen = curve_obj->generator(); - *generator = gen; // This should copy the generator point - return ecc_point_ref{generator}; -} - -cmem_t ecurve_order(ecurve_ref* curve) { - ecurve_t* curve_obj = static_cast(curve->opaque); - bn_t order = curve_obj->order(); - buf_t order_buf = order.to_bin(); - return coinbase::ffi::copy_to_cmem(order_buf); -} - -int ecurve_get_curve_code(ecurve_ref* curve) { - ecurve_t* curve_obj = static_cast(curve->opaque); - return curve_obj->get_openssl_code(); -} - -ecc_point_ref ecc_point_from_bytes(cmem_t point_bytes) { - ecc_point_t* point = new ecc_point_t(); - error_t err = coinbase::deser(coinbase::ffi::view(point_bytes), *point); - if (err) { - delete point; - return ecc_point_ref{nullptr}; - } - return ecc_point_ref{point}; -} - -cmem_t ecc_point_to_bytes(ecc_point_ref* point) { - ecc_point_t* point_obj = static_cast(point->opaque); - buf_t point_buf = coinbase::ser(*point_obj); - return coinbase::ffi::copy_to_cmem(point_buf); -} - -ecc_point_ref ecc_point_multiply(ecc_point_ref* point, cmem_t scalar) { - ecc_point_t* point_obj = static_cast(point->opaque); - // Use from_bin to convert raw bytes to bn_t - bn_t scalar_bn = bn_t::from_bin(coinbase::ffi::view(scalar)); - - ecc_point_t* result = new ecc_point_t(scalar_bn * (*point_obj)); - return ecc_point_ref{result}; -} - -ecc_point_ref ecc_point_add(ecc_point_ref* point1, ecc_point_ref* point2) { - ecc_point_t* p1 = static_cast(point1->opaque); - ecc_point_t* p2 = static_cast(point2->opaque); - ecc_point_t* result = new ecc_point_t(*p1 + *p2); - return ecc_point_ref{result}; -} - -ecc_point_ref ecc_point_subtract(ecc_point_ref* point1, ecc_point_ref* point2) { - ecc_point_t* p1 = static_cast(point1->opaque); - ecc_point_t* p2 = static_cast(point2->opaque); - ecc_point_t* result = new ecc_point_t(*p1 - *p2); - return ecc_point_ref{result}; -} - -cmem_t ecc_point_get_x(ecc_point_ref* point) { - ecc_point_t* point_obj = static_cast(point->opaque); - buf_t x_buf = point_obj->get_x().to_bin(); - return coinbase::ffi::copy_to_cmem(x_buf); -} - -cmem_t ecc_point_get_y(ecc_point_ref* point) { - ecc_point_t* point_obj = static_cast(point->opaque); - buf_t y_buf = point_obj->get_y().to_bin(); - return coinbase::ffi::copy_to_cmem(y_buf); -} - -int ecc_point_is_zero(ecc_point_ref* point) { - ecc_point_t* point_obj = static_cast(point->opaque); - // Use the built-in infinity check method - return point_obj->is_infinity() ? 1 : 0; -} - -int ecc_point_equals(ecc_point_ref* point1, ecc_point_ref* point2) { - ecc_point_t* p1 = static_cast(point1->opaque); - ecc_point_t* p2 = static_cast(point2->opaque); - return (*p1 == *p2) ? 1 : 0; -} - -// ============ Random Scalar Generation ================ - -cmem_t ecurve_random_scalar(ecurve_ref* curve) { - ecurve_t* curve_obj = static_cast(curve->opaque); - bn_t k = curve_obj->get_random_value(); - buf_t k_buf = k.to_bin(curve_obj->order().get_bin_size()); - return coinbase::ffi::copy_to_cmem(k_buf); -} - -int ecc_verify_der(int curve_code, cmem_t pub_oct, cmem_t hash, cmem_t der_sig) { - ecurve_t curve = ecurve_t::find(curve_code); - if (!curve) return -1; - ecc_point_t Q; - if (Q.from_oct(curve, coinbase::ffi::view(pub_oct))) return -2; - ecc_pub_key_t pub(Q); - error_t rv = pub.verify(coinbase::ffi::view(hash), coinbase::ffi::view(der_sig)); - return rv ? -3 : 0; -} - -// ============ Scalar Operations ================ - -// Adds two scalars represented as byte arrays (big-endian) and returns the -// resulting scalar as bytes. The addition is performed using the bn_t -// implementation from the core library to ensure constant-time behaviour. - -cmem_t bn_add(cmem_t a, cmem_t b) { - bn_t a_bn = bn_t::from_bin(coinbase::ffi::view(a)); - bn_t b_bn = bn_t::from_bin(coinbase::ffi::view(b)); - bn_t c_bn = a_bn + b_bn; - buf_t c_buf = c_bn.to_bin(); - return coinbase::ffi::copy_to_cmem(c_buf); -} - -// Adds two scalars modulo the curve order and returns the result as bytes. -cmem_t ec_mod_add(ecurve_ref* curve, cmem_t a, cmem_t b) { - ecurve_t* curve_obj = static_cast(curve->opaque); - mod_t q = curve_obj->order(); - - bn_t a_bn = bn_t::from_bin(coinbase::ffi::view(a)); - bn_t b_bn = bn_t::from_bin(coinbase::ffi::view(b)); - - bn_t c_bn = (a_bn + b_bn) % q; - - buf_t c_buf = c_bn.to_bin(q.get_bin_size()); - return coinbase::ffi::copy_to_cmem(c_buf); -} - -// Creates a bn_t from an int64 value and returns its byte representation. -cmem_t bn_from_int64(int64_t value) { - bn_t bn; - bn.set_int64(value); - buf_t bn_buf = bn.to_bin(); - return coinbase::ffi::copy_to_cmem(bn_buf); -} - -// ============ Generator Multiply ================ - -ecc_point_ref ecurve_mul_generator(ecurve_ref* curve, cmem_t scalar) { - ecurve_t* curve_obj = static_cast(curve->opaque); - bn_t k = bn_t::from_bin(coinbase::ffi::view(scalar)); - ecc_point_t* result = new ecc_point_t(curve_obj->mul_to_generator(k)); - return ecc_point_ref{result}; -} \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/curve.go b/demos-go/cb-mpc-go/internal/cgobinding/curve.go deleted file mode 100644 index 0cac2a58..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/curve.go +++ /dev/null @@ -1,164 +0,0 @@ -package cgobinding - -import ( - "fmt" -) - -/* -#include -#include "curve.h" -*/ -import "C" - -// =========== Curve and Point Types ===================== - -// Exported aliases so that other packages can reference these types. -// They are simple type aliases, so no extra conversion cost. -type ECurveRef C.ecurve_ref -type ECCPointRef C.ecc_point_ref - -func (c *ECurveRef) Free() { - C.free_ecurve(C.ecurve_ref(*c)) -} - -func (p *ECCPointRef) Free() { - C.free_ecc_point(C.ecc_point_ref(*p)) -} - -// =========== Curve Operations ===================== - -// ECurveFind finds a curve by curve code -func ECurveFind(curveCode int) (ECurveRef, error) { - cCurve := C.new_ecurve(C.int(curveCode)) - if cCurve.opaque == nil { - return ECurveRef{}, fmt.Errorf("invalid curve code: %d", curveCode) - } - return ECurveRef(cCurve), nil -} - -// ECurveGenerator returns the generator point of the curve -func ECurveGenerator(curve ECurveRef) ECCPointRef { - cPoint := C.ecurve_generator((*C.ecurve_ref)(&curve)) - return ECCPointRef(cPoint) -} - -// ECurveOrderToMem returns the order of the curve as bytes -func ECurveOrderToMem(curve ECurveRef) []byte { - cMem := C.ecurve_order((*C.ecurve_ref)(&curve)) - return CMEMGet(cMem) -} - -// ECurveGetCurveCode returns the curve code -func ECurveGetCurveCode(curve ECurveRef) int { - code := C.ecurve_get_curve_code((*C.ecurve_ref)(&curve)) - return int(code) -} - -// ECurveRandomScalarToMem returns a random scalar modulo the curve order -func ECurveRandomScalarToMem(curve ECurveRef) []byte { - cMem := C.ecurve_random_scalar((*C.ecurve_ref)(&curve)) - return CMEMGet(cMem) -} - -// ================= Scalar Operations ==================== - -// ScalarAdd returns the byte representation of a + b where the operands are -// interpreted as big-endian scalars (bn_t in the C++ layer). -// The computation is delegated to the native C++ implementation to leverage -// its constant-time big number arithmetic. -func ScalarAdd(a, b []byte) []byte { - cMem := C.bn_add(cmem(a), cmem(b)) - return CMEMGet(cMem) -} - -// ScalarAddModOrder returns (a+b) mod order(curve). -func ScalarAddModOrder(curve ECurveRef, a, b []byte) []byte { - cMem := C.ec_mod_add((*C.ecurve_ref)(&curve), cmem(a), cmem(b)) - return CMEMGet(cMem) -} - -// ScalarFromInt64 creates a scalar from an int64 value and returns its byte representation. -func ScalarFromInt64(value int64) []byte { - cMem := C.bn_from_int64((C.int64_t)(value)) - return CMEMGet(cMem) -} - -// ECurveMulGenerator multiplies the curve generator by a scalar and returns a -// new point reference. -func ECurveMulGenerator(curve ECurveRef, scalar []byte) ECCPointRef { - cPoint := C.ecurve_mul_generator((*C.ecurve_ref)(&curve), cmem(scalar)) - return ECCPointRef(cPoint) -} - -// =========== Point Operations ===================== - -// ECCPointFromBytes creates a point from bytes -func ECCPointFromBytes(pointBytes []byte) (ECCPointRef, error) { - cPoint := C.ecc_point_from_bytes(cmem(pointBytes)) - if cPoint.opaque == nil { - return ECCPointRef{}, fmt.Errorf("failed to create point from bytes") - } - return ECCPointRef(cPoint), nil -} - -// ECCPointMultiply multiplies a point by a scalar -func ECCPointMultiply(point ECCPointRef, scalar []byte) (ECCPointRef, error) { - cPoint := C.ecc_point_multiply((*C.ecc_point_ref)(&point), cmem(scalar)) - if cPoint.opaque == nil { - return ECCPointRef{}, fmt.Errorf("failed to multiply point") - } - return ECCPointRef(cPoint), nil -} - -// ECCPointAdd adds two points -func ECCPointAdd(point1, point2 ECCPointRef) ECCPointRef { - cPoint := C.ecc_point_add((*C.ecc_point_ref)(&point1), (*C.ecc_point_ref)(&point2)) - return ECCPointRef(cPoint) -} - -// ECCPointSubtract subtracts two points -func ECCPointSubtract(point1, point2 ECCPointRef) ECCPointRef { - cPoint := C.ecc_point_subtract((*C.ecc_point_ref)(&point1), (*C.ecc_point_ref)(&point2)) - return ECCPointRef(cPoint) -} - -// ECCPointGetX returns the X coordinate of the point -func ECCPointGetX(point ECCPointRef) []byte { - cMem := C.ecc_point_get_x((*C.ecc_point_ref)(&point)) - return CMEMGet(cMem) -} - -// ECCPointGetY returns the Y coordinate of the point -func ECCPointGetY(point ECCPointRef) []byte { - cMem := C.ecc_point_get_y((*C.ecc_point_ref)(&point)) - return CMEMGet(cMem) -} - -// ECCPointIsZero checks if the point is zero -func ECCPointIsZero(point ECCPointRef) bool { - return C.ecc_point_is_zero((*C.ecc_point_ref)(&point)) != 0 -} - -// ECCPointEquals checks if two points are equal -func ECCPointEquals(point1, point2 ECCPointRef) bool { - return C.ecc_point_equals((*C.ecc_point_ref)(&point1), (*C.ecc_point_ref)(&point2)) != 0 -} - -// ECCPointToBytes serializes a point to the library's canonical byte format. -func ECCPointToBytes(point ECCPointRef) []byte { - cMem := C.ecc_point_to_bytes((*C.ecc_point_ref)(&point)) - return CMEMGet(cMem) -} - -// ECCVerifyDER verifies a DER-encoded ECDSA signature. -// curveCode: OpenSSL NID for the curve (e.g., 714 for secp256k1) -// pubOct: SEC1 uncompressed public key bytes -// hash: 32-byte digest -// derSig: DER-encoded ECDSA signature -func ECCVerifyDER(curveCode int, pubOct []byte, hash []byte, derSig []byte) error { - rv := C.ecc_verify_der(C.int(curveCode), cmem(pubOct), cmem(hash), cmem(derSig)) - if rv != 0 { - return fmt.Errorf("ecdsa verify failed (%d)", int(rv)) - } - return nil -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/curve.h b/demos-go/cb-mpc-go/internal/cgobinding/curve.h deleted file mode 100644 index 71b71c4d..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/curve.h +++ /dev/null @@ -1,57 +0,0 @@ -#pragma once - -#include - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// ============ Curve Type Definitions ============= - -typedef struct ecc_point_ref { - void* opaque; -} ecc_point_ref; - -typedef struct ecurve_ref { - void* opaque; -} ecurve_ref; - -// ============ Curve Memory Management ============= - -void free_ecc_point(ecc_point_ref ref); -void free_ecurve(ecurve_ref ref); - -// ============ Curve Operations ============= - -// Curve functions -ecurve_ref new_ecurve(int curve_code); -ecc_point_ref ecurve_generator(ecurve_ref* curve); -cmem_t ecurve_order(ecurve_ref* curve); -int ecurve_get_curve_code(ecurve_ref* curve); - -// Point functions -ecc_point_ref ecc_point_from_bytes(cmem_t point_bytes); -cmem_t ecc_point_to_bytes(ecc_point_ref* point); -ecc_point_ref ecc_point_multiply(ecc_point_ref* point, cmem_t scalar); -ecc_point_ref ecc_point_add(ecc_point_ref* point1, ecc_point_ref* point2); -ecc_point_ref ecc_point_subtract(ecc_point_ref* point1, ecc_point_ref* point2); -cmem_t ecc_point_get_x(ecc_point_ref* point); -cmem_t ecc_point_get_y(ecc_point_ref* point); -int ecc_point_is_zero(ecc_point_ref* point); -int ecc_point_equals(ecc_point_ref* point1, ecc_point_ref* point2); -cmem_t ecurve_random_scalar(ecurve_ref* curve); - -// ECDSA verification (DER-encoded signature). Returns 0 on success, non-zero on failure. -int ecc_verify_der(int curve_code, cmem_t pub_oct, cmem_t hash, cmem_t der_sig); - -// Scalar operations -cmem_t bn_add(cmem_t a, cmem_t b); -cmem_t ec_mod_add(ecurve_ref* curve, cmem_t a, cmem_t b); -cmem_t bn_from_int64(int64_t value); -ecc_point_ref ecurve_mul_generator(ecurve_ref* curve, cmem_t scalar); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.cpp b/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.cpp deleted file mode 100644 index e2eee2d3..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.cpp +++ /dev/null @@ -1,113 +0,0 @@ -#include "ecdsa2p.h" - -#include - -#include -#include -#include -#include -#include - -#include "curve.h" -#include "network.h" - -using namespace coinbase; -using namespace coinbase::mpc; - -int mpc_ecdsa2p_dkg(job_2p_ref* j, int curve_code, mpc_ecdsa2pc_key_ref* k) { - job_2p_t* job = static_cast(j->opaque); - ecurve_t curve = ecurve_t::find(curve_code); - - ecdsa2pc::key_t* key = new ecdsa2pc::key_t(); - - error_t err = ecdsa2pc::dkg(*job, curve, *key); - if (err) return err; - *k = mpc_ecdsa2pc_key_ref{key}; - - return 0; -} - -int mpc_ecdsa2p_refresh(job_2p_ref* j, mpc_ecdsa2pc_key_ref* k, mpc_ecdsa2pc_key_ref* nk) { - job_2p_t* job = static_cast(j->opaque); - - ecdsa2pc::key_t* key = static_cast(k->opaque); - ecdsa2pc::key_t* new_key = new ecdsa2pc::key_t(); - - error_t err = ecdsa2pc::refresh(*job, *key, *new_key); - if (err) return err; - *nk = mpc_ecdsa2pc_key_ref{new_key}; - - return 0; -} - -int mpc_ecdsa2p_sign(job_2p_ref* j, cmem_t sid_mem, mpc_ecdsa2pc_key_ref* k, cmems_t msgs, cmems_t* sigs) { - job_2p_t* job = static_cast(j->opaque); - ecdsa2pc::key_t* key = static_cast(k->opaque); - buf_t sid = coinbase::ffi::view(sid_mem); - // Reconstruct messages from cmems_t explicitly and copy into owned buffers - int count = msgs.count; - std::vector owned_msgs; - owned_msgs.reserve(count); - const uint8_t* p = msgs.data; - for (int i = 0; i < count; i++) { - int len = msgs.sizes ? msgs.sizes[i] : 0; - buf_t b(len); - if (len > 0) memcpy(b.data(), p, len); - owned_msgs.emplace_back(std::move(b)); - p += len; - } - std::vector messages(owned_msgs.size()); - for (size_t i = 0; i < owned_msgs.size(); i++) messages[i] = owned_msgs[i]; - - std::vector signatures; - error_t err = ecdsa2pc::sign_batch(*job, sid, *key, messages, signatures); - if (err) return err; - *sigs = coinbase::ffi::copy_to_cmems(buf_t::to_mems(signatures)); - - return 0; -} - -// ============ Memory Management ================= -void free_mpc_ecdsa2p_key(mpc_ecdsa2pc_key_ref ctx) { - if (ctx.opaque) { - delete static_cast(ctx.opaque); - } -} - -// ============ Accessors ========================= - -int mpc_ecdsa2p_key_get_role_index(mpc_ecdsa2pc_key_ref* key) { - if (key == NULL || key->opaque == NULL) { - return -1; // error: invalid key - } - ecdsa2pc::key_t* k = static_cast(key->opaque); - return static_cast(k->role); -} - -ecc_point_ref mpc_ecdsa2p_key_get_Q(mpc_ecdsa2pc_key_ref* key) { - if (key == NULL || key->opaque == NULL) { - return ecc_point_ref{nullptr}; - } - ecdsa2pc::key_t* k = static_cast(key->opaque); - ecc_point_t* Q_copy = new ecc_point_t(k->Q); // deep copy - return ecc_point_ref{Q_copy}; -} - -cmem_t mpc_ecdsa2p_key_get_x_share(mpc_ecdsa2pc_key_ref* key) { - if (key == NULL || key->opaque == NULL) { - return cmem_t{nullptr, 0}; - } - ecdsa2pc::key_t* k = static_cast(key->opaque); - // Serialize bn_t to bytes (minimal length) preserving order size - int bin_size = std::max(k->x_share.get_bin_size(), k->curve.order().get_bin_size()); - buf_t x_buf = k->x_share.to_bin(bin_size); - return coinbase::ffi::copy_to_cmem(x_buf); -} - -int mpc_ecdsa2p_key_get_curve_code(mpc_ecdsa2pc_key_ref* key) { - if (key == NULL || key->opaque == NULL) { - return -1; - } - ecdsa2pc::key_t* k = static_cast(key->opaque); - return k->curve.get_openssl_code(); -} \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.go b/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.go deleted file mode 100644 index 1c0c42fd..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.go +++ /dev/null @@ -1,88 +0,0 @@ -package cgobinding - -/* -#include "ecdsa2p.h" -*/ -import "C" - -import ( - "fmt" - "runtime" -) - -type Mpc_ecdsa2pc_key_ref C.mpc_ecdsa2pc_key_ref - -// Free releases the underlying native key structure. -func (k *Mpc_ecdsa2pc_key_ref) Free() { - C.free_mpc_ecdsa2p_key(C.mpc_ecdsa2pc_key_ref(*k)) -} - -// DistributedKeyGen performs the two-party ECDSA DKG using a numeric curve code. -func DistributedKeyGen(job Job2P, curveCode int) (Mpc_ecdsa2pc_key_ref, error) { - var key Mpc_ecdsa2pc_key_ref - cErr := C.mpc_ecdsa2p_dkg(job.GetCJob(), C.int(curveCode), (*C.mpc_ecdsa2pc_key_ref)(&key)) - if cErr != 0 { - return key, fmt.Errorf("key generation failed, %v", cErr) - } - return key, nil -} - -// Refresh re-shares an existing 2-party ECDSA key. -func Refresh(job Job2P, key Mpc_ecdsa2pc_key_ref) (Mpc_ecdsa2pc_key_ref, error) { - var newKey Mpc_ecdsa2pc_key_ref - cErr := C.mpc_ecdsa2p_refresh(job.GetCJob(), (*C.mpc_ecdsa2pc_key_ref)(&key), (*C.mpc_ecdsa2pc_key_ref)(&newKey)) - if cErr != 0 { - return newKey, fmt.Errorf("ECDSA-2p refresh failed, %v", cErr) - } - return newKey, nil -} - -// Sign produces batch signatures using the two-party ECDSA key. -func Sign(job Job2P, sid []byte, key Mpc_ecdsa2pc_key_ref, msgs [][]byte) ([][]byte, error) { - var sigs CMEMS - pin := makeCmems(msgs) - cErr := C.mpc_ecdsa2p_sign(job.GetCJob(), cmem(sid), (*C.mpc_ecdsa2pc_key_ref)(&key), pin.c, &sigs) - runtime.KeepAlive(pin) - if cErr != 0 { - return nil, fmt.Errorf("ECDSA-2p sign failed, %v", cErr) - } - return CMEMSGet(sigs), nil -} - -// KeyRoleIndex returns the role index (e.g., 0 or 1) for the provided key share. -// A negative return value indicates an error at the native layer. -func KeyRoleIndex(key Mpc_ecdsa2pc_key_ref) (int, error) { - idx := int(C.mpc_ecdsa2p_key_get_role_index((*C.mpc_ecdsa2pc_key_ref)(&key))) - if idx < 0 { - return -1, fmt.Errorf("failed to get role index: %d", idx) - } - return idx, nil -} - -// KeyQ returns a reference to the public key point Q inside the 2PC key. -// The caller must eventually free the returned ECCPointRef via its Free method. -func KeyQ(key Mpc_ecdsa2pc_key_ref) (ECCPointRef, error) { - cPoint := C.mpc_ecdsa2p_key_get_Q((*C.mpc_ecdsa2pc_key_ref)(&key)) - if cPoint.opaque == nil { - return ECCPointRef{}, fmt.Errorf("failed to retrieve Q from key") - } - return ECCPointRef(cPoint), nil -} - -// KeyXShare returns the secret scalar share x_i as raw bytes (big-endian). -func KeyXShare(key Mpc_ecdsa2pc_key_ref) ([]byte, error) { - cMem := C.mpc_ecdsa2p_key_get_x_share((*C.mpc_ecdsa2pc_key_ref)(&key)) - if cMem.data == nil || cMem.size == 0 { - return nil, fmt.Errorf("failed to retrieve x_share from key") - } - return CMEMGet(cMem), nil -} - -// KeyCurveCode returns the OpenSSL NID of the curve used by the provided key. -func KeyCurveCode(key Mpc_ecdsa2pc_key_ref) (int, error) { - code := int(C.mpc_ecdsa2p_key_get_curve_code((*C.mpc_ecdsa2pc_key_ref)(&key))) - if code < 0 { - return 0, fmt.Errorf("failed to get curve code from key") - } - return code, nil -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.h b/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.h deleted file mode 100644 index 5723d757..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ecdsa2p.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include - -#include - -#include "curve.h" -#include "network.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// ------------------------- Type Wrappers --------------------------- -// Wrapper for coinbase::mpc::ecdsa2pc::key_t - -typedef struct mpc_ecdsa2pc_key_ref { - void* opaque; // Opaque pointer to the C++ class instance -} mpc_ecdsa2pc_key_ref; - -// ------------------------- Memory management ----------------------- -void free_mpc_ecdsa2p_key(mpc_ecdsa2pc_key_ref ctx); - -// ------------------------- Function Wrappers ----------------------- - -int mpc_ecdsa2p_dkg(job_2p_ref* job, int curve, mpc_ecdsa2pc_key_ref* key); - -int mpc_ecdsa2p_refresh(job_2p_ref* job, mpc_ecdsa2pc_key_ref* key, mpc_ecdsa2pc_key_ref* new_key); - -int mpc_ecdsa2p_sign(job_2p_ref* job, cmem_t sid, mpc_ecdsa2pc_key_ref* key, cmems_t msgs, cmems_t* sigs); - -// Returns the role index (e.g., 0 or 1) corresponding to the given key share. -// Returns a negative value on error. -int mpc_ecdsa2p_key_get_role_index(mpc_ecdsa2pc_key_ref* key); - -// Returns a freshly allocated copy of the public key point Q. -// The caller is responsible for freeing the returned ecc_point_ref via free_ecc_point. -ecc_point_ref mpc_ecdsa2p_key_get_Q(mpc_ecdsa2pc_key_ref* key); - -// Returns the secret share x_i of the private key as a byte slice (big-endian). -// The caller is responsible for freeing the returned memory via cgo_free. -cmem_t mpc_ecdsa2p_key_get_x_share(mpc_ecdsa2pc_key_ref* key); - -// Returns the numeric OpenSSL NID identifying the curve associated with the provided key. -// A negative value indicates an error. -int mpc_ecdsa2p_key_get_curve_code(mpc_ecdsa2pc_key_ref* key); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.cpp b/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.cpp deleted file mode 100644 index 4270cc59..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// ecdsamp.cpp – Signing-only bindings (key management moved to eckeymp.cpp) - -#include "ecdsamp.h" - -#include - -#include -#include -#include -#include - -#include "curve.h" -#include "network.h" - -using namespace coinbase; -using namespace coinbase::mpc; - -// ----------------------------------------------------------------------------- -// ECDSA-MPC signing helpers -// ----------------------------------------------------------------------------- - -int mpc_ecdsampc_sign(job_mp_ref* j, mpc_eckey_mp_ref* k, cmem_t msg_mem, int sig_receiver, cmem_t* sig_mem) { - job_mp_t* job = static_cast(j->opaque); - ecdsampc::key_t* key = static_cast(k->opaque); - - buf_t msg = coinbase::ffi::view(msg_mem); - buf_t sig; - error_t err = ecdsampc::sign(*job, *key, msg, party_idx_t(sig_receiver), sig); - if (err) return err; - *sig_mem = coinbase::ffi::copy_to_cmem(sig); - return 0; -} - -int mpc_ecdsampc_sign_with_ot_roles(job_mp_ref* j, mpc_eckey_mp_ref* k, cmem_t msg_mem, int sig_receiver, - cmems_t ot_role_map, int n_parties, cmem_t* sig_mem) { - job_mp_t* job = static_cast(j->opaque); - ecdsampc::key_t* key = static_cast(k->opaque); - - buf_t msg = coinbase::ffi::view(msg_mem); - std::vector role_bufs = coinbase::ffi::bufs_from_cmems(ot_role_map); - std::vector> ot_roles(n_parties, std::vector(n_parties)); - - for (int i = 0; i < n_parties; i++) { - if (i < role_bufs.size()) { - const uint8_t* data = role_bufs[i].data(); - for (int j = 0; j < n_parties && j * sizeof(int) < role_bufs[i].size(); j++) { - memcpy(&ot_roles[i][j], data + j * sizeof(int), sizeof(int)); - } - } - } - - buf_t sig; - error_t err = ecdsampc::sign(*job, *key, msg, party_idx_t(sig_receiver), ot_roles, sig); - if (err) return err; - - *sig_mem = coinbase::ffi::copy_to_cmem(sig); - return 0; -} \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.go b/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.go deleted file mode 100644 index d2bcacd7..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.go +++ /dev/null @@ -1,77 +0,0 @@ -package cgobinding - -/* -#include "ecdsamp.h" -#include "eckeymp.h" -*/ -import "C" - -import ( - "fmt" -) - -func MPC_ecdsampc_sign(job JobMP, key Mpc_eckey_mp_ref, msgMem []byte, sigReceiver int) ([]byte, error) { - var sigMem CMEM - cErr := C.mpc_ecdsampc_sign(job.GetCJob(), (*C.mpc_eckey_mp_ref)(&key), cmem(msgMem), C.int(sigReceiver), &sigMem) - if cErr != 0 { - return nil, fmt.Errorf("ECDSA-mp sign failed, %v", cErr) - } - return CMEMGet(sigMem), nil -} - -// ----------------------------------------------------------------------------- -// Signing with default OT-role map -// ----------------------------------------------------------------------------- - -func DefaultOTRoleMap(nParties int) [][]int { - const ( - OT_NO_ROLE = -1 - OT_SENDER = 0 - OT_RECEIVER = 1 - ) - - otRoleMap := make([][]int, nParties) - for i := 0; i < nParties; i++ { - otRoleMap[i] = make([]int, nParties) - otRoleMap[i][i] = OT_NO_ROLE - } - - for i := 0; i < nParties; i++ { - for j := i + 1; j < nParties; j++ { - otRoleMap[i][j] = OT_SENDER - otRoleMap[j][i] = OT_RECEIVER - } - } - - return otRoleMap -} - -func MPC_ecdsampc_sign_default_ot_roles(job JobMP, key Mpc_eckey_mp_ref, msgMem []byte, sigReceiver int, nParties int) ([]byte, error) { - otRoleMap := DefaultOTRoleMap(nParties) - // Convert OT role map to the required format (flattened byte slices) - roleData := make([][]byte, nParties) - for i := 0; i < nParties; i++ { - roleData[i] = make([]byte, nParties*4) // 4 bytes per int (little endian) - for j := 0; j < nParties && j < len(otRoleMap[i]); j++ { - role := otRoleMap[i][j] - roleData[i][j*4+0] = byte(role) - roleData[i][j*4+1] = byte(role >> 8) - roleData[i][j*4+2] = byte(role >> 16) - roleData[i][j*4+3] = byte(role >> 24) - } - } - - var sigMem CMEM - cErr := C.mpc_ecdsampc_sign_with_ot_roles( - job.GetCJob(), - (*C.mpc_eckey_mp_ref)(&key), - cmem(msgMem), - C.int(sigReceiver), - cmems(roleData), - C.int(nParties), - &sigMem) - if cErr != 0 { - return nil, fmt.Errorf("ECDSA-mp sign with OT roles failed, %v", cErr) - } - return CMEMGet(sigMem), nil -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.h b/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.h deleted file mode 100644 index fbaea7c6..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/ecdsamp.h +++ /dev/null @@ -1,18 +0,0 @@ -// ecdsamp.h – Signing-only C interface (key management moved to eckeymp.h) -#pragma once - -#include "eckeymp.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// Other ECDSA-MPC protocols are in eckeymp.h, to be shared with EdDSA-MPC protocols -int mpc_ecdsampc_sign(job_mp_ref* j, mpc_eckey_mp_ref* k, cmem_t msg_mem, int sig_receiver, cmem_t* sig_mem); - -int mpc_ecdsampc_sign_with_ot_roles(job_mp_ref* j, mpc_eckey_mp_ref* k, cmem_t msg_mem, int sig_receiver, - cmems_t ot_role_map, int n_parties, cmem_t* sig_mem); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.cpp b/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.cpp deleted file mode 100644 index d6693fff..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.cpp +++ /dev/null @@ -1,198 +0,0 @@ -#include "eckeymp.h" - -#include - -#include -#include -#include -#include -#include - -#include "curve.h" -#include "network.h" - -using namespace coinbase; -using namespace coinbase::crypto; -using namespace coinbase::mpc; - -// ------------------------- Memory helpers --------------------------- -void free_mpc_eckey_mp(mpc_eckey_mp_ref ctx) { - if (ctx.opaque) { - delete static_cast(ctx.opaque); - } -} - -// --------------------------- Field accessors ----------------------- -int mpc_eckey_mp_get_party_name(mpc_eckey_mp_ref* k, cmem_t* party_name_mem) { - if (k == nullptr || k->opaque == nullptr) { - return 1; // Invalid key reference - } - - eckey::key_share_mp_t* key = static_cast(k->opaque); - *party_name_mem = coinbase::ffi::copy_to_cmem(coinbase::mem_t(key->party_name)); - return 0; -} - -int mpc_eckey_mp_get_x_share(mpc_eckey_mp_ref* k, cmem_t* x_share_mem) { - if (k == nullptr || k->opaque == nullptr) { - return 1; // Invalid key reference - } - eckey::key_share_mp_t* key = static_cast(k->opaque); - buf_t x_buf = key->x_share.to_bin(key->curve.order().get_bin_size()); - *x_share_mem = coinbase::ffi::copy_to_cmem(x_buf); - return 0; -} - -ecc_point_ref mpc_eckey_mp_get_Q(mpc_eckey_mp_ref* k) { - if (k == nullptr || k->opaque == nullptr) { - return ecc_point_ref{nullptr}; - } - eckey::key_share_mp_t* key = static_cast(k->opaque); - ecc_point_t* Q_copy = new ecc_point_t(key->Q); - return ecc_point_ref{Q_copy}; -} - -ecurve_ref mpc_eckey_mp_get_curve(mpc_eckey_mp_ref* k) { - if (k == nullptr || k->opaque == nullptr) { - return ecurve_ref{nullptr}; - } - eckey::key_share_mp_t* key = static_cast(k->opaque); - // Allocate a copy so the caller can own it independently. - ecurve_t* curve_copy = new ecurve_t(key->curve); - return ecurve_ref{curve_copy}; -} - -int mpc_eckey_mp_get_Qis(mpc_eckey_mp_ref* k, cmems_t* party_names_mem, cmems_t* points_mem) { - if (k == nullptr || k->opaque == nullptr) { - return 1; // Invalid key reference - } - eckey::key_share_mp_t* key = static_cast(k->opaque); - - std::vector name_bufs; - std::vector point_bufs; - name_bufs.reserve(key->Qis.size()); - point_bufs.reserve(key->Qis.size()); - - for (const auto& kv : key->Qis) { - name_bufs.emplace_back(coinbase::mem_t(kv.first)); - point_bufs.push_back(coinbase::ser(kv.second)); - } - - *party_names_mem = coinbase::ffi::copy_to_cmems(buf_t::to_mems(name_bufs)); - *points_mem = coinbase::ffi::copy_to_cmems(buf_t::to_mems(point_bufs)); - return 0; -} - -// ------------------------- Protocols ----------------------------------------- -int mpc_eckey_mp_dkg(job_mp_ref* j, ecurve_ref* curve_ref, mpc_eckey_mp_ref* k) { - job_mp_t* job = static_cast(j->opaque); - ecurve_t* curve_ptr = static_cast(curve_ref->opaque); - if (curve_ptr == nullptr) { - return 1; // Invalid curve reference - } - - // Allocate key on the heap – ensure we release it on failure to avoid leaks. - std::unique_ptr key(new eckey::key_share_mp_t()); - - buf_t sid; - error_t err = eckey::key_share_mp_t::dkg(*job, *curve_ptr, *key, sid); - if (err) { - return err; // unique_ptr automatically frees memory - } - - // Transfer ownership to the caller – release smart pointer so object lives on. - *k = mpc_eckey_mp_ref{key.release()}; - return 0; -} - -int mpc_eckey_mp_refresh(job_mp_ref* j, cmem_t sid_mem, mpc_eckey_mp_ref* k, mpc_eckey_mp_ref* nk) { - job_mp_t* job = static_cast(j->opaque); - eckey::key_share_mp_t* key = static_cast(k->opaque); - - // Allocate new key with automatic cleanup on error. - std::unique_ptr new_key(new eckey::key_share_mp_t()); - - buf_t sid = coinbase::ffi::view(sid_mem); - error_t err = eckey::key_share_mp_t::refresh(*job, sid, *key, *new_key); - if (err) { - return err; // unique_ptr frees memory - } - - *nk = mpc_eckey_mp_ref{new_key.release()}; - return 0; -} - -// ------------------- Threshold / Quorum helpers -------------------- -int eckey_dkg_mp_threshold_dkg(job_mp_ref* job_ptr, ecurve_ref* curve_ref, cmem_t sid, crypto_ss_ac_ref* ac, - mpc_party_set_ref* quorum, mpc_eckey_mp_ref* key) { - job_mp_t* job = static_cast(job_ptr->opaque); - ecurve_t* curve_ptr = static_cast(curve_ref->opaque); - if (curve_ptr == nullptr) { - return 1; // Invalid curve reference - } - - buf_t sid_buf = coinbase::ffi::view(sid); - crypto::ss::ac_t* ac_obj = static_cast(ac->opaque); - party_set_t* quorum_set = static_cast(quorum->opaque); - - // Allocate key share with RAII – will auto free on early return. - std::unique_ptr key_share(new eckey::key_share_mp_t()); - error_t err = eckey::key_share_mp_t::threshold_dkg(*job, *curve_ptr, sid_buf, *ac_obj, *quorum_set, *key_share); - if (err) { - return err; // unique_ptr cleans up - } - - *key = mpc_eckey_mp_ref{key_share.release()}; - return 0; -} - -int eckey_key_share_mp_to_additive_share(mpc_eckey_mp_ref* key, crypto_ss_ac_ref* ac, cmems_t quorum_party_names, - mpc_eckey_mp_ref* additive_key) { - eckey::key_share_mp_t* key_share = static_cast(key->opaque); - crypto::ss::ac_t* ac_obj = static_cast(ac->opaque); - - std::vector name_bufs = coinbase::ffi::bufs_from_cmems(quorum_party_names); - std::set quorum_names; - for (const auto& name_buf : name_bufs) { - quorum_names.insert(name_buf.to_string()); - } - - // Allocate additive share with RAII to avoid leaks on error. - std::unique_ptr additive_share(new eckey::key_share_mp_t()); - error_t err = key_share->to_additive_share(*ac_obj, quorum_names, *additive_share); - if (err) { - return err; // unique_ptr cleans up automatically - } - - *additive_key = mpc_eckey_mp_ref{additive_share.release()}; - return 0; -} - -// --------------------------- Utilities ----------------------------- -int serialize_mpc_eckey_mp(mpc_eckey_mp_ref* k, cmems_t* ser) { - eckey::key_share_mp_t* key = static_cast(k->opaque); - - auto x = coinbase::ser(key->x_share); - auto Q = coinbase::ser(key->Q); - auto Qis = coinbase::ser(key->Qis); - auto curve = coinbase::ser(key->curve); - auto party_name = coinbase::ser(key->party_name); - - auto out = std::vector{x, Q, Qis, curve, party_name}; - *ser = coinbase::ffi::copy_to_cmems(out); - return 0; -} - -int deserialize_mpc_eckey_mp(cmems_t sers, mpc_eckey_mp_ref* k) { - std::unique_ptr key(new eckey::key_share_mp_t()); - std::vector sers_vec = coinbase::ffi::bufs_from_cmems(sers); - - if (coinbase::deser(sers_vec[0], key->x_share)) return 1; - if (coinbase::deser(sers_vec[1], key->Q)) return 1; - if (coinbase::deser(sers_vec[2], key->Qis)) return 1; - if (coinbase::deser(sers_vec[3], key->curve)) return 1; - if (coinbase::deser(sers_vec[4], key->party_name)) return 1; - - *k = mpc_eckey_mp_ref{key.release()}; - return 0; -} \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.go b/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.go deleted file mode 100644 index d4412b27..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.go +++ /dev/null @@ -1,245 +0,0 @@ -package cgobinding - -/* -#cgo CFLAGS: -Werror -#include "eckeymp.h" -*/ -import "C" - -import "fmt" - -type Mpc_eckey_mp_ref C.mpc_eckey_mp_ref - -// SerializeKeyShare converts an mpc_eckey_mp_ref into a slice of byte buffers -// that fully represent the secret-share. The data is suitable for short-term -// transport or caching. It should NOT be relied upon for long-term storage -// across library versions. -func SerializeKeyShare(key Mpc_eckey_mp_ref) ([][]byte, error) { - var ser CMEMS - err := C.serialize_mpc_eckey_mp((*C.mpc_eckey_mp_ref)(&key), &ser) - if err != 0 { - return nil, fmt.Errorf("serialize_mpc_eckey_mp failed: %v", err) - } - return CMEMSGet(ser), nil -} - -// DeserializeKeyShare allocates a fresh key-share object from the byte buffers -// produced by SerializeKeyShare and returns a reference to it. -func DeserializeKeyShare(ser [][]byte) (Mpc_eckey_mp_ref, error) { - var key Mpc_eckey_mp_ref - err := C.deserialize_mpc_eckey_mp(cmems(ser), (*C.mpc_eckey_mp_ref)(&key)) - if err != 0 { - return Mpc_eckey_mp_ref{}, fmt.Errorf("deserialize_mpc_eckey_mp failed: %v", err) - } - return key, nil -} - -// ----------------------------------------------------------------------------- -// Backwards-compatibility thin wrappers (deprecated) -// ----------------------------------------------------------------------------- - -// SerializeECDSAShare is kept for historical reasons. New code should migrate -// to SerializeKeyShare. -func SerializeECDSAShare(key Mpc_eckey_mp_ref) ([][]byte, error) { return SerializeKeyShare(key) } - -// DeserializeECDSAShare is kept for historical reasons. New code should migrate -// to DeserializeKeyShare. -func DeserializeECDSAShare(ser [][]byte) (Mpc_eckey_mp_ref, error) { - return DeserializeKeyShare(ser) -} - -// KeyShareDKG performs distributed key generation and returns a key-share -// reference. It is algorithm agnostic – it only depends on the underlying -// Schnorr-style key-share representation. -func KeyShareDKG(job JobMP, curveRef ECurveRef) (Mpc_eckey_mp_ref, error) { - var key Mpc_eckey_mp_ref - cErr := C.mpc_eckey_mp_dkg(job.GetCJob(), (*C.ecurve_ref)(&curveRef), (*C.mpc_eckey_mp_ref)(&key)) - if cErr != 0 { - return key, fmt.Errorf("key-share DKG failed, %v", cErr) - } - return key, nil -} - -// KeyShareDKGCode is a convenience wrapper that takes a curve code instead of a native curve ref. -func KeyShareDKGCode(job JobMP, curveCode int) (Mpc_eckey_mp_ref, error) { - ref, err := ECurveFind(curveCode) - if err != nil { - return Mpc_eckey_mp_ref{}, err - } - return KeyShareDKG(job, ref) -} - -// KeyShareRefresh rerandomises the secret shares while keeping the aggregated -// public key unchanged. -func KeyShareRefresh(job JobMP, sid []byte, key Mpc_eckey_mp_ref) (Mpc_eckey_mp_ref, error) { - if sid == nil { - sid = make([]byte, 0) - } - var newKey Mpc_eckey_mp_ref - cErr := C.mpc_eckey_mp_refresh(job.GetCJob(), cmem(sid), (*C.mpc_eckey_mp_ref)(&key), (*C.mpc_eckey_mp_ref)(&newKey)) - if cErr != 0 { - return newKey, fmt.Errorf("key-share refresh failed, %v", cErr) - } - return newKey, nil -} - -// ThresholdDKG runs a threshold Distributed Key Generation for Schnorr-style -// keys (used by both ECDSA-MPC and EdDSA-MPC). It returns a fresh key share -// owned by the calling party. -func ThresholdDKG(job JobMP, curveRef ECurveRef, sid []byte, ac C_AcPtr, roleIndices []int) (Mpc_eckey_mp_ref, error) { - if sid == nil { - sid = make([]byte, 0) - } - - quorum := NewPartySet() - defer quorum.Free() - for _, idx := range roleIndices { - quorum.Add(idx) - } - - var key Mpc_eckey_mp_ref - cErr := C.eckey_dkg_mp_threshold_dkg( - job.GetCJob(), - (*C.ecurve_ref)(&curveRef), - cmem(sid), - (*C.crypto_ss_ac_ref)(&ac), - (*C.mpc_party_set_ref)(&quorum), - (*C.mpc_eckey_mp_ref)(&key)) - if cErr != 0 { - return key, fmt.Errorf("threshold DKG failed, %v", cErr) - } - return key, nil -} - -// ThresholdDKGCode mirrors ThresholdDKG but accepts a curve code. -func ThresholdDKGCode(job JobMP, curveCode int, sid []byte, ac C_AcPtr, roleIndices []int) (Mpc_eckey_mp_ref, error) { - ref, err := ECurveFind(curveCode) - if err != nil { - return Mpc_eckey_mp_ref{}, err - } - return ThresholdDKG(job, ref, sid, ac, roleIndices) -} - -// Back-compat synonym. -func KeyShareThresholdDKG(job JobMP, curveRef ECurveRef, sid []byte, ac C_AcPtr, roleIndices []int) (Mpc_eckey_mp_ref, error) { - return ThresholdDKG(job, curveRef, sid, ac, roleIndices) -} - -// ToAdditiveShare converts a multiplicative share into an additive one under the given access structure and quorum names. -func (key *Mpc_eckey_mp_ref) ToAdditiveShare(ac C_AcPtr, quorumPartyNames []string) (Mpc_eckey_mp_ref, error) { - var additiveKey Mpc_eckey_mp_ref - - nameBytes := make([][]byte, len(quorumPartyNames)) - for i, name := range quorumPartyNames { - nameBytes[i] = []byte(name) - } - - cErr := C.eckey_key_share_mp_to_additive_share( - (*C.mpc_eckey_mp_ref)(key), - (*C.crypto_ss_ac_ref)(&ac), - cmems(nameBytes), - (*C.mpc_eckey_mp_ref)(&additiveKey)) - if cErr != 0 { - return additiveKey, fmt.Errorf("to_additive_share failed, %v", cErr) - } - return additiveKey, nil -} - -// ----------------------------------------------------------------------------- -// Accessors (shared between ECDSA-MPC and EdDSA-MPC) -// ----------------------------------------------------------------------------- - -// KeySharePartyName returns the party identifier associated with the key share. -func KeySharePartyName(key Mpc_eckey_mp_ref) (string, error) { - var nameMem CMEM - err := C.mpc_eckey_mp_get_party_name((*C.mpc_eckey_mp_ref)(&key), &nameMem) - if err != 0 { - return "", fmt.Errorf("getting party name failed, %v", err) - } - return string(CMEMGet(nameMem)), nil -} - -// KeyShareXShare returns the secret scalar held by this party. -func KeyShareXShare(key Mpc_eckey_mp_ref) ([]byte, error) { - var xShareMem CMEM - err := C.mpc_eckey_mp_get_x_share((*C.mpc_eckey_mp_ref)(&key), &xShareMem) - if err != 0 { - return nil, fmt.Errorf("getting x_share failed, %v", err) - } - return CMEMGet(xShareMem), nil -} - -// KeyShareQ returns a reference to the aggregated public key point Q. -func KeyShareQ(key Mpc_eckey_mp_ref) (ECCPointRef, error) { - cPoint := C.mpc_eckey_mp_get_Q((*C.mpc_eckey_mp_ref)(&key)) - if cPoint.opaque == nil { - return ECCPointRef{}, fmt.Errorf("failed to retrieve Q from key") - } - return ECCPointRef(cPoint), nil -} - -// KeyShareCurve returns the curve associated with the key share. -func KeyShareCurve(key Mpc_eckey_mp_ref) (ECurveRef, error) { - cRef := C.mpc_eckey_mp_get_curve((*C.mpc_eckey_mp_ref)(&key)) - if cRef.opaque == nil { - return ECurveRef{}, fmt.Errorf("failed to get curve from key") - } - return ECurveRef(cRef), nil -} - -// KeyShareQis returns per-party public key shares. -func KeyShareQis(key Mpc_eckey_mp_ref) ([][]byte, [][]byte, error) { - var nameMems CMEMS - var pointMems CMEMS - cErr := C.mpc_eckey_mp_get_Qis((*C.mpc_eckey_mp_ref)(&key), &nameMems, &pointMems) - if cErr != 0 { - return nil, nil, fmt.Errorf("getting Qis failed, %v", cErr) - } - names := CMEMSGet(nameMems) - points := CMEMSGet(pointMems) - if len(names) != len(points) { - return nil, nil, fmt.Errorf("inconsistent Qis arrays: %d names vs %d points", len(names), len(points)) - } - return names, points, nil -} - -// KeyShareCurveCode returns the numeric curve code associated with the key share. -func KeyShareCurveCode(key Mpc_eckey_mp_ref) (int, error) { - ref, err := KeyShareCurve(key) - if err != nil { - return 0, err - } - return ECurveGetCurveCode(ref), nil -} - -// KeyShareQBytes returns the Q point as bytes. -func KeyShareQBytes(key Mpc_eckey_mp_ref) ([]byte, error) { - ref, err := KeyShareQ(key) - if err != nil { - return nil, err - } - bytes := ECCPointToBytes(ref) - (&ref).Free() - return bytes, nil -} - -// Free releases the underlying native key-share object. -func (k *Mpc_eckey_mp_ref) Free() { - C.free_mpc_eckey_mp(C.mpc_eckey_mp_ref(*k)) -} - -// ----------------------------------------------------------------------------- -// Back-compat wrappers with legacy names -// ----------------------------------------------------------------------------- - -func MPC_mpc_eckey_mp_get_party_name(key Mpc_eckey_mp_ref) (string, error) { - return KeySharePartyName(key) -} -func MPC_mpc_eckey_mp_get_x_share(key Mpc_eckey_mp_ref) ([]byte, error) { - return KeyShareXShare(key) -} -func MPC_mpc_eckey_mp_Q(key Mpc_eckey_mp_ref) (ECCPointRef, error) { return KeyShareQ(key) } -func MPC_mpc_eckey_mp_curve(key Mpc_eckey_mp_ref) (ECurveRef, error) { return KeyShareCurve(key) } -func MPC_mpc_eckey_mp_Qis(key Mpc_eckey_mp_ref) ([][]byte, [][]byte, error) { - return KeyShareQis(key) -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.h b/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.h deleted file mode 100644 index c4a95523..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/eckeymp.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include - -#include - -#include "ac.h" -#include "curve.h" -#include "network.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// ----------------------------------------------------------------------------- -// Common opaque key reference (same as original definition) -// ----------------------------------------------------------------------------- -typedef struct mpc_eckey_mp_ref { - void* opaque; -} mpc_eckey_mp_ref; - -// ------------------------- Memory helpers ------------------------------------ -void free_mpc_eckey_mp(mpc_eckey_mp_ref ctx); - -// --------------------------- Field accessors --------------------------------- -int mpc_eckey_mp_get_party_name(mpc_eckey_mp_ref* k, cmem_t* party_name_mem); -int mpc_eckey_mp_get_x_share(mpc_eckey_mp_ref* k, cmem_t* x_share_mem); -// Returns a newly allocated ecc_point_t copy – caller must free with free_ecc_point -// (see curve.h). -ecc_point_ref mpc_eckey_mp_get_Q(mpc_eckey_mp_ref* k); -ecurve_ref mpc_eckey_mp_get_curve(mpc_eckey_mp_ref* k); -int mpc_eckey_mp_get_Qis(mpc_eckey_mp_ref* k, cmems_t* party_names_mem, cmems_t* points_mem); - -// ------------------------- Protocols ----------------------------------------- -int mpc_eckey_mp_dkg(job_mp_ref* j, ecurve_ref* curve, mpc_eckey_mp_ref* k); -int mpc_eckey_mp_refresh(job_mp_ref* j, cmem_t sid_mem, mpc_eckey_mp_ref* k, mpc_eckey_mp_ref* new_key); - -// --------------------- Threshold / Quorum helpers --------------------------- -int eckey_dkg_mp_threshold_dkg(job_mp_ref* job, ecurve_ref* curve, cmem_t sid, crypto_ss_ac_ref* ac, - mpc_party_set_ref* quorum, mpc_eckey_mp_ref* key); - -int eckey_key_share_mp_to_additive_share(mpc_eckey_mp_ref* key, crypto_ss_ac_ref* ac, cmems_t quorum_party_names, - mpc_eckey_mp_ref* additive_key); - -// ------------------------- Utilities ----------------------------------------- -int serialize_mpc_eckey_mp(mpc_eckey_mp_ref* k, cmems_t* ser); -int deserialize_mpc_eckey_mp(cmems_t ser, mpc_eckey_mp_ref* k); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.cpp b/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.cpp deleted file mode 100644 index 1bc7d996..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.cpp +++ /dev/null @@ -1,32 +0,0 @@ -// eddsamp.cpp – Signing-only bindings for EdDSA multi-party - -#include "eddsamp.h" - -#include - -#include -#include -#include -#include - -#include "curve.h" -#include "network.h" - -using namespace coinbase; -using namespace coinbase::mpc; - -// ----------------------------------------------------------------------------- -// EdDSA-MPC signing helper -// ----------------------------------------------------------------------------- - -int mpc_eddsampc_sign(job_mp_ref* j, mpc_eckey_mp_ref* k, cmem_t msg_mem, int sig_receiver, cmem_t* sig_mem) { - job_mp_t* job = static_cast(j->opaque); - eddsampc::key_t* key = static_cast(k->opaque); - - buf_t msg = coinbase::ffi::view(msg_mem); - buf_t sig; - error_t err = eddsampc::sign(*job, *key, msg, party_idx_t(sig_receiver), sig); - if (err) return err; - *sig_mem = coinbase::ffi::copy_to_cmem(sig); - return 0; -} \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.go b/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.go deleted file mode 100644 index 036bded6..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.go +++ /dev/null @@ -1,25 +0,0 @@ -package cgobinding - -/* -#include "eddsamp.h" -*/ -import "C" - -import "fmt" - -// ----------------------------------------------------------------------------- -// EdDSA-MPC signing binding (key management shared with ECDSA) -// ----------------------------------------------------------------------------- - -type Mpc_eddsampc_key_ref = Mpc_eckey_mp_ref // underlying type is identical - -// MPC_eddsampc_sign performs the N-party EdDSA signing protocol. -// It mirrors MPC_ecdsampc_sign but uses the EdDSA Schnorr variant internally. -func MPC_eddsampc_sign(job JobMP, key Mpc_eckey_mp_ref, msgMem []byte, sigReceiver int) ([]byte, error) { - var sigMem CMEM - cErr := C.mpc_eddsampc_sign(job.GetCJob(), (*C.mpc_eckey_mp_ref)(&key), cmem(msgMem), C.int(sigReceiver), &sigMem) - if cErr != 0 { - return nil, fmt.Errorf("EdDSA-mp sign failed, %v", cErr) - } - return CMEMGet(sigMem), nil -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.h b/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.h deleted file mode 100644 index 0b010594..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/eddsamp.h +++ /dev/null @@ -1,15 +0,0 @@ -// eddsamp.h – Signing-only C interface for EdDSA multi-party -#pragma once - -#include "eckeymp.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// EdDSA-MPC signing API (other key-management functions are shared via eckeymp.h) -int mpc_eddsampc_sign(job_mp_ref* j, mpc_eckey_mp_ref* k, cmem_t msg_mem, int sig_receiver, cmem_t* sig_mem); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/kem.h b/demos-go/cb-mpc-go/internal/cgobinding/kem.h deleted file mode 100644 index a5080670..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/kem.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// KEM-specific context-based PKI callbacks -typedef int (*kem_encap_ctx_fn)(void* ctx, cmem_t /* ek_bytes */, cmem_t /* rho */, cmem_t* /* kem_ct out */, - cmem_t* /* kem_ss out */); - -// Private key is passed as an opaque handle owned by the caller. For byte-based -// keys, the handle points to a cmem_t describing the bytes for the duration of -// the call. -typedef int (*kem_decap_ctx_fn)(void* ctx, const void* /* dk_handle */, cmem_t /* kem_ct */, cmem_t* /* kem_ss out */); - -typedef int (*kem_dk_to_ek_ctx_fn)(void* ctx, const void* /* dk_handle */, cmem_t* /* out ek_bytes */); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/network.cpp b/demos-go/cb-mpc-go/internal/cgobinding/network.cpp deleted file mode 100644 index 2559478d..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/network.cpp +++ /dev/null @@ -1,342 +0,0 @@ -#include "network.h" - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -using namespace coinbase; -using namespace coinbase::mpc; - -namespace { -constexpr int SUCCESS_CODE = 0; -constexpr int ERROR_CODE = -1; -constexpr int PARAM_ERROR_CODE = -2; - -// Helper function to validate party names -bool validate_party_names(const char* const* pnames, int count) noexcept { - if (!pnames) return false; - for (int i = 0; i < count; ++i) { - if (!pnames[i] || std::string_view(pnames[i]).empty()) { - return false; - } - } - return true; -} - -// Helper functions to validate and dereference pointers before construction -const data_transport_callbacks_t& validate_and_deref_callbacks(const data_transport_callbacks_t* callbacks_ptr) { - if (!callbacks_ptr) { - throw std::invalid_argument("callbacks_ptr cannot be null"); - } - if (!callbacks_ptr->send_fun || !callbacks_ptr->receive_fun || !callbacks_ptr->receive_all_fun) { - throw std::invalid_argument("all callback functions must be provided"); - } - return *callbacks_ptr; -} - -void* validate_go_impl_ptr(void* go_impl_ptr) { - if (!go_impl_ptr) { - throw std::invalid_argument("go_impl_ptr cannot be null"); - } - return go_impl_ptr; -} - -// RAII wrapper for job references -template -struct JobDeleter { - void operator()(JobType* job) const noexcept { - if constexpr (std::is_same_v) { - free_job_2p(job); - } else { - free_job_mp(job); - } - } -}; - -template -using unique_job_ptr = std::unique_ptr>; -} // namespace - -void free_job_2p(job_2p_ref* ptr) { - if (!ptr) return; - - if (ptr->opaque) { - try { - delete static_cast(ptr->opaque); - } catch (const std::exception& e) { - std::cerr << "Error freeing job_2p: " << e.what() << std::endl; - } - ptr->opaque = nullptr; - } - delete ptr; -} - -void free_job_mp(job_mp_ref* ptr) { - if (!ptr) return; - - if (ptr->opaque) { - try { - delete static_cast(ptr->opaque); - } catch (const std::exception& e) { - std::cerr << "Error freeing job_mp: " << e.what() << std::endl; - } - ptr->opaque = nullptr; - } - delete ptr; -} - -class callback_data_transport_t : public data_transport_interface_t { - private: - const data_transport_callbacks_t callbacks; - void* const go_impl_ptr; - - public: - callback_data_transport_t(const data_transport_callbacks_t* callbacks_ptr, void* go_impl_ptr) - : callbacks(validate_and_deref_callbacks(callbacks_ptr)), go_impl_ptr(validate_go_impl_ptr(go_impl_ptr)) { - // Validation is now done safely in the helper functions before dereferencing - } - - error_t send(const party_idx_t receiver, mem_t msg) override { - cmem_t cmsg{msg.data, msg.size}; - int result = callbacks.send_fun(go_impl_ptr, receiver, cmsg); - return error_t(result); - } - - error_t receive(const party_idx_t sender, buf_t& msg) override { - cmem_t cmsg{nullptr, 0}; - error_t rv = UNINITIALIZED_ERROR; - if (rv = error_t(callbacks.receive_fun(go_impl_ptr, sender, &cmsg))) return rv; - msg = coinbase::ffi::copy_from_cmem_and_free(cmsg); - return SUCCESS; - } - - error_t receive_all(const std::vector& senders, std::vector& msgs) override { - const auto n = static_cast(senders.size()); - if (n == 0) { - msgs.clear(); - return SUCCESS; - } - - // Use stack allocation for small arrays, heap for larger ones - constexpr int STACK_THRESHOLD = 64; - std::vector c_senders; - c_senders.reserve(n); - - for (const auto sender : senders) { - c_senders.push_back(sender); - } - - // Ensure cmsgs is initialized so that error paths never leave us with - // uninitialized pointers/count. - cmems_t cmsgs{0, nullptr, nullptr}; - const int result = callbacks.receive_all_fun(go_impl_ptr, const_cast(c_senders.data()), n, &cmsgs); - if (error_t rv = error_t(result)) { - msgs.clear(); - return rv; - } - - // Copy results out of the cgo-owned buffers, then free them (Go side uses C.malloc). - msgs = coinbase::ffi::bufs_from_cmems(cmsgs); - cgo_free(cmsgs.data); - cgo_free(cmsgs.sizes); - return SUCCESS; - } -}; - -job_2p_ref* new_job_2p(const data_transport_callbacks_t* callbacks, void* go_impl_ptr, int index, - const char* const* pnames, int pname_count) { - // Input validation with specific error codes - if (pname_count != 2) { - std::cerr << "Error: expected exactly 2 pnames, got " << pname_count << std::endl; - return nullptr; - } - - if (!callbacks || !go_impl_ptr) { - std::cerr << "Error: null parameters passed to new_job_2p" << std::endl; - return nullptr; - } - - if (!validate_party_names(pnames, pname_count)) { - std::cerr << "Error: invalid party names" << std::endl; - return nullptr; - } - - try { - auto data_transport_ptr = std::make_shared(callbacks, go_impl_ptr); - auto job_impl = - std::make_unique(party_t(index), std::string(pnames[0]), std::string(pnames[1]), data_transport_ptr); - - auto result = std::make_unique(); - result->opaque = job_impl.release(); - return result.release(); - - } catch (const std::exception& e) { - std::cerr << "Error creating job_2p: " << e.what() << std::endl; - return nullptr; - } -} - -#define VALIDATE_JOB_2P(job) \ - do { \ - if (!job || !job->opaque) { \ - return NETWORK_INVALID_STATE; \ - } \ - } while (0) - -#define GET_JOB_2P(job) static_cast(job->opaque) - -int is_peer1(const job_2p_ref* job) { - if (!job || !job->opaque) return 0; - return static_cast(job->opaque)->is_p1() ? 1 : 0; -} - -int is_peer2(const job_2p_ref* job) { - if (!job || !job->opaque) return 0; - return static_cast(job->opaque)->is_p2() ? 1 : 0; -} - -int is_role_index(const job_2p_ref* job, int party_index) { - if (!job || !job->opaque) return 0; - return static_cast(job->opaque)->is_party_idx(party_index) ? 1 : 0; -} - -int get_role_index(const job_2p_ref* job) { - if (!job || !job->opaque) return -1; - return static_cast(static_cast(job->opaque)->get_party_idx()); -} - -int mpc_2p_send(job_2p_ref* job, int receiver, cmem_t msg) { - if (!job || !job->opaque) return NETWORK_INVALID_STATE; - if (!msg.data && msg.size > 0) return NETWORK_PARAM_ERROR; - if (msg.size < 0) return NETWORK_PARAM_ERROR; - - try { - job_2p_t* j = GET_JOB_2P(job); - buf_t msg_buf{coinbase::ffi::view(msg)}; - error_t result = j->send(party_idx_t(receiver), msg_buf); - return static_cast(result); - } catch (const std::exception& e) { - std::cerr << "Error in mpc_2p_send: " << e.what() << std::endl; - return NETWORK_ERROR; - } -} - -int mpc_2p_receive(job_2p_ref* job, int sender, cmem_t* msg) { - if (!job || !job->opaque || !msg) return NETWORK_PARAM_ERROR; - - try { - job_2p_t* j = GET_JOB_2P(job); - buf_t msg_buf; - error_t err = j->receive(party_idx_t(sender), msg_buf); - - if (err) return static_cast(err); - - msg->size = static_cast(msg_buf.size()); - if (msg->size > 0) { - msg->data = static_cast(malloc(msg->size)); - if (!msg->data) return NETWORK_MEMORY_ERROR; - memcpy(msg->data, msg_buf.data(), msg->size); - } else { - msg->data = nullptr; - } - - return NETWORK_SUCCESS; - } catch (const std::exception& e) { - std::cerr << "Error in mpc_2p_receive: " << e.what() << std::endl; - return NETWORK_ERROR; - } -} - -job_mp_ref* new_job_mp(const data_transport_callbacks_t* callbacks, void* go_impl_ptr, int party_count, int index, - const char* const* pnames, int pname_count) { - // Input validation - if (pname_count != party_count) { - std::cerr << "Error: pname_count (" << pname_count << ") does not match party_count (" << party_count << ")" - << std::endl; - return nullptr; - } - - if (party_count <= 0) { - std::cerr << "Error: party_count must be positive, got " << party_count << std::endl; - return nullptr; - } - - if (!callbacks || !go_impl_ptr) { - std::cerr << "Error: null parameters passed to new_job_mp" << std::endl; - return nullptr; - } - - if (!validate_party_names(pnames, pname_count)) { - std::cerr << "Error: invalid party names" << std::endl; - return nullptr; - } - - try { - auto data_transport_ptr = std::make_shared(callbacks, go_impl_ptr); - - std::vector pnames_vec; - pnames_vec.reserve(party_count); - for (int i = 0; i < party_count; ++i) { - pnames_vec.emplace_back(pnames[i]); - } - - auto job_impl = std::make_unique(party_idx_t(index), std::move(pnames_vec), data_transport_ptr); - - auto result = std::make_unique(); - result->opaque = job_impl.release(); - return result.release(); - - } catch (const std::exception& e) { - std::cerr << "Error creating job_mp: " << e.what() << std::endl; - return nullptr; - } -} - -#define VALIDATE_JOB_MP(job) \ - do { \ - if (!job || !job->opaque) { \ - return NETWORK_INVALID_STATE; \ - } \ - } while (0) - -#define GET_JOB_MP(job) static_cast(job->opaque) - -int is_party(const job_mp_ref* job, int party_index) { - if (!job || !job->opaque) return 0; - return static_cast(job->opaque)->is_party_idx(party_index) ? 1 : 0; -} - -int get_party_idx(const job_mp_ref* job) { - if (!job || !job->opaque) return -1; - return static_cast(static_cast(job->opaque)->get_party_idx()); -} - -int get_n_parties(const job_mp_ref* job) { - if (!job || !job->opaque) return -1; - return static_cast(static_cast(job->opaque)->get_n_parties()); -} - -mpc_party_set_ref new_party_set() { - party_set_t* set = new party_set_t(); - return mpc_party_set_ref{set}; -} - -void party_set_add(mpc_party_set_ref* set, int party_idx) { - party_set_t* party_set = static_cast(set->opaque); - party_set->add(party_idx); -} - -void free_party_set(mpc_party_set_ref ctx) { - if (ctx.opaque) { - delete static_cast(ctx.opaque); - } -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/network.go b/demos-go/cb-mpc-go/internal/cgobinding/network.go deleted file mode 100644 index 04ae878c..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/network.go +++ /dev/null @@ -1,496 +0,0 @@ -package cgobinding - -import ( - "context" - "crypto/x509" - "fmt" - "sync" - "unsafe" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mtls" -) - -/* -#include -#include -#include "network.h" - -extern int callback_send(void*, int, cmem_t); -extern int callback_receive(void*, int, cmem_t*); -extern int callback_receive_all(void*, int*, int, cmems_t*); - -static void set_callbacks(data_transport_callbacks_t* dt_callbacks) -{ - dt_callbacks->send_fun = callback_send; - dt_callbacks->receive_fun = callback_receive; - dt_callbacks->receive_all_fun = callback_receive_all; -} -*/ -import "C" - -// Error constants matching C++ definitions -const ( - NetworkSuccess = 0 - NetworkError = -1 - NetworkParamError = -2 - NetworkMemoryError = -3 - NetworkInvalidState = -4 -) - -// Transport interface aliases for backward compatibility -type IDataTransport = transport.Messenger -type MTLSDataTransport = mtls.MTLSMessenger -type PartyConfig = mtls.PartyConfig -type Config = mtls.Config - -type MpcPartySetRef C.mpc_party_set_ref - -// --------------------------------------------------------------------------- -// Party-set helpers (moved from ecdsamp.go) - -// NewPartySet allocates a new party set and returns its opaque reference. -func NewPartySet() MpcPartySetRef { - set := C.new_party_set() - return MpcPartySetRef(set) -} - -// Add inserts a party index into the set. -func (ps *MpcPartySetRef) Add(partyIdx int) { - C.party_set_add((*C.mpc_party_set_ref)(ps), C.int(partyIdx)) -} - -// Free releases the underlying C++ party_set_t instance. -func (ps *MpcPartySetRef) Free() { - C.free_party_set(C.mpc_party_set_ref(*ps)) -} - -// Data transport implementation management with better type safety -var dtImplMap = sync.Map{} - -func SetDTImpl(dtImpl any) (unsafe.Pointer, error) { - if dtImpl == nil { - return nil, fmt.Errorf("data transport implementation cannot be nil") - } - ptr := C.malloc(1) - if ptr == nil { - return nil, fmt.Errorf("failed to allocate memory for data transport pointer") - } - dtImplMap.Store(ptr, dtImpl) - return ptr, nil -} - -func FreeDTImpl(ptr unsafe.Pointer) error { - if ptr == nil { - return nil // Not an error to free a nil pointer - } - _, loaded := dtImplMap.LoadAndDelete(ptr) - if !loaded { - return fmt.Errorf("attempt to free unknown data transport pointer") - } - C.free(ptr) - return nil -} - -func GetDTImpl(ptr unsafe.Pointer) (any, error) { - if ptr == nil { - return nil, fmt.Errorf("cannot get implementation from nil pointer") - } - dtImpl, ok := dtImplMap.Load(ptr) - if !ok { - return nil, fmt.Errorf("failed to load dtImpl from pointer") - } - return dtImpl, nil -} - -// Callback functions -var callbacks C.data_transport_callbacks_t - -//export callback_send -func callback_send(ptr unsafe.Pointer, receiver C.int, message C.cmem_t) C.int { - dtImpl, err := GetDTImpl(ptr) - if err != nil { - return C.int(NetworkError) - } - - transport, ok := dtImpl.(*IDataTransport) - if !ok { - return C.int(NetworkError) - } - - var goBytes []byte - if message.size > 0 && message.data != nil { - goBytes = C.GoBytes(unsafe.Pointer(message.data), message.size) - } - - if err := (*transport).MessageSend(context.Background(), int(receiver), goBytes); err != nil { - return C.int(NetworkError) - } - - return C.int(NetworkSuccess) -} - -//export callback_receive -func callback_receive(ptr unsafe.Pointer, sender C.int, message *C.cmem_t) C.int { - dtImpl, err := GetDTImpl(ptr) - if err != nil { - return C.int(NetworkError) - } - - transport, ok := dtImpl.(*IDataTransport) - if !ok { - return C.int(NetworkError) - } - - received, err := (*transport).MessageReceive(context.Background(), int(sender)) - if err != nil { - return C.int(NetworkError) - } - - message.size = C.int(len(received)) - if len(received) > 0 { - buf := C.malloc(C.size_t(len(received))) - if buf == nil { - return C.int(NetworkMemoryError) - } - C.memcpy(buf, unsafe.Pointer(&received[0]), C.size_t(len(received))) - message.data = (*C.uint8_t)(buf) - } else { - message.data = nil - } - - return C.int(NetworkSuccess) -} - -// Array manipulation utilities - optimized for performance -var ( - cIntSize = int(unsafe.Sizeof(C.int(0))) - cPtrSize = int(unsafe.Sizeof(unsafe.Pointer(nil))) -) - -func arrGetIntC(arr unsafe.Pointer, index int) int { - ptr := (*C.int)(unsafe.Pointer(uintptr(arr) + uintptr(index*cIntSize))) - return int(*ptr) -} - -func arrSetIntC(arr unsafe.Pointer, index int, value int) { - ptr := (*C.int)(unsafe.Pointer(uintptr(arr) + uintptr(index*cIntSize))) - *ptr = C.int(value) -} - -func arrSetBytePtrC(arr unsafe.Pointer, index int, value unsafe.Pointer) { - ptr := (*unsafe.Pointer)(unsafe.Pointer(uintptr(arr) + uintptr(index*cPtrSize))) - *ptr = value -} - -//export callback_receive_all -func callback_receive_all(ptr unsafe.Pointer, senders *C.int, senderCount C.int, messages *C.cmems_t) C.int { - dtImpl, err := GetDTImpl(ptr) - if err != nil { - return C.int(NetworkError) - } - - transport, ok := dtImpl.(*IDataTransport) - if !ok { - return C.int(NetworkError) - } - - count := int(senderCount) - if count == 0 { - messages.count = 0 - messages.data = nil - messages.sizes = nil - return C.int(NetworkSuccess) - } - - sendersArray := make([]int, count) - for i := 0; i < count; i++ { - sendersArray[i] = arrGetIntC(unsafe.Pointer(senders), i) - } - - received, err := (*transport).MessagesReceive(context.Background(), sendersArray) - if err != nil { - return C.int(NetworkError) - } - - if len(received) != count { - return C.int(NetworkError) - } - - // Build flattened cmems_t - total := 0 - for i := 0; i < count; i++ { - total += len(received[i]) - } - - var dataPtr unsafe.Pointer - if total > 0 { - dataPtr = C.malloc(C.size_t(total)) - if dataPtr == nil { - return C.int(NetworkMemoryError) - } - } - - sizesPtr := C.malloc(C.size_t(count) * C.size_t(cIntSize)) - if sizesPtr == nil && count > 0 { - if dataPtr != nil { - C.free(dataPtr) - } - return C.int(NetworkMemoryError) - } - - offset := 0 - for i := 0; i < count; i++ { - arrSetIntC(sizesPtr, i, len(received[i])) - if len(received[i]) > 0 { - C.memcpy(unsafe.Pointer(uintptr(dataPtr)+uintptr(offset)), unsafe.Pointer(&received[i][0]), C.size_t(len(received[i]))) - offset += len(received[i]) - } - } - - messages.count = C.int(count) - messages.data = (*C.uint8_t)(dataPtr) - messages.sizes = (*C.int)(sizesPtr) - - return C.int(NetworkSuccess) -} - -// Helper function to create C string arrays safely -func createCStringArray(strings []string) (unsafe.Pointer, []*C.char, error) { - if len(strings) == 0 { - return nil, nil, fmt.Errorf("string array cannot be empty") - } - - cArray := C.malloc(C.size_t(len(strings)) * C.size_t(unsafe.Sizeof(uintptr(0)))) - if cArray == nil { - return nil, nil, fmt.Errorf("failed to allocate memory for string array") - } - - cSlice := (*[1 << 30]unsafe.Pointer)(cArray)[:len(strings):len(strings)] - cStrs := make([]*C.char, len(strings)) - - for i, str := range strings { - if str == "" { - C.free(cArray) - for j := 0; j < i; j++ { - C.free(unsafe.Pointer(cStrs[j])) - } - return nil, nil, fmt.Errorf("string at index %d cannot be empty", i) - } - cStrs[i] = C.CString(str) - cSlice[i] = unsafe.Pointer(cStrs[i]) - } - - return cArray, cStrs, nil -} - -func freeCStringArray(cArray unsafe.Pointer, cStrs []*C.char) { - if cArray != nil { - C.free(cArray) - } - for _, cStr := range cStrs { - if cStr != nil { - C.free(unsafe.Pointer(cStr)) - } - } -} - -// Job2P represents a 2-party job with improved resource management -type Job2P struct { - dtImplPtr unsafe.Pointer - cJob *C.job_2p_ref -} - -func (j *Job2P) GetCJob() *C.job_2p_ref { - return j.cJob -} - -func NewJob2P(dt IDataTransport, roleIndex int, pnames []string) (Job2P, error) { - if len(pnames) != 2 { - return Job2P{}, fmt.Errorf("NewJob2P requires exactly 2 pnames, got %d", len(pnames)) - } - - if dt == nil { - return Job2P{}, fmt.Errorf("data transport cannot be nil") - } - - ptr, err := SetDTImpl(&dt) - if err != nil { - return Job2P{}, fmt.Errorf("failed to set data transport implementation: %w", err) - } - - cArray, cStrs, err := createCStringArray(pnames) - if err != nil { - FreeDTImpl(ptr) - return Job2P{}, fmt.Errorf("failed to create C string array: %w", err) - } - defer freeCStringArray(cArray, cStrs) - - cJobRef := C.new_job_2p(&callbacks, ptr, C.int(roleIndex), (**C.char)(cArray), C.int(len(pnames))) - if cJobRef == nil { - FreeDTImpl(ptr) - return Job2P{}, fmt.Errorf("failed to create 2P job") - } - - return Job2P{ptr, cJobRef}, nil -} - -func (j *Job2P) Free() { - if j.cJob != nil { - C.free_job_2p(j.cJob) - j.cJob = nil - } - if j.dtImplPtr != nil { - FreeDTImpl(j.dtImplPtr) // Ignore error on cleanup - j.dtImplPtr = nil - } -} - -func (j *Job2P) IsPeer1() bool { - return j.cJob != nil && C.is_peer1(j.cJob) != 0 -} - -func (j *Job2P) IsPeer2() bool { - return j.cJob != nil && C.is_peer2(j.cJob) != 0 -} - -func (j *Job2P) IsRoleIndex(roleIndex int) bool { - return j.cJob != nil && C.is_role_index(j.cJob, C.int(roleIndex)) != 0 -} - -func (j *Job2P) GetRoleIndex() int { - if j.cJob == nil { - return -1 - } - return int(C.get_role_index(j.cJob)) -} - -func (j *Job2P) Message(sender, receiver int, msg []byte) ([]byte, error) { - if j.cJob == nil { - return nil, fmt.Errorf("job is not initialized") - } - - if j.IsRoleIndex(sender) { - var cmsg C.cmem_t - if len(msg) > 0 { - cmsg.data = (*C.uint8_t)(&msg[0]) - cmsg.size = C.int(len(msg)) - } else { - cmsg.data = nil - cmsg.size = 0 - } - cErr := C.mpc_2p_send(j.cJob, C.int(receiver), cmsg) - if cErr != NetworkSuccess { - return nil, fmt.Errorf("2p send failed: error code %d", cErr) - } - return msg, nil - } else if j.IsRoleIndex(receiver) { - var cmsg C.cmem_t - cErr := C.mpc_2p_receive(j.cJob, C.int(sender), &cmsg) - if cErr != NetworkSuccess { - return nil, fmt.Errorf("2p receive failed: error code %d", cErr) - } - - if cmsg.data == nil || cmsg.size == 0 { - return []byte{}, nil - } - - result := C.GoBytes(unsafe.Pointer(cmsg.data), cmsg.size) - C.free(unsafe.Pointer(cmsg.data)) - return result, nil - } - - return nil, fmt.Errorf("caller needs to be either sender (%d) or receiver (%d), current role is %d", - sender, receiver, j.GetRoleIndex()) -} - -// JobMP represents a multi-party job with improved resource management -type JobMP struct { - dtImplPtr unsafe.Pointer - cJob *C.job_mp_ref -} - -func (j *JobMP) GetCJob() *C.job_mp_ref { - return j.cJob -} - -func NewJobMP(dt IDataTransport, partyCount int, roleIndex int, pnames []string) (JobMP, error) { - if len(pnames) != partyCount { - return JobMP{}, fmt.Errorf("NewJobMP requires pnames array length (%d) to match partyCount (%d)", - len(pnames), partyCount) - } - - if dt == nil { - return JobMP{}, fmt.Errorf("data transport cannot be nil") - } - - if partyCount <= 0 { - return JobMP{}, fmt.Errorf("partyCount must be positive, got %d", partyCount) - } - - if roleIndex < 0 || roleIndex >= partyCount { - return JobMP{}, fmt.Errorf("roleIndex (%d) must be in range [0, %d)", roleIndex, partyCount) - } - - ptr, err := SetDTImpl(&dt) - if err != nil { - return JobMP{}, fmt.Errorf("failed to set data transport implementation: %w", err) - } - - cArray, cStrs, err := createCStringArray(pnames) - if err != nil { - FreeDTImpl(ptr) - return JobMP{}, fmt.Errorf("failed to create C string array: %w", err) - } - defer freeCStringArray(cArray, cStrs) - - cJobRef := C.new_job_mp(&callbacks, ptr, C.int(partyCount), C.int(roleIndex), (**C.char)(cArray), C.int(len(pnames))) - if cJobRef == nil { - FreeDTImpl(ptr) - return JobMP{}, fmt.Errorf("failed to create MP job") - } - - return JobMP{ptr, cJobRef}, nil -} - -func (j *JobMP) Free() { - if j.cJob != nil { - C.free_job_mp(j.cJob) - j.cJob = nil - } - if j.dtImplPtr != nil { - FreeDTImpl(j.dtImplPtr) // Ignore error on cleanup - j.dtImplPtr = nil - } -} - -func (j *JobMP) IsParty(partyIndex int) bool { - return j.cJob != nil && C.is_party(j.cJob, C.int(partyIndex)) != 0 -} - -func (j *JobMP) GetPartyIndex() int { - if j.cJob == nil { - return -1 - } - return int(C.get_party_idx(j.cJob)) -} - -func (j *JobMP) GetNParties() int { - if j.cJob == nil { - return -1 - } - return int(C.get_n_parties(j.cJob)) -} - -// Transport factory functions -func NewMTLSDataTransport(config Config) (IDataTransport, error) { - return mtls.NewMTLSMessenger(config) -} - -func PartyNameFromCertificate(cert *x509.Certificate) (string, error) { - return mtls.PartyNameFromCertificate(cert) -} - -func init() { - C.set_callbacks(&callbacks) -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/network.h b/demos-go/cb-mpc-go/internal/cgobinding/network.h deleted file mode 100644 index bf1595dc..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/network.h +++ /dev/null @@ -1,73 +0,0 @@ -#pragma once - -#include -#include - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Error codes for consistent error handling -#define NETWORK_SUCCESS 0 -#define NETWORK_ERROR -1 -#define NETWORK_PARAM_ERROR -2 -#define NETWORK_MEMORY_ERROR -3 -#define NETWORK_INVALID_STATE -4 - -// Callback function types using cmem_t/cmems_t for cleaner interfaces -typedef int (*send_f)(void* go_impl_ptr, int receiver, cmem_t message); -typedef int (*receive_f)(void* go_impl_ptr, int sender, cmem_t* message); -typedef int (*receive_all_f)(void* go_impl_ptr, int* senders, int sender_count, cmems_t* messages); - -typedef struct data_transport_callbacks_t { - send_f send_fun; - receive_f receive_fun; - receive_all_f receive_all_fun; -} data_transport_callbacks_t; - -typedef struct job_2p_ref { - void* opaque; -} job_2p_ref; - -typedef struct job_mp_ref { - void* opaque; -} job_mp_ref; - -// --------------------------------------------------------------------------- -// Generic wrapper for a party_set_t instance used across multiple APIs. -// Moved from ecdsamp.h / cblib.h to centralize the definition and avoid -// duplication. Renamed from PARTY_SET_PTR to mpc_party_set_ref. -typedef struct mpc_party_set_ref { - void* opaque; // Opaque pointer to the C++ party_set_t instance -} mpc_party_set_ref; - -// --------------------------------------------------------------------------- -// Party-set helper C API (moved from ecdsamp / cblib headers) -mpc_party_set_ref new_party_set(); -void party_set_add(mpc_party_set_ref* set, int party_idx); -void free_party_set(mpc_party_set_ref ctx); - -// job_2p_ref Functions -job_2p_ref* new_job_2p(const data_transport_callbacks_t* callbacks, void* go_impl_ptr, int party_index, - const char* const* pnames, int pname_count); -void free_job_2p(job_2p_ref* ptr); -int is_peer1(const job_2p_ref* job); -int is_peer2(const job_2p_ref* job); -int is_role_index(const job_2p_ref* job, int party_index); -int get_role_index(const job_2p_ref* job); -int mpc_2p_send(job_2p_ref* job, int receiver, cmem_t msg); -int mpc_2p_receive(job_2p_ref* job, int sender, cmem_t* msg); - -// job_mp_ref Functions -job_mp_ref* new_job_mp(const data_transport_callbacks_t* callbacks, void* go_impl_ptr, int party_count, int party_index, - const char* const* pnames, int pname_count); -void free_job_mp(job_mp_ref* ptr); -int is_party(const job_mp_ref* job, int party_index); -int get_party_idx(const job_mp_ref* job); -int get_n_parties(const job_mp_ref* job); - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/demos-go/cb-mpc-go/internal/cgobinding/pve.cpp b/demos-go/cb-mpc-go/internal/cgobinding/pve.cpp deleted file mode 100644 index 851eb85f..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/pve.cpp +++ /dev/null @@ -1,382 +0,0 @@ -#include "pve.h" - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "curve.h" -#include "network.h" - -using namespace coinbase; -using namespace coinbase::crypto; -using namespace coinbase::mpc; -using node_t = coinbase::crypto::ss::node_t; -using node_e = coinbase::crypto::ss::node_e; - -static thread_local void* g_ctx = nullptr; - -// KEM -static kem_encap_ctx_fn g_kem_enc = nullptr; -static kem_decap_ctx_fn g_kem_dec = nullptr; -static kem_dk_to_ek_ctx_fn g_kem_derive_pub = nullptr; - -// KEM registration and stub shims (third arg kept for backward compat but ignored) -void pve_register_kem_functions(kem_encap_ctx_fn e, kem_decap_ctx_fn d, void* /*ignored*/, kem_dk_to_ek_ctx_fn dpub) { - g_kem_enc = e; - g_kem_dec = d; - g_kem_derive_pub = dpub; -} - -static int stub_kem_encapsulate(cmem_t ek, cmem_t rho, cmem_t* ct_out, cmem_t* ss_out) { - if (g_kem_enc == nullptr || g_ctx == nullptr) return 1; - return g_kem_enc(g_ctx, ek, rho, ct_out, ss_out); -} - -static int stub_kem_decapsulate(const void* dk, cmem_t ct, cmem_t* ss_out) { - if (g_kem_dec == nullptr || g_ctx == nullptr) return 1; - return g_kem_dec(g_ctx, dk, ct, ss_out); -} - -static int stub_kem_dk_to_ek(const void* dk, cmem_t* out) { - if (g_kem_derive_pub == nullptr || g_ctx == nullptr) return 1; - return g_kem_derive_pub(g_ctx, dk, out); -} - -ffi_kem_encap_fn get_ffi_kem_encap_fn(void) { return stub_kem_encapsulate; } -ffi_kem_decap_fn get_ffi_kem_decap_fn(void) { return stub_kem_decapsulate; } -ffi_kem_dk_to_ek_fn get_ffi_kem_dk_to_ek_fn(void) { return stub_kem_dk_to_ek; } - -// ============================================================================ -// PVE – single receiver, single value -// ============================================================================ - -int pve_encrypt(cmem_t pub_key_cmem, cmem_t x_cmem, const char* label_ptr, int curve_code, cmem_t* out_ptr) { - if (label_ptr == nullptr || out_ptr == nullptr) { - return coinbase::error(E_BADARG); - } - error_t rv = UNINITIALIZED_ERROR; - - // Wrap public key bytes into FFI PKI key type (opaque buffer). - ffi_kem_ek_t pub_key; - pub_key = coinbase::ffi::view(pub_key_cmem); - - // Deserialize secret scalar x - bn_t x = bn_t::from_bin(coinbase::ffi::view(x_cmem)); - - // Resolve curve - ecurve_t curve = ecurve_t::find(curve_code); - if (!curve) return coinbase::error(E_CRYPTO, "unsupported curve code"); - - // Perform encryption - ec_pve_t pve(mpc::kem_pve_base_pke()); - try { - pve.encrypt(&pub_key, std::string(label_ptr), curve, x); - } catch (const std::exception& ex) { - return coinbase::error(E_CRYPTO, ex.what()); - } - - buf_t out = coinbase::convert(pve); - *out_ptr = coinbase::ffi::copy_to_cmem(out); - return SUCCESS; -} - -int pve_decrypt(cmem_t prv_key_cmem, cmem_t pve_bundle_cmem, const char* label_ptr, int curve_code, cmem_t* out_x_ptr) { - if (label_ptr == nullptr || out_x_ptr == nullptr) { - return coinbase::error(E_BADARG); - } - error_t rv = UNINITIALIZED_ERROR; - - // The dk can be either raw bytes or a handle encoded as bytes. - // We pass it through as an opaque handle pointer by default. For pure - // byte-backed dk, we pass a pointer to the cmem_t on the stack whose - // lifetime spans the call chain. - ffi_kem_dk_t prv_key; - cmem_t dk_bytes = prv_key_cmem; - prv_key.handle = static_cast(&dk_bytes); - - // Deserialize ciphertext bundle - ec_pve_t pve(mpc::kem_pve_base_pke()); - rv = coinbase::deser(coinbase::ffi::view(pve_bundle_cmem), pve); - if (rv) return rv; - - // Resolve curve - ecurve_t curve = ecurve_t::find(curve_code); - if (!curve) return coinbase::error(E_CRYPTO, "unsupported curve code"); - - // Decrypt - bn_t x_out; - rv = pve.decrypt(&prv_key, nullptr /*unused ek*/, std::string(label_ptr), curve, x_out, /*skip_verify=*/true); - if (rv) return rv; - - buf_t x_buf = x_out.to_bin(curve.order().get_bin_size()); - *out_x_ptr = coinbase::ffi::copy_to_cmem(x_buf); - return SUCCESS; -} - -int pve_verify(cmem_t pub_key_cmem, cmem_t pve_bundle_cmem, cmem_t Q_cmem, const char* label_ptr) { - if (label_ptr == nullptr) { - return coinbase::error(E_BADARG); - } - error_t rv = UNINITIALIZED_ERROR; - - // Deserialize inputs - ffi_kem_ek_t pub_key; - pub_key = coinbase::ffi::view(pub_key_cmem); - - ecc_point_t Q; - rv = coinbase::deser(coinbase::ffi::view(Q_cmem), Q); - if (rv) return rv; - - ec_pve_t pve(mpc::kem_pve_base_pke()); - rv = coinbase::deser(coinbase::ffi::view(pve_bundle_cmem), pve); - if (rv) return rv; - - // Verify - rv = pve.verify(&pub_key, Q, std::string(label_ptr)); - if (rv) return rv; - - return SUCCESS; -} - -// No explicit template instantiation needed; ec_pve_ac_t is non-templated. - -// ============================================================================ -// PVE-AC - many receivers, many values -// ========================================================================= -int pve_ac_encrypt(crypto_ss_ac_ref* ac_ptr, cmems_t names_list_ptr, cmems_t pub_keys_list_ptr, int pub_keys_count, - cmems_t xs_list_ptr, int xs_count, const char* label_ptr, int curve_code, cmem_t* out_ptr) { - if (ac_ptr == nullptr || ac_ptr->opaque == nullptr) { - return coinbase::error(E_CRYPTO, "null access-structure pointer"); - } - - error_t rv = UNINITIALIZED_ERROR; - crypto::ss::ac_t* ac = static_cast(ac_ptr->opaque); - crypto::ss::node_t* root = const_cast(ac->root); - - // Deserialize names - std::vector name_bufs = coinbase::ffi::bufs_from_cmems(names_list_ptr); - if (name_bufs.size() != (size_t)pub_keys_count) { - return coinbase::error(E_CRYPTO, "names list and key list size mismatch"); - } - std::vector names(pub_keys_count); - for (int i = 0; i < pub_keys_count; i++) { - names[i] = std::string((const char*)name_bufs[i].data(), name_bufs[i].size()); - } - - // Deserialize public keys (opaque FFI KEM ek) - std::vector pub_bufs = coinbase::ffi::bufs_from_cmems(pub_keys_list_ptr); - std::vector pub_keys_list(pub_keys_count); - for (int i = 0; i < pub_keys_count; i++) { - pub_keys_list[i] = pub_bufs[i]; - } - - // Deserialize xs - std::vector xs_bufs = coinbase::ffi::bufs_from_cmems(xs_list_ptr); - std::vector xs(xs_count); - for (int i = 0; i < xs_count; i++) { - xs[i] = bn_t::from_bin(xs_bufs[i]); - } - - // Resolve curve - ecurve_t curve = ecurve_t::find(curve_code); - if (!curve) return coinbase::error(E_CRYPTO, "unsupported curve code"); - - // Validate inputs - if (xs.empty()) { - return coinbase::error(E_CRYPTO, "empty xs list"); - } - if (pub_keys_list.empty()) { - return coinbase::error(E_CRYPTO, "empty public keys list"); - } - - // Build access structure and get leaf names - ss::ac_owned_t ac_owned(root); - auto leaf_set = ac_owned.list_leaf_names(); - std::vector leaves(leaf_set.begin(), leaf_set.end()); - - if (names.size() != pub_keys_list.size()) { - return coinbase::error(E_CRYPTO, "names list and key list size mismatch"); - } - if (pub_keys_list.size() != leaves.size()) { - return coinbase::error(E_CRYPTO, "leaf count and key list size mismatch"); - } - - // Build the mapping leaf_name -> pub_key - std::map pub_keys; - std::vector pub_keys_storage(leaves.size()); - for (size_t i = 0; i < leaves.size(); ++i) { - pub_keys_storage[i] = pub_keys_list[i]; - pub_keys[names[i]] = pub_keys_storage[i]; - } - - // Encrypt using FFI KEM base PKE - ec_pve_ac_t pve(mpc::kem_pve_base_pke()); - std::map ac_pks; - for (size_t i = 0; i < leaves.size(); ++i) { - ac_pks[names[i]] = static_cast(&pub_keys_storage[i]); - } - pve.encrypt(ac_owned, ac_pks, std::string(label_ptr), curve, xs); - buf_t out = coinbase::convert(pve); - *out_ptr = coinbase::ffi::copy_to_cmem(out); - return SUCCESS; -} - -extern "C" int pve_ac_party_decrypt_row(crypto_ss_ac_ref* ac_ptr, cmem_t prv_key_cmem, cmem_t pve_bundle_cmem, - const char* label_ptr, const char* path_ptr, int row_index, cmem_t* out_share_ptr) { - if (ac_ptr == nullptr || ac_ptr->opaque == nullptr) { - return coinbase::error(E_CRYPTO, "null access-structure pointer"); - } - - // Deserialize PVE bundle - ec_pve_ac_t pve(mpc::kem_pve_base_pke()); - error_t rv = coinbase::deser(coinbase::ffi::view(pve_bundle_cmem), pve); - if (rv) return rv; - - // Access structure - crypto::ss::ac_t* ac = static_cast(ac_ptr->opaque); - ss::ac_owned_t ac_owned(const_cast(ac->root)); - - // Prepare DK handle wrapper for FFI KEM - ffi_kem_dk_t prv_key; - cmem_t dk_bytes = prv_key_cmem; - prv_key.handle = static_cast(&dk_bytes); - - // Compute share - bn_t share; - rv = pve.party_decrypt_row(ac_owned, row_index, std::string(path_ptr), static_cast(&prv_key), - std::string(label_ptr), share); - if (rv) return rv; - - buf_t share_buf = share.to_bin(); - *out_share_ptr = coinbase::ffi::copy_to_cmem(share_buf); - return SUCCESS; -} - -extern "C" int pve_ac_aggregate_to_restore_row(crypto_ss_ac_ref* ac_ptr, cmem_t pve_bundle_cmem, const char* label_ptr, - cmems_t paths_list_ptr, cmems_t shares_list_ptr, int quorum_count, int row_index, - cmems_t* out_values_ptr) { - if (ac_ptr == nullptr || ac_ptr->opaque == nullptr) { - return coinbase::error(E_CRYPTO, "null access-structure pointer"); - } - - // Deserialize PVE bundle - ec_pve_ac_t pve(mpc::kem_pve_base_pke()); - error_t rv = coinbase::deser(coinbase::ffi::view(pve_bundle_cmem), pve); - if (rv) return rv; - - // Access structure - crypto::ss::ac_t* ac = static_cast(ac_ptr->opaque); - ss::ac_owned_t ac_owned(const_cast(ac->root)); - - // Build quorum shares map: path -> bn share - std::vector name_bufs = coinbase::ffi::bufs_from_cmems(paths_list_ptr); - std::vector share_bufs = coinbase::ffi::bufs_from_cmems(shares_list_ptr); - if ((int)name_bufs.size() != quorum_count || (int)share_bufs.size() != quorum_count) { - return coinbase::error(E_CRYPTO, "quorum lists size mismatch"); - } - std::map quorum_decrypted; - for (int i = 0; i < quorum_count; i++) { - std::string path((const char*)name_bufs[i].data(), name_bufs[i].size()); - quorum_decrypted[path] = bn_t::from_bin(share_bufs[i]); - } - - // Recover values for the specified row - std::vector x; - rv = pve.aggregate_to_restore_row(ac_owned, row_index, std::string(label_ptr), quorum_decrypted, x, - true /*skip_verify*/); - if (rv) return rv; - - // Serialize outputs to fixed-size bins - const std::vector& Q = pve.get_Q(); - if (Q.empty()) return coinbase::error(E_CRYPTO, "empty Q"); - ecurve_t curve = Q[0].get_curve(); - int fixed_size = curve.order().get_bin_size(); - std::vector out(x.size()); - for (size_t i = 0; i < x.size(); i++) out[i] = x[i].to_bin(fixed_size); - *out_values_ptr = coinbase::ffi::copy_to_cmems(buf_t::to_mems(out)); - return SUCCESS; -} - -int pve_ac_verify(crypto_ss_ac_ref* ac_ptr, cmems_t names_list_ptr, cmems_t pub_keys_list_ptr, int pub_keys_count, - cmem_t pve_bundle_cmem, cmems_t Xs_list_ptr, int xs_count, const char* label_ptr) { - if (ac_ptr == nullptr || ac_ptr->opaque == nullptr) { - return coinbase::error(E_CRYPTO, "null access-structure pointer"); - } - - error_t rv = UNINITIALIZED_ERROR; - crypto::ss::ac_t* ac = static_cast(ac_ptr->opaque); - crypto::ss::node_t* root = const_cast(ac->root); - - // Deserialize names - std::vector name_bufs = coinbase::ffi::bufs_from_cmems(names_list_ptr); - if (name_bufs.size() != (size_t)pub_keys_count) { - return coinbase::error(E_CRYPTO, "names list and key list size mismatch"); - } - std::vector names(pub_keys_count); - for (int i = 0; i < pub_keys_count; i++) { - names[i] = std::string((const char*)name_bufs[i].data(), name_bufs[i].size()); - } - - // Deserialize public keys (opaque FFI KEM ek) - std::vector pub_bufs = coinbase::ffi::bufs_from_cmems(pub_keys_list_ptr); - std::vector pub_keys_list(pub_keys_count); - for (int i = 0; i < pub_keys_count; i++) { - pub_keys_list[i] = pub_bufs[i]; - } - - // Deserialize Xs (public shares) - std::vector Xs_bufs = coinbase::ffi::bufs_from_cmems(Xs_list_ptr); - std::vector Xs(xs_count); - for (int i = 0; i < xs_count; i++) { - rv = coinbase::deser(Xs_bufs[i], Xs[i]); - if (rv) return rv; - } - - // Deserialize the PVE bundle - ec_pve_ac_t pve(mpc::kem_pve_base_pke()); - buf_t pve_bundle = coinbase::ffi::view(pve_bundle_cmem); - rv = coinbase::deser(pve_bundle, pve); - if (rv) return rv; - - // Build leaf names from access structure - ss::ac_owned_t ac_owned(root); - auto leaf_set = ac_owned.list_leaf_names(); - std::vector leaves(leaf_set.begin(), leaf_set.end()); - if (leaves.size() != names.size()) { - return coinbase::error(E_CRYPTO, "leaf count and names list size mismatch"); - } - - // Build mapping leaf_name -> pub_key - std::map pub_keys; - for (size_t i = 0; i < leaves.size(); ++i) { - pub_keys[names[i]] = pub_keys_list[i]; - } - - // Perform verification - std::string label(label_ptr); - std::vector pub_keys_storage(leaves.size()); - std::map ac_pks; - for (size_t i = 0; i < leaves.size(); ++i) { - pub_keys_storage[i] = pub_keys[names[i]]; - ac_pks[names[i]] = static_cast(&pub_keys_storage[i]); - } - rv = pve.verify(*ac, ac_pks, Xs, label); - if (rv) return rv; - - return SUCCESS; -} - -extern "C" void pve_activate_ctx(void* ctx) { g_ctx = ctx; } diff --git a/demos-go/cb-mpc-go/internal/cgobinding/pve.go b/demos-go/cb-mpc-go/internal/cgobinding/pve.go deleted file mode 100644 index 20f5629f..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/pve.go +++ /dev/null @@ -1,317 +0,0 @@ -package cgobinding - -/* -#cgo CXXFLAGS: -std=c++17 -#include -#include -#include "pve.h" -// Forward declarations of the Go callbacks – defined further below. - -int go_pve_kem_encapsulate_bridge_ctx( - void*, // ctx - cmem_t, // ek - cmem_t, // rho - cmem_t*, // kem_ct out - cmem_t*); // kem_ss out -int go_pve_kem_decapsulate_bridge_ctx( - void*, // ctx - void*, // dk_handle - cmem_t, // kem_ct - cmem_t*); // kem_ss out - -int go_pve_derive_pub_bridge_ctx( - void*, // ctx - void*, // dk_handle - cmem_t*); // out ek - -// === single-party PVE helpers exposed by pve.h === -int pve_encrypt(cmem_t, cmem_t, const char*, int, cmem_t*); -int pve_decrypt(cmem_t, cmem_t, const char*, int, cmem_t*); -int pve_verify(cmem_t, cmem_t, cmem_t, const char*); - -// Activate context -void pve_activate_ctx(void*); -*/ -import "C" - -import ( - "errors" - "fmt" - "runtime" - "sync" - "sync/atomic" - "unsafe" -) - -// KEM describes the pluggable KEM backend used by the C++ PVE core. -// All byte slices are opaque references – the Go side decides their meaning. -// Implementations MUST be safe for concurrent use by multiple goroutines. -type KEM interface { - Generate() (skRef, ek []byte, err error) - Encapsulate(ek []byte, rho [32]byte) (ct, ss []byte, err error) - Decapsulate(skHandle unsafe.Pointer, ct []byte) (ss []byte, err error) - DerivePub(skRef []byte) ([]byte, error) -} - -var ( - // Multi-instance registry keyed by opaque context pointers coming from C. - instanceReg sync.Map // map[unsafe.Pointer]KEM - nextCtxID uint64 - // Ensures we register the C-side PKI callbacks only once per process. - registerPKIFuncOnce sync.Once - // Ensures we register the C-side KEM callbacks only once per process. - registerKEMFuncOnce sync.Once -) - -// ---------------------------------------------------------------------------- -// PVE - single receiver, single value -// ---------------------------------------------------------------------------- - -func PVE_encrypt(pubKey []byte, x []byte, label string, curveCode int) ([]byte, error) { - var out CMEM - cLabel := C.CString(label) - defer C.free(unsafe.Pointer(cLabel)) - rv := C.pve_encrypt(cmem(pubKey), cmem(x), cLabel, C.int(curveCode), (*C.cmem_t)(&out)) - if rv != 0 { - return nil, fmt.Errorf("pve encrypt failed: %v", rv) - } - return CMEMGet(out), nil -} - -func PVE_decrypt(prvKey []byte, ciphertext []byte, label string, curveCode int) ([]byte, error) { - var out CMEM - cLabel := C.CString(label) - defer C.free(unsafe.Pointer(cLabel)) - rv := C.pve_decrypt(cmem(prvKey), cmem(ciphertext), cLabel, C.int(curveCode), (*C.cmem_t)(&out)) - if rv != 0 { - return nil, fmt.Errorf("pve decrypt failed: %v", rv) - } - return CMEMGet(out), nil -} - -func PVE_verify(pubKey []byte, ciphertext []byte, Q []byte, label string) error { - cLabel := C.CString(label) - defer C.free(unsafe.Pointer(cLabel)) - - rv := C.pve_verify(cmem(pubKey), cmem(ciphertext), cmem(Q), cLabel) - if rv != 0 { - return fmt.Errorf("pve verify failed: %v", rv) - } - return nil -} - -// ---------------------------------------------------------------------------- -// PVE-AC - many receivers, many values -// ---------------------------------------------------------------------------- -func PVE_AC_encrypt(ac C_AcPtr, names [][]byte, pubKeys [][]byte, count int, xs [][]byte, xsCount int, label string, curveCode int) ([]byte, error) { - var out CMEM - cLabel := C.CString(label) - defer C.free(unsafe.Pointer(cLabel)) - - // Pin array memory during the C call - namesPin := makeCmems(names) - pubPin := makeCmems(pubKeys) - xsPin := makeCmems(xs) - - rv := C.pve_ac_encrypt((*C.crypto_ss_ac_ref)(&ac), namesPin.c, pubPin.c, C.int(count), xsPin.c, C.int(xsCount), cLabel, C.int(curveCode), &out) - // Ensure Go slices are kept alive until after the C call returns - runtime.KeepAlive(namesPin) - runtime.KeepAlive(pubPin) - runtime.KeepAlive(xsPin) - if rv != 0 { - return nil, fmt.Errorf("pve quorum encrypt (map) failed: %v", rv) - } - return CMEMGet(out), nil -} - -func PVE_AC_party_decrypt_row(ac C_AcPtr, prvKey []byte, pveBundle []byte, label string, path string, rowIndex int) ([]byte, error) { - var out CMEM - cLabel := C.CString(label) - cPath := C.CString(path) - defer C.free(unsafe.Pointer(cLabel)) - defer C.free(unsafe.Pointer(cPath)) - - rv := C.pve_ac_party_decrypt_row((*C.crypto_ss_ac_ref)(&ac), cmem(prvKey), cmem(pveBundle), cLabel, cPath, C.int(rowIndex), (*C.cmem_t)(&out)) - if rv != 0 { - return nil, fmt.Errorf("pve quorum party_decrypt_row failed: %v", rv) - } - return CMEMGet(out), nil -} - -func PVE_AC_aggregate_to_restore_row(ac C_AcPtr, pveBundle []byte, label string, paths [][]byte, shares [][]byte, rowIndex int) ([][]byte, error) { - if len(paths) != len(shares) { - return nil, fmt.Errorf("paths and shares length mismatch") - } - var out CMEMS - cLabel := C.CString(label) - defer C.free(unsafe.Pointer(cLabel)) - pathsPin := makeCmems(paths) - sharesPin := makeCmems(shares) - rv := C.pve_ac_aggregate_to_restore_row((*C.crypto_ss_ac_ref)(&ac), cmem(pveBundle), cLabel, pathsPin.c, sharesPin.c, C.int(len(paths)), C.int(rowIndex), &out) - runtime.KeepAlive(pathsPin) - runtime.KeepAlive(sharesPin) - if rv != 0 { - return nil, fmt.Errorf("pve quorum aggregate_to_restore_row failed: %v", rv) - } - return CMEMSGet(out), nil -} - -func PVE_AC_verify(ac C_AcPtr, names [][]byte, pubKeys [][]byte, count int, pveBundle []byte, Xs [][]byte, xsCount int, label string) error { - cLabel := C.CString(label) - defer C.free(unsafe.Pointer(cLabel)) - - namesPin := makeCmems(names) - pubPin := makeCmems(pubKeys) - xsPin := makeCmems(Xs) - - rv := C.pve_ac_verify((*C.crypto_ss_ac_ref)(&ac), namesPin.c, pubPin.c, C.int(count), cmem(pveBundle), xsPin.c, C.int(xsCount), cLabel) - runtime.KeepAlive(namesPin) - runtime.KeepAlive(pubPin) - runtime.KeepAlive(xsPin) - if rv != 0 { - return fmt.Errorf("pve quorum verify (map) failed: %v", rv) - } - return nil -} - -// RegisterKEMInstance registers a distinct KEM implementation and configures the C bridge. -func RegisterKEMInstance(p KEM) (unsafe.Pointer, error) { - if p == nil { - return nil, errors.New("RegisterKEMInstance: nil implementation") - } - registerKEMFuncOnce.Do(func() { - C.pve_register_kem_functions( - (C.kem_encap_ctx_fn)(unsafe.Pointer(C.go_pve_kem_encapsulate_bridge_ctx)), - (C.kem_decap_ctx_fn)(unsafe.Pointer(C.go_pve_kem_decapsulate_bridge_ctx)), - unsafe.Pointer(nil), - (C.kem_dk_to_ek_ctx_fn)(unsafe.Pointer(C.go_pve_derive_pub_bridge_ctx)), - ) - }) - - // Fast-path: if already registered, return previous ctx. - var foundCtx unsafe.Pointer - instanceReg.Range(func(key, value any) bool { - if value == p { - foundCtx = key.(unsafe.Pointer) - return false - } - return true - }) - if foundCtx != nil { - return foundCtx, nil - } - - // Allocate an opaque context pointer. - _ = atomic.AddUint64(&nextCtxID, 1) - ctx := C.malloc(1) - - instanceReg.Store(ctx, p) - return ctx, nil -} - -//export go_pve_derive_pub_bridge_ctx -func go_pve_derive_pub_bridge_ctx( - ctx unsafe.Pointer, - dkHandle unsafe.Pointer, - out *C.cmem_t, -) C.int { - v, ok := instanceReg.Load(ctx) - if !ok { - return 1 - } - impl := v.(KEM) - var dkGo []byte - if dkHandle != nil { - cm := (*C.cmem_t)(dkHandle) - if cm != nil && cm.data != nil && cm.size >= 0 { - dkGo = unsafe.Slice((*byte)(unsafe.Pointer(cm.data)), int(cm.size)) - } else { - p := uintptr(dkHandle) - b := make([]byte, unsafe.Sizeof(p)) - for i := 0; i < len(b); i++ { - b[i] = byte(p >> (8 * i)) - } - dkGo = b - } - } - ek, err := impl.DerivePub(dkGo) - if err != nil { - return 2 - } - mem := C.malloc(C.size_t(len(ek))) - if len(ek) > 0 { - C.memcpy(mem, unsafe.Pointer(&ek[0]), C.size_t(len(ek))) - } - out.data = (*C.uint8_t)(mem) - out.size = C.int(len(ek)) - return 0 -} - -//export go_pve_kem_encapsulate_bridge_ctx -func go_pve_kem_encapsulate_bridge_ctx( - ctx unsafe.Pointer, - ek C.cmem_t, - rho C.cmem_t, - ctOut *C.cmem_t, - ssOut *C.cmem_t, -) C.int { - v, ok := instanceReg.Load(ctx) - if !ok { - return 1 - } - impl := v.(KEM) - // Enforce exactly 32 bytes of seed entropy - if int(rho.size) != 32 { - return 3 - } - ekGo := unsafe.Slice((*byte)(unsafe.Pointer(ek.data)), int(ek.size)) - rhoGo := unsafe.Slice((*byte)(unsafe.Pointer(rho.data)), int(rho.size)) - var rhoArr [32]byte - copy(rhoArr[:], rhoGo) - ct, ss, err := impl.Encapsulate(ekGo, rhoArr) - if err != nil { - return 2 - } - ctMem := C.malloc(C.size_t(len(ct))) - if len(ct) > 0 { - C.memcpy(ctMem, unsafe.Pointer(&ct[0]), C.size_t(len(ct))) - } - ctOut.data = (*C.uint8_t)(ctMem) - ctOut.size = C.int(len(ct)) - ssMem := C.malloc(C.size_t(len(ss))) - if len(ss) > 0 { - C.memcpy(ssMem, unsafe.Pointer(&ss[0]), C.size_t(len(ss))) - } - ssOut.data = (*C.uint8_t)(ssMem) - ssOut.size = C.int(len(ss)) - return 0 -} - -//export go_pve_kem_decapsulate_bridge_ctx -func go_pve_kem_decapsulate_bridge_ctx( - ctx unsafe.Pointer, - dkHandle unsafe.Pointer, - ct C.cmem_t, - ssOut *C.cmem_t, -) C.int { - v, ok := instanceReg.Load(ctx) - if !ok { - return 1 - } - impl := v.(KEM) - ctGo := unsafe.Slice((*byte)(unsafe.Pointer(ct.data)), int(ct.size)) - ss, err := impl.Decapsulate(dkHandle, ctGo) - if err != nil { - return 2 - } - mem := C.malloc(C.size_t(len(ss))) - if len(ss) > 0 { - C.memcpy(mem, unsafe.Pointer(&ss[0]), C.size_t(len(ss))) - } - ssOut.data = (*C.uint8_t)(mem) - ssOut.size = C.int(len(ss)) - return 0 -} - -// ActivateCtx tells the C shim which KEM instance is about to run. -func ActivateCtx(ctx unsafe.Pointer) { C.pve_activate_ctx(ctx) } diff --git a/demos-go/cb-mpc-go/internal/cgobinding/pve.h b/demos-go/cb-mpc-go/internal/cgobinding/pve.h deleted file mode 100644 index c807cd5b..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/pve.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include - -#include -#include - -#include "ac.h" -#include "curve.h" -#include "kem.h" -#include "network.h" - -#ifdef __cplusplus -extern "C" { -#endif - -void pve_register_kem_functions(kem_encap_ctx_fn e, kem_decap_ctx_fn d, void* /*ignored*/, kem_dk_to_ek_ctx_fn dpub); - -// Switch the currently active PKI context (used by shim wrappers). -void pve_activate_ctx(void* ctx); - -int pve_encrypt(cmem_t pub_key_cmem, cmem_t x_cmem, const char* label_ptr, int curve_code, cmem_t* out_ptr); -int pve_decrypt(cmem_t prv_key_cmem, cmem_t pve_bundle_cmem, const char* label_ptr, int curve_code, cmem_t* out_x_ptr); -int pve_verify(cmem_t pub_key_cmem, cmem_t pve_bundle_cmem, cmem_t Q_cmem, const char* label_ptr); - -// Quorum encryption / verification operating on a full access-structure pointer. -int pve_ac_encrypt(crypto_ss_ac_ref* ac_ptr, cmems_t names_list_ptr, cmems_t pub_keys_list_ptr, int pub_keys_count, - cmems_t xs_list_ptr, int xs_count, const char* label_ptr, int curve_code, cmem_t* out_ptr); -int pve_ac_verify(crypto_ss_ac_ref* ac_ptr, cmems_t names_list_ptr, cmems_t pub_keys_list_ptr, int pub_keys_count, - cmem_t pve_bundle_cmem, cmems_t Xs_list_ptr, int xs_count, const char* label_ptr); - -// Interactive quorum decryption APIs -int pve_ac_party_decrypt_row(crypto_ss_ac_ref* ac_ptr, - cmem_t prv_key_cmem, - cmem_t pve_bundle_cmem, - const char* label_ptr, - const char* path_ptr, - int row_index, - cmem_t* out_share_ptr); - -int pve_ac_aggregate_to_restore_row(crypto_ss_ac_ref* ac_ptr, - cmem_t pve_bundle_cmem, - const char* label_ptr, - cmems_t paths_list_ptr, - cmems_t shares_list_ptr, - int quorum_count, - int row_index, - cmems_t* out_values_ptr); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/cgobinding/zk.cpp b/demos-go/cb-mpc-go/internal/cgobinding/zk.cpp deleted file mode 100644 index b1598aa4..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/zk.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "zk.h" - -#include -#include - -int zk_dl_prove(ecc_point_ref* Q_ref, cmem_t w_mem, cmem_t sid_mem, uint64_t aux, cmem_t* proof_mem) { - // Deserialize inputs - ecc_point_t* Q = static_cast(Q_ref->opaque); - buf_t sid = coinbase::ffi::view(sid_mem); - bn_t w = bn_t::from_bin(coinbase::ffi::view(w_mem)); - - // Prove - coinbase::zk::uc_dl_t zk; - zk.prove(*Q, w, sid, aux); - - // Serialize proof - buf_t proof = coinbase::ser(zk); - *proof_mem = coinbase::ffi::copy_to_cmem(proof); - - return SUCCESS; -} - -int zk_dl_verify(ecc_point_ref* Q_ref, cmem_t proof_mem, cmem_t sid_mem, uint64_t aux) { - // Deserialize inputs - ecc_point_t* Q = static_cast(Q_ref->opaque); - coinbase::zk::uc_dl_t zk; - buf_t sid = coinbase::ffi::view(sid_mem); - - error_t rv = coinbase::deser(coinbase::ffi::view(proof_mem), zk); - if (rv != SUCCESS) return rv; - - return zk.verify(*Q, sid, aux); -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/zk.go b/demos-go/cb-mpc-go/internal/cgobinding/zk.go deleted file mode 100644 index 1af72877..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/zk.go +++ /dev/null @@ -1,44 +0,0 @@ -package cgobinding - -import ( - "fmt" -) - -/* -#include -#include "zk.h" -*/ -import "C" - -func ZK_DL_Prove(Q ECCPointRef, w []byte, sessionID []byte, aux uint64) ([]byte, error) { - var proof C.cmem_t - - cErr := C.zk_dl_prove( - (*C.ecc_point_ref)(&Q), - cmem(w), - cmem(sessionID), - C.uint64_t(aux), - &proof, - ) - - if cErr != 0 { - return nil, fmt.Errorf("ZK DL prove failed: %d", int(cErr)) - } - - return CMEMGet(proof), nil -} - -func ZK_DL_Verify(Q ECCPointRef, proof []byte, sessionID []byte, aux uint64) (bool, error) { - cErr := C.zk_dl_verify( - (*C.ecc_point_ref)(&Q), - cmem(proof), - cmem(sessionID), - C.uint64_t(aux), - ) - - if cErr == 0 { - return true, nil - } else { - return false, nil // Not an error, just verification failed - } -} diff --git a/demos-go/cb-mpc-go/internal/cgobinding/zk.h b/demos-go/cb-mpc-go/internal/cgobinding/zk.h deleted file mode 100644 index 3c5f75da..00000000 --- a/demos-go/cb-mpc-go/internal/cgobinding/zk.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include - -#include - -#include "curve.h" - -#ifdef __cplusplus -extern "C" { -#endif - -int zk_dl_prove(ecc_point_ref* Q, cmem_t w_mem, cmem_t sid_mem, uint64_t aux, cmem_t* proof_mem); -int zk_dl_verify(ecc_point_ref* Q, cmem_t proof_mem, cmem_t sid_mem, uint64_t aux); - -#ifdef __cplusplus -} // extern "C" -#endif \ No newline at end of file diff --git a/demos-go/cb-mpc-go/internal/curvemap/nid.go b/demos-go/cb-mpc-go/internal/curvemap/nid.go deleted file mode 100644 index ea682957..00000000 --- a/demos-go/cb-mpc-go/internal/curvemap/nid.go +++ /dev/null @@ -1,29 +0,0 @@ -package curvemap - -import ( - "fmt" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" -) - -// Numeric OpenSSL NIDs for curves supported by cb-mpc-go. -const ( - Secp256k1 = 714 // NID_secp256k1 - P256 = 415 // NID_X9_62_prime256v1 - Ed25519 = 1087 // NID_ED25519 -) - -// CurveForCode converts an OpenSSL NID into a curve.Curve instance. -// Only for internal consumption. -func CurveForCode(code int) (curve.Curve, error) { - switch code { - case Secp256k1: - return curve.NewSecp256k1() - case P256: - return curve.NewP256() - case Ed25519: - return curve.NewEd25519() - default: - return nil, fmt.Errorf("unsupported curve code %d", code) - } -} diff --git a/demos-go/cb-mpc-go/internal/testutil/testutil.go b/demos-go/cb-mpc-go/internal/testutil/testutil.go deleted file mode 100644 index ff762fba..00000000 --- a/demos-go/cb-mpc-go/internal/testutil/testutil.go +++ /dev/null @@ -1,36 +0,0 @@ -package testutil - -import ( - "os" - "syscall" - "testing" -) - -// WithSilencedStderr redirects the process' stderr to /dev/null while fn runs. -func WithSilencedStderr(fn func()) { - savedFD, err := syscall.Dup(2) - if err != nil { - fn() - return - } - devnull, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) - if err != nil { - _ = syscall.Close(savedFD) - fn() - return - } - defer devnull.Close() - _ = syscall.Dup2(int(devnull.Fd()), 2) - // Ensure stderr is restored even if fn panics - defer func() { - _ = syscall.Dup2(savedFD, 2) - _ = syscall.Close(savedFD) - }() - fn() -} - -// TSilence wraps a test section and silences stderr for its duration. -func TSilence(t *testing.T, fn func(t *testing.T)) { - t.Helper() - WithSilencedStderr(func() { fn(t) }) -} diff --git a/demos-go/cmd/threshold-ecdsa-web/.gitignore b/demos-go/cmd/threshold-ecdsa-web/.gitignore deleted file mode 100644 index 5fd6a1f4..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/.gitignore +++ /dev/null @@ -1,10 +0,0 @@ -*.pem -*.der -*.csr -*.srl -*.json -*.txt - - -demo -tmp/ \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/Makefile b/demos-go/cmd/threshold-ecdsa-web/Makefile deleted file mode 100644 index a16435d1..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/Makefile +++ /dev/null @@ -1,72 +0,0 @@ -.PHONY: generate-ca -generate-ca: - @echo "Generating CA RSA key..." - openssl genpkey -algorithm RSA -out certs/ca-key.pem -pkeyopt rsa_keygen_bits:2048 - @echo "Generating CA certificate..." - openssl req -new -x509 -key certs/ca-key.pem -out certs/ca.pem -days 365 -subj "/C=US/ST=California/L=San Francisco/O=My Company/CN=CA Root Certificate" - -.PHONY: generate-cert -generate-cert: - @echo "Generating RSA key with index $(INDEX)..." - openssl genpkey -algorithm RSA -out certs/party-$(INDEX)/key-$(INDEX).pem -pkeyopt rsa_keygen_bits:2048 - @echo "Generating CSR with index $(INDEX)..." - openssl req -new -key certs/party-$(INDEX)/key-$(INDEX).pem -out certs/party-$(INDEX)/cert-$(INDEX).csr -config certs/party-$(INDEX)/openssl-$(INDEX).cnf - # @echo "Generating self-signed certificate with index $(INDEX)..." - # openssl x509 -req -in certs/party-$(INDEX)/cert-$(INDEX).csr -signkey certs/party-$(INDEX)/key-$(INDEX).pem -out unsigned-cert-$(INDEX).pem -days 365 - @echo "Signing certificate with CA..." - openssl x509 -req -in certs/party-$(INDEX)/cert-$(INDEX).csr -CA certs/ca.pem -CAkey certs/ca-key.pem -CAcreateserial -out certs/party-$(INDEX)/cert-$(INDEX).pem -days 365 -extensions v3_req -extfile certs/party-$(INDEX)/openssl-$(INDEX).cnf - @echo "Certificate signed and saved as certs/party-$(INDEX)/cert-$(INDEX).pem" - @echo "Converting cert-$(INDEX).pem to ASN.1 format..." - openssl x509 -in certs/party-$(INDEX)/cert-$(INDEX).pem -outform DER -out certs/party-$(INDEX)/cert-$(INDEX).der - @echo "Conversion complete. ASN.1 format saved as certs/party-$(INDEX)/cert-$(INDEX).der" - -.PHONY: certs -certs: - make clean-all - mkdir -p certs certs/party-0 certs/party-1 certs/party-2 certs/party-3 - make generate-ca - make generate-cert INDEX=0 - make generate-cert INDEX=1 - make generate-cert INDEX=2 - make generate-cert INDEX=3 - -.PHONY: clean-all -clean-all: - rm -rf certs/*/{*.pem,*.csr,*.srl,*.der} - make clean-logs - rm -f demo - rm -rf keyshare_party* - rm -f threshold.txt - rm -rf tmp - -.PHONY: clean-logs -clean-logs: - rm -rf *.log - -.PHONY: clean-processes -clean-processes: - lsof -ti tcp:8080 | xargs -r kill -TERM - lsof -ti tcp:8081 | xargs -r kill -TERM - lsof -ti tcp:8082 | xargs -r kill -TERM - lsof -ti tcp:8083 | xargs -r kill -TERM - -.PHONY: run-dkg -run-dkg: - make clean-processes - make clean-logs - go run *.go -index=0 -phase=dkg -mode=cli -participants=0,1,2,3 -threshold=3 > dkg-0.log 2>&1 & - sleep 5 && go run *.go -index=1 -phase=dkg -mode=cli -participants=0,1,2,3 -threshold=3 > dkg-1.log 2>&1 & - sleep 10 && go run *.go -index=2 -phase=dkg -mode=cli -participants=0,1,2,3 -threshold=3 > dkg-2.log 2>&1 & - sleep 15 && go run *.go -index=3 -phase=dkg -mode=cli -participants=0,1,2,3 -threshold=3 > dkg-3.log 2>&1 & - -.PHONY: run-sign -run-sign: - make clean-processes - make clean-logs - go run *.go -index=0 -phase=sign -mode=cli -participants=0,2,3 -threshold=3 > sign-0.log 2>&1 & - sleep 5 && go run *.go -index=2 -phase=sign -mode=cli -participants=0,2,3 -threshold=3 > sign-2.log 2>&1 & - sleep 10 && go run *.go -index=3 -phase=sign -mode=cli -participants=0,2,3 -threshold=3 > sign-3.log 2>&1 & - -.PHONY: run-server -run-server: - cd ../../../ && BUILD_TYPE=Release bash scripts/go_with_cpp.sh --no-cd bash -lc "cd 'demos-go/cmd/threshold-ecdsa-web' && env CGO_ENABLED=1 go run *.go -index=$(INDEX)" \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/README.md b/demos-go/cmd/threshold-ecdsa-web/README.md deleted file mode 100644 index 7a65649f..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/README.md +++ /dev/null @@ -1,31 +0,0 @@ -# Web-based Demo - - -For the initial setup, first create the certificates: - -```bash -make clean-all # deletes all existing certificates, do it only if you want a fresh start -make certs # It will ask you some questions about ca root cert -``` - -Next, make sure that the c++ library is compiled and installed. - -```bash -cd ../../.. -make build -sudo make install -cd - -``` - -To run the demo, in four separate terminals, run the following commands: - - Terminal 1: `make run-server INDEX=0` - - Terminal 2: `make run-server INDEX=1` - - Terminal 3: `make run-server INDEX=2` - - Terminal 4: `make run-server INDEX=3` - -And go to the following urls in your browser: - - [127.0.0.1:7080/page/dkg](127.0.0.1:7080/page/dkg) - - [127.0.0.1:7081/page/dkg](127.0.0.1:7081/page/dkg) - - [127.0.0.1:7082/page/dkg](127.0.0.1:7082/page/dkg) - - [127.0.0.1:7083/page/dkg](127.0.0.1:7083/page/dkg) - diff --git a/demos-go/cmd/threshold-ecdsa-web/certs/party-0/openssl-0.cnf b/demos-go/cmd/threshold-ecdsa-web/certs/party-0/openssl-0.cnf deleted file mode 100644 index 2db6fabc..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/certs/party-0/openssl-0.cnf +++ /dev/null @@ -1,23 +0,0 @@ - [ req ] - default_bits = 2048 - distinguished_name = peer0 - req_extensions = req_ext - x509_extensions = v3_req - prompt = no - - [ peer0 ] - countryName = US - stateOrProvinceName = California - localityName = San Francisco - organizationName = My Company - commonName = peerindex0 - - [ req_ext ] - subjectAltName = @alt_names - - [ v3_req ] - subjectAltName = @alt_names - - [ alt_names ] - DNS.1 = peerindex0 - IP.1 = 127.0.0.1 \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/certs/party-1/openssl-1.cnf b/demos-go/cmd/threshold-ecdsa-web/certs/party-1/openssl-1.cnf deleted file mode 100644 index 36154c50..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/certs/party-1/openssl-1.cnf +++ /dev/null @@ -1,23 +0,0 @@ - [ req ] - default_bits = 2048 - distinguished_name = peer1 - req_extensions = req_ext - x509_extensions = v3_req - prompt = no - - [ peer1 ] - countryName = US - stateOrProvinceName = California - localityName = San Francisco - organizationName = My Company - commonName = peerindex1 - - [ req_ext ] - subjectAltName = @alt_names - - [ v3_req ] - subjectAltName = @alt_names - - [ alt_names ] - DNS.1 = peerindex1 - IP.1 = 127.0.0.1 \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/certs/party-2/openssl-2.cnf b/demos-go/cmd/threshold-ecdsa-web/certs/party-2/openssl-2.cnf deleted file mode 100644 index d86f1582..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/certs/party-2/openssl-2.cnf +++ /dev/null @@ -1,23 +0,0 @@ - [ req ] - default_bits = 2048 - distinguished_name = peer2 - req_extensions = req_ext - x509_extensions = v3_req - prompt = no - - [ peer2 ] - countryName = US - stateOrProvinceName = California - localityName = San Francisco - organizationName = My Company - commonName = peerindex2 - - [ req_ext ] - subjectAltName = @alt_names - - [ v3_req ] - subjectAltName = @alt_names - - [ alt_names ] - DNS.1 = peerindex2 - IP.1 = 127.0.0.1 \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/certs/party-3/openssl-3.cnf b/demos-go/cmd/threshold-ecdsa-web/certs/party-3/openssl-3.cnf deleted file mode 100644 index 6725a7a7..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/certs/party-3/openssl-3.cnf +++ /dev/null @@ -1,23 +0,0 @@ - [ req ] - default_bits = 2048 - distinguished_name = peer3 - req_extensions = req_ext - x509_extensions = v3_req - prompt = no - - [ peer3 ] - countryName = US - stateOrProvinceName = California - localityName = San Francisco - organizationName = My Company - commonName = peerindex3 - - [ req_ext ] - subjectAltName = @alt_names - - [ v3_req ] - subjectAltName = @alt_names - - [ alt_names ] - DNS.1 = peerindex3 - IP.1 = 127.0.0.1 \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/config-0.yaml b/demos-go/cmd/threshold-ecdsa-web/config-0.yaml deleted file mode 100644 index 733248f8..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/config-0.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The CA certificate -caFile: "certs/ca.pem" - -# The certificate key and index of the current party -certFile: "certs/party-0/cert-0.pem" -keyFile: "certs/party-0/key-0.pem" -webAddress: "127.0.0.1:7080" - -# The list of parties and their certificates -parties: - - address: "127.0.0.1:8080" - cert: "certs/party-0/cert-0.der" - - address: "127.0.0.1:8081" - cert: "certs/party-1/cert-1.der" - - address: "127.0.0.1:8082" - cert: "certs/party-2/cert-2.der" - - address: "127.0.0.1:8083" - cert: "certs/party-3/cert-3.der" diff --git a/demos-go/cmd/threshold-ecdsa-web/config-1.yaml b/demos-go/cmd/threshold-ecdsa-web/config-1.yaml deleted file mode 100644 index 3f599d43..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/config-1.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The CA certificate -caFile: "certs/ca.pem" - -# The certificate key and index of the current party -certFile: "certs/party-1/cert-1.pem" -keyFile: "certs/party-1/key-1.pem" -webAddress: "127.0.0.1:7081" - -# The list of parties and their certificates -parties: - - address: "127.0.0.1:8080" - cert: "certs/party-0/cert-0.der" - - address: "127.0.0.1:8081" - cert: "certs/party-1/cert-1.der" - - address: "127.0.0.1:8082" - cert: "certs/party-2/cert-2.der" - - address: "127.0.0.1:8083" - cert: "certs/party-3/cert-3.der" diff --git a/demos-go/cmd/threshold-ecdsa-web/config-2.yaml b/demos-go/cmd/threshold-ecdsa-web/config-2.yaml deleted file mode 100644 index 5423dd1e..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/config-2.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The CA certificate -caFile: "certs/ca.pem" - -# The certificate key and index of the current party -certFile: "certs/party-2/cert-2.pem" -keyFile: "certs/party-2/key-2.pem" -webAddress: "127.0.0.1:7082" - -# The list of parties and their certificates -parties: - - address: "127.0.0.1:8080" - cert: "certs/party-0/cert-0.der" - - address: "127.0.0.1:8081" - cert: "certs/party-1/cert-1.der" - - address: "127.0.0.1:8082" - cert: "certs/party-2/cert-2.der" - - address: "127.0.0.1:8083" - cert: "certs/party-3/cert-3.der" diff --git a/demos-go/cmd/threshold-ecdsa-web/config-3.yaml b/demos-go/cmd/threshold-ecdsa-web/config-3.yaml deleted file mode 100644 index 3ff02ce6..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/config-3.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# The CA certificate -caFile: "certs/ca.pem" - -# The certificate key and index of the current party -certFile: "certs/party-3/cert-3.pem" -keyFile: "certs/party-3/key-3.pem" -webAddress: "127.0.0.1:7083" - -# The list of parties and their certificates -parties: - - address: "127.0.0.1:8080" - cert: "certs/party-0/cert-0.der" - - address: "127.0.0.1:8081" - cert: "certs/party-1/cert-1.der" - - address: "127.0.0.1:8082" - cert: "certs/party-2/cert-2.der" - - address: "127.0.0.1:8083" - cert: "certs/party-3/cert-3.der" diff --git a/demos-go/cmd/threshold-ecdsa-web/go.mod b/demos-go/cmd/threshold-ecdsa-web/go.mod deleted file mode 100644 index edc17edb..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/go.mod +++ /dev/null @@ -1,36 +0,0 @@ -module github.com/coinbase/cb-mpc/demo-runner - -go 1.24.1 - -replace github.com/coinbase/cb-mpc/demos-go/cb-mpc-go => ../../cb-mpc-go - -require ( - github.com/coinbase/cb-mpc/demos-go/cb-mpc-go v0.0.0-20250616220207-ed1310e03545 - github.com/labstack/echo/v4 v4.13.4 - github.com/spf13/viper v1.20.1 -) - -require ( - github.com/fsnotify/fsnotify v1.8.0 // indirect - github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/labstack/gommon v0.4.2 // indirect - github.com/mattn/go-colorable v0.1.14 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/sagikazarmark/locafero v0.7.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.12.0 // indirect - github.com/spf13/cast v1.7.1 // indirect - github.com/spf13/pflag v1.0.6 // indirect - github.com/subosito/gotenv v1.6.0 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasttemplate v1.2.2 // indirect - go.uber.org/atomic v1.9.0 // indirect - go.uber.org/multierr v1.9.0 // indirect - golang.org/x/crypto v0.45.0 // indirect - golang.org/x/net v0.47.0 // indirect - golang.org/x/sync v0.18.0 // indirect - golang.org/x/sys v0.38.0 // indirect - golang.org/x/text v0.31.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/demos-go/cmd/threshold-ecdsa-web/go.sum b/demos-go/cmd/threshold-ecdsa-web/go.sum deleted file mode 100644 index 8b134470..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/go.sum +++ /dev/null @@ -1,71 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= -github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= -github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= -github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/labstack/echo/v4 v4.13.4 h1:oTZZW+T3s9gAu5L8vmzihV7/lkXGZuITzTQkTEhcXEA= -github.com/labstack/echo/v4 v4.13.4/go.mod h1:g63b33BZ5vZzcIUF8AtRH40DrTlXnx4UMC8rBdndmjQ= -github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= -github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= -github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= -github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= -github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= -github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= -github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= -github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= -github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= -github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= -github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= -github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= -github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4= -github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= -github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= -github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= -go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demos-go/cmd/threshold-ecdsa-web/handlers.go b/demos-go/cmd/threshold-ecdsa-web/handlers.go deleted file mode 100644 index 7b506533..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/handlers.go +++ /dev/null @@ -1,337 +0,0 @@ -package main - -import ( - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/asn1" - "encoding/base64" - "fmt" - "math/big" - "os" - "path/filepath" - "strconv" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/mpc" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mtls" -) - -const LEADER_INDEX = 0 - -type KeyShareData struct { - KeyShare [][]byte `json:"keyShare"` - PartyName string `json:"partyName"` -} - -func runThresholdDKG(partyIndex int, quorumCount int, allPNameList []string, ac *mpc.AccessStructure, curve curve.Curve, messenger transport.Messenger) (*mpc.ECDSAMPCKey, error) { - totalPartyCount := len(allPNameList) // since dkg involves all parties - job, err := mpc.NewJobMP(messenger, totalPartyCount, partyIndex, allPNameList) - if err != nil { - return nil, fmt.Errorf("failed to create job: %v", err) - } - defer job.Free() - - if quorumCount < 2 { - return nil, fmt.Errorf("threshold must be at least 1") - } - if quorumCount > totalPartyCount { - return nil, fmt.Errorf("threshold must be less than or equal to the number of online parties") - } - - sid := make([]byte, 0) // empty sid means that the dkg will generate it internally - keyShare, err := mpc.ECDSAMPCThresholdDKG(job, &mpc.ECDSAMPCThresholdDKGRequest{ - Curve: curve, SessionID: sid, AccessStructure: ac, - }) - if err != nil { - return nil, fmt.Errorf("threshold DKG failed: %v", err) - } - return &keyShare.KeyShare, nil -} - -func runThresholdSign(keyShareRef *mpc.ECDSAMPCKey, ac *mpc.AccessStructure, partyIndex int, quorumCount int, quorumPNames []string, inputMessage []byte, messenger transport.Messenger) ([]byte, []byte, error) { - keyShare := *keyShareRef - - job, err := mpc.NewJobMP(messenger, quorumCount, partyIndex, quorumPNames) - if err != nil { - return nil, nil, fmt.Errorf("failed to create job: %v", err) - } - defer job.Free() - - if quorumCount != len(quorumPNames) { - return nil, nil, fmt.Errorf("quorum count does not match the number of participants") - } - - hashedMessage := sha256.Sum256(inputMessage) - message := hashedMessage[:] - - additiveShare, err := keyShare.ToAdditiveShare(ac, quorumPNames) - if err != nil { - return nil, nil, fmt.Errorf("converting to additive share: %v", err) - } - - sigResponse, err := mpc.ECDSAMPCSign(job, &mpc.ECDSAMPCSignRequest{ - KeyShare: additiveShare, - Message: message, - SignatureReceiver: LEADER_INDEX, - }) - if err != nil { - return nil, nil, fmt.Errorf("signing failed: %v", err) - } - - derSig := []byte{} - pemKey := []byte{} - if partyIndex == LEADER_INDEX { - derSig, err = createDERSignature(sigResponse.Signature) - if err != nil { - return nil, nil, fmt.Errorf("creating DER signature: %v", err) - } - - pemKey, err = createPEMPublicKey(&additiveShare) - if err != nil { - return nil, nil, fmt.Errorf("creating PEM public key: %v", err) - } - - } - return derSig, pemKey, nil -} - -func loadPartyConfig(certPEM []byte, partyAddress string) (mtls.PartyConfig, error) { - cert, err := x509.ParseCertificate(certPEM) - if err != nil { - return mtls.PartyConfig{}, fmt.Errorf("failed to parse expected server cert: %v", err) - } - return mtls.PartyConfig{ - Address: partyAddress, - Cert: cert, - }, nil -} - -func setupTransport(config Config, partyIndex int, participantsIndices map[int]bool) (int, string, []string, []string, transport.Messenger, error) { - cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) - if err != nil { - return 0, "", nil, nil, nil, fmt.Errorf("loading key pair %s and %s: %v", config.CertFile, config.KeyFile, err) - } - - caCert, err := os.ReadFile(config.CaFile) - if err != nil { - return 0, "", nil, nil, nil, fmt.Errorf("reading CA cert %s: %v", config.CaFile, err) - } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) - - allPNames := make([]string, 0) - participantPNames := make([]string, 0) - - myPname := "" - myNetworkPartyIndex := 0 - networkPartyIndex := 0 - networkParties := make(map[int]mtls.PartyConfig) - nameToIndex := make(map[string]int) - for i, party := range config.Parties { - certPEM, err := os.ReadFile(party.Cert) - if err != nil { - return 0, "", nil, nil, nil, fmt.Errorf("failed to read expected server cert: %v", err) - } - networkParty, err := loadPartyConfig(certPEM, party.Address) - if err != nil { - return 0, "", nil, nil, nil, fmt.Errorf("loading party config: %v", err) - } - pname, err := mtls.PartyNameFromCertificate(networkParty.Cert) - if err != nil { - return 0, "", nil, nil, nil, fmt.Errorf("extracting pname from cert: %v", err) - } - allPNames = append(allPNames, pname) - - if i == partyIndex { - myNetworkPartyIndex = networkPartyIndex - myPname = pname - } - if _, ok := participantsIndices[i]; ok { - participantPNames = append(participantPNames, pname) - networkParties[networkPartyIndex] = networkParty - nameToIndex[pname] = networkPartyIndex - networkPartyIndex++ - } - } - - transport, err := mtls.NewMTLSMessenger(mtls.Config{ - Parties: networkParties, - CertPool: caCertPool, - TLSCert: cert, - SelfIndex: myNetworkPartyIndex, - NameToIndex: nameToIndex, - }) - if err != nil { - return 0, "", nil, nil, nil, fmt.Errorf("failed to create transport: %v", err) - } - fmt.Printf("transport:\n") - fmt.Printf(" - myPname: %s\n", myPname) - fmt.Printf(" - myIndex: %d\n", myNetworkPartyIndex) - fmt.Printf(" - networkParties: %+v\n", networkParties) - - fmt.Println("MTLSDataTransport initialized successfully") - return myNetworkPartyIndex, myPname, allPNames, participantPNames, transport, nil -} - -func createThresholdAccessStructure(pnameList []string, threshold int, curve curve.Curve) mpc.AccessStructure { - root := mpc.Threshold("", threshold) - for _, pname := range pnameList { - child := mpc.Leaf(pname) - root.Children = append(root.Children, child) - } - - ac := mpc.AccessStructure{ - Root: root, - Curve: curve, - } - return ac -} - -func saveKeyShare(keyShare *mpc.ECDSAMPCKey, partyName string) error { - ser, err := keyShare.MarshalBinary() - if err != nil { - return fmt.Errorf("serializing key share: %v", err) - } - - filename := fmt.Sprintf("keyshare_party_%s.json", partyName) - if err := os.WriteFile(filename, ser, 0600); err != nil { - return fmt.Errorf("writing key file: %v", err) - } - - return nil -} - -func saveThreshold(threshold int) error { - if err := os.WriteFile("threshold.txt", fmt.Appendf(nil, "%d", threshold), 0600); err != nil { - return fmt.Errorf("writing key file: %v", err) - } - - return nil -} - -func loadThreshold() (int, error) { - data, err := os.ReadFile("threshold.txt") - if err != nil { - return 0, fmt.Errorf("reading threshold file: %v", err) - } - return strconv.Atoi(string(data)) -} - -func removeKeyShares() { - files, err := filepath.Glob("keyshare_party_*.json") - if err != nil { - fmt.Printf("finding key share files: %v\n", err) - } - for _, file := range files { - if err := os.Remove(file); err != nil { - fmt.Printf("removing key share file %s: %v\n", file, err) - } - } -} - -func loadKeyShare(partyName string) (*mpc.ECDSAMPCKey, error) { - filename := fmt.Sprintf("keyshare_party_%s.json", partyName) - data, err := os.ReadFile(filename) - if err != nil { - return nil, fmt.Errorf("reading key file: %v", err) - } - - keyShare := mpc.ECDSAMPCKey{} - if err := keyShare.UnmarshalBinary(data); err != nil { - return nil, fmt.Errorf("unmarshaling key data: %v", err) - } - - return &keyShare, nil -} - -// createOTRoleMap creates a default OT role map for a given number of parties -func createOTRoleMap(nParties int) [][]int { - const ( - OT_NO_ROLE = -1 - OT_SENDER = 0 - OT_RECEIVER = 1 - ) - - otRoleMap := make([][]int, nParties) - for i := 0; i < nParties; i++ { - otRoleMap[i] = make([]int, nParties) - otRoleMap[i][i] = OT_NO_ROLE - } - - for i := 0; i < nParties; i++ { - for j := i + 1; j < nParties; j++ { - otRoleMap[i][j] = OT_SENDER - otRoleMap[j][i] = OT_RECEIVER - } - } - - return otRoleMap -} - -func createPEMPublicKey(key *mpc.ECDSAMPCKey) ([]byte, error) { - Q, err := key.Q() - if err != nil { - return nil, fmt.Errorf("extracting public key: %v", err) - } - pubKeyX, pubKeyY := Q.GetX(), Q.GetY() - if err != nil { - return nil, fmt.Errorf("extracting public key: %v", err) - } - - // Create uncompressed public key (0x04 prefix + X + Y) - pubKeyBytes := make([]byte, 1+len(pubKeyX)+len(pubKeyY)) - pubKeyBytes[0] = 0x04 - copy(pubKeyBytes[1:], pubKeyX) - copy(pubKeyBytes[1+len(pubKeyX):], pubKeyY) - - // Create ASN.1 structure - pubKeyInfo := publicKeyInfo{ - Algorithm: algorithmIdentifier{ - Algorithm: ecPublicKeyOID, - Parameters: secp256k1OID, - }, - PublicKey: asn1.BitString{ - Bytes: pubKeyBytes, - BitLength: len(pubKeyBytes) * 8, - }, - } - - derBytes, err := asn1.Marshal(pubKeyInfo) - if err != nil { - return nil, fmt.Errorf("marshaling public key: %v", err) - } - - pemData := "-----BEGIN PUBLIC KEY-----\n" - b64Data := base64.StdEncoding.EncodeToString(derBytes) - for i := 0; i < len(b64Data); i += 64 { - end := i + 64 - if end > len(b64Data) { - end = len(b64Data) - } - pemData += b64Data[i:end] + "\n" - } - pemData += "-----END PUBLIC KEY-----\n" - - return []byte(pemData), nil -} - -func createDERSignature(signature []byte) ([]byte, error) { - // Parse signature - assuming it's already in DER format or raw r,s format - // If it's 64 bytes, it's likely raw r,s (32 bytes each) - if len(signature) == 64 { - // Raw format: first 32 bytes = r, next 32 bytes = s - r := new(big.Int).SetBytes(signature[:32]) - s := new(big.Int).SetBytes(signature[32:]) - - sig := ecdsaSignature{R: r, S: s} - derBytes, err := asn1.Marshal(sig) - if err != nil { - return nil, fmt.Errorf("marshaling signature: %v", err) - } - return derBytes, nil - } else { - return signature, nil - } -} diff --git a/demos-go/cmd/threshold-ecdsa-web/main.go b/demos-go/cmd/threshold-ecdsa-web/main.go deleted file mode 100644 index c325f145..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/main.go +++ /dev/null @@ -1,127 +0,0 @@ -package main - -import ( - "encoding/asn1" - "flag" - "fmt" - "log" - "math/big" - "strconv" - "strings" - - "github.com/spf13/viper" -) - -// ASN.1 structures for OpenSSL compatibility -type ecdsaSignature struct { - R *big.Int - S *big.Int -} - -type publicKeyInfo struct { - Algorithm algorithmIdentifier - PublicKey asn1.BitString -} - -type algorithmIdentifier struct { - Algorithm asn1.ObjectIdentifier - Parameters asn1.ObjectIdentifier `asn1:"optional"` -} - -// secp256k1 OID: 1.3.132.0.10 -var secp256k1OID = asn1.ObjectIdentifier{1, 3, 132, 0, 10} - -// ecPublicKey OID: 1.2.840.10045.2.1 -var ecPublicKeyOID = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} - -type PartyConfig struct { - Address string `yaml:"address"` - Cert string `yaml:"cert"` -} - -type Config struct { - CaFile string `yaml:"caFile"` - CertFile string `yaml:"certFile"` - WebAddress string `yaml:"webAddress"` - KeyFile string `yaml:"keyFile"` - Parties []PartyConfig `yaml:"parties"` -} - -type RunConfig struct { - Config Config - ParticipantsIndices map[int]bool - Phase string // dkg or sign - Threshold int // threshold for DKG or signing - MyIndex int // my index in the list of participants -} - -func readConfig() (*RunConfig, error) { - var configFile string - var participants string - var phase string - var threshold int - var myIndex int - - flag.StringVar(&configFile, "config", "", "path to config file") - flag.StringVar(&participants, "participants", "", "comma-separated list of participant indices") - flag.StringVar(&phase, "phase", "", "phase to run: agree-random or dkg or sign") - flag.IntVar(&threshold, "threshold", 3, "threshold for DKG or signing") - flag.IntVar(&myIndex, "index", 0, "my index in the list of participants") - flag.Parse() - - if configFile == "" { - configFile = fmt.Sprintf("config-%d.yaml", myIndex) - } - if strings.HasSuffix(configFile, ".yaml") { - configFile = configFile[:len(configFile)-5] - } - - viper.SetConfigName(configFile) - viper.SetConfigType("yaml") - viper.AddConfigPath(".") - - if err := viper.ReadInConfig(); err != nil { - return nil, fmt.Errorf("Error reading config file, %s", err) - } - - var config Config - if err := viper.Unmarshal(&config); err != nil { - return nil, fmt.Errorf("unable to decode into struct, %v", err) - } - - participantsIndices := make(map[int]bool) - if participants != "" { - tokens := strings.SplitSeq(participants, ",") - for token := range tokens { - participantIndex, err := strconv.Atoi(token) - if err != nil { - return nil, fmt.Errorf("invalid participant index: %v", err) - } - participantsIndices[participantIndex] = true - } - } else { - for i := range len(config.Parties) { - participantsIndices[i] = true - } - } - fmt.Printf("Running with %d total parties\n", len(participantsIndices)) - - return &RunConfig{ - Config: config, - ParticipantsIndices: participantsIndices, - Phase: phase, - Threshold: threshold, - MyIndex: myIndex, - }, nil -} - -func main() { - runConfig, err := readConfig() - if err != nil { - log.Fatalf("Error reading config: %v", err) - } - - if err := main_web(runConfig); err != nil { - log.Fatalf("Error running web: %v", err) - } -} diff --git a/demos-go/cmd/threshold-ecdsa-web/templates/dkg_base.html b/demos-go/cmd/threshold-ecdsa-web/templates/dkg_base.html deleted file mode 100644 index 0794ac7f..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/templates/dkg_base.html +++ /dev/null @@ -1,96 +0,0 @@ - - - - {{.Title}} - - - - - - -
- - - -
-
-

Distributed Key Generation with configurable threshold

-
- -
-

Participating Parties ({{len .Parties}})

-
- {{range $index, $party := .Parties}} - - Party {{$index}}: {{$party.Address}} - - {{end}} -
-
- -
-
-
- -
-
-
- -
-
-
-
- - - -
-

Connecting...

-

Establishing MTLS connections with all parties

- Connecting -
-
-
-
-
-
- - \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/templates/dkg_connection_success.html b/demos-go/cmd/threshold-ecdsa-web/templates/dkg_connection_success.html deleted file mode 100644 index 3945ce3b..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/templates/dkg_connection_success.html +++ /dev/null @@ -1,47 +0,0 @@ -
-

- Connection Established -

-

Connection time: {{.ConnectionTime}}

-
- -
-
DKG Configuration
-
- -
- -
-

Number of parties required to reconstruct the key (max: {{.MaxThreshold}})

-
- -
-
- -
-
-
- -
-
-
- - - -
-

Running DKG...

-

Generating distributed keys, this may take a few moments

- Processing -
-
\ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/templates/dkg_connection_waiting.html b/demos-go/cmd/threshold-ecdsa-web/templates/dkg_connection_waiting.html deleted file mode 100644 index b27600a1..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/templates/dkg_connection_waiting.html +++ /dev/null @@ -1,149 +0,0 @@ -
-

- Connection Established -

-

Connection time: {{.ConnectionTime}}

-
- -
-
-
- - - -
-

Waiting for Party 0

-

Party 0 will initiate the DKG process. Please wait...

- Waiting - -
-

- - You will automatically participate in DKG once Party 0 starts the process. -

-
-
-
- -
-
-
- - - -
-

DKG Started!

-

Participating in distributed key generation...

- Processing -
-
- - \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/templates/dkg_result.html b/demos-go/cmd/threshold-ecdsa-web/templates/dkg_result.html deleted file mode 100644 index 63317e94..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/templates/dkg_result.html +++ /dev/null @@ -1,14 +0,0 @@ -
- Success! DKG completed -

- {{if not .IsParty0}} - Execution Time: -
{{.ConnectionTime}}
-

- {{end}} - X Share: -
{{.XShare}}
-

- PEM Key: -
{{.PemKey}}
-
\ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/templates/error.html b/demos-go/cmd/threshold-ecdsa-web/templates/error.html deleted file mode 100644 index d2fb7edb..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/templates/error.html +++ /dev/null @@ -1,3 +0,0 @@ -
- Error: {{.Message}} -
\ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/templates/signing_immediate_waiting.html b/demos-go/cmd/threshold-ecdsa-web/templates/signing_immediate_waiting.html deleted file mode 100644 index 5026255d..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/templates/signing_immediate_waiting.html +++ /dev/null @@ -1,228 +0,0 @@ - - - - Threshold Signing - - - - - - -
- - -
-
-

- You are Party {{.CurrentParty}} -

-

Party 0 will configure the signing parameters and select participants. Please wait...

-
- -
-
-
- - - -
-

Waiting for Party 0

-

Party 0 will select participants, configure threshold, and provide a message to sign.

- Waiting - -
-

- - You will be notified if you're selected to participate in the signing process. -

-
-
-
- - - -
-
-
- - - -
-

Signing Started!

-

Establishing connections and participating in threshold signing...

- Processing -
-
- -
-
-
- - - - \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/templates/signing_leader_interface.html b/demos-go/cmd/threshold-ecdsa-web/templates/signing_leader_interface.html deleted file mode 100644 index 4d33ef1b..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/templates/signing_leader_interface.html +++ /dev/null @@ -1,237 +0,0 @@ - - - - {{.Title}} - - - - - - -
- - -
-
-

Configure and initiate threshold signature generation

-
- -
-

- You are Party 0 (Leader) -

-

Configure the signing parameters and initiate the process. Other parties are waiting for your decisions.

-
- -
-

Available Parties ({{len .Parties}})

-
- {{range $index, $party := .Parties}} - - Party {{$index}}: {{$party.Address}} - - {{end}} -
-
- -
-
- -
- -
-
- -
- -
-
- - - - -
-
- -
- -
- -
- -
-
- -
-
- -
-
-
- -
-
-
- - - -
-

Establishing Connections & Signing...

-

Setting up secure connections with selected parties and executing threshold signature

- Processing -
-
- -
-
-
- - - - \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/templates/signing_result.html b/demos-go/cmd/threshold-ecdsa-web/templates/signing_result.html deleted file mode 100644 index eaa8ae8d..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/templates/signing_result.html +++ /dev/null @@ -1,33 +0,0 @@ -{{if eq .PartyIndex 0}} -
-
-
Signature (DER Format)
-
-
- -
-

Base64-encoded DER signature compatible with OpenSSL

-
-
- -
-
Public Key (PEM Format)
-
-
- -
-

PEM-formatted public key for signature verification

-
-
-
- -{{else}} -
-

- Signing Complete -

-

Execution Time: {{.ConnectionTime}}

-

This party participated in the threshold signing process.

-

Signature and public key details are only displayed for Party 0.

-
-{{end}} \ No newline at end of file diff --git a/demos-go/cmd/threshold-ecdsa-web/web.go b/demos-go/cmd/threshold-ecdsa-web/web.go deleted file mode 100644 index bf8987b9..00000000 --- a/demos-go/cmd/threshold-ecdsa-web/web.go +++ /dev/null @@ -1,741 +0,0 @@ -package main - -import ( - "encoding/base64" - "fmt" - "html/template" - "log" - "net/http" - "sort" - "strconv" - "strings" - "sync" - "time" - - "slices" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mtls" - "github.com/labstack/echo/v4" -) - -type MockDB struct { - dkg *DKGCoordinationState - signing *SigningCoordinationState -} - -// Global DKG coordination state - -type DKGCoordinationState struct { - mutex sync.RWMutex - initiated bool - threshold int - initiatedAt time.Time - completionFlags map[int]bool -} - -// Global Signing coordination state - -type SigningCoordinationState struct { - mutex sync.RWMutex - initiated bool - threshold int - selectedParties []int - message string - initiatedAt time.Time - completionFlags map[int]bool -} - -type PageData struct { - Title string - Parties []PartyConfig -} - -type SigningPageData struct { - Title string - Parties []PartyConfig - CurrentParty int - Threshold int - MaxOtherParties int -} - -type ConnectionData struct { - ConnectionTime string - MaxThreshold int - Party0BaseUrl string -} - -type SigningConnectionData struct { - ConnectionTime string - TotalParties int - Threshold int - ConnectedParties []int - Party0BaseUrl string -} - -type SigningWaitingData struct { - ConnectionTime string - Party0BaseUrl string - AllParties []PartyConfig - CurrentParty int -} - -type SigningCoordinationData struct { - Initiated bool - Threshold int - SelectedParties []int - Message string - ShouldParticipate bool -} - -type DKGResultData struct { - IsParty0 bool - ConnectionTime string - XShare string - PemKey string -} - -type SigningResultData struct { - ConnectionTime string - Message string - SignatureBase64 string - PublicKey string - PartyIndex int -} - -type ErrorData struct { - Message string -} - -var templates *template.Template - -var db = &MockDB{ - dkg: &DKGCoordinationState{ - completionFlags: make(map[int]bool), - }, - signing: &SigningCoordinationState{ - completionFlags: make(map[int]bool), - }, -} - -func (d *DKGCoordinationState) Set(initiated bool, threshold int, initiatedAt time.Time) { - d.mutex.Lock() - d.initiated = initiated - d.threshold = threshold - d.initiatedAt = initiatedAt - d.completionFlags = make(map[int]bool) - d.mutex.Unlock() -} - -func (s *SigningCoordinationState) Set(initiated bool, threshold int, selectedParties []int, message string, initiatedAt time.Time) { - s.mutex.Lock() - s.initiated = initiated - s.threshold = threshold - s.selectedParties = selectedParties - s.message = message - s.initiatedAt = initiatedAt - s.completionFlags = make(map[int]bool) - s.mutex.Unlock() -} - -func (d *DKGCoordinationState) SetCompletionFlag(partyIndex int, completed bool) { - d.mutex.Lock() - d.completionFlags[partyIndex] = completed - d.mutex.Unlock() -} - -func (s *SigningCoordinationState) SetCompletionFlag(partyIndex int, completed bool) { - s.mutex.Lock() - s.completionFlags[partyIndex] = completed - s.mutex.Unlock() -} - -func closeExistingConnections(dkgTransport transport.Messenger, signingTransport transport.Messenger) { - fmt.Printf("Closing existing connections...\n") - - hadConnections := false - - if dkgTransport != nil { - hadConnections = true - if err := dkgTransport.(*mtls.MTLSMessenger).Close(); err != nil { - fmt.Printf("Error closing DKG transport: %v\n", err) - } - dkgTransport = nil - } - - if signingTransport != nil { - hadConnections = true - if err := signingTransport.(*mtls.MTLSMessenger).Close(); err != nil { - fmt.Printf("Error closing signing transport: %v\n", err) - } - signingTransport = nil - } - - if hadConnections { - fmt.Printf("Waiting briefly for ports to be released...\n") - time.Sleep(100 * time.Millisecond) - } - - fmt.Printf("Connection cleanup completed\n") -} - -func loadTemplates() error { - var err error - templates, err = template.ParseGlob("templates/*.html") - if err != nil { - return fmt.Errorf("failed to parse templates: %v", err) - } - return nil -} - -func renderTemplate(name string, data interface{}) (string, error) { - var buf strings.Builder - err := templates.ExecuteTemplate(&buf, name, data) - if err != nil { - return "", fmt.Errorf("failed to execute template %s: %v", name, err) - } - return buf.String(), nil -} - -func renderError(message string) (string, error) { - return renderTemplate("error.html", ErrorData{Message: message}) -} - -func dkg(dkgPartyIndex int, dkgPartyName string, dkgAllPNames []string, dkgTransport transport.Messenger, threshold int, curve curve.Curve) (time.Duration, []byte, []byte, error) { - ac := createThresholdAccessStructure(dkgAllPNames, threshold, curve) - - startTime := time.Now() - keyShare, err := runThresholdDKG(dkgPartyIndex, threshold, dkgAllPNames, &ac, curve, dkgTransport) - if err != nil { - return 0, nil, nil, fmt.Errorf("threshold DKG failed: %v", err) - } - duration := time.Since(startTime) - - if err := saveKeyShare(keyShare, dkgPartyName); err != nil { - return 0, nil, nil, fmt.Errorf("saving key share: %v", err) - } - - if err := saveThreshold(threshold); err != nil { - return 0, nil, nil, fmt.Errorf("saving threshold: %v", err) - } - - pemKey, err := createPEMPublicKey(keyShare) - if err != nil { - return 0, nil, nil, fmt.Errorf("creating PEM public key: %v", err) - } - xShare, err := keyShare.XShare() - if err != nil { - return 0, nil, nil, fmt.Errorf("extracting public key: %v", err) - } - return duration, pemKey, []byte(xShare.String()), nil -} - -func main_web(runConfig *RunConfig) error { - e := echo.New() - - // Enable CORS to allow cross-origin requests between party instances - e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - c.Response().Header().Set("Access-Control-Allow-Origin", "*") - c.Response().Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE") - c.Response().Header().Set("Access-Control-Allow-Headers", "*") - return next(c) - } - }) - - // These will be populated in one endpoint call and used in another. - // Proper implementation would use a session to store most of them in db. - var err error - var dkgTransport transport.Messenger - var dkgPartyIndex int - var dkgPartyName string - var dkgAllPNames []string - - var signingTransport transport.Messenger - var signingPartyIndex int - var signingPartyName string - var signingAllPNames []string - var signingParticipantPNames []string - - curve, err := curve.NewSecp256k1() - if err != nil { - return fmt.Errorf("creating curve: %v", err) - } - - defer func() { - closeExistingConnections(dkgTransport, signingTransport) - }() - - if err := loadTemplates(); err != nil { - return fmt.Errorf("failed to load templates: %v", err) - } - - // The initial dkg page - - e.GET("/page/dkg", func(c echo.Context) error { - // Clean up any existing connections when loading the DKG page - closeExistingConnections(dkgTransport, signingTransport) - - db.dkg.Set(false, 0, time.Time{}) - // removeKeyShares() - - dkgPartyIndex = 0 - dkgPartyName = "" - dkgAllPNames = nil - - signingPartyIndex = 0 - signingPartyName = "" - signingAllPNames = nil - signingParticipantPNames = nil - - pageData := PageData{ - Title: "Threshold DKG", - Parties: runConfig.Config.Parties, - } - - html, err := renderTemplate("dkg_base.html", pageData) - if err != nil { - log.Printf("Error rendering template: %v", err) - return c.String(http.StatusInternalServerError, "Template error") - } - - return c.HTML(http.StatusOK, html) - }) - - // HTMX endpoint to process connect request from the dkg page - - e.GET("/api/dkg/connect", func(c echo.Context) error { - startTime := time.Now() - - dkgPartyIndex, dkgPartyName, dkgAllPNames, _, dkgTransport, err = setupTransport(runConfig.Config, runConfig.MyIndex, runConfig.ParticipantsIndices) - if err != nil { - return fmt.Errorf("setting up transport: %v", err) - } - - duration := time.Since(startTime) - - connectionData := ConnectionData{ - ConnectionTime: duration.Round(time.Millisecond).String(), - MaxThreshold: len(dkgAllPNames), - Party0BaseUrl: "http://127.0.0.1:7080", - } - - // Party 0 gets the control interface, others get waiting interface - var templateName string - if runConfig.MyIndex == 0 { - templateName = "dkg_connection_success.html" - } else { - templateName = "dkg_connection_waiting.html" - } - - html, err := renderTemplate(templateName, connectionData) - if err != nil { - log.Printf("Error rendering connection template: %v", err) - errorHtml, _ := renderError("Template rendering failed") - return c.HTML(http.StatusInternalServerError, errorHtml) - } - - return c.HTML(http.StatusOK, html) - }) - - // HTMX endpoint to process execute request from the dkg page - - e.GET("/api/dkg/execute", func(c echo.Context) error { - // Only party 0 should be able to call this endpoint directly - if runConfig.MyIndex != 0 { - errorHtml, _ := renderError("Only party 0 can initiate DKG") - return c.HTML(http.StatusForbidden, errorHtml) - } - - threshold := c.QueryParam("threshold") - if threshold == "" { - errorHtml, _ := renderError("Threshold parameter is required") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - thresholdInt, err := strconv.Atoi(threshold) - if err != nil { - errorHtml, _ := renderError("Invalid threshold value") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - if thresholdInt < 1 || thresholdInt > len(dkgAllPNames) { - errorHtml, _ := renderError(fmt.Sprintf("Threshold must be between 1 and %d", len(dkgAllPNames))) - return c.HTML(http.StatusBadRequest, errorHtml) - } - - db.dkg.Set(true, thresholdInt, time.Now()) - - duration, pemKey, xShare, err := dkg(dkgPartyIndex, dkgPartyName, dkgAllPNames, dkgTransport, thresholdInt, curve) - if err != nil { - return fmt.Errorf("dkg failed: %v", err) - } - - db.dkg.SetCompletionFlag(runConfig.MyIndex, true) - - resultData := DKGResultData{ - IsParty0: runConfig.MyIndex == 0, - ConnectionTime: duration.Round(time.Millisecond).String(), - XShare: base64.StdEncoding.EncodeToString(xShare), - PemKey: string(pemKey), - } - - html, err := renderTemplate("dkg_result.html", resultData) - if err != nil { - log.Printf("Error rendering DKG result template: %v", err) - errorHtml, _ := renderError("Template rendering failed") - return c.HTML(http.StatusInternalServerError, errorHtml) - } - - return c.HTML(http.StatusOK, html) - }) - - // Polling endpoint for non-party-0 participants to check if DKG has been initiated - - e.GET("/api/dkg/poll", func(c echo.Context) error { - db.dkg.mutex.RLock() - initiated := db.dkg.initiated - threshold := db.dkg.threshold - db.dkg.mutex.RUnlock() - - if initiated { - return c.JSON(http.StatusOK, map[string]interface{}{ - "initiated": true, - "threshold": threshold, - }) - } - - return c.JSON(http.StatusOK, map[string]interface{}{ - "initiated": false, - }) - }) - - // Auto-execute endpoint for non-party-0 participants - - e.GET("/api/dkg/auto-execute", func(c echo.Context) error { - // Only non-party-0 participants should use this endpoint - if runConfig.MyIndex == 0 { - errorHtml, _ := renderError("Party 0 should use the regular execute endpoint") - return c.HTML(http.StatusForbidden, errorHtml) - } - - // Get threshold from query parameter since we already verified DKG initiation via polling - thresholdStr := c.QueryParam("threshold") - if thresholdStr == "" { - errorHtml, _ := renderError("Threshold parameter is required") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - threshold, err := strconv.Atoi(thresholdStr) - if err != nil { - errorHtml, _ := renderError("Invalid threshold value") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - duration, pemKey, xShare, err := dkg(dkgPartyIndex, dkgPartyName, dkgAllPNames, dkgTransport, threshold, curve) - if err != nil { - return fmt.Errorf("dkg failed: %v", err) - } - - resultData := DKGResultData{ - IsParty0: runConfig.MyIndex == 0, - ConnectionTime: duration.Round(time.Millisecond).String(), - XShare: base64.StdEncoding.EncodeToString(xShare), - PemKey: string(pemKey), - } - - html, err := renderTemplate("dkg_result.html", resultData) - if err != nil { - log.Printf("Error rendering DKG result template: %v", err) - errorHtml, _ := renderError("Template rendering failed") - return c.HTML(http.StatusInternalServerError, errorHtml) - } - - return c.HTML(http.StatusOK, html) - }) - - // The initial signing page - - e.GET("/page/sign", func(c echo.Context) error { - // Clean up any existing connections when loading the signing page - closeExistingConnections(dkgTransport, signingTransport) - - db.signing.Set(false, 0, nil, "", time.Time{}) - - dkgPartyIndex = 0 - dkgPartyName = "" - dkgAllPNames = nil - - signingPartyIndex = 0 - signingPartyName = "" - signingAllPNames = nil - signingParticipantPNames = nil - - // Party 0 gets immediate configuration interface, others get waiting interface - if runConfig.MyIndex == 0 { - threshold, err := loadThreshold() - if err != nil { - return fmt.Errorf("loading threshold: %v", err) - } - - signingPageData := SigningPageData{ - Title: "Threshold Signing", - Parties: runConfig.Config.Parties, - CurrentParty: runConfig.MyIndex, - Threshold: threshold, - MaxOtherParties: threshold - 1, - } - - html, err := renderTemplate("signing_leader_interface.html", signingPageData) - if err != nil { - log.Printf("Error rendering signing leader template: %v", err) - return c.String(http.StatusInternalServerError, "Template error") - } - return c.HTML(http.StatusOK, html) - } else { - // Other parties get immediate waiting interface - waitingData := SigningWaitingData{ - ConnectionTime: "Ready", - Party0BaseUrl: "http://127.0.0.1:7080", - AllParties: runConfig.Config.Parties, - CurrentParty: runConfig.MyIndex, - } - - html, err := renderTemplate("signing_immediate_waiting.html", waitingData) - if err != nil { - log.Printf("Error rendering signing waiting template: %v", err) - return c.String(http.StatusInternalServerError, "Template error") - } - return c.HTML(http.StatusOK, html) - } - }) - - // HTMX endpoint to process execute request from the signing page - - e.GET("/api/sign/execute", func(c echo.Context) error { - // Only party 0 should be able to call this endpoint directly - if runConfig.MyIndex != 0 { - errorHtml, _ := renderError("Only party 0 can initiate signing") - return c.HTML(http.StatusForbidden, errorHtml) - } - - selectedParties := c.QueryParams()["parties"] - thresholdStr := c.QueryParam("threshold") - message := c.QueryParam("message") - - if thresholdStr == "" { - errorHtml, _ := renderError("Threshold parameter is required") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - if message == "" { - errorHtml, _ := renderError("Message parameter is required") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - threshold, err := strconv.Atoi(thresholdStr) - if err != nil { - errorHtml, _ := renderError("Invalid threshold value") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - // Always include party 0 in the selected parties - selectedParties = append(selectedParties, "0") - if len(selectedParties) != threshold { - errorHtml, _ := renderError(fmt.Sprintf("Must select exactly %d parties, got %d", threshold, len(selectedParties))) - return c.HTML(http.StatusBadRequest, errorHtml) - } - - // Convert selected parties to integers and create participants map - selectedPartyInts := make([]int, 0) - signingParticipantsIndices := make(map[int]bool) - - for _, partyStr := range selectedParties { - partyIdx, err := strconv.Atoi(partyStr) - if err != nil { - errorHtml, _ := renderError(fmt.Sprintf("Invalid party index: %s", partyStr)) - return c.HTML(http.StatusBadRequest, errorHtml) - } - selectedPartyInts = append(selectedPartyInts, partyIdx) - signingParticipantsIndices[partyIdx] = true - } - sort.Ints(selectedPartyInts) - - db.signing.Set(true, threshold, selectedPartyInts, message, time.Now()) - db.signing.selectedParties = selectedPartyInts - - signingPartyIndex, signingPartyName, signingAllPNames, signingParticipantPNames, signingTransport, err = setupTransport(runConfig.Config, runConfig.MyIndex, signingParticipantsIndices) - if err != nil { - return fmt.Errorf("setting up transport for signing: %v", err) - } - - // Execute signing for party 0 - keyShare, err := loadKeyShare(signingPartyName) - if err != nil { - return fmt.Errorf("loading key share: %v", err) - } - - ac := createThresholdAccessStructure(signingAllPNames, threshold, curve) - - startTime := time.Now() - signature, publicKey, err := runThresholdSign(keyShare, &ac, signingPartyIndex, threshold, signingParticipantPNames, []byte(message), signingTransport) - if err != nil { - log.Printf("Threshold signing failed: %v", err) - errorHtml, _ := renderError(fmt.Sprintf("Signing failed: %v", err)) - return c.HTML(http.StatusInternalServerError, errorHtml) - } - duration := time.Since(startTime) - - db.signing.SetCompletionFlag(runConfig.MyIndex, true) - - resultData := SigningResultData{ - ConnectionTime: duration.Round(time.Millisecond).String(), - Message: message, - SignatureBase64: base64.StdEncoding.EncodeToString(signature), - PublicKey: string(publicKey), - PartyIndex: signingPartyIndex, - } - - html, err := renderTemplate("signing_result.html", resultData) - if err != nil { - log.Printf("Error rendering signing result template: %v", err) - errorHtml, _ := renderError("Template rendering failed") - return c.HTML(http.StatusInternalServerError, errorHtml) - } - - return c.HTML(http.StatusOK, html) - }) - - // Polling endpoint for non-party-0 participants to check if signing has been initiated - e.GET("/api/sign/poll", func(c echo.Context) error { - db.signing.mutex.RLock() - initiated := db.signing.initiated - threshold := db.signing.threshold - selectedParties := db.signing.selectedParties - message := db.signing.message - db.signing.mutex.RUnlock() - - counterParty := c.QueryParam("party") - counterPartyInt, err := strconv.Atoi(counterParty) - if err != nil { - errorHtml, _ := renderError("Invalid counter party value") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - if initiated { - shouldParticipate := slices.Contains(selectedParties, counterPartyInt) - return c.JSON(http.StatusOK, map[string]interface{}{ - "initiated": true, - "threshold": threshold, - "selectedParties": selectedParties, - "message": message, - "shouldParticipate": shouldParticipate, - "party": counterPartyInt, - }) - } - - return c.JSON(http.StatusOK, map[string]interface{}{ - "initiated": false, - }) - }) - - // Auto-execute endpoint for non-party-0 participants - - e.GET("/api/sign/auto-execute", func(c echo.Context) error { - // Only non-party-0 participants should use this endpoint - if runConfig.MyIndex == 0 { - errorHtml, _ := renderError("Party 0 should use the regular execute endpoint") - return c.HTML(http.StatusForbidden, errorHtml) - } - - // Get parameters from query since we already verified signing initiation via polling - thresholdStr := c.QueryParam("threshold") - selectedPartiesStr := c.QueryParam("selectedParties") - message := c.QueryParam("message") - - if thresholdStr == "" || selectedPartiesStr == "" || message == "" { - errorHtml, _ := renderError("Missing required parameters") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - threshold, err := strconv.Atoi(thresholdStr) - if err != nil { - errorHtml, _ := renderError("Invalid threshold value") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - // Parse selected parties - var selectedParties []int - partyStrs := strings.Split(selectedPartiesStr, ",") - signingParticipantsIndices := make(map[int]bool) - - for _, partyStr := range partyStrs { - if strings.TrimSpace(partyStr) == "" { - continue - } - partyIdx, err := strconv.Atoi(strings.TrimSpace(partyStr)) - if err != nil { - errorHtml, _ := renderError(fmt.Sprintf("Invalid party index: %s", partyStr)) - return c.HTML(http.StatusBadRequest, errorHtml) - } - selectedParties = append(selectedParties, partyIdx) - signingParticipantsIndices[partyIdx] = true - } - - counterParty := c.QueryParam("party") - counterPartyInt, err := strconv.Atoi(counterParty) - if err != nil { - errorHtml, _ := renderError("Invalid counter party value") - return c.HTML(http.StatusBadRequest, errorHtml) - } - - if !slices.Contains(selectedParties, counterPartyInt) { - errorHtml, _ := renderError(fmt.Sprintf("Party %d is not in the selected parties list", counterPartyInt)) - return c.HTML(http.StatusBadRequest, errorHtml) - } - - signingPartyIndex, signingPartyName, signingAllPNames, signingParticipantPNames, signingTransport, err = setupTransport(runConfig.Config, runConfig.MyIndex, signingParticipantsIndices) - if err != nil { - return fmt.Errorf("setting up transport for auto-signing: %v", err) - } - - keyShare, err := loadKeyShare(signingPartyName) - if err != nil { - return fmt.Errorf("loading key share: %v", err) - } - - ac := createThresholdAccessStructure(signingAllPNames, threshold, curve) - - startTime := time.Now() - signature, publicKey, err := runThresholdSign(keyShare, &ac, signingPartyIndex, threshold, signingParticipantPNames, []byte(message), signingTransport) - if err != nil { - log.Printf("Auto threshold signing failed: %v", err) - errorHtml, _ := renderError(fmt.Sprintf("Auto-signing failed: %v", err)) - return c.HTML(http.StatusInternalServerError, errorHtml) - } - duration := time.Since(startTime) - - db.signing.SetCompletionFlag(runConfig.MyIndex, true) - - resultData := SigningResultData{ - ConnectionTime: duration.Round(time.Millisecond).String(), - Message: message, - SignatureBase64: base64.StdEncoding.EncodeToString(signature), - PublicKey: string(publicKey), - PartyIndex: signingPartyIndex, - } - - html, err := renderTemplate("signing_result.html", resultData) - if err != nil { - log.Printf("Error rendering signing result template: %v", err) - errorHtml, _ := renderError("Template rendering failed") - return c.HTML(http.StatusInternalServerError, errorHtml) - } - - return c.HTML(http.StatusOK, html) - }) - - e.Logger.Fatal(e.Start(runConfig.Config.WebAddress)) - return nil -} diff --git a/demos-go/examples/access-structure/go.mod b/demos-go/examples/access-structure/go.mod deleted file mode 100644 index 4fc611a2..00000000 --- a/demos-go/examples/access-structure/go.mod +++ /dev/null @@ -1,14 +0,0 @@ -module github.com/coinbase/cb-mpc/demo-go-access-structure - -go 1.23.0 - -toolchain go1.24.2 - -require github.com/coinbase/cb-mpc/demos-go/cb-mpc-go v0.0.0-20240501131245-1eee31b51009 - -require ( - github.com/stretchr/testify v1.10.0 // indirect - golang.org/x/sync v0.15.0 // indirect -) - -replace github.com/coinbase/cb-mpc/demos-go/cb-mpc-go => ../../cb-mpc-go diff --git a/demos-go/examples/access-structure/go.sum b/demos-go/examples/access-structure/go.sum deleted file mode 100644 index c43696e9..00000000 --- a/demos-go/examples/access-structure/go.sum +++ /dev/null @@ -1,10 +0,0 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demos-go/examples/access-structure/main.go b/demos-go/examples/access-structure/main.go deleted file mode 100644 index 25ea6e8c..00000000 --- a/demos-go/examples/access-structure/main.go +++ /dev/null @@ -1,36 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/mpc" -) - -func main() { - // Build the same sample tree described in the AccessNode documentation. - root := mpc.And("", - mpc.Or("role", - mpc.Leaf("role:Admin"), - mpc.Leaf("dept:HR"), - ), - mpc.Threshold("sig", 2, - mpc.Leaf("sig:A"), - mpc.Leaf("sig:B"), - mpc.Leaf("sig:C"), - ), - ) - - // Use secp256k1 curve in this example. - c, err := curve.NewSecp256k1() - if err != nil { - panic(err) - } - - as := &mpc.AccessStructure{ - Root: root, - Curve: c, - } - - fmt.Print(as) -} diff --git a/demos-go/examples/agreerandom/go.mod b/demos-go/examples/agreerandom/go.mod deleted file mode 100644 index 9ab201e3..00000000 --- a/demos-go/examples/agreerandom/go.mod +++ /dev/null @@ -1,14 +0,0 @@ -module github.com/coinbase/cb-mpc/demo-go-agreerandom - -go 1.23.0 - -toolchain go1.24.2 - -require github.com/coinbase/cb-mpc/demos-go/cb-mpc-go v0.0.0-20240501131245-1eee31b51009 - -require ( - github.com/stretchr/testify v1.10.0 // indirect - golang.org/x/sync v0.15.0 // indirect -) - -replace github.com/coinbase/cb-mpc/demos-go/cb-mpc-go => ../../cb-mpc-go diff --git a/demos-go/examples/agreerandom/go.sum b/demos-go/examples/agreerandom/go.sum deleted file mode 100644 index c43696e9..00000000 --- a/demos-go/examples/agreerandom/go.sum +++ /dev/null @@ -1,10 +0,0 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demos-go/examples/agreerandom/main.go b/demos-go/examples/agreerandom/main.go deleted file mode 100644 index bd7830ca..00000000 --- a/demos-go/examples/agreerandom/main.go +++ /dev/null @@ -1,92 +0,0 @@ -package main - -import ( - "encoding/hex" - "fmt" - "log" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/mpc" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" -) - -// runAgreeRandomDemo runs the AgreeRandom protocol for the given bit length and -// prints the agreed-upon value for both parties. -func runAgreeRandomDemo(bitLen int) error { - const nParties = 2 - - // Create an in-memory mock network. - messengers := mocknet.NewMockNetwork(nParties) - - partyNames := []string{"party_0", "party_1"} - - responses := make([]*mpc.AgreeRandomResponse, nParties) - errChan := make(chan error, nParties) - respChan := make(chan struct { - idx int - resp *mpc.AgreeRandomResponse - }, nParties) - - for i := 0; i < nParties; i++ { - go func(partyIdx int) { - // Construct Job2P for this party. - jp, err := mpc.NewJob2P(messengers[partyIdx], partyIdx, partyNames) - if err != nil { - errChan <- err - return - } - defer jp.Free() - - req := &mpc.AgreeRandomRequest{BitLen: bitLen} - resp, err := mpc.AgreeRandom(jp, req) - if err != nil { - errChan <- err - return - } - - respChan <- struct { - idx int - resp *mpc.AgreeRandomResponse - }{partyIdx, resp} - }(i) - } - - // Collect results. - for i := 0; i < nParties; i++ { - select { - case err := <-errChan: - return err - case r := <-respChan: - responses[r.idx] = r.resp - } - } - - // Verify both parties agreed. - agreedHex := hex.EncodeToString(responses[0].RandomValue) - fmt.Printf("Party 0: agreed on randomness %s\n", agreedHex) - fmt.Printf("Party 1: agreed on randomness %s\n", hex.EncodeToString(responses[1].RandomValue)) - if agreedHex == hex.EncodeToString(responses[1].RandomValue) { - fmt.Println("✅ Both parties agreed on the same random value!") - } else { - fmt.Println("❌ Parties got different random values!") - } - - return nil -} - -func main() { - fmt.Println("\n=== CB-MPC Agree Random Example ===") - - fmt.Println("## Running 2-party AgreeRandom (128 bits)") - if err := runAgreeRandomDemo(128); err != nil { - log.Fatalf("AgreeRandom 128-bit failed: %v", err) - } - - fmt.Println() - - fmt.Println("## Running 2-party AgreeRandom (10 bits)") - if err := runAgreeRandomDemo(10); err != nil { - log.Fatalf("AgreeRandom 10-bit failed: %v", err) - } - - fmt.Println("\nAgreeRandom example completed successfully!") -} diff --git a/demos-go/examples/ecdsa-2pc/go.mod b/demos-go/examples/ecdsa-2pc/go.mod deleted file mode 100644 index 8f1d21a1..00000000 --- a/demos-go/examples/ecdsa-2pc/go.mod +++ /dev/null @@ -1,17 +0,0 @@ -module github.com/coinbase/cb-mpc/demo-go-ecdsa-2pc - -go 1.23.0 - -toolchain go1.24.2 - -require github.com/coinbase/cb-mpc/demos-go/cb-mpc-go v0.0.0-20240501131245-1eee31b51009 - -require github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect - -require ( - github.com/btcsuite/btcd/btcec/v2 v2.3.5 - github.com/stretchr/testify v1.10.0 // indirect - golang.org/x/sync v0.15.0 // indirect -) - -replace github.com/coinbase/cb-mpc/demos-go/cb-mpc-go => ../../cb-mpc-go diff --git a/demos-go/examples/ecdsa-2pc/go.sum b/demos-go/examples/ecdsa-2pc/go.sum deleted file mode 100644 index d7849e65..00000000 --- a/demos-go/examples/ecdsa-2pc/go.sum +++ /dev/null @@ -1,18 +0,0 @@ -github.com/btcsuite/btcd/btcec/v2 v2.3.5 h1:dpAlnAwmT1yIBm3exhT1/8iUSD98RDJM5vqJVQDQLiU= -github.com/btcsuite/btcd/btcec/v2 v2.3.5/go.mod h1:m22FrOAiuxl/tht9wIqAoGHcbnCCaPWyauO8y2LGGtQ= -github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 h1:q0rUy8C/TYNBQS1+CGKw68tLOFYSNEs0TFnxxnS9+4U= -github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/decred/dcrd/crypto/blake256 v1.0.0 h1:/8DMNYp9SGi5f0w7uCm6d6M4OU2rGFK09Y2A4Xv7EE0= -github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demos-go/examples/ecdsa-2pc/main.go b/demos-go/examples/ecdsa-2pc/main.go deleted file mode 100644 index 5d7f3c98..00000000 --- a/demos-go/examples/ecdsa-2pc/main.go +++ /dev/null @@ -1,295 +0,0 @@ -package main - -import ( - "crypto/ecdsa" - "crypto/sha256" - "encoding/asn1" - "encoding/hex" - "fmt" - "log" - "math/big" - - "sync" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/mpc" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" - "github.com/btcsuite/btcd/btcec/v2" - btcecEcdsa "github.com/btcsuite/btcd/btcec/v2/ecdsa" -) - -func main() { - fmt.Println("\n=== CB-MPC ECDSA 2PC Example ===") - - // Example: Complete ECDSA 2PC workflow (key generation + signing) - // Initialize the secp256k1 curve implementation. Remember to release the - // underlying native resources when it is no longer needed. - curveObj, err := curve.NewSecp256k1() - if err != nil { - log.Fatalf("failed to initialize curve: %v", err) - } - defer curveObj.Free() - - fmt.Println("## Running ECDSA 2PC key generation only") - - keyGenResponses, err := keyGenWithMockNet(curveObj) - if err != nil { - log.Fatalf("Key generation failed: %v", err) - } - - fmt.Printf("✅ Key generation completed\n") - fmt.Printf("Generated %d key shares for future use\n", len(keyGenResponses)) - - printKeyShares("Initial key shares", keyGenResponses) - - // === Signing round 1 === - message1 := []byte("Hello, CB-MPC!") - digest1 := sha256.Sum256(message1) - fmt.Println("\n## Running first collaborative signing round") - firstSigResponses, err := signWithMockNet([]byte("session-1"), digest1[:], keyGenResponses) - if err != nil { - log.Fatalf("Signing round 1 failed: %v", err) - } - printSignatures(message1, firstSigResponses) - if err := verifyExampleSignature(keyGenResponses[0].KeyShare, firstSigResponses[0].Signature, digest1[:]); err != nil { - fmt.Printf("❌ Signature verification FAILED: %v\n", err) - } else { - fmt.Println("✅ Signature verification PASSED") - } - - // === Refresh === - fmt.Println("\n## Running key refresh (re-share)") - refreshResponses, err := refreshWithMockNet(keyGenResponses) - if err != nil { - log.Fatalf("Refresh failed: %v", err) - } - fmt.Println("✅ Refresh completed – parties now hold new key shares") - printKeyShares("Refreshed key shares", refreshResponses) - - // === Signing round 2 === - message2 := []byte("Fresh signing after refresh!") - digest2 := sha256.Sum256(message2) - fmt.Println("\n## Running second collaborative signing round") - secondSigResponses, err := signWithMockNet([]byte("session-2"), digest2[:], refreshResponses) - if err != nil { - log.Fatalf("Signing round 2 failed: %v", err) - } - printSignatures(message2, secondSigResponses) - if err := verifyExampleSignature(refreshResponses[0].KeyShare, secondSigResponses[0].Signature, digest2[:]); err != nil { - fmt.Printf("❌ Signature verification FAILED: %v\n", err) - } else { - fmt.Println("✅ Signature verification PASSED") - } - - fmt.Println("🎉 ECDSA 2PC example completed successfully!") -} - -// keyGenWithMockNet is a small helper used solely by this example to run the -// distributed key-generation protocol using the in-memory mock network. It -// relies only on the public cb-mpc-go API and therefore avoids importing any -// internal packages. -func keyGenWithMockNet(curveObj curve.Curve) ([]*mpc.ECDSA2PCKeyGenResponse, error) { - if curveObj == nil { - return nil, fmt.Errorf("curve must be provided") - } - - const nParties = 2 - messengers := mocknet.NewMockNetwork(nParties) - partyNames := []string{"party_0", "party_1"} - - responses := make([]*mpc.ECDSA2PCKeyGenResponse, nParties) - var wg sync.WaitGroup - var firstErr error - - for i := 0; i < nParties; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - - // Create a Job2P for this party - jp, err := mpc.NewJob2P(messengers[idx], idx, partyNames) - if err != nil { - firstErr = err - return - } - defer jp.Free() - - resp, err := mpc.ECDSA2PCKeyGen(jp, &mpc.ECDSA2PCKeyGenRequest{Curve: curveObj}) - if err != nil { - firstErr = err - return - } - responses[idx] = resp - }(i) - } - - wg.Wait() - - if firstErr != nil { - return nil, firstErr - } - - return responses, nil -} - -// printKeyShares prints the role index, x-share, and public key Q for each party. -func printKeyShares(title string, keyGenResponses []*mpc.ECDSA2PCKeyGenResponse) { - fmt.Printf("\n### %s\n", title) - for _, resp := range keyGenResponses { - fmt.Printf("KeyShare: %+v\n", resp.KeyShare) - share := resp.KeyShare - roleIdx, _ := share.RoleIndex() - fmt.Printf("RoleIndex: %d\n", roleIdx) - - x, _ := share.XShare() - fmt.Printf("Party %d: x_i = %s\n", roleIdx, x) - Q, _ := share.Q() - fmt.Printf("Party %d: Q = %s\n", roleIdx, Q) - Q.Free() - } -} - -// printSignatures displays the signatures obtained by the parties. -func printSignatures(message []byte, signResponses []*mpc.ECDSA2PCSignResponse) { - fmt.Printf("\nSignatures for message: %q\n", message) - for i, resp := range signResponses { - if len(resp.Signature) == 0 { - fmt.Printf("Party %d: \n", i) - continue - } - fmt.Printf("Party %d signature: %s\n", i, hex.EncodeToString(resp.Signature)) - } -} - -// signWithMockNet runs the collaborative signing protocol using an in-memory -// network and returns each party's response. -func signWithMockNet(sessionID, message []byte, keyGenResponses []*mpc.ECDSA2PCKeyGenResponse) ([]*mpc.ECDSA2PCSignResponse, error) { - if len(keyGenResponses) != 2 { - return nil, fmt.Errorf("need exactly 2 key shares, got %d", len(keyGenResponses)) - } - const nParties = 2 - messengers := mocknet.NewMockNetwork(nParties) - partyNames := []string{"party_0", "party_1"} - - responses := make([]*mpc.ECDSA2PCSignResponse, nParties) - var wg sync.WaitGroup - var firstErr error - - for i := 0; i < nParties; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - - jp, err := mpc.NewJob2P(messengers[idx], idx, partyNames) - if err != nil { - firstErr = err - return - } - defer jp.Free() - - msgCopy := append([]byte(nil), message...) - resp, err := mpc.ECDSA2PCSign(jp, &mpc.ECDSA2PCSignRequest{ - SessionID: sessionID, - KeyShare: keyGenResponses[idx].KeyShare, - Message: msgCopy, - }) - if err != nil { - firstErr = err - return - } - responses[idx] = resp - }(i) - } - - wg.Wait() - - if firstErr != nil { - return nil, firstErr - } - return responses, nil -} - -// refreshWithMockNet performs the re-share protocol for the provided key -// shares and returns the refreshed shares. -func refreshWithMockNet(oldKeyGenResponses []*mpc.ECDSA2PCKeyGenResponse) ([]*mpc.ECDSA2PCKeyGenResponse, error) { - if len(oldKeyGenResponses) != 2 { - return nil, fmt.Errorf("need exactly 2 key shares, got %d", len(oldKeyGenResponses)) - } - const nParties = 2 - messengers := mocknet.NewMockNetwork(nParties) - partyNames := []string{"party_0", "party_1"} - - newResponses := make([]*mpc.ECDSA2PCKeyGenResponse, nParties) - var wg sync.WaitGroup - var firstErr error - - for i := 0; i < nParties; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - - jp, err := mpc.NewJob2P(messengers[idx], idx, partyNames) - if err != nil { - firstErr = err - return - } - defer jp.Free() - - resp, err := mpc.ECDSA2PCRefresh(jp, &mpc.ECDSA2PCRefreshRequest{ - KeyShare: oldKeyGenResponses[idx].KeyShare, - }) - if err != nil { - firstErr = err - return - } - newResponses[idx] = &mpc.ECDSA2PCKeyGenResponse{KeyShare: resp.NewKeyShare} - }(i) - } - - wg.Wait() - - if firstErr != nil { - return nil, firstErr - } - return newResponses, nil -} - -// verifyExampleSignature verifies a DER-encoded ECDSA signature using the public key Q from the key share -func verifyExampleSignature(key mpc.ECDSA2PCKey, derSig []byte, digest []byte) error { - Q, err := key.Q() - if err != nil { - return fmt.Errorf("failed to get public key Q: %v", err) - } - defer Q.Free() - - // Prefer native verification (matches signing backend) - if c, err := key.Curve(); err == nil { - resp := &mpc.ECDSA2PCSignResponse{Signature: derSig} - if err := resp.Verify(Q, digest, c); err == nil { - return nil - } - } - - x := new(big.Int).SetBytes(Q.GetX()) - y := new(big.Int).SetBytes(Q.GetY()) - pubKey := &ecdsa.PublicKey{Curve: btcec.S256(), X: x, Y: y} - - type ecdsaSignature struct{ R, S *big.Int } - var s ecdsaSignature - if _, err := asn1.Unmarshal(derSig, &s); err != nil { - return fmt.Errorf("failed to parse DER signature: %v", err) - } - // Verify with stdlib - if ecdsa.Verify(pubKey, digest, s.R, s.S) { - return nil - } - // Fallback to btcec ecdsa verification using parsed DER and SEC1-encoded pubkey - if pk, err := btcec.ParsePubKey(Q.Bytes()); err == nil { - if sig, err := btcecEcdsa.ParseDERSignature(derSig); err == nil { - if sig.Verify(digest, pk) { - return nil - } - } - } - return fmt.Errorf("invalid signature") -} diff --git a/demos-go/examples/ecdsa-mpc-with-backup/go.mod b/demos-go/examples/ecdsa-mpc-with-backup/go.mod deleted file mode 100644 index fdf3ac55..00000000 --- a/demos-go/examples/ecdsa-mpc-with-backup/go.mod +++ /dev/null @@ -1,18 +0,0 @@ -module github.com/coinbase/cb-mpc/demo-go-ecdsa-mpc-with-backup - -go 1.24.0 - -toolchain go1.24.2 - -replace github.com/coinbase/cb-mpc/demos-go/cb-mpc-go => ../../cb-mpc-go - -require ( - github.com/btcsuite/btcd/btcec/v2 v2.3.5 - github.com/coinbase/cb-mpc/demos-go/cb-mpc-go v0.0.0-20240501131245-1eee31b51009 - golang.org/x/sync v0.15.0 -) - -require ( - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect - golang.org/x/crypto v0.45.0 -) diff --git a/demos-go/examples/ecdsa-mpc-with-backup/go.sum b/demos-go/examples/ecdsa-mpc-with-backup/go.sum deleted file mode 100644 index 7c0e09d6..00000000 --- a/demos-go/examples/ecdsa-mpc-with-backup/go.sum +++ /dev/null @@ -1,17 +0,0 @@ -github.com/btcsuite/btcd/btcec/v2 v2.3.5 h1:dpAlnAwmT1yIBm3exhT1/8iUSD98RDJM5vqJVQDQLiU= -github.com/btcsuite/btcd/btcec/v2 v2.3.5/go.mod h1:m22FrOAiuxl/tht9wIqAoGHcbnCCaPWyauO8y2LGGtQ= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demos-go/examples/ecdsa-mpc-with-backup/main.go b/demos-go/examples/ecdsa-mpc-with-backup/main.go deleted file mode 100644 index 1908a022..00000000 --- a/demos-go/examples/ecdsa-mpc-with-backup/main.go +++ /dev/null @@ -1,400 +0,0 @@ -package main - -import ( - "bytes" - "crypto/ecdsa" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/x509" - "encoding/asn1" - "encoding/hex" - "fmt" - "io" - "log" - "math/big" - "runtime" - "unsafe" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/mpc" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" - "github.com/btcsuite/btcd/btcec/v2" - "golang.org/x/crypto/hkdf" - "golang.org/x/sync/errgroup" -) - -// Deterministic reader for RSA OAEP derived from rho (seed) -type ctrRand struct { - seed [32]byte - counter uint64 - buf []byte - off int -} - -func newCTRRand(seed []byte) *ctrRand { - var s [32]byte - copy(s[:], seed) - return &ctrRand{seed: s} -} - -func (r *ctrRand) refill() { - ctrBytes := make([]byte, 8) - for i := 0; i < 8; i++ { - ctrBytes[7-i] = byte(r.counter >> (8 * i)) - } - h := sha256.Sum256(append(ctrBytes, r.seed[:]...)) - r.buf = h[:] - r.off = 0 - r.counter++ -} - -func (r *ctrRand) Read(p []byte) (int, error) { - n := 0 - for n < len(p) { - if r.off >= len(r.buf) { - r.refill() - } - m := copy(p[n:], r.buf[r.off:]) - r.off += m - n += m - } - return n, nil -} - -// rsaDemoKEM implements the mpc KEM interface with RSA OAEP and deterministic encapsulation -// so that PVE verification is reproducible. -type rsaDemoKEM struct{} - -const kemLabel = "demo-rsa-kem" // wire label; do not change without versioning -const kemDS = "rsa-demo-kem:v1" // domain-sep string for KDF derivations - -func (rsaDemoKEM) Generate() ([]byte, []byte, error) { - k, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - k.Precompute() - - prv := x509.MarshalPKCS1PrivateKey(k) - pub := x509.MarshalPKCS1PublicKey(&k.PublicKey) - return prv, pub, nil -} - -func (rsaDemoKEM) Encapsulate(ek []byte, rho [32]byte) ([]byte, []byte, error) { - pub, err := x509.ParsePKCS1PublicKey(ek) - if err != nil { - return nil, nil, err - } - if pub.Size() != 256 { - return nil, nil, fmt.Errorf("invalid RSA modulus size: got %d bits", pub.Size()*8) - } - if pub.E != 65537 { - return nil, nil, fmt.Errorf("unsupported RSA public exponent") - } - - // --- Derive independent materials from rho, bound to key N and DS --- - salt := pub.N.Bytes() // binds derivations to recipient key - - // Derive the OAEP seed (exactly Hash.Size() bytes for SHA-256) - var oaepSeed [32]byte - if _, err := io.ReadFull(hkdf.New(sha256.New, rho[:], salt, []byte(kemDS+"|oaep-seed")), oaepSeed[:]); err != nil { - return nil, nil, fmt.Errorf("hkdf: %w", err) - } - - // Derive the shared secret independently from the OAEP seed - ss := make([]byte, 32) - if _, err := io.ReadFull(hkdf.New(sha256.New, rho[:], salt, []byte(kemDS+"|ss")), ss); err != nil { - mpc.SecureWipe(oaepSeed[:]) - return nil, nil, fmt.Errorf("hkdf: %w", err) - } - - // Deterministic OAEP randomness: rsa.EncryptOAEP reads exactly Hash.Size() bytes. - r := bytes.NewReader(oaepSeed[:]) - ct, err := rsa.EncryptOAEP(sha256.New(), r, pub, ss, []byte(kemLabel)) - mpc.SecureWipe(oaepSeed[:]) // best-effort wipe - if err != nil { - mpc.SecureWipe(ss) - return nil, nil, err - } - - // Deterministic: (pub, rho) -> (ct, ss) - return ct, ss, nil -} - -func (rsaDemoKEM) Decapsulate(skHandle unsafe.Pointer, ct []byte) ([]byte, error) { - // Expect cmem_t pointing to private key bytes - type cmem_t struct { - data *byte - size int32 - } - cm := (*cmem_t)(skHandle) - if cm == nil || cm.data == nil || cm.size <= 0 { - return nil, fmt.Errorf("kem: decapsulation failed") // uniform error - } - // Sanity cap to avoid DoS; PKCS#1 DER for 2048-bit is ~1–2KB - if cm.size < 256 || cm.size > 8192 { - return nil, fmt.Errorf("kem: decapsulation failed") - } - - // Copy foreign memory into Go memory before parsing (safer if caller frees it) - dk := unsafe.Slice((*byte)(unsafe.Pointer(cm.data)), int(cm.size)) - dkCopy := make([]byte, len(dk)) - copy(dkCopy, dk) - // Ensure cm isn’t GC’d early - runtime.KeepAlive(cm) - - prv, err := x509.ParsePKCS1PrivateKey(dkCopy) - mpc.SecureWipe(dkCopy) // best-effort wipe of private key bytes copy - if err != nil { - return nil, fmt.Errorf("kem: decapsulation failed") - } - // Pin modulus size to 2048 - if prv.Size() != 256 { - return nil, fmt.Errorf("kem: decapsulation failed") - } - // Basic sanity on ciphertext length - if len(ct) != prv.Size() { - return nil, fmt.Errorf("kem: decapsulation failed") - } - - // Decrypt with blinding; output still deterministic wrt inputs - ss, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, prv, ct, []byte(kemLabel)) - if err != nil { - return nil, fmt.Errorf("kem: decapsulation failed") - } - if len(ss) != 32 { - // Should not occur if encaps follows the contract, but guard anyway - return nil, fmt.Errorf("kem: decapsulation failed") - } - - out := make([]byte, 32) - copy(out, ss) - mpc.SecureWipe(ss) // wipe temp - return out, nil -} - -func (rsaDemoKEM) DerivePub(dk []byte) ([]byte, error) { - prv, err := x509.ParsePKCS1PrivateKey(dk) - if err != nil { - return nil, err - } - if prv.Size() != 256 { - return nil, fmt.Errorf("invalid RSA modulus size: got %d bits", prv.Size()*8) - } - return x509.MarshalPKCS1PublicKey(&prv.PublicKey), nil -} - -func main() { - fmt.Println("=== ECDSA MPC with Backup Example ===") - fmt.Println("This example demonstrates:") - fmt.Println("1. N-party ECDSA key generation and signing") - fmt.Println("2. Secure backup and recovery of key shares using PVE") - fmt.Println() - - // Configuration - // The batch size determines the number of signing keys to generate. This is so that a batch of keys is created - // for each party and the batch backup using PVE can be properly demoed. - batchSize := 2 - nParties := 4 - messengers := mocknet.NewMockNetwork(nParties) - partyNames := make([]string, nParties) - for i := 0; i < nParties; i++ { - // In production settings, the party name should be tied to the party's identity. For example, hash of the public key. - partyNames[i] = fmt.Sprintf("p%d", i) - } - secp, err := curve.NewSecp256k1() - if err != nil { - log.Fatal(fmt.Errorf("failed to create secp256k1 curve: %v", err)) - } - parties := make([]*Party, nParties) - for i := 0; i < nParties; i++ { - parties[i] = &Party{ - Index: i, - Messenger: messengers[i], - NParties: nParties, - PartyNames: partyNames, - BatchSize: batchSize, - dkgResp: make([]*mpc.ECDSAMPCKeyGenResponse, batchSize), - signResp: make([]*mpc.ECDSAMPCSignResponse, batchSize), - } - } - - signatureReceiverId := 0 - - // Step 1: Run N-party ECDSA key generation and signing - fmt.Println("## Step 1: N-Party ECDSA Key Generation and Signing") - eg := errgroup.Group{} - for i := 0; i < nParties; i++ { - i := i - eg.Go(func() error { - for k := 0; k < batchSize; k++ { - if err := parties[i].Dkg(secp, k); err != nil { - return err - } - } - return nil - }) - } - - if err := eg.Wait(); err != nil { - log.Fatal(fmt.Errorf("ECDSA keygen failed: %v", err)) - } - - fmt.Printf("Generated %d-party ECDSA key shares\n", nParties) - // All parties will have received the same Q, using one of the as representative - for k := 0; k < batchSize; k++ { - Q := parties[signatureReceiverId].Q(k) - fmt.Printf("* Public Key[%d]: %v\n", k, Q) - } - - // Step 2: Backup the key shares using many to many PVE - fmt.Println("## Step 2: Backing Up Key Shares with PVE") - - // Step 2.1: create the access structure for backing up the keys - root := mpc.And("") - root.Children = []*mpc.AccessNode{mpc.Leaf(partyNames[0]), mpc.Threshold("th", 2)} - root.Children[1].Children = []*mpc.AccessNode{mpc.Leaf(partyNames[1]), mpc.Leaf(partyNames[2]), mpc.Leaf(partyNames[3])} - ac := mpc.AccessStructure{ - Root: root, - Curve: secp, - } - - // Step 2.2: define a KEM and PVE instance (RSA KEM here) - for i := 0; i < nParties; i++ { - parties[i].InitPVE() - } - - // Step 2.3: create encryption keys for leaves via KEM - // NOTE: in this demo, the parties are ALSO acting as backup holders but this does not have to be the case. - // If the backup holders are different, then the RSA keys should be generated by a different group of parties - // and their public keys should be communicated with the keyshare holders. - pubKeys := make(map[string]mpc.BaseEncPublicKey) - for i := 0; i < nParties; i++ { - // In a production setting, the public keys should be exchanged using PKI or some other secure mechanism - pubKeys[partyNames[i]] = parties[i].RSAKeygen() - } - - // Step 2.4: choose a human readable label. This will be cryptographically bound to the backup data - inputLabel := "demo-data" - - // Step 2.5: create a publicly verifiable backup - // Each party create a batch backup of all the dkg keys that it has generated - pveEncResps := make([]*mpc.PVEAcEncryptResponse, nParties) - for i := 0; i < nParties; i++ { - pveEncResps[i] = parties[i].Backup(inputLabel, secp, &ac, pubKeys) - } - - // Step 2.7: verify the backup - // All the parties verify all the backups that have been generated by themselves and everyone else - for i := 0; i < nParties; i++ { - err := parties[i].VerifyAllBackups(pveEncResps, &ac, inputLabel, pubKeys) - if err != nil { - log.Fatal(fmt.Errorf("failed to verify: %v", err)) - } - } - fmt.Printf("PVE verification passed\n") - - // Step 2.8: restore via interactive quorum PVE - // Each backup holder party (in this demo the same as dkg parties), uses their RSA decryption keys to - // partially decrypt each of the backups - // The partial decryptions are sent to the appropriate recipient and aggregated. - // IMPORTANT: In production setting it is extremely important that this is done correctly as to - // not send the partial decryptions to the incorrect party. - allPartialDecs := make([][]*mpc.PVEAcPartyDecryptRowResponse, nParties) - for i := 0; i < nParties; i++ { - allPartialDecs[i] = parties[i].PartialBackupDecryption(pveEncResps, inputLabel, &ac) - } - - for i := 0; i < nParties; i++ { - // Prepare all the partial decryptions made for party i by all parties - respsForParty := make(map[string]*mpc.PVEAcPartyDecryptRowResponse) - for j := 0; j < nParties; j++ { - respsForParty[partyNames[j]] = allPartialDecs[j][i] - } - // Party i uses all the partial decryptions to decrypt its own backup - backup := pveEncResps[i] - aggResp := parties[i].AggregatePartialDecryptedBackups(respsForParty, backup, inputLabel, &ac) - - // Assert restored values - for k := 0; k < batchSize; k++ { - xs, err := parties[i].dkgResp[k].KeyShare.XShare() - if err != nil { - log.Fatalf("failed to get share %v", err) - } - if !bytes.Equal(aggResp.PrivateValues[k].Bytes, xs.Bytes) { - log.Fatal("decrypted value does not match the original value") - } - } - } - fmt.Printf("PVE restore passed\n") - - // Step 4: Sign using the recovered keyshares - fmt.Println("## Step 1: N-Party ECDSA Key Generation and Signing") - - inputMessage := []byte("This is a message for ECDSA MPC with backup") - hash := sha256.Sum256(inputMessage) - eg = errgroup.Group{} - for i := 0; i < nParties; i++ { - i := i - eg.Go(func() error { - for k := 0; k < batchSize; k++ { - if err := parties[i].Sign(hash, signatureReceiverId, k); err != nil { - log.Fatalf("signing failed %v", err) - } - } - - return nil - }) - } - - if err := eg.Wait(); err != nil { - log.Fatal(fmt.Errorf("ECDSA sign failed: %v", err)) - } - - for k := 0; k < batchSize; k++ { - sig := parties[signatureReceiverId].signResp[k].Signature - fmt.Printf("* Signature[%d]: %s\n", k, hex.EncodeToString(sig)) - fmt.Println() - - // Verifying the signature - // Extract X and Y coordinates from the MPC public key - Q := parties[signatureReceiverId].Q(k) - xBytes := Q.GetX() - yBytes := Q.GetY() - - x := new(big.Int).SetBytes(xBytes) - y := new(big.Int).SetBytes(yBytes) - - goPubKey := &ecdsa.PublicKey{ - Curve: btcec.S256(), - X: x, - Y: y, - } - - // Parse DER-encoded signature - // DER format: SEQUENCE { r INTEGER, s INTEGER } - type ecdsaSignature struct { - R, S *big.Int - } - - var derSig ecdsaSignature - _, err = asn1.Unmarshal(sig, &derSig) - if err != nil { - log.Fatal(fmt.Errorf("failed to parse DER signature: %v", err)) - } - - r := derSig.R - s := derSig.S - - valid := ecdsa.Verify(goPubKey, hash[:], r, s) - - if valid { - fmt.Println("Signature verification PASSED") - fmt.Println("* The signature is valid and matches the message and public key") - } else { - fmt.Println("Signature verification FAILED") - fmt.Println("* The signature does not match the message and public key") - } - } -} diff --git a/demos-go/examples/ecdsa-mpc-with-backup/party.go b/demos-go/examples/ecdsa-mpc-with-backup/party.go deleted file mode 100644 index ac8172ac..00000000 --- a/demos-go/examples/ecdsa-mpc-with-backup/party.go +++ /dev/null @@ -1,183 +0,0 @@ -package main - -import ( - "fmt" - "log" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/mpc" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/transport/mocknet" -) - -type Party struct { - Index int - Messenger *mocknet.MockMessenger - NParties int - PartyNames []string - BatchSize int - pve *mpc.PVE - dkgResp []*mpc.ECDSAMPCKeyGenResponse - signResp []*mpc.ECDSAMPCSignResponse - rsaEk mpc.BaseEncPublicKey - rsaDk mpc.BaseEncPrivateKey -} - -func (p *Party) Dkg(c curve.Curve, batchId int) error { - jb, err := mpc.NewJobMP(p.Messenger, p.NParties, p.Index, p.PartyNames) - if err != nil { - return err - } - defer jb.Free() - - resp, err := mpc.ECDSAMPCKeyGen(jb, &mpc.ECDSAMPCKeyGenRequest{Curve: c}) - if err != nil { - return err - } - - p.dkgResp[batchId] = resp - return nil -} - -func (p *Party) Sign(hash [32]byte, signatureReceiverId int, batchId int) error { - jb, err := mpc.NewJobMP(p.Messenger, p.NParties, p.Index, p.PartyNames) - if err != nil { - return err - } - defer jb.Free() - - resp, err := mpc.ECDSAMPCSign(jb, &mpc.ECDSAMPCSignRequest{ - KeyShare: p.dkgResp[batchId].KeyShare, - Message: hash[:], - SignatureReceiver: signatureReceiverId, - }) - if err != nil { - return err - } - p.signResp[batchId] = resp - - return nil -} - -func (p *Party) InitPVE() { - var err error - p.pve, err = mpc.NewPVE(mpc.Config{KEM: rsaDemoKEM{}}) - if err != nil { - log.Fatal("failed to init PVE: %v", err) - } -} - -func (p *Party) Q(batchId int) *curve.Point { - // will panic dkgResp is nil, ok for demo purposes - Q, err := p.dkgResp[batchId].KeyShare.Q() - if err != nil { - log.Fatalf("failed to get public key: %v", err) - } - return Q -} - -func (p *Party) RSAKeygen() []byte { - dk, ek, err := rsaDemoKEM{}.Generate() - if err != nil { - log.Fatalf("failed to generate base encryption key pair: %v", err) - } - p.rsaDk = mpc.BaseEncPrivateKey(dk) - p.rsaEk = mpc.BaseEncPublicKey(ek) - return ek -} - -func (p *Party) Backup(inputLabel string, c curve.Curve, ac *mpc.AccessStructure, pubKeys map[string]mpc.BaseEncPublicKey) *mpc.PVEAcEncryptResponse { - xs := make([]*curve.Scalar, p.BatchSize) - Xs := make([]*curve.Point, p.BatchSize) - var err error - for k := 0; k < p.BatchSize; k++ { - xs[k], err = p.dkgResp[k].KeyShare.XShare() - if err != nil { - log.Fatalf("failed to get X share: %v", err) - } - Qis, err := p.dkgResp[k].KeyShare.Qis() - if err != nil { - log.Fatalf("failed to get Qis: %v", err) - } - Xs[k] = Qis[p.PartyNames[p.Index]] - } - - pveEncResp, err := p.pve.AcEncrypt(&mpc.PVEAcEncryptRequest{ - AccessStructure: ac, - PublicKeys: pubKeys, - PrivateValues: xs, - Label: inputLabel, - Curve: c, - }) - if err != nil { - log.Fatalf("failed to encrypt: %v", err) - } - return pveEncResp -} - -func (p *Party) VerifyAllBackups(backups []*mpc.PVEAcEncryptResponse, ac *mpc.AccessStructure, inputLabel string, pubKeys map[string]mpc.BaseEncPublicKey) error { - for j := 0; j < p.NParties; j++ { - // Party i creates what she thinks the public value for party j should be and verifies against that value - Xs := make([]*curve.Point, p.BatchSize) - var err error - for k := 0; k < p.BatchSize; k++ { - Qis, err := p.dkgResp[k].KeyShare.Qis() - if err != nil { - return fmt.Errorf("failed to get Qis: %v", err) - } - Xs[k] = Qis[p.PartyNames[j]] - } - - verifyResp, err := p.pve.AcVerify(&mpc.PVEAcVerifyRequest{ - AccessStructure: ac, - EncryptedBundle: backups[j].EncryptedBundle, - PublicKeys: pubKeys, - PublicShares: Xs, - Label: inputLabel, - }) - if err != nil { - return fmt.Errorf("failed to verify: %v", err) - } - if !verifyResp.Valid { - return fmt.Errorf("PVE verification failed") - } - } - return nil -} - -func (p *Party) PartialBackupDecryption(backups []*mpc.PVEAcEncryptResponse, inputLabel string, ac *mpc.AccessStructure) []*mpc.PVEAcPartyDecryptRowResponse { - resps := make([]*mpc.PVEAcPartyDecryptRowResponse, p.NParties) - var err error - for j := 0; j < p.NParties; j++ { - resps[j], err = p.pve.AcPartyDecryptRow(&mpc.PVEAcPartyDecryptRowRequest{ - AccessStructure: ac, - Path: p.PartyNames[p.Index], - PrivateKey: p.rsaDk, - EncryptedBundle: backups[j].EncryptedBundle, - Label: inputLabel, - RowIndex: 0, - }) - if err != nil { - log.Fatalf("failed to party decrypt row: %v", err) - } - } - return resps -} - -func (p *Party) AggregatePartialDecryptedBackups(partialDecryptions map[string]*mpc.PVEAcPartyDecryptRowResponse, backup *mpc.PVEAcEncryptResponse, inputLabel string, ac *mpc.AccessStructure) *mpc.PVEAcAggregateToRestoreRowResponse { - shares := make(map[string][]byte) - for pname, resp := range partialDecryptions { - shares[pname] = resp.Share - } - aggResp, err := p.pve.AcAggregateToRestoreRow(&mpc.PVEAcAggregateToRestoreRowRequest{ - AccessStructure: ac, - EncryptedBundle: backup.EncryptedBundle, - Label: inputLabel, - RowIndex: 0, - Shares: shares, - }) - if err != nil { - log.Fatalf("failed to aggregate: %v", err) - } - - return aggResp -} diff --git a/demos-go/examples/zk/go.mod b/demos-go/examples/zk/go.mod deleted file mode 100644 index d1d3d520..00000000 --- a/demos-go/examples/zk/go.mod +++ /dev/null @@ -1,14 +0,0 @@ -module github.com/coinbase/cb-mpc/demo-go-zk - -go 1.23.0 - -toolchain go1.24.2 - -require github.com/coinbase/cb-mpc/demos-go/cb-mpc-go v0.0.0-20240501131245-1eee31b51009 - -require ( - github.com/stretchr/testify v1.10.0 // indirect - golang.org/x/sync v0.15.0 // indirect -) - -replace github.com/coinbase/cb-mpc/demos-go/cb-mpc-go => ../../cb-mpc-go diff --git a/demos-go/examples/zk/go.sum b/demos-go/examples/zk/go.sum deleted file mode 100644 index c43696e9..00000000 --- a/demos-go/examples/zk/go.sum +++ /dev/null @@ -1,10 +0,0 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demos-go/examples/zk/main.go b/demos-go/examples/zk/main.go deleted file mode 100644 index 261b207d..00000000 --- a/demos-go/examples/zk/main.go +++ /dev/null @@ -1,48 +0,0 @@ -package main - -import ( - "fmt" - "log" - - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/curve" - "github.com/coinbase/cb-mpc/demos-go/cb-mpc-go/api/zk" -) - -func main() { - fmt.Println("\n=== CB-MPC Zero-Knowledge Discrete Logarithm Example ===") - - // Use the secp256k1 curve for this demo (any supported curve works). - c, err := curve.NewSecp256k1() - if err != nil { - log.Fatalf("creating curve failed: %v", err) - } - defer c.Free() - - // Generate a random key pair (w, W = w·G) - w, W, err := c.RandomKeyPair() - if err != nil { - log.Fatalf("key generation failed: %v", err) - } - fmt.Printf("Generated key pair on %s – witness length: %d bytes\n", c.String(), len(w.Bytes)) - - // Create a proof - sessionID := []byte("example-session") - auxiliary := uint64(2025) - - pr, err := zk.ZKUCDLProve(&zk.ZKUCDLProveRequest{PublicKey: W, Witness: w, SessionID: sessionID, Auxiliary: auxiliary}) - if err != nil { - log.Fatalf("proof generation failed: %v", err) - } - fmt.Printf("Proof generated – %d bytes\n", len(pr.Proof)) - - // Verify the proof - vr, err := zk.ZKUCDLVerify(&zk.ZKUCDLVerifyRequest{PublicKey: W, Proof: pr.Proof, SessionID: sessionID, Auxiliary: auxiliary}) - if err != nil { - log.Fatalf("verification failed: %v", err) - } - if !vr.Valid { - log.Fatalf("❌ proof verification failed") - } - - fmt.Println("✅ Proof verified successfully") -} diff --git a/docs/secure-usage.pdf b/docs/secure-usage.pdf deleted file mode 100644 index ee5b8aa9d9805985f61a13c2f5547951ff4953cb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 131 zcmWN?OA^8$3;@tQr|1PNO$a3PHhhH{m5$UdJiWfnyUKg^@mkt>9&^|F-p|{k&h>x$ z#5Iklo<~Xg0yTPGW(#QDE*2p(M}V|BCg7Yc8FC3N7YPPE`4R#ri6XpF5SyTquN5$Q NBWm_ #include -#include -#include +#include namespace coinbase { @@ -12,40 +11,6 @@ struct buf256_t; class converter_t; -class convertable_t { // interface - public: - virtual void convert(converter_t& converter) = 0; - virtual ~convertable_t() {} - - class def_t { - public: - virtual ~def_t() {} - virtual convertable_t* create() = 0; - }; - - template - class def_entry_t : public def_t { - public: - def_entry_t() { factory_t::register_type(this, code_type); } - virtual convertable_t* create() { return new type(); } - }; - - class factory_t { - private: - unordered_map_t map; - - public: - static void register_type(def_t* def, uint64_t code_type); - static convertable_t* create(mem_t data, bool convert = true); - static convertable_t* create(uint64_t code_type); - - template - class register_t : public global_init_t> {}; - }; -}; - -static global_t g_convertable_factory; - class converter_t { public: template @@ -86,6 +51,11 @@ class converter_t { // Maximum number of elements allowed when (de)serializing a std::vector. static constexpr uint32_t MAX_CONTAINER_ELEMENTS = 1 << 20; + // Maximum value allowed for `convert_len(...)` when deserializing. + // + // `convert_len` is used as a length prefix for buffers and container counts. Bounding it protects against + // attacker-controlled allocations and loops if a malicious peer supplies an oversized length prefix. + static constexpr uint32_t MAX_CONVERT_LEN = 64 * 1024 * 1024; // 64 MiB void convert(bool& value); void convert(uint8_t& value); @@ -180,6 +150,12 @@ class converter_t { if (!write) value.clear(); uint32_t count = (uint32_t)value.size(); convert_len(count); + if (!write) { + if (count > MAX_CONTAINER_ELEMENTS) { + set_error(); + return; + } + } auto v = value.begin(); for (uint32_t i = 0; i < count && !is_error(); i++) { if (write) { diff --git a/src/cbmpc/core/extended_uint.h b/include-internal/cbmpc/internal/core/extended_uint.h similarity index 100% rename from src/cbmpc/core/extended_uint.h rename to include-internal/cbmpc/internal/core/extended_uint.h diff --git a/src/cbmpc/core/log.h b/include-internal/cbmpc/internal/core/log.h similarity index 100% rename from src/cbmpc/core/log.h rename to include-internal/cbmpc/internal/core/log.h diff --git a/src/cbmpc/core/strext.h b/include-internal/cbmpc/internal/core/strext.h similarity index 90% rename from src/cbmpc/core/strext.h rename to include-internal/cbmpc/internal/core/strext.h index 35cd48cd..9d3373ed 100755 --- a/src/cbmpc/core/strext.h +++ b/include-internal/cbmpc/internal/core/strext.h @@ -16,8 +16,10 @@ class insensitive_map_t : public unordered_map_t + #include #include @@ -21,20 +23,43 @@ inline int bits_to_bytes_floor(int bits) { return bits >> 3; } inline int bits_to_bytes(int bits) { return (bits + 7) >> 3; } inline int bytes_to_bits(int bytes) { return bytes << 3; } -inline uint16_t le_get_2(const_byte_ptr src) { return *(uint16_t*)src; } -inline uint32_t le_get_4(const_byte_ptr src) { return *(uint32_t*)src; } -inline uint64_t le_get_8(const_byte_ptr src) { return *(uint64_t*)src; } -inline void le_set_2(byte_ptr dst, uint16_t value) { *(uint16_t*)dst = value; } -inline void le_set_4(byte_ptr dst, uint32_t value) { *(uint32_t*)dst = value; } -inline void le_set_8(byte_ptr dst, uint64_t value) { *(uint64_t*)dst = value; } +namespace detail { +inline uint16_t load_u16_unaligned(const_byte_ptr src) noexcept { + uint16_t v; + std::memcpy(&v, src, sizeof(v)); + return v; +} +inline uint32_t load_u32_unaligned(const_byte_ptr src) noexcept { + uint32_t v; + std::memcpy(&v, src, sizeof(v)); + return v; +} +inline uint64_t load_u64_unaligned(const_byte_ptr src) noexcept { + uint64_t v; + std::memcpy(&v, src, sizeof(v)); + return v; +} +inline void store_u16_unaligned(byte_ptr dst, uint16_t v) noexcept { std::memcpy(dst, &v, sizeof(v)); } +inline void store_u32_unaligned(byte_ptr dst, uint32_t v) noexcept { std::memcpy(dst, &v, sizeof(v)); } +inline void store_u64_unaligned(byte_ptr dst, uint64_t v) noexcept { std::memcpy(dst, &v, sizeof(v)); } +} // namespace detail + +// NOTE: These helpers must be safe for unaligned buffers and must not rely on strict-aliasing violations. +// Modern compilers typically inline fixed-size memcpy into efficient loads/stores. +inline uint16_t le_get_2(const_byte_ptr src) { return detail::load_u16_unaligned(src); } +inline uint32_t le_get_4(const_byte_ptr src) { return detail::load_u32_unaligned(src); } +inline uint64_t le_get_8(const_byte_ptr src) { return detail::load_u64_unaligned(src); } +inline void le_set_2(byte_ptr dst, uint16_t value) { detail::store_u16_unaligned(dst, value); } +inline void le_set_4(byte_ptr dst, uint32_t value) { detail::store_u32_unaligned(dst, value); } +inline void le_set_8(byte_ptr dst, uint64_t value) { detail::store_u64_unaligned(dst, value); } #if defined(__x86_64__) -inline uint16_t be_get_2(const_byte_ptr src) { return __builtin_bswap16(*(uint16_t*)src); } -inline uint32_t be_get_4(const_byte_ptr src) { return __builtin_bswap32(*(uint32_t*)src); } -inline uint64_t be_get_8(const_byte_ptr src) { return __builtin_bswap64(*(uint64_t*)src); } -inline void be_set_2(byte_ptr dst, uint16_t value) { *(uint16_t*)dst = __builtin_bswap16(value); } -inline void be_set_4(byte_ptr dst, uint32_t value) { *(uint32_t*)dst = __builtin_bswap32(value); } -inline void be_set_8(byte_ptr dst, uint64_t value) { *(uint64_t*)dst = __builtin_bswap64(value); } +inline uint16_t be_get_2(const_byte_ptr src) { return __builtin_bswap16(detail::load_u16_unaligned(src)); } +inline uint32_t be_get_4(const_byte_ptr src) { return __builtin_bswap32(detail::load_u32_unaligned(src)); } +inline uint64_t be_get_8(const_byte_ptr src) { return __builtin_bswap64(detail::load_u64_unaligned(src)); } +inline void be_set_2(byte_ptr dst, uint16_t value) { detail::store_u16_unaligned(dst, __builtin_bswap16(value)); } +inline void be_set_4(byte_ptr dst, uint32_t value) { detail::store_u32_unaligned(dst, __builtin_bswap32(value)); } +inline void be_set_8(byte_ptr dst, uint64_t value) { detail::store_u64_unaligned(dst, __builtin_bswap64(value)); } #else inline uint16_t be_get_2(const_byte_ptr src) { return (uint16_t(src[0]) << 8) | src[1]; } inline uint32_t be_get_4(const_byte_ptr src) { return (uint32_t(be_get_2(src + 0)) << 16) | be_get_2(src + 2); } @@ -75,7 +100,10 @@ template auto lookup(const std::map& map, const Key& value) { using Ref = typename std::map::value_type::second_type; auto it = map.find(value); - return std::tuple(it != map.end(), it->second); + // NOTE: Never return `it->second` by reference unconditionally here: when `it == map.end()`, + // dereferencing is undefined behavior even if the caller checks the accompanying boolean. + const Ref* ref = (it != map.end()) ? &it->second : nullptr; + return std::tuple(ref != nullptr, ref); } template diff --git a/src/cbmpc/crypto/base.h b/include-internal/cbmpc/internal/crypto/base.h similarity index 93% rename from src/cbmpc/crypto/base.h rename to include-internal/cbmpc/internal/crypto/base.h index 458e875e..ef54d37c 100644 --- a/src/cbmpc/crypto/base.h +++ b/include-internal/cbmpc/internal/crypto/base.h @@ -13,7 +13,14 @@ #define X509_get_notAfter X509_getm_notAfter #endif -#include +#include +#include +#include +#include +#include +#include +#include +#include enum { E_CRYPTO = ERRCODE(ECATEGORY_CRYPTO, 1), E_ECDSA_2P_BIT_LEAK = ERRCODE(ECATEGORY_CRYPTO, 2) }; @@ -180,17 +187,6 @@ class aes_gmac_t : public evp_cipher_ctx_t { } // namespace coinbase::crypto -// clang-format off -// Order matters here -#include "base_bn.h" -#include "base_mod.h" -#include "base_ecc.h" -#include "base_eddsa.h" -#include "base_hash.h" -#include "base_paillier.h" -#include "base_rsa.h" - -// clang-format on using coinbase::crypto::bn_t; using coinbase::crypto::ecc_point_t; using coinbase::crypto::ecurve_t; diff --git a/src/cbmpc/crypto/base_bn.h b/include-internal/cbmpc/internal/crypto/base_bn.h old mode 100755 new mode 100644 similarity index 76% rename from src/cbmpc/crypto/base_bn.h rename to include-internal/cbmpc/internal/crypto/base_bn.h index 9c1496dd..a7dafd1d --- a/src/cbmpc/crypto/base_bn.h +++ b/include-internal/cbmpc/internal/crypto/base_bn.h @@ -1,5 +1,12 @@ #pragma once +#include +#include +#include +#include + +#include + struct bignum_st { BN_ULONG* d; /* Pointer to an array of 'BN_BITS2' bit * chunks. */ @@ -121,8 +128,9 @@ class bn_t { static buf_t vector_to_bin(const std::vector& vals, int val_size); int to_bin(byte_ptr dst) const; void to_bin(byte_ptr dst, int size) const; - void to_bin(mem_t mem) const { to_bin(mem.data, mem.size); } + void to_bin(mem_t mem) const { to_bin(const_cast(mem.data), mem.size); } static bn_t from_bin(mem_t mem); + static error_t vector_from_bin(mem_t mem, int n, int size, const mod_t& q, std::vector& out); static std::vector vector_from_bin(mem_t mem, int n, int size, const mod_t& q); static bn_t from_bin_bitlen(mem_t mem, int bits); @@ -133,6 +141,10 @@ class bn_t { static bn_t from_string(const std::string& str) { return from_string(str.c_str()); } static bn_t from_hex(const_char_ptr str); + static error_t from_string(const_char_ptr str, bn_t& result); + static error_t from_string(const std::string& str, bn_t& result) { return from_string(str.c_str(), result); } + static error_t from_hex(const_char_ptr str, bn_t& result); + static int compare(const bn_t& b1, const bn_t& b2); int sign() const; @@ -159,6 +171,17 @@ class bn_t { static bool check_modulo(const mod_t& n); static void reset_modulo(const mod_t& n); + // Maximum number of bytes allowed when (de)serializing a `bn_t`. + // + // This is a defense-in-depth bound against attacker-controlled inputs causing excessive allocations or crashes + // during deserialization. Override at build time by defining `CBMPC_MAX_SERIALIZED_BIGNUM_BYTES`. + inline static constexpr uint32_t MAX_SERIALIZED_BIGNUM_BYTES = +#if defined(CBMPC_MAX_SERIALIZED_BIGNUM_BYTES) + CBMPC_MAX_SERIALIZED_BIGNUM_BYTES; +#else + 1024 * 1024; // 1 MiB +#endif + void convert(converter_t& converter); // thread local storage for BN_CTX @@ -172,6 +195,22 @@ class bn_t { void init(); }; +/** + * WARNING: `MODULO(n)` sets a thread-local modulus for all `bn_t` arithmetic in the current thread. + * + * The modulus is reset in the `for` loop update clause, so it is NOT reset if the body exits early + * via `return`, `break`, `throw`, `goto`, etc. If that happens, subsequent cryptographic operations + * on the same thread may run under an unexpected modulus (or under a modulus when they should not), + * potentially producing incorrect results or corrupted state. + * + * Additional caveats: + * - The modulus is stored as a single thread-local pointer (it is not stacked), so `MODULO(...)` + * must not be nested. + * - Avoid mixing arithmetic that assumes "no modulus" with code that can leave a previous modulus set. + * + * If early-exit behavior is required, ensure the modulus is reset before leaving the scope, or refactor + * to propagate errors without exiting the `MODULO(...) { ... }` block. + */ #define MODULO(n) \ for (coinbase::crypto::bn_t::set_modulo(n); coinbase::crypto::bn_t::check_modulo(n); \ coinbase::crypto::bn_t::reset_modulo(n)) diff --git a/src/cbmpc/crypto/base_ec_core.h b/include-internal/cbmpc/internal/crypto/base_ec_core.h similarity index 99% rename from src/cbmpc/crypto/base_ec_core.h rename to include-internal/cbmpc/internal/crypto/base_ec_core.h index c72a1e2f..f32d88eb 100644 --- a/src/cbmpc/crypto/base_ec_core.h +++ b/include-internal/cbmpc/internal/crypto/base_ec_core.h @@ -1,6 +1,6 @@ #pragma once -#include "base.h" +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/base_ecc.h b/include-internal/cbmpc/internal/crypto/base_ecc.h similarity index 99% rename from src/cbmpc/crypto/base_ecc.h rename to include-internal/cbmpc/internal/crypto/base_ecc.h index 0a59dfc4..85d1640b 100644 --- a/src/cbmpc/crypto/base_ecc.h +++ b/include-internal/cbmpc/internal/crypto/base_ecc.h @@ -1,6 +1,7 @@ #pragma once -#include "base_bn.h" +#include +#include #ifndef NID_ED25519 #define NID_ED25519 1087 diff --git a/src/cbmpc/crypto/base_ecc_secp256k1.h b/include-internal/cbmpc/internal/crypto/base_ecc_secp256k1.h similarity index 98% rename from src/cbmpc/crypto/base_ecc_secp256k1.h rename to include-internal/cbmpc/internal/crypto/base_ecc_secp256k1.h index a3bc6ef2..abc8c508 100644 --- a/src/cbmpc/crypto/base_ecc_secp256k1.h +++ b/include-internal/cbmpc/internal/crypto/base_ecc_secp256k1.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/base_eddsa.h b/include-internal/cbmpc/internal/crypto/base_eddsa.h similarity index 96% rename from src/cbmpc/crypto/base_eddsa.h rename to include-internal/cbmpc/internal/crypto/base_eddsa.h index 86813895..d7b29552 100755 --- a/src/cbmpc/crypto/base_eddsa.h +++ b/include-internal/cbmpc/internal/crypto/base_eddsa.h @@ -1,7 +1,7 @@ #pragma once -#include "base_ecc.h" -#include "ec25519_core.h" +#include +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/base_hash.h b/include-internal/cbmpc/internal/crypto/base_hash.h similarity index 98% rename from src/cbmpc/crypto/base_hash.h rename to include-internal/cbmpc/internal/crypto/base_hash.h index acbd7b3b..76d22f80 100644 --- a/src/cbmpc/crypto/base_hash.h +++ b/include-internal/cbmpc/internal/crypto/base_hash.h @@ -5,9 +5,10 @@ #include #include -#include - -#include "base_bn.h" +#include +#include +#include +#include namespace coinbase::crypto { class bn_t; diff --git a/src/cbmpc/crypto/base_mod.h b/include-internal/cbmpc/internal/crypto/base_mod.h similarity index 93% rename from src/cbmpc/crypto/base_mod.h rename to include-internal/cbmpc/internal/crypto/base_mod.h index 24350ec2..9d85eed0 100644 --- a/src/cbmpc/crypto/base_mod.h +++ b/include-internal/cbmpc/internal/crypto/base_mod.h @@ -1,5 +1,7 @@ #pragma once +#include + namespace coinbase::crypto { // This is a dangerous function and should be used only if you know what are you doing! @@ -39,7 +41,15 @@ class mod_t { bn_t pow(const bn_t& x, const bn_t& e) const { bn_t r; _pow(r, x, e); return r; } bn_t mod(const bn_t& a) const { bn_t r; _mod(r, a); return r; } - bn_t mod(int a) const { return a < 0 ? neg(bn_t(-a)) : bn_t(a); } + bn_t mod(int a) const { + // NOTE: Avoid `-INT_MIN` signed overflow / UB when `a == INT_MIN`. + if (a >= 0) return bn_t(a); + const BN_ULONG abs_a = static_cast(-static_cast(a)); + bn_t tmp; + int res = BN_set_word(tmp, abs_a); + cb_assert(res); + return neg(tmp); + } // clang-format on // only works with odd m diff --git a/src/cbmpc/crypto/base_paillier.h b/include-internal/cbmpc/internal/crypto/base_paillier.h similarity index 94% rename from src/cbmpc/crypto/base_paillier.h rename to include-internal/cbmpc/internal/crypto/base_paillier.h index ea256075..831cddd5 100644 --- a/src/cbmpc/crypto/base_paillier.h +++ b/include-internal/cbmpc/internal/crypto/base_paillier.h @@ -1,6 +1,7 @@ #pragma once -#include "base_rsa.h" +#include +#include namespace coinbase::crypto { @@ -140,7 +141,7 @@ class paillier_t { template error_t verify_ciphers(Values... ciphers) const { std::array arr{ciphers...}; - return batch_verify_ciphers(&arr[0], sizeof...(ciphers)); + return batch_verify_ciphers(arr.data(), static_cast(arr.size())); } error_t batch_verify_ciphers(const bn_t* ciphers, int n) const; @@ -152,6 +153,9 @@ class paillier_t { bn_t q; bn_t phi_N; // cached bn_t inv_phi_N; // cached + // Cached inverse of N modulo 2^bit_size for constant-time extraction of L(u) = (u-1)/N + // during Paillier decryption (avoid variable-time big integer division on secret-dependent values). + bn_t inv_N_mod_2k; struct crt_t { mod_t p, q; diff --git a/src/cbmpc/crypto/base_pki.h b/include-internal/cbmpc/internal/crypto/base_pki.h similarity index 74% rename from src/cbmpc/crypto/base_pki.h rename to include-internal/cbmpc/internal/crypto/base_pki.h index 8ccee478..ae6e1505 100644 --- a/src/cbmpc/crypto/base_pki.h +++ b/include-internal/cbmpc/internal/crypto/base_pki.h @@ -2,11 +2,10 @@ #include -#include - -#include "base.h" -#include "base_ecc.h" -#include "base_rsa.h" +#include +#include +#include +#include namespace coinbase::crypto { @@ -189,98 +188,9 @@ struct kem_policy_ecdh_p256_t { }; // --------------------------------------------------------------------------- -// C++ native unified PKE types +// PKE scheme type aliases // --------------------------------------------------------------------------- -class prv_key_t; - -typedef uint8_t key_type_t; - -enum key_type_e : uint8_t { - NONE = 0, - RSA = 1, - ECC = 2, -}; - -class pub_key_t { - friend class prv_key_t; - - public: - static pub_key_t from(const rsa_pub_key_t& rsa); - static pub_key_t from(const ecc_pub_key_t& ecc); - const rsa_pub_key_t& rsa() const { return rsa_key; } - const ecc_pub_key_t& ecc() const { return ecc_key; } - - key_type_t get_type() const { return key_type; } - - void convert(coinbase::converter_t& c) { - c.convert(key_type); - if (key_type == key_type_e::RSA) - c.convert(rsa_key); - else if (key_type == key_type_e::ECC) - c.convert(ecc_key); - else - cb_assert(false && "Invalid key type"); - } - - bool operator==(const pub_key_t& val) const { - if (key_type != val.key_type) return false; - - if (key_type == key_type_e::RSA) - return rsa() == val.rsa(); - else if (key_type == key_type_e::ECC) - return ecc() == val.ecc(); - else { - cb_assert(false && "Invalid key type"); - return false; - } - } - bool operator!=(const pub_key_t& val) const { return !(*this == val); } - - private: - key_type_t key_type = key_type_e::NONE; - rsa_pub_key_t rsa_key; - ecc_pub_key_t ecc_key; -}; - -class prv_key_t { - public: - static prv_key_t from(const rsa_prv_key_t& rsa); - static prv_key_t from(const ecc_prv_key_t& ecc); - const rsa_prv_key_t rsa() const { return rsa_key; } - const ecc_prv_key_t ecc() const { return ecc_key; } - - key_type_t get_type() const { return key_type; } - - pub_key_t pub() const; - error_t execute(mem_t in, buf_t& out) const; - - private: - key_type_t key_type = key_type_e::NONE; - rsa_prv_key_t rsa_key; - ecc_prv_key_t ecc_key; -}; - -struct ciphertext_t { - key_type_t key_type = key_type_e::NONE; - kem_aead_ciphertext_t rsa_kem; - kem_aead_ciphertext_t ecies; - - error_t encrypt(const pub_key_t& pub_key, mem_t label, mem_t plain, drbg_aes_ctr_t* drbg = nullptr); - - error_t decrypt(const prv_key_t& prv_key, mem_t label, buf_t& plain) const; - - void convert(coinbase::converter_t& c) { - c.convert(key_type); - if (key_type == key_type_e::RSA) - c.convert(rsa_kem); - else if (key_type == key_type_e::ECC) - c.convert(ecies); - else - cb_assert(false && "Invalid key type"); - } -}; - template struct hybrid_pke_t { using ek_t = EK_T; @@ -290,7 +200,6 @@ struct hybrid_pke_t { using rsa_pke_t = hybrid_pke_t>; using ecies_t = hybrid_pke_t>; -using unified_pke_t = hybrid_pke_t; template struct sign_scheme_t { diff --git a/include-internal/cbmpc/internal/crypto/base_rsa.h b/include-internal/cbmpc/internal/crypto/base_rsa.h new file mode 100644 index 00000000..8a1d9642 --- /dev/null +++ b/include-internal/cbmpc/internal/crypto/base_rsa.h @@ -0,0 +1,136 @@ +#pragma once + +#include +#include + +typedef EVP_PKEY RSA_BASE; + +namespace coinbase::crypto { + +const int RSA_KEY_LENGTH = 2048; +class rsa_pub_key_t : public scoped_ptr_t { + public: + int size() const; + + static error_t pad_oaep(int bits, mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t& out); + static error_t pad_oaep_with_seed(int bits, mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t seed, + buf_t& out); + + error_t encrypt_raw(mem_t in, buf_t& out) const; + error_t encrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t& out) const; + error_t encrypt_oaep_with_seed(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t seed, buf_t& out) const; + error_t verify_pkcs1(mem_t data, hash_e hash_alg, mem_t signature) const; + + buf_t to_der() const; + buf_t to_der_pkcs1() const; + error_t from_der(mem_t der); + + void set(const BIGNUM* n, const BIGNUM* e) { + create(); + set(ptr, n, e); + } + + void convert(coinbase::converter_t& converter); + + bool operator==(const rsa_pub_key_t& val) const { return EVP_PKEY_eq(ptr, val.ptr); } + bool operator!=(const rsa_pub_key_t& val) const { return !EVP_PKEY_eq(ptr, val.ptr); } + + private: + struct data_t { + BIGNUM *n = nullptr, *e = nullptr; + }; + + static data_t get(const RSA_BASE* ptr); + static void set(RSA_BASE*& rsa, const BIGNUM* n, const BIGNUM* e); + + data_t get() const { return get(ptr); } + void create(); +}; + +class rsa_prv_key_t : public scoped_ptr_t { + public: + error_t execute(mem_t enc_info, buf_t& dec_info) const; + + rsa_pub_key_t pub() const; + int size() const; + + void generate(int bits); + + error_t decrypt_raw(mem_t in, buf_t& out) const; + error_t decrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t& out) const; + error_t sign_pkcs1(mem_t data, hash_e hash_alg, buf_t& sig) const; + + buf_t to_der() const; + error_t from_der(mem_t der); + + void convert(coinbase::converter_t& converter); + + bn_t get_e() const { return bn_t(get().e); } + bn_t get_n() const { return bn_t(get().n); } + bn_t get_p() const { return bn_t(get().p); } + bn_t get_q() const { return bn_t(get().q); } + + void set(const BIGNUM* n, const BIGNUM* e, const BIGNUM* d) { + create(); + set(ptr, n, e, d); + } + void set(const BIGNUM* n, const BIGNUM* e, const BIGNUM* d, const BIGNUM* p, const BIGNUM* q) { + create(); + set(ptr, n, e, d, p, q); + } + void set(const BIGNUM* n, const BIGNUM* e, const BIGNUM* d, const BIGNUM* p, const BIGNUM* q, const BIGNUM* dp, + const BIGNUM* dq, const BIGNUM* qinv) { + create(); + set(ptr, n, e, d, p, q, dp, dq, qinv); + } + error_t recover_factors(); + void set_paillier(const BIGNUM* n, const BIGNUM* p, const BIGNUM* q, const BIGNUM* dp, const BIGNUM* dq, + const BIGNUM* qinv); + + private: + struct data_t { + bn_t n, e; + bn_t p, q; + }; + static data_t get(const RSA_BASE* ptr); + static void set(RSA_BASE*& rsa, const BIGNUM* n, const BIGNUM* e, const BIGNUM* d); + static void set(RSA_BASE*& rsa, const BIGNUM* n, const BIGNUM* e, const BIGNUM* d, const BIGNUM* p, const BIGNUM* q); + static void set(RSA_BASE*& rsa, const BIGNUM* n, const BIGNUM* e, const BIGNUM* d, const BIGNUM* p, const BIGNUM* q, + const BIGNUM* dp, const BIGNUM* dq, const BIGNUM* qinv); + static void set(RSA_BASE*& rsa, const data_t& data); + + data_t get() const { return get(ptr); } + void create(); +}; + +class rsa_oaep_t { + public: + typedef error_t (*exec_t)(void* ctx, int hash_alg, int mgf_alg, mem_t label, mem_t input, buf_t& output); + + rsa_oaep_t(const rsa_prv_key_t& _key) : key(&_key), exec(nullptr), ctx(nullptr) {} + rsa_oaep_t(exec_t _exec, void* _ctx) : key(nullptr), exec(_exec), ctx(_ctx) {} + + error_t execute(hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t in, buf_t& out) const; + static error_t execute(void* ctx, int hash_alg, int mgf_alg, mem_t label, mem_t in, buf_t& out); + + private: + exec_t exec; + void* ctx; + const rsa_prv_key_t* key; +}; + +static int evp_md_size(hash_e type) { return hash_alg_t::get(type).size; } +static int evp_digest_init_ex(hash_t& ctx, hash_e type, void* impl) { + ctx.init(); + return 1; +} +static int evp_digest_update(hash_t& ctx, const void* d, size_t cnt) { + ctx.update(const_byte_ptr(d), int(cnt)); + return 1; +} +static int evp_digest_final_ex(hash_t& ctx, unsigned char* md, unsigned int* s) { + ctx.final(md); + return 1; +} + +} // namespace coinbase::crypto diff --git a/src/cbmpc/crypto/commitment.h b/include-internal/cbmpc/internal/crypto/commitment.h similarity index 99% rename from src/cbmpc/crypto/commitment.h rename to include-internal/cbmpc/internal/crypto/commitment.h index ae15c728..5bf1d63b 100644 --- a/src/cbmpc/crypto/commitment.h +++ b/include-internal/cbmpc/internal/crypto/commitment.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/ec25519_core.h b/include-internal/cbmpc/internal/crypto/ec25519_core.h similarity index 97% rename from src/cbmpc/crypto/ec25519_core.h rename to include-internal/cbmpc/internal/crypto/ec25519_core.h index cde7a6d5..65c0546c 100755 --- a/src/cbmpc/crypto/ec25519_core.h +++ b/include-internal/cbmpc/internal/crypto/ec25519_core.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace coinbase::crypto::ec25519_core { diff --git a/src/cbmpc/crypto/elgamal.h b/include-internal/cbmpc/internal/crypto/elgamal.h similarity index 98% rename from src/cbmpc/crypto/elgamal.h rename to include-internal/cbmpc/internal/crypto/elgamal.h index 614aba1c..a346c915 100644 --- a/src/cbmpc/crypto/elgamal.h +++ b/include-internal/cbmpc/internal/crypto/elgamal.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/lagrange.h b/include-internal/cbmpc/internal/crypto/lagrange.h similarity index 94% rename from src/cbmpc/crypto/lagrange.h rename to include-internal/cbmpc/internal/crypto/lagrange.h index 6f55510d..10944466 100644 --- a/src/cbmpc/crypto/lagrange.h +++ b/include-internal/cbmpc/internal/crypto/lagrange.h @@ -1,7 +1,7 @@ #pragma once -#include -#include -#include +#include +#include +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/ro.h b/include-internal/cbmpc/internal/crypto/ro.h similarity index 98% rename from src/cbmpc/crypto/ro.h rename to include-internal/cbmpc/internal/crypto/ro.h index 0b0ab554..2742931b 100644 --- a/src/cbmpc/crypto/ro.h +++ b/include-internal/cbmpc/internal/crypto/ro.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace coinbase::crypto::ro { // random oracle struct hmac_state_t { diff --git a/src/cbmpc/crypto/scope.h b/include-internal/cbmpc/internal/crypto/scope.h similarity index 100% rename from src/cbmpc/crypto/scope.h rename to include-internal/cbmpc/internal/crypto/scope.h diff --git a/include-internal/cbmpc/internal/crypto/secret_sharing.h b/include-internal/cbmpc/internal/crypto/secret_sharing.h new file mode 100644 index 00000000..8f966304 --- /dev/null +++ b/include-internal/cbmpc/internal/crypto/secret_sharing.h @@ -0,0 +1,195 @@ +#pragma once + +#include +#include + +namespace coinbase::crypto::ss { + +template +using party_map_t = std::map; + +std::vector share_and(const mod_t& q, const bn_t& x, const int n, crypto::drbg_aes_ctr_t* drbg = nullptr); +std::pair, std::vector> share_threshold(const mod_t& q, const bn_t& a, const int threshold, + const int n, const std::vector& pids, + crypto::drbg_aes_ctr_t* drbg = nullptr); + +enum class node_e { + NONE = 0, + LEAF = 1, + AND = 2, + OR = 3, + THRESHOLD = 4, +}; + +class node_t; + +typedef party_map_t ac_shares_t; +typedef party_map_t ac_internal_shares_t; +typedef party_map_t ac_pub_shares_t; +typedef party_map_t ac_internal_pub_shares_t; + +class ac_t; +class ac_owned_t; + +struct node_t { + friend class ac_t; + friend class ac_owned_t; + + node_e type; + pname_t name; + int threshold; + std::vector children; + node_t* parent = nullptr; + + node_t(node_e _type, pname_t _name, int _threshold = 0) : type(_type), name(_name), threshold(_threshold) {} + + node_t(node_e _type, pname_t _name, int _threshold, std::initializer_list nodes) + : type(_type), name(_name), threshold(_threshold), children(nodes) { + for (auto child : nodes) { + child->parent = this; + } + } + + ~node_t(); + node_t* clone() const; + + int get_n() const { return int(children.size()); } + std::string get_path() const; + + static bn_t pid_from_path(const std::string& path); + bn_t get_pid() const; + + std::vector list_leaf_paths() const; + std::set list_leaf_names() const; + const node_t* find(const pname_t& path) const; + void add_child_node(node_t* node); + void remove_and_delete(); + + error_t validate_tree() const { + std::set names; + return validate_tree(names); + } + error_t validate_tree(std::set& names) const; + bool enough_for_quorum(const std::set& names) const; + + std::vector get_sorted_children() const; + + private: + node_t() {} + void convert_node(coinbase::converter_t& c); +}; + +static std::string get_node_path(const std::string& parent_path, const node_t* node) { + if (!node->parent) return ""; + return parent_path + "/" + node->name; +} + +class ac_t { + public: + explicit ac_t() = default; + explicit ac_t(const node_t* _root) : root(_root) {} + ac_t(const node_t* _root, ecurve_t _curve) : root(_root), curve(_curve) {} + explicit ac_t(ecurve_t _curve) : curve(_curve) {} + + const node_t* get_root() const { return root; } + bool has_root() const { return root != nullptr; } + ecurve_t get_curve() const { return curve; } + bool has_curve() const { return curve.valid(); } + + error_t validate_tree() const { + if (!root) return coinbase::error(E_BADARG, "missing root"); + if (!curve.valid()) return coinbase::error(E_BADARG, "missing curve"); + return root->validate_tree(); + } + + const node_t* find(const pname_t& name) const { return root->find(name); } + std::set list_leaf_names() const { return root->list_leaf_names(); } + std::set list_pub_data_nodes() const; + int get_pub_data_size(const node_t* node) const { + if (node->type == node_e::AND) + return node->get_n(); + else if (node->type == node_e::THRESHOLD) + return node->threshold; + else + return 0; + } + + bool enough_for_quorum(const std::set names) const { return root ? root->enough_for_quorum(names) : false; } + template + bool enough_for_quorum(const party_map_t& map) const { + std::set names; + for (const auto& [name, value] : map) names.insert(name); + return root ? root->enough_for_quorum(names) : false; + } + + /** + * @specs: + * - basic-primitives-spec | ac-Share-1P + */ + ac_shares_t share(const mod_t& q, const bn_t& x, drbg_aes_ctr_t* drbg = nullptr) const; + error_t share_with_internals(const mod_t& q, const bn_t& x, ac_shares_t& shares, + ac_internal_shares_t& ac_internal_shares, + ac_internal_pub_shares_t& ac_internal_pub_shares, drbg_aes_ctr_t* drbg = nullptr) const; + error_t verify_share_against_ancestors_pub_data(const ecc_point_t& Q, const bn_t& si, + const ac_internal_pub_shares_t& pub_data, const pname_t& leaf) const; + + /** + * @specs: + * - basic-primitives-spec | ac-Reconstruct-1P + */ + error_t reconstruct(const mod_t& q, const ac_shares_t& shares, bn_t& x) const; + + /** + * @specs: + * - basic-primitives-spec | ac-Reconstruct-Exponent-1P + */ + error_t reconstruct_exponent(const ac_pub_shares_t& shares, ecc_point_t& P) const; + + const node_t* root = nullptr; + ecurve_t curve; +}; + +class ac_owned_t : public ac_t { + public: + ac_owned_t() = default; + explicit ac_owned_t(const node_t* _root) { assign(_root); } + explicit ac_owned_t(const node_t* _root, ecurve_t _curve) { assign(_root, _curve); } + explicit ac_owned_t(const ac_t& ac) { assign(ac); } + ~ac_owned_t() { delete root; } + void assign(const node_t* _root) { + delete root; + root = _root->clone(); + } + void assign(const node_t* _root, ecurve_t _curve) { + curve = _curve; + assign(_root); + } + void assign(const ac_t& ac) { + curve = ac.curve; + assign(ac.root); + } + ac_owned_t(const ac_owned_t& src) : ac_t() { assign(static_cast(src)); } + ac_owned_t(ac_owned_t&& src) : ac_t() { + root = src.root; + curve = src.curve; + src.root = nullptr; + src.curve = nullptr; + } + ac_owned_t& operator=(const ac_owned_t& src) { + if (&src != this) assign(static_cast(src)); + return *this; + } + ac_owned_t& operator=(ac_owned_t&& src) { + if (&src != this) { + delete root; + root = src.root; + curve = src.curve; + src.root = nullptr; + src.curve = nullptr; + } + return *this; + } + void convert(coinbase::converter_t& c); +}; + +} // namespace coinbase::crypto::ss \ No newline at end of file diff --git a/src/cbmpc/crypto/tdh2.h b/include-internal/cbmpc/internal/crypto/tdh2.h similarity index 82% rename from src/cbmpc/crypto/tdh2.h rename to include-internal/cbmpc/internal/crypto/tdh2.h index 46ede95d..ce2a7e5e 100644 --- a/src/cbmpc/crypto/tdh2.h +++ b/include-internal/cbmpc/internal/crypto/tdh2.h @@ -1,8 +1,8 @@ #pragma once -#include -#include -#include +#include +#include +#include namespace coinbase::crypto::tdh2 { @@ -18,7 +18,7 @@ struct ciphertext_t { bn_t e, f; buf_t L; - void convert(coinbase::converter_t& converter) { converter.convert(c, R1, R2, e, f, iv); } + void convert(coinbase::converter_t& converter) { converter.convert(c, R1, R2, e, f, iv, L); } /** * @specs: @@ -38,9 +38,12 @@ struct ciphertext_t { struct public_key_t { ecc_point_t Q, Gamma; + buf_t sid; public_key_t() {} - public_key_t(const ecc_point_t& _Q) : Q(_Q) { Gamma = ro::hash_curve(mem_t("TDH2-Gamma"), Q).curve(Q.get_curve()); } + public_key_t(const ecc_point_t& _Q, mem_t _sid) : Q(_Q), sid(_sid) { + Gamma = ro::hash_curve(mem_t("TDH2-Gamma"), Q, sid).curve(Q.get_curve()); + } /** * @specs: @@ -57,11 +60,11 @@ struct public_key_t { ciphertext_t encrypt(mem_t plain, mem_t label, const bn_t& r, const bn_t& s, mem_t iv) const; bool valid() const { return Q.valid(); } - void convert(coinbase::converter_t& converter) { converter.convert(Q, Gamma); } + void convert(coinbase::converter_t& converter) { converter.convert(Q, Gamma, sid); } buf_t to_bin() const { return coinbase::convert(*this); } error_t from_bin(mem_t bin) { return coinbase::convert(*this, bin); } - bool operator==(const public_key_t& other) const { return Q == other.Q && Gamma == other.Gamma; } - bool operator!=(const public_key_t& other) const { return Q != other.Q || Gamma != other.Gamma; } + bool operator==(const public_key_t& other) const { return Q == other.Q && Gamma == other.Gamma && sid == other.sid; } + bool operator!=(const public_key_t& other) const { return Q != other.Q || Gamma != other.Gamma || sid != other.sid; } }; struct private_key_t { @@ -74,11 +77,12 @@ struct private_key_t { }; struct partial_decryption_t { - int pid; + // Role/index id: 1..n, aligned with the public share ordering (Qi[0] corresponds to rid=1). + int rid = 0; ecc_point_t Xi; bn_t ei, fi; - void convert(coinbase::converter_t& converter) { converter.convert(pid, Xi, ei, fi); } + void convert(coinbase::converter_t& converter) { converter.convert(rid, Xi, ei, fi); } /** * @specs: @@ -92,7 +96,8 @@ struct partial_decryption_t { struct private_share_t { public_key_t pub_key; bn_t x; - int pid = 0; + // Role/index id: 1..n, aligned with the public share ordering (Qi[0] corresponds to rid=1). + int rid = 0; /** * @specs: diff --git a/src/cbmpc/protocol/agree_random.h b/include-internal/cbmpc/internal/protocol/agree_random.h similarity index 96% rename from src/cbmpc/protocol/agree_random.h rename to include-internal/cbmpc/internal/protocol/agree_random.h index e8c054e4..89ffbda8 100644 --- a/src/cbmpc/protocol/agree_random.h +++ b/include-internal/cbmpc/internal/protocol/agree_random.h @@ -1,5 +1,5 @@ #pragma once -#include +#include namespace coinbase::mpc { diff --git a/src/cbmpc/protocol/committed_broadcast.h b/include-internal/cbmpc/internal/protocol/committed_broadcast.h similarity index 96% rename from src/cbmpc/protocol/committed_broadcast.h rename to include-internal/cbmpc/internal/protocol/committed_broadcast.h index 587e6009..37b455a6 100644 --- a/src/cbmpc/protocol/committed_broadcast.h +++ b/include-internal/cbmpc/internal/protocol/committed_broadcast.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include namespace coinbase::mpc { diff --git a/src/cbmpc/protocol/ec_dkg.h b/include-internal/cbmpc/internal/protocol/ec_dkg.h similarity index 73% rename from src/cbmpc/protocol/ec_dkg.h rename to include-internal/cbmpc/internal/protocol/ec_dkg.h index 9313bed1..25751888 100644 --- a/src/cbmpc/protocol/ec_dkg.h +++ b/include-internal/cbmpc/internal/protocol/ec_dkg.h @@ -2,14 +2,13 @@ #include -#include -#include -#include -#include -#include -#include - -#include "util.h" +#include +#include +#include +#include +#include +#include +#include namespace coinbase::mpc::eckey { struct dkg_2p_t { @@ -88,29 +87,29 @@ struct key_share_mp_t { * @specs: * - ec-dkg-spec | EC-DKG-Threshold-MP * @notes: - * - This threshold DKG is not optimal in the sense that all n parties need to be connected + * - This access-structure DKG is not optimal in the sense that all n parties need to be connected * throughout, even though only t are active. In practice (and how we work in reality), it makes more sense for the t * parties to run the protocol, and then have the rest separately download the output message. This requires * additional infrastructure beyond what is in the scope of this library (like a PKI for the t parties to * encrypt-and-sign the output messages for the n parties), and therefore we implement this simpler demo DKG here. - * In the future, we are planning on adding a VSS implementation that will make it easier to implement a threshold DKG - * with only a subset of the parties online. + * In the future, we may add primitives that make it easier to implement a DKG where only a subset of the + * parties need to be online. */ - static error_t threshold_dkg(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t, - const party_set_t& quorum_party_set, key_share_mp_t& key); + static error_t dkg_ac(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t, + const party_set_t& quorum_party_set, key_share_mp_t& key); /** * @specs: * - ec-dkg-spec | EC-Refresh-Threshold-MP * @notes: * - See `dkg` for notes. */ - static error_t threshold_refresh(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t, - const party_set_t& quorum_party_set, key_share_mp_t& key, key_share_mp_t& new_key); + static error_t refresh_ac(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t, + const party_set_t& quorum_party_set, key_share_mp_t& key, key_share_mp_t& new_key); private: - static error_t threshold_dkg_or_refresh(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t, - const party_set_t& quorum_party_set, key_share_mp_t& key, - key_share_mp_t& new_key, bool is_refresh); + static error_t dkg_or_refresh_ac(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t, + const party_set_t& quorum_party_set, key_share_mp_t& key, key_share_mp_t& new_key, + bool is_refresh); }; } // namespace coinbase::mpc::eckey \ No newline at end of file diff --git a/src/cbmpc/protocol/ecdsa_2p.h b/include-internal/cbmpc/internal/protocol/ecdsa_2p.h similarity index 96% rename from src/cbmpc/protocol/ecdsa_2p.h rename to include-internal/cbmpc/internal/protocol/ecdsa_2p.h index 2e66809b..ce3a6425 100644 --- a/src/cbmpc/protocol/ecdsa_2p.h +++ b/include-internal/cbmpc/internal/protocol/ecdsa_2p.h @@ -2,10 +2,10 @@ #include -#include -#include -#include -#include +#include +#include +#include +#include namespace coinbase::mpc::ecdsa2pc { diff --git a/src/cbmpc/protocol/ecdsa_mp.h b/include-internal/cbmpc/internal/protocol/ecdsa_mp.h similarity index 89% rename from src/cbmpc/protocol/ecdsa_mp.h rename to include-internal/cbmpc/internal/protocol/ecdsa_mp.h index 0ac98871..1ab44323 100644 --- a/src/cbmpc/protocol/ecdsa_mp.h +++ b/include-internal/cbmpc/internal/protocol/ecdsa_mp.h @@ -2,9 +2,9 @@ #include -#include -#include -#include +#include +#include +#include namespace coinbase::mpc::ecdsampc { @@ -35,15 +35,15 @@ error_t refresh(job_mp_t& job, buf_t& sid, key_t& key, key_t& new_key); * @specs: * - ec-dkg-spec | EC-DKG-Threshold-MP */ -error_t threshold_dkg(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_t& key); +error_t dkg_ac(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_t& key); /** * @specs: * - ec-dkg-spec | EC-Refresh-Threshold-MP */ -error_t threshold_refresh(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_t& key, key_t& new_key); +error_t refresh_ac(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_t& key, key_t& new_key); /** * @specs: diff --git a/src/cbmpc/protocol/eddsa.h b/include-internal/cbmpc/internal/protocol/eddsa.h similarity index 82% rename from src/cbmpc/protocol/eddsa.h rename to include-internal/cbmpc/internal/protocol/eddsa.h index d02537c1..012f4fa5 100644 --- a/src/cbmpc/protocol/eddsa.h +++ b/include-internal/cbmpc/internal/protocol/eddsa.h @@ -1,8 +1,8 @@ #pragma once -#include -#include -#include +#include +#include +#include namespace coinbase::mpc::eddsa2pc { typedef schnorr2p::key_t key_t; diff --git a/src/cbmpc/protocol/hd_keyset_ecdsa_2p.h b/include-internal/cbmpc/internal/protocol/hd_keyset_ecdsa_2p.h similarity index 87% rename from src/cbmpc/protocol/hd_keyset_ecdsa_2p.h rename to include-internal/cbmpc/internal/protocol/hd_keyset_ecdsa_2p.h index 5c81a2b7..81c642f1 100644 --- a/src/cbmpc/protocol/hd_keyset_ecdsa_2p.h +++ b/include-internal/cbmpc/internal/protocol/hd_keyset_ecdsa_2p.h @@ -1,8 +1,8 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include +#include namespace coinbase::mpc { diff --git a/src/cbmpc/protocol/hd_keyset_eddsa_2p.h b/include-internal/cbmpc/internal/protocol/hd_keyset_eddsa_2p.h similarity index 86% rename from src/cbmpc/protocol/hd_keyset_eddsa_2p.h rename to include-internal/cbmpc/internal/protocol/hd_keyset_eddsa_2p.h index e8598626..5f9345ce 100644 --- a/src/cbmpc/protocol/hd_keyset_eddsa_2p.h +++ b/include-internal/cbmpc/internal/protocol/hd_keyset_eddsa_2p.h @@ -1,8 +1,8 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include +#include namespace coinbase::mpc { diff --git a/src/cbmpc/protocol/hd_tree_bip32.h b/include-internal/cbmpc/internal/protocol/hd_tree_bip32.h similarity index 97% rename from src/cbmpc/protocol/hd_tree_bip32.h rename to include-internal/cbmpc/internal/protocol/hd_tree_bip32.h index 915c6568..e7e87aaa 100644 --- a/src/cbmpc/protocol/hd_tree_bip32.h +++ b/include-internal/cbmpc/internal/protocol/hd_tree_bip32.h @@ -1,5 +1,5 @@ #pragma once -#include +#include namespace coinbase::mpc { diff --git a/src/cbmpc/protocol/int_commitment.h b/include-internal/cbmpc/internal/protocol/int_commitment.h similarity index 86% rename from src/cbmpc/protocol/int_commitment.h rename to include-internal/cbmpc/internal/protocol/int_commitment.h index 2b97cdd0..3920331f 100644 --- a/src/cbmpc/protocol/int_commitment.h +++ b/include-internal/cbmpc/internal/protocol/int_commitment.h @@ -1,8 +1,8 @@ #pragma once -#include -#include -#include +#include +#include +#include namespace coinbase::crypto { diff --git a/src/cbmpc/protocol/mpc_job.h b/include-internal/cbmpc/internal/protocol/mpc_job.h similarity index 88% rename from src/cbmpc/protocol/mpc_job.h rename to include-internal/cbmpc/internal/protocol/mpc_job.h index b281fc4e..6536ef7a 100644 --- a/src/cbmpc/protocol/mpc_job.h +++ b/include-internal/cbmpc/internal/protocol/mpc_job.h @@ -2,19 +2,17 @@ #include -#include -#include -#include -#include - -#include "data_transport.h" -#include "util.h" +#include +#include +#include +#include +#include +#include namespace coinbase::mpc { -typedef int32_t party_idx_t; - -enum class party_t : party_idx_t { p1 = 0, p2 = 1 }; +using party_idx_t = coinbase::api::party_idx_t; +using party_t = coinbase::api::party_2p_t; class party_set_t { public: @@ -167,11 +165,20 @@ class job_mp_t { int n_parties; std::vector pids; std::vector names; - std::shared_ptr transport_ptr; + // Transport lifetime: + // - If the job is constructed with a `shared_ptr`, it keeps the transport alive. + // - If the job is constructed with a `data_transport_i&`, it stores a non-owning + // pointer; the caller must ensure the transport outlives protocol execution. + std::shared_ptr transport_ptr; + coinbase::api::data_transport_i* transport_raw = nullptr; public: - job_mp_t(int index, std::vector pnames, std::shared_ptr tptr = nullptr) - : party_index(index), n_parties(pnames.size()), transport_ptr(tptr) { + job_mp_t(int index, std::vector pnames, + std::shared_ptr tptr = nullptr) + : party_index(index), + n_parties(pnames.size()), + transport_ptr(std::move(tptr)), + transport_raw(transport_ptr.get()) { if (party_index < 0 || party_index >= n_parties) coinbase::error(E_BADARG, "invalid party_index"); this->names = pnames; for (const auto& name : pnames) { @@ -181,19 +188,31 @@ class job_mp_t { cb_assert(pids.size() <= 64 && "at most 64 parties are supported"); } + job_mp_t(int index, std::vector pnames, coinbase::api::data_transport_i& transport) + : job_mp_t(index, std::move(pnames), /*tptr=*/nullptr) { + transport_raw = &transport; + } + virtual error_t send_impl(party_idx_t to, mem_t msg) { - if (!transport_ptr) return E_NET_GENERAL; - return transport_ptr->send(to, msg); + if (!transport_raw) return E_NET_GENERAL; + return transport_raw->send(to, msg); } virtual error_t receive_impl(party_idx_t from, buf_t& msg) { - if (!transport_ptr) return E_NET_GENERAL; - return transport_ptr->receive(from, msg); + if (!transport_raw) return E_NET_GENERAL; + return transport_raw->receive(from, msg); } virtual error_t receive_many_impl(std::vector from_set, std::vector& outs); - void set_transport(party_idx_t idx, std::shared_ptr ptr) { + void set_transport(party_idx_t idx, std::shared_ptr ptr) { party_index = idx; - transport_ptr = ptr; + transport_ptr = std::move(ptr); + transport_raw = transport_ptr.get(); + } + + void set_transport(party_idx_t idx, coinbase::api::data_transport_i& transport) { + party_index = idx; + transport_ptr.reset(); + transport_raw = &transport; } /* MPC Properties */ @@ -261,7 +280,7 @@ class job_mp_t { if (!from_set.is_empty()) { std::vector receive; if (rv = receive_from_parties(from_set, receive)) return rv; - unpack_msgs(from_set, receive, msgs...); + if (rv = unpack_msgs(from_set, receive, msgs...)) return rv; } return SUCCESS; @@ -444,9 +463,12 @@ class job_mp_t { class job_2p_t : public job_mp_t { public: job_2p_t(party_t index, crypto::pname_t pname1, crypto::pname_t pname2, - std::shared_ptr tptr = nullptr) + std::shared_ptr tptr = nullptr) : job_mp_t(party_idx_t(index), {pname1, pname2}, tptr) {} + job_2p_t(party_t index, crypto::pname_t pname1, crypto::pname_t pname2, coinbase::api::data_transport_i& transport) + : job_mp_t(party_idx_t(index), {pname1, pname2}, transport) {} + bool is_p1() const { return is_party_idx(party_idx_t(party_t::p1)); } bool is_p2() const { return is_party_idx(party_idx_t(party_t::p2)); } bool is_party(party_t party) const { return is_party_idx(party_idx_t(party)); } diff --git a/src/cbmpc/protocol/ot.h b/include-internal/cbmpc/internal/protocol/ot.h similarity index 97% rename from src/cbmpc/protocol/ot.h rename to include-internal/cbmpc/internal/protocol/ot.h index ccf8ebee..1f02e6ce 100644 --- a/src/cbmpc/protocol/ot.h +++ b/include-internal/cbmpc/internal/protocol/ot.h @@ -1,7 +1,7 @@ #pragma once -#include -#include -#include +#include +#include +#include namespace coinbase::mpc { @@ -55,7 +55,7 @@ class h_matrix_256rows_t { int rows() const { return 256; } void set_row(int index, mem_t value) { cb_assert(value.size == row_size_in_bytes()); - memmove(get_row(index).data, value.data, value.size); + memmove(const_cast(get_row(index).data), value.data, value.size); } mem_t get_row(int index) const { return mem_t(buf.data() + row_size_in_bytes() * index, row_size_in_bytes()); } diff --git a/include-internal/cbmpc/internal/protocol/pve.h b/include-internal/cbmpc/internal/protocol/pve.h new file mode 100644 index 00000000..b9a82199 --- /dev/null +++ b/include-internal/cbmpc/internal/protocol/pve.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +namespace coinbase::mpc { + +class ec_pve_t { + public: + ec_pve_t() = default; + + const static int kappa = SEC_P_COM; + const static int rho_size = 32; + + error_t encrypt(const pve_base_pke_i& base_pke, pve_keyref_t ek, mem_t label, ecurve_t curve, const bn_t& x); + error_t verify(const pve_base_pke_i& base_pke, pve_keyref_t ek, const ecc_point_t& Q, mem_t label) const; + error_t decrypt(const pve_base_pke_i& base_pke, pve_keyref_t dk, pve_keyref_t ek, mem_t label, ecurve_t curve, + bn_t& x, bool skip_verify = false) const; + + const ecc_point_t& get_Q() const { return Q; } + const buf_t& get_Label() const { return L; } + + void convert(coinbase::converter_t& converter) { + converter.convert(Q, L, b); + for (int i = 0; i < kappa; i++) { + converter.convert(x_rows[i]); + converter.convert(r[i]); + converter.convert(c[i]); + } + } + + private: + buf_t L; + ecc_point_t Q; + buf128_t b; + + bn_t x_rows[kappa]; + buf128_t r[kappa]; + buf_t c[kappa]; + + error_t restore_from_decrypted(int row_index, mem_t decrypted_x_buf, ecurve_t curve, bn_t& x_value) const; +}; + +} // namespace coinbase::mpc diff --git a/include-internal/cbmpc/internal/protocol/pve_ac.h b/include-internal/cbmpc/internal/protocol/pve_ac.h new file mode 100644 index 00000000..532ba942 --- /dev/null +++ b/include-internal/cbmpc/internal/protocol/pve_ac.h @@ -0,0 +1,101 @@ +#pragma once + +#include +#include +#include + +namespace coinbase::mpc { + +class ec_pve_ac_t { + public: + struct ciphertext_adapter_t { + buf_t ct_ser; + void convert(coinbase::converter_t& converter) { converter.convert(ct_ser); } + }; + + typedef std::map pks_t; // maps leaf path -> encryption key reference + typedef std::map sks_t; // maps leaf path -> decryption key reference + + static constexpr int kappa = SEC_P_COM; + static constexpr std::size_t iv_size = crypto::KEM_AEAD_IV_SIZE; + static constexpr std::size_t tag_size = crypto::KEM_AEAD_TAG_SIZE; + static constexpr std::size_t iv_bitlen = iv_size * 8; + + ec_pve_ac_t() : rows(kappa) {} + + void convert(coinbase::converter_t& converter) { + converter.convert(Q, L, b); + + for (int i = 0; i < kappa; i++) { + converter.convert(rows[i].x_bin); + converter.convert(rows[i].r); + converter.convert(rows[i].c); + converter.convert(rows[i].quorum_c); + } + } + + /** + * @specs: + * - publicly-verifiable-encryption-spec | vencrypt-batch-many-1P + */ + error_t encrypt(const pve_base_pke_i& base_pke, const crypto::ss::ac_t& ac, const pks_t& ac_pks, mem_t label, + ecurve_t curve, const std::vector& x); + + /** + * @specs: + * - publicly-verifiable-encryption-spec | vverify-batch-many-1P + */ + error_t verify(const pve_base_pke_i& base_pke, const crypto::ss::ac_t& ac, const pks_t& ac_pks, + const std::vector& Q, mem_t label) const; + + /** + * @specs: + * - publicly-verifiable-encryption-spec | vdecrypt-local-batch-many-1P + * + * @notes: + * Each party calls party_decrypt_row to produce its share for a specific row. + * Then, the caller aggregates shares using aggregate_to_restore_row to recover x. + * This is different from the spec since the decryption is not done in a loop, rather at each + * invocation, a single row is decrypted. As a result, it is the responsibility of the caller application + * to call this api multiple times if needed. + */ + error_t party_decrypt_row(const pve_base_pke_i& base_pke, const crypto::ss::ac_t& ac, int row_index, + const std::string& path, pve_keyref_t prv_key, mem_t label, bn_t& out_share) const; + + /** + * @specs: + * - publicly-verifiable-encryption-spec | vdecrypt-combine-batch-many-1P + */ + error_t aggregate_to_restore_row(const pve_base_pke_i& base_pke, const crypto::ss::ac_t& ac, int row_index, + mem_t label, const std::map& quorum_decrypted, + std::vector& x, bool skip_verify = false, + const pks_t& all_ac_pks = pks_t()) const; + const std::vector& get_Q() const { return Q; } + + private: + std::vector Q; + buf_t L; + buf128_t b; + struct row_t { + buf_t x_bin, r, c; + std::vector quorum_c; + }; + std::vector rows; + + error_t encrypt_row(const pve_base_pke_i& base_pke, const crypto::ss::ac_t& ac, const pks_t& ac_pks, mem_t label, + ecurve_t curve, mem_t seed, mem_t plain, buf_t& c, + std::vector& quorum_c) const; + + error_t encrypt_row0(const pve_base_pke_i& base_pke, const crypto::ss::ac_t& ac, const pks_t& ac_pks, mem_t label, + ecurve_t curve, mem_t r0_1, mem_t r0_2, int batch_size, std::vector& x0, buf_t& c, + std::vector& quorum_c) const; + + error_t encrypt_row1(const pve_base_pke_i& base_pke, const crypto::ss::ac_t& ac, const pks_t& ac_pks, mem_t label, + ecurve_t curve, mem_t r1, mem_t x1_bin, buf_t& c, + std::vector& quorum_c) const; + + static error_t find_quorum_ciphertext(const std::vector& sorted_leaves, const std::string& path, + const row_t& row, const ciphertext_adapter_t*& c); +}; + +} // namespace coinbase::mpc diff --git a/include-internal/cbmpc/internal/protocol/pve_base.h b/include-internal/cbmpc/internal/protocol/pve_base.h new file mode 100644 index 00000000..3f0c5d87 --- /dev/null +++ b/include-internal/cbmpc/internal/protocol/pve_base.h @@ -0,0 +1,258 @@ +#pragma once + +#include +#include + +namespace coinbase::mpc { + +namespace detail { +// A lightweight type tag for runtime-checked key type erasure. +// +// We intentionally avoid RTTI (`typeid`) so this works even if RTTI is disabled. +template +inline const void* pve_type_tag() noexcept { + static const int kTag = 0; + return &kTag; +} +} // namespace detail + +struct pve_keyref_t { + const void* ptr = nullptr; + const void* tag = nullptr; + + template + static pve_keyref_t from(const T& v) noexcept { + return pve_keyref_t{&v, detail::pve_type_tag()}; + } + + template + static pve_keyref_t from_ptr(const T* p) noexcept { + return pve_keyref_t{p, detail::pve_type_tag()}; + } + + template + const T* get() const noexcept { + if (tag != detail::pve_type_tag()) return nullptr; + return static_cast(ptr); + } +}; + +template +inline pve_keyref_t pve_keyref(const T& v) noexcept { + return pve_keyref_t::from(v); +} + +template +inline pve_keyref_t pve_keyref(const T* p) noexcept { + return pve_keyref_t::from_ptr(p); +} + +struct pve_base_pke_i { + virtual ~pve_base_pke_i() = default; + virtual error_t encrypt(pve_keyref_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const = 0; + virtual error_t decrypt(pve_keyref_t dk, mem_t label, mem_t ct, buf_t& out_plain) const = 0; +}; + +// Generic adapter that turns any KEM policy into a PVE base PKE via kem_aead_ciphertext_t +template +struct kem_pve_base_pke_t : public pve_base_pke_i { + using EK = typename KEM_POLICY::ek_t; + using DK = typename KEM_POLICY::dk_t; + using CT = crypto::kem_aead_ciphertext_t; + + error_t encrypt(pve_keyref_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { + const EK* pub_key = ek.get(); + if (!pub_key) return coinbase::error(E_BADARG, "invalid encryption key"); + crypto::drbg_aes_ctr_t drbg(rho); + CT ct; + error_t rv = ct.encrypt(*pub_key, label, plain, &drbg); + if (rv) return rv; + out_ct = ser(ct); + return SUCCESS; + } + + error_t decrypt(pve_keyref_t dk, mem_t label, mem_t ct_ser, buf_t& out_plain) const override { + const DK* prv_key = dk.get(); + if (!prv_key) return coinbase::error(E_BADARG, "invalid decryption key"); + error_t rv = UNINITIALIZED_ERROR; + CT ct; + if (rv = deser(ct_ser, ct)) return rv; + return ct.decrypt(*prv_key, label, out_plain); + } +}; + +template +inline const pve_base_pke_i& kem_pve_base_pke() { + static const kem_pve_base_pke_t pke; + return pke; +} + +// Accessors to built-in base PKE implementations for testing and convenience. +// (RSA-OAEP(2048) and ECIES(P-256), both implemented via kem_aead_ciphertext_t.) +const pve_base_pke_i& pve_base_pke_rsa(); +const pve_base_pke_i& pve_base_pke_ecies(); + +// --------------------------------------------------------------------------- +// Runtime / callback-based KEM adapters (for HSM / FFI / wrappers) +// --------------------------------------------------------------------------- + +// Generic runtime KEM callbacks. Intended for wrappers that want to provide a +// custom KEM but still reuse cbmpc's KEM/DEM transform (kem_aead_ciphertext_t). +// +// Requirements: +// - `encap` MUST be deterministic given `rho32`. +// - `decap` MUST return the same shared secret produced by `encap`. +struct pve_runtime_kem_callbacks_t { + using encap_fn_t = error_t (*)(void* ctx, mem_t ek_bytes, mem_t rho32, buf_t& kem_ct, buf_t& kem_ss); + using decap_fn_t = error_t (*)(void* ctx, const void* dk_handle, mem_t kem_ct, buf_t& kem_ss); + + void* ctx = nullptr; + encap_fn_t encap = nullptr; + decap_fn_t decap = nullptr; +}; + +struct pve_runtime_kem_ek_t { + mem_t ek_bytes; + const pve_runtime_kem_callbacks_t* callbacks = nullptr; +}; + +struct pve_runtime_kem_dk_t { + const void* dk_handle = nullptr; + const pve_runtime_kem_callbacks_t* callbacks = nullptr; +}; + +struct kem_policy_runtime_kem_t { + using ek_t = pve_runtime_kem_ek_t; + using dk_t = pve_runtime_kem_dk_t; + + static error_t encapsulate(const ek_t& pub_key, buf_t& kem_ct, buf_t& kem_ss, crypto::drbg_aes_ctr_t* drbg) { + if (!pub_key.callbacks || !pub_key.callbacks->encap) return E_BADARG; + constexpr int rho_size = 32; + buf_t rho = drbg ? drbg->gen(rho_size) : crypto::gen_random(rho_size); + return pub_key.callbacks->encap(pub_key.callbacks->ctx, pub_key.ek_bytes, rho, kem_ct, kem_ss); + } + + static error_t decapsulate(const dk_t& prv_key, mem_t kem_ct, buf_t& kem_ss) { + if (!prv_key.callbacks || !prv_key.callbacks->decap) return E_BADARG; + return prv_key.callbacks->decap(prv_key.callbacks->ctx, prv_key.dk_handle, kem_ct, kem_ss); + } +}; + +inline const pve_base_pke_i& pve_base_pke_runtime_kem() { return kem_pve_base_pke(); } + +// HSM-backed decapsulation for the built-in RSA-OAEP KEM. +// +// The callback must perform RSA-OAEP decryption (OAEP label is empty per our KEM policy) +// and return the recovered KEM shared secret. +struct pve_rsa_oaep_hsm_dk_t { + using decap_fn_t = error_t (*)(void* ctx, mem_t dk_handle, mem_t kem_ct, buf_t& kem_ss); + + mem_t dk_handle; + void* ctx = nullptr; + decap_fn_t decap = nullptr; +}; + +struct kem_policy_rsa_oaep_hsm_t { + using ek_t = crypto::rsa_pub_key_t; + using dk_t = pve_rsa_oaep_hsm_dk_t; + + static error_t encapsulate(const ek_t& pub_key, buf_t& kem_ct, buf_t& kem_ss, crypto::drbg_aes_ctr_t* drbg) { + return crypto::kem_policy_rsa_oaep_t::encapsulate(pub_key, kem_ct, kem_ss, drbg); + } + + static error_t decapsulate(const dk_t& prv_key, mem_t kem_ct, buf_t& kem_ss) { + if (!prv_key.decap) return E_BADARG; + error_t rv = prv_key.decap(prv_key.ctx, prv_key.dk_handle, kem_ct, kem_ss); + if (rv) return rv; + + // Our RSA-OAEP KEM policy uses a 32-byte shared secret (SHA-256 output size). + const int expected_ss_size = crypto::hash_alg_t::get(crypto::hash_e::sha256).size; + if (kem_ss.size() != expected_ss_size) return coinbase::error(E_CRYPTO, "invalid RSA KEM output size"); + + return SUCCESS; + } +}; + +inline const pve_base_pke_i& pve_base_pke_rsa_oaep_hsm() { return kem_pve_base_pke(); } + +// HSM-backed decapsulation for the built-in ECIES(P-256) KEM. +// +// The callback only needs to perform the ECDH step and return the raw affine-X +// coordinate as a 32-byte big-endian buffer. The library derives the final KEM +// shared secret per RFC 9180 (DHKEM(P-256, HKDF-SHA256)). +struct pve_ecies_p256_hsm_dk_t { + using ecdh_fn_t = error_t (*)(void* ctx, mem_t dk_handle, mem_t kem_ct, buf_t& dh_x32); + + mem_t dk_handle; + void* ctx = nullptr; + ecdh_fn_t ecdh = nullptr; + + // Uncompressed public key octets (for kem_context = enc || pub_key). + buf_t pub_key_oct; +}; + +struct kem_policy_ecdh_p256_hsm_t { + using ek_t = crypto::ecc_pub_key_t; + using dk_t = pve_ecies_p256_hsm_dk_t; + + static error_t encapsulate(const ek_t& pub_key, buf_t& kem_ct, buf_t& kem_ss, crypto::drbg_aes_ctr_t* drbg) { + return crypto::kem_policy_ecdh_p256_t::encapsulate(pub_key, kem_ct, kem_ss, drbg); + } + + static error_t decapsulate(const dk_t& prv_key, mem_t kem_ct, buf_t& kem_ss) { + error_t rv = UNINITIALIZED_ERROR; + if (!prv_key.ecdh) return E_BADARG; + if (prv_key.pub_key_oct.empty()) return coinbase::error(E_BADARG, "missing ECIES public key"); + + crypto::ecc_point_t E; + if (rv = E.from_oct(crypto::curve_p256, kem_ct)) return rv; + if (rv = crypto::curve_p256.check(E)) return rv; + + buf_t dh; + if (rv = prv_key.ecdh(prv_key.ctx, prv_key.dk_handle, kem_ct, dh)) return rv; + if (dh.size() != 32) return coinbase::error(E_CRYPTO, "invalid ECDH output size"); + + // kem_context = enc || pub_key + buf_t kem_context; + kem_context += kem_ct; + kem_context += prv_key.pub_key_oct; + + buf_t eae_prk = crypto::kem_policy_ecdh_p256_t::labeled_extract(mem_t("eae_prk"), dh, mem_t()); + kem_ss = crypto::kem_policy_ecdh_p256_t::labeled_expand(eae_prk, mem_t("shared_secret"), kem_context, 32); + return SUCCESS; + } +}; + +inline const pve_base_pke_i& pve_base_pke_ecies_p256_hsm() { return kem_pve_base_pke(); } + +/** + * @notes: + * - This is the underlying encryption used in PVE + */ +template +buf_t pve_base_encrypt(const typename HPKE_T::ek_t& pub_key, mem_t label, const buf_t& plaintext, mem_t rho) { + crypto::drbg_aes_ctr_t drbg(rho); + typename HPKE_T::ct_t ct; + ct.encrypt(pub_key, label, plaintext, &drbg); + return ser(ct); +} + +/** + * @notes: + * - This is the underlying decryption used in PVE + */ +template +error_t pve_base_decrypt(const typename HPKE_T::dk_t& prv_key, mem_t label, mem_t ciphertext, buf_t& plain) { + error_t rv = UNINITIALIZED_ERROR; + typename HPKE_T::ct_t ct; + if (rv = deser(ciphertext, ct)) return rv; + if (rv = ct.decrypt(prv_key, label, plain)) return rv; + return SUCCESS; +} + +template +static buf_t genPVELabelWithPoint(mem_t label, const T& Q) { + return buf_t(label) + "-" + strext::to_hex(crypto::sha256_t::hash(Q)); +} + +} // namespace coinbase::mpc diff --git a/include-internal/cbmpc/internal/protocol/pve_batch.h b/include-internal/cbmpc/internal/protocol/pve_batch.h new file mode 100644 index 00000000..a86961a8 --- /dev/null +++ b/include-internal/cbmpc/internal/protocol/pve_batch.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include + +namespace coinbase::mpc { + +class ec_pve_batch_t { + public: + explicit ec_pve_batch_t(int batch_count) : n(batch_count), rows(kappa) { + cb_assert(batch_count > 0 && batch_count <= MAX_BATCH_COUNT); + Q.resize(n); + } + + const static int kappa = SEC_P_COM; + // Upper bound to prevent integer-overflow and unbounded memory allocation when `n` is untrusted. + // This is a defensive limit; callers should treat any larger batch as invalid input. + static constexpr int MAX_BATCH_COUNT = 100000; + // We assume the base encryption scheme requires 32 bytes of randomness. If it needs more, it can be changed to use + // DRBG with 32 bytes of randomness as the seed. + const static int rho_size = 32; + + /** + * @specs: + * - publicly-verifiable-encryption-spec | vencrypt-batch-1P + */ + error_t encrypt(const pve_base_pke_i& base_pke, pve_keyref_t ek, mem_t label, ecurve_t curve, + const std::vector& x); + + /** + * @specs: + * - publicly-verifiable-encryption-spec | vverify-batch-1P + */ + error_t verify(const pve_base_pke_i& base_pke, pve_keyref_t ek, const std::vector& Q, mem_t label) const; + + /** + * @specs: + * - publicly-verifiable-encryption-spec | vdecrypt-batch-1P + */ + error_t decrypt(const pve_base_pke_i& base_pke, pve_keyref_t dk, pve_keyref_t ek, mem_t label, ecurve_t curve, + std::vector& x, bool skip_verify = false) const; + + int batch_count() const { return n; } + const std::vector& get_Qs() const { return Q; } + const buf_t& get_Label() const { return L; } + + void convert(coinbase::converter_t& converter) { + if (int(Q.size()) != n) { + converter.set_error(); + return; + } + + converter.convert(Q, L, b); + + for (int i = 0; i < kappa; i++) { + converter.convert(rows[i].x_bin); + converter.convert(rows[i].r); + converter.convert(rows[i].c); + } + } + + private: + int n; + + buf_t L; + std::vector Q; + buf128_t b; + + struct row_t { + buf_t x_bin; + buf_t r; + buf_t c; + }; + std::vector rows; + + error_t restore_from_decrypted(int row_index, mem_t decrypted_x_buf, ecurve_t curve, std::vector& xs) const; +}; + +} // namespace coinbase::mpc \ No newline at end of file diff --git a/src/cbmpc/protocol/schnorr_2p.h b/include-internal/cbmpc/internal/protocol/schnorr_2p.h similarity index 71% rename from src/cbmpc/protocol/schnorr_2p.h rename to include-internal/cbmpc/internal/protocol/schnorr_2p.h index 8bef0079..2e4ca304 100644 --- a/src/cbmpc/protocol/schnorr_2p.h +++ b/include-internal/cbmpc/internal/protocol/schnorr_2p.h @@ -1,10 +1,10 @@ #pragma once -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include namespace coinbase::mpc::schnorr2p { diff --git a/src/cbmpc/protocol/schnorr_mp.h b/include-internal/cbmpc/internal/protocol/schnorr_mp.h similarity index 68% rename from src/cbmpc/protocol/schnorr_mp.h rename to include-internal/cbmpc/internal/protocol/schnorr_mp.h index bdf02d44..47f15e93 100644 --- a/src/cbmpc/protocol/schnorr_mp.h +++ b/include-internal/cbmpc/internal/protocol/schnorr_mp.h @@ -2,9 +2,9 @@ #include -#include -#include -#include +#include +#include +#include namespace coinbase::mpc::schnorrmp { @@ -31,15 +31,15 @@ error_t refresh(job_mp_t& job, buf_t& sid, key_t& key, key_t& new_key); * @specs: * - ec-dkg-spec | EC-DKG-Threshold-MP */ -error_t threshold_dkg(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_t& key); +error_t dkg_ac(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_t& key); /** * @specs: * - ec-dkg-spec | EC-Refresh-Threshold-MP */ -error_t threshold_refresh(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_t& key, key_t& new_key); +error_t refresh_ac(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_t& key, key_t& new_key); /** * @specs: diff --git a/src/cbmpc/protocol/sid.h b/include-internal/cbmpc/internal/protocol/sid.h similarity index 94% rename from src/cbmpc/protocol/sid.h rename to include-internal/cbmpc/internal/protocol/sid.h index 8db7436f..1874ab85 100644 --- a/src/cbmpc/protocol/sid.h +++ b/include-internal/cbmpc/internal/protocol/sid.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include namespace coinbase::mpc { diff --git a/src/cbmpc/protocol/util.h b/include-internal/cbmpc/internal/protocol/util.h similarity index 94% rename from src/cbmpc/protocol/util.h rename to include-internal/cbmpc/internal/protocol/util.h index d1682c05..5fa2364e 100644 --- a/src/cbmpc/protocol/util.h +++ b/include-internal/cbmpc/internal/protocol/util.h @@ -1,6 +1,6 @@ #pragma once -#include +#include template static T SUM(T zero, int n, LAMBDA lambda) { @@ -62,7 +62,7 @@ auto map_args_to_tuple(F f, Args&&... args) { return map_args_to_tuple_impl(f, tup, std::index_sequence_for{}); } -inline bn_t curve_msg_to_bn(mem_t msg, const ecurve_t& curve) { +inline bn_t curve_msg_to_bn(coinbase::mem_t msg, const ecurve_t& curve) { if (msg.size > curve.size()) msg.size = curve.size(); return bn_t::from_bin(msg); } diff --git a/src/cbmpc/zk/fischlin.h b/include-internal/cbmpc/internal/zk/fischlin.h similarity index 88% rename from src/cbmpc/zk/fischlin.h rename to include-internal/cbmpc/internal/zk/fischlin.h index 308b0dca..68300da2 100644 --- a/src/cbmpc/zk/fischlin.h +++ b/include-internal/cbmpc/internal/zk/fischlin.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include namespace coinbase::zk { @@ -14,6 +14,11 @@ void sha256_update_zs(EVP_MD_CTX* ctx, const bn_t& first, REST&... rest) { cb_assert(first.get_bin_size() <= 256); // prevent stack overflow int len = first.to_bin(temp); + // Length-prefix each element to avoid ambiguous concatenation across multiple `bn_t`s: + // without this, tuples like (0x01, 0x02) and (0x0102, 0x00) hash identically. + byte_t len_be[4]; + coinbase::be_set_4(len_be, uint32_t(len)); + EVP_DigestUpdate(ctx, len_be, sizeof(len_be)); EVP_DigestUpdate(ctx, temp, len); sha256_update_zs(ctx, rest...); diff --git a/src/cbmpc/zk/small_primes.h b/include-internal/cbmpc/internal/zk/small_primes.h similarity index 81% rename from src/cbmpc/zk/small_primes.h rename to include-internal/cbmpc/internal/zk/small_primes.h index 01b7f030..a3660cdf 100644 --- a/src/cbmpc/zk/small_primes.h +++ b/include-internal/cbmpc/internal/zk/small_primes.h @@ -1,6 +1,8 @@ #pragma once -#include +#include + +namespace coinbase::zk { constexpr int small_primes_count = 10000; @@ -14,3 +16,5 @@ static error_t check_integer_with_small_primes(const bn_t& prime, int alpha) { } return SUCCESS; } + +} // namespace coinbase::zk diff --git a/src/cbmpc/zk/zk_ec.h b/include-internal/cbmpc/internal/zk/zk_ec.h similarity index 96% rename from src/cbmpc/zk/zk_ec.h rename to include-internal/cbmpc/internal/zk/zk_ec.h index 85051d0b..00012254 100644 --- a/src/cbmpc/zk/zk_ec.h +++ b/include-internal/cbmpc/internal/zk/zk_ec.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include namespace coinbase::zk { diff --git a/src/cbmpc/zk/zk_elgamal_com.h b/include-internal/cbmpc/internal/zk/zk_elgamal_com.h similarity index 94% rename from src/cbmpc/zk/zk_elgamal_com.h rename to include-internal/cbmpc/internal/zk/zk_elgamal_com.h index 3103dd91..e12fb744 100644 --- a/src/cbmpc/zk/zk_elgamal_com.h +++ b/include-internal/cbmpc/internal/zk/zk_elgamal_com.h @@ -1,8 +1,8 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include +#include namespace coinbase::zk { diff --git a/src/cbmpc/zk/zk_paillier.h b/include-internal/cbmpc/internal/zk/zk_paillier.h similarity index 98% rename from src/cbmpc/zk/zk_paillier.h rename to include-internal/cbmpc/internal/zk/zk_paillier.h index accaa835..e2923544 100644 --- a/src/cbmpc/zk/zk_paillier.h +++ b/include-internal/cbmpc/internal/zk/zk_paillier.h @@ -1,7 +1,7 @@ #pragma once -#include -#include -#include +#include +#include +#include namespace coinbase::zk { diff --git a/src/cbmpc/zk/zk_pedersen.h b/include-internal/cbmpc/internal/zk/zk_pedersen.h similarity index 96% rename from src/cbmpc/zk/zk_pedersen.h rename to include-internal/cbmpc/internal/zk/zk_pedersen.h index cf180c92..21499b01 100644 --- a/src/cbmpc/zk/zk_pedersen.h +++ b/include-internal/cbmpc/internal/zk/zk_pedersen.h @@ -1,7 +1,7 @@ #pragma once -#include -#include -#include +#include +#include +#include namespace coinbase::zk { diff --git a/src/cbmpc/zk/zk_unknown_order.h b/include-internal/cbmpc/internal/zk/zk_unknown_order.h similarity index 93% rename from src/cbmpc/zk/zk_unknown_order.h rename to include-internal/cbmpc/internal/zk/zk_unknown_order.h index 3e2c366a..ceee3fe2 100644 --- a/src/cbmpc/zk/zk_unknown_order.h +++ b/include-internal/cbmpc/internal/zk/zk_unknown_order.h @@ -1,5 +1,5 @@ #pragma once -#include +#include namespace coinbase::zk { diff --git a/src/cbmpc/zk/zk_util.h b/include-internal/cbmpc/internal/zk/zk_util.h similarity index 97% rename from src/cbmpc/zk/zk_util.h rename to include-internal/cbmpc/internal/zk/zk_util.h index f8ee124f..3ab21686 100644 --- a/src/cbmpc/zk/zk_util.h +++ b/include-internal/cbmpc/internal/zk/zk_util.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace coinbase::zk { diff --git a/include/cbmpc/api/curve.h b/include/cbmpc/api/curve.h new file mode 100644 index 00000000..44cb1881 --- /dev/null +++ b/include/cbmpc/api/curve.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace coinbase::api { + +// Public curve identifiers for API/FFI stability. +enum class curve_id : uint32_t { + p256 = 1, // NIST P-256 (aka prime256v1 / secp256r1) + secp256k1 = 2, // secp256k1 + ed25519 = 3, // Edwards25519 +}; + +} // namespace coinbase::api diff --git a/include/cbmpc/api/ecdsa_2p.h b/include/cbmpc/api/ecdsa_2p.h new file mode 100644 index 00000000..89ea8ed7 --- /dev/null +++ b/include/cbmpc/api/ecdsa_2p.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace coinbase::api::ecdsa_2p { + +using party_t = coinbase::api::party_2p_t; + +// Run the 2-party key generation protocol. +// +// The output `key_blob` is a versioned, opaque byte string that can be persisted +// by the caller and used in other API calls. +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& key_blob); + +// Refresh an existing key share, producing a new share (same public key). +error_t refresh(const coinbase::api::job_2p_t& job, mem_t key_blob, buf_t& new_key_blob); + +// Sign a message hash (not the raw message) using ECDSA-2PC. +// +// `sid` is an in/out session id used by the protocol; callers may pass an empty +// buffer and let the protocol derive one. +// +// Output signature is DER encoded. +// +// Note: the underlying protocol returns the signature only on P1. On P2, +// `sig_der` may be left empty on success. +error_t sign(const coinbase::api::job_2p_t& job, mem_t key_blob, mem_t msg_hash, buf_t& sid, buf_t& sig_der); + +// Get the compressed public key from a key blob. +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning SEC1 +// compressed point encoding. +// +// This is useful for verifiable backup schemes like PVE, where a scalar x can be +// verified against its corresponding curve point Q = x*G. +// +// Notes: +// - Intended for full key blobs. If you detached a scalar share, persist the +// returned public share point separately (or call this API before detaching). +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed); + +// Detach the private scalar share from a key blob, producing: +// - a key blob with its private scalar removed, and +// - the private scalar x encoded as a big-endian buffer. +// +// The scalar-removed blob is not usable for signing/refresh until restored with +// `attach_private_scalar`. +// +// Note (ECDSA-2PC encoding): +// - Unlike ECDSA-MP, this scalar encoding is NOT fixed-length. ECDSA-2PC keeps +// the share as a Paillier-compatible integer representative and it may grow +// after refresh. +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar); + +// Restore a full key blob by attaching a big-endian private scalar share x into a +// scalar-removed key blob (produced by `detach_private_scalar`). +// +// This validates that x matches the expected share point by checking: +// (x mod q)*G == public_share_compressed. +// +// Input: +// - `private_scalar` is a big-endian scalar encoding (variable-length). +// - `public_share_compressed` must be the SEC1 compressed point encoding of this +// party's share public point (Qi_self), e.g. from `get_public_share_compressed`. +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar, mem_t public_share_compressed, + buf_t& out_key_blob); + +} // namespace coinbase::api::ecdsa_2p diff --git a/include/cbmpc/api/ecdsa_mp.h b/include/cbmpc/api/ecdsa_mp.h new file mode 100644 index 00000000..6511f5cf --- /dev/null +++ b/include/cbmpc/api/ecdsa_mp.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include +#include +#include + +// All the functions have two versions: additive and ac. Additive means that the +// sharing is additive, while ac means that the sharing is according to a given access structure. +namespace coinbase::api::ecdsa_mp { + +// Run the multi-party key generation protocol. +// +// The output `key_blob` is a versioned, opaque byte string that can be persisted +// by the caller and used in other API calls. +// +// Supported curves: `curve_id::p256`, `curve_id::secp256k1`. +error_t dkg_additive(const job_mp_t& job, curve_id curve, buf_t& key_blob, buf_t& sid); + +// Run the multi-party key generation protocol with a general access +// structure. +// +// Notes: +// - This is an n-party protocol: **all** parties in +// `job.party_names` must be online and participate. +// - Only the provided `quorum_party_names` actively contribute to the generated +// key shares. +// - The output key blob represents an access-structure key share and +// is not directly usable with `sign_additive()`. Use `sign_ac()` to sign with an online +// quorum (it derives additive shares internally). +// - `sid` is an in/out session id used by the protocol. Callers may pass an +// empty buffer to let the protocol derive one. +// +// Supported curves: `curve_id::p256`, `curve_id::secp256k1`. +error_t dkg_ac(const job_mp_t& job, curve_id curve, buf_t& sid, const access_structure_t& access_structure, + const std::vector& quorum_party_names, buf_t& key_blob); + +// Refresh an existing key share set, producing a new key share. +// +// `sid` is an in/out session id used by the refresh protocol. Callers may pass +// an empty buffer to let the protocol derive one. +error_t refresh_additive(const job_mp_t& job, buf_t& sid, mem_t key_blob, buf_t& new_key_blob); + +// Refresh an existing key share set using the access-structure refresh protocol. +// +// Notes: +// - See `dkg_ac` for protocol participation semantics. +// - The output key blob represents an access-structure key share and +// is not directly usable with `sign_additive()`. Use `sign_ac()` to sign with an online +// quorum (it derives additive shares internally). +// - `sid` is an in/out session id used by the protocol. Callers may pass an +// empty buffer to let the protocol derive one. +error_t refresh_ac(const job_mp_t& job, buf_t& sid, mem_t key_blob, const access_structure_t& access_structure, + const std::vector& quorum_party_names, buf_t& new_key_blob); + +// Sign a message with ECDSA-MP and output signature on `sig_receiver`. +// +// Output signature is DER encoded. +// +// Note: the underlying protocol returns the signature only on `sig_receiver`. On +// other parties, `sig_der` may be left empty on success. +error_t sign_additive(const job_mp_t& job, mem_t key_blob, mem_t msg, party_idx_t sig_receiver, buf_t& sig_der); + +// Sign a message with ECDSA-MP using an access-structure key share (from +// `dkg_ac` / `refresh_ac`). +// +// This API first derives an additive-share signing key for the **online** signing +// parties in `job.party_names` and then runs the normal `sign_additive()` protocol among +// those parties. +// +// Notes: +// - Unlike `dkg_ac` / `refresh_ac`, `sign_ac` only requires the parties in +// `job.party_names` to be online and participate. +// - Output semantics match `sign_additive()`: the signature is returned only on +// `sig_receiver`. On other parties, `sig_der` may be left empty on success. +error_t sign_ac(const job_mp_t& job, mem_t ac_key_blob, const access_structure_t& access_structure, mem_t msg, + party_idx_t sig_receiver, buf_t& sig_der); + +// Get the compressed public key from a key blob (SEC1 compressed point). +// +// This is the same encoding as `ecdsa_2p::get_public_key_compressed`. +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private share backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning SEC1 +// compressed point encoding. +// +// This is useful for verifiable backup schemes like PVE, where a scalar x can be +// verified against its corresponding curve point Q = x*G. +// +// Notes: +// - Works for both additive (v1) and access-structure (v2) ECDSA-MP key blobs. +// - This can be called on a blob produced by `detach_private_scalar` since Qi_self +// remains present in the key blob even after redaction. +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed); + +// Detach the private scalar share from a key blob, producing: +// - a "public" key blob with its private scalar wiped, and +// - the private scalar x encoded as a fixed-length big-endian buffer. +// +// The public blob is safe to persist as public-only material, but is not usable +// for signing/refresh until restored with `attach_private_scalar`. +// +// Output: +// - `out_private_scalar_fixed` length equals the curve order size in bytes +// (e.g., 32 bytes for secp256k1/p256). +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed); + +// Restore a full key blob by attaching a fixed-length private scalar x into a +// public key blob (produced by `detach_private_scalar`). +// +// This validates that x matches the key blob by checking: +// - `public_share_compressed` matches the blob's Qi_self, and +// - x*G == Qi_self. +// +// Input: +// - `private_scalar_fixed` must be a fixed-length big-endian scalar encoding with +// length equal to the curve order size in bytes. +// - `public_share_compressed` must be the SEC1 compressed point encoding of +// this party's share public point (Qi_self), e.g. from `get_public_share_compressed`. +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob); + +} // namespace coinbase::api::ecdsa_mp diff --git a/include/cbmpc/api/eddsa_2p.h b/include/cbmpc/api/eddsa_2p.h new file mode 100644 index 00000000..8ab7efc6 --- /dev/null +++ b/include/cbmpc/api/eddsa_2p.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include + +namespace coinbase::api::eddsa_2p { + +using party_t = coinbase::api::party_2p_t; + +// Run the 2-party key generation protocol for EdDSA. +// +// The output `key_blob` is a versioned, opaque byte string that can be persisted +// by the caller and used in other API calls. +// +// Supported curves: `curve_id::ed25519`. +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& key_blob); + +// Refresh an existing key share, producing a new share (same public key). +error_t refresh(const coinbase::api::job_2p_t& job, mem_t key_blob, buf_t& new_key_blob); + +// Sign a message with EdDSA-2PC (Ed25519). +// +// Note: the underlying protocol returns the signature only on P1. On P2, `sig` +// may be left empty on success. +error_t sign(const coinbase::api::job_2p_t& job, mem_t key_blob, mem_t msg, buf_t& sig); + +// Get the Ed25519 public key from a key blob. +// +// Output is the standard Ed25519 32-byte compressed public key encoding. +// +// Note: Ed25519 public keys are always encoded in this compressed format; the +// `_compressed` suffix is provided for naming consistency with ECDSA APIs. +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning the +// standard Ed25519 32-byte compressed public key encoding. +// +// This is useful for verifiable backup schemes like PVE, where a scalar x can be +// verified against its corresponding curve point Q = x*G. +// +// Notes: +// - Intended for full key blobs. If you detached a scalar share, persist the +// returned public share point separately (or call this API before detaching). +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed); + +// Detach the private scalar share from a key blob, producing: +// - a "public" key blob with its private scalar removed, and +// - the private scalar x encoded as a fixed-length big-endian buffer. +// +// The public blob is safe to persist as public-only material, but is not usable +// for signing/refresh until restored with `attach_private_scalar`. +// +// Output: +// - `out_private_scalar_fixed` length equals the curve order size in bytes +// (32 bytes for ed25519). +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed); + +// Restore a full key blob by attaching a fixed-length private scalar x into a +// public key blob (produced by `detach_private_scalar`). +// +// This validates that x matches the expected share point by checking: +// x*G == public_share_compressed. +// +// Input: +// - `private_scalar_fixed` must be a fixed-length big-endian scalar encoding with +// length equal to the curve order size in bytes (32 bytes for ed25519). +// - `public_share_compressed` must be the standard Ed25519 32-byte compressed point +// encoding of this party's share public point (Qi_self), e.g. from `get_public_share_compressed`. +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob); + +} // namespace coinbase::api::eddsa_2p diff --git a/include/cbmpc/api/eddsa_mp.h b/include/cbmpc/api/eddsa_mp.h new file mode 100644 index 00000000..877fb9a4 --- /dev/null +++ b/include/cbmpc/api/eddsa_mp.h @@ -0,0 +1,131 @@ +#pragma once + +#include +#include +#include +#include +#include + +// All the functions have two versions: additive and ac. Additive means that the +// sharing is additive, while ac means that the sharing is according to a given access structure. +namespace coinbase::api::eddsa_mp { + +// Run the multi-party key generation protocol for EdDSA. +// +// The output `key_blob` is a versioned, opaque byte string that can be persisted +// by the caller and used in other API calls. +// +// Supported curves: `curve_id::ed25519`. +error_t dkg_additive(const job_mp_t& job, curve_id curve, buf_t& key_blob, buf_t& sid); + +// Run the multi-party key generation protocol for EdDSA with a +// general access structure. +// +// Notes: +// - This is an n-party protocol: **all** parties in +// `job.party_names` must be online and participate. +// - Only the provided `quorum_party_names` actively contribute to the generated +// key shares. +// - The output key blob represents an access-structure key share and +// is not directly usable with `sign_additive()`. Use `sign_ac()` to sign with an online +// quorum (it derives additive shares internally). +// - `sid` is an in/out session id used by the protocol. Callers may pass an +// empty buffer to let the protocol derive one. +// +// Supported curves: `curve_id::ed25519`. +error_t dkg_ac(const job_mp_t& job, curve_id curve, buf_t& sid, const access_structure_t& access_structure, + const std::vector& quorum_party_names, buf_t& key_blob); + +// Refresh an existing key share set, producing a new key share. +// +// `sid` is an in/out session id used by the refresh protocol. Callers may pass +// an empty buffer to let the protocol derive one. +error_t refresh_additive(const job_mp_t& job, buf_t& sid, mem_t key_blob, buf_t& new_key_blob); + +// Refresh an existing key share set using the access-structure refresh protocol. +// +// Notes: +// - See `dkg_ac` for protocol participation semantics. +// - The output key blob represents an access-structure key share and +// is not directly usable with `sign_additive()`. Use `sign_ac()` to sign with an online +// quorum (it derives additive shares internally). +// - `sid` is an in/out session id used by the protocol. Callers may pass an +// empty buffer to let the protocol derive one. +error_t refresh_ac(const job_mp_t& job, buf_t& sid, mem_t key_blob, const access_structure_t& access_structure, + const std::vector& quorum_party_names, buf_t& new_key_blob); + +// Sign a message with EdDSA-MP (Ed25519) and output signature on `sig_receiver`. +// +// Output signature is 64 bytes: R (32 bytes) || s (32 bytes). +// +// Note: the underlying protocol returns the signature only on `sig_receiver`. On +// other parties, `sig` may be left empty on success. +error_t sign_additive(const job_mp_t& job, mem_t key_blob, mem_t msg, party_idx_t sig_receiver, buf_t& sig); + +// Sign a message with EdDSA-MP (Ed25519) using an access-structure key share (from +// `dkg_ac` / `refresh_ac`). +// +// This API first derives an additive-share signing key for the **online** signing +// parties in `job.party_names` and then runs the normal `sign_additive()` protocol among +// those parties. +// +// Notes: +// - Unlike `dkg_ac` / `refresh_ac`, `sign_ac` only requires the parties in +// `job.party_names` to be online and participate. +// - Output semantics match `sign_additive()`: the signature is returned only on +// `sig_receiver`. On other parties, `sig` may be left empty on success. +error_t sign_ac(const job_mp_t& job, mem_t ac_key_blob, const access_structure_t& access_structure, mem_t msg, + party_idx_t sig_receiver, buf_t& sig); + +// Get the Ed25519 public key from a key blob. +// +// Output is the standard Ed25519 32-byte compressed public key encoding. +// +// Note: Ed25519 public keys are always encoded in this compressed format; the +// `_compressed` suffix is provided for naming consistency with ECDSA APIs. +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning the +// standard Ed25519 32-byte compressed public key encoding. +// +// This is useful for verifiable backup schemes like PVE, where a scalar x can be +// verified against its corresponding curve point Q = x*G. +// +// Notes: +// - Works for both additive (v1) and access-structure (v2) EdDSA-MP key blobs. +// - This can be called on a blob produced by `detach_private_scalar` since Qi_self +// remains present in the key blob even after redaction. +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed); + +// Detach the private scalar share from a key blob, producing: +// - a "public" key blob with its private scalar wiped, and +// - the private scalar x encoded as a fixed-length big-endian buffer. +// +// The public blob is safe to persist as public-only material, but is not usable +// for signing/refresh until restored with `attach_private_scalar`. +// +// Output: +// - `out_private_scalar_fixed` length equals the curve order size in bytes +// (32 bytes for ed25519). +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed); + +// Restore a full key blob by attaching a fixed-length private scalar x into a +// public key blob (produced by `detach_private_scalar`). +// +// This validates that x matches the key blob by checking: +// - `public_share_compressed` matches the blob's Qi_self, and +// - x*G == Qi_self. +// +// Input: +// - `private_scalar_fixed` must be a fixed-length big-endian scalar encoding with +// length equal to the curve order size in bytes (32 bytes for ed25519). +// - `public_share_compressed` must be the standard Ed25519 32-byte compressed point +// encoding of this party's share public point (Qi_self), e.g. from `get_public_share_compressed`. +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob); + +} // namespace coinbase::api::eddsa_mp diff --git a/include/cbmpc/api/hd_keyset_ecdsa_2p.h b/include/cbmpc/api/hd_keyset_ecdsa_2p.h new file mode 100644 index 00000000..78940104 --- /dev/null +++ b/include/cbmpc/api/hd_keyset_ecdsa_2p.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace coinbase::api::hd_keyset_ecdsa_2p { + +using coinbase::api::bip32_path_t; + +// Run the 2-party HD keyset initialization protocol. +// +// The output `keyset_blob` is a versioned, opaque byte string that can be persisted +// by the caller and used for refresh / derivation. +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& keyset_blob); + +// Refresh an existing HD keyset share, producing a new share (same public root keys). +error_t refresh(const coinbase::api::job_2p_t& job, mem_t keyset_blob, buf_t& new_keyset_blob); + +// Derive per-path ECDSA-2PC key blobs from an HD keyset. +// +// - `hardened_path` selects the hardened derivation branch (VRF-based step). +// - `non_hardened_paths` are applied to the derived public key using BIP32 non-hardened steps. +// - `sid` is an in/out session id used by the protocol; callers may pass an empty +// buffer and let the protocol derive one. +// +// Output: +// - `out_ecdsa_2p_key_blobs.size() == non_hardened_paths.size()` +// - Each element is an `ecdsa_2p` key blob compatible with `coinbase::api::ecdsa_2p::*`. +error_t derive_ecdsa_2p_keys(const coinbase::api::job_2p_t& job, mem_t keyset_blob, const bip32_path_t& hardened_path, + const std::vector& non_hardened_paths, buf_t& sid, + std::vector& out_ecdsa_2p_key_blobs); + +// Extract the compressed root ECDSA public key Q from a keyset blob. +error_t extract_root_public_key_compressed(mem_t keyset_blob, buf_t& out_Q_compressed); + +} // namespace coinbase::api::hd_keyset_ecdsa_2p diff --git a/include/cbmpc/api/hd_keyset_eddsa_2p.h b/include/cbmpc/api/hd_keyset_eddsa_2p.h new file mode 100644 index 00000000..c7a6177e --- /dev/null +++ b/include/cbmpc/api/hd_keyset_eddsa_2p.h @@ -0,0 +1,46 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace coinbase::api::hd_keyset_eddsa_2p { + +using coinbase::api::bip32_path_t; + +// Run the 2-party HD keyset initialization protocol (EdDSA / Ed25519). +// +// The output `keyset_blob` is a versioned, opaque byte string that can be persisted +// by the caller and used for refresh / derivation. +// +// Supported curves: `curve_id::ed25519`. +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& keyset_blob); + +// Refresh an existing HD keyset share, producing a new share (same public root keys). +error_t refresh(const coinbase::api::job_2p_t& job, mem_t keyset_blob, buf_t& new_keyset_blob); + +// Derive per-path EdDSA-2PC key blobs from an HD keyset. +// +// - `hardened_path` selects the hardened derivation branch (VRF-based step). +// - `non_hardened_paths` are applied to the derived public key using BIP32 non-hardened steps. +// - `sid` is an in/out session id used by the protocol; callers may pass an empty +// buffer and let the protocol derive one. +// +// Output: +// - `out_eddsa_2p_key_blobs.size() == non_hardened_paths.size()` +// - Each element is an `eddsa_2p` key blob compatible with `coinbase::api::eddsa_2p::*`. +error_t derive_eddsa_2p_keys(const coinbase::api::job_2p_t& job, mem_t keyset_blob, const bip32_path_t& hardened_path, + const std::vector& non_hardened_paths, buf_t& sid, + std::vector& out_eddsa_2p_key_blobs); + +// Extract the compressed root public key Q from a keyset blob. +// +// Output encoding matches `coinbase::api::eddsa_2p::get_public_key_compressed`: +// a 32-byte Ed25519 compressed public key. +error_t extract_root_public_key_compressed(mem_t keyset_blob, buf_t& out_Q_compressed); + +} // namespace coinbase::api::hd_keyset_eddsa_2p diff --git a/include/cbmpc/api/pve_base_pke.h b/include/cbmpc/api/pve_base_pke.h new file mode 100644 index 00000000..b9dbbd00 --- /dev/null +++ b/include/cbmpc/api/pve_base_pke.h @@ -0,0 +1,164 @@ +#pragma once + +#include +#include +#include + +namespace coinbase::api::pve { + +// Pluggable base public-key encryption (PKE) used by PVE. +// +// The PVE construction requires *deterministic* encryption given `rho`: +// calling `encrypt(...)` with the same inputs (including the same `rho`) must +// produce the same ciphertext. +// +// Implementations must treat `rho` as the sole source of encryption randomness. +class base_pke_i { + public: + virtual ~base_pke_i() = default; + + // Encrypt `plain` under `ek` with associated data `label`, using `rho` as + // deterministic randomness seed. + virtual error_t encrypt(mem_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const = 0; + + // Decrypt `ct` under `dk` with associated data `label`. + virtual error_t decrypt(mem_t dk, mem_t label, mem_t ct, buf_t& out_plain) const = 0; +}; + +// Built-in base PKE implementation (default). +// +// Key format: +// - `ek` / `dk` are opaque, versioned byte strings generated by: +// - `generate_base_pke_rsa_keypair(...)` +// - `generate_base_pke_ecies_p256_keypair(...)` +// +// Notes: +// - The built-in base PKE supports RSA-OAEP(2048) and ECIES(P-256). +// - The base PKE is independent of the curve used by the outer PVE scheme. +const base_pke_i& base_pke_default(); + +// Generate a keypair for the built-in RSA-OAEP (2048-bit) base PKE. +// +// Outputs are opaque, versioned byte strings compatible with `base_pke_default()`. +error_t generate_base_pke_rsa_keypair(buf_t& out_ek, buf_t& out_dk); + +// Generate a keypair for the built-in ECIES (P-256) base PKE. +// +// Outputs are opaque, versioned byte strings compatible with `base_pke_default()`. +error_t generate_base_pke_ecies_p256_keypair(buf_t& out_ek, buf_t& out_dk); + +// Build a built-in ECIES(P-256) base PKE public key blob from an external public key. +// +// This is useful when the private key lives in an HSM (or other external system) +// and only the public key can be exported to software. +// +// Input format: +// - `pub_key_oct` must be the *uncompressed* NIST P-256 public key octet string: +// 65 bytes: 0x04 || X(32) || Y(32). +// +// Output: +// - `out_ek` is an opaque, versioned byte string compatible with: +// - `base_pke_default()` +// - `encrypt/verify(...)` +// - `decrypt_ecies_p256_hsm(...)` +error_t base_pke_ecies_p256_ek_from_oct(mem_t pub_key_oct, buf_t& out_ek); + +// Build a built-in RSA-OAEP(2048) base PKE public key blob from a raw modulus. +// +// This is useful when the private key lives in an HSM (or other external system) +// that exports only the raw modulus (e.g. YubiHSM 2). +// +// Input format: +// - `modulus` must be the big-endian RSA modulus (256 bytes for RSA-2048). +// - The public exponent is assumed to be 65537. +// +// Output: +// - `out_ek` is an opaque, versioned byte string compatible with: +// - `base_pke_default()` +// - `encrypt/verify(...)` +// - `decrypt_rsa_oaep_hsm(...)` +error_t base_pke_rsa_ek_from_modulus(mem_t modulus, buf_t& out_ek); + +// --------------------------------------------------------------------------- +// Built-in HSM support (KEM decapsulation callback) +// --------------------------------------------------------------------------- + +// RSA-OAEP decapsulation callback. +// +// This callback is responsible only for recovering the RSA KEM shared secret +// from `kem_ct` using an HSM-held private key referenced by `dk_handle`. +// The library performs HKDF + AES-GCM DEM and ciphertext parsing. +struct rsa_oaep_hsm_decap_cb_t { + void* ctx = nullptr; + error_t (*decap)(void* ctx, mem_t dk_handle, mem_t kem_ct, buf_t& out_kem_ss) = nullptr; +}; + +// ECIES(P-256) ECDH callback. +// +// This callback should perform the ECDH operation using the HSM-held private +// key referenced by `dk_handle` and the ephemeral public key encoded in `kem_ct`. +// It must output the raw affine-X coordinate as a 32-byte big-endian buffer. +// The library derives the final KEM shared secret and performs HKDF + AES-GCM DEM. +struct ecies_p256_hsm_ecdh_cb_t { + void* ctx = nullptr; + error_t (*ecdh)(void* ctx, mem_t dk_handle, mem_t kem_ct, buf_t& out_dh_x32) = nullptr; +}; + +// Encrypt a scalar `x` under base encryption key `ek`, producing a PVE ciphertext. +// +// - `curve` selects the elliptic curve used for the outer PVE relation (Q = x*G). +// - `label` is associated data that must match in `verify()` / `decrypt()`. +// - `x` is interpreted as a big-endian integer and reduced modulo the curve order. +// - Defensive limit: `x.size` must be <= the curve order size in bytes. +error_t encrypt(const base_pke_i& base_pke, curve_id curve, mem_t ek, mem_t label, mem_t x, buf_t& out_ciphertext); + +// Same as `encrypt(...)` using `base_pke_default()`. +error_t encrypt(curve_id curve, mem_t ek, mem_t label, mem_t x, buf_t& out_ciphertext); + +// Verify ciphertext validity against the expected public value Q and label. +// +// - `Q_compressed` is the compressed public point encoding for `curve` +// (33 bytes for P-256/secp256k1, 32 bytes for Ed25519). +error_t verify(const base_pke_i& base_pke, curve_id curve, mem_t ek, mem_t ciphertext, mem_t Q_compressed, mem_t label); + +// Same as `verify(...)` using `base_pke_default()`. +error_t verify(curve_id curve, mem_t ek, mem_t ciphertext, mem_t Q_compressed, mem_t label); + +// Decrypt a ciphertext, recovering the scalar x. +// +// Output: +// - On success, `out_x` is set to a fixed-length big-endian encoding of x with +// length equal to the curve order size in bytes (e.g., 32 bytes for P-256). +// +// Notes: +// - This function intentionally does not verify `ciphertext` before decryption. +// Invalid ciphertexts may cause decryption to fail, but are designed to not +// leak secret information. +// - If you need ciphertext validation, call `verify(...)` first. +error_t decrypt(const base_pke_i& base_pke, curve_id curve, mem_t dk, mem_t ek, mem_t ciphertext, mem_t label, + buf_t& out_x); + +// Same as `decrypt(...)` using `base_pke_default()`. +error_t decrypt(curve_id curve, mem_t dk, mem_t ek, mem_t ciphertext, mem_t label, buf_t& out_x); + +// Decrypt using an HSM-backed RSA private key. +// +// - `dk_handle` is an opaque handle (or identifier) understood by the caller's HSM callback. +// - `ek` is the software public key blob (used for verification and encryption-side checks). +error_t decrypt_rsa_oaep_hsm(curve_id curve, mem_t dk_handle, mem_t ek, mem_t ciphertext, mem_t label, + const rsa_oaep_hsm_decap_cb_t& cb, buf_t& out_x); + +// Decrypt using an HSM-backed ECIES(P-256) private key. +// +// - `dk_handle` is an opaque handle (or identifier) understood by the caller's HSM callback. +// - `ek` is the software public key blob (used for verification and to derive the KEM context). +error_t decrypt_ecies_p256_hsm(curve_id curve, mem_t dk_handle, mem_t ek, mem_t ciphertext, mem_t label, + const ecies_p256_hsm_ecdh_cb_t& cb, buf_t& out_x); + +// Extract the public value Q whose discrete log is encrypted as the ciphertext, returning compressed point encoding. +error_t get_public_key_compressed(mem_t ciphertext, buf_t& out_Q_compressed); + +// Extract the associated label from a ciphertext. +error_t get_Label(mem_t ciphertext, buf_t& out_label); + +} // namespace coinbase::api::pve diff --git a/include/cbmpc/api/pve_batch_ac.h b/include/cbmpc/api/pve_batch_ac.h new file mode 100644 index 00000000..9cc67300 --- /dev/null +++ b/include/cbmpc/api/pve_batch_ac.h @@ -0,0 +1,117 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace coinbase::api::pve { + +// --------------------------------------------------------------------------- +// PVE-AC (access-structure / quorum decryption) +// --------------------------------------------------------------------------- +// +// This API encrypts a *batch* of scalars {x_i} under a leaf-keyed access structure. +// +// Decryption is multi-party in principle: +// - Each party decrypts its own leaf ciphertext to produce a share (step function). +// - An application collects enough shares to satisfy the access structure and aggregates them to recover {x_i}. + +// Leaf-key maps (leaf name -> base-PKE key blob). +// +// - For the built-in base PKE (`base_pke_default()`), the key blob format is the same as used by `pve.h`: +// RSA-OAEP(2048) or ECIES(P-256) public/private blobs. +using leaf_keys_t = std::map; + +// Leaf-share map (leaf name -> decrypted share bytes). +using leaf_shares_t = std::map; + +// Encrypt a batch of scalars {x_i} under an access structure. +// +// - `ac_pks` must provide a public key for every leaf in `ac`. +// - Each `xs[i]` is interpreted as a big-endian integer and reduced modulo the curve order. +// - Defensive limit: each `xs[i].size` must be <= the curve order size in bytes. +error_t encrypt_ac(const base_pke_i& base_pke, curve_id curve, const access_structure_t& ac, const leaf_keys_t& ac_pks, + mem_t label, const std::vector& xs, buf_t& out_ciphertext); + +// Same as `encrypt_ac(...)` using `base_pke_default()`. +error_t encrypt_ac(curve_id curve, const access_structure_t& ac, const leaf_keys_t& ac_pks, mem_t label, + const std::vector& xs, buf_t& out_ciphertext); + +// Verify an access-structure ciphertext against expected Qs and label. +error_t verify_ac(const base_pke_i& base_pke, curve_id curve, const access_structure_t& ac, const leaf_keys_t& ac_pks, + mem_t ciphertext, const std::vector& Qs_compressed, mem_t label); + +// Same as `verify_ac(...)` using `base_pke_default()`. +error_t verify_ac(curve_id curve, const access_structure_t& ac, const leaf_keys_t& ac_pks, mem_t ciphertext, + const std::vector& Qs_compressed, mem_t label); + +// Step 1: each party decrypts its share for a specific attempt. +// +// Output: +// - `out_share` is a fixed-length big-endian scalar encoding with length equal +// to the outer curve order size in bytes (e.g., 32 bytes for secp256k1). +error_t partial_decrypt_ac_attempt(const base_pke_i& base_pke, curve_id curve, const access_structure_t& ac, + mem_t ciphertext, int attempt_index, std::string_view leaf_name, mem_t dk, + mem_t label, buf_t& out_share); + +// Same as `partial_decrypt_ac_attempt(...)` using `base_pke_default()`. +error_t partial_decrypt_ac_attempt(curve_id curve, const access_structure_t& ac, mem_t ciphertext, int attempt_index, + std::string_view leaf_name, mem_t dk, mem_t label, buf_t& out_share); + +// Step 1 (HSM): decrypt a single leaf share for a specific attempt using an HSM-backed +// RSA-OAEP private key (KEM decapsulation callback). +// +// - `dk_handle` is an opaque handle (or identifier) understood by the HSM callback. +// - `ek` is the leaf's built-in base PKE public key blob (used to validate key type). +error_t partial_decrypt_ac_attempt_rsa_oaep_hsm(curve_id curve, const access_structure_t& ac, mem_t ciphertext, + int attempt_index, std::string_view leaf_name, mem_t dk_handle, + mem_t ek, mem_t label, const rsa_oaep_hsm_decap_cb_t& cb, + buf_t& out_share); + +// Step 1 (HSM): decrypt a single leaf share for a specific attempt using an HSM-backed +// ECIES(P-256) private key (ECDH callback only). +// +// - `dk_handle` is an opaque handle (or identifier) understood by the HSM callback. +// - `ek` is the leaf's built-in base PKE public key blob (used to validate key type +// and derive the KEM context). +error_t partial_decrypt_ac_attempt_ecies_p256_hsm(curve_id curve, const access_structure_t& ac, mem_t ciphertext, + int attempt_index, std::string_view leaf_name, mem_t dk_handle, + mem_t ek, mem_t label, const ecies_p256_hsm_ecdh_cb_t& cb, + buf_t& out_share); + +// Step 2: Aggregate enough decrypted shares to recover {x_i} for a specific attempt. +// If combine fails, then increase the attempt_index and gather another set of +// partial decryptions and call combine again. +// +// - `quorum_shares` must satisfy the access structure `ac`. +// +// Notes: +// - This function intentionally does not verify `ciphertext` before reconstruction. +// Invalid ciphertexts may cause reconstruction to fail, but are designed to not +// leak secret information. +// - If you need ciphertext validation, call `verify_ac(...)` first. +// +// Output: +// - `out_xs[i]` is a fixed-length big-endian encoding of x_i with length equal +// to the curve order size in bytes. +error_t combine_ac(const base_pke_i& base_pke, curve_id curve, const access_structure_t& ac, mem_t ciphertext, + int attempt_index, mem_t label, const leaf_shares_t& quorum_shares, std::vector& out_xs); + +// Same as `combine_ac(...)` using `base_pke_default()`. +error_t combine_ac(curve_id curve, const access_structure_t& ac, mem_t ciphertext, int attempt_index, mem_t label, + const leaf_shares_t& quorum_shares, std::vector& out_xs); + +// Extract batch count (number of scalars) from a PVE-AC ciphertext. +error_t get_ac_batch_count(mem_t ciphertext, int& out_batch_count); + +// Extract the public values {Q_i} whose discrete logs are encrypted as the ciphertext, returning compressed point +// encodings. +error_t get_public_keys_compressed_ac(mem_t ciphertext, std::vector& out_Qs_compressed); + +} // namespace coinbase::api::pve diff --git a/include/cbmpc/api/pve_batch_single_recipient.h b/include/cbmpc/api/pve_batch_single_recipient.h new file mode 100644 index 00000000..7ae612dc --- /dev/null +++ b/include/cbmpc/api/pve_batch_single_recipient.h @@ -0,0 +1,72 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace coinbase::api::pve { + +// --------------------------------------------------------------------------- +// Batch PVE (1P) API +// --------------------------------------------------------------------------- +// +// This API batches the core PVE algorithm for multiple scalars {x_i} in one +// ciphertext, sharing the same label and base PKE key. +// +// Notes: +// - `xs[i]` is interpreted as a big-endian integer and reduced modulo the curve order. +// - Defensive limit: each `xs[i].size` must be <= the curve order size in bytes. +// - `Qs_compressed[i]` is the compressed point encoding for the outer curve: +// 33 bytes for P-256/secp256k1, 32 bytes for Ed25519. + +error_t encrypt_batch(const base_pke_i& base_pke, curve_id curve, mem_t ek, mem_t label, const std::vector& xs, + buf_t& out_ciphertext); + +// Same as `encrypt_batch(...)` using `base_pke_default()`. +error_t encrypt_batch(curve_id curve, mem_t ek, mem_t label, const std::vector& xs, buf_t& out_ciphertext); + +error_t verify_batch(const base_pke_i& base_pke, curve_id curve, mem_t ek, mem_t ciphertext, + const std::vector& Qs_compressed, mem_t label); + +// Same as `verify_batch(...)` using `base_pke_default()`. +error_t verify_batch(curve_id curve, mem_t ek, mem_t ciphertext, const std::vector& Qs_compressed, mem_t label); + +// Decrypt a batch ciphertext, recovering the vector of scalars {x_i}. +// +// Output: +// - On success, `out_xs[i]` is a fixed-length big-endian encoding of x_i with +// length equal to the curve order size in bytes (e.g., 32 bytes for P-256). +// +// Notes: +// - This function intentionally does not verify `ciphertext` before decryption. +// Invalid ciphertexts may cause decryption to fail, but are designed to not +// leak secret information. +// - If you need ciphertext validation, call `verify_batch(...)` first. +error_t decrypt_batch(const base_pke_i& base_pke, curve_id curve, mem_t dk, mem_t ek, mem_t ciphertext, mem_t label, + std::vector& out_xs); + +// Same as `decrypt_batch(...)` using `base_pke_default()`. +error_t decrypt_batch(curve_id curve, mem_t dk, mem_t ek, mem_t ciphertext, mem_t label, std::vector& out_xs); + +// Decrypt a batch ciphertext using an HSM-backed RSA private key. +error_t decrypt_batch_rsa_oaep_hsm(curve_id curve, mem_t dk_handle, mem_t ek, mem_t ciphertext, mem_t label, + const rsa_oaep_hsm_decap_cb_t& cb, std::vector& out_xs); + +// Decrypt a batch ciphertext using an HSM-backed ECIES(P-256) private key. +error_t decrypt_batch_ecies_p256_hsm(curve_id curve, mem_t dk_handle, mem_t ek, mem_t ciphertext, mem_t label, + const ecies_p256_hsm_ecdh_cb_t& cb, std::vector& out_xs); + +// Extract batch count from a batch ciphertext. +error_t get_batch_count(mem_t ciphertext, int& out_batch_count); + +// Extract the public values {Q_i} whose discrete logs are encrypted as the ciphertext, returning compressed point +// encodings. +error_t get_public_keys_compressed_batch(mem_t ciphertext, std::vector& out_Qs_compressed); + +// Extract the associated label from a batch ciphertext. +error_t get_Label_batch(mem_t ciphertext, buf_t& out_label); + +} // namespace coinbase::api::pve diff --git a/include/cbmpc/api/schnorr_2p.h b/include/cbmpc/api/schnorr_2p.h new file mode 100644 index 00000000..4176e2c5 --- /dev/null +++ b/include/cbmpc/api/schnorr_2p.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include +#include + +namespace coinbase::api::schnorr_2p { + +using party_t = coinbase::api::party_2p_t; + +// Run the 2-party key generation protocol. +// +// The output `key_blob` is a versioned, opaque byte string that can be persisted +// by the caller and used in other API calls. +// +// This wrapper implements the BIP340 Schnorr signature scheme, which is defined +// for secp256k1 only. +// +// Supported curves: `curve_id::secp256k1`. +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& key_blob); + +// Refresh an existing key share, producing a new share (same public key). +error_t refresh(const coinbase::api::job_2p_t& job, mem_t key_blob, buf_t& new_key_blob); + +// Sign a message with Schnorr-2PC (BIP340). +// +// Input: +// - `msg` must be exactly 32 bytes (BIP340 message digest). +// +// Output signature is 64 bytes: r_x (32 bytes) || s (32 bytes). +// +// Note: the underlying protocol returns the signature only on P1. On P2, +// `sig` may be left empty on success. +error_t sign(const coinbase::api::job_2p_t& job, mem_t key_blob, mem_t msg, buf_t& sig); + +// Get the public key from a key blob. +// +// Notes: +// - For BIP340, the "standard" public key encoding is x-only (32 bytes). This +// API provides both x-only and SEC1 compressed encodings for convenience. +// - Output is deterministic and derived from the key blob contents. +// +// `pub_key_compressed` is a SEC1 compressed point encoding (33 bytes) on +// secp256k1: 0x02/0x03 || x (32 bytes). +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key_compressed); + +// Extract the BIP340 x-only public key (32 bytes). +error_t extract_public_key_xonly(mem_t key_blob, buf_t& pub_key_xonly); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning SEC1 +// compressed point encoding (33 bytes). +// +// This is useful for verifiable backup schemes like PVE, where a scalar x can be +// verified against its corresponding curve point Q = x*G. +// +// Notes: +// - Intended for full key blobs. If you detached a scalar share, persist the +// returned public share point separately (or call this API before detaching). +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed); + +// Detach the private scalar share from a key blob, producing: +// - a "public" key blob with its private scalar removed, and +// - the private scalar x encoded as a fixed-length big-endian buffer. +// +// The public blob is safe to persist as public-only material, but is not usable +// for signing/refresh until restored with `attach_private_scalar`. +// +// Output: +// - `out_private_scalar_fixed` length equals the curve order size in bytes +// (32 bytes for secp256k1). +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed); + +// Restore a full key blob by attaching a fixed-length private scalar x into a +// public key blob (produced by `detach_private_scalar`). +// +// This validates that x matches the expected share point by checking: +// x*G == public_share_compressed. +// +// Input: +// - `private_scalar_fixed` must be a fixed-length big-endian scalar encoding with +// length equal to the curve order size in bytes (32 bytes for secp256k1). +// - `public_share_compressed` must be the SEC1 compressed point encoding of this +// party's share public point (Qi_self), e.g. from `get_public_share_compressed`. +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob); + +} // namespace coinbase::api::schnorr_2p diff --git a/include/cbmpc/api/schnorr_mp.h b/include/cbmpc/api/schnorr_mp.h new file mode 100644 index 00000000..ce88725f --- /dev/null +++ b/include/cbmpc/api/schnorr_mp.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include +#include +#include +#include + +// All the functions have two versions: additive and ac. Additive means that the +// sharing is additive, while ac means that the sharing is according to a given access structure. +namespace coinbase::api::schnorr_mp { + +// Run the multi-party key generation protocol. +// +// The output `key_blob` is a versioned, opaque byte string that can be persisted +// by the caller and used in other API calls. +// +// This wrapper implements the BIP340 Schnorr signature scheme, which is defined +// for secp256k1 only. +// +// Supported curves: `curve_id::secp256k1`. +error_t dkg_additive(const coinbase::api::job_mp_t& job, curve_id curve, buf_t& key_blob, buf_t& sid); + +// Run the multi-party key generation protocol with a general access +// structure. +// +// Notes: +// - This is an n-party protocol: **all** parties in +// `job.party_names` must be online and participate. +// - Only the provided `quorum_party_names` actively contribute to the generated +// key shares. +// - The output key blob represents an access-structure key share and +// is not directly usable with `sign_additive()`. Use `sign_ac()` to sign with an online +// quorum (it derives additive shares internally). +// - `sid` is an in/out session id used by the protocol. Callers may pass an +// empty buffer to let the protocol derive one. +// +// Supported curves: `curve_id::secp256k1`. +error_t dkg_ac(const coinbase::api::job_mp_t& job, curve_id curve, buf_t& sid, + const access_structure_t& access_structure, const std::vector& quorum_party_names, + buf_t& key_blob); + +// Refresh an existing key share set, producing a new key share. +// +// `sid` is an in/out session id used by the refresh protocol. Callers may pass +// an empty buffer to let the protocol derive one. +error_t refresh_additive(const coinbase::api::job_mp_t& job, buf_t& sid, mem_t key_blob, buf_t& new_key_blob); + +// Refresh an existing key share set using the access-structure refresh protocol. +// +// Notes: +// - See `dkg_ac` for protocol participation semantics. +// - The output key blob represents an access-structure key share and +// is not directly usable with `sign_additive()`. Use `sign_ac()` to sign with an online +// quorum (it derives additive shares internally). +// - `sid` is an in/out session id used by the protocol. Callers may pass an +// empty buffer to let the protocol derive one. +error_t refresh_ac(const coinbase::api::job_mp_t& job, buf_t& sid, mem_t key_blob, + const access_structure_t& access_structure, const std::vector& quorum_party_names, + buf_t& new_key_blob); + +// Sign a message with Schnorr-MP (BIP340 on secp256k1) and output signature on +// `sig_receiver`. +// +// Input: +// - `msg` must be exactly 32 bytes (BIP340 message digest). +// +// Output signature is 64 bytes: r_x (32 bytes) || s (32 bytes). +// +// Note: the underlying protocol returns the signature only on `sig_receiver`. On +// other parties, `sig` may be left empty on success. +error_t sign_additive(const coinbase::api::job_mp_t& job, mem_t key_blob, mem_t msg, party_idx_t sig_receiver, + buf_t& sig); + +// Sign a message with Schnorr-MP (BIP340) using an access-structure key share (from +// `dkg_ac` / `refresh_ac`). +// +// This API first derives an additive-share signing key for the **online** signing +// parties in `job.party_names` and then runs the normal `sign_additive()` protocol among +// those parties. +// +// Notes: +// - Unlike `dkg_ac` / `refresh_ac`, `sign_ac` only requires the parties in +// `job.party_names` to be online and participate. +// - Output semantics match `sign_additive()`: the signature is returned only on +// `sig_receiver`. On other parties, `sig` may be left empty on success. +error_t sign_ac(const coinbase::api::job_mp_t& job, mem_t ac_key_blob, const access_structure_t& access_structure, + mem_t msg, party_idx_t sig_receiver, buf_t& sig); + +// Get the public key from a key blob. +// +// Notes: +// - For BIP340, the "standard" public key encoding is x-only (32 bytes). This +// API provides both x-only and SEC1 compressed encodings for convenience. +// - Output is deterministic and derived from the key blob contents. +// +// `pub_key_compressed` is a SEC1 compressed point encoding (33 bytes) on +// secp256k1: 0x02/0x03 || x (32 bytes). +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key_compressed); + +// Extract the BIP340 x-only public key (32 bytes). +error_t extract_public_key_xonly(mem_t key_blob, buf_t& pub_key_xonly); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning SEC1 +// compressed point encoding. +// +// This is useful for verifiable backup schemes like PVE, where a scalar x can be +// verified against its corresponding curve point Q = x*G. +// +// Notes: +// - Works for both additive (v1) and access-structure (v2) Schnorr-MP key blobs. +// - This can be called on a blob produced by `detach_private_scalar` since Qi_self +// remains present in the key blob even after redaction. +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed); + +// Detach the private scalar share from a key blob, producing: +// - a "public" key blob with its private scalar wiped, and +// - the private scalar x encoded as a fixed-length big-endian buffer. +// +// The public blob is safe to persist as public-only material, but is not usable +// for signing/refresh until restored with `attach_private_scalar`. +// +// Output: +// - `out_private_scalar_fixed` length equals the curve order size in bytes +// (32 bytes for secp256k1). +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed); + +// Restore a full key blob by attaching a fixed-length private scalar x into a +// public key blob (produced by `detach_private_scalar`). +// +// This validates that x matches the key blob by checking: +// - `public_share_compressed` matches the blob's Qi_self, and +// - x*G == Qi_self. +// +// Input: +// - `private_scalar_fixed` must be a fixed-length big-endian scalar encoding with +// length equal to the curve order size in bytes (32 bytes for secp256k1). +// - `public_share_compressed` must be the SEC1 compressed point encoding of +// this party's share public point (Qi_self), e.g. from `get_public_share_compressed`. +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob); + +} // namespace coinbase::api::schnorr_mp diff --git a/include/cbmpc/api/tdh2.h b/include/cbmpc/api/tdh2.h new file mode 100644 index 00000000..de6dbf84 --- /dev/null +++ b/include/cbmpc/api/tdh2.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +// All the functions have two versions: additive and ac. Additive means that the +// sharing is additive, while ac means that the sharing is according to a given access structure. +namespace coinbase::api::tdh2 { + +// Run the multi-party key generation protocol for TDH2 (additive shares). +// +// Outputs: +// - `sid`: session id used by the protocol (output on all parties). +// - `public_key`: TDH2 public key blob (same on all parties). +// - `public_shares`: compressed public shares Qi (same on all parties), ordered +// to match `job.party_names` (index i corresponds to `job.party_names[i]`). +// - `private_share`: opaque, versioned private share blob for this party only. +// +// Supported curves: `curve_id::p256`, `curve_id::secp256k1`. +error_t dkg_additive(const coinbase::api::job_mp_t& job, curve_id curve, buf_t& public_key, + std::vector& public_shares, buf_t& private_share, buf_t& sid); + +// Run the multi-party key generation protocol for TDH2 with a +// general access structure. +// +// Notes: +// - This is an n-party protocol (all parties in `job.party_names` +// participate), but only the provided `quorum_party_names` actively contribute +// to the generated key shares. +// - `sid` is an in/out session id used by the protocol; callers may pass an +// empty buffer and let the protocol derive one. +// +// Output semantics match `dkg_additive()`. +// +// Supported curves: `curve_id::p256`, `curve_id::secp256k1`. +error_t dkg_ac(const coinbase::api::job_mp_t& job, curve_id curve, buf_t& sid, + const access_structure_t& access_structure, const std::vector& quorum_party_names, + buf_t& public_key, std::vector& public_shares, buf_t& private_share); + +// Encrypt using a TDH2 public key. +error_t encrypt(mem_t public_key, mem_t plaintext, mem_t label, buf_t& ciphertext); + +// Verify ciphertext validity for a given public key and label. +error_t verify(mem_t public_key, mem_t ciphertext, mem_t label); + +// Locally compute a partial decryption from a private share. +error_t partial_decrypt(mem_t private_share, mem_t ciphertext, mem_t label, buf_t& partial_decryption); + +// Combine additive shares + partial decryptions to decrypt. +error_t combine_additive(mem_t public_key, const std::vector& public_shares, mem_t label, + const std::vector& partial_decryptions, mem_t ciphertext, buf_t& plaintext); + +// Combine access-structure shares + partial decryptions to decrypt. +// +// - `party_names` and `public_shares` define the mapping name -> Qi for *all* +// parties in the access structure. +// - `partial_decryption_party_names` and `partial_decryptions` provide the quorum +// subset used for decryption. +// +// Requirements: +// - `party_names.size() == public_shares.size()` +// - `partial_decryption_party_names.size() == partial_decryptions.size()` +// - The leaf set of `access_structure` must match `party_names` exactly. +error_t combine_ac(const access_structure_t& access_structure, mem_t public_key, + const std::vector& party_names, const std::vector& public_shares, + mem_t label, const std::vector& partial_decryption_party_names, + const std::vector& partial_decryptions, mem_t ciphertext, buf_t& plaintext); + +} // namespace coinbase::api::tdh2 diff --git a/include/cbmpc/c_api/access_structure.h b/include/cbmpc/c_api/access_structure.h new file mode 100644 index 00000000..805c4fd2 --- /dev/null +++ b/include/cbmpc/c_api/access_structure.h @@ -0,0 +1,79 @@ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// C representation of an access structure (a boolean / threshold policy tree). +// +// This type is used by access-structure protocols (e.g. *_dkg_ac, *_refresh_ac) +// to describe quorum conditions over party names. +// +// Encoding: +// - The access structure is represented as a rooted tree stored in flat arrays. +// - Each node specifies its children as a slice in the global `child_indices` array. +// - Leaf nodes carry a party name (NUL-terminated string). +// +// Requirements: +// - `root_index` must be a valid index in [0, nodes_count). +// - The node graph rooted at `root_index` must be a tree: +// - no cycles +// - no node reuse (DAGs are rejected) +// - all nodes must be reachable from the root (unreachable nodes are rejected) +// - The root node must not be a leaf. +// - Leaf nodes: +// - `leaf_name` must be non-NULL and non-empty +// - `child_indices_count` must be 0 +// - `threshold_k` must be 0 +// - AND/OR nodes: +// - `leaf_name` must be NULL +// - `child_indices_count` must be > 0 +// - `threshold_k` must be 0 +// - THRESHOLD nodes: +// - `leaf_name` must be NULL +// - `child_indices_count` must be > 0 +// - `threshold_k` must satisfy 1 <= threshold_k <= child_indices_count +// +// Memory/lifetime: +// - This is a view type; it does not own any memory. +// - All pointers (including `leaf_name`) must remain valid for the duration of +// the protocol call that consumes the access structure. +typedef enum cbmpc_access_structure_node_type_e { + CBMPC_ACCESS_STRUCTURE_NODE_LEAF = 1, + CBMPC_ACCESS_STRUCTURE_NODE_AND = 2, + CBMPC_ACCESS_STRUCTURE_NODE_OR = 3, + CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD = 4, +} cbmpc_access_structure_node_type_t; + +typedef struct cbmpc_access_structure_node_t { + cbmpc_access_structure_node_type_t type; + + // Leaf-only: party name (NUL-terminated). + // Must be NULL for non-leaf nodes. + const char* leaf_name; + + // Threshold-only: k in THRESHOLD[k](...). + // Must be 0 for non-threshold nodes. + int32_t threshold_k; + + // Children slice (global indices into `cbmpc_access_structure_t::nodes`). + int32_t child_indices_offset; + int32_t child_indices_count; +} cbmpc_access_structure_node_t; + +typedef struct cbmpc_access_structure_t { + const cbmpc_access_structure_node_t* nodes; + int32_t nodes_count; + + // Flattened concatenation of all child index lists. + const int32_t* child_indices; + int32_t child_indices_count; + + int32_t root_index; +} cbmpc_access_structure_t; + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/cmem.h b/include/cbmpc/c_api/cmem.h new file mode 100644 index 00000000..01f1a8dd --- /dev/null +++ b/include/cbmpc/c_api/cmem.h @@ -0,0 +1,47 @@ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// `cmem_t` is the C ABI representation of a byte slice. +// +// Invariants: +// - `size` must be non-negative. +// - If `size > 0`, then `data` must be non-NULL. +// +// Ownership: +// - A `cmem_t` may be a view into caller-owned memory *or* an owning buffer +// allocated by the library or a callback, depending on the API contract. +typedef struct tag_cmem_t { + uint8_t* data; + int size; +} cmem_t; + +// `cmems_t` is the C ABI representation of a list of byte slices, flattened into +// one contiguous `data` buffer and a parallel `sizes[]` array. +// +// Invariants: +// - `count` must be non-negative. +// - If `count > 0`, then `sizes` must be non-NULL. +// - Each `sizes[i]` must be non-negative. +// - Let `total = sum_i sizes[i]`. +// - If `total > 0`, then `data` must be non-NULL. +// - If `total == 0`, then `data` may be NULL. +// +// Layout: +// - Slice i is stored at `data + offset` with length `sizes[i]`, where +// `offset = sum_{j < i} sizes[j]`. +// +// Ownership: same caveats as `cmem_t` (API-specific). +typedef struct tag_cmems_t { + int count; + uint8_t* data; + int* sizes; +} cmems_t; + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/common.h b/include/cbmpc/c_api/common.h new file mode 100644 index 00000000..81551bf5 --- /dev/null +++ b/include/cbmpc/c_api/common.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef int cbmpc_error_t; + +enum { CBMPC_SUCCESS = 0 }; + +// Error codes. +// +// The C API returns a subset of cbmpc's internal error codes as integers. +// We expose the common values here so C/FFI consumers can interpret failures +// without including C++ headers. +// +// Error code encoding: +// 0xff000000 | (category << 16) | code +#define CBMPC_ERRCODE(category, code) ((cbmpc_error_t)(0xff000000u | (((uint32_t)(category)) << 16) | (uint32_t)(code))) +#define CBMPC_ECATEGORY(errcode) (((uint32_t)(errcode) >> 16) & 0x00ffu) + +#define CBMPC_ECATEGORY_GENERIC 0x01u +#define CBMPC_ECATEGORY_NETWORK 0x03u +#define CBMPC_ECATEGORY_CRYPTO 0x04u +#define CBMPC_ECATEGORY_OPENSSL 0x06u +#define CBMPC_ECATEGORY_CONTROL_FLOW 0x0au + +#define CBMPC_UNINITIALIZED_ERROR CBMPC_ERRCODE(CBMPC_ECATEGORY_GENERIC, 0x0000u) +#define CBMPC_E_GENERAL CBMPC_ERRCODE(CBMPC_ECATEGORY_GENERIC, 0x0001u) +#define CBMPC_E_BADARG CBMPC_ERRCODE(CBMPC_ECATEGORY_GENERIC, 0x0002u) +#define CBMPC_E_FORMAT CBMPC_ERRCODE(CBMPC_ECATEGORY_GENERIC, 0x0003u) +#define CBMPC_E_NOT_SUPPORTED CBMPC_ERRCODE(CBMPC_ECATEGORY_GENERIC, 0x0005u) +#define CBMPC_E_NOT_FOUND CBMPC_ERRCODE(CBMPC_ECATEGORY_GENERIC, 0x0006u) +#define CBMPC_E_INSUFFICIENT CBMPC_ERRCODE(CBMPC_ECATEGORY_GENERIC, 0x000cu) +#define CBMPC_E_RANGE CBMPC_ERRCODE(CBMPC_ECATEGORY_GENERIC, 0x0012u) + +#define CBMPC_E_NET_GENERAL CBMPC_ERRCODE(CBMPC_ECATEGORY_NETWORK, 0x0001u) + +// Crypto-category errors that can surface via the API wrappers. +#define CBMPC_E_CRYPTO CBMPC_ERRCODE(CBMPC_ECATEGORY_CRYPTO, 0x0001u) +#define CBMPC_E_ECDSA_2P_BIT_LEAK CBMPC_ERRCODE(CBMPC_ECATEGORY_CRYPTO, 0x0002u) + +typedef enum cbmpc_curve_id_e { + CBMPC_CURVE_P256 = 1, + CBMPC_CURVE_SECP256K1 = 2, + CBMPC_CURVE_ED25519 = 3, +} cbmpc_curve_id_t; + +// Memory helpers for the C API. +// +// - `cbmpc_malloc`/`cbmpc_free` are provided so FFI bindings (Go/Rust/...) can +// use the same allocator for buffers passed across the ABI boundary. +// - Any `cmem_t` returned by the library must be freed with `cbmpc_cmem_free` +// (which zeroizes the buffer contents before freeing). +// - Any `cmems_t` returned by the library must be freed with `cbmpc_cmems_free` +// (which zeroizes the flat `data` buffer before freeing). +void* cbmpc_malloc(size_t size); +void cbmpc_free(void* ptr); +void cbmpc_cmem_free(cmem_t mem); +void cbmpc_cmems_free(cmems_t mems); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/ecdsa_2p.h b/include/cbmpc/c_api/ecdsa_2p.h new file mode 100644 index 00000000..f6d65099 --- /dev/null +++ b/include/cbmpc/c_api/ecdsa_2p.h @@ -0,0 +1,89 @@ +#pragma once + +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Interactive key generation. +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_P256`, `CBMPC_CURVE_SECP256K1`. +cbmpc_error_t cbmpc_ecdsa_2p_dkg(const cbmpc_2pc_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob); + +// Interactive key refresh. +// +// Ownership: +// - On success, `out_new_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_new_key_blob)`. +// - On failure, `*out_new_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_2p_refresh(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t* out_new_key_blob); + +// Sign a message hash. Outputs DER signature and the session id used. +// +// Ownership: +// - On success, `sig_der_out->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*sig_der_out)`. +// - If `sid_out` is non-NULL, then on success `sid_out->data` is allocated by +// the library and must be freed with `cbmpc_cmem_free(*sid_out)`. +// - On failure, `*sig_der_out` (and `*sid_out` when provided) are set to +// `{NULL, 0}`. +// +// Note: the underlying protocol returns the DER signature only on P1. If +// `key_blob` belongs to P2, `*sig_der_out` may be `{NULL, 0}` on success. +cbmpc_error_t cbmpc_ecdsa_2p_sign(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t msg_hash, cmem_t sid_in, + cmem_t* sid_out, cmem_t* sig_der_out); + +// Compute compressed public key for a key blob. +// +// Ownership: +// - On success, `out_pub_key->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_pub_key)`. +// - On failure, `*out_pub_key` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning SEC1 +// compressed point encoding. +// +// Ownership: same as `cbmpc_ecdsa_2p_get_public_key_compressed`. +cbmpc_error_t cbmpc_ecdsa_2p_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share); + +// Detach a key blob into a scalar-removed blob + private scalar. +// +// Notes: +// - Unlike ECDSA-MP, the scalar encoding is NOT fixed-length: after refresh, +// ECDSA-2PC keeps the share as a Paillier-compatible integer representative and +// it may grow. +// +// Ownership: +// - On success, `out_public_key_blob->data` and `out_private_scalar->data` are +// allocated by the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, outputs are set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_2p_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar); + +// Attach a variable-length private scalar into a scalar-removed key blob, +// validating it against the expected public share point. +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_2p_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar, + cmem_t public_share_compressed, cmem_t* out_key_blob); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/ecdsa_mp.h b/include/cbmpc/c_api/ecdsa_mp.h new file mode 100644 index 00000000..256ffa2f --- /dev/null +++ b/include/cbmpc/c_api/ecdsa_mp.h @@ -0,0 +1,163 @@ +#pragma once + +#include + +#include +#include +#include + +// All the functions have two versions: additive and ac. Additive means that the +// sharing is additive, while ac means that the sharing is according to a given access structure. +#ifdef __cplusplus +extern "C" { +#endif + +// Interactive multi-party key generation. +// +// Ownership: +// - On success, `out_key_blob->data` and `out_sid->data` are allocated by the +// library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, `*out_key_blob` and `*out_sid` are set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_P256`, `CBMPC_CURVE_SECP256K1`. +cbmpc_error_t cbmpc_ecdsa_mp_dkg_additive(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob, + cmem_t* out_sid); + +// Interactive multi-party key generation with a general access structure. +// +// Notes: +// - This is an n-party protocol: **all** parties in `job->party_names` must be +// online and participate. +// - Only the provided `quorum_party_names` actively contribute to the generated +// key shares. +// - The output key blob represents an access-structure key share and +// is not directly usable with `cbmpc_ecdsa_mp_sign_additive`. Use `cbmpc_ecdsa_mp_sign_ac` +// to sign with an online quorum (it derives additive shares internally). +// - `sid_in` is the in/out session id used by the protocol; callers may pass +// `{NULL, 0}` and let the protocol derive one. +// +// Ownership: +// - On success, `out_ac_key_blob->data` and `out_sid->data` are allocated by +// the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, `*out_ac_key_blob` and `*out_sid` are set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_P256`, `CBMPC_CURVE_SECP256K1`. +cbmpc_error_t cbmpc_ecdsa_mp_dkg_ac(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t sid_in, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* out_ac_key_blob, cmem_t* out_sid); + +// Interactive multi-party key refresh (same public key). +// +// `sid_in` is the in/out session id used by the refresh protocol; callers may +// pass `{NULL, 0}` and let the protocol derive one. If `sid_out` is non-NULL, it +// will receive the session id used. +// +// Ownership: +// - On success, `out_new_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_new_key_blob)`. +// - If `sid_out` is non-NULL, then on success `sid_out->data` is allocated by +// the library and must be freed with `cbmpc_cmem_free(*sid_out)`. +// - On failure, `*out_new_key_blob` (and `*sid_out` when provided) are set to +// `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_mp_refresh_additive(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t key_blob, + cmem_t* sid_out, cmem_t* out_new_key_blob); + +// Interactive multi-party access-structure key refresh (same public key). +// +// Notes: +// - See `cbmpc_ecdsa_mp_dkg_ac` for protocol participation semantics. +// - The output key blob represents an access-structure key share and +// is not directly usable with `cbmpc_ecdsa_mp_sign_additive`. Use `cbmpc_ecdsa_mp_sign_ac` +// to sign with an online quorum (it derives additive shares internally). +// - `sid_in` is the in/out session id used by the refresh protocol; callers may +// pass `{NULL, 0}` and let the protocol derive one. If `sid_out` is non-NULL, it +// will receive the session id used. +// +// Ownership: +// - On success, `out_new_ac_key_blob->data` is allocated by the library and +// must be freed with `cbmpc_cmem_free(*out_new_ac_key_blob)`. +// - If `sid_out` is non-NULL, then on success `sid_out->data` is allocated by the +// library and must be freed with `cbmpc_cmem_free(*sid_out)`. +// - On failure, `*out_new_ac_key_blob` (and `*sid_out` when provided) are +// set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_mp_refresh_ac(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* sid_out, cmem_t* out_new_ac_key_blob); + +// Sign a message hash with ECDSA-MP. Outputs DER signature on `sig_receiver`. +// +// Ownership: +// - On success, `sig_der_out->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*sig_der_out)`. +// - On failure, `*sig_der_out` is set to `{NULL, 0}`. +// +// Note: the underlying protocol returns the DER signature only on +// `sig_receiver`. On other parties, `*sig_der_out` may be `{NULL, 0}` on +// success. +cbmpc_error_t cbmpc_ecdsa_mp_sign_additive(const cbmpc_mp_job_t* job, cmem_t key_blob, cmem_t msg_hash, + int32_t sig_receiver, cmem_t* sig_der_out); + +// Sign a message hash with ECDSA-MP using an access-structure key share (from +// `cbmpc_ecdsa_mp_dkg_ac` / `cbmpc_ecdsa_mp_refresh_ac`). +// +// This API first derives an additive-share signing key for the **online** signing +// parties in `job->party_names` and then runs the normal `cbmpc_ecdsa_mp_sign_additive` +// protocol among those parties. +// +// Notes: +// - Unlike `cbmpc_ecdsa_mp_dkg_ac` / `cbmpc_ecdsa_mp_refresh_ac`, `cbmpc_ecdsa_mp_sign_ac` +// only requires the parties in `job->party_names` to be online and participate. +// - Output semantics match `cbmpc_ecdsa_mp_sign_additive`: the signature is returned only +// on `sig_receiver`. On other parties, `*sig_der_out` may be `{NULL, 0}` on +// success. +// +// Ownership: same as `cbmpc_ecdsa_mp_sign_additive`. +cbmpc_error_t cbmpc_ecdsa_mp_sign_ac(const cbmpc_mp_job_t* job, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, cmem_t msg_hash, + int32_t sig_receiver, cmem_t* sig_der_out); + +// Compute compressed public key for a key blob (SEC1 compressed point). +// +// Ownership: +// - On success, `out_pub_key->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_pub_key)`. +// - On failure, `*out_pub_key` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private share backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning SEC1 +// compressed point encoding. +// +// Ownership: +// - On success, `out_public_share->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_public_share)`. +// - On failure, `*out_public_share` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share); + +// Detach a key blob into a public blob + fixed-length private scalar. +// +// Ownership: +// - On success, `out_public_key_blob->data` and `out_private_scalar_fixed->data` are +// allocated by the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, outputs are set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_mp_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed); + +// Attach a fixed-length private scalar into a public key blob, validating it +// against the expected public share point. +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_ecdsa_mp_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/eddsa_2p.h b/include/cbmpc/c_api/eddsa_2p.h new file mode 100644 index 00000000..f5fde1ca --- /dev/null +++ b/include/cbmpc/c_api/eddsa_2p.h @@ -0,0 +1,86 @@ +#pragma once + +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Interactive key generation for EdDSA-2PC (Ed25519). +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_ED25519`. +cbmpc_error_t cbmpc_eddsa_2p_dkg(const cbmpc_2pc_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob); + +// Interactive key refresh (same public key). +// +// Ownership: +// - On success, `out_new_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_new_key_blob)`. +// - On failure, `*out_new_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_eddsa_2p_refresh(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t* out_new_key_blob); + +// Sign a message with EdDSA-2PC (Ed25519). Outputs a 64-byte signature: R || s. +// +// Ownership: +// - On success, `sig_out->data` is allocated by the library and must be freed +// with `cbmpc_cmem_free(*sig_out)`. +// - On failure, `*sig_out` is set to `{NULL, 0}`. +// +// Note: the underlying protocol returns the signature only on P1. If `key_blob` +// belongs to P2, `*sig_out` may be `{NULL, 0}` on success. +cbmpc_error_t cbmpc_eddsa_2p_sign(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t msg, cmem_t* sig_out); + +// Get the Ed25519 public key from a key blob. +// +// Ownership: +// - On success, `out_pub_key->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_pub_key)`. +// - On failure, `*out_pub_key` is set to `{NULL, 0}`. +// +// Output is the standard Ed25519 32-byte compressed public key encoding. +// +// Note: Ed25519 public keys are always encoded in this compressed format; the +// `_compressed` suffix is provided for naming consistency with ECDSA APIs. +cbmpc_error_t cbmpc_eddsa_2p_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob. +// +// Output is the standard Ed25519 32-byte compressed point encoding. +// +// Ownership: same as `cbmpc_eddsa_2p_get_public_key_compressed`. +cbmpc_error_t cbmpc_eddsa_2p_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share); + +// Detach a key blob into a public blob + fixed-length private scalar. +// +// Ownership: +// - On success, `out_public_key_blob->data` and `out_private_scalar_fixed->data` are +// allocated by the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, outputs are set to `{NULL, 0}`. +cbmpc_error_t cbmpc_eddsa_2p_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed); + +// Attach a fixed-length private scalar into a public key blob, validating it +// against the expected public share point. +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_eddsa_2p_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/eddsa_mp.h b/include/cbmpc/c_api/eddsa_mp.h new file mode 100644 index 00000000..f29a817b --- /dev/null +++ b/include/cbmpc/c_api/eddsa_mp.h @@ -0,0 +1,159 @@ +#pragma once + +#include + +#include +#include +#include + +// All the functions have two versions: additive and ac. Additive means that the +// sharing is additive, while ac means that the sharing is according to a given access structure. +#ifdef __cplusplus +extern "C" { +#endif + +// Multi-party key generation for EdDSA-MP (Ed25519). +// +// Ownership: +// - On success, `out_key_blob->data` and `out_sid->data` are allocated by the +// library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, `*out_key_blob` and `*out_sid` are set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_ED25519`. +cbmpc_error_t cbmpc_eddsa_mp_dkg_additive(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob, + cmem_t* out_sid); + +// Multi-party key generation for EdDSA-MP (Ed25519) with a general access +// structure. +// +// Notes: +// - This is an n-party protocol: **all** parties in `job->party_names` must be +// online and participate. +// - Only the provided `quorum_party_names` actively contribute to the generated +// key shares. +// - The output key blob represents an access-structure key share and +// is not directly usable with `cbmpc_eddsa_mp_sign_additive`. Use `cbmpc_eddsa_mp_sign_ac` +// to sign with an online quorum (it derives additive shares internally). +// - `sid_in` is the in/out session id used by the protocol; callers may pass +// `{NULL, 0}` and let the protocol derive one. +// +// Ownership: +// - On success, `out_ac_key_blob->data` and `out_sid->data` are allocated by +// the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, `*out_ac_key_blob` and `*out_sid` are set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_ED25519`. +cbmpc_error_t cbmpc_eddsa_mp_dkg_ac(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t sid_in, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* out_ac_key_blob, cmem_t* out_sid); + +// Multi-party key refresh (same public key). +// +// `sid_in` is the in/out session id used by the refresh protocol; callers may +// pass `{NULL, 0}` and let the protocol derive one. If `sid_out` is non-NULL, it +// will receive the session id used. +// +// Ownership: +// - On success, `out_new_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_new_key_blob)`. +// - If `sid_out` is non-NULL, then on success `sid_out->data` is allocated by +// the library and must be freed with `cbmpc_cmem_free(*sid_out)`. +// - On failure, `*out_new_key_blob` (and `*sid_out` when provided) are set to +// `{NULL, 0}`. +cbmpc_error_t cbmpc_eddsa_mp_refresh_additive(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t key_blob, + cmem_t* sid_out, cmem_t* out_new_key_blob); + +// Multi-party access-structure key refresh (same public key). +// +// Notes: +// - See `cbmpc_eddsa_mp_dkg_ac` for protocol participation semantics. +// - The output key blob represents an access-structure key share and +// is not directly usable with `cbmpc_eddsa_mp_sign_additive`. Use `cbmpc_eddsa_mp_sign_ac` +// to sign with an online quorum (it derives additive shares internally). +// - `sid_in` is the in/out session id used by the refresh protocol; callers may +// pass `{NULL, 0}` and let the protocol derive one. If `sid_out` is non-NULL, it +// will receive the session id used. +// +// Ownership: same as `cbmpc_eddsa_mp_refresh_additive`. +cbmpc_error_t cbmpc_eddsa_mp_refresh_ac(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* sid_out, cmem_t* out_new_ac_key_blob); + +// Sign a message with EdDSA-MP (Ed25519). Outputs a 64-byte signature (R || s) on `sig_receiver`. +// +// Ownership: +// - On success, `sig_out->data` is allocated by the library and must be freed +// with `cbmpc_cmem_free(*sig_out)`. +// - On failure, `*sig_out` is set to `{NULL, 0}`. +// +// Note: the underlying protocol returns the signature only on `sig_receiver`. On +// other parties, `*sig_out` may be `{NULL, 0}` on success. +cbmpc_error_t cbmpc_eddsa_mp_sign_additive(const cbmpc_mp_job_t* job, cmem_t key_blob, cmem_t msg, int32_t sig_receiver, + cmem_t* sig_out); + +// Sign a message with EdDSA-MP (Ed25519) using an access-structure key share (from +// `cbmpc_eddsa_mp_dkg_ac` / `cbmpc_eddsa_mp_refresh_ac`). +// +// This API first derives an additive-share signing key for the **online** signing +// parties in `job->party_names` and then runs the normal `cbmpc_eddsa_mp_sign_additive` +// protocol among those parties. +// +// Notes: +// - Unlike `cbmpc_eddsa_mp_dkg_ac` / `cbmpc_eddsa_mp_refresh_ac`, `cbmpc_eddsa_mp_sign_ac` +// only requires the parties in `job->party_names` to be online and participate. +// - Output semantics match `cbmpc_eddsa_mp_sign_additive`: the signature is returned only +// on `sig_receiver`. On other parties, `*sig_out` may be `{NULL, 0}` on success. +// +// Ownership: same as `cbmpc_eddsa_mp_sign_additive`. +cbmpc_error_t cbmpc_eddsa_mp_sign_ac(const cbmpc_mp_job_t* job, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, cmem_t msg, int32_t sig_receiver, + cmem_t* sig_out); + +// Get the Ed25519 public key from a key blob. +// +// Ownership: +// - On success, `out_pub_key->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_pub_key)`. +// - On failure, `*out_pub_key` is set to `{NULL, 0}`. +// +// Output is the standard Ed25519 32-byte compressed public key encoding. +// +// Note: Ed25519 public keys are always encoded in this compressed format; the +// `_compressed` suffix is provided for naming consistency with ECDSA APIs. +cbmpc_error_t cbmpc_eddsa_mp_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob. +// +// Output is the standard Ed25519 32-byte compressed point encoding. +// +// Ownership: same as `cbmpc_eddsa_mp_get_public_key_compressed`. +cbmpc_error_t cbmpc_eddsa_mp_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share); + +// Detach a key blob into a public blob + fixed-length private scalar. +// +// Ownership: +// - On success, `out_public_key_blob->data` and `out_private_scalar_fixed->data` are +// allocated by the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, outputs are set to `{NULL, 0}`. +cbmpc_error_t cbmpc_eddsa_mp_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed); + +// Attach a fixed-length private scalar into a public key blob, validating it +// against the expected public share point. +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_eddsa_mp_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/job.h b/include/cbmpc/c_api/job.h new file mode 100644 index 00000000..63c0620d --- /dev/null +++ b/include/cbmpc/c_api/job.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef cbmpc_error_t (*cbmpc_transport_send_fn)(void* ctx, int32_t receiver, const uint8_t* data, int size); +typedef cbmpc_error_t (*cbmpc_transport_receive_fn)(void* ctx, int32_t sender, cmem_t* out_msg); +typedef cbmpc_error_t (*cbmpc_transport_receive_all_fn)(void* ctx, const int32_t* senders, int senders_count, + cmems_t* out_msgs); +typedef void (*cbmpc_transport_free_fn)(void* ctx, void* ptr); + +// Transport callbacks used by interactive protocols (e.g., ECDSA-2PC, ECDSA-MP). +// +// - `send` does not take ownership of `data`; the caller retains ownership. +// - `receive`/`receive_all` must return message buffers allocated by the caller. +// - The library will free returned buffers with `free(ctx, ptr)` when provided, +// otherwise it will call `cbmpc_free(ptr)`. +// - For `receive_all`, the callback must allocate **both** `out_msgs->data` and +// `out_msgs->sizes` (typically as two independent allocations). +// - On success, `receive_all` must set `out_msgs->count == senders_count` and +// return messages in the same order as `senders[]`. +// - On non-zero return, the library may still free any non-NULL output pointers +// set by the callback as a best-effort to avoid leaks. +typedef struct cbmpc_transport_t { + void* ctx; + cbmpc_transport_send_fn send; + cbmpc_transport_receive_fn receive; + cbmpc_transport_receive_all_fn receive_all; // optional (required by MP protocols, e.g., ECDSA-MP) + cbmpc_transport_free_fn free; // optional (defaults to cbmpc_free) +} cbmpc_transport_t; + +typedef enum cbmpc_2pc_party_e { + CBMPC_2PC_P1 = 0, + CBMPC_2PC_P2 = 1, +} cbmpc_2pc_party_t; + +// Execution context for 2-party protocols. +// +// Notes: +// - This is a lightweight view type. It does not own the transport object or +// the name strings. +// - `p1_name` and `p2_name` must be NUL-terminated strings. +// - The caller must ensure referenced objects (including the name strings) +// outlive the protocol call. +typedef struct cbmpc_2pc_job_t { + cbmpc_2pc_party_t self; + const char* p1_name; + const char* p2_name; + const cbmpc_transport_t* transport; +} cbmpc_2pc_job_t; + +// Execution context for multi-party protocols. +// +// Notes: +// - This is a lightweight view type. It does not own the transport object or +// the name strings. +// - `self` is the caller party index in `party_names`. +// - Each entry in `party_names` must be a NUL-terminated string. +// - The caller must ensure referenced objects (including the name strings) +// outlive the protocol call. +typedef struct cbmpc_mp_job_t { + int32_t self; + const char* const* party_names; + int party_names_count; + const cbmpc_transport_t* transport; +} cbmpc_mp_job_t; + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/pve_base_pke.h b/include/cbmpc/c_api/pve_base_pke.h new file mode 100644 index 00000000..f692650a --- /dev/null +++ b/include/cbmpc/c_api/pve_base_pke.h @@ -0,0 +1,220 @@ +#pragma once + +#include +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Pluggable base PKE callbacks for PVE (C API). +// +// Requirements: +// - `encrypt` must be deterministic given `rho`. +// - Any callback that writes `cmem_t* out` MUST allocate `out->data` with +// `cbmpc_malloc(out->size)`. cbmpc will free returned buffers with +// `cbmpc_cmem_free(...)` (zeroizes before free). +// +// Key format is defined by the callback implementation. +typedef cbmpc_error_t (*cbmpc_pve_base_pke_encrypt_fn)(void* ctx, cmem_t ek, cmem_t label, cmem_t plain, cmem_t rho, + cmem_t* out_ct); + +typedef cbmpc_error_t (*cbmpc_pve_base_pke_decrypt_fn)(void* ctx, cmem_t dk, cmem_t label, cmem_t ct, + cmem_t* out_plain); + +typedef struct cbmpc_pve_base_pke_t { + void* ctx; + cbmpc_pve_base_pke_encrypt_fn encrypt; + cbmpc_pve_base_pke_decrypt_fn decrypt; +} cbmpc_pve_base_pke_t; + +// --------------------------------------------------------------------------- +// Built-in base PKE software keys +// --------------------------------------------------------------------------- + +// Generate a keypair for cbmpc's built-in RSA-OAEP (2048-bit) base PKE. +// +// Outputs are opaque, versioned byte strings compatible with passing +// `base_pke = NULL` to `cbmpc_pve_encrypt/verify/decrypt`. +// +// Ownership: same as `cbmpc_pve_encrypt`. +cbmpc_error_t cbmpc_pve_generate_base_pke_rsa_keypair(cmem_t* out_ek, cmem_t* out_dk); + +// Generate a keypair for cbmpc's built-in ECIES (P-256) base PKE. +// +// Outputs are opaque, versioned byte strings compatible with passing +// `base_pke = NULL` to `cbmpc_pve_encrypt/verify/decrypt`. +// +// Ownership: same as `cbmpc_pve_encrypt`. +cbmpc_error_t cbmpc_pve_generate_base_pke_ecies_p256_keypair(cmem_t* out_ek, cmem_t* out_dk); + +// Build a cbmpc ECIES(P-256) base PKE public key blob from an external public key. +// +// This is useful when the private key lives in an HSM (or other external system) +// and only the public key can be exported to software. +// +// Input format: +// - `pub_key_oct` must be the *uncompressed* NIST P-256 public key octet string: +// 65 bytes: 0x04 || X(32) || Y(32). +// +// Output: +// - `out_ek` is an opaque, versioned byte string compatible with passing +// `base_pke = NULL` to `cbmpc_pve_encrypt/verify/decrypt`, and with +// `cbmpc_pve_decrypt_ecies_p256_hsm`. +// +// Ownership: same as `cbmpc_pve_encrypt`. +cbmpc_error_t cbmpc_pve_base_pke_ecies_p256_ek_from_oct(cmem_t pub_key_oct, cmem_t* out_ek); + +// Build a cbmpc RSA-OAEP(2048) base PKE public key blob from a raw modulus. +// +// This is useful when the private key lives in an HSM (or other external system) +// that exports only the raw modulus (e.g. YubiHSM 2). +// +// Input format: +// - `modulus` must be the big-endian RSA modulus (256 bytes for RSA-2048). +// - The public exponent is assumed to be 65537. +// +// Output: +// - `out_ek` is an opaque, versioned byte string compatible with passing +// `base_pke = NULL` to `cbmpc_pve_encrypt/verify/decrypt`, and with +// `cbmpc_pve_decrypt_rsa_oaep_hsm`. +// +// Ownership: same as `cbmpc_pve_encrypt`. +cbmpc_error_t cbmpc_pve_base_pke_rsa_ek_from_modulus(cmem_t modulus, cmem_t* out_ek); + +// --------------------------------------------------------------------------- +// Built-in base PKE HSM support (KEM decapsulation callbacks) +// --------------------------------------------------------------------------- + +// RSA-OAEP decapsulation callback. +// +// The callback must allocate `out_kem_ss->data` with `cbmpc_malloc`. +// The output is the KEM shared secret (OAEP decrypted value). +typedef cbmpc_error_t (*cbmpc_pve_rsa_oaep_hsm_decap_fn)(void* ctx, cmem_t dk_handle, cmem_t kem_ct, + cmem_t* out_kem_ss); + +typedef struct cbmpc_pve_rsa_oaep_hsm_decap_t { + void* ctx; + cbmpc_pve_rsa_oaep_hsm_decap_fn decap; +} cbmpc_pve_rsa_oaep_hsm_decap_t; + +// ECIES(P-256) ECDH callback. +// +// The callback must allocate `out_dh_x32->data` with `cbmpc_malloc`. +// The output must be exactly 32 bytes (affine-X coordinate, big-endian). +typedef cbmpc_error_t (*cbmpc_pve_ecies_p256_hsm_ecdh_fn)(void* ctx, cmem_t dk_handle, cmem_t kem_ct, + cmem_t* out_dh_x32); + +typedef struct cbmpc_pve_ecies_p256_hsm_ecdh_t { + void* ctx; + cbmpc_pve_ecies_p256_hsm_ecdh_fn ecdh; +} cbmpc_pve_ecies_p256_hsm_ecdh_t; + +// Decrypt using an HSM-backed RSA private key (KEM decapsulation callback). +// +// - `dk_handle` is an opaque handle understood by the callback. +// - `ek` is the software public key blob (used to validate key type and for optional pre-validation via +// `cbmpc_pve_verify`). +// +// Ownership: +// - On success, `out_x->data` is allocated by the library and must be freed with +// `cbmpc_cmem_free(*out_x)`. +// - On failure, `*out_x` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_pve_decrypt_rsa_oaep_hsm(cbmpc_curve_id_t curve, cmem_t dk_handle, cmem_t ek, cmem_t ciphertext, + cmem_t label, const cbmpc_pve_rsa_oaep_hsm_decap_t* cb, cmem_t* out_x); + +// Decrypt using an HSM-backed ECIES(P-256) private key (ECDH callback). +// +// - `dk_handle` is an opaque handle understood by the callback. +// - `ek` is the software public key blob (used for KEM context and optional pre-validation via `cbmpc_pve_verify`). +// +// Ownership: same as `cbmpc_pve_decrypt_rsa_oaep_hsm`. +cbmpc_error_t cbmpc_pve_decrypt_ecies_p256_hsm(cbmpc_curve_id_t curve, cmem_t dk_handle, cmem_t ek, cmem_t ciphertext, + cmem_t label, const cbmpc_pve_ecies_p256_hsm_ecdh_t* cb, cmem_t* out_x); + +// --------------------------------------------------------------------------- +// Custom KEM (library provides KEM/DEM transform) +// --------------------------------------------------------------------------- + +// Custom KEM encapsulation callback. +// +// - Must be deterministic given `rho32`. +// - Must allocate outputs with `cbmpc_malloc`. +typedef cbmpc_error_t (*cbmpc_pve_kem_encap_fn)(void* ctx, cmem_t ek, cmem_t rho32, cmem_t* out_kem_ct, + cmem_t* out_kem_ss); + +// Custom KEM decapsulation callback. +// +// - Must allocate outputs with `cbmpc_malloc`. +typedef cbmpc_error_t (*cbmpc_pve_kem_decap_fn)(void* ctx, cmem_t dk, cmem_t kem_ct, cmem_t* out_kem_ss); + +typedef struct cbmpc_pve_base_kem_t { + void* ctx; + cbmpc_pve_kem_encap_fn encap; + cbmpc_pve_kem_decap_fn decap; +} cbmpc_pve_base_kem_t; + +// Encrypt using a custom KEM (cbmpc provides HKDF + AES-GCM DEM). +// +// Ownership: same as `cbmpc_pve_encrypt`. +cbmpc_error_t cbmpc_pve_encrypt_with_kem(const cbmpc_pve_base_kem_t* kem, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t label, cmem_t x, cmem_t* out_ciphertext); + +// Verify using a custom KEM (requires `kem->encap` for deterministic recomputation). +cbmpc_error_t cbmpc_pve_verify_with_kem(const cbmpc_pve_base_kem_t* kem, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t ciphertext, cmem_t Q_compressed, cmem_t label); + +// Decrypt using a custom KEM (cbmpc provides HKDF + AES-GCM DEM). +// +// Ownership: same as `cbmpc_pve_decrypt`. +cbmpc_error_t cbmpc_pve_decrypt_with_kem(const cbmpc_pve_base_kem_t* kem, cbmpc_curve_id_t curve, cmem_t dk, cmem_t ek, + cmem_t ciphertext, cmem_t label, cmem_t* out_x); + +// Encrypt a scalar `x` under base encryption key `ek`, producing a PVE ciphertext. +// +// If `base_pke` is NULL, cbmpc uses its built-in default base PKE. +// +// Ownership: +// - On success, `out_ciphertext->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_ciphertext)`. +// - On failure, `*out_ciphertext` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_pve_encrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t ek, cmem_t label, + cmem_t x, cmem_t* out_ciphertext); + +// Verify ciphertext validity against the expected Q and label. +// +// - `Q_compressed` is the compressed point encoding for `curve`. +// - If `base_pke` is NULL, cbmpc uses its built-in default base PKE. +cbmpc_error_t cbmpc_pve_verify(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t ciphertext, cmem_t Q_compressed, cmem_t label); + +// Decrypt a ciphertext, recovering the scalar x. +// +// Notes: +// - This function intentionally does not verify `ciphertext` before decryption. +// Invalid ciphertexts may cause decryption to fail, but are designed to not +// leak secret information. +// - If you need ciphertext validation, call `cbmpc_pve_verify(...)` (or `cbmpc_pve_verify_with_kem(...)`) first. +// +// Ownership: +// - On success, `out_x->data` is allocated by the library and must be freed with +// `cbmpc_cmem_free(*out_x)`. +// - On failure, `*out_x` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_pve_decrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t dk, cmem_t ek, + cmem_t ciphertext, cmem_t label, cmem_t* out_x); + +// Extract Q from a ciphertext (compressed point encoding). +// +// Ownership: same as `cbmpc_pve_encrypt`. +cbmpc_error_t cbmpc_pve_get_Q(cmem_t ciphertext, cmem_t* out_Q_compressed); + +// Extract label from a ciphertext. +// +// Ownership: same as `cbmpc_pve_encrypt`. +cbmpc_error_t cbmpc_pve_get_Label(cmem_t ciphertext, cmem_t* out_label); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/pve_batch_ac.h b/include/cbmpc/c_api/pve_batch_ac.h new file mode 100644 index 00000000..d028312e --- /dev/null +++ b/include/cbmpc/c_api/pve_batch_ac.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// --------------------------------------------------------------------------- +// PVE-AC (access-structure / quorum decryption) API (C API) +// --------------------------------------------------------------------------- +// +// This API encrypts a *batch* of scalars {x_i} under a leaf-keyed access structure. +// +// Decryption is stepwise: +// - Each party calls `cbmpc_pve_ac_partial_decrypt_attempt` to produce a leaf share for a specific attempt. +// - The application collects enough shares and calls `cbmpc_pve_ac_combine` to recover {x_i}. +// +// Notes: +// - Leaf keys are passed as a mapping (parallel arrays) from leaf name to key blob. +// - If `base_pke` is NULL, cbmpc uses its built-in default base PKE. + +cbmpc_error_t cbmpc_pve_ac_encrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, const char* const* leaf_names, + const cmem_t* leaf_eks, int leaf_count, cmem_t label, cmems_t xs, + cmem_t* out_ciphertext); + +cbmpc_error_t cbmpc_pve_ac_verify(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, const char* const* leaf_names, + const cmem_t* leaf_eks, int leaf_count, cmem_t ciphertext, cmems_t Qs_compressed, + cmem_t label); + +// Step 1: decrypt a single leaf share for `attempt_index`. +// +// Ownership: same as `cbmpc_pve_encrypt`. +cbmpc_error_t cbmpc_pve_ac_partial_decrypt_attempt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, cmem_t ciphertext, + int attempt_index, const char* leaf_name, cmem_t dk, cmem_t label, + cmem_t* out_share); + +// Step 1 (HSM): decrypt a single leaf share for `attempt_index` using an HSM-backed +// RSA-OAEP private key (KEM decapsulation callback). +// +// - `dk_handle` is an opaque handle understood by the callback. +// - `ek` is the leaf's built-in base PKE public key blob (used to validate key type). +// +// Ownership: same as `cbmpc_pve_ac_partial_decrypt_attempt`. +cbmpc_error_t cbmpc_pve_ac_partial_decrypt_attempt_rsa_oaep_hsm(cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, cmem_t ciphertext, + int attempt_index, const char* leaf_name, + cmem_t dk_handle, cmem_t ek, cmem_t label, + const cbmpc_pve_rsa_oaep_hsm_decap_t* cb, + cmem_t* out_share); + +// Step 1 (HSM): decrypt a single leaf share for `attempt_index` using an HSM-backed +// ECIES(P-256) private key (ECDH callback only). +// +// - `dk_handle` is an opaque handle understood by the callback. +// - `ek` is the leaf's built-in base PKE public key blob (used to validate key type +// and derive the KEM context). +// +// Ownership: same as `cbmpc_pve_ac_partial_decrypt_attempt`. +cbmpc_error_t cbmpc_pve_ac_partial_decrypt_attempt_ecies_p256_hsm(cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, cmem_t ciphertext, + int attempt_index, const char* leaf_name, + cmem_t dk_handle, cmem_t ek, cmem_t label, + const cbmpc_pve_ecies_p256_hsm_ecdh_t* cb, + cmem_t* out_share); + +// Step 2: aggregate enough leaf shares to recover {x_i} for `attempt_index`. +// If combine fails, then increase the attempt_index and gather another set of +// partial decryptions and call combine again. +// +// - `quorum_leaf_names[i]` corresponds to `quorum_shares[i]`. +// +// Notes: +// - This function intentionally does not verify `ciphertext` before reconstruction. +// Invalid ciphertexts may cause reconstruction to fail, but are designed to not +// leak secret information. +// - If you need ciphertext validation, call `cbmpc_pve_ac_verify(...)` first. +// +// Ownership: same as `cbmpc_pve_batch_decrypt`. +cbmpc_error_t cbmpc_pve_ac_combine(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, const char* const* quorum_leaf_names, + const cmem_t* quorum_shares, int quorum_count, cmem_t ciphertext, int attempt_index, + cmem_t label, cmems_t* out_xs); + +// Extract batch count from a PVE-AC ciphertext. +cbmpc_error_t cbmpc_pve_ac_get_count(cmem_t ciphertext, int* out_batch_count); + +// Extract {Q_i} from a PVE-AC ciphertext (compressed point encodings). +// +// Ownership: same as `cbmpc_pve_batch_decrypt`. +cbmpc_error_t cbmpc_pve_ac_get_Qs(cmem_t ciphertext, cmems_t* out_Qs_compressed); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/pve_batch_single_recipient.h b/include/cbmpc/c_api/pve_batch_single_recipient.h new file mode 100644 index 00000000..d4036abd --- /dev/null +++ b/include/cbmpc/c_api/pve_batch_single_recipient.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// --------------------------------------------------------------------------- +// Batch PVE (1P) API (C API) +// --------------------------------------------------------------------------- +// +// This API batches the core PVE algorithm for multiple scalars {x_i} in one ciphertext. +// +// - `xs` / `Qs_compressed` / `out_xs` use `cmems_t` (flattened segments). +// - If `base_pke` is NULL, cbmpc uses its built-in default base PKE. + +// Encrypt a batch of scalars {x_i}, producing a single batch ciphertext. +// +// Ownership: +// - On success, `out_ciphertext->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_ciphertext)`. +// - On failure, `*out_ciphertext` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_pve_batch_encrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t label, cmems_t xs, cmem_t* out_ciphertext); + +cbmpc_error_t cbmpc_pve_batch_verify(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t ciphertext, cmems_t Qs_compressed, cmem_t label); + +// Decrypt a batch ciphertext, recovering {x_i}. +// +// Ownership: +// - On success, `out_xs->data` and `out_xs->sizes` are allocated by the library and must be freed with +// `cbmpc_cmems_free(*out_xs)`. +// - On failure, `*out_xs` is set to `{0, NULL, NULL}`. +// +// Notes: +// - This function intentionally does not verify `ciphertext` before decryption. +// Invalid ciphertexts may cause decryption to fail, but are designed to not +// leak secret information. +// - If you need ciphertext validation, call `cbmpc_pve_batch_verify(...)` first. +cbmpc_error_t cbmpc_pve_batch_decrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t dk, + cmem_t ek, cmem_t ciphertext, cmem_t label, cmems_t* out_xs); + +// Decrypt using an HSM-backed RSA private key (KEM decapsulation callback). +// +// Ownership: same as `cbmpc_pve_batch_decrypt`. +cbmpc_error_t cbmpc_pve_batch_decrypt_rsa_oaep_hsm(cbmpc_curve_id_t curve, cmem_t dk_handle, cmem_t ek, + cmem_t ciphertext, cmem_t label, + const cbmpc_pve_rsa_oaep_hsm_decap_t* cb, cmems_t* out_xs); + +// Decrypt using an HSM-backed ECIES(P-256) private key (ECDH callback). +// +// Ownership: same as `cbmpc_pve_batch_decrypt`. +cbmpc_error_t cbmpc_pve_batch_decrypt_ecies_p256_hsm(cbmpc_curve_id_t curve, cmem_t dk_handle, cmem_t ek, + cmem_t ciphertext, cmem_t label, + const cbmpc_pve_ecies_p256_hsm_ecdh_t* cb, cmems_t* out_xs); + +// Extract batch count from a batch ciphertext. +cbmpc_error_t cbmpc_pve_batch_get_count(cmem_t ciphertext, int* out_batch_count); + +// Extract {Q_i} from a batch ciphertext (compressed point encodings). +// +// Ownership: +// - On success, `out_Qs_compressed->data` and `out_Qs_compressed->sizes` are +// allocated by the library and must be freed with `cbmpc_cmems_free(*out_Qs_compressed)`. +// - On failure, `*out_Qs_compressed` is set to `{0, NULL, NULL}`. +cbmpc_error_t cbmpc_pve_batch_get_Qs(cmem_t ciphertext, cmems_t* out_Qs_compressed); + +// Extract label from a batch ciphertext. +// +// Ownership: +// - On success, `out_label->data` is allocated by the library and must be freed with +// `cbmpc_cmem_free(*out_label)`. +// - On failure, `*out_label` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_pve_batch_get_Label(cmem_t ciphertext, cmem_t* out_label); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/schnorr_2p.h b/include/cbmpc/c_api/schnorr_2p.h new file mode 100644 index 00000000..7b07fb84 --- /dev/null +++ b/include/cbmpc/c_api/schnorr_2p.h @@ -0,0 +1,87 @@ +#pragma once + +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Interactive key generation for Schnorr-2PC (BIP340). +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_SECP256K1`. +cbmpc_error_t cbmpc_schnorr_2p_dkg(const cbmpc_2pc_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob); + +// Interactive key refresh (same public key). +// +// Ownership: +// - On success, `out_new_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_new_key_blob)`. +// - On failure, `*out_new_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_schnorr_2p_refresh(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t* out_new_key_blob); + +// Sign a message with Schnorr-2PC (BIP340). Outputs a 64-byte signature. +// +// Ownership: +// - On success, `sig_out->data` is allocated by the library and must be freed +// with `cbmpc_cmem_free(*sig_out)`. +// - On failure, `*sig_out` is set to `{NULL, 0}`. +// +// Note: the underlying protocol returns the signature only on P1. If `key_blob` +// belongs to P2, `*sig_out` may be `{NULL, 0}` on success. +cbmpc_error_t cbmpc_schnorr_2p_sign(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t msg, cmem_t* sig_out); + +// Get the Schnorr/BIP340 public key (SEC1 compressed point encoding). +// +// Ownership: +// - On success, `out_pub_key->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_pub_key)`. +// - On failure, `*out_pub_key` is set to `{NULL, 0}`. +// +// Output size is 33 bytes for secp256k1: 0x02/0x03 || x (32 bytes). +cbmpc_error_t cbmpc_schnorr_2p_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key); + +// Extract the Schnorr/BIP340 x-only public key (32 bytes). +// +// Ownership: same as `cbmpc_schnorr_2p_get_public_key_compressed`. +cbmpc_error_t cbmpc_schnorr_2p_extract_public_key_xonly(cmem_t key_blob, cmem_t* out_pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning SEC1 +// compressed point encoding. +// +// Ownership: same as `cbmpc_schnorr_2p_get_public_key_compressed`. +cbmpc_error_t cbmpc_schnorr_2p_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share); + +// Detach a key blob into a public blob + fixed-length private scalar. +// +// Ownership: +// - On success, `out_public_key_blob->data` and `out_private_scalar_fixed->data` are +// allocated by the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, outputs are set to `{NULL, 0}`. +cbmpc_error_t cbmpc_schnorr_2p_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed); + +// Attach a fixed-length private scalar into a public key blob, validating it +// against the expected public share point. +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_schnorr_2p_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/schnorr_mp.h b/include/cbmpc/c_api/schnorr_mp.h new file mode 100644 index 00000000..b44cdc80 --- /dev/null +++ b/include/cbmpc/c_api/schnorr_mp.h @@ -0,0 +1,160 @@ +#pragma once + +#include + +#include +#include +#include + +// All the functions have two versions: additive and ac. Additive means that the +// sharing is additive, while ac means that the sharing is according to a given access structure. +#ifdef __cplusplus +extern "C" { +#endif + +// Interactive multi-party key generation for Schnorr-MP (BIP340). +// +// Ownership: +// - On success, `out_key_blob->data` and `out_sid->data` are allocated by the +// library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, `*out_key_blob` and `*out_sid` are set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_SECP256K1`. +cbmpc_error_t cbmpc_schnorr_mp_dkg_additive(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob, + cmem_t* out_sid); + +// Interactive multi-party key generation for Schnorr-MP (BIP340) with a general +// access structure. +// +// Notes: +// - This is an n-party protocol: **all** parties in `job->party_names` must be +// online and participate. +// - Only the provided `quorum_party_names` actively contribute to the generated +// key shares. +// - The output key blob represents an access-structure key share and +// is not directly usable with `cbmpc_schnorr_mp_sign_additive`. Use `cbmpc_schnorr_mp_sign_ac` +// to sign with an online quorum (it derives additive shares internally). +// - `sid_in` is the in/out session id used by the protocol; callers may pass +// `{NULL, 0}` and let the protocol derive one. +// +// Ownership: +// - On success, `out_ac_key_blob->data` and `out_sid->data` are allocated by +// the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, `*out_ac_key_blob` and `*out_sid` are set to `{NULL, 0}`. +// +// Supported curves: `CBMPC_CURVE_SECP256K1`. +cbmpc_error_t cbmpc_schnorr_mp_dkg_ac(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t sid_in, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* out_ac_key_blob, cmem_t* out_sid); + +// Interactive multi-party key refresh (same public key). +// +// `sid_in` is the in/out session id used by the refresh protocol; callers may +// pass `{NULL, 0}` and let the protocol derive one. If `sid_out` is non-NULL, it +// will receive the session id used. +// +// Ownership: +// - On success, `out_new_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_new_key_blob)`. +// - If `sid_out` is non-NULL, then on success `sid_out->data` is allocated by +// the library and must be freed with `cbmpc_cmem_free(*sid_out)`. +// - On failure, `*out_new_key_blob` (and `*sid_out` when provided) are set to +// `{NULL, 0}`. +cbmpc_error_t cbmpc_schnorr_mp_refresh_additive(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t key_blob, + cmem_t* sid_out, cmem_t* out_new_key_blob); + +// Interactive multi-party access-structure key refresh (same public key). +// +// Notes: +// - See `cbmpc_schnorr_mp_dkg_ac` for protocol participation semantics. +// - The output key blob represents an access-structure key share and +// is not directly usable with `cbmpc_schnorr_mp_sign_additive`. Use `cbmpc_schnorr_mp_sign_ac` +// to sign with an online quorum (it derives additive shares internally). +// - `sid_in` is the in/out session id used by the refresh protocol; callers may +// pass `{NULL, 0}` and let the protocol derive one. If `sid_out` is non-NULL, it +// will receive the session id used. +// +// Ownership: same as `cbmpc_schnorr_mp_refresh_additive`. +cbmpc_error_t cbmpc_schnorr_mp_refresh_ac(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* sid_out, cmem_t* out_new_ac_key_blob); + +// Sign a message with Schnorr-MP (BIP340). Outputs a 64-byte signature on `sig_receiver`. +// +// Ownership: +// - On success, `sig_out->data` is allocated by the library and must be freed +// with `cbmpc_cmem_free(*sig_out)`. +// - On failure, `*sig_out` is set to `{NULL, 0}`. +// +// Note: the underlying protocol returns the signature only on `sig_receiver`. On +// other parties, `*sig_out` may be `{NULL, 0}` on success. +cbmpc_error_t cbmpc_schnorr_mp_sign_additive(const cbmpc_mp_job_t* job, cmem_t key_blob, cmem_t msg, + int32_t sig_receiver, cmem_t* sig_out); + +// Sign a message with Schnorr-MP (BIP340) using an access-structure key share (from +// `cbmpc_schnorr_mp_dkg_ac` / `cbmpc_schnorr_mp_refresh_ac`). +// +// This API first derives an additive-share signing key for the **online** signing +// parties in `job->party_names` and then runs the normal `cbmpc_schnorr_mp_sign_additive` +// protocol among those parties. +// +// Notes: +// - Unlike `cbmpc_schnorr_mp_dkg_ac` / `cbmpc_schnorr_mp_refresh_ac`, `cbmpc_schnorr_mp_sign_ac` +// only requires the parties in `job->party_names` to be online and participate. +// - Output semantics match `cbmpc_schnorr_mp_sign_additive`: the signature is returned only +// on `sig_receiver`. On other parties, `*sig_out` may be `{NULL, 0}` on success. +// +// Ownership: same as `cbmpc_schnorr_mp_sign_additive`. +cbmpc_error_t cbmpc_schnorr_mp_sign_ac(const cbmpc_mp_job_t* job, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, cmem_t msg, + int32_t sig_receiver, cmem_t* sig_out); + +// Get the Schnorr/BIP340 public key (SEC1 compressed point encoding) from a key blob. +// +// Ownership: +// - On success, `out_pub_key->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_pub_key)`. +// - On failure, `*out_pub_key` is set to `{NULL, 0}`. +// +// Output size is 33 bytes for secp256k1: 0x02/0x03 || x (32 bytes). +cbmpc_error_t cbmpc_schnorr_mp_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key); + +// Extract the Schnorr/BIP340 x-only public key (32 bytes) from a key blob. +// +// Ownership: same as `cbmpc_schnorr_mp_get_public_key_compressed`. +cbmpc_error_t cbmpc_schnorr_mp_extract_public_key_xonly(cmem_t key_blob, cmem_t* out_pub_key); + +// --------------------------------------------------------------------------- +// Key blob manipulation (private scalar backup / restore) +// --------------------------------------------------------------------------- + +// Get this party's share public point (Qi) from a key blob, returning SEC1 +// compressed point encoding. +// +// Ownership: same as `cbmpc_schnorr_mp_get_public_key_compressed`. +cbmpc_error_t cbmpc_schnorr_mp_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share); + +// Detach a key blob into a public blob + fixed-length private scalar. +// +// Ownership: +// - On success, `out_public_key_blob->data` and `out_private_scalar_fixed->data` are +// allocated by the library and must be freed with `cbmpc_cmem_free(...)`. +// - On failure, outputs are set to `{NULL, 0}`. +cbmpc_error_t cbmpc_schnorr_mp_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed); + +// Attach a fixed-length private scalar into a public key blob, validating it +// against the expected public share point. +// +// Ownership: +// - On success, `out_key_blob->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_key_blob)`. +// - On failure, `*out_key_blob` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_schnorr_mp_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/c_api/tdh2.h b/include/cbmpc/c_api/tdh2.h new file mode 100644 index 00000000..9f824331 --- /dev/null +++ b/include/cbmpc/c_api/tdh2.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include + +// All the functions have two versions: additive and ac. Additive means that the +// sharing is additive, while ac means that the sharing is according to a given access structure. +#ifdef __cplusplus +extern "C" { +#endif + +// Interactive multi-party key generation for TDH2 (additive shares). +// +// Ownership: +// - On success, outputs are allocated by the library and must be freed with: +// - `cbmpc_cmem_free(*out_public_key)` +// - `cbmpc_cmems_free(*out_public_shares)` +// - `cbmpc_cmem_free(*out_private_share)` +// - `cbmpc_cmem_free(*out_sid)` +// - On failure, output parameters are set to empty values (`{NULL, 0}` / `{0, NULL, NULL}`). +// +// Supported curves: `CBMPC_CURVE_P256`, `CBMPC_CURVE_SECP256K1`. +cbmpc_error_t cbmpc_tdh2_dkg_additive(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_public_key, + cmems_t* out_public_shares, cmem_t* out_private_share, cmem_t* out_sid); + +// Interactive multi-party key generation for TDH2 with a general access structure. +// +// Notes: +// - This is an n-party protocol: **all** parties in `job->party_names` must be +// online and participate. +// - Only the provided `quorum_party_names` actively contribute to the generated +// key shares. +// - `sid_in` is the in/out session id used by the protocol; callers may pass `{NULL, 0}` and let the protocol derive +// one. +// +// Ownership: same as `cbmpc_tdh2_dkg_additive`. +cbmpc_error_t cbmpc_tdh2_dkg_ac(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t sid_in, + const cbmpc_access_structure_t* access_structure, const char* const* quorum_party_names, + int quorum_party_names_count, cmem_t* out_public_key, cmems_t* out_public_shares, + cmem_t* out_private_share, cmem_t* out_sid); + +// Encrypt a plaintext under a TDH2 public key. +// +// Ownership: +// - On success, `out_ciphertext->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_ciphertext)`. +// - On failure, `*out_ciphertext` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_tdh2_encrypt(cmem_t public_key, cmem_t plaintext, cmem_t label, cmem_t* out_ciphertext); + +cbmpc_error_t cbmpc_tdh2_verify(cmem_t public_key, cmem_t ciphertext, cmem_t label); + +// Compute a partial decryption share. +// +// Ownership: +// - On success, `out_partial_decryption->data` is allocated by the library and +// must be freed with `cbmpc_cmem_free(*out_partial_decryption)`. +// - On failure, `*out_partial_decryption` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_tdh2_partial_decrypt(cmem_t private_share, cmem_t ciphertext, cmem_t label, + cmem_t* out_partial_decryption); + +// Combine additive shares / partial decryptions to recover the plaintext. +// +// Ownership: +// - On success, `out_plaintext->data` is allocated by the library and must be +// freed with `cbmpc_cmem_free(*out_plaintext)`. +// - On failure, `*out_plaintext` is set to `{NULL, 0}`. +cbmpc_error_t cbmpc_tdh2_combine_additive(cmem_t public_key, cmems_t public_shares, cmem_t label, + cmems_t partial_decryptions, cmem_t ciphertext, cmem_t* out_plaintext); + +// Combine access-structure shares / partial decryptions to recover the plaintext. +// +// - `party_names` and `public_shares` define the mapping name -> Qi for *all* +// parties in the access structure (order-aligned). +// - `partial_decryption_party_names` and `partial_decryptions` provide the quorum subset used for decryption. +// +// Requirements: +// - `party_names_count == public_shares.count` +// - `partial_decryption_party_names_count == partial_decryptions.count` +// - The leaf set of `access_structure` must match `party_names` exactly. +// +// Ownership: same as `cbmpc_tdh2_combine_additive`. +cbmpc_error_t cbmpc_tdh2_combine_ac(const cbmpc_access_structure_t* access_structure, cmem_t public_key, + const char* const* party_names, int party_names_count, cmems_t public_shares, + cmem_t label, const char* const* partial_decryption_party_names, + int partial_decryption_party_names_count, cmems_t partial_decryptions, + cmem_t ciphertext, cmem_t* out_plaintext); + +#ifdef __cplusplus +} +#endif diff --git a/include/cbmpc/core/access_structure.h b/include/cbmpc/core/access_structure.h new file mode 100644 index 00000000..bec5af10 --- /dev/null +++ b/include/cbmpc/core/access_structure.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include + +namespace coinbase::api { + +// Access structure used by threshold protocols (e.g., threshold DKG / refresh). +// +// This is a lightweight view type: +// - Leaf party names are represented as `std::string_view` and must outlive the +// protocol call that consumes the access structure. +// (For example, do not construct leaves from temporary `std::string`s.) +// - Internal node names are not exposed; the library assigns deterministic, +// unique internal names when converting to the internal representation. +// +// Notes: +// - Leaf names are expected to match the `job_mp_t::party_names` values. +// - The root node is unnamed in the internal representation; therefore the root +// of an access structure cannot be a leaf node. +struct access_structure_t { + enum class node_type : uint8_t { + leaf = 1, + and_node = 2, + or_node = 3, + threshold = 4, + }; + + node_type type = node_type::leaf; + + // Leaf party name (only meaningful when `type == node_type::leaf`). + std::string_view leaf_name; + + // Threshold parameter k (only meaningful when `type == node_type::threshold`). + // Must satisfy: 1 <= k <= children.size(). + int32_t threshold_k = 0; + + // Child nodes (only meaningful when `type != node_type::leaf`). + std::vector children; + + static access_structure_t leaf(std::string_view party_name) { + access_structure_t n; + n.type = node_type::leaf; + n.leaf_name = party_name; + return n; + } + + static access_structure_t And(std::vector ch) { + access_structure_t n; + n.type = node_type::and_node; + n.children = std::move(ch); + return n; + } + + static access_structure_t Or(std::vector ch) { + access_structure_t n; + n.type = node_type::or_node; + n.children = std::move(ch); + return n; + } + + static access_structure_t Threshold(int32_t k, std::vector ch) { + access_structure_t n; + n.type = node_type::threshold; + n.threshold_k = k; + n.children = std::move(ch); + return n; + } +}; + +} // namespace coinbase::api diff --git a/include/cbmpc/core/bip32_path.h b/include/cbmpc/core/bip32_path.h new file mode 100644 index 00000000..13eed963 --- /dev/null +++ b/include/cbmpc/core/bip32_path.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace coinbase::api { + +// BIP32 derivation path: a sequence of 32-bit child indices. +// +// This is a lightweight value type shared across HD keyset APIs. +// Index interpretation (hardened vs non-hardened) is defined by the caller / spec. +struct bip32_path_t { + std::vector indices; +}; + +} // namespace coinbase::api diff --git a/src/cbmpc/core/buf.h b/include/cbmpc/core/buf.h similarity index 80% rename from src/cbmpc/core/buf.h rename to include/cbmpc/core/buf.h index 80133080..1962e755 100644 --- a/src/cbmpc/core/buf.h +++ b/include/cbmpc/core/buf.h @@ -1,15 +1,30 @@ #pragma once +#include +#include #include #include namespace coinbase { void memmove_reverse(byte_ptr dst, const_byte_ptr src, int size); -inline void bzero(byte_ptr pointer, int size) { memset(pointer, 0, size); } +inline void bzero(byte_ptr pointer, int size) { + cb_assert(size >= 0 && "bzero: negative size"); + if (size == 0) return; + cb_assert(pointer && "bzero: null pointer"); + memset(pointer, 0, static_cast(size)); +} inline void secure_bzero(byte_ptr pointer, int size) { - volatile unsigned char* p = pointer; - while (size--) *p++ = 0; + if (size <= 0) return; + cb_assert(pointer && "secure_bzero: null pointer"); +#if defined(__STDC_LIB_EXT1__) + // `memset_s` is guaranteed to perform the memory write and not be optimized away. + (void)memset_s(pointer, static_cast(size), 0, static_cast(size)); +#else + // Best-effort fallback that compilers cannot elide. + volatile byte_t* p = pointer; + for (int i = 0; i < size; i++) p[i] = 0; +#endif } template @@ -33,20 +48,22 @@ class buf_t; class converter_t; struct mem_t { - byte_ptr data; + const_byte_ptr data; int size; mem_t() noexcept(true) : data(0), size(0) {} - mem_t(const_byte_ptr the_data, int the_size) noexcept(true) : data(byte_ptr(the_data)), size(the_size) {} - mem_t(const std::string& s) noexcept(true) : data(byte_ptr(s.data())), size(int(s.size())) { + mem_t(const_byte_ptr the_data, int the_size) noexcept(true) : data(the_data), size(the_size) {} + mem_t(const std::string& s) noexcept(true) : data(const_byte_ptr(s.data())), size(int(s.size())) { cb_assert(s.size() <= INT_MAX); } template - mem_t(const char (&s)[N]) : data(byte_ptr(s)), size(N) { + mem_t(const char (&s)[N]) : data(const_byte_ptr(s)), size(N) { if (N > 0 && s[N - 1] == '\0') size--; // zero-terminated } - void bzero() { coinbase::bzero(data, size); } - void secure_bzero() { coinbase::secure_bzero(data, size); } + // NOTE: `mem_t` is often used as a read-only view. These mutating helpers are available for + // cases where the underlying memory is known to be writable. + void bzero() { coinbase::bzero(const_cast(data), size); } + void secure_bzero() { coinbase::secure_bzero(const_cast(data), size); } void reverse(); buf_t rev() const; @@ -61,7 +78,7 @@ struct mem_t { } uint8_t& operator[](int index) { cb_assert(index >= 0 && index < size); - return data[index]; + return const_cast(data)[index]; } mem_t range(int offset, int len) const { @@ -88,18 +105,17 @@ struct mem_t { } // namespace coinbase -using coinbase::mem_t; - +namespace coinbase { std::ostream& operator<<(std::ostream& os, mem_t mem); - -#include "buf128.h" -#include "buf256.h" +} // namespace coinbase namespace coinbase { class buf_t { public: buf_t() noexcept(true); + // NOTE: Allocates storage but intentionally does not initialize it. + // Callers must fully overwrite the buffer before reading from `data()`. explicit buf_t(int new_size); buf_t(const_byte_ptr src, int src_size); buf_t(mem_t mem); @@ -116,6 +132,8 @@ class buf_t { int size() const; bool empty() const; byte_ptr resize(int new_size); + // NOTE: Like `buf_t(int)`, `alloc()` leaves the returned buffer contents uninitialized. + // Callers must fully overwrite the buffer before reading from `data()`. byte_ptr alloc(int new_size); void bzero(); void secure_bzero(); @@ -271,8 +289,6 @@ class bits_t { void set(int index, bool value); void append(bool bit); - static bool equ(const bits_t& src1, const bits_t& src2); - private: void copy_from(const bits_t& src); void bzero_unused() const; @@ -282,5 +298,3 @@ class bits_t { }; } // namespace coinbase - -using coinbase::buf_t; diff --git a/src/cbmpc/core/buf128.h b/include/cbmpc/core/buf128.h similarity index 97% rename from src/cbmpc/core/buf128.h rename to include/cbmpc/core/buf128.h index 30095407..aab2cdce 100755 --- a/src/cbmpc/core/buf128.h +++ b/include/cbmpc/core/buf128.h @@ -6,6 +6,7 @@ namespace coinbase { class converter_t; +struct mem_t; #if defined(__x86_64__) typedef __m128i u128_t; @@ -40,7 +41,7 @@ struct buf128_t { static buf128_t zero() { return u128(u128_zero()); } - operator mem_t() const { return mem_t(byte_ptr(this), sizeof(buf128_t)); } + operator mem_t() const; buf128_t& operator=(std::nullptr_t); // zeroization buf128_t& operator=(mem_t); diff --git a/src/cbmpc/core/buf256.h b/include/cbmpc/core/buf256.h similarity index 94% rename from src/cbmpc/core/buf256.h rename to include/cbmpc/core/buf256.h index 63555554..4b2118ac 100755 --- a/src/cbmpc/core/buf256.h +++ b/include/cbmpc/core/buf256.h @@ -1,9 +1,13 @@ #pragma once -#include "buf128.h" +#include #define ZERO256 (buf256_t::zero()) namespace coinbase { +class buf_t; +class converter_t; +struct mem_t; + struct buf256_t { buf128_t lo, hi; @@ -13,7 +17,7 @@ struct buf256_t { operator const_byte_ptr() const { return const_byte_ptr(this); } operator byte_ptr() { return byte_ptr(this); } - operator mem_t() const { return mem_t(byte_ptr(this), sizeof(buf256_t)); } + operator mem_t() const; static buf256_t zero() { return make(ZERO128, ZERO128); } static buf256_t make(buf128_t lo, buf128_t hi = ZERO128); diff --git a/src/cbmpc/core/error.h b/include/cbmpc/core/error.h similarity index 97% rename from src/cbmpc/core/error.h rename to include/cbmpc/core/error.h index 4f588c50..435870e9 100755 --- a/src/cbmpc/core/error.h +++ b/include/cbmpc/core/error.h @@ -1,7 +1,9 @@ #pragma once #include -typedef int error_t; +namespace coinbase { +using error_t = int; +} // namespace coinbase #define ERRCODE(category, code) (0xff000000 | (uint32_t(category) << 16) | uint32_t(code)) #define ECATEGORY(code) (((code) >> 16) & 0x00ff) diff --git a/include/cbmpc/core/job.h b/include/cbmpc/core/job.h new file mode 100644 index 00000000..7fe7ff9f --- /dev/null +++ b/include/cbmpc/core/job.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace coinbase::api { + +// Party index (0..n-1) used by transports and multi-party protocols. +using party_idx_t = int32_t; + +// Two-party role used by 2PC protocols. +enum class party_2p_t : party_idx_t { + p1 = 0, + p2 = 1, +}; + +// Transport abstraction for MPC protocols. +// +// Implementations are expected to be *blocking*: `receive` should wait until a +// message from `sender` is available. +class data_transport_i { + public: + virtual ~data_transport_i() = default; + + virtual error_t send(party_idx_t receiver, mem_t msg) = 0; + virtual error_t receive(party_idx_t sender, buf_t& msg) = 0; + + // Multi-party receive. Implementations may choose any ordering contract; + // cbmpc callers should pass the same `senders` ordering on both sides. + virtual error_t receive_all(const std::vector& senders, std::vector& msgs) = 0; +}; + +// Execution context for 2-party protocols. +// +// Notes: +// - This is a lightweight view type. It does not own the transport object or +// the name strings. +// - The caller must ensure referenced objects (including the character data +// backing the `std::string_view`s) outlive the protocol call. +// (For example, do not pass temporary `std::string`s when constructing a job.) +struct job_2p_t { + party_2p_t self; + std::string_view p1_name; + std::string_view p2_name; + data_transport_i& transport; +}; + +// Execution context for multi-party protocols. +// +// Notes: +// - This is a lightweight view type. It does not own the transport object or +// the name strings. +// - `self` is the caller party index in `party_names`. +// - The caller must ensure referenced objects (including the character data +// backing `party_names`) outlive the protocol call. +// (For example, do not build `party_names` from temporary `std::string`s.) +struct job_mp_t { + party_idx_t self; + std::vector party_names; + data_transport_i& transport; +}; + +} // namespace coinbase::api diff --git a/src/cbmpc/core/macros.h b/include/cbmpc/core/macros.h similarity index 94% rename from src/cbmpc/core/macros.h rename to include/cbmpc/core/macros.h index b2ebcd50..6338ca1e 100755 --- a/src/cbmpc/core/macros.h +++ b/include/cbmpc/core/macros.h @@ -1,10 +1,10 @@ #pragma once -#include "precompiled.h" +#include #ifdef __APPLE__ -#include "TargetConditionals.h" +#include #if TARGET_OS_IOS || TARGET_OS_TV || TARGET_OS_WATCH #define TARGET_OS_IOSX 1 diff --git a/src/cbmpc/core/precompiled.h b/include/cbmpc/core/precompiled.h similarity index 98% rename from src/cbmpc/core/precompiled.h rename to include/cbmpc/core/precompiled.h index 0cf68224..dc30ad14 100755 --- a/src/cbmpc/core/precompiled.h +++ b/include/cbmpc/core/precompiled.h @@ -2,7 +2,7 @@ #define DY_PRECOMPILED_H #ifdef __APPLE__ -#include "TargetConditionals.h" +#include #endif #if defined(__x86_64__) diff --git a/scripts/auto_build_cpp.sh b/scripts/auto_build_cpp.sh deleted file mode 100644 index 069b9caa..00000000 --- a/scripts/auto_build_cpp.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -# Auto-rebuild the C++ library if sources changed since the last build. - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" - -BUILD_TYPE="${BUILD_TYPE:-Release}" - -SRC_DIR="${REPO_ROOT}/src" -LIB_CANDIDATES=( - "${REPO_ROOT}/build/${BUILD_TYPE}/lib/libcbmpc.a" - "${REPO_ROOT}/lib/${BUILD_TYPE}/libcbmpc.a" -) - -stat_mtime() { - if [[ "$(uname)" == "Darwin" ]]; then - stat -f "%m" "$1" - else - stat -c "%Y" "$1" - fi -} - -latest_src_mtime() { - # Consider C++ sources and headers - local latest=0 - while IFS= read -r -d '' f; do - local t - t=$(stat_mtime "$f") - if (( t > latest )); then - latest=$t - fi - done < <(find "${SRC_DIR}" -type f \( -name '*.cpp' -o -name '*.h' \) -print0) - echo "$latest" -} - -need_build=1 -for lib in "${LIB_CANDIDATES[@]}"; do - if [[ -f "$lib" ]]; then - lib_mtime=$(stat_mtime "$lib") - src_mtime=$(latest_src_mtime) - if (( src_mtime > lib_mtime )); then - need_build=1 - else - need_build=0 - fi - break - fi -done - -if (( need_build == 1 )); then - echo "[auto_build_cpp] Building C++ library (${BUILD_TYPE})..." - make -C "${REPO_ROOT}" build-no-test BUILD_TYPE="${BUILD_TYPE}" -else - echo "[auto_build_cpp] C++ library up-to-date (${BUILD_TYPE})." -fi - - diff --git a/scripts/go_with_cpp.sh b/scripts/go_with_cpp.sh deleted file mode 100644 index b705f325..00000000 --- a/scripts/go_with_cpp.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" - -BUILD_TYPE="${BUILD_TYPE:-Release}" - -DO_CD=1 -if [[ $# -gt 0 && "$1" == "--no-cd" ]]; then - DO_CD=0 - shift -fi - -INC_DIR="${REPO_ROOT}/src" -LIB_DIRS=( - "${REPO_ROOT}/build/${BUILD_TYPE}/lib" - "${REPO_ROOT}/lib/${BUILD_TYPE}" -) - -LDFLAGS_ACCUM=() -for d in "${LIB_DIRS[@]}"; do - LDFLAGS_ACCUM+=("-L${d}") -done - -export CGO_CFLAGS="-I${INC_DIR}" -export CGO_CXXFLAGS="-I${INC_DIR}" -export CGO_LDFLAGS="${LDFLAGS_ACCUM[*]}" -export BUILD_TYPE - -bash "${SCRIPT_DIR}/auto_build_cpp.sh" - -if [[ ${DO_CD} -eq 1 ]]; then - cd "${REPO_ROOT}/demos-go/cb-mpc-go" -fi - -bash "${SCRIPT_DIR}/auto_build_cpp.sh" -exec "$@" - - diff --git a/scripts/install.sh b/scripts/install.sh index 88013809..78c0e066 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -4,27 +4,129 @@ set -e SCRIPT_PATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" +usage() { + cat <<'EOF' +Usage: + scripts/install.sh [--mode public|full] [--prefix ] [--build-type ] + +Options: + --mode Install mode (default: public). Also supports $CBMPC_INSTALL_MODE. + - public: install only headers under include/ + - full: also install internal headers under include-internal/ + --prefix Install prefix (default: /build/install/). Also supports $CBMPC_PREFIX. + --build-type Build type for selecting the built library artifact to install + (default: Release). Also supports $CBMPC_BUILD_TYPE. + +Examples (flags can be in any order): + scripts/install.sh --mode public + scripts/install.sh --prefix /tmp/cbmpc --mode full + CBMPC_PREFIX=/tmp/cbmpc scripts/install.sh --mode public +EOF +} + +# Install mode: +# - public: install only curated public headers under include/ +# - full: additionally install internal headers under include-internal/ +INSTALL_MODE="${CBMPC_INSTALL_MODE:-public}" +# Build type controls which lib artifact we copy from `lib//`. +BUILD_TYPE="${CBMPC_BUILD_TYPE:-Release}" +# Install prefix can be customized to avoid requiring sudo. +# Default: /build/install/ +DST_PARENT_DIR="${CBMPC_PREFIX:-}" + +while [[ $# -gt 0 ]]; do + case "$1" in + --mode) + if [[ $# -lt 2 ]]; then + echo "Missing value for --mode" + usage + exit 1 + fi + INSTALL_MODE="$2" + shift 2 + ;; + --prefix) + if [[ $# -lt 2 ]]; then + echo "Missing value for --prefix" + usage + exit 1 + fi + DST_PARENT_DIR="$2" + shift 2 + ;; + --build-type) + if [[ $# -lt 2 ]]; then + echo "Missing value for --build-type" + usage + exit 1 + fi + BUILD_TYPE="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" + usage + exit 1 + ;; + esac +done + # Other necessary paths ROOT_DIR="$SCRIPT_PATH/.." -DST_PARENT_DIR=/usr/local/opt/cbmpc -SRC_DIR="$ROOT_DIR/src/" +PUBLIC_INCLUDE_DIR="$ROOT_DIR/include/" +INTERNAL_INCLUDE_DIR="$ROOT_DIR/include-internal/" + +if [[ -z "$DST_PARENT_DIR" ]]; then + DST_PARENT_DIR="$ROOT_DIR/build/install/$INSTALL_MODE" +fi + DST_DIR="$DST_PARENT_DIR/include/" -LIB_SRC_DIR="$ROOT_DIR/lib/Release" +LIB_SRC_DIR="$ROOT_DIR/lib/$BUILD_TYPE" LIB_DST_DIR="$DST_PARENT_DIR/lib/" -# Check if destination directory exists. If not, create it -if [ ! -d "$DST_PARENT_DIR" ]; then - mkdir -p "$DST_DIR" - mkdir -p "$LIB_DST_DIR" +# Validate mode +case "$INSTALL_MODE" in + public|full) ;; + *) + echo "Unknown install mode: $INSTALL_MODE (expected: public|full)" + exit 1 + ;; +esac + +# Validate build type / lib path. +LIB_FILE="$LIB_SRC_DIR/libcbmpc.a" +if [[ ! -f "$LIB_FILE" ]]; then + echo "Missing built library: $LIB_FILE" + echo "Build it first, e.g.: make build-no-test BUILD_TYPE=$BUILD_TYPE" + exit 1 fi -# Find and copy header files +# Ensure destination directories exist. +mkdir -p "$LIB_DST_DIR" + +# Refresh include directory to avoid stale headers lingering across installs. +rm -rf "$DST_DIR" +mkdir -p "$DST_DIR" + +# Copy public headers rsync -avm \ - --exclude='*/build/*' \ --include='*.h' \ --include='*/' \ --exclude='*' \ - "$SRC_DIR/" "$DST_DIR/" + "$PUBLIC_INCLUDE_DIR/" "$DST_DIR/" + +# Copy internal headers (full mode only) +if [[ "$INSTALL_MODE" == "full" ]]; then + rsync -avm \ + --include='*.h' \ + --include='*/' \ + --exclude='*' \ + "$INTERNAL_INCLUDE_DIR/" "$DST_DIR/" +fi # Copy library files FILES=("libcbmpc.a") @@ -32,4 +134,4 @@ for file in "${FILES[@]}"; do rsync -av "$LIB_SRC_DIR/$file" "$LIB_DST_DIR/" done -echo "All header and library files have been copied to $DST_DIR" \ No newline at end of file +echo "Installed ($INSTALL_MODE) headers + library ($BUILD_TYPE) to $DST_PARENT_DIR" diff --git a/scripts/make-release.sh b/scripts/make-release.sh index 2723ab3c..4e5b59f7 100755 --- a/scripts/make-release.sh +++ b/scripts/make-release.sh @@ -12,7 +12,7 @@ cd $ROOT_PATH if [ "$#" -ne 1 ]; then echo "Usage: $0 " - echo " ref_name: The release tag name (e.g. v1.0.0)" + echo " ref_name: The release tag name (e.g. cb-mpc-1.0.0)" exit 1 fi @@ -22,7 +22,7 @@ make clean make clean-demos make clean-bench -tar -czf "cb-mpc-${REF_NAME}.tar.gz" \ +tar -czf "${REF_NAME}.tar.gz" \ --exclude='.git' \ --exclude='.github' \ --exclude='.buildkite' \ diff --git a/scripts/openssl/build-static-openssl-linux.sh b/scripts/openssl/build-static-openssl-linux.sh index a197fc91..31d6cbd9 100755 --- a/scripts/openssl/build-static-openssl-linux.sh +++ b/scripts/openssl/build-static-openssl-linux.sh @@ -3,13 +3,13 @@ set -e cd /tmp -curl -L https://github.com/openssl/openssl/releases/download/openssl-3.2.0/openssl-3.2.0.tar.gz --output openssl-3.2.0.tar.gz -expectedHash='14c826f07c7e433706fb5c69fa9e25dab95684844b4c962a2cf1bf183eb4690e' -fileHash=$(sha256sum openssl-3.2.0.tar.gz | cut -d " " -f 1 ) +curl -L https://github.com/openssl/openssl/releases/download/openssl-3.6.1/openssl-3.6.1.tar.gz --output openssl-3.6.1.tar.gz +expectedHash='b1bfedcd5b289ff22aee87c9d600f515767ebf45f77168cb6d64f231f518a82e' +fileHash=$(sha256sum openssl-3.6.1.tar.gz | cut -d " " -f 1 ) if [ $expectedHash != $fileHash ] then - echo 'ERROR: SHA1 DOES NOT MATCH!' + echo 'ERROR: SHA256 DOES NOT MATCH!' echo 'expected: ' $expectedHash echo 'file: ' $fileHash exit 1 @@ -18,8 +18,8 @@ fi echo "LINUX Start" uname -r -tar -xzf openssl-3.2.0.tar.gz -cd openssl-3.2.0 +tar -xzf openssl-3.6.1.tar.gz +cd openssl-3.6.1 sed -i -e 's/^static//' crypto/ec/curve25519.c @@ -29,7 +29,7 @@ sed -i -e 's/^static//' crypto/ec/curve25519.c no-gost no-http no-idea no-mdc2 no-md2 no-md4 no-module no-nextprotoneg no-ocb no-ocsp no-psk no-padlockeng no-poly1305 \ no-quic no-rc2 no-rc4 no-rc5 no-rfc3779 no-scrypt no-sctp no-seed no-siphash no-sm2 no-sm3 no-sm4 no-sock no-srtp no-srp \ no-ssl-trace no-ssl3 no-stdio no-tests no-tls no-ts no-unit-test no-uplink no-whirlpool no-zlib \ - --prefix=/usr/local/opt/openssl@3.2.0 --libdir=lib64 + --prefix=/usr/local/opt/openssl@3.6.1 --libdir=lib64 make build_generated install_sw -j4 diff --git a/scripts/openssl/build-static-openssl-macos-m1.sh b/scripts/openssl/build-static-openssl-macos-m1.sh index 7291e903..11b6c878 100755 --- a/scripts/openssl/build-static-openssl-macos-m1.sh +++ b/scripts/openssl/build-static-openssl-macos-m1.sh @@ -1,22 +1,41 @@ set -e +# The main project (when built with upstream Clang 20+) targets macOS 16.0 by default +# (e.g., via `-platform_version macos 16.0 26.x`). If OpenSSL is built with AppleClang, +# it will typically default to targeting the *current* macOS version (26.x), which then +# causes linker warnings like: +# "object file ... was built for newer 'macOS' version (26.0) than being linked (16.0)" +# +# Prefer upstream Clang if available (matches project recommendation) and default the +# deployment target to 16.0 unless the caller overrides it. +if [ -z "${CC:-}" ]; then + if [ -x "/opt/homebrew/opt/llvm/bin/clang" ]; then + export CC="/opt/homebrew/opt/llvm/bin/clang" + export CXX="/opt/homebrew/opt/llvm/bin/clang++" + elif [ -x "/usr/local/opt/llvm/bin/clang" ]; then + export CC="/usr/local/opt/llvm/bin/clang" + export CXX="/usr/local/opt/llvm/bin/clang++" + fi +fi +export MACOSX_DEPLOYMENT_TARGET="${MACOSX_DEPLOYMENT_TARGET:-16.0}" + cd /tmp -curl -L https://github.com/openssl/openssl/releases/download/openssl-3.2.0/openssl-3.2.0.tar.gz --output openssl-3.2.0.tar.gz -expectedHash='14c826f07c7e433706fb5c69fa9e25dab95684844b4c962a2cf1bf183eb4690e' -fileHash=$(sha256sum openssl-3.2.0.tar.gz | cut -d " " -f 1 ) +curl -L https://github.com/openssl/openssl/releases/download/openssl-3.6.1/openssl-3.6.1.tar.gz --output openssl-3.6.1.tar.gz +expectedHash='b1bfedcd5b289ff22aee87c9d600f515767ebf45f77168cb6d64f231f518a82e' +fileHash=$(sha256sum openssl-3.6.1.tar.gz | cut -d " " -f 1 ) if [ $expectedHash != $fileHash ] then - echo 'ERROR: SHA1 DOES NOT MATCH!' + echo 'ERROR: SHA256 DOES NOT MATCH!' echo 'expected: ' $expectedHash echo 'file: ' $fileHash exit 1 fi -tar -xzf openssl-3.2.0.tar.gz -cd openssl-3.2.0 +tar -xzf openssl-3.6.1.tar.gz +cd openssl-3.6.1 sed -i -e 's/^static//' crypto/ec/curve25519.c @@ -26,7 +45,7 @@ sed -i -e 's/^static//' crypto/ec/curve25519.c no-gost no-http no-idea no-mdc2 no-md2 no-md4 no-module no-nextprotoneg no-ocb no-ocsp no-psk no-padlockeng no-poly1305 \ no-quic no-rc2 no-rc4 no-rc5 no-rfc3779 no-scrypt no-sctp no-seed no-siphash no-sm2 no-sm3 no-sm4 no-sock no-srtp no-srp \ no-ssl-trace no-ssl3 no-stdio no-tests no-tls no-ts no-unit-test no-uplink no-whirlpool no-zlib \ - --prefix=/usr/local/opt/openssl@3.2.0 darwin64-arm64-cc + --prefix=/usr/local/opt/openssl@3.6.1 darwin64-arm64-cc make -j make install_sw diff --git a/scripts/openssl/build-static-openssl-macos.sh b/scripts/openssl/build-static-openssl-macos.sh index da06d3b4..601e3362 100755 --- a/scripts/openssl/build-static-openssl-macos.sh +++ b/scripts/openssl/build-static-openssl-macos.sh @@ -1,21 +1,40 @@ set -e +# The main project (when built with upstream Clang 20+) targets macOS 16.0 by default +# (e.g., via `-platform_version macos 16.0 26.x`). If OpenSSL is built with AppleClang, +# it will typically default to targeting the *current* macOS version (26.x), which then +# causes linker warnings like: +# "object file ... was built for newer 'macOS' version (26.0) than being linked (16.0)" +# +# Prefer upstream Clang if available (matches project recommendation) and default the +# deployment target to 16.0 unless the caller overrides it. +if [ -z "${CC:-}" ]; then + if [ -x "/opt/homebrew/opt/llvm/bin/clang" ]; then + export CC="/opt/homebrew/opt/llvm/bin/clang" + export CXX="/opt/homebrew/opt/llvm/bin/clang++" + elif [ -x "/usr/local/opt/llvm/bin/clang" ]; then + export CC="/usr/local/opt/llvm/bin/clang" + export CXX="/usr/local/opt/llvm/bin/clang++" + fi +fi +export MACOSX_DEPLOYMENT_TARGET="${MACOSX_DEPLOYMENT_TARGET:-16.0}" + cd /tmp -curl -L https://github.com/openssl/openssl/releases/download/openssl-3.2.0/openssl-3.2.0.tar.gz --output openssl-3.2.0.tar.gz -expectedHash='14c826f07c7e433706fb5c69fa9e25dab95684844b4c962a2cf1bf183eb4690e' -fileHash=$(sha256sum openssl-3.2.0.tar.gz | cut -d " " -f 1 ) +curl -L https://github.com/openssl/openssl/releases/download/openssl-3.6.1/openssl-3.6.1.tar.gz --output openssl-3.6.1.tar.gz +expectedHash='b1bfedcd5b289ff22aee87c9d600f515767ebf45f77168cb6d64f231f518a82e' +fileHash=$(sha256sum openssl-3.6.1.tar.gz | cut -d " " -f 1 ) if [ $expectedHash != $fileHash ] then - echo 'ERROR: SHA1 DOES NOT MATCH!' + echo 'ERROR: SHA256 DOES NOT MATCH!' echo 'expected: ' $expectedHash echo 'file: ' $fileHash exit 1 fi -tar -xzf openssl-3.2.0.tar.gz -cd openssl-3.2.0 +tar -xzf openssl-3.6.1.tar.gz +cd openssl-3.6.1 sed -i -e 's/^static//' crypto/ec/curve25519.c @@ -25,7 +44,7 @@ sed -i -e 's/^static//' crypto/ec/curve25519.c no-gost no-http no-idea no-mdc2 no-md2 no-md4 no-module no-nextprotoneg no-ocb no-ocsp no-psk no-padlockeng no-poly1305 \ no-quic no-rc2 no-rc4 no-rc5 no-rfc3779 no-scrypt no-sctp no-seed no-siphash no-sm2 no-sm3 no-sm4 no-sock no-srtp no-srp \ no-ssl-trace no-ssl3 no-stdio no-tests no-tls no-ts no-unit-test no-uplink no-whirlpool no-zlib \ - --prefix=/usr/local/opt/openssl@3.2.0 darwin64-x86_64-cc + --prefix=/usr/local/opt/openssl@3.6.1 darwin64-x86_64-cc make -j make install_sw diff --git a/scripts/run-demos.sh b/scripts/run-demos.sh index 9b629ee0..d9e70704 100755 --- a/scripts/run-demos.sh +++ b/scripts/run-demos.sh @@ -8,46 +8,166 @@ SCRIPT_PATH="$( )" ROOT_PATH="${SCRIPT_PATH}/.." -DEMOS_CPP_DIR="${ROOT_PATH}/demos-cpp" -DEMOS_GO_DIR="${ROOT_PATH}/demos-go/examples" +DEMOS_CPP_DIR="${ROOT_PATH}/demo-cpp" +DEMO_API_DIR="${ROOT_PATH}/demo-api" +DEMO_GO_DIR="${ROOT_PATH}/demo-go" -CPP_DEMOS=("basic_primitive" "zk") -GO_DEMOS=("access-structure" "agreerandom" "ecdsa-2pc" "ecdsa-mpc-with-backup" "zk") +BUILD_TYPE="${BUILD_TYPE:-Release}" +CBMPC_PREFIX_PUBLIC="${CBMPC_PREFIX_PUBLIC:-${ROOT_PATH}/build/install/public}" +CBMPC_PREFIX_FULL="${CBMPC_PREFIX_FULL:-${ROOT_PATH}/build/install/full}" +# OpenSSL path is used by demos (C++ via CMake; Go via CGO_LDFLAGS). Keep it +# configurable and consistent with `cmake/openssl.cmake`. +CBMPC_OPENSSL_ROOT="${CBMPC_OPENSSL_ROOT:-/usr/local/opt/openssl@3.6.1}" + +CPP_DEMOS=("basic_primitive" "zk" "parallel_transport") +API_DEMOS=("pve" "hd_keyset_ecdsa_2p" "ecdsa_mp_pve_backup" "schnorr_2p_pve_batch_backup") +GO_DEMOS=("pve" "sign" "tdh2" "eddsa_mp_pve_ac_backup") + +go_is_supported() { + if ! command -v go >/dev/null 2>&1; then + echo "Skipping Go demos (go not installed)." + return 1 + fi + + local ver + ver="$(go version 2>/dev/null || true)" + # Example: "go version go1.21.5 darwin/arm64" + if [[ "$ver" =~ go([0-9]+)\.([0-9]+) ]]; then + local major="${BASH_REMATCH[1]}" + local minor="${BASH_REMATCH[2]}" + if (( major > 1 || (major == 1 && minor >= 20) )); then + return 0 + fi + echo "Skipping Go demos (Go >= 1.20 required; found go${major}.${minor})." + return 1 + fi + + # If parsing fails, attempt to run anyway. + return 0 +} + +resolve_openssl_lib_dir() { + local root="$1" + if [[ -z "$root" ]]; then + return 1 + fi + + # Prefer lib64 when present (common on Linux) *and* contains libcrypto. + if [[ -d "${root}/lib64" ]] && compgen -G "${root}/lib64/libcrypto.*" >/dev/null; then + echo "${root}/lib64" + return 0 + fi + if [[ -d "${root}/lib" ]] && compgen -G "${root}/lib/libcrypto.*" >/dev/null; then + echo "${root}/lib" + return 0 + fi + + # Fall back to lib directories even if we can't find a concrete libcrypto + # artifact (caller will still attempt to link and may succeed via defaults). + if [[ -d "${root}/lib64" ]]; then + echo "${root}/lib64" + return 0 + fi + if [[ -d "${root}/lib" ]]; then + echo "${root}/lib" + return 0 + fi + return 1 +} clean() { for proj in ${CPP_DEMOS[@]}; do rm -rf $DEMOS_CPP_DIR/$proj/build/ done + for proj in ${API_DEMOS[@]}; do + rm -rf $DEMO_API_DIR/$proj/build/ + done } build_all_cpp() { for proj in ${CPP_DEMOS[@]}; do cd ${DEMOS_CPP_DIR}/$proj - cmake -Bbuild - cmake --build build/ + cmake -Bbuild/${BUILD_TYPE} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DCBMPC_SOURCE_DIR=${CBMPC_PREFIX_FULL} + cmake --build build/${BUILD_TYPE}/ done } run_all_cpp() { build_all_cpp for proj in ${CPP_DEMOS[@]}; do - ${DEMOS_CPP_DIR}/$proj/build/mpc-demo-$proj + ${DEMOS_CPP_DIR}/$proj/build/${BUILD_TYPE}/mpc-demo-$proj + done +} + +build_all_api() { + for proj in ${API_DEMOS[@]}; do + cd ${DEMO_API_DIR}/$proj + cmake -Bbuild/${BUILD_TYPE} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DCBMPC_SOURCE_DIR=${CBMPC_PREFIX_PUBLIC} + cmake --build build/${BUILD_TYPE}/ + done +} + +run_all_api() { + build_all_api + for proj in ${API_DEMOS[@]}; do + ${DEMO_API_DIR}/$proj/build/${BUILD_TYPE}/mpc-demo-api-$proj done } run_all_go() { - # cd $ROOT_PATH - # make install + if ! go_is_supported; then + return 0 + fi + + local openssl_lib_dir + openssl_lib_dir="$(resolve_openssl_lib_dir "${CBMPC_OPENSSL_ROOT}" || true)" + if [[ -z "${openssl_lib_dir}" ]]; then + echo "Skipping Go demos (OpenSSL not found under CBMPC_OPENSSL_ROOT='${CBMPC_OPENSSL_ROOT}')." + return 0 + fi + for proj in ${GO_DEMOS[@]}; do - run_go_demo $proj + cd ${DEMO_GO_DIR}/$proj + CGO_ENABLED=1 \ + CGO_CFLAGS="-I${CBMPC_PREFIX_PUBLIC}/include" \ + CGO_LDFLAGS="-L${CBMPC_PREFIX_PUBLIC}/lib -L${openssl_lib_dir}" \ + go run . done } +run_cpp_demo() { + local proj="$1" + cd "${DEMOS_CPP_DIR}/${proj}" + cmake -Bbuild/${BUILD_TYPE} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DCBMPC_SOURCE_DIR=${CBMPC_PREFIX_FULL} + cmake --build build/${BUILD_TYPE}/ + "${DEMOS_CPP_DIR}/${proj}/build/${BUILD_TYPE}/mpc-demo-${proj}" +} + +run_api_demo() { + local proj="$1" + cd "${DEMO_API_DIR}/${proj}" + cmake -Bbuild/${BUILD_TYPE} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DCBMPC_SOURCE_DIR=${CBMPC_PREFIX_PUBLIC} + cmake --build build/${BUILD_TYPE}/ + "${DEMO_API_DIR}/${proj}/build/${BUILD_TYPE}/mpc-demo-api-${proj}" +} + run_go_demo() { - cd $DEMOS_GO_DIR/$1 - go mod tidy - # Ensure CGO uses the locally built C++ lib and auto-rebuilds if needed - (cd "$ROOT_PATH" && BUILD_TYPE=${BUILD_TYPE:-Release} bash scripts/go_with_cpp.sh --no-cd bash -lc "cd '$DEMOS_GO_DIR/$1' && env CGO_ENABLED=1 go run *.go") + local proj="$1" + if ! go_is_supported; then + return 0 + fi + + local openssl_lib_dir + openssl_lib_dir="$(resolve_openssl_lib_dir "${CBMPC_OPENSSL_ROOT}" || true)" + if [[ -z "${openssl_lib_dir}" ]]; then + echo "Skipping Go demo '${proj}' (OpenSSL not found under CBMPC_OPENSSL_ROOT='${CBMPC_OPENSSL_ROOT}')." + return 0 + fi + cd "${DEMO_GO_DIR}/${proj}" + CGO_ENABLED=1 \ + CGO_CFLAGS="-I${CBMPC_PREFIX_PUBLIC}/include" \ + CGO_LDFLAGS="-L${CBMPC_PREFIX_PUBLIC}/lib -L${openssl_lib_dir}" \ + go run . } POSITIONAL_ARGS=() @@ -56,12 +176,25 @@ while [[ $# -gt 0 ]]; do case $1 in --run-all) run_all_cpp + run_all_api run_all_go shift # past argument ;; --run) TEST_NAME="$2" - run_go_demo $TEST_NAME + run_cpp_demo "$TEST_NAME" + shift # past argument + shift # past value + ;; + --run-api) + TEST_NAME="$2" + run_api_demo "$TEST_NAME" + shift # past argument + shift # past value + ;; + --run-go) + TEST_NAME="$2" + run_go_demo "$TEST_NAME" shift # past argument shift # past value ;; diff --git a/src/cbmpc/api/CMakeLists.txt b/src/cbmpc/api/CMakeLists.txt new file mode 100644 index 00000000..043fdc81 --- /dev/null +++ b/src/cbmpc/api/CMakeLists.txt @@ -0,0 +1,19 @@ +add_library(cbmpc_api OBJECT "") + +target_sources(cbmpc_api PRIVATE + eddsa_mp.cpp + eddsa2pc.cpp + ecdsa_mp.cpp + ecdsa2pc.cpp + hd_keyset_eddsa_2p.cpp + hd_keyset_ecdsa_2p.cpp + pve_base_pke.cpp + pve_batch_ac.cpp + pve_batch_single_recipient.cpp + schnorr_mp.cpp + schnorr2pc.cpp + tdh2.cpp +) + +target_link_libraries(cbmpc_api cbmpc_core cbmpc_crypto cbmpc_protocol) + diff --git a/src/cbmpc/api/access_structure_util.h b/src/cbmpc/api/access_structure_util.h new file mode 100644 index 00000000..da3b3488 --- /dev/null +++ b/src/cbmpc/api/access_structure_util.h @@ -0,0 +1,227 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace coinbase::api::detail { + +inline constexpr size_t MAX_ACCESS_STRUCTURE_DEPTH = 128; +inline constexpr size_t MAX_ACCESS_STRUCTURE_NODES = 4096; + +inline error_t validate_access_structure_node_impl(const access_structure_t& n, size_t depth, size_t& nodes_seen) { + if (++nodes_seen > MAX_ACCESS_STRUCTURE_NODES) return coinbase::error(E_RANGE, "access_structure: too many nodes"); + if (depth > MAX_ACCESS_STRUCTURE_DEPTH) return coinbase::error(E_RANGE, "access_structure: too deep"); + + switch (n.type) { + case access_structure_t::node_type::leaf: + if (n.leaf_name.empty()) return coinbase::error(E_BADARG, "access_structure: leaf name must be non-empty"); + if (!n.children.empty()) return coinbase::error(E_BADARG, "access_structure: leaf node must not have children"); + if (n.threshold_k != 0) return coinbase::error(E_BADARG, "access_structure: leaf node must not set threshold_k"); + return SUCCESS; + + case access_structure_t::node_type::and_node: + case access_structure_t::node_type::or_node: + if (n.children.empty()) return coinbase::error(E_BADARG, "access_structure: AND/OR node must have children"); + if (!n.leaf_name.empty()) + return coinbase::error(E_BADARG, "access_structure: internal node must not set leaf_name"); + if (n.threshold_k != 0) + return coinbase::error(E_BADARG, "access_structure: AND/OR node must not set threshold_k"); + break; + + case access_structure_t::node_type::threshold: + if (n.children.empty()) return coinbase::error(E_BADARG, "access_structure: THRESHOLD node must have children"); + if (!n.leaf_name.empty()) + return coinbase::error(E_BADARG, "access_structure: internal node must not set leaf_name"); + if (n.threshold_k < 1) return coinbase::error(E_BADARG, "access_structure: invalid threshold_k"); + if (static_cast(n.threshold_k) > n.children.size()) + return coinbase::error(E_BADARG, "access_structure: threshold_k > children.size()"); + break; + } + + for (const auto& ch : n.children) { + const error_t rv = validate_access_structure_node_impl(ch, depth + 1, nodes_seen); + if (rv) return rv; + } + return SUCCESS; +} + +inline error_t validate_access_structure_node(const access_structure_t& n) { + size_t nodes_seen = 0; + return validate_access_structure_node_impl(n, /*depth=*/0, nodes_seen); +} + +inline error_t collect_leaf_names(const access_structure_t& n, std::set& out) { + struct frame_t { + const access_structure_t* node = nullptr; + size_t depth = 0; + }; + + size_t nodes_seen = 0; + std::vector stack; + stack.reserve(64); + stack.push_back(frame_t{&n, 0}); + + while (!stack.empty()) { + const frame_t frame = stack.back(); + stack.pop_back(); + + if (!frame.node) return coinbase::error(E_BADARG, "access_structure: invalid node"); + if (++nodes_seen > MAX_ACCESS_STRUCTURE_NODES) return coinbase::error(E_RANGE, "access_structure: too many nodes"); + if (frame.depth > MAX_ACCESS_STRUCTURE_DEPTH) return coinbase::error(E_RANGE, "access_structure: too deep"); + + if (frame.node->type == access_structure_t::node_type::leaf) { + out.insert(std::string(frame.node->leaf_name)); + continue; + } + + for (const auto& ch : frame.node->children) { + stack.push_back(frame_t{&ch, frame.depth + 1}); + } + } + + return SUCCESS; +} + +inline std::string gen_internal_node_name(std::unordered_set& used, uint64_t& counter) { + // Generate deterministic unique names that do not collide with leaf names. + while (true) { + counter++; + std::string name = "__cbmpc_ac_node_" + std::to_string(counter); + if (used.insert(name).second) return name; + } +} + +inline error_t build_internal_ac_node(const access_structure_t& in, size_t depth, bool is_root, + std::unordered_set& used_names, uint64_t& name_counter, + coinbase::crypto::ss::node_t*& out) { + out = nullptr; + if (depth > MAX_ACCESS_STRUCTURE_DEPTH) return coinbase::error(E_RANGE, "access_structure: too deep"); + + using coinbase::crypto::ss::node_e; + using coinbase::crypto::ss::node_t; + + switch (in.type) { + case access_structure_t::node_type::leaf: { + // Leaves must be named. (Root leaf is not supported; root is unnamed.) + if (is_root) return coinbase::error(E_BADARG, "access_structure: root cannot be a leaf node"); + auto* node = new node_t(node_e::LEAF, std::string(in.leaf_name)); + out = node; + return SUCCESS; + } + + case access_structure_t::node_type::and_node: + case access_structure_t::node_type::or_node: + case access_structure_t::node_type::threshold: { + if (in.children.empty()) return coinbase::error(E_BADARG, "access_structure: internal node missing children"); + + const node_e t = (in.type == access_structure_t::node_type::and_node) ? node_e::AND + : (in.type == access_structure_t::node_type::or_node) ? node_e::OR + : node_e::THRESHOLD; + + const int k = (in.type == access_structure_t::node_type::threshold) ? in.threshold_k : 0; + + const std::string node_name = is_root ? std::string() : gen_internal_node_name(used_names, name_counter); + auto* node = new node_t(t, node_name, k); + + for (const auto& ch : in.children) { + node_t* child = nullptr; + const error_t rv = build_internal_ac_node(ch, depth + 1, /*is_root=*/false, used_names, name_counter, child); + if (rv) { + delete node; // deletes any already-added children + return rv; + } + node->add_child_node(child); + } + + out = node; + return SUCCESS; + } + } + + return coinbase::error(E_BADARG, "access_structure: invalid node type"); +} + +inline error_t to_internal_access_structure(const access_structure_t& in, + const std::vector& party_names, + coinbase::crypto::ecurve_t curve, coinbase::crypto::ss::ac_owned_t& out) { + // Clear any existing tree. + delete out.root; + out.root = nullptr; + + if (!curve.valid()) return coinbase::error(E_BADARG, "access_structure: invalid curve"); + + // Basic shape validation (independent of job). + error_t rv = validate_access_structure_node(in); + if (rv) return rv; + if (in.type == access_structure_t::node_type::leaf) + return coinbase::error(E_BADARG, "access_structure: root cannot be leaf"); + + // Validate that leaf set matches job.party_names exactly. + std::set leaf_names; + rv = collect_leaf_names(in, leaf_names); + if (rv) return rv; + + std::set party_set; + for (const auto& name_view : party_names) party_set.insert(std::string(name_view)); + + if (leaf_names != party_set) + return coinbase::error(E_BADARG, "access_structure: leaf names must match job.party_names exactly"); + + // Build internal node tree with generated internal node names. + std::unordered_set used; + used.reserve(leaf_names.size() * 2 + 8); + used.insert(std::string()); // root name + for (const auto& name : leaf_names) used.insert(name); + + uint64_t counter = 0; + coinbase::crypto::ss::node_t* root = nullptr; + rv = build_internal_ac_node(in, /*depth=*/0, /*is_root=*/true, used, counter, root); + if (rv) return rv; + + out.curve = curve; + out.root = root; + + rv = out.validate_tree(); + if (rv) { + delete out.root; + out.root = nullptr; + return rv; + } + + return SUCCESS; +} + +inline error_t to_internal_party_set(const std::vector& party_names, + const std::vector& quorum_party_names, + coinbase::mpc::party_set_t& out) { + out = coinbase::mpc::party_set_t::empty(); + if (quorum_party_names.empty()) return coinbase::error(E_BADARG, "quorum_party_names must be non-empty"); + + std::unordered_map index_by_name; + index_by_name.reserve(party_names.size()); + for (size_t i = 0; i < party_names.size(); i++) index_by_name.emplace(party_names[i], static_cast(i)); + + std::unordered_set seen; + seen.reserve(quorum_party_names.size()); + for (const auto& qn : quorum_party_names) { + if (!seen.insert(qn).second) return coinbase::error(E_BADARG, "duplicate quorum party name"); + const auto it = index_by_name.find(qn); + if (it == index_by_name.end()) return coinbase::error(E_BADARG, "unknown quorum party name"); + out.add(it->second); + } + + return SUCCESS; +} + +} // namespace coinbase::api::detail diff --git a/src/cbmpc/api/curve_util.h b/src/cbmpc/api/curve_util.h new file mode 100644 index 00000000..ad4443f8 --- /dev/null +++ b/src/cbmpc/api/curve_util.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +namespace coinbase::api::detail { + +// Map public curve identifiers to internal curve objects. +inline coinbase::crypto::ecurve_t to_internal_curve(curve_id curve) { + switch (curve) { + case curve_id::p256: + return coinbase::crypto::curve_p256; + case curve_id::secp256k1: + return coinbase::crypto::curve_secp256k1; + case curve_id::ed25519: + return coinbase::crypto::curve_ed25519; + } + return coinbase::crypto::ecurve_t(); +} + +// Map internal curves back to public curve identifiers. +// +// Note: Only curves supported by the public API wrappers are mapped here. +inline bool from_internal_curve(coinbase::crypto::ecurve_t curve, curve_id& out) { + if (curve == coinbase::crypto::curve_p256) { + out = curve_id::p256; + return true; + } + if (curve == coinbase::crypto::curve_secp256k1) { + out = curve_id::secp256k1; + return true; + } + if (curve == coinbase::crypto::curve_ed25519) { + out = curve_id::ed25519; + return true; + } + return false; +} + +} // namespace coinbase::api::detail diff --git a/src/cbmpc/api/ecdsa2pc.cpp b/src/cbmpc/api/ecdsa2pc.cpp new file mode 100644 index 00000000..f9603f34 --- /dev/null +++ b/src/cbmpc/api/ecdsa2pc.cpp @@ -0,0 +1,302 @@ +#include +#include +#include + +#include "curve_util.h" +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::ecdsa_2p { + +namespace { + +constexpr uint32_t key_blob_version_v1 = 1; + +using coinbase::api::detail::from_internal_curve; +using coinbase::api::detail::to_internal_curve; +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::to_internal_party; +using coinbase::api::detail::validate_job_2p; + +struct key_blob_v1_t { + uint32_t version = key_blob_version_v1; + uint32_t role = 0; // 0=p1, 1=p2 + uint32_t curve = 0; // coinbase::api::curve_id + + buf_t Q_compressed; + coinbase::crypto::bn_t x_share; + coinbase::crypto::bn_t c_key; + coinbase::crypto::paillier_t paillier; + + void convert(coinbase::converter_t& c) { c.convert(version, role, curve, Q_compressed, x_share, c_key, paillier); } +}; + +error_t blob_to_key(const key_blob_v1_t& blob, coinbase::mpc::ecdsa2pc::key_t& key) { + if (blob.role > 1) return coinbase::error(E_FORMAT, "invalid key blob role"); + + const auto cid = static_cast(blob.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid key blob curve"); + auto curve = to_internal_curve(cid); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid key blob curve"); + + key.role = static_cast(static_cast(blob.role)); + key.curve = curve; + + // Defensive validation at the opaque blob boundary. + // + // In our ECDSA-2PC protocol, party P1 owns the Paillier private key, and `c_key` is an encryption of P1's share under + // that key. Reject malformed / tampered blobs early. + const bool paillier_has_private = blob.paillier.has_private_key(); + if ((blob.role == 0) != paillier_has_private) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto& N = blob.paillier.get_N(); + if (N.value().get_bits_count() < coinbase::crypto::paillier_t::bit_size) { + return coinbase::error(E_FORMAT, "invalid key blob"); + } + + // Intentionally do not enforce `x_share in [0, q)` here: + // in ECDSA-2PC this share is maintained as a Paillier-compatible integer representative and can be + // unreduced after refresh, so rejecting non-reduced values would break valid refreshed key blobs. + // + // However, `x_share` must remain in Z_N so that Paillier-related operations are well-defined and to avoid + // attacker-controlled bignum blowups. + if (!N.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + // Ensure `c_key` is a well-formed Paillier ciphertext under this key. + { + coinbase::crypto::vartime_scope_t vartime_scope; + if (blob.paillier.verify_cipher(blob.c_key)) return coinbase::error(E_FORMAT, "invalid key blob"); + } + + // If we have the private key (P1), bind the share to its Paillier encryption. + if (paillier_has_private) { + const coinbase::crypto::bn_t plain = blob.paillier.decrypt(blob.c_key); + if (plain != N.mod(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + } + + key.x_share = blob.x_share; + key.c_key = blob.c_key; + key.paillier = blob.paillier; + + if (const error_t rv = key.Q.from_bin(curve, blob.Q_compressed)) return rv; + if (curve.check(key.Q)) return coinbase::error(E_FORMAT, "invalid key blob"); + return SUCCESS; +} + +error_t serialize_key_blob(const coinbase::mpc::ecdsa2pc::key_t& key, buf_t& out) { + curve_id cid; + if (!from_internal_curve(key.curve, cid)) return coinbase::error(E_BADARG, "unsupported curve"); + if (cid == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + key_blob_v1_t blob; + blob.role = static_cast(key.role); + blob.curve = static_cast(cid); + blob.Q_compressed = key.Q.to_compressed_bin(); + blob.x_share = key.x_share; + blob.c_key = key.c_key; + blob.paillier = key.paillier; + out = coinbase::convert(blob); + return SUCCESS; +} + +error_t deserialize_key_blob(mem_t in, coinbase::mpc::ecdsa2pc::key_t& key) { + // Reject unsupported versions before attempting to parse variable-length fields. + // This prevents mis-parsing newer blob versions with incompatible encodings. + if (in.size < 4) return coinbase::error(E_FORMAT, "invalid key blob"); + const uint32_t version = coinbase::be_get_4(in.data); + if (version != key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + + key_blob_v1_t blob; + const error_t rv = coinbase::convert(blob, in); + if (rv) return rv; + if (blob.version != key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + return blob_to_key(blob, key); +} + +} // namespace + +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& key_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (curve == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::ecdsa2pc::key_t key; + const error_t rv = coinbase::mpc::ecdsa2pc::dkg(mpc_job, icurve, key); + if (rv) return rv; + + return serialize_key_blob(key, key_blob); +} + +error_t refresh(const coinbase::api::job_2p_t& job, mem_t key_blob, buf_t& new_key_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::ecdsa2pc::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (key.role != self) return coinbase::error(E_BADARG, "job.self mismatch key blob role"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::ecdsa2pc::key_t new_key; + rv = coinbase::mpc::ecdsa2pc::refresh(mpc_job, key, new_key); + if (rv) return rv; + + return serialize_key_blob(new_key, new_key_blob); +} + +using internal_sign_fn_t = error_t (*)(coinbase::mpc::job_2p_t&, buf_t& /*sid*/, const coinbase::mpc::ecdsa2pc::key_t&, + const mem_t /*msg*/, buf_t& /*sig_der*/); + +static error_t sign_common(internal_sign_fn_t fn, const coinbase::api::job_2p_t& job, mem_t key_blob, mem_t msg_hash, + buf_t& sid, buf_t& sig_der) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + msg_hash, "msg_hash", coinbase::api::detail::MAX_MESSAGE_DIGEST_SIZE)) + return rv; + coinbase::mpc::ecdsa2pc::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (key.role != self) return coinbase::error(E_BADARG, "job.self mismatch key blob role"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + sig_der.free(); + return fn(mpc_job, sid, key, msg_hash, sig_der); +} + +error_t sign(const coinbase::api::job_2p_t& job, mem_t key_blob, mem_t msg_hash, buf_t& sid, buf_t& sig_der) { + return sign_common(&coinbase::mpc::ecdsa2pc::sign, job, key_blob, msg_hash, sid, sig_der); +} + +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::ecdsa2pc::key_t key; + const error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + pub_key = key.Q.to_compressed_bin(); + return SUCCESS; +} + +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::ecdsa2pc::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto& q = key.curve.order(); + const coinbase::crypto::bn_t x_mod_q = key.x_share % q; + out_public_share_compressed = (x_mod_q * key.curve.generator()).to_compressed_bin(); + return SUCCESS; +} + +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::ecdsa2pc::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + curve_id cid; + if (!from_internal_curve(key.curve, cid)) return coinbase::error(E_BADARG, "unsupported curve"); + if (cid == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + // Variable-length big-endian encoding (may grow after refresh). + out_private_scalar = key.x_share.to_bin(); + + // Produce a v1-format blob with an invalid (out-of-range) scalar share so it is + // rejected by sign/refresh APIs. + key_blob_v1_t pub; + pub.role = static_cast(key.role); + pub.curve = static_cast(cid); + pub.Q_compressed = key.Q.to_compressed_bin(); + pub.x_share = key.paillier.get_N().value(); // x_share == N is out of range + pub.c_key = key.c_key; + pub.paillier = key.paillier; + out_public_key_blob = coinbase::convert(pub); + return SUCCESS; +} + +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar, mem_t public_share_compressed, + buf_t& out_key_blob) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key_blob, "public_key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t pub; + error_t rv = coinbase::convert(pub, public_key_blob); + if (rv) return rv; + if (pub.version != key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (pub.role > 1) return coinbase::error(E_FORMAT, "invalid key blob role"); + + const auto cid = static_cast(pub.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid key blob curve"); + auto curve = to_internal_curve(cid); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid key blob curve"); + + if (const error_t rvm = coinbase::api::detail::validate_mem_arg(private_scalar, "private_scalar")) return rvm; + if (private_scalar.size == 0) return coinbase::error(E_BADARG, "private_scalar must be non-empty"); + if (const error_t rvp = coinbase::api::detail::validate_mem_arg(public_share_compressed, "public_share_compressed")) + return rvp; + + // Validate Paillier material (c_key + key) is well-formed. + const bool paillier_has_private = pub.paillier.has_private_key(); + if ((pub.role == 0) != paillier_has_private) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto& N = pub.paillier.get_N(); + if (N.value().get_bits_count() < coinbase::crypto::paillier_t::bit_size) + return coinbase::error(E_FORMAT, "invalid key blob"); + { + coinbase::crypto::vartime_scope_t vartime_scope; + if (pub.paillier.verify_cipher(pub.c_key)) return coinbase::error(E_FORMAT, "invalid key blob"); + } + + // Interpret scalar and ensure it stays in Z_N (matching key blob invariants). + coinbase::crypto::bn_t x_share = coinbase::crypto::bn_t::from_bin(private_scalar); + if (!N.is_in_range(x_share)) return coinbase::error(E_FORMAT, "invalid private_scalar"); + + // If we have the private key (P1), bind the share to its Paillier encryption. + if (paillier_has_private) { + coinbase::crypto::vartime_scope_t vartime_scope; + const coinbase::crypto::bn_t plain = pub.paillier.decrypt(pub.c_key); + if (plain != N.mod(x_share)) return coinbase::error(E_FORMAT, "x_share mismatch key blob"); + } + + // Verify scalar matches the provided self-share point. + const coinbase::crypto::mod_t& q = curve.order(); + const coinbase::crypto::bn_t x_mod_q = x_share % q; + if (!q.is_in_range(x_mod_q)) return coinbase::error(E_FORMAT, "invalid private_scalar"); + + coinbase::crypto::ecc_point_t Qi_self(curve); + if (rv = Qi_self.from_bin(curve, public_share_compressed)) + return coinbase::error(rv, "invalid public_share_compressed"); + if (rv = curve.check(Qi_self)) return coinbase::error(rv, "invalid public_share_compressed"); + if (x_mod_q * curve.generator() != Qi_self) return coinbase::error(E_FORMAT, "x_share mismatch key blob"); + + // Validate and normalize global public key encoding. + coinbase::crypto::ecc_point_t Q(curve); + if (rv = Q.from_bin(curve, pub.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + if (rv = curve.check(Q)) return coinbase::error(rv, "invalid key blob"); + + pub.Q_compressed = Q.to_compressed_bin(); + pub.x_share = std::move(x_share); + out_key_blob = coinbase::convert(pub); + return SUCCESS; +} + +} // namespace coinbase::api::ecdsa_2p diff --git a/src/cbmpc/api/ecdsa_mp.cpp b/src/cbmpc/api/ecdsa_mp.cpp new file mode 100644 index 00000000..04ac8c61 --- /dev/null +++ b/src/cbmpc/api/ecdsa_mp.cpp @@ -0,0 +1,552 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "access_structure_util.h" +#include "curve_util.h" +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::ecdsa_mp { + +namespace { + +constexpr uint32_t key_blob_version_v1 = 1; +constexpr uint32_t ac_key_blob_version_v1 = 2; + +using coinbase::api::detail::to_internal_curve; +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::validate_job_mp; + +struct key_blob_v1_t { + uint32_t version = key_blob_version_v1; + uint32_t curve = 0; // coinbase::api::curve_id + + std::string party_name; // self identity (name-bound, not index-bound) + + buf_t Q_compressed; + std::map Qis_compressed; // name -> compressed Qi + + coinbase::crypto::bn_t x_share; + + void convert(coinbase::converter_t& c) { + c.convert(version, curve, party_name, Q_compressed, Qis_compressed, x_share); + } +}; + +static error_t parse_key_blob_any_version(mem_t in, key_blob_v1_t& out_blob, coinbase::api::curve_id& out_curve_id, + coinbase::crypto::ecurve_t& out_curve) { + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(in, "key_blob", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + error_t rv = UNINITIALIZED_ERROR; + + if (rv = coinbase::convert(out_blob, in)) return rv; + if (out_blob.version != key_blob_version_v1 && out_blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + + const auto cid = static_cast(out_blob.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid key blob curve"); + const auto curve = to_internal_curve(cid); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid key blob curve"); + + out_curve_id = cid; + out_curve = curve; + return SUCCESS; +} + +static error_t get_self_Qi_compressed_from_blob(const key_blob_v1_t& blob, buf_t& out_Qi_self_compressed) { + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + const auto it = blob.Qis_compressed.find(blob.party_name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (it->second.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + out_Qi_self_compressed = it->second; + return SUCCESS; +} + +static error_t serialize_key_blob_for_party_names(const std::vector& party_names, + const std::string& self_name, + const coinbase::mpc::ecdsampc::key_t& key, uint32_t version, + buf_t& out) { + // Curve: only allow public wrapper-supported curves. + curve_id cid; + if (!coinbase::api::detail::from_internal_curve(key.curve, cid)) + return coinbase::error(E_BADARG, "unsupported curve"); + if (cid == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + const std::string_view self_sv(self_name); + bool self_in_party_names = false; + for (const auto& name_view : party_names) { + if (name_view == self_sv) { + self_in_party_names = true; + break; + } + } + if (!self_in_party_names) return coinbase::error(E_BADARG, "self_name not in party_names"); + + if (key.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key"); + + key_blob_v1_t blob; + blob.version = version; + blob.curve = static_cast(cid); + blob.party_name = key.party_name; + blob.Q_compressed = key.Q.to_compressed_bin(); + blob.x_share = key.x_share; + + for (const auto& name_view : party_names) { + const std::string name(name_view); + const auto it = key.Qis.find(name); + if (it == key.Qis.end()) return coinbase::error(E_FORMAT, "key missing Qi"); + blob.Qis_compressed[name] = it->second.to_compressed_bin(); + } + + out = coinbase::convert(blob); + return SUCCESS; +} + +static error_t serialize_key_blob(const coinbase::api::job_mp_t& job, const coinbase::mpc::ecdsampc::key_t& key, + buf_t& out) { + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + return serialize_key_blob_for_party_names(job.party_names, self_name, key, key_blob_version_v1, out); +} + +static error_t serialize_ac_key_blob(const coinbase::api::job_mp_t& job, const coinbase::mpc::ecdsampc::key_t& key, + buf_t& out) { + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + return serialize_key_blob_for_party_names(job.party_names, self_name, key, ac_key_blob_version_v1, out); +} + +static error_t deserialize_key_blob(const coinbase::api::job_mp_t& job, mem_t in, coinbase::mpc::ecdsampc::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + if (job.party_names.size() != blob.Qis_compressed.size()) return coinbase::error(E_BADARG, "invalid key blob"); + + // Ensure the party name set matches the job (order can differ). + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + if (blob.Qis_compressed.find(name) == blob.Qis_compressed.end()) + return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + } + + const auto cid = static_cast(blob.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid key blob curve"); + const auto curve = to_internal_curve(cid); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid key blob curve"); + + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + const auto it = blob.Qis_compressed.find(name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, it->second)) return coinbase::error(rv, "invalid key blob"); + Qis[name] = std::move(Qi); + } + + coinbase::crypto::ecc_point_t Q_sum = curve.infinity(); + for (const auto& kv : Qis) Q_sum += kv.second; + if (Q != Q_sum) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +static error_t deserialize_ac_key_blob(const coinbase::api::job_mp_t& job, mem_t in, + coinbase::mpc::ecdsampc::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != ac_key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + if (job.party_names.size() != blob.Qis_compressed.size()) return coinbase::error(E_BADARG, "invalid key blob"); + + // Ensure the party name set matches the job (order can differ). + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + if (blob.Qis_compressed.find(name) == blob.Qis_compressed.end()) + return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + } + + const auto cid = static_cast(blob.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid key blob curve"); + const auto curve = to_internal_curve(cid); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid key blob curve"); + + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + const auto it = blob.Qis_compressed.find(name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, it->second)) return coinbase::error(rv, "invalid key blob"); + Qis[name] = std::move(Qi); + } + + // Access-structure key blobs are validated using the access structure at use sites. + // Here we only enforce the self-share binding. + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +static error_t deserialize_ac_key_blob(mem_t in, coinbase::mpc::ecdsampc::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != ac_key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.Qis_compressed.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto cid = static_cast(blob.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid key blob curve"); + const auto curve = to_internal_curve(cid); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid key blob curve"); + + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& kv : blob.Qis_compressed) { + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, kv.second)) return coinbase::error(rv, "invalid key blob"); + Qis[kv.first] = std::move(Qi); + } + + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +} // namespace + +error_t dkg_additive(const job_mp_t& job, curve_id curve, buf_t& key_blob, buf_t& sid) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + + if (curve == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::ecdsampc::key_t key; + sid.free(); + key_blob.free(); + + rv = coinbase::mpc::ecdsampc::dkg(mpc_job, icurve, key, sid); + if (rv) return rv; + + return serialize_key_blob(job, key, key_blob); +} + +error_t dkg_ac(const job_mp_t& job, curve_id curve, buf_t& sid, const access_structure_t& access_structure, + const std::vector& quorum_party_names, buf_t& key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + + if (curve == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, job.party_names, icurve, ac); + if (rv) return rv; + + coinbase::mpc::party_set_t quorum_party_set; + rv = coinbase::api::detail::to_internal_party_set(job.party_names, quorum_party_names, quorum_party_set); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::ecdsampc::key_t key; + key_blob.free(); + + rv = coinbase::mpc::ecdsampc::dkg_ac(mpc_job, icurve, sid, ac, quorum_party_set, key); + if (rv) return rv; + + return serialize_ac_key_blob(job, key, key_blob); +} + +error_t refresh_additive(const job_mp_t& job, buf_t& sid, mem_t key_blob, buf_t& new_key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + coinbase::mpc::ecdsampc::key_t key; + rv = deserialize_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::ecdsampc::key_t new_key; + new_key_blob.free(); + rv = coinbase::mpc::ecdsampc::refresh(mpc_job, sid, key, new_key); + if (rv) return rv; + + return serialize_key_blob(job, new_key, new_key_blob); +} + +error_t refresh_ac(const job_mp_t& job, buf_t& sid, mem_t key_blob, const access_structure_t& access_structure, + const std::vector& quorum_party_names, buf_t& new_key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + coinbase::mpc::ecdsampc::key_t key; + rv = deserialize_ac_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, job.party_names, key.curve, ac); + if (rv) return rv; + + coinbase::mpc::party_set_t quorum_party_set; + rv = coinbase::api::detail::to_internal_party_set(job.party_names, quorum_party_names, quorum_party_set); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::ecdsampc::key_t new_key; + new_key_blob.free(); + rv = coinbase::mpc::ecdsampc::refresh_ac(mpc_job, key.curve, sid, ac, quorum_party_set, key, new_key); + if (rv) return rv; + + return serialize_ac_key_blob(job, new_key, new_key_blob); +} + +error_t sign_ac(const job_mp_t& job, mem_t ac_key_blob, const access_structure_t& access_structure, mem_t msg, + party_idx_t sig_receiver, buf_t& sig_der) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(ac_key_blob, "ac_key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(msg, "msg", coinbase::api::detail::MAX_MESSAGE_DIGEST_SIZE)) + return rv; + if (sig_receiver < 0 || static_cast(sig_receiver) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid sig_receiver"); + + coinbase::mpc::ecdsampc::key_t ac_key; + rv = deserialize_ac_key_blob(ac_key_blob, ac_key); + if (rv) return rv; + + // Bind the key share to the local party identity in the job. + const std::string_view self_name_sv(job.party_names[static_cast(job.self)]); + if (ac_key.party_name != self_name_sv) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + + // Full party set is the key's Qis key set. + std::vector all_party_names; + all_party_names.reserve(ac_key.Qis.size()); + for (const auto& kv : ac_key.Qis) all_party_names.emplace_back(kv.first); + + // Validate that the signing party set (`job.party_names`) is a subset of the key's party set. + coinbase::mpc::party_set_t _unused; + rv = coinbase::api::detail::to_internal_party_set(all_party_names, job.party_names, _unused); + if (rv) return rv; + + // Convert access structure to internal and validate it matches the key party set. + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, all_party_names, ac_key.curve, ac); + if (rv) return rv; + + // Convert signing party list to internal set of names. + std::set quorum_names; + for (const auto& name : job.party_names) quorum_names.insert(std::string(name)); + + coinbase::mpc::ecdsampc::key_t additive_key; + rv = ac_key.to_additive_share(ac, quorum_names, additive_key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + sig_der.free(); + return coinbase::mpc::ecdsampc::sign(mpc_job, additive_key, msg, sig_receiver, sig_der); +} + +error_t sign_additive(const job_mp_t& job, mem_t key_blob, mem_t msg, party_idx_t sig_receiver, buf_t& sig_der) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(msg, "msg", coinbase::api::detail::MAX_MESSAGE_DIGEST_SIZE)) + return rv; + if (sig_receiver < 0 || static_cast(sig_receiver) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid sig_receiver"); + + coinbase::mpc::ecdsampc::key_t key; + rv = deserialize_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + sig_der.free(); + return coinbase::mpc::ecdsampc::sign(mpc_job, key, msg, sig_receiver, sig_der); +} + +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, key_blob); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + + const auto cid = static_cast(blob.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid key blob curve"); + const auto curve = to_internal_curve(cid); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid key blob curve"); + + coinbase::crypto::ecc_point_t Q(curve); + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + + pub_key = Q.to_compressed_bin(); + return SUCCESS; +} + +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed) { + key_blob_v1_t blob; + curve_id _cid; + coinbase::crypto::ecurve_t _curve; + error_t rv = parse_key_blob_any_version(key_blob, blob, _cid, _curve); + if (rv) return rv; + return get_self_Qi_compressed_from_blob(blob, out_public_share_compressed); +} + +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed) { + key_blob_v1_t blob; + curve_id _cid; + coinbase::crypto::ecurve_t curve; + error_t rv = parse_key_blob_any_version(key_blob, blob, _cid, curve); + if (rv) return rv; + + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + out_private_scalar_fixed = blob.x_share.to_bin(order_size); + + // Redact private scalar share. + blob.x_share = 0; + out_public_key_blob = coinbase::convert(blob); + return SUCCESS; +} + +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob) { + key_blob_v1_t blob; + curve_id _cid; + coinbase::crypto::ecurve_t curve; + error_t rv = parse_key_blob_any_version(public_key_blob, blob, _cid, curve); + if (rv) return rv; + + const coinbase::crypto::mod_t& q = curve.order(); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + if (const error_t rvm = coinbase::api::detail::validate_mem_arg(private_scalar_fixed, "private_scalar_fixed")) + return rvm; + if (private_scalar_fixed.size != order_size) return coinbase::error(E_BADARG, "private_scalar_fixed wrong size"); + if (const error_t rvp = coinbase::api::detail::validate_mem_arg(public_share_compressed, "public_share_compressed")) + return rvp; + + // Recover self-share public point (Qi_self) from the blob. + buf_t Qi_self_compressed; + rv = get_self_Qi_compressed_from_blob(blob, Qi_self_compressed); + if (rv) return rv; + + if (public_share_compressed != mem_t(Qi_self_compressed)) + return coinbase::error(E_BADARG, "public_share_compressed mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi_self(curve); + if (rv = Qi_self.from_bin(curve, Qi_self_compressed)) return coinbase::error(rv, "invalid key blob"); + if (rv = curve.check(Qi_self)) return coinbase::error(rv, "invalid key blob"); + + // Interpret scalar and reduce modulo group order (matching PVE semantics). + coinbase::crypto::bn_t x = coinbase::crypto::bn_t::from_bin(private_scalar_fixed) % q; + if (!q.is_in_range(x)) return coinbase::error(E_FORMAT, "invalid private_scalar_fixed"); + + const auto& G = curve.generator(); + if (x * G != Qi_self) return coinbase::error(E_FORMAT, "x_share mismatch key blob"); + + blob.x_share = std::move(x); + out_key_blob = coinbase::convert(blob); + return SUCCESS; +} + +} // namespace coinbase::api::ecdsa_mp diff --git a/src/cbmpc/api/eddsa2pc.cpp b/src/cbmpc/api/eddsa2pc.cpp new file mode 100644 index 00000000..b3d591d2 --- /dev/null +++ b/src/cbmpc/api/eddsa2pc.cpp @@ -0,0 +1,224 @@ +#include +#include +#include + +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::eddsa_2p { + +namespace { + +constexpr uint32_t key_blob_version_v1 = 1; + +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::to_internal_party; +using coinbase::api::detail::validate_job_2p; + +struct key_blob_v1_t { + uint32_t version = key_blob_version_v1; + uint32_t role = 0; // 0=p1, 1=p2 + uint32_t curve = 0; // coinbase::api::curve_id + + buf_t Q_compressed; + coinbase::crypto::bn_t x_share; + + void convert(coinbase::converter_t& c) { c.convert(version, role, curve, Q_compressed, x_share); } +}; + +static error_t blob_to_key(const key_blob_v1_t& blob, coinbase::mpc::eddsa2pc::key_t& key) { + if (blob.role > 1) return coinbase::error(E_FORMAT, "invalid key blob role"); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + + key.role = static_cast(static_cast(blob.role)); + key.curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = key.curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + key.x_share = blob.x_share; + + error_t rv = key.Q.from_bin(key.curve, blob.Q_compressed); + if (rv) return coinbase::error(rv, "invalid key blob"); + if (key.curve.check(key.Q)) return coinbase::error(E_FORMAT, "invalid key blob"); + return SUCCESS; +} + +static error_t serialize_key_blob(const coinbase::mpc::eddsa2pc::key_t& key, buf_t& out) { + if (key.curve != coinbase::crypto::curve_ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + key_blob_v1_t blob; + blob.role = static_cast(key.role); + blob.curve = static_cast(curve_id::ed25519); + blob.Q_compressed = key.Q.to_compressed_bin(); + blob.x_share = key.x_share; + out = coinbase::convert(blob); + return SUCCESS; +} + +static error_t deserialize_key_blob(mem_t in, coinbase::mpc::eddsa2pc::key_t& key) { + key_blob_v1_t blob; + const error_t rv = coinbase::convert(blob, in); + if (rv) return rv; + if (blob.version != key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + return blob_to_key(blob, key); +} + +} // namespace + +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& key_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (curve != curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::eddsa2pc::key_t key; + buf_t sid; // unused by this API + const error_t rv = coinbase::mpc::eckey::key_share_2p_t::dkg(mpc_job, coinbase::crypto::curve_ed25519, key, sid); + if (rv) return rv; + + return serialize_key_blob(key, key_blob); +} + +error_t refresh(const coinbase::api::job_2p_t& job, mem_t key_blob, buf_t& new_key_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::eddsa2pc::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (key.role != self) return coinbase::error(E_BADARG, "job.self mismatch key blob role"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::eddsa2pc::key_t new_key; + new_key_blob.free(); + rv = coinbase::mpc::eckey::key_share_2p_t::refresh(mpc_job, key, new_key); + if (rv) return rv; + + return serialize_key_blob(new_key, new_key_blob); +} + +error_t sign(const coinbase::api::job_2p_t& job, mem_t key_blob, mem_t msg, buf_t& sig) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(msg, "msg")) return rv; + coinbase::mpc::eddsa2pc::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (key.role != self) return coinbase::error(E_BADARG, "job.self mismatch key blob role"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + sig.free(); + return coinbase::mpc::eddsa2pc::sign(mpc_job, key, msg, sig); +} + +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::eddsa2pc::key_t key; + const error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + pub_key = key.Q.to_compressed_bin(); + return SUCCESS; +} + +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::eddsa2pc::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = curve.order(); + const auto& G = curve.generator(); + const coinbase::crypto::bn_t x = key.x_share % q; + out_public_share_compressed = (x * G).to_compressed_bin(); + return SUCCESS; +} + +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::eddsa2pc::key_t key; + const error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = curve.order(); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + out_private_scalar_fixed = key.x_share.to_bin(order_size); + + // Produce a v1-format blob with an invalid (out-of-range) scalar share so it is + // rejected by sign/refresh APIs. + key_blob_v1_t pub; + pub.role = static_cast(key.role); + pub.curve = static_cast(curve_id::ed25519); + pub.Q_compressed = key.Q.to_compressed_bin(); + pub.x_share = q; // x_share == q is out of range + out_public_key_blob = coinbase::convert(pub); + return SUCCESS; +} + +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key_blob, "public_key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t pub; + error_t rv = coinbase::convert(pub, public_key_blob); + if (rv) return rv; + if (pub.version != key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (pub.role > 1) return coinbase::error(E_FORMAT, "invalid key blob role"); + if (static_cast(pub.curve) != curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (pub.Q_compressed.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = curve.order(); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + if (const error_t rvm = coinbase::api::detail::validate_mem_arg(private_scalar_fixed, "private_scalar_fixed")) + return rvm; + if (private_scalar_fixed.size != order_size) return coinbase::error(E_BADARG, "private_scalar_fixed wrong size"); + + if (const error_t rvp = coinbase::api::detail::validate_mem_arg(public_share_compressed, "public_share_compressed")) + return rvp; + + coinbase::crypto::ecc_point_t Qi_self(curve); + if (rv = Qi_self.from_bin(curve, public_share_compressed)) + return coinbase::error(rv, "invalid public_share_compressed"); + if (rv = curve.check(Qi_self)) return coinbase::error(rv, "invalid public_share_compressed"); + if (!Qi_self.is_in_subgroup()) return coinbase::error(E_FORMAT, "invalid public_share_compressed"); + + const coinbase::crypto::bn_t x = coinbase::crypto::bn_t::from_bin(private_scalar_fixed) % q; + if (!q.is_in_range(x)) return coinbase::error(E_FORMAT, "invalid private_scalar_fixed"); + + const auto& G = curve.generator(); + if (x * G != Qi_self) return coinbase::error(E_FORMAT, "x_share mismatch key blob"); + + // Validate and normalize global public key encoding. + coinbase::crypto::ecc_point_t Q(curve); + if (rv = Q.from_bin(curve, pub.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + if (rv = curve.check(Q)) return coinbase::error(rv, "invalid key blob"); + + pub.x_share = x; + pub.Q_compressed = Q.to_compressed_bin(); + out_key_blob = coinbase::convert(pub); + return SUCCESS; +} + +} // namespace coinbase::api::eddsa_2p diff --git a/src/cbmpc/api/eddsa_mp.cpp b/src/cbmpc/api/eddsa_mp.cpp new file mode 100644 index 00000000..0c75aa35 --- /dev/null +++ b/src/cbmpc/api/eddsa_mp.cpp @@ -0,0 +1,531 @@ +#include +#include + +#include +#include +#include + +#include "access_structure_util.h" +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::eddsa_mp { + +namespace { + +constexpr uint32_t key_blob_version_v1 = 1; +constexpr uint32_t ac_key_blob_version_v1 = 2; + +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::validate_job_mp; + +struct key_blob_v1_t { + uint32_t version = key_blob_version_v1; + uint32_t curve = 0; // coinbase::api::curve_id + + std::string party_name; // self identity (name-bound, not index-bound) + + buf_t Q_compressed; + std::map Qis_compressed; // name -> compressed Qi + + coinbase::crypto::bn_t x_share; + + void convert(coinbase::converter_t& c) { + c.convert(version, curve, party_name, Q_compressed, Qis_compressed, x_share); + } +}; + +static error_t extract_Q_from_key_blob(mem_t in, coinbase::crypto::ecc_point_t& Q) { + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, in); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.Q_compressed.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + const auto curve = coinbase::crypto::curve_ed25519; + rv = Q.from_bin(curve, blob.Q_compressed); + if (rv) return coinbase::error(rv, "invalid key blob"); + if (curve.check(Q)) return coinbase::error(E_FORMAT, "invalid key blob"); + return SUCCESS; +} + +static error_t serialize_key_blob_for_party_names(const std::vector& party_names, + const std::string& self_name, + const coinbase::mpc::schnorrmp::key_t& key, uint32_t version, + buf_t& out) { + if (key.curve != coinbase::crypto::curve_ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + const std::string_view self_sv(self_name); + bool self_in_party_names = false; + for (const auto& name_view : party_names) { + if (name_view == self_sv) { + self_in_party_names = true; + break; + } + } + if (!self_in_party_names) return coinbase::error(E_BADARG, "self_name not in party_names"); + + if (key.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key"); + + key_blob_v1_t blob; + blob.version = version; + blob.curve = static_cast(curve_id::ed25519); + blob.party_name = key.party_name; + blob.Q_compressed = key.Q.to_compressed_bin(); + blob.x_share = key.x_share; + + for (const auto& name_view : party_names) { + const std::string name(name_view); + const auto it = key.Qis.find(name); + if (it == key.Qis.end()) return coinbase::error(E_FORMAT, "key missing Qi"); + blob.Qis_compressed[name] = it->second.to_compressed_bin(); + } + + out = coinbase::convert(blob); + return SUCCESS; +} + +static error_t serialize_key_blob(const coinbase::api::job_mp_t& job, const coinbase::mpc::schnorrmp::key_t& key, + buf_t& out) { + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + return serialize_key_blob_for_party_names(job.party_names, self_name, key, key_blob_version_v1, out); +} + +static error_t serialize_ac_key_blob(const coinbase::api::job_mp_t& job, const coinbase::mpc::schnorrmp::key_t& key, + buf_t& out) { + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + return serialize_key_blob_for_party_names(job.party_names, self_name, key, ac_key_blob_version_v1, out); +} + +static error_t deserialize_key_blob(const coinbase::api::job_mp_t& job, mem_t in, + coinbase::mpc::schnorrmp::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version: " + std::to_string(blob.version)); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + if (blob.Qis_compressed.size() != job.party_names.size()) return coinbase::error(E_BADARG, "invalid key blob"); + + // Ensure the party name set matches the job (order can differ). + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + if (blob.Qis_compressed.find(name) == blob.Qis_compressed.end()) + return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + } + + const auto curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + if (curve.check(Q)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + const auto it = blob.Qis_compressed.find(name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, it->second)) return coinbase::error(rv, "invalid key blob"); + if (!Qi.is_in_subgroup()) return coinbase::error(E_FORMAT, "invalid key blob"); + Qis[name] = std::move(Qi); + } + + coinbase::crypto::ecc_point_t Q_sum = curve.infinity(); + for (const auto& kv : Qis) Q_sum += kv.second; + if (Q != Q_sum) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +static error_t deserialize_ac_key_blob(const coinbase::api::job_mp_t& job, mem_t in, + coinbase::mpc::schnorrmp::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != ac_key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + if (blob.Qis_compressed.size() != job.party_names.size()) return coinbase::error(E_BADARG, "invalid key blob"); + + // Ensure the party name set matches the job (order can differ). + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + if (blob.Qis_compressed.find(name) == blob.Qis_compressed.end()) + return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + } + + const auto curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + if (curve.check(Q)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + const auto it = blob.Qis_compressed.find(name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, it->second)) return coinbase::error(rv, "invalid key blob"); + if (!Qi.is_in_subgroup()) return coinbase::error(E_FORMAT, "invalid key blob"); + Qis[name] = std::move(Qi); + } + + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +static error_t deserialize_ac_key_blob(mem_t in, coinbase::mpc::schnorrmp::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != ac_key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.Qis_compressed.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + if (curve.check(Q)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& kv : blob.Qis_compressed) { + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, kv.second)) return coinbase::error(rv, "invalid key blob"); + if (!Qi.is_in_subgroup()) return coinbase::error(E_FORMAT, "invalid key blob"); + Qis[kv.first] = std::move(Qi); + } + + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +} // namespace + +error_t dkg_additive(const job_mp_t& job, curve_id curve, buf_t& key_blob, buf_t& sid) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (curve != curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorrmp::key_t key; + sid.free(); + key_blob.free(); + rv = coinbase::mpc::schnorrmp::dkg(mpc_job, coinbase::crypto::curve_ed25519, key, sid); + if (rv) return rv; + + return serialize_key_blob(job, key, key_blob); +} + +error_t dkg_ac(const job_mp_t& job, curve_id curve, buf_t& sid, const access_structure_t& access_structure, + const std::vector& quorum_party_names, buf_t& key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (curve != curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, job.party_names, + coinbase::crypto::curve_ed25519, ac); + if (rv) return rv; + + coinbase::mpc::party_set_t quorum_party_set; + rv = coinbase::api::detail::to_internal_party_set(job.party_names, quorum_party_names, quorum_party_set); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorrmp::key_t key; + key_blob.free(); + rv = coinbase::mpc::schnorrmp::dkg_ac(mpc_job, coinbase::crypto::curve_ed25519, sid, ac, quorum_party_set, key); + if (rv) return rv; + + return serialize_ac_key_blob(job, key, key_blob); +} + +error_t refresh_additive(const job_mp_t& job, buf_t& sid, mem_t key_blob, buf_t& new_key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + coinbase::mpc::schnorrmp::key_t key; + rv = deserialize_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorrmp::key_t new_key; + new_key_blob.free(); + rv = coinbase::mpc::schnorrmp::refresh(mpc_job, sid, key, new_key); + if (rv) return rv; + + return serialize_key_blob(job, new_key, new_key_blob); +} + +error_t refresh_ac(const job_mp_t& job, buf_t& sid, mem_t key_blob, const access_structure_t& access_structure, + const std::vector& quorum_party_names, buf_t& new_key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + coinbase::mpc::schnorrmp::key_t key; + rv = deserialize_ac_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, job.party_names, key.curve, ac); + if (rv) return rv; + + coinbase::mpc::party_set_t quorum_party_set; + rv = coinbase::api::detail::to_internal_party_set(job.party_names, quorum_party_names, quorum_party_set); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorrmp::key_t new_key; + new_key_blob.free(); + rv = coinbase::mpc::schnorrmp::refresh_ac(mpc_job, key.curve, sid, ac, quorum_party_set, key, new_key); + if (rv) return rv; + + return serialize_ac_key_blob(job, new_key, new_key_blob); +} + +error_t sign_ac(const job_mp_t& job, mem_t ac_key_blob, const access_structure_t& access_structure, mem_t msg, + party_idx_t sig_receiver, buf_t& sig) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(ac_key_blob, "ac_key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (rv = coinbase::api::detail::validate_mem_arg(msg, "msg")) return rv; + if (sig_receiver < 0 || static_cast(sig_receiver) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid sig_receiver"); + + coinbase::mpc::schnorrmp::key_t ac_key; + rv = deserialize_ac_key_blob(ac_key_blob, ac_key); + if (rv) return rv; + + // Bind the key share to the local party identity in the job. + const std::string_view self_name_sv(job.party_names[static_cast(job.self)]); + if (ac_key.party_name != self_name_sv) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + + // Full party set is the key's Qis key set. + std::vector all_party_names; + all_party_names.reserve(ac_key.Qis.size()); + for (const auto& kv : ac_key.Qis) all_party_names.emplace_back(kv.first); + + // Validate that the signing party set (`job.party_names`) is a subset of the key's party set. + coinbase::mpc::party_set_t _unused; + rv = coinbase::api::detail::to_internal_party_set(all_party_names, job.party_names, _unused); + if (rv) return rv; + + // Convert access structure to internal and validate it matches the key party set. + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, all_party_names, ac_key.curve, ac); + if (rv) return rv; + + // Convert signing party list to internal set of names. + std::set quorum_names; + for (const auto& name : job.party_names) quorum_names.insert(std::string(name)); + + coinbase::mpc::schnorrmp::key_t additive_key; + rv = ac_key.to_additive_share(ac, quorum_names, additive_key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + sig.free(); + return coinbase::mpc::eddsampc::sign(mpc_job, additive_key, msg, sig_receiver, sig); +} + +error_t sign_additive(const job_mp_t& job, mem_t key_blob, mem_t msg, party_idx_t sig_receiver, buf_t& sig) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (rv = coinbase::api::detail::validate_mem_arg(msg, "msg")) return rv; + if (sig_receiver < 0 || static_cast(sig_receiver) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid sig_receiver"); + + coinbase::mpc::schnorrmp::key_t key; + rv = deserialize_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + sig.free(); + return coinbase::mpc::eddsampc::sign(mpc_job, key, msg, sig_receiver, sig); +} + +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::crypto::ecc_point_t Q(coinbase::crypto::curve_ed25519); + const error_t rv = extract_Q_from_key_blob(key_blob, Q); + if (rv) return rv; + pub_key = Q.to_compressed_bin(); + return SUCCESS; +} + +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, key_blob); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto it = blob.Qis_compressed.find(blob.party_name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_FORMAT, "key blob missing self Qi"); + out_public_share_compressed = it->second; + return SUCCESS; +} + +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, key_blob); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + + const auto curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + out_private_scalar_fixed = blob.x_share.to_bin(order_size); + + // Wipe private scalar share. + blob.x_share = 0; + out_public_key_blob = coinbase::convert(blob); + return SUCCESS; +} + +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key_blob, "public_key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, public_key_blob); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto curve = coinbase::crypto::curve_ed25519; + const coinbase::crypto::mod_t& q = curve.order(); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + if (const error_t rvm = coinbase::api::detail::validate_mem_arg(private_scalar_fixed, "private_scalar_fixed")) + return rvm; + if (private_scalar_fixed.size != order_size) return coinbase::error(E_BADARG, "private_scalar_fixed wrong size"); + if (const error_t rvp = coinbase::api::detail::validate_mem_arg(public_share_compressed, "public_share_compressed")) + return rvp; + + const auto it = blob.Qis_compressed.find(blob.party_name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_FORMAT, "key blob missing self Qi"); + const buf_t& Qi_self_compressed = it->second; + + if (public_share_compressed != mem_t(Qi_self_compressed)) + return coinbase::error(E_BADARG, "public_share_compressed mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi_self(curve); + if (rv = Qi_self.from_bin(curve, Qi_self_compressed)) return coinbase::error(rv, "invalid key blob"); + if (rv = curve.check(Qi_self)) return coinbase::error(rv, "invalid key blob"); + if (!Qi_self.is_in_subgroup()) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::bn_t x = coinbase::crypto::bn_t::from_bin(private_scalar_fixed) % q; + if (!q.is_in_range(x)) return coinbase::error(E_FORMAT, "invalid private_scalar_fixed"); + + const auto& G = curve.generator(); + if (x * G != Qi_self) return coinbase::error(E_FORMAT, "x_share mismatch key blob"); + + blob.x_share = std::move(x); + out_key_blob = coinbase::convert(blob); + return SUCCESS; +} + +} // namespace coinbase::api::eddsa_mp diff --git a/src/cbmpc/api/hd_keyset_ecdsa_2p.cpp b/src/cbmpc/api/hd_keyset_ecdsa_2p.cpp new file mode 100644 index 00000000..e7f0f6e5 --- /dev/null +++ b/src/cbmpc/api/hd_keyset_ecdsa_2p.cpp @@ -0,0 +1,226 @@ +#include +#include +#include +#include +#include + +#include "curve_util.h" +#include "hd_keyset_util.h" +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::hd_keyset_ecdsa_2p { + +namespace { + +constexpr uint32_t keyset_blob_version_v1 = 1; + +using coinbase::api::detail::from_internal_curve; +using coinbase::api::detail::to_internal_bip32_path; +using coinbase::api::detail::to_internal_curve; +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::to_internal_party; +using coinbase::api::detail::validate_job_2p; +using coinbase::api::detail::validate_no_duplicate_bip32_paths; + +// Mirror of the `coinbase::api::ecdsa_2p` key blob format (see `src/cbmpc/api/ecdsa2pc.cpp`). +constexpr uint32_t ecdsa2pc_key_blob_version_v1 = 1; +struct ecdsa2pc_key_blob_v1_t { + uint32_t version = ecdsa2pc_key_blob_version_v1; + uint32_t role = 0; // 0=p1, 1=p2 + uint32_t curve = 0; // coinbase::api::curve_id + + buf_t Q_compressed; + coinbase::crypto::bn_t x_share; + coinbase::crypto::bn_t c_key; + coinbase::crypto::paillier_t paillier; + + void convert(coinbase::converter_t& c) { c.convert(version, role, curve, Q_compressed, x_share, c_key, paillier); } +}; + +static error_t serialize_ecdsa2pc_key_blob(const coinbase::mpc::ecdsa2pc::key_t& key, buf_t& out) { + curve_id cid; + if (!from_internal_curve(key.curve, cid)) return coinbase::error(E_BADARG, "unsupported curve"); + if (cid == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + ecdsa2pc_key_blob_v1_t blob; + blob.role = static_cast(key.role); + blob.curve = static_cast(cid); + blob.Q_compressed = key.Q.to_compressed_bin(); + blob.x_share = key.x_share; + blob.c_key = key.c_key; + blob.paillier = key.paillier; + out = coinbase::convert(blob); + return SUCCESS; +} + +struct keyset_blob_v1_t { + uint32_t version = keyset_blob_version_v1; + uint32_t role = 0; // 0=p1, 1=p2 + uint32_t curve = 0; // coinbase::api::curve_id + + buf_t root_Q_compressed; + buf_t root_K_compressed; + coinbase::crypto::bn_t x_share; + coinbase::crypto::bn_t k_share; + coinbase::crypto::paillier_t paillier; + coinbase::crypto::bn_t c_key; + + void convert(coinbase::converter_t& c) { + c.convert(version, role, curve, root_Q_compressed, root_K_compressed, x_share, k_share, paillier, c_key); + } +}; + +static error_t blob_to_keyset(const keyset_blob_v1_t& blob, coinbase::mpc::key_share_ecdsa_hdmpc_2p_t& keyset) { + if (blob.role > 1) return coinbase::error(E_FORMAT, "invalid keyset blob role"); + + const auto cid = static_cast(blob.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid keyset blob curve"); + const auto icurve = to_internal_curve(cid); + if (!icurve.valid()) return coinbase::error(E_FORMAT, "invalid keyset blob curve"); + + keyset.party_index = static_cast(blob.role); + keyset.curve = icurve; + + const coinbase::crypto::mod_t& q = keyset.curve.order(); + if (!q.is_in_range(blob.k_share)) return coinbase::error(E_FORMAT, "invalid keyset blob"); + + keyset.root.x_share = blob.x_share; + keyset.root.k_share = blob.k_share; + keyset.paillier = blob.paillier; + keyset.c_key = blob.c_key; + + // For ECDSA-2PC, party p1 maintains `x_share` as an integer representative compatible with Paillier plaintext + // operations (it is intentionally not reduced modulo q after refresh). Enforce only that it fits in Z_N. + const coinbase::crypto::mod_t& N = keyset.paillier.get_N(); + if (!N.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid keyset blob"); + + error_t rv = keyset.root.Q.from_bin(icurve, blob.root_Q_compressed); + if (rv) return rv; + return keyset.root.K.from_bin(icurve, blob.root_K_compressed); +} + +static error_t deserialize_keyset_blob(mem_t in, coinbase::mpc::key_share_ecdsa_hdmpc_2p_t& out) { + keyset_blob_v1_t blob; + error_t rv = coinbase::convert(blob, in); + if (rv) return rv; + if (blob.version != keyset_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported keyset blob version"); + return blob_to_keyset(blob, out); +} + +static error_t serialize_keyset_blob(const coinbase::mpc::key_share_ecdsa_hdmpc_2p_t& keyset, buf_t& out) { + curve_id cid; + if (!from_internal_curve(keyset.curve, cid)) return coinbase::error(E_BADARG, "unsupported curve"); + if (cid == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + keyset_blob_v1_t blob; + blob.role = static_cast(keyset.party_index); + blob.curve = static_cast(cid); + blob.root_Q_compressed = keyset.root.Q.to_compressed_bin(); + blob.root_K_compressed = keyset.root.K.to_compressed_bin(); + blob.x_share = keyset.root.x_share; + blob.k_share = keyset.root.k_share; + blob.paillier = keyset.paillier; + blob.c_key = keyset.c_key; + out = coinbase::convert(blob); + return SUCCESS; +} + +} // namespace + +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& keyset_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (curve == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::key_share_ecdsa_hdmpc_2p_t keyset; + const error_t rv = coinbase::mpc::key_share_ecdsa_hdmpc_2p_t::dkg(mpc_job, icurve, keyset); + if (rv) return rv; + + return serialize_keyset_blob(keyset, keyset_blob); +} + +error_t refresh(const coinbase::api::job_2p_t& job, mem_t keyset_blob, buf_t& new_keyset_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(keyset_blob, "keyset_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::key_share_ecdsa_hdmpc_2p_t keyset; + error_t rv = deserialize_keyset_blob(keyset_blob, keyset); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (static_cast(keyset.party_index) != static_cast(self)) + return coinbase::error(E_BADARG, "job.self mismatch keyset blob role"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::key_share_ecdsa_hdmpc_2p_t new_keyset; + rv = coinbase::mpc::key_share_ecdsa_hdmpc_2p_t::refresh(mpc_job, keyset, new_keyset); + if (rv) return rv; + + return serialize_keyset_blob(new_keyset, new_keyset_blob); +} + +error_t derive_ecdsa_2p_keys(const coinbase::api::job_2p_t& job, mem_t keyset_blob, const bip32_path_t& hardened_path, + const std::vector& non_hardened_paths, buf_t& sid, + std::vector& out_ecdsa_2p_key_blobs) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(keyset_blob, "keyset_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::key_share_ecdsa_hdmpc_2p_t keyset; + error_t rv = deserialize_keyset_blob(keyset_blob, keyset); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (static_cast(keyset.party_index) != static_cast(self)) + return coinbase::error(E_BADARG, "job.self mismatch keyset blob role"); + + rv = validate_no_duplicate_bip32_paths(non_hardened_paths); + if (rv) return rv; + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + const coinbase::mpc::bip32_path_t hardened_path_internal = to_internal_bip32_path(hardened_path); + std::vector non_hardened_paths_internal; + non_hardened_paths_internal.reserve(non_hardened_paths.size()); + for (const auto& p : non_hardened_paths) non_hardened_paths_internal.push_back(to_internal_bip32_path(p)); + + std::vector derived_keys(non_hardened_paths.size()); + rv = coinbase::mpc::key_share_ecdsa_hdmpc_2p_t::derive_keys(mpc_job, keyset, hardened_path_internal, + non_hardened_paths_internal, sid, derived_keys); + if (rv) { + out_ecdsa_2p_key_blobs.clear(); + return rv; + } + + std::vector blobs; + blobs.resize(derived_keys.size()); + for (size_t i = 0; i < derived_keys.size(); i++) { + rv = serialize_ecdsa2pc_key_blob(derived_keys[i], blobs[i]); + if (rv) { + out_ecdsa_2p_key_blobs.clear(); + return rv; + } + } + + out_ecdsa_2p_key_blobs = std::move(blobs); + return SUCCESS; +} + +error_t extract_root_public_key_compressed(mem_t keyset_blob, buf_t& out_Q_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(keyset_blob, "keyset_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::key_share_ecdsa_hdmpc_2p_t keyset; + const error_t rv = deserialize_keyset_blob(keyset_blob, keyset); + if (rv) return rv; + out_Q_compressed = keyset.root.Q.to_compressed_bin(); + return SUCCESS; +} + +} // namespace coinbase::api::hd_keyset_ecdsa_2p diff --git a/src/cbmpc/api/hd_keyset_eddsa_2p.cpp b/src/cbmpc/api/hd_keyset_eddsa_2p.cpp new file mode 100644 index 00000000..2095f3bb --- /dev/null +++ b/src/cbmpc/api/hd_keyset_eddsa_2p.cpp @@ -0,0 +1,201 @@ +#include +#include +#include +#include + +#include "hd_keyset_util.h" +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::hd_keyset_eddsa_2p { + +namespace { + +constexpr uint32_t keyset_blob_version_v1 = 1; + +using coinbase::api::detail::to_internal_bip32_path; +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::to_internal_party; +using coinbase::api::detail::validate_job_2p; +using coinbase::api::detail::validate_no_duplicate_bip32_paths; + +// Mirror of the `coinbase::api::eddsa_2p` key blob format (see `src/cbmpc/api/eddsa2pc.cpp`). +constexpr uint32_t eddsa2pc_key_blob_version_v1 = 1; +struct eddsa2pc_key_blob_v1_t { + uint32_t version = eddsa2pc_key_blob_version_v1; + uint32_t role = 0; // 0=p1, 1=p2 + uint32_t curve = 0; // coinbase::api::curve_id + + buf_t Q_compressed; + coinbase::crypto::bn_t x_share; + + void convert(coinbase::converter_t& c) { c.convert(version, role, curve, Q_compressed, x_share); } +}; + +static error_t serialize_eddsa2pc_key_blob(const coinbase::mpc::eddsa2pc::key_t& key, buf_t& out) { + if (key.curve != coinbase::crypto::curve_ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + eddsa2pc_key_blob_v1_t blob; + blob.role = static_cast(key.role); + blob.curve = static_cast(curve_id::ed25519); + blob.Q_compressed = key.Q.to_compressed_bin(); + blob.x_share = key.x_share; + + out = coinbase::convert(blob); + return SUCCESS; +} + +struct keyset_blob_v1_t { + uint32_t version = keyset_blob_version_v1; + uint32_t role = 0; // 0=p1, 1=p2 + uint32_t curve = 0; // coinbase::api::curve_id + + buf_t root_Q_compressed; + buf_t root_K_compressed; + coinbase::crypto::bn_t x_share; + coinbase::crypto::bn_t k_share; + + void convert(coinbase::converter_t& c) { + c.convert(version, role, curve, root_Q_compressed, root_K_compressed, x_share, k_share); + } +}; + +static error_t blob_to_keyset(const keyset_blob_v1_t& blob, coinbase::mpc::key_share_eddsa_hdmpc_2p_t& keyset) { + if (blob.role > 1) return coinbase::error(E_FORMAT, "invalid keyset blob role"); + if (static_cast(blob.curve) != curve_id::ed25519) + return coinbase::error(E_FORMAT, "invalid keyset blob curve"); + + keyset.party_index = static_cast(blob.role); + keyset.curve = coinbase::crypto::curve_ed25519; + + const coinbase::crypto::mod_t& q = keyset.curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid keyset blob"); + if (!q.is_in_range(blob.k_share)) return coinbase::error(E_FORMAT, "invalid keyset blob"); + + keyset.root.x_share = blob.x_share; + keyset.root.k_share = blob.k_share; + + error_t rv = keyset.root.Q.from_bin(keyset.curve, blob.root_Q_compressed); + if (rv) return rv; + return keyset.root.K.from_bin(keyset.curve, blob.root_K_compressed); +} + +static error_t deserialize_keyset_blob(mem_t in, coinbase::mpc::key_share_eddsa_hdmpc_2p_t& out) { + keyset_blob_v1_t blob; + error_t rv = coinbase::convert(blob, in); + if (rv) return rv; + if (blob.version != keyset_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported keyset blob version"); + return blob_to_keyset(blob, out); +} + +static error_t serialize_keyset_blob(const coinbase::mpc::key_share_eddsa_hdmpc_2p_t& keyset, buf_t& out) { + if (keyset.curve != coinbase::crypto::curve_ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + keyset_blob_v1_t blob; + blob.role = static_cast(keyset.party_index); + blob.curve = static_cast(curve_id::ed25519); + blob.root_Q_compressed = keyset.root.Q.to_compressed_bin(); + blob.root_K_compressed = keyset.root.K.to_compressed_bin(); + blob.x_share = keyset.root.x_share; + blob.k_share = keyset.root.k_share; + + out = coinbase::convert(blob); + return SUCCESS; +} + +} // namespace + +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& keyset_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (curve != curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::key_share_eddsa_hdmpc_2p_t keyset; + const error_t rv = coinbase::mpc::key_share_eddsa_hdmpc_2p_t::dkg(mpc_job, coinbase::crypto::curve_ed25519, keyset); + if (rv) return rv; + + return serialize_keyset_blob(keyset, keyset_blob); +} + +error_t refresh(const coinbase::api::job_2p_t& job, mem_t keyset_blob, buf_t& new_keyset_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(keyset_blob, "keyset_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::key_share_eddsa_hdmpc_2p_t keyset; + error_t rv = deserialize_keyset_blob(keyset_blob, keyset); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (static_cast(keyset.party_index) != static_cast(self)) + return coinbase::error(E_BADARG, "job.self mismatch keyset blob role"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::key_share_eddsa_hdmpc_2p_t new_keyset; + rv = coinbase::mpc::key_share_eddsa_hdmpc_2p_t::refresh(mpc_job, keyset, new_keyset); + if (rv) return rv; + + return serialize_keyset_blob(new_keyset, new_keyset_blob); +} + +error_t derive_eddsa_2p_keys(const coinbase::api::job_2p_t& job, mem_t keyset_blob, const bip32_path_t& hardened_path, + const std::vector& non_hardened_paths, buf_t& sid, + std::vector& out_eddsa_2p_key_blobs) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(keyset_blob, "keyset_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::key_share_eddsa_hdmpc_2p_t keyset; + error_t rv = deserialize_keyset_blob(keyset_blob, keyset); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (static_cast(keyset.party_index) != static_cast(self)) + return coinbase::error(E_BADARG, "job.self mismatch keyset blob role"); + + rv = validate_no_duplicate_bip32_paths(non_hardened_paths); + if (rv) return rv; + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + const coinbase::mpc::bip32_path_t hardened_path_internal = to_internal_bip32_path(hardened_path); + std::vector non_hardened_paths_internal; + non_hardened_paths_internal.reserve(non_hardened_paths.size()); + for (const auto& p : non_hardened_paths) non_hardened_paths_internal.push_back(to_internal_bip32_path(p)); + + std::vector derived_keys(non_hardened_paths.size()); + rv = coinbase::mpc::key_share_eddsa_hdmpc_2p_t::derive_keys(mpc_job, keyset, hardened_path_internal, + non_hardened_paths_internal, sid, derived_keys); + if (rv) { + out_eddsa_2p_key_blobs.clear(); + return rv; + } + + std::vector blobs; + blobs.resize(derived_keys.size()); + for (size_t i = 0; i < derived_keys.size(); i++) { + rv = serialize_eddsa2pc_key_blob(derived_keys[i], blobs[i]); + if (rv) { + out_eddsa_2p_key_blobs.clear(); + return rv; + } + } + + out_eddsa_2p_key_blobs = std::move(blobs); + return SUCCESS; +} + +error_t extract_root_public_key_compressed(mem_t keyset_blob, buf_t& out_Q_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(keyset_blob, "keyset_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::key_share_eddsa_hdmpc_2p_t keyset; + const error_t rv = deserialize_keyset_blob(keyset_blob, keyset); + if (rv) return rv; + out_Q_compressed = keyset.root.Q.to_compressed_bin(); + return SUCCESS; +} + +} // namespace coinbase::api::hd_keyset_eddsa_2p diff --git a/src/cbmpc/api/hd_keyset_util.h b/src/cbmpc/api/hd_keyset_util.h new file mode 100644 index 00000000..94569f6a --- /dev/null +++ b/src/cbmpc/api/hd_keyset_util.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +#include +#include + +namespace coinbase::api::detail { + +// Convert a public bip32 path type (must have `.indices`) to internal representation. +template +inline coinbase::mpc::bip32_path_t to_internal_bip32_path(const PublicPathT& in) { + coinbase::mpc::bip32_path_t out; + for (const uint32_t idx : in.indices) out.append(idx); + return out; +} + +// Validate that a list of public bip32 paths contains no duplicates. +// The public path type must have `.indices` (vector). +template +inline error_t validate_no_duplicate_bip32_paths(const std::vector& paths) { + // O(n^2) is fine here: typical derivation batches are small (caller-controlled). + for (size_t i = 0; i < paths.size(); i++) { + for (size_t j = i + 1; j < paths.size(); j++) { + if (paths[i].indices == paths[j].indices) return coinbase::error(E_BADARG, "duplicate non_hardened_paths"); + } + } + return SUCCESS; +} + +} // namespace coinbase::api::detail diff --git a/src/cbmpc/api/job_util.h b/src/cbmpc/api/job_util.h new file mode 100644 index 00000000..903a2d6a --- /dev/null +++ b/src/cbmpc/api/job_util.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +#include +#include + +namespace coinbase::api::detail { + +// Basic validation for the public 2PC job view. +// +// Shared across all 2-party public API wrappers. +inline error_t validate_job_2p(const coinbase::api::job_2p_t& job) { + if (job.self != party_2p_t::p1 && job.self != party_2p_t::p2) return coinbase::error(E_BADARG, "invalid job.self"); + if (job.p1_name.empty()) return coinbase::error(E_BADARG, "p1_name must be non-empty"); + if (job.p2_name.empty()) return coinbase::error(E_BADARG, "p2_name must be non-empty"); + if (job.p1_name == job.p2_name) return coinbase::error(E_BADARG, "party names must be distinct"); + return SUCCESS; +} + +// Map public 2PC role enum to internal protocol role enum. +inline coinbase::mpc::party_t to_internal_party(party_2p_t self) { return static_cast(self); } + +// Convert the public 2PC job view to the internal protocol job. +// +// Note: the returned internal job copies party names, but does *not* take +// ownership of the transport. It stores a non-owning pointer/reference to +// `job.transport`, which must outlive any protocol call using the returned job. +inline coinbase::mpc::job_2p_t to_internal_job(const coinbase::api::job_2p_t& job) { + return coinbase::mpc::job_2p_t(to_internal_party(job.self), std::string(job.p1_name), std::string(job.p2_name), + job.transport); +} + +// Convert the public MP job view to the internal protocol job. +// +// Note: the returned internal job copies party names, but does *not* take +// ownership of the transport. It stores a non-owning pointer/reference to +// `job.transport`, which must outlive any protocol call using the returned job. +inline coinbase::mpc::job_mp_t to_internal_job(const coinbase::api::job_mp_t& job) { + std::vector names; + names.reserve(job.party_names.size()); + for (const auto& name : job.party_names) names.emplace_back(name); + return coinbase::mpc::job_mp_t(job.self, std::move(names), job.transport); +} + +// Basic validation for the public MP job view. +// +// Shared across all multi-party public API wrappers. +inline error_t validate_job_mp(const coinbase::api::job_mp_t& job) { + const size_t n = job.party_names.size(); + if (n < 2) return coinbase::error(E_BADARG, "job must contain at least 2 parties"); + if (n > 64) return coinbase::error(E_RANGE, "at most 64 parties are supported"); + if (job.self < 0 || static_cast(job.self) >= n) return coinbase::error(E_BADARG, "invalid job.self"); + + std::unordered_set names; + names.reserve(n); + for (const auto& name : job.party_names) { + if (name.empty()) return coinbase::error(E_BADARG, "party name must be non-empty"); + if (!names.insert(name).second) return coinbase::error(E_BADARG, "duplicate party name"); + } + return SUCCESS; +} + +} // namespace coinbase::api::detail diff --git a/src/cbmpc/api/mem_util.h b/src/cbmpc/api/mem_util.h new file mode 100644 index 00000000..3fd08c9b --- /dev/null +++ b/src/cbmpc/api/mem_util.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +#include +#include + +namespace coinbase::api::detail { + +// Conservative size limits for untrusted byte blobs passed into the public API. +// +// These APIs frequently parse/deserialize opaque blobs into internal structures, +// which can otherwise lead to large allocations and CPU usage on attacker-controlled +// inputs (e.g., blobs received over the network). +inline constexpr int MAX_OPAQUE_BLOB_SIZE = 1 * 1024 * 1024; // 1 MiB +inline constexpr int MAX_CIPHERTEXT_BLOB_SIZE = 64 * 1024 * 1024; // 64 MiB +inline constexpr int MAX_MESSAGE_DIGEST_SIZE = 64; // 64 bytes (e.g., SHA-512 / SHA3-512) + +// Validate the basic invariants of a mem_t passed into the public API. +// +// Important: Do not use cb_assert() here. These checks are for *untrusted* +// inputs (including data that may be adversary-controlled) and must fail +// gracefully with an error code. +inline error_t validate_mem_arg(mem_t m, const char* name) { + if (m.size <= 0) return coinbase::error(E_BADARG, std::string("invalid ") + name, /*to_print_stack_trace=*/false); + if (!m.data) return coinbase::error(E_BADARG, std::string("invalid ") + name, /*to_print_stack_trace=*/false); + return SUCCESS; +} + +inline error_t validate_mem_vec_arg(const std::vector& ms, const char* name) { + for (size_t i = 0; i < ms.size(); i++) { + const error_t rv = validate_mem_arg(ms[i], name); + if (rv) return rv; + } + return SUCCESS; +} + +inline error_t validate_mem_arg_max_size(mem_t m, const char* name, int max_size) { + if (const error_t rv = validate_mem_arg(m, name)) return rv; + if (max_size < 0) return coinbase::error(E_BADARG, "invalid max_size", /*to_print_stack_trace=*/false); + if (m.size > max_size) + return coinbase::error(E_RANGE, std::string(name) + " too large", /*to_print_stack_trace=*/false); + return SUCCESS; +} + +inline error_t validate_mem_vec_arg_max_size(const std::vector& ms, const char* name, int max_size) { + for (size_t i = 0; i < ms.size(); i++) { + const error_t rv = validate_mem_arg_max_size(ms[i], name, max_size); + if (rv) return rv; + } + return SUCCESS; +} + +} // namespace coinbase::api::detail diff --git a/src/cbmpc/api/pve_base_pke.cpp b/src/cbmpc/api/pve_base_pke.cpp new file mode 100644 index 00000000..fdeb32d1 --- /dev/null +++ b/src/cbmpc/api/pve_base_pke.cpp @@ -0,0 +1,373 @@ +#include +#include +#include +#include +#include + +#include "curve_util.h" +#include "mem_util.h" +#include "pve_internal.h" + +namespace coinbase::api::pve { + +namespace { + +constexpr uint32_t pve_ciphertext_version_v1 = 1; + +using coinbase::api::detail::to_internal_curve; + +struct pve_ciphertext_blob_v1_t { + uint32_t version = pve_ciphertext_version_v1; + buf_t ct; // serialized `coinbase::mpc::ec_pve_t` + + void convert(coinbase::converter_t& c) { c.convert(version, ct); } +}; + +using detail::base_pke_bridge_t; +using detail::base_pke_dk_blob_v1_t; +using detail::base_pke_ek_blob_v1_t; +using detail::base_pke_key_type_v1; +using detail::ecies_p256_hsm_base_pke_t; +using detail::parse_dk_blob; +using detail::parse_ek_blob; +using detail::rsa_oaep_hsm_base_pke_t; + +class unified_key_blob_base_pke_t final : public base_pke_i { + public: + error_t encrypt(mem_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { + base_pke_ek_blob_v1_t blob; + error_t rv = parse_ek_blob(ek, blob); + if (rv) return rv; + + switch (static_cast(blob.key_type)) { + case base_pke_key_type_v1::rsa_oaep_2048: + return coinbase::mpc::pve_base_pke_rsa().encrypt(coinbase::mpc::pve_keyref(blob.rsa_ek), label, plain, rho, + out_ct); + case base_pke_key_type_v1::ecies_p256: + if (blob.ecies_ek.get_curve() != coinbase::crypto::curve_p256) + return coinbase::error(E_BADARG, "ECIES base PKE key must be on P-256"); + return coinbase::mpc::pve_base_pke_ecies().encrypt(coinbase::mpc::pve_keyref(blob.ecies_ek), label, plain, rho, + out_ct); + default: + return coinbase::error(E_FORMAT, "unsupported base PKE key type"); + } + } + + error_t decrypt(mem_t dk, mem_t label, mem_t ct, buf_t& out_plain) const override { + base_pke_dk_blob_v1_t blob; + error_t rv = parse_dk_blob(dk, blob); + if (rv) return rv; + + switch (static_cast(blob.key_type)) { + case base_pke_key_type_v1::rsa_oaep_2048: + return coinbase::mpc::pve_base_pke_rsa().decrypt(coinbase::mpc::pve_keyref(blob.rsa_dk), label, ct, out_plain); + case base_pke_key_type_v1::ecies_p256: + if (blob.ecies_dk.get_curve() != coinbase::crypto::curve_p256) + return coinbase::error(E_BADARG, "ECIES base PKE key must be on P-256"); + return coinbase::mpc::pve_base_pke_ecies().decrypt(coinbase::mpc::pve_keyref(blob.ecies_dk), label, ct, + out_plain); + default: + return coinbase::error(E_FORMAT, "unsupported base PKE key type"); + } + } +}; + +} // namespace + +const base_pke_i& base_pke_default() { + static const unified_key_blob_base_pke_t pke; + return pke; +} + +error_t generate_base_pke_rsa_keypair(buf_t& out_ek, buf_t& out_dk) { + coinbase::crypto::rsa_prv_key_t sk; + sk.generate(coinbase::crypto::RSA_KEY_LENGTH); + coinbase::crypto::rsa_pub_key_t pk = sk.pub(); + + base_pke_ek_blob_v1_t ek_blob; + ek_blob.key_type = static_cast(base_pke_key_type_v1::rsa_oaep_2048); + ek_blob.rsa_ek = pk; + + base_pke_dk_blob_v1_t dk_blob; + dk_blob.key_type = static_cast(base_pke_key_type_v1::rsa_oaep_2048); + dk_blob.rsa_dk = std::move(sk); + + out_ek = coinbase::convert(ek_blob); + out_dk = coinbase::convert(dk_blob); + return SUCCESS; +} + +error_t generate_base_pke_ecies_p256_keypair(buf_t& out_ek, buf_t& out_dk) { + coinbase::crypto::ecc_prv_key_t sk; + sk.generate(coinbase::crypto::curve_p256); + coinbase::crypto::ecc_pub_key_t pk = sk.pub(); + + base_pke_ek_blob_v1_t ek_blob; + ek_blob.key_type = static_cast(base_pke_key_type_v1::ecies_p256); + ek_blob.ecies_ek = pk; + + base_pke_dk_blob_v1_t dk_blob; + dk_blob.key_type = static_cast(base_pke_key_type_v1::ecies_p256); + dk_blob.ecies_dk = std::move(sk); + + out_ek = coinbase::convert(ek_blob); + out_dk = coinbase::convert(dk_blob); + return SUCCESS; +} + +error_t base_pke_ecies_p256_ek_from_oct(mem_t pub_key_oct, buf_t& out_ek) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(pub_key_oct, "pub_key_oct", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + error_t rv = UNINITIALIZED_ERROR; + + coinbase::crypto::ecc_pub_key_t pk; + if (rv = pk.from_oct(coinbase::crypto::curve_p256, pub_key_oct)) return rv; + if (rv = coinbase::crypto::curve_p256.check(pk)) return rv; + + base_pke_ek_blob_v1_t ek_blob; + ek_blob.key_type = static_cast(base_pke_key_type_v1::ecies_p256); + ek_blob.ecies_ek = std::move(pk); + + out_ek = coinbase::convert(ek_blob); + return SUCCESS; +} + +error_t base_pke_rsa_ek_from_modulus(mem_t modulus, buf_t& out_ek) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(modulus, "modulus", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + constexpr int kExpectedModulusBytes = coinbase::crypto::RSA_KEY_LENGTH / 8; + if (modulus.size != kExpectedModulusBytes) + return coinbase::error(E_BADARG, "modulus must be exactly 256 bytes (RSA-2048)"); + + constexpr unsigned long kDefaultExponent = 65537; + + BIGNUM* n_bn = BN_bin2bn(modulus.data, modulus.size, nullptr); + if (!n_bn) return coinbase::error(E_GENERAL, "BN_bin2bn(modulus) failed"); + + if (BN_is_zero(n_bn)) { + BN_free(n_bn); + return coinbase::error(E_BADARG, "modulus must not be zero"); + } + + BIGNUM* e_bn = BN_new(); + if (!e_bn) { + BN_free(n_bn); + return coinbase::error(E_GENERAL, "BN_new(e) failed"); + } + BN_set_word(e_bn, kDefaultExponent); + + coinbase::crypto::rsa_pub_key_t pk; + pk.set(n_bn, e_bn); + BN_free(n_bn); + BN_free(e_bn); + + base_pke_ek_blob_v1_t ek_blob; + ek_blob.key_type = static_cast(base_pke_key_type_v1::rsa_oaep_2048); + ek_blob.rsa_ek = std::move(pk); + + out_ek = coinbase::convert(ek_blob); + return SUCCESS; +} + +error_t encrypt(const base_pke_i& base_pke, curve_id curve, mem_t ek, mem_t label, mem_t x, buf_t& out_ciphertext) { + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(x, "x")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + // Defensive check: avoid large attacker-controlled allocations when converting `x` into a bignum. + const int max_x_size = icurve.order().get_bin_size(); + if (x.size > max_x_size) return coinbase::error(E_RANGE, "x too large", /*to_print_stack_trace=*/false); + + base_pke_bridge_t bridge(base_pke); + coinbase::mpc::ec_pve_t pve_ct; + + const coinbase::mem_t ek_mem(ek.data, ek.size); + const coinbase::crypto::bn_t x_bn = coinbase::crypto::bn_t::from_bin(x); + + out_ciphertext.free(); + rv = pve_ct.encrypt(bridge, coinbase::mpc::pve_keyref(ek_mem), label, icurve, x_bn); + if (rv) return rv; + + pve_ciphertext_blob_v1_t blob; + blob.ct = coinbase::convert(pve_ct); + out_ciphertext = coinbase::convert(blob); + return SUCCESS; +} + +error_t encrypt(curve_id curve, mem_t ek, mem_t label, mem_t x, buf_t& out_ciphertext) { + return encrypt(base_pke_default(), curve, ek, label, x, out_ciphertext); +} + +error_t verify(const base_pke_i& base_pke, curve_id curve, mem_t ek, mem_t ciphertext, mem_t Q_compressed, + mem_t label) { + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(Q_compressed, "Q_compressed")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + pve_ciphertext_blob_v1_t blob; + if (rv = coinbase::convert(blob, ciphertext)) return rv; + if (blob.version != pve_ciphertext_version_v1) return coinbase::error(E_FORMAT, "unsupported ciphertext version"); + + base_pke_bridge_t bridge(base_pke); + coinbase::mpc::ec_pve_t pve_ct; + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + if (pve_ct.get_Q().get_curve() != icurve) return coinbase::error(E_BADARG, "ciphertext curve mismatch"); + + coinbase::crypto::ecc_point_t Q_expected; + if (rv = Q_expected.from_bin(icurve, Q_compressed)) return coinbase::error(rv, "invalid Q"); + if (rv = icurve.check(Q_expected)) return coinbase::error(rv, "invalid Q"); + + const coinbase::mem_t ek_mem(ek.data, ek.size); + return pve_ct.verify(bridge, coinbase::mpc::pve_keyref(ek_mem), Q_expected, label); +} + +error_t verify(curve_id curve, mem_t ek, mem_t ciphertext, mem_t Q_compressed, mem_t label) { + return verify(base_pke_default(), curve, ek, ciphertext, Q_compressed, label); +} + +error_t decrypt(const base_pke_i& base_pke, curve_id curve, mem_t dk, mem_t ek, mem_t ciphertext, mem_t label, + buf_t& out_x) { + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(dk, "dk", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + pve_ciphertext_blob_v1_t blob; + if (rv = coinbase::convert(blob, ciphertext)) return rv; + if (blob.version != pve_ciphertext_version_v1) return coinbase::error(E_FORMAT, "unsupported ciphertext version"); + + base_pke_bridge_t bridge(base_pke); + coinbase::mpc::ec_pve_t pve_ct; + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + if (pve_ct.get_Q().get_curve() != icurve) return coinbase::error(E_BADARG, "ciphertext curve mismatch"); + + const coinbase::mem_t dk_mem(dk.data, dk.size); + const coinbase::mem_t ek_mem(ek.data, ek.size); + + coinbase::crypto::bn_t x_bn; + rv = pve_ct.decrypt(bridge, coinbase::mpc::pve_keyref(dk_mem), coinbase::mpc::pve_keyref(ek_mem), label, icurve, x_bn, + /*skip_verify=*/true); + if (rv) { + out_x.free(); + return rv; + } + + out_x = x_bn.to_bin(icurve.order().get_bin_size()); + return SUCCESS; +} + +error_t decrypt(curve_id curve, mem_t dk, mem_t ek, mem_t ciphertext, mem_t label, buf_t& out_x) { + return decrypt(base_pke_default(), curve, dk, ek, ciphertext, label, out_x); +} + +error_t decrypt_rsa_oaep_hsm(curve_id curve, mem_t dk_handle, mem_t ek, mem_t ciphertext, mem_t label, + const rsa_oaep_hsm_decap_cb_t& cb, buf_t& out_x) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg(dk_handle, "dk_handle")) return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (!cb.decap) return coinbase::error(E_BADARG, "missing HSM RSA decap callback"); + + base_pke_ek_blob_v1_t ek_blob; + error_t rv = parse_ek_blob(ek, ek_blob); + if (rv) return rv; + if (static_cast(ek_blob.key_type) != base_pke_key_type_v1::rsa_oaep_2048) + return coinbase::error(E_BADARG, "expected RSA base PKE public key"); + + rsa_oaep_hsm_base_pke_t base_pke(cb); + return decrypt(base_pke, curve, dk_handle, ek, ciphertext, label, out_x); +} + +error_t decrypt_ecies_p256_hsm(curve_id curve, mem_t dk_handle, mem_t ek, mem_t ciphertext, mem_t label, + const ecies_p256_hsm_ecdh_cb_t& cb, buf_t& out_x) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg(dk_handle, "dk_handle")) return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (!cb.ecdh) return coinbase::error(E_BADARG, "missing HSM ECIES ECDH callback"); + + base_pke_ek_blob_v1_t ek_blob; + error_t rv = parse_ek_blob(ek, ek_blob); + if (rv) return rv; + if (static_cast(ek_blob.key_type) != base_pke_key_type_v1::ecies_p256) + return coinbase::error(E_BADARG, "expected ECIES(P-256) base PKE public key"); + if (ek_blob.ecies_ek.get_curve() != coinbase::crypto::curve_p256) + return coinbase::error(E_BADARG, "ECIES base PKE key must be on P-256"); + + ecies_p256_hsm_base_pke_t base_pke(ek_blob.ecies_ek.to_oct(), cb); + return decrypt(base_pke, curve, dk_handle, ek, ciphertext, label, out_x); +} + +error_t get_public_key_compressed(mem_t ciphertext, buf_t& out_Q_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + error_t rv = UNINITIALIZED_ERROR; + + pve_ciphertext_blob_v1_t blob; + if (rv = coinbase::convert(blob, ciphertext)) return rv; + if (blob.version != pve_ciphertext_version_v1) return coinbase::error(E_FORMAT, "unsupported ciphertext version"); + + coinbase::mpc::ec_pve_t pve_ct; // base PKE not used for extraction + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + out_Q_compressed = pve_ct.get_Q().to_compressed_bin(); + return SUCCESS; +} + +error_t get_Label(mem_t ciphertext, buf_t& out_label) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + error_t rv = UNINITIALIZED_ERROR; + + pve_ciphertext_blob_v1_t blob; + if (rv = coinbase::convert(blob, ciphertext)) return rv; + if (blob.version != pve_ciphertext_version_v1) return coinbase::error(E_FORMAT, "unsupported ciphertext version"); + + coinbase::mpc::ec_pve_t pve_ct; // base PKE not used for extraction + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + out_label = pve_ct.get_Label(); + return SUCCESS; +} + +} // namespace coinbase::api::pve diff --git a/src/cbmpc/api/pve_batch_ac.cpp b/src/cbmpc/api/pve_batch_ac.cpp new file mode 100644 index 00000000..2f3a4c98 --- /dev/null +++ b/src/cbmpc/api/pve_batch_ac.cpp @@ -0,0 +1,405 @@ +#include +#include +#include +#include + +#include "access_structure_util.h" +#include "curve_util.h" +#include "mem_util.h" +#include "pve_internal.h" + +namespace coinbase::api::pve { + +namespace { + +constexpr uint32_t pve_ac_ciphertext_version_v1 = 1; +// Defensive limit for untrusted inputs. This bounds allocations/work when ciphertexts come from the network. +static constexpr int MAX_BATCH_COUNT = 100000; + +using coinbase::api::detail::to_internal_curve; + +struct pve_ac_ciphertext_blob_v1_t { + uint32_t version = pve_ac_ciphertext_version_v1; + uint32_t batch_count = 0; + buf_t ct; // serialized `coinbase::mpc::ec_pve_ac_t` + + void convert(coinbase::converter_t& c) { c.convert(version, batch_count, ct); } +}; + +static error_t parse_ac_ciphertext(mem_t ciphertext, pve_ac_ciphertext_blob_v1_t& out_blob) { + error_t rv = coinbase::convert(out_blob, ciphertext); + if (rv) return rv; + if (out_blob.version != pve_ac_ciphertext_version_v1) + return coinbase::error(E_FORMAT, "unsupported ciphertext version"); + if (out_blob.batch_count == 0) return coinbase::error(E_FORMAT, "invalid batch count"); + if (out_blob.batch_count > static_cast(MAX_BATCH_COUNT)) return coinbase::error(E_RANGE, "batch too large"); + return SUCCESS; +} + +static error_t to_internal_ac_and_leaves(const access_structure_t& ac_in, const coinbase::crypto::ecurve_t& curve, + coinbase::crypto::ss::ac_owned_t& out_ac, + std::set& out_leaf_names) { + out_leaf_names.clear(); + error_t rv = coinbase::api::detail::collect_leaf_names(ac_in, out_leaf_names); + if (rv) return rv; + if (out_leaf_names.empty()) return coinbase::error(E_BADARG, "access_structure: missing leaves"); + + std::vector party_names; + party_names.reserve(out_leaf_names.size()); + for (const auto& s : out_leaf_names) party_names.emplace_back(s); + + return coinbase::api::detail::to_internal_access_structure(ac_in, party_names, curve, out_ac); +} + +static error_t validate_leaf_keys_exact(const std::set& leaf_names, const leaf_keys_t& keys, + const char* what) { + if (leaf_names.empty()) return coinbase::error(E_BADARG, "access_structure: missing leaves"); + if (keys.size() != leaf_names.size()) return coinbase::error(E_BADARG, std::string(what) + ": key set mismatch"); + + for (const auto& [name_view, key] : keys) { + if (name_view.empty()) return coinbase::error(E_BADARG, std::string(what) + ": leaf name must be non-empty"); + if (key.size < 0) return coinbase::error(E_BADARG, std::string(what) + ": invalid key size"); + if (key.size > 0 && !key.data) return coinbase::error(E_BADARG, std::string(what) + ": missing key bytes"); + if (key.size > coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE) + return coinbase::error(E_RANGE, std::string(what) + ": key too large"); + + if (leaf_names.count(std::string(name_view)) == 0) + return coinbase::error(E_BADARG, std::string(what) + ": unknown leaf name"); + } + + for (const auto& leaf : leaf_names) { + if (keys.find(std::string_view(leaf)) == keys.end()) + return coinbase::error(E_BADARG, std::string(what) + ": missing key for leaf " + leaf); + } + + return SUCCESS; +} + +static error_t validate_quorum_shares_subset(const std::set& leaf_names, const leaf_shares_t& shares, + const char* what) { + if (shares.empty()) return coinbase::error(E_BADARG, std::string(what) + ": quorum_shares must be non-empty"); + + for (const auto& [name_view, share] : shares) { + if (name_view.empty()) return coinbase::error(E_BADARG, std::string(what) + ": leaf name must be non-empty"); + if (share.size < 0) return coinbase::error(E_BADARG, std::string(what) + ": invalid share size"); + if (share.size > 0 && !share.data) return coinbase::error(E_BADARG, std::string(what) + ": missing share bytes"); + if (share.size > coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE) + return coinbase::error(E_RANGE, std::string(what) + ": share too large"); + if (leaf_names.count(std::string(name_view)) == 0) + return coinbase::error(E_BADARG, std::string(what) + ": unknown leaf name"); + } + return SUCCESS; +} + +static void build_internal_leaf_key_ptrs(const leaf_keys_t& in, std::map& out_key_mems, + coinbase::mpc::ec_pve_ac_t::pks_t& out_ptrs) { + out_key_mems.clear(); + out_ptrs.clear(); + for (const auto& [name_view, key] : in) { + out_key_mems.emplace(std::string(name_view), coinbase::mem_t(key.data, key.size)); + } + for (auto& [name, key_mem] : out_key_mems) out_ptrs.emplace(name, coinbase::mpc::pve_keyref(key_mem)); +} + +} // namespace + +error_t encrypt_ac(const base_pke_i& base_pke, curve_id curve, const access_structure_t& ac, const leaf_keys_t& ac_pks, + mem_t label, const std::vector& xs, buf_t& out_ciphertext) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_vec_arg(xs, "xs")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + if (xs.empty()) return coinbase::error(E_BADARG, "batch_count must be positive"); + if (xs.size() > static_cast(MAX_BATCH_COUNT)) return coinbase::error(E_RANGE, "batch too large"); + + // Defensive check: avoid large attacker-controlled allocations when converting `xs[i]` into bignums. + const int max_x_size = icurve.order().get_bin_size(); + for (const auto& x : xs) { + if (x.size > max_x_size) return coinbase::error(E_RANGE, "xs element too large", /*to_print_stack_trace=*/false); + } + + coinbase::crypto::ss::ac_owned_t ac_internal; + std::set leaf_names; + if (rv = to_internal_ac_and_leaves(ac, icurve, ac_internal, leaf_names)) return rv; + if (rv = validate_leaf_keys_exact(leaf_names, ac_pks, "ac_pks")) return rv; + + std::vector xs_bn; + xs_bn.reserve(xs.size()); + for (const auto& x : xs) xs_bn.push_back(coinbase::crypto::bn_t::from_bin(x)); + + detail::base_pke_bridge_t bridge(base_pke); + + std::map key_mems; + coinbase::mpc::ec_pve_ac_t::pks_t pk_ptrs; + build_internal_leaf_key_ptrs(ac_pks, key_mems, pk_ptrs); + + coinbase::mpc::ec_pve_ac_t pve_ct; + + out_ciphertext.free(); + rv = pve_ct.encrypt(bridge, ac_internal, pk_ptrs, label, icurve, xs_bn); + if (rv) return rv; + + pve_ac_ciphertext_blob_v1_t blob; + blob.batch_count = static_cast(xs.size()); + blob.ct = coinbase::convert(pve_ct); + + out_ciphertext = coinbase::convert(blob); + return SUCCESS; +} + +error_t encrypt_ac(curve_id curve, const access_structure_t& ac, const leaf_keys_t& ac_pks, mem_t label, + const std::vector& xs, buf_t& out_ciphertext) { + return encrypt_ac(base_pke_default(), curve, ac, ac_pks, label, xs, out_ciphertext); +} + +error_t verify_ac(const base_pke_i& base_pke, curve_id curve, const access_structure_t& ac, const leaf_keys_t& ac_pks, + mem_t ciphertext, const std::vector& Qs_compressed, mem_t label) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_vec_arg(Qs_compressed, "Qs_compressed")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::crypto::ss::ac_owned_t ac_internal; + std::set leaf_names; + if (rv = to_internal_ac_and_leaves(ac, icurve, ac_internal, leaf_names)) return rv; + if (rv = validate_leaf_keys_exact(leaf_names, ac_pks, "ac_pks")) return rv; + + pve_ac_ciphertext_blob_v1_t blob; + if (rv = parse_ac_ciphertext(ciphertext, blob)) return rv; + + if (Qs_compressed.size() != static_cast(blob.batch_count)) + return coinbase::error(E_BADARG, "Q count mismatch"); + + coinbase::mpc::ec_pve_ac_t pve_ct; + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + // Validate ciphertext curve. + for (const auto& q : pve_ct.get_Q()) { + if (q.get_curve() != icurve) return coinbase::error(E_BADARG, "ciphertext curve mismatch"); + } + + std::vector Q_expected; + Q_expected.resize(Qs_compressed.size()); + for (size_t i = 0; i < Qs_compressed.size(); i++) { + if (rv = Q_expected[i].from_bin(icurve, Qs_compressed[i])) return coinbase::error(rv, "invalid Q"); + if (rv = icurve.check(Q_expected[i])) return coinbase::error(rv, "invalid Q"); + } + + detail::base_pke_bridge_t bridge(base_pke); + + std::map key_mems; + coinbase::mpc::ec_pve_ac_t::pks_t pk_ptrs; + build_internal_leaf_key_ptrs(ac_pks, key_mems, pk_ptrs); + + return pve_ct.verify(bridge, ac_internal, pk_ptrs, Q_expected, label); +} + +error_t verify_ac(curve_id curve, const access_structure_t& ac, const leaf_keys_t& ac_pks, mem_t ciphertext, + const std::vector& Qs_compressed, mem_t label) { + return verify_ac(base_pke_default(), curve, ac, ac_pks, ciphertext, Qs_compressed, label); +} + +error_t partial_decrypt_ac_attempt(const base_pke_i& base_pke, curve_id curve, const access_structure_t& ac, + mem_t ciphertext, int attempt_index, std::string_view leaf_name, mem_t dk, + mem_t label, buf_t& out_share) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(dk, "dk", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + if (leaf_name.empty()) return coinbase::error(E_BADARG, "leaf_name must be non-empty"); + + coinbase::crypto::ss::ac_owned_t ac_internal; + std::set leaf_names; + if (rv = to_internal_ac_and_leaves(ac, icurve, ac_internal, leaf_names)) return rv; + if (leaf_names.count(std::string(leaf_name)) == 0) return coinbase::error(E_BADARG, "unknown leaf_name"); + + pve_ac_ciphertext_blob_v1_t blob; + if (rv = parse_ac_ciphertext(ciphertext, blob)) return rv; + + coinbase::mpc::ec_pve_ac_t pve_ct; + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + for (const auto& q : pve_ct.get_Q()) { + if (q.get_curve() != icurve) return coinbase::error(E_BADARG, "ciphertext curve mismatch"); + } + + detail::base_pke_bridge_t bridge(base_pke); + + const coinbase::mem_t dk_mem(dk.data, dk.size); + coinbase::crypto::bn_t share_bn; + rv = pve_ct.party_decrypt_row(bridge, ac_internal, attempt_index, std::string(leaf_name), + coinbase::mpc::pve_keyref(dk_mem), label, share_bn); + if (rv) { + out_share.free(); + return rv; + } + + out_share = share_bn.to_bin(icurve.order().get_bin_size()); + return SUCCESS; +} + +error_t partial_decrypt_ac_attempt(curve_id curve, const access_structure_t& ac, mem_t ciphertext, int attempt_index, + std::string_view leaf_name, mem_t dk, mem_t label, buf_t& out_share) { + return partial_decrypt_ac_attempt(base_pke_default(), curve, ac, ciphertext, attempt_index, leaf_name, dk, label, + out_share); +} + +error_t partial_decrypt_ac_attempt_rsa_oaep_hsm(curve_id curve, const access_structure_t& ac, mem_t ciphertext, + int attempt_index, std::string_view leaf_name, mem_t dk_handle, + mem_t ek, mem_t label, const rsa_oaep_hsm_decap_cb_t& cb, + buf_t& out_share) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg(dk_handle, "dk_handle")) return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (!cb.decap) return coinbase::error(E_BADARG, "missing HSM RSA decap callback"); + + detail::base_pke_ek_blob_v1_t ek_blob; + error_t rv = detail::parse_ek_blob(ek, ek_blob); + if (rv) return rv; + if (static_cast(ek_blob.key_type) != detail::base_pke_key_type_v1::rsa_oaep_2048) + return coinbase::error(E_BADARG, "expected RSA base PKE public key"); + + detail::rsa_oaep_hsm_base_pke_t base_pke(cb); + return partial_decrypt_ac_attempt(base_pke, curve, ac, ciphertext, attempt_index, leaf_name, dk_handle, label, + out_share); +} + +error_t partial_decrypt_ac_attempt_ecies_p256_hsm(curve_id curve, const access_structure_t& ac, mem_t ciphertext, + int attempt_index, std::string_view leaf_name, mem_t dk_handle, + mem_t ek, mem_t label, const ecies_p256_hsm_ecdh_cb_t& cb, + buf_t& out_share) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg(dk_handle, "dk_handle")) return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (!cb.ecdh) return coinbase::error(E_BADARG, "missing HSM ECIES ECDH callback"); + + detail::base_pke_ek_blob_v1_t ek_blob; + error_t rv = detail::parse_ek_blob(ek, ek_blob); + if (rv) return rv; + if (static_cast(ek_blob.key_type) != detail::base_pke_key_type_v1::ecies_p256) + return coinbase::error(E_BADARG, "expected ECIES(P-256) base PKE public key"); + if (ek_blob.ecies_ek.get_curve() != coinbase::crypto::curve_p256) + return coinbase::error(E_BADARG, "ECIES base PKE key must be on P-256"); + + detail::ecies_p256_hsm_base_pke_t base_pke(ek_blob.ecies_ek.to_oct(), cb); + return partial_decrypt_ac_attempt(base_pke, curve, ac, ciphertext, attempt_index, leaf_name, dk_handle, label, + out_share); +} + +error_t combine_ac(const base_pke_i& base_pke, curve_id curve, const access_structure_t& ac, mem_t ciphertext, + int attempt_index, mem_t label, const leaf_shares_t& quorum_shares, std::vector& out_xs) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::crypto::ss::ac_owned_t ac_internal; + std::set leaf_names; + if (rv = to_internal_ac_and_leaves(ac, icurve, ac_internal, leaf_names)) return rv; + if (rv = validate_quorum_shares_subset(leaf_names, quorum_shares, "quorum_shares")) return rv; + + pve_ac_ciphertext_blob_v1_t blob; + if (rv = parse_ac_ciphertext(ciphertext, blob)) return rv; + + coinbase::mpc::ec_pve_ac_t pve_ct; + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + for (const auto& q : pve_ct.get_Q()) { + if (q.get_curve() != icurve) return coinbase::error(E_BADARG, "ciphertext curve mismatch"); + } + + const int expected_share_size = icurve.order().get_bin_size(); + for (const auto& [name_view, share_bytes] : quorum_shares) { + if (share_bytes.size != expected_share_size) return coinbase::error(E_BADARG, "quorum_shares: invalid share size"); + } + + std::map quorum_bn; + for (const auto& [name_view, share_bytes] : quorum_shares) { + quorum_bn.emplace(std::string(name_view), coinbase::crypto::bn_t::from_bin(share_bytes)); + } + + detail::base_pke_bridge_t bridge(base_pke); + coinbase::mpc::ec_pve_ac_t::pks_t pk_ptrs; + + std::vector xs_bn; + rv = pve_ct.aggregate_to_restore_row(bridge, ac_internal, attempt_index, label, quorum_bn, xs_bn, + /*skip_verify=*/true, pk_ptrs); + if (rv) { + out_xs.clear(); + return rv; + } + + std::vector out_local; + out_local.resize(xs_bn.size()); + const int out_len = icurve.order().get_bin_size(); + for (size_t i = 0; i < xs_bn.size(); i++) out_local[i] = xs_bn[i].to_bin(out_len); + out_xs = std::move(out_local); + return SUCCESS; +} + +error_t combine_ac(curve_id curve, const access_structure_t& ac, mem_t ciphertext, int attempt_index, mem_t label, + const leaf_shares_t& quorum_shares, std::vector& out_xs) { + return combine_ac(base_pke_default(), curve, ac, ciphertext, attempt_index, label, quorum_shares, out_xs); +} + +error_t get_ac_batch_count(mem_t ciphertext, int& out_batch_count) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + pve_ac_ciphertext_blob_v1_t blob; + error_t rv = parse_ac_ciphertext(ciphertext, blob); + if (rv) return rv; + out_batch_count = static_cast(blob.batch_count); + return SUCCESS; +} + +error_t get_public_keys_compressed_ac(mem_t ciphertext, std::vector& out_Qs_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + error_t rv = UNINITIALIZED_ERROR; + + pve_ac_ciphertext_blob_v1_t blob; + if (rv = parse_ac_ciphertext(ciphertext, blob)) return rv; + + coinbase::mpc::ec_pve_ac_t pve_ct; // base PKE not used for extraction + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + std::vector out_local; + out_local.reserve(pve_ct.get_Q().size()); + for (const auto& q : pve_ct.get_Q()) out_local.push_back(q.to_compressed_bin()); + + out_Qs_compressed = std::move(out_local); + return SUCCESS; +} + +} // namespace coinbase::api::pve diff --git a/src/cbmpc/api/pve_batch_single_recipient.cpp b/src/cbmpc/api/pve_batch_single_recipient.cpp new file mode 100644 index 00000000..2285b0fa --- /dev/null +++ b/src/cbmpc/api/pve_batch_single_recipient.cpp @@ -0,0 +1,288 @@ +#include +#include +#include +#include +#include + +#include "curve_util.h" +#include "mem_util.h" +#include "pve_internal.h" + +namespace coinbase::api::pve { + +namespace { + +constexpr uint32_t pve_batch_ciphertext_version_v1 = 1; + +using coinbase::api::detail::to_internal_curve; + +struct pve_batch_ciphertext_blob_v1_t { + uint32_t version = pve_batch_ciphertext_version_v1; + uint32_t batch_count = 0; + buf_t ct; // serialized `coinbase::mpc::ec_pve_batch_t` + + void convert(coinbase::converter_t& c) { c.convert(version, batch_count, ct); } +}; + +using detail::base_pke_bridge_t; +using detail::base_pke_ek_blob_v1_t; +using detail::base_pke_key_type_v1; +using detail::ecies_p256_hsm_base_pke_t; +using detail::parse_ek_blob; +using detail::rsa_oaep_hsm_base_pke_t; + +static error_t parse_batch_ciphertext(mem_t ciphertext, pve_batch_ciphertext_blob_v1_t& out_blob) { + error_t rv = coinbase::convert(out_blob, ciphertext); + if (rv) return rv; + if (out_blob.version != pve_batch_ciphertext_version_v1) + return coinbase::error(E_FORMAT, "unsupported ciphertext version"); + if (out_blob.batch_count == 0) return coinbase::error(E_FORMAT, "invalid batch count"); + if (out_blob.batch_count > static_cast(coinbase::mpc::ec_pve_batch_t::MAX_BATCH_COUNT)) + return coinbase::error(E_RANGE, "batch too large"); + return SUCCESS; +} + +} // namespace + +error_t encrypt_batch(const base_pke_i& base_pke, curve_id curve, mem_t ek, mem_t label, const std::vector& xs, + buf_t& out_ciphertext) { + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_vec_arg(xs, "xs")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + if (xs.empty()) return coinbase::error(E_BADARG, "batch_count must be positive"); + if (xs.size() > static_cast(coinbase::mpc::ec_pve_batch_t::MAX_BATCH_COUNT)) + return coinbase::error(E_RANGE, "batch too large"); + + // Defensive check: avoid large attacker-controlled allocations when converting `xs[i]` into bignums. + const int max_x_size = icurve.order().get_bin_size(); + for (const auto& x : xs) { + if (x.size > max_x_size) return coinbase::error(E_RANGE, "xs element too large", /*to_print_stack_trace=*/false); + } + + const int n = static_cast(xs.size()); + std::vector x_bn; + x_bn.reserve(static_cast(n)); + for (int i = 0; i < n; i++) x_bn.push_back(coinbase::crypto::bn_t::from_bin(xs[static_cast(i)])); + + base_pke_bridge_t bridge(base_pke); + const coinbase::mem_t ek_mem(ek.data, ek.size); + + coinbase::mpc::ec_pve_batch_t pve_ct(n); + + out_ciphertext.free(); + rv = pve_ct.encrypt(bridge, coinbase::mpc::pve_keyref(ek_mem), label, icurve, x_bn); + if (rv) return rv; + + pve_batch_ciphertext_blob_v1_t blob; + blob.batch_count = static_cast(n); + blob.ct = coinbase::convert(pve_ct); + out_ciphertext = coinbase::convert(blob); + return SUCCESS; +} + +error_t encrypt_batch(curve_id curve, mem_t ek, mem_t label, const std::vector& xs, buf_t& out_ciphertext) { + return encrypt_batch(base_pke_default(), curve, ek, label, xs, out_ciphertext); +} + +error_t verify_batch(const base_pke_i& base_pke, curve_id curve, mem_t ek, mem_t ciphertext, + const std::vector& Qs_compressed, mem_t label) { + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_vec_arg(Qs_compressed, "Qs_compressed")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + pve_batch_ciphertext_blob_v1_t blob; + if (rv = parse_batch_ciphertext(ciphertext, blob)) return rv; + + const int n = static_cast(blob.batch_count); + if (static_cast(n) != Qs_compressed.size()) return coinbase::error(E_BADARG, "Q count mismatch"); + + base_pke_bridge_t bridge(base_pke); + coinbase::mpc::ec_pve_batch_t pve_ct(n); + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + for (const auto& q : pve_ct.get_Qs()) { + if (q.get_curve() != icurve) return coinbase::error(E_BADARG, "ciphertext curve mismatch"); + } + + std::vector Q_expected; + Q_expected.resize(static_cast(n)); + for (int i = 0; i < n; i++) { + if (rv = Q_expected[static_cast(i)].from_bin(icurve, Qs_compressed[static_cast(i)])) + return coinbase::error(rv, "invalid Q"); + if (rv = icurve.check(Q_expected[static_cast(i)])) return coinbase::error(rv, "invalid Q"); + } + + const coinbase::mem_t ek_mem(ek.data, ek.size); + return pve_ct.verify(bridge, coinbase::mpc::pve_keyref(ek_mem), Q_expected, label); +} + +error_t verify_batch(curve_id curve, mem_t ek, mem_t ciphertext, const std::vector& Qs_compressed, mem_t label) { + return verify_batch(base_pke_default(), curve, ek, ciphertext, Qs_compressed, label); +} + +error_t decrypt_batch(const base_pke_i& base_pke, curve_id curve, mem_t dk, mem_t ek, mem_t ciphertext, mem_t label, + std::vector& out_xs) { + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(dk, "dk", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + error_t rv = UNINITIALIZED_ERROR; + + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + pve_batch_ciphertext_blob_v1_t blob; + if (rv = parse_batch_ciphertext(ciphertext, blob)) return rv; + + const int n = static_cast(blob.batch_count); + + base_pke_bridge_t bridge(base_pke); + coinbase::mpc::ec_pve_batch_t pve_ct(n); + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + for (const auto& q : pve_ct.get_Qs()) { + if (q.get_curve() != icurve) return coinbase::error(E_BADARG, "ciphertext curve mismatch"); + } + + const coinbase::mem_t dk_mem(dk.data, dk.size); + const coinbase::mem_t ek_mem(ek.data, ek.size); + + std::vector xs_bn; + rv = pve_ct.decrypt(bridge, coinbase::mpc::pve_keyref(dk_mem), coinbase::mpc::pve_keyref(ek_mem), label, icurve, + xs_bn, /*skip_verify=*/true); + if (rv) { + out_xs.clear(); + return rv; + } + + std::vector out_local; + out_local.resize(static_cast(n)); + const int out_len = icurve.order().get_bin_size(); + for (int i = 0; i < n; i++) out_local[static_cast(i)] = xs_bn[static_cast(i)].to_bin(out_len); + + out_xs = std::move(out_local); + return SUCCESS; +} + +error_t decrypt_batch(curve_id curve, mem_t dk, mem_t ek, mem_t ciphertext, mem_t label, std::vector& out_xs) { + return decrypt_batch(base_pke_default(), curve, dk, ek, ciphertext, label, out_xs); +} + +error_t decrypt_batch_rsa_oaep_hsm(curve_id curve, mem_t dk_handle, mem_t ek, mem_t ciphertext, mem_t label, + const rsa_oaep_hsm_decap_cb_t& cb, std::vector& out_xs) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg(dk_handle, "dk_handle")) return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (!cb.decap) return coinbase::error(E_BADARG, "missing HSM RSA decap callback"); + + base_pke_ek_blob_v1_t ek_blob; + error_t rv = parse_ek_blob(ek, ek_blob); + if (rv) return rv; + if (static_cast(ek_blob.key_type) != base_pke_key_type_v1::rsa_oaep_2048) + return coinbase::error(E_BADARG, "expected RSA base PKE public key"); + + rsa_oaep_hsm_base_pke_t base_pke(cb); + return decrypt_batch(base_pke, curve, dk_handle, ek, ciphertext, label, out_xs); +} + +error_t decrypt_batch_ecies_p256_hsm(curve_id curve, mem_t dk_handle, mem_t ek, mem_t ciphertext, mem_t label, + const ecies_p256_hsm_ecdh_cb_t& cb, std::vector& out_xs) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg(dk_handle, "dk_handle")) return rv; + if (const error_t rv = + coinbase::api::detail::validate_mem_arg_max_size(ek, "ek", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (!cb.ecdh) return coinbase::error(E_BADARG, "missing HSM ECIES ECDH callback"); + + base_pke_ek_blob_v1_t ek_blob; + error_t rv = parse_ek_blob(ek, ek_blob); + if (rv) return rv; + if (static_cast(ek_blob.key_type) != base_pke_key_type_v1::ecies_p256) + return coinbase::error(E_BADARG, "expected ECIES(P-256) base PKE public key"); + if (ek_blob.ecies_ek.get_curve() != coinbase::crypto::curve_p256) + return coinbase::error(E_BADARG, "ECIES base PKE key must be on P-256"); + + ecies_p256_hsm_base_pke_t base_pke(ek_blob.ecies_ek.to_oct(), cb); + return decrypt_batch(base_pke, curve, dk_handle, ek, ciphertext, label, out_xs); +} + +error_t get_batch_count(mem_t ciphertext, int& out_batch_count) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + pve_batch_ciphertext_blob_v1_t blob; + error_t rv = parse_batch_ciphertext(ciphertext, blob); + if (rv) return rv; + out_batch_count = static_cast(blob.batch_count); + return SUCCESS; +} + +error_t get_public_keys_compressed_batch(mem_t ciphertext, std::vector& out_Qs_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + error_t rv = UNINITIALIZED_ERROR; + + pve_batch_ciphertext_blob_v1_t blob; + if (rv = parse_batch_ciphertext(ciphertext, blob)) return rv; + + const int n = static_cast(blob.batch_count); + coinbase::mpc::ec_pve_batch_t pve_ct(n); // base PKE not used for extraction + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + std::vector out_local; + out_local.reserve(static_cast(n)); + for (const auto& q : pve_ct.get_Qs()) out_local.push_back(q.to_compressed_bin()); + + out_Qs_compressed = std::move(out_local); + return SUCCESS; +} + +error_t get_Label_batch(mem_t ciphertext, buf_t& out_label) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + error_t rv = UNINITIALIZED_ERROR; + + pve_batch_ciphertext_blob_v1_t blob; + if (rv = parse_batch_ciphertext(ciphertext, blob)) return rv; + + const int n = static_cast(blob.batch_count); + coinbase::mpc::ec_pve_batch_t pve_ct(n); // base PKE not used for extraction + if (rv = coinbase::convert(pve_ct, blob.ct)) return rv; + + out_label = pve_ct.get_Label(); + return SUCCESS; +} + +} // namespace coinbase::api::pve diff --git a/src/cbmpc/api/pve_internal.h b/src/cbmpc/api/pve_internal.h new file mode 100644 index 00000000..f1a6629c --- /dev/null +++ b/src/cbmpc/api/pve_internal.h @@ -0,0 +1,168 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace coinbase::api::pve::detail { + +// Bridge from the public `coinbase::api::pve::base_pke_i` to the internal +// `coinbase::mpc::pve_base_pke_i`. +class base_pke_bridge_t final : public coinbase::mpc::pve_base_pke_i { + public: + explicit base_pke_bridge_t(const base_pke_i& base_pke) : base_pke_(base_pke) {} + + error_t encrypt(coinbase::mpc::pve_keyref_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { + const auto* ek_mem = ek.get(); + if (!ek_mem) return coinbase::error(E_BADARG, "invalid encryption key"); + return base_pke_.encrypt(*ek_mem, label, plain, rho, out_ct); + } + + error_t decrypt(coinbase::mpc::pve_keyref_t dk, mem_t label, mem_t ct, buf_t& out_plain) const override { + const auto* dk_mem = dk.get(); + if (!dk_mem) return coinbase::error(E_BADARG, "invalid decryption key"); + return base_pke_.decrypt(*dk_mem, label, ct, out_plain); + } + + private: + const base_pke_i& base_pke_; +}; + +// --------------------------------------------------------------------------- +// Built-in base PKE key blob format (internal implementation detail) +// --------------------------------------------------------------------------- + +constexpr uint32_t base_pke_key_blob_version_v1 = 1; + +enum class base_pke_key_type_v1 : uint32_t { + rsa_oaep_2048 = 1, + ecies_p256 = 2, +}; + +struct base_pke_ek_blob_v1_t { + uint32_t version = base_pke_key_blob_version_v1; + uint32_t key_type = static_cast(base_pke_key_type_v1::rsa_oaep_2048); + + coinbase::crypto::rsa_pub_key_t rsa_ek; + coinbase::crypto::ecc_pub_key_t ecies_ek; + + void convert(coinbase::converter_t& c) { + c.convert(version, key_type); + switch (static_cast(key_type)) { + case base_pke_key_type_v1::rsa_oaep_2048: + c.convert(rsa_ek); + return; + case base_pke_key_type_v1::ecies_p256: + c.convert(ecies_ek); + return; + default: + c.set_error(); + return; + } + } +}; + +struct base_pke_dk_blob_v1_t { + uint32_t version = base_pke_key_blob_version_v1; + uint32_t key_type = static_cast(base_pke_key_type_v1::rsa_oaep_2048); + + coinbase::crypto::rsa_prv_key_t rsa_dk; + coinbase::crypto::ecc_prv_key_t ecies_dk; + + void convert(coinbase::converter_t& c) { + c.convert(version, key_type); + switch (static_cast(key_type)) { + case base_pke_key_type_v1::rsa_oaep_2048: + c.convert(rsa_dk); + return; + case base_pke_key_type_v1::ecies_p256: + c.convert(ecies_dk); + return; + default: + c.set_error(); + return; + } + } +}; + +inline error_t parse_ek_blob(mem_t ek, base_pke_ek_blob_v1_t& out) { + error_t rv = coinbase::convert(out, ek); + if (rv) return rv; + if (out.version != base_pke_key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported base PKE key version"); + return SUCCESS; +} + +inline error_t parse_dk_blob(mem_t dk, base_pke_dk_blob_v1_t& out) { + error_t rv = coinbase::convert(out, dk); + if (rv) return rv; + if (out.version != base_pke_key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported base PKE key version"); + return SUCCESS; +} + +// --------------------------------------------------------------------------- +// Built-in HSM base PKE adapters (KEM decapsulation callback only) +// --------------------------------------------------------------------------- + +class rsa_oaep_hsm_base_pke_t final : public base_pke_i { + public: + explicit rsa_oaep_hsm_base_pke_t(const rsa_oaep_hsm_decap_cb_t& cb) : cb_(cb) {} + + error_t encrypt(mem_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { + base_pke_ek_blob_v1_t blob; + error_t rv = parse_ek_blob(ek, blob); + if (rv) return rv; + if (static_cast(blob.key_type) != base_pke_key_type_v1::rsa_oaep_2048) + return coinbase::error(E_BADARG, "RSA-OAEP HSM decrypt requires an RSA base PKE public key"); + return coinbase::mpc::pve_base_pke_rsa().encrypt(coinbase::mpc::pve_keyref(blob.rsa_ek), label, plain, rho, out_ct); + } + + error_t decrypt(mem_t dk_handle, mem_t label, mem_t ct, buf_t& out_plain) const override { + if (!cb_.decap) return coinbase::error(E_BADARG, "missing HSM RSA decap callback"); + coinbase::mpc::pve_rsa_oaep_hsm_dk_t dk; + dk.dk_handle = dk_handle; + dk.ctx = cb_.ctx; + dk.decap = cb_.decap; + return coinbase::mpc::pve_base_pke_rsa_oaep_hsm().decrypt(coinbase::mpc::pve_keyref(dk), label, ct, out_plain); + } + + private: + const rsa_oaep_hsm_decap_cb_t& cb_; +}; + +class ecies_p256_hsm_base_pke_t final : public base_pke_i { + public: + ecies_p256_hsm_base_pke_t(buf_t pub_key_oct, const ecies_p256_hsm_ecdh_cb_t& cb) + : pub_key_oct_(std::move(pub_key_oct)), cb_(cb) {} + + error_t encrypt(mem_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { + base_pke_ek_blob_v1_t blob; + error_t rv = parse_ek_blob(ek, blob); + if (rv) return rv; + if (static_cast(blob.key_type) != base_pke_key_type_v1::ecies_p256) + return coinbase::error(E_BADARG, "ECIES(P-256) HSM decrypt requires an ECIES base PKE public key"); + if (blob.ecies_ek.get_curve() != coinbase::crypto::curve_p256) + return coinbase::error(E_BADARG, "ECIES base PKE key must be on P-256"); + return coinbase::mpc::pve_base_pke_ecies().encrypt(coinbase::mpc::pve_keyref(blob.ecies_ek), label, plain, rho, + out_ct); + } + + error_t decrypt(mem_t dk_handle, mem_t label, mem_t ct, buf_t& out_plain) const override { + if (!cb_.ecdh) return coinbase::error(E_BADARG, "missing HSM ECIES ECDH callback"); + coinbase::mpc::pve_ecies_p256_hsm_dk_t dk; + dk.dk_handle = dk_handle; + dk.ctx = cb_.ctx; + dk.ecdh = cb_.ecdh; + dk.pub_key_oct = pub_key_oct_; + return coinbase::mpc::pve_base_pke_ecies_p256_hsm().decrypt(coinbase::mpc::pve_keyref(dk), label, ct, out_plain); + } + + private: + buf_t pub_key_oct_; + const ecies_p256_hsm_ecdh_cb_t& cb_; +}; + +} // namespace coinbase::api::pve::detail diff --git a/src/cbmpc/api/schnorr2pc.cpp b/src/cbmpc/api/schnorr2pc.cpp new file mode 100644 index 00000000..eec3262a --- /dev/null +++ b/src/cbmpc/api/schnorr2pc.cpp @@ -0,0 +1,234 @@ +#include +#include +#include + +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::schnorr_2p { + +namespace { + +constexpr uint32_t key_blob_version_v1 = 1; + +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::to_internal_party; +using coinbase::api::detail::validate_job_2p; + +struct key_blob_v1_t { + uint32_t version = key_blob_version_v1; + uint32_t role = 0; // 0=p1, 1=p2 + uint32_t curve = 0; // coinbase::api::curve_id + + buf_t Q_compressed; + coinbase::crypto::bn_t x_share; + + void convert(coinbase::converter_t& c) { c.convert(version, role, curve, Q_compressed, x_share); } +}; + +static error_t blob_to_key(const key_blob_v1_t& blob, coinbase::mpc::schnorr2p::key_t& key) { + if (blob.role > 1) return coinbase::error(E_FORMAT, "invalid key blob role"); + if (static_cast(blob.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + + key.role = static_cast(static_cast(blob.role)); + key.curve = coinbase::crypto::curve_secp256k1; + const auto& q = key.curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + key.x_share = blob.x_share; + + return key.Q.from_bin(key.curve, blob.Q_compressed); +} + +static error_t serialize_key_blob(const coinbase::mpc::schnorr2p::key_t& key, buf_t& out) { + if (key.curve != coinbase::crypto::curve_secp256k1) return coinbase::error(E_BADARG, "unsupported curve"); + + key_blob_v1_t blob; + blob.role = static_cast(key.role); + blob.curve = static_cast(curve_id::secp256k1); + blob.Q_compressed = key.Q.to_compressed_bin(); + blob.x_share = key.x_share; + out = coinbase::convert(blob); + return SUCCESS; +} + +static error_t deserialize_key_blob(mem_t in, coinbase::mpc::schnorr2p::key_t& key) { + key_blob_v1_t blob; + const error_t rv = coinbase::convert(blob, in); + if (rv) return rv; + if (blob.version != key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + return blob_to_key(blob, key); +} + +} // namespace + +error_t dkg(const coinbase::api::job_2p_t& job, curve_id curve, buf_t& key_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (curve != curve_id::secp256k1) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorr2p::key_t key; + buf_t sid; // unused by this API + const error_t rv = coinbase::mpc::eckey::key_share_2p_t::dkg(mpc_job, coinbase::crypto::curve_secp256k1, key, sid); + if (rv) return rv; + + return serialize_key_blob(key, key_blob); +} + +error_t refresh(const coinbase::api::job_2p_t& job, mem_t key_blob, buf_t& new_key_blob) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::schnorr2p::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (key.role != self) return coinbase::error(E_BADARG, "job.self mismatch key blob role"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorr2p::key_t new_key; + new_key_blob.free(); + rv = coinbase::mpc::eckey::key_share_2p_t::refresh(mpc_job, key, new_key); + if (rv) return rv; + + return serialize_key_blob(new_key, new_key_blob); +} + +error_t sign(const coinbase::api::job_2p_t& job, mem_t key_blob, mem_t msg, buf_t& sig) { + if (const error_t rv = validate_job_2p(job)) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(msg, "msg")) return rv; + if (msg.size != 32) return coinbase::error(E_BADARG, "BIP340 requires a 32-byte message"); + + coinbase::mpc::schnorr2p::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto self = to_internal_party(job.self); + if (key.role != self) return coinbase::error(E_BADARG, "job.self mismatch key blob role"); + + coinbase::mpc::job_2p_t mpc_job = to_internal_job(job); + + sig.free(); + return coinbase::mpc::schnorr2p::sign(mpc_job, key, msg, sig, coinbase::mpc::schnorr2p::variant_e::BIP340); +} + +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::schnorr2p::key_t key; + const error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + pub_key_compressed = key.Q.to_compressed_bin(); + return SUCCESS; +} + +error_t extract_public_key_xonly(mem_t key_blob, buf_t& pub_key_xonly) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::schnorr2p::key_t key; + const error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + pub_key_xonly = key.Q.get_x().to_bin(/*size=*/32); + return SUCCESS; +} + +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::schnorr2p::key_t key; + error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto curve = coinbase::crypto::curve_secp256k1; + const coinbase::crypto::mod_t& q = curve.order(); + const auto& G = curve.generator(); + const coinbase::crypto::bn_t x = key.x_share % q; + out_public_share_compressed = (x * G).to_compressed_bin(); + return SUCCESS; +} + +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::mpc::schnorr2p::key_t key; + const error_t rv = deserialize_key_blob(key_blob, key); + if (rv) return rv; + + const auto curve = coinbase::crypto::curve_secp256k1; + const coinbase::crypto::mod_t& q = curve.order(); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + out_private_scalar_fixed = key.x_share.to_bin(order_size); + + // Produce a v1-format blob with an invalid (out-of-range) scalar share so it is + // rejected by sign/refresh APIs. + key_blob_v1_t pub; + pub.role = static_cast(key.role); + pub.curve = static_cast(curve_id::secp256k1); + pub.Q_compressed = key.Q.to_compressed_bin(); + pub.x_share = q; // x_share == q is out of range + out_public_key_blob = coinbase::convert(pub); + return SUCCESS; +} + +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key_blob, "public_key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t pub; + error_t rv = coinbase::convert(pub, public_key_blob); + if (rv) return rv; + if (pub.version != key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (pub.role > 1) return coinbase::error(E_FORMAT, "invalid key blob role"); + if (static_cast(pub.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (pub.Q_compressed.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto curve = coinbase::crypto::curve_secp256k1; + const coinbase::crypto::mod_t& q = curve.order(); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + if (const error_t rvm = coinbase::api::detail::validate_mem_arg(private_scalar_fixed, "private_scalar_fixed")) + return rvm; + if (private_scalar_fixed.size != order_size) return coinbase::error(E_BADARG, "private_scalar_fixed wrong size"); + + if (const error_t rvp = coinbase::api::detail::validate_mem_arg(public_share_compressed, "public_share_compressed")) + return rvp; + + coinbase::crypto::ecc_point_t Qi_self(curve); + if (rv = Qi_self.from_bin(curve, public_share_compressed)) + return coinbase::error(rv, "invalid public_share_compressed"); + if (rv = curve.check(Qi_self)) return coinbase::error(rv, "invalid public_share_compressed"); + + const coinbase::crypto::bn_t x = coinbase::crypto::bn_t::from_bin(private_scalar_fixed) % q; + if (!q.is_in_range(x)) return coinbase::error(E_FORMAT, "invalid private_scalar_fixed"); + + const auto& G = curve.generator(); + if (x * G != Qi_self) return coinbase::error(E_FORMAT, "x_share mismatch key blob"); + + // Validate and normalize global public key encoding. + coinbase::crypto::ecc_point_t Q(curve); + if (rv = Q.from_bin(curve, pub.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + if (rv = curve.check(Q)) return coinbase::error(rv, "invalid key blob"); + + pub.x_share = x; + pub.Q_compressed = Q.to_compressed_bin(); + out_key_blob = coinbase::convert(pub); + return SUCCESS; +} + +} // namespace coinbase::api::schnorr_2p diff --git a/src/cbmpc/api/schnorr_mp.cpp b/src/cbmpc/api/schnorr_mp.cpp new file mode 100644 index 00000000..37daa4ab --- /dev/null +++ b/src/cbmpc/api/schnorr_mp.cpp @@ -0,0 +1,544 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "access_structure_util.h" +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::schnorr_mp { + +namespace { + +constexpr uint32_t key_blob_version_v1 = 1; +constexpr uint32_t ac_key_blob_version_v1 = 2; + +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::validate_job_mp; + +struct key_blob_v1_t { + uint32_t version = key_blob_version_v1; + uint32_t curve = 0; // coinbase::api::curve_id + + std::string party_name; // self identity (name-bound, not index-bound) + + buf_t Q_compressed; + std::map Qis_compressed; // name -> compressed Qi + + coinbase::crypto::bn_t x_share; + + void convert(coinbase::converter_t& c) { + c.convert(version, curve, party_name, Q_compressed, Qis_compressed, x_share); + } +}; + +static error_t extract_Q_from_key_blob(mem_t in, coinbase::crypto::ecc_point_t& Q) { + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, in); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.Q_compressed.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + return Q.from_bin(coinbase::crypto::curve_secp256k1, blob.Q_compressed); +} + +static error_t serialize_key_blob_for_party_names(const std::vector& party_names, + const std::string& self_name, + const coinbase::mpc::schnorrmp::key_t& key, uint32_t version, + buf_t& out) { + if (key.curve != coinbase::crypto::curve_secp256k1) return coinbase::error(E_BADARG, "unsupported curve"); + + const std::string_view self_sv(self_name); + bool self_in_party_names = false; + for (const auto& name_view : party_names) { + if (name_view == self_sv) { + self_in_party_names = true; + break; + } + } + if (!self_in_party_names) return coinbase::error(E_BADARG, "self_name not in party_names"); + if (key.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key"); + + key_blob_v1_t blob; + blob.version = version; + blob.curve = static_cast(curve_id::secp256k1); + blob.party_name = key.party_name; + blob.Q_compressed = key.Q.to_compressed_bin(); + blob.x_share = key.x_share; + + for (const auto& name_view : party_names) { + const std::string name(name_view); + const auto it = key.Qis.find(name); + if (it == key.Qis.end()) return coinbase::error(E_FORMAT, "key missing Qi"); + blob.Qis_compressed[name] = it->second.to_compressed_bin(); + } + + out = coinbase::convert(blob); + return SUCCESS; +} + +static error_t serialize_key_blob(const coinbase::api::job_mp_t& job, const coinbase::mpc::schnorrmp::key_t& key, + buf_t& out) { + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + + const std::string self_name(job.party_names[static_cast(job.self)]); + return serialize_key_blob_for_party_names(job.party_names, self_name, key, key_blob_version_v1, out); +} + +static error_t serialize_ac_key_blob(const coinbase::api::job_mp_t& job, const coinbase::mpc::schnorrmp::key_t& key, + buf_t& out) { + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + + const std::string self_name(job.party_names[static_cast(job.self)]); + return serialize_key_blob_for_party_names(job.party_names, self_name, key, ac_key_blob_version_v1, out); +} + +static error_t deserialize_key_blob(const coinbase::api::job_mp_t& job, mem_t in, + coinbase::mpc::schnorrmp::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version: " + std::to_string(blob.version)); + if (static_cast(blob.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + if (blob.Qis_compressed.size() != job.party_names.size()) return coinbase::error(E_BADARG, "invalid key blob"); + + // Ensure the party name set matches the job (order can differ). + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + if (blob.Qis_compressed.find(name) == blob.Qis_compressed.end()) + return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + } + + const auto curve = coinbase::crypto::curve_secp256k1; + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + const auto it = blob.Qis_compressed.find(name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, it->second)) return coinbase::error(rv, "invalid key blob"); + Qis[name] = std::move(Qi); + } + + coinbase::crypto::ecc_point_t Q_sum = curve.infinity(); + for (const auto& kv : Qis) Q_sum += kv.second; + if (Q != Q_sum) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +static error_t deserialize_ac_key_blob(const coinbase::api::job_mp_t& job, mem_t in, + coinbase::mpc::schnorrmp::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + if (job.self < 0 || static_cast(job.self) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid job.self"); + const std::string self_name(job.party_names[static_cast(job.self)]); + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != ac_key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.party_name != self_name) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + if (blob.Qis_compressed.size() != job.party_names.size()) return coinbase::error(E_BADARG, "invalid key blob"); + + // Ensure the party name set matches the job (order can differ). + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + if (blob.Qis_compressed.find(name) == blob.Qis_compressed.end()) + return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + } + + const auto curve = coinbase::crypto::curve_secp256k1; + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + const auto it = blob.Qis_compressed.find(name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_BADARG, "job.party_names mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, it->second)) return coinbase::error(rv, "invalid key blob"); + Qis[name] = std::move(Qi); + } + + // Access-structure key blobs are validated using the access structure at use sites. + // Here we only enforce the self-share binding. + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +static error_t deserialize_ac_key_blob(mem_t in, coinbase::mpc::schnorrmp::key_t& key) { + error_t rv = UNINITIALIZED_ERROR; + + key_blob_v1_t blob; + if (rv = coinbase::convert(blob, in)) return rv; + if (blob.version != ac_key_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.Qis_compressed.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto curve = coinbase::crypto::curve_secp256k1; + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + + coinbase::crypto::ecc_point_t Q; + if (rv = Q.from_bin(curve, blob.Q_compressed)) return coinbase::error(rv, "invalid key blob"); + + coinbase::crypto::ss::party_map_t Qis; + for (const auto& kv : blob.Qis_compressed) { + coinbase::crypto::ecc_point_t Qi; + if (rv = Qi.from_bin(curve, kv.second)) return coinbase::error(rv, "invalid key blob"); + Qis[kv.first] = std::move(Qi); + } + + const auto& G = curve.generator(); + const auto it_self = Qis.find(blob.party_name); + if (it_self == Qis.end()) return coinbase::error(E_FORMAT, "invalid key blob"); + if (blob.x_share * G != it_self->second) return coinbase::error(E_FORMAT, "invalid key blob"); + + key.party_name = blob.party_name; + key.curve = curve; + key.x_share = blob.x_share; + key.Qis = std::move(Qis); + key.Q = std::move(Q); + return SUCCESS; +} + +} // namespace + +error_t dkg_additive(const coinbase::api::job_mp_t& job, curve_id curve, buf_t& key_blob, buf_t& sid) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (curve != curve_id::secp256k1) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorrmp::key_t key; + sid.free(); + key_blob.free(); + rv = coinbase::mpc::schnorrmp::dkg(mpc_job, coinbase::crypto::curve_secp256k1, key, sid); + if (rv) return rv; + + return serialize_key_blob(job, key, key_blob); +} + +error_t dkg_ac(const coinbase::api::job_mp_t& job, curve_id curve, buf_t& sid, + const access_structure_t& access_structure, const std::vector& quorum_party_names, + buf_t& key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (curve != curve_id::secp256k1) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, job.party_names, + coinbase::crypto::curve_secp256k1, ac); + if (rv) return rv; + + coinbase::mpc::party_set_t quorum_party_set; + rv = coinbase::api::detail::to_internal_party_set(job.party_names, quorum_party_names, quorum_party_set); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorrmp::key_t key; + key_blob.free(); + + rv = coinbase::mpc::schnorrmp::dkg_ac(mpc_job, coinbase::crypto::curve_secp256k1, sid, ac, quorum_party_set, key); + if (rv) return rv; + + return serialize_ac_key_blob(job, key, key_blob); +} + +error_t refresh_additive(const coinbase::api::job_mp_t& job, buf_t& sid, mem_t key_blob, buf_t& new_key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + coinbase::mpc::schnorrmp::key_t key; + rv = deserialize_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorrmp::key_t new_key; + new_key_blob.free(); + rv = coinbase::mpc::schnorrmp::refresh(mpc_job, sid, key, new_key); + if (rv) return rv; + + return serialize_key_blob(job, new_key, new_key_blob); +} + +error_t refresh_ac(const coinbase::api::job_mp_t& job, buf_t& sid, mem_t key_blob, + const access_structure_t& access_structure, const std::vector& quorum_party_names, + buf_t& new_key_blob) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + coinbase::mpc::schnorrmp::key_t key; + rv = deserialize_ac_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, job.party_names, key.curve, ac); + if (rv) return rv; + + coinbase::mpc::party_set_t quorum_party_set; + rv = coinbase::api::detail::to_internal_party_set(job.party_names, quorum_party_names, quorum_party_set); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + coinbase::mpc::schnorrmp::key_t new_key; + new_key_blob.free(); + rv = coinbase::mpc::schnorrmp::refresh_ac(mpc_job, key.curve, sid, ac, quorum_party_set, key, new_key); + if (rv) return rv; + + return serialize_ac_key_blob(job, new_key, new_key_blob); +} + +error_t sign_ac(const coinbase::api::job_mp_t& job, mem_t ac_key_blob, const access_structure_t& access_structure, + mem_t msg, party_idx_t sig_receiver, buf_t& sig) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(ac_key_blob, "ac_key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (rv = coinbase::api::detail::validate_mem_arg(msg, "msg")) return rv; + if (msg.size != 32) return coinbase::error(E_BADARG, "BIP340 requires a 32-byte message"); + if (sig_receiver < 0 || static_cast(sig_receiver) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid sig_receiver"); + + coinbase::mpc::schnorrmp::key_t ac_key; + rv = deserialize_ac_key_blob(ac_key_blob, ac_key); + if (rv) return rv; + + // Bind the key share to the local party identity in the job. + const std::string_view self_name_sv(job.party_names[static_cast(job.self)]); + if (ac_key.party_name != self_name_sv) return coinbase::error(E_BADARG, "job.self mismatch key blob"); + + // Full party set is the key's Qis key set. + std::vector all_party_names; + all_party_names.reserve(ac_key.Qis.size()); + for (const auto& kv : ac_key.Qis) all_party_names.emplace_back(kv.first); + + // Validate that the signing party set (`job.party_names`) is a subset of the key's party set. + coinbase::mpc::party_set_t _unused; + rv = coinbase::api::detail::to_internal_party_set(all_party_names, job.party_names, _unused); + if (rv) return rv; + + // Convert access structure to internal and validate it matches the key party set. + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, all_party_names, ac_key.curve, ac); + if (rv) return rv; + + // Convert signing party list to internal set of names. + std::set quorum_names; + for (const auto& name : job.party_names) quorum_names.insert(std::string(name)); + + coinbase::mpc::schnorrmp::key_t additive_key; + rv = ac_key.to_additive_share(ac, quorum_names, additive_key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + sig.free(); + return coinbase::mpc::schnorrmp::sign(mpc_job, additive_key, msg, sig_receiver, sig, + coinbase::mpc::schnorrmp::variant_e::BIP340); +} + +error_t sign_additive(const coinbase::api::job_mp_t& job, mem_t key_blob, mem_t msg, party_idx_t sig_receiver, + buf_t& sig) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + if (rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (rv = coinbase::api::detail::validate_mem_arg(msg, "msg")) return rv; + if (msg.size != 32) return coinbase::error(E_BADARG, "BIP340 requires a 32-byte message"); + if (sig_receiver < 0 || static_cast(sig_receiver) >= job.party_names.size()) + return coinbase::error(E_BADARG, "invalid sig_receiver"); + + coinbase::mpc::schnorrmp::key_t key; + rv = deserialize_key_blob(job, key_blob, key); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + + sig.free(); + return coinbase::mpc::schnorrmp::sign(mpc_job, key, msg, sig_receiver, sig, + coinbase::mpc::schnorrmp::variant_e::BIP340); +} + +error_t get_public_key_compressed(mem_t key_blob, buf_t& pub_key_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::crypto::ecc_point_t Q(coinbase::crypto::curve_secp256k1); + const error_t rv = extract_Q_from_key_blob(key_blob, Q); + if (rv) return rv; + pub_key_compressed = Q.to_compressed_bin(); + return SUCCESS; +} + +error_t extract_public_key_xonly(mem_t key_blob, buf_t& pub_key_xonly) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + coinbase::crypto::ecc_point_t Q(coinbase::crypto::curve_secp256k1); + const error_t rv = extract_Q_from_key_blob(key_blob, Q); + if (rv) return rv; + pub_key_xonly = Q.get_x().to_bin(/*size=*/32); + return SUCCESS; +} + +error_t get_public_share_compressed(mem_t key_blob, buf_t& out_public_share_compressed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, key_blob); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto it = blob.Qis_compressed.find(blob.party_name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_FORMAT, "key blob missing self Qi"); + out_public_share_compressed = it->second; + return SUCCESS; +} + +error_t detach_private_scalar(mem_t key_blob, buf_t& out_public_key_blob, buf_t& out_private_scalar_fixed) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(key_blob, "key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, key_blob); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + + const auto curve = coinbase::crypto::curve_secp256k1; + const coinbase::crypto::mod_t& q = curve.order(); + if (!q.is_in_range(blob.x_share)) return coinbase::error(E_FORMAT, "invalid key blob"); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + out_private_scalar_fixed = blob.x_share.to_bin(order_size); + + // Wipe private scalar share. + blob.x_share = 0; + out_public_key_blob = coinbase::convert(blob); + return SUCCESS; +} + +error_t attach_private_scalar(mem_t public_key_blob, mem_t private_scalar_fixed, mem_t public_share_compressed, + buf_t& out_key_blob) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key_blob, "public_key_blob", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + key_blob_v1_t blob; + error_t rv = coinbase::convert(blob, public_key_blob); + if (rv) return rv; + if (blob.version != key_blob_version_v1 && blob.version != ac_key_blob_version_v1) + return coinbase::error(E_FORMAT, "unsupported key blob version"); + if (static_cast(blob.curve) != curve_id::secp256k1) + return coinbase::error(E_FORMAT, "invalid key blob curve"); + if (blob.party_name.empty()) return coinbase::error(E_FORMAT, "invalid key blob"); + + const auto curve = coinbase::crypto::curve_secp256k1; + const coinbase::crypto::mod_t& q = curve.order(); + const int order_size = q.get_bin_size(); + if (order_size <= 0) return coinbase::error(E_GENERAL, "invalid curve order size"); + + if (const error_t rvm = coinbase::api::detail::validate_mem_arg(private_scalar_fixed, "private_scalar_fixed")) + return rvm; + if (private_scalar_fixed.size != order_size) return coinbase::error(E_BADARG, "private_scalar_fixed wrong size"); + if (const error_t rvp = coinbase::api::detail::validate_mem_arg(public_share_compressed, "public_share_compressed")) + return rvp; + + const auto it = blob.Qis_compressed.find(blob.party_name); + if (it == blob.Qis_compressed.end()) return coinbase::error(E_FORMAT, "key blob missing self Qi"); + const buf_t& Qi_self_compressed = it->second; + + if (public_share_compressed != mem_t(Qi_self_compressed)) + return coinbase::error(E_BADARG, "public_share_compressed mismatch key blob"); + + coinbase::crypto::ecc_point_t Qi_self(curve); + if (rv = Qi_self.from_bin(curve, Qi_self_compressed)) return coinbase::error(rv, "invalid key blob"); + if (rv = curve.check(Qi_self)) return coinbase::error(rv, "invalid key blob"); + + coinbase::crypto::bn_t x = coinbase::crypto::bn_t::from_bin(private_scalar_fixed) % q; + if (!q.is_in_range(x)) return coinbase::error(E_FORMAT, "invalid private_scalar_fixed"); + + const auto& G = curve.generator(); + if (x * G != Qi_self) return coinbase::error(E_FORMAT, "x_share mismatch key blob"); + + blob.x_share = std::move(x); + out_key_blob = coinbase::convert(blob); + return SUCCESS; +} + +} // namespace coinbase::api::schnorr_mp diff --git a/src/cbmpc/api/tdh2.cpp b/src/cbmpc/api/tdh2.cpp new file mode 100644 index 00000000..c66b9ef6 --- /dev/null +++ b/src/cbmpc/api/tdh2.cpp @@ -0,0 +1,363 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include "access_structure_util.h" +#include "curve_util.h" +#include "job_util.h" +#include "mem_util.h" + +namespace coinbase::api::tdh2 { + +namespace { + +constexpr uint32_t private_share_blob_version_v1 = 1; + +using coinbase::api::detail::to_internal_curve; +using coinbase::api::detail::to_internal_job; +using coinbase::api::detail::validate_job_mp; + +static error_t validate_public_key(const coinbase::crypto::tdh2::public_key_t& pk) { + if (!pk.Q.valid() || !pk.Gamma.valid()) return coinbase::error(E_FORMAT, "invalid public key"); + const auto curve = pk.Q.get_curve(); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid public key"); + if (pk.Gamma.get_curve() != curve) return coinbase::error(E_FORMAT, "invalid public key"); + + // Enforce subgroup / infinity checks (DoS resistance and misuse-hardening). + if (curve.check(pk.Q)) return coinbase::error(E_FORMAT, "invalid public key"); + if (curve.check(pk.Gamma)) return coinbase::error(E_FORMAT, "invalid public key"); + + // `Gamma` is deterministically derived from `(Q, sid)` and must match. + const auto expected_gamma = coinbase::crypto::ro::hash_curve(mem_t("TDH2-Gamma"), pk.Q, pk.sid).curve(curve); + if (pk.Gamma != expected_gamma) return coinbase::error(E_FORMAT, "invalid public key"); + return SUCCESS; +} + +struct private_share_blob_v1_t { + uint32_t version = private_share_blob_version_v1; + uint32_t curve = 0; // coinbase::api::curve_id + + // Role/index id: 1..n, aligned with `job.party_names` order. + int rid = 0; + coinbase::crypto::bn_t x; + coinbase::crypto::tdh2::public_key_t pub_key; + + void convert(coinbase::converter_t& c) { c.convert(version, curve, rid, x, pub_key); } +}; + +error_t deserialize_private_share(mem_t in, coinbase::crypto::tdh2::private_share_t& out) { + private_share_blob_v1_t blob; + error_t rv = coinbase::convert(blob, in); + if (rv) return rv; + if (blob.version != private_share_blob_version_v1) return coinbase::error(E_FORMAT, "unsupported private share blob"); + + const auto cid = static_cast(blob.curve); + if (cid == curve_id::ed25519) return coinbase::error(E_FORMAT, "invalid private share curve"); + const auto curve = to_internal_curve(cid); + if (!curve.valid()) return coinbase::error(E_FORMAT, "invalid private share curve"); + + if (blob.rid <= 0) return coinbase::error(E_FORMAT, "invalid private share blob"); + if (curve.order().is_in_range(blob.x) == false) return coinbase::error(E_FORMAT, "invalid private share blob"); + + if (const error_t pk_rv = validate_public_key(blob.pub_key)) return pk_rv; + if (blob.pub_key.Q.get_curve() != curve) return coinbase::error(E_FORMAT, "invalid private share blob"); + + out.rid = blob.rid; + out.x = blob.x; + out.pub_key = blob.pub_key; + return SUCCESS; +} + +static error_t serialize_public_outputs(const coinbase::api::job_mp_t& job, + const coinbase::mpc::eckey::key_share_mp_t& key, const buf_t& sid, + buf_t& public_key, std::vector& public_shares) { + const coinbase::crypto::tdh2::public_key_t pk(key.Q, sid); + public_key = pk.to_bin(); + + public_shares.clear(); + public_shares.reserve(job.party_names.size()); + for (const auto& name_view : job.party_names) { + const std::string name(name_view); + const auto it = key.Qis.find(name); + if (it == key.Qis.end()) return coinbase::error(E_FORMAT, "DKG output missing public share"); + public_shares.emplace_back(it->second.to_compressed_bin()); + } + return SUCCESS; +} + +static error_t serialize_private_share(curve_id curve, coinbase::api::party_idx_t self, + const coinbase::crypto::bn_t& x_share, + const coinbase::crypto::tdh2::public_key_t& pk, buf_t& private_share) { + private_share_blob_v1_t blob; + blob.curve = static_cast(curve); + blob.rid = static_cast(self) + 1; // 1..n (aligned with job.party_names order) + blob.x = x_share; + blob.pub_key = pk; + private_share = coinbase::convert(blob); + return SUCCESS; +} + +} // namespace + +error_t dkg_additive(const coinbase::api::job_mp_t& job, curve_id curve, buf_t& public_key, + std::vector& public_shares, buf_t& private_share, buf_t& sid) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + + if (curve == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + coinbase::mpc::eckey::key_share_mp_t key; + + public_key.free(); + private_share.free(); + sid.free(); + public_shares.clear(); + + rv = coinbase::mpc::eckey::key_share_mp_t::dkg(mpc_job, icurve, key, sid); + if (rv) return rv; + + rv = serialize_public_outputs(job, key, sid, public_key, public_shares); + if (rv) return rv; + + // Deserialize the public key blob so we can embed the exact serialized form in the private share. + coinbase::crypto::tdh2::public_key_t pk; + rv = pk.from_bin(public_key); + if (rv) return rv; + + return serialize_private_share(curve, job.self, key.x_share, pk, private_share); +} + +error_t dkg_ac(const coinbase::api::job_mp_t& job, curve_id curve, buf_t& sid, + const access_structure_t& access_structure, const std::vector& quorum_party_names, + buf_t& public_key, std::vector& public_shares, buf_t& private_share) { + error_t rv = validate_job_mp(job); + if (rv) return rv; + + if (curve == curve_id::ed25519) return coinbase::error(E_BADARG, "unsupported curve"); + const auto icurve = to_internal_curve(curve); + if (!icurve.valid()) return coinbase::error(E_BADARG, "unsupported curve"); + + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, job.party_names, icurve, ac); + if (rv) return rv; + + coinbase::mpc::party_set_t quorum_party_set; + rv = coinbase::api::detail::to_internal_party_set(job.party_names, quorum_party_names, quorum_party_set); + if (rv) return rv; + + coinbase::mpc::job_mp_t mpc_job = to_internal_job(job); + coinbase::mpc::eckey::key_share_mp_t key; + + public_key.free(); + private_share.free(); + public_shares.clear(); + + rv = coinbase::mpc::eckey::key_share_mp_t::dkg_ac(mpc_job, icurve, sid, ac, quorum_party_set, key); + if (rv) return rv; + + rv = serialize_public_outputs(job, key, sid, public_key, public_shares); + if (rv) return rv; + + coinbase::crypto::tdh2::public_key_t pk; + rv = pk.from_bin(public_key); + if (rv) return rv; + + return serialize_private_share(curve, job.self, key.x_share, pk, private_share); +} + +error_t encrypt(mem_t public_key, mem_t plaintext, mem_t label, buf_t& ciphertext) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key, "public_key", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(plaintext, "plaintext")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + + coinbase::crypto::tdh2::public_key_t pk; + error_t rv = pk.from_bin(public_key); + if (rv) return rv; + if (rv = validate_public_key(pk)) return rv; + + const auto ct = pk.encrypt(plaintext, label); + ciphertext = coinbase::convert(ct); + return SUCCESS; +} + +error_t verify(mem_t public_key, mem_t ciphertext, mem_t label) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key, "public_key", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + + coinbase::crypto::tdh2::public_key_t pk; + error_t rv = pk.from_bin(public_key); + if (rv) return rv; + if (rv = validate_public_key(pk)) return rv; + + coinbase::crypto::tdh2::ciphertext_t ct; + rv = coinbase::convert(ct, ciphertext); + if (rv) return rv; + return ct.verify(pk, label); +} + +error_t partial_decrypt(mem_t private_share, mem_t ciphertext, mem_t label, buf_t& partial_decryption) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(private_share, "private_share", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + + coinbase::crypto::tdh2::private_share_t share; + error_t rv = deserialize_private_share(private_share, share); + if (rv) return rv; + + coinbase::crypto::tdh2::ciphertext_t ct; + rv = coinbase::convert(ct, ciphertext); + if (rv) return rv; + + coinbase::crypto::tdh2::partial_decryption_t partial; + rv = share.decrypt(ct, label, partial); + if (rv) return rv; + + partial_decryption = coinbase::convert(partial); + return SUCCESS; +} + +error_t combine_additive(mem_t public_key, const std::vector& public_shares, mem_t label, + const std::vector& partial_decryptions, mem_t ciphertext, buf_t& plaintext) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key, "public_key", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_vec_arg_max_size( + public_shares, "public_shares", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_vec_arg_max_size( + partial_decryptions, "partial_decryptions", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + if (public_shares.size() != partial_decryptions.size()) + return coinbase::error(E_BADARG, "public_shares and partial_decryptions size mismatch"); + + coinbase::crypto::tdh2::public_key_t pk; + error_t rv = pk.from_bin(public_key); + if (rv) return rv; + if (rv = validate_public_key(pk)) return rv; + + coinbase::crypto::tdh2::ciphertext_t ct; + rv = coinbase::convert(ct, ciphertext); + if (rv) return rv; + + auto curve = pk.Q.get_curve(); + if (!curve.valid()) return coinbase::error(E_FORMAT, "public key missing curve"); + + coinbase::crypto::tdh2::pub_shares_t Qi; + Qi.reserve(public_shares.size()); + for (const auto& m : public_shares) { + coinbase::crypto::ecc_point_t P; + rv = P.from_bin(curve, m); + if (rv) return rv; + Qi.emplace_back(std::move(P)); + } + + coinbase::crypto::tdh2::partial_decryptions_t partials; + partials.resize(partial_decryptions.size()); + for (size_t i = 0; i < partial_decryptions.size(); i++) { + rv = coinbase::convert(partials[i], partial_decryptions[i]); + if (rv) return rv; + } + + return coinbase::crypto::tdh2::combine_additive(pk, Qi, label, partials, ct, plaintext); +} + +error_t combine_ac(const access_structure_t& access_structure, mem_t public_key, + const std::vector& party_names, const std::vector& public_shares, + mem_t label, const std::vector& partial_decryption_party_names, + const std::vector& partial_decryptions, mem_t ciphertext, buf_t& plaintext) { + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size(public_key, "public_key", + coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg(label, "label")) return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_arg_max_size( + ciphertext, "ciphertext", coinbase::api::detail::MAX_CIPHERTEXT_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_vec_arg_max_size( + public_shares, "public_shares", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + if (const error_t rv = coinbase::api::detail::validate_mem_vec_arg_max_size( + partial_decryptions, "partial_decryptions", coinbase::api::detail::MAX_OPAQUE_BLOB_SIZE)) + return rv; + + if (party_names.size() != public_shares.size()) + return coinbase::error(E_BADARG, "party_names and public_shares size mismatch"); + if (partial_decryption_party_names.size() != partial_decryptions.size()) + return coinbase::error(E_BADARG, "partial_decryption_party_names and partial_decryptions size mismatch"); + + coinbase::crypto::tdh2::public_key_t pk; + error_t rv = pk.from_bin(public_key); + if (rv) return rv; + if (rv = validate_public_key(pk)) return rv; + + const auto curve = pk.Q.get_curve(); + if (!curve.valid()) return coinbase::error(E_FORMAT, "public key missing curve"); + + // Convert access structure to internal representation and validate that leaf names match `party_names`. + coinbase::crypto::ss::ac_owned_t ac; + rv = coinbase::api::detail::to_internal_access_structure(access_structure, party_names, curve, ac); + if (rv) return rv; + + coinbase::crypto::tdh2::ciphertext_t ct; + rv = coinbase::convert(ct, ciphertext); + if (rv) return rv; + + coinbase::crypto::ss::ac_pub_shares_t pub_shares_map; + pub_shares_map.clear(); + std::set seen_party_names; + for (size_t i = 0; i < party_names.size(); i++) { + const auto name_view = party_names[i]; + if (name_view.empty()) return coinbase::error(E_BADARG, "party name must be non-empty"); + if (!seen_party_names.insert(name_view).second) return coinbase::error(E_BADARG, "duplicate party name"); + coinbase::crypto::ecc_point_t Qi; + rv = Qi.from_bin(curve, public_shares[i]); + if (rv) return rv; + pub_shares_map.emplace(std::string(name_view), std::move(Qi)); + } + + coinbase::crypto::ss::party_map_t partials_map; + partials_map.clear(); + std::set seen_partial_names; + for (size_t i = 0; i < partial_decryptions.size(); i++) { + const auto name_view = partial_decryption_party_names[i]; + if (name_view.empty()) return coinbase::error(E_BADARG, "partial decryption party name must be non-empty"); + if (!seen_partial_names.insert(name_view).second) + return coinbase::error(E_BADARG, "duplicate partial decryption party name"); + + const std::string name(name_view); + if (pub_shares_map.find(name) == pub_shares_map.end()) + return coinbase::error(E_BADARG, "partial decryption party name not in party_names"); + + coinbase::crypto::tdh2::partial_decryption_t partial; + rv = coinbase::convert(partial, partial_decryptions[i]); + if (rv) return rv; + partials_map[name] = std::move(partial); + } + + return coinbase::crypto::tdh2::combine(ac, pk, pub_shares_map, label, partials_map, ct, plaintext); +} + +} // namespace coinbase::api::tdh2 diff --git a/src/cbmpc/c_api/CMakeLists.txt b/src/cbmpc/c_api/CMakeLists.txt new file mode 100644 index 00000000..014bf149 --- /dev/null +++ b/src/cbmpc/c_api/CMakeLists.txt @@ -0,0 +1,18 @@ +add_library(cbmpc_c_api OBJECT "") + +target_sources(cbmpc_c_api PRIVATE + common.cpp + eddsa_mp.cpp + eddsa2pc.cpp + ecdsa_mp.cpp + ecdsa2pc.cpp + pve_base_pke.cpp + pve_batch_ac.cpp + pve_batch_single_recipient.cpp + schnorr_mp.cpp + schnorr2pc.cpp + tdh2.cpp +) + +target_link_libraries(cbmpc_c_api cbmpc_core cbmpc_api) + diff --git a/src/cbmpc/c_api/access_structure_adapter.h b/src/cbmpc/c_api/access_structure_adapter.h new file mode 100644 index 00000000..f8c9b50e --- /dev/null +++ b/src/cbmpc/c_api/access_structure_adapter.h @@ -0,0 +1,123 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace coinbase::capi::detail { + +inline cbmpc_error_t to_cpp_quorum_party_names(const char* const* names, int names_count, + std::vector& out) { + out.clear(); + if (names_count < 0) return E_BADARG; + if (names_count == 0) return CBMPC_SUCCESS; + if (!names) return E_BADARG; + + out.reserve(static_cast(names_count)); + for (int i = 0; i < names_count; i++) { + const char* s = names[i]; + if (!s) return E_BADARG; + if (s[0] == '\0') return E_BADARG; + out.emplace_back(s); + } + return CBMPC_SUCCESS; +} + +inline cbmpc_error_t to_cpp_access_structure(const cbmpc_access_structure_t* in, + coinbase::api::access_structure_t& out) { + if (!in) return E_BADARG; + if (in->nodes_count < 0 || in->child_indices_count < 0) return E_BADARG; + if (in->nodes_count == 0) return E_BADARG; + if (!in->nodes) return E_BADARG; + if (in->child_indices_count > 0 && !in->child_indices) return E_BADARG; + if (in->root_index < 0 || in->root_index >= in->nodes_count) return E_BADARG; + + // state: 0 = unvisited, 1 = visiting, 2 = done + std::vector state(static_cast(in->nodes_count), 0); + + struct builder_t { + static cbmpc_error_t build(const cbmpc_access_structure_t* in, int32_t idx, bool is_root, + std::vector& state, coinbase::api::access_structure_t& out) { + if (idx < 0 || idx >= in->nodes_count) return E_BADARG; + + const auto uidx = static_cast(idx); + if (state[uidx] == 1) return E_BADARG; // cycle + if (state[uidx] == 2) return E_BADARG; // node reuse (DAG) + state[uidx] = 1; + + const cbmpc_access_structure_node_t& n = in->nodes[uidx]; + + if (n.child_indices_offset < 0 || n.child_indices_count < 0) return E_BADARG; + if (n.child_indices_count > 0) { + if (n.child_indices_offset > in->child_indices_count - n.child_indices_count) return E_BADARG; + } + + auto node = coinbase::api::access_structure_t{}; + + switch (n.type) { + case CBMPC_ACCESS_STRUCTURE_NODE_LEAF: { + if (is_root) return E_BADARG; + if (!n.leaf_name || n.leaf_name[0] == '\0') return E_BADARG; + if (n.threshold_k != 0) return E_BADARG; + if (n.child_indices_count != 0) return E_BADARG; + node.type = coinbase::api::access_structure_t::node_type::leaf; + node.leaf_name = std::string_view(n.leaf_name); + } break; + + case CBMPC_ACCESS_STRUCTURE_NODE_AND: + case CBMPC_ACCESS_STRUCTURE_NODE_OR: + case CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD: { + if (n.leaf_name) return E_BADARG; + if (n.child_indices_count <= 0) return E_BADARG; + + if (n.type == CBMPC_ACCESS_STRUCTURE_NODE_AND) { + if (n.threshold_k != 0) return E_BADARG; + node.type = coinbase::api::access_structure_t::node_type::and_node; + } else if (n.type == CBMPC_ACCESS_STRUCTURE_NODE_OR) { + if (n.threshold_k != 0) return E_BADARG; + node.type = coinbase::api::access_structure_t::node_type::or_node; + } else { + if (n.threshold_k < 1) return E_BADARG; + if (n.threshold_k > n.child_indices_count) return E_BADARG; + node.type = coinbase::api::access_structure_t::node_type::threshold; + node.threshold_k = n.threshold_k; + } + + node.children.reserve(static_cast(n.child_indices_count)); + const int32_t off = n.child_indices_offset; + for (int32_t i = 0; i < n.child_indices_count; i++) { + const int32_t child_idx = in->child_indices[static_cast(off + i)]; + coinbase::api::access_structure_t child; + const cbmpc_error_t rv = build(in, child_idx, /*is_root=*/false, state, child); + if (rv) return rv; + node.children.emplace_back(std::move(child)); + } + } break; + + default: + return E_BADARG; + } + + out = std::move(node); + state[uidx] = 2; + return CBMPC_SUCCESS; + } + }; + + const cbmpc_error_t rv = builder_t::build(in, in->root_index, /*is_root=*/true, state, out); + if (rv) return rv; + + // Reject unreachable nodes (must be a single rooted tree). + for (size_t i = 0; i < state.size(); i++) { + if (state[i] != 2) return E_BADARG; + } + + return CBMPC_SUCCESS; +} + +} // namespace coinbase::capi::detail diff --git a/src/cbmpc/c_api/common.cpp b/src/cbmpc/c_api/common.cpp new file mode 100644 index 00000000..8760529c --- /dev/null +++ b/src/cbmpc/c_api/common.cpp @@ -0,0 +1,63 @@ +#include +#include + +#include + +#if defined(__APPLE__) +#include +#elif defined(__linux__) && defined(__GLIBC__) +#include +#endif + +extern "C" { + +static void secure_bzero(void* ptr, size_t size) { + if (!ptr || size == 0) return; + OPENSSL_cleanse(ptr, size); +} + +static size_t malloc_usable_bytes(void* ptr) { + if (!ptr) return 0; +#if defined(__APPLE__) + return malloc_size(ptr); +#elif defined(__linux__) && defined(__GLIBC__) + return malloc_usable_size(ptr); +#else + return 0; +#endif +} + +// NOLINTNEXTLINE(cppcoreguidelines-no-malloc) +void* cbmpc_malloc(size_t size) { + if (size == 0) return nullptr; + return std::malloc(size); +} + +void cbmpc_free(void* ptr) { std::free(ptr); } + +void cbmpc_cmem_free(cmem_t mem) { + if (mem.data && mem.size > 0) secure_bzero(mem.data, static_cast(mem.size)); + cbmpc_free(mem.data); +} + +void cbmpc_cmems_free(cmems_t mems) { + if (mems.data) { + // Prefer wiping the full malloc allocation when possible. This avoids + // relying on `mems.sizes` being well-formed for zeroization. + size_t wipe_len = malloc_usable_bytes(mems.data); + if (wipe_len == 0 && mems.count > 0 && mems.sizes) { + // Best-effort fallback: sum non-negative segment sizes. + size_t total = 0; + for (int i = 0; i < mems.count; i++) { + const int sz = mems.sizes[i]; + if (sz > 0) total += static_cast(sz); + } + wipe_len = total; + } + secure_bzero(mems.data, wipe_len); + } + cbmpc_free(mems.data); + cbmpc_free(mems.sizes); +} + +} // extern "C" diff --git a/src/cbmpc/c_api/ecdsa2pc.cpp b/src/cbmpc/c_api/ecdsa2pc.cpp new file mode 100644 index 00000000..6f52523b --- /dev/null +++ b/src/cbmpc/c_api/ecdsa2pc.cpp @@ -0,0 +1,262 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "util.h" + +namespace { + +using namespace coinbase::capi::detail; + +using api_sign_fn_t = coinbase::error_t (*)(const coinbase::api::job_2p_t&, coinbase::mem_t /*key_blob*/, + coinbase::mem_t /*msg_hash*/, coinbase::buf_t& /*sid*/, + coinbase::buf_t& /*sig_der*/); + +static cbmpc_error_t sign_common(api_sign_fn_t fn, const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t msg_hash, + cmem_t sid_in, cmem_t* sid_out, cmem_t* sig_der_out) { + try { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (!sig_der_out) return E_BADARG; + *sig_der_out = cmem_t{nullptr, 0}; + + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + const auto vmh = validate_cmem(msg_hash); + if (vmh) return vmh; + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t sig; + + const coinbase::error_t rv = fn(ctx.job, view_cmem(key_blob), view_cmem(msg_hash), sid, sig); + if (rv) return rv; + + const auto r_sig = alloc_cmem_from_buf(sig, sig_der_out); + if (r_sig) return r_sig; + + if (sid_out) { + const auto r_sid = alloc_cmem_from_buf(sid, sid_out); + if (r_sid) { + cbmpc_cmem_free(*sig_der_out); + *sig_der_out = cmem_t{nullptr, 0}; + return r_sid; + } + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (sig_der_out) { + cbmpc_cmem_free(*sig_der_out); + *sig_der_out = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (sig_der_out) { + cbmpc_cmem_free(*sig_der_out); + *sig_der_out = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +} // namespace + +extern "C" { + +cbmpc_error_t cbmpc_ecdsa_2p_dkg(const cbmpc_2pc_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t key_blob; + const coinbase::error_t rv = coinbase::api::ecdsa_2p::dkg(ctx.job, curve_cpp, key_blob); + if (rv) return rv; + + return alloc_cmem_from_buf(key_blob, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_2p_refresh(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t* out_new_key_blob) { + try { + if (!out_new_key_blob) return E_BADARG; + *out_new_key_blob = cmem_t{nullptr, 0}; + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t new_key; + const coinbase::error_t rv = coinbase::api::ecdsa_2p::refresh(ctx.job, view_cmem(key_blob), new_key); + if (rv) return rv; + + return alloc_cmem_from_buf(new_key, out_new_key_blob); + } catch (const std::bad_alloc&) { + if (out_new_key_blob) *out_new_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_new_key_blob) *out_new_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_2p_sign(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t msg_hash, cmem_t sid_in, + cmem_t* sid_out, cmem_t* sig_der_out) { + return sign_common(&coinbase::api::ecdsa_2p::sign, job, key_blob, msg_hash, sid_in, sid_out, sig_der_out); +} + +cbmpc_error_t cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key) { + try { + if (!out_pub_key) return E_BADARG; + *out_pub_key = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t pk; + const coinbase::error_t rv = coinbase::api::ecdsa_2p::get_public_key_compressed(view_cmem(key_blob), pk); + if (rv) return rv; + + return alloc_cmem_from_buf(pk, out_pub_key); + } catch (const std::bad_alloc&) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_2p_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share) { + try { + if (!out_public_share) return E_BADARG; + *out_public_share = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t Qi; + const coinbase::error_t rv = coinbase::api::ecdsa_2p::get_public_share_compressed(view_cmem(key_blob), Qi); + if (rv) return rv; + return alloc_cmem_from_buf(Qi, out_public_share); + } catch (const std::bad_alloc&) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_2p_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar) { + try { + if (!out_public_key_blob || !out_private_scalar) return E_BADARG; + *out_public_key_blob = cmem_t{nullptr, 0}; + *out_private_scalar = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t public_blob; + coinbase::buf_t private_scalar; + const coinbase::error_t rv = + coinbase::api::ecdsa_2p::detach_private_scalar(view_cmem(key_blob), public_blob, private_scalar); + if (rv) return rv; + + const auto r1 = alloc_cmem_from_buf(public_blob, out_public_key_blob); + if (r1) return r1; + const auto r2 = alloc_cmem_from_buf(private_scalar, out_private_scalar); + if (r2) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + return r2; + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar) { + cbmpc_cmem_free(*out_private_scalar); + *out_private_scalar = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar) { + cbmpc_cmem_free(*out_private_scalar); + *out_private_scalar = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_2p_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar, + cmem_t public_share_compressed, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vpb = validate_cmem(public_key_blob); + if (vpb) return vpb; + const auto vx = validate_cmem(private_scalar); + if (vx) return vx; + const auto vq = validate_cmem(public_share_compressed); + if (vq) return vq; + + coinbase::buf_t merged; + const coinbase::error_t rv = coinbase::api::ecdsa_2p::attach_private_scalar( + view_cmem(public_key_blob), view_cmem(private_scalar), view_cmem(public_share_compressed), merged); + if (rv) return rv; + return alloc_cmem_from_buf(merged, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/ecdsa_mp.cpp b/src/cbmpc/c_api/ecdsa_mp.cpp new file mode 100644 index 00000000..2d0ca784 --- /dev/null +++ b/src/cbmpc/c_api/ecdsa_mp.cpp @@ -0,0 +1,454 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "access_structure_adapter.h" +#include "util.h" + +using namespace coinbase::capi::detail; + +extern "C" { + +cbmpc_error_t cbmpc_ecdsa_mp_dkg_additive(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob, + cmem_t* out_sid) { + try { + if (!out_key_blob || !out_sid) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t key_blob; + coinbase::buf_t sid; + const coinbase::error_t rv = coinbase::api::ecdsa_mp::dkg_additive(ctx.job, curve_cpp, key_blob, sid); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(key_blob, out_key_blob); + if (r_key) return r_key; + + const auto r_sid = alloc_cmem_from_buf(sid, out_sid); + if (r_sid) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_key_blob) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_dkg_ac(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t sid_in, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* out_ac_key_blob, cmem_t* out_sid) { + try { + if (!out_ac_key_blob || !out_sid) return E_BADARG; + *out_ac_key_blob = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + + std::vector quorum_names; + const auto vqn = to_cpp_quorum_party_names(quorum_party_names, quorum_party_names_count, quorum_names); + if (vqn) return vqn; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t key_blob; + const coinbase::error_t rv = + coinbase::api::ecdsa_mp::dkg_ac(ctx.job, curve_cpp, sid, ac_cpp, quorum_names, key_blob); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(key_blob, out_ac_key_blob); + if (r_key) return r_key; + + const auto r_sid = alloc_cmem_from_buf(sid, out_sid); + if (r_sid) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_ac_key_blob) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_ac_key_blob) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_refresh_additive(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t key_blob, + cmem_t* sid_out, cmem_t* out_new_key_blob) { + try { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (!out_new_key_blob) return E_BADARG; + *out_new_key_blob = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t new_key; + const coinbase::error_t rv = coinbase::api::ecdsa_mp::refresh_additive(ctx.job, sid, view_cmem(key_blob), new_key); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(new_key, out_new_key_blob); + if (r_key) return r_key; + + if (sid_out) { + const auto r_sid = alloc_cmem_from_buf(sid, sid_out); + if (r_sid) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_key_blob) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_key_blob) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_refresh_ac(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* sid_out, cmem_t* out_new_ac_key_blob) { + try { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (!out_new_ac_key_blob) return E_BADARG; + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + const auto vkb = validate_cmem(ac_key_blob); + if (vkb) return vkb; + + std::vector quorum_names; + const auto vqn = to_cpp_quorum_party_names(quorum_party_names, quorum_party_names_count, quorum_names); + if (vqn) return vqn; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t new_key; + const coinbase::error_t rv = + coinbase::api::ecdsa_mp::refresh_ac(ctx.job, sid, view_cmem(ac_key_blob), ac_cpp, quorum_names, new_key); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(new_key, out_new_ac_key_blob); + if (r_key) return r_key; + + if (sid_out) { + const auto r_sid = alloc_cmem_from_buf(sid, sid_out); + if (r_sid) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_ac_key_blob) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_ac_key_blob) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_sign_additive(const cbmpc_mp_job_t* job, cmem_t key_blob, cmem_t msg_hash, + int32_t sig_receiver, cmem_t* sig_der_out) { + try { + if (!sig_der_out) return E_BADARG; + *sig_der_out = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + const auto vmh = validate_cmem(msg_hash); + if (vmh) return vmh; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sig; + const coinbase::error_t rv = + coinbase::api::ecdsa_mp::sign_additive(ctx.job, view_cmem(key_blob), view_cmem(msg_hash), sig_receiver, sig); + if (rv) return rv; + + return alloc_cmem_from_buf(sig, sig_der_out); + } catch (const std::bad_alloc&) { + if (sig_der_out) { + cbmpc_cmem_free(*sig_der_out); + *sig_der_out = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sig_der_out) { + cbmpc_cmem_free(*sig_der_out); + *sig_der_out = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_sign_ac(const cbmpc_mp_job_t* job, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, cmem_t msg_hash, + int32_t sig_receiver, cmem_t* sig_der_out) { + try { + if (!sig_der_out) return E_BADARG; + *sig_der_out = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vkb = validate_cmem(ac_key_blob); + if (vkb) return vkb; + const auto vmh = validate_cmem(msg_hash); + if (vmh) return vmh; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sig; + const coinbase::error_t rv = coinbase::api::ecdsa_mp::sign_ac(ctx.job, view_cmem(ac_key_blob), ac_cpp, + view_cmem(msg_hash), sig_receiver, sig); + if (rv) return rv; + + return alloc_cmem_from_buf(sig, sig_der_out); + } catch (const std::bad_alloc&) { + if (sig_der_out) { + cbmpc_cmem_free(*sig_der_out); + *sig_der_out = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sig_der_out) { + cbmpc_cmem_free(*sig_der_out); + *sig_der_out = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key) { + try { + if (!out_pub_key) return E_BADARG; + *out_pub_key = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t pk; + const coinbase::error_t rv = coinbase::api::ecdsa_mp::get_public_key_compressed(view_cmem(key_blob), pk); + if (rv) return rv; + + return alloc_cmem_from_buf(pk, out_pub_key); + } catch (const std::bad_alloc&) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share) { + try { + if (!out_public_share) return E_BADARG; + *out_public_share = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t Qi; + const coinbase::error_t rv = coinbase::api::ecdsa_mp::get_public_share_compressed(view_cmem(key_blob), Qi); + if (rv) return rv; + return alloc_cmem_from_buf(Qi, out_public_share); + } catch (const std::bad_alloc&) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed) { + try { + if (!out_public_key_blob || !out_private_scalar_fixed) return E_BADARG; + *out_public_key_blob = cmem_t{nullptr, 0}; + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t public_blob; + coinbase::buf_t private_scalar_fixed; + const coinbase::error_t rv = + coinbase::api::ecdsa_mp::detach_private_scalar(view_cmem(key_blob), public_blob, private_scalar_fixed); + if (rv) return rv; + + const auto r1 = alloc_cmem_from_buf(public_blob, out_public_key_blob); + if (r1) return r1; + const auto r2 = alloc_cmem_from_buf(private_scalar_fixed, out_private_scalar_fixed); + if (r2) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + return r2; + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_ecdsa_mp_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(public_key_blob); + if (vkb) return vkb; + const auto vx = validate_cmem(private_scalar_fixed); + if (vx) return vx; + const auto vq = validate_cmem(public_share_compressed); + if (vq) return vq; + + coinbase::buf_t merged; + const coinbase::error_t rv = coinbase::api::ecdsa_mp::attach_private_scalar( + view_cmem(public_key_blob), view_cmem(private_scalar_fixed), view_cmem(public_share_compressed), merged); + if (rv) return rv; + return alloc_cmem_from_buf(merged, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/eddsa2pc.cpp b/src/cbmpc/c_api/eddsa2pc.cpp new file mode 100644 index 00000000..e96a91f4 --- /dev/null +++ b/src/cbmpc/c_api/eddsa2pc.cpp @@ -0,0 +1,218 @@ +#include + +#include +#include +#include +#include +#include +#include + +#include "util.h" + +using namespace coinbase::capi::detail; + +extern "C" { + +cbmpc_error_t cbmpc_eddsa_2p_dkg(const cbmpc_2pc_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t key_blob; + const coinbase::error_t rv = coinbase::api::eddsa_2p::dkg(ctx.job, curve_cpp, key_blob); + if (rv) return rv; + + return alloc_cmem_from_buf(key_blob, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_2p_refresh(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t* out_new_key_blob) { + try { + if (!out_new_key_blob) return E_BADARG; + *out_new_key_blob = cmem_t{nullptr, 0}; + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t new_key; + const coinbase::error_t rv = coinbase::api::eddsa_2p::refresh(ctx.job, view_cmem(key_blob), new_key); + if (rv) return rv; + + return alloc_cmem_from_buf(new_key, out_new_key_blob); + } catch (const std::bad_alloc&) { + if (out_new_key_blob) *out_new_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_new_key_blob) *out_new_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_2p_sign(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t msg, cmem_t* sig_out) { + try { + if (!sig_out) return E_BADARG; + *sig_out = cmem_t{nullptr, 0}; + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + const auto vm = validate_cmem(msg); + if (vm) return vm; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t sig; + const coinbase::error_t rv = coinbase::api::eddsa_2p::sign(ctx.job, view_cmem(key_blob), view_cmem(msg), sig); + if (rv) return rv; + + return alloc_cmem_from_buf(sig, sig_out); + } catch (const std::bad_alloc&) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_2p_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key) { + try { + if (!out_pub_key) return E_BADARG; + *out_pub_key = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t pk; + const coinbase::error_t rv = coinbase::api::eddsa_2p::get_public_key_compressed(view_cmem(key_blob), pk); + if (rv) return rv; + + return alloc_cmem_from_buf(pk, out_pub_key); + } catch (const std::bad_alloc&) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_2p_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share) { + try { + if (!out_public_share) return E_BADARG; + *out_public_share = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t Qi; + const coinbase::error_t rv = coinbase::api::eddsa_2p::get_public_share_compressed(view_cmem(key_blob), Qi); + if (rv) return rv; + return alloc_cmem_from_buf(Qi, out_public_share); + } catch (const std::bad_alloc&) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_2p_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed) { + try { + if (!out_public_key_blob || !out_private_scalar_fixed) return E_BADARG; + *out_public_key_blob = cmem_t{nullptr, 0}; + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t public_blob; + coinbase::buf_t private_scalar_fixed; + const coinbase::error_t rv = + coinbase::api::eddsa_2p::detach_private_scalar(view_cmem(key_blob), public_blob, private_scalar_fixed); + if (rv) return rv; + + const auto r1 = alloc_cmem_from_buf(public_blob, out_public_key_blob); + if (r1) return r1; + const auto r2 = alloc_cmem_from_buf(private_scalar_fixed, out_private_scalar_fixed); + if (r2) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + return r2; + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_2p_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vpb = validate_cmem(public_key_blob); + if (vpb) return vpb; + const auto vx = validate_cmem(private_scalar_fixed); + if (vx) return vx; + const auto vq = validate_cmem(public_share_compressed); + if (vq) return vq; + + coinbase::buf_t merged; + const coinbase::error_t rv = coinbase::api::eddsa_2p::attach_private_scalar( + view_cmem(public_key_blob), view_cmem(private_scalar_fixed), view_cmem(public_share_compressed), merged); + if (rv) return rv; + return alloc_cmem_from_buf(merged, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/eddsa_mp.cpp b/src/cbmpc/c_api/eddsa_mp.cpp new file mode 100644 index 00000000..20042525 --- /dev/null +++ b/src/cbmpc/c_api/eddsa_mp.cpp @@ -0,0 +1,439 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "access_structure_adapter.h" +#include "util.h" + +using namespace coinbase::capi::detail; + +extern "C" { + +cbmpc_error_t cbmpc_eddsa_mp_dkg_additive(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob, + cmem_t* out_sid) { + try { + if (!out_key_blob || !out_sid) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t key_blob; + coinbase::buf_t sid; + const coinbase::error_t rv = coinbase::api::eddsa_mp::dkg_additive(ctx.job, curve_cpp, key_blob, sid); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(key_blob, out_key_blob); + if (r_key) return r_key; + + const auto r_sid = alloc_cmem_from_buf(sid, out_sid); + if (r_sid) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_key_blob) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_dkg_ac(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t sid_in, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* out_ac_key_blob, cmem_t* out_sid) { + try { + if (!out_ac_key_blob || !out_sid) return E_BADARG; + *out_ac_key_blob = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + + std::vector quorum_names; + const auto vqn = to_cpp_quorum_party_names(quorum_party_names, quorum_party_names_count, quorum_names); + if (vqn) return vqn; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t key_blob; + const coinbase::error_t rv = + coinbase::api::eddsa_mp::dkg_ac(ctx.job, curve_cpp, sid, ac_cpp, quorum_names, key_blob); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(key_blob, out_ac_key_blob); + if (r_key) return r_key; + + const auto r_sid = alloc_cmem_from_buf(sid, out_sid); + if (r_sid) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_ac_key_blob) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_ac_key_blob) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_refresh_additive(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t key_blob, + cmem_t* sid_out, cmem_t* out_new_key_blob) { + try { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (!out_new_key_blob) return E_BADARG; + *out_new_key_blob = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t new_key; + const coinbase::error_t rv = coinbase::api::eddsa_mp::refresh_additive(ctx.job, sid, view_cmem(key_blob), new_key); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(new_key, out_new_key_blob); + if (r_key) return r_key; + + if (sid_out) { + const auto r_sid = alloc_cmem_from_buf(sid, sid_out); + if (r_sid) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_key_blob) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_key_blob) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_refresh_ac(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* sid_out, cmem_t* out_new_ac_key_blob) { + try { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (!out_new_ac_key_blob) return E_BADARG; + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + const auto vkb = validate_cmem(ac_key_blob); + if (vkb) return vkb; + + std::vector quorum_names; + const auto vqn = to_cpp_quorum_party_names(quorum_party_names, quorum_party_names_count, quorum_names); + if (vqn) return vqn; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t new_key; + const coinbase::error_t rv = + coinbase::api::eddsa_mp::refresh_ac(ctx.job, sid, view_cmem(ac_key_blob), ac_cpp, quorum_names, new_key); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(new_key, out_new_ac_key_blob); + if (r_key) return r_key; + + if (sid_out) { + const auto r_sid = alloc_cmem_from_buf(sid, sid_out); + if (r_sid) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_ac_key_blob) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_ac_key_blob) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_sign_additive(const cbmpc_mp_job_t* job, cmem_t key_blob, cmem_t msg, int32_t sig_receiver, + cmem_t* sig_out) { + try { + if (!sig_out) return E_BADARG; + *sig_out = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + const auto vm = validate_cmem(msg); + if (vm) return vm; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sig; + const coinbase::error_t rv = + coinbase::api::eddsa_mp::sign_additive(ctx.job, view_cmem(key_blob), view_cmem(msg), sig_receiver, sig); + if (rv) return rv; + + return alloc_cmem_from_buf(sig, sig_out); + } catch (const std::bad_alloc&) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_sign_ac(const cbmpc_mp_job_t* job, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, cmem_t msg, int32_t sig_receiver, + cmem_t* sig_out) { + try { + if (!sig_out) return E_BADARG; + *sig_out = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vkb = validate_cmem(ac_key_blob); + if (vkb) return vkb; + const auto vm = validate_cmem(msg); + if (vm) return vm; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sig; + const coinbase::error_t rv = + coinbase::api::eddsa_mp::sign_ac(ctx.job, view_cmem(ac_key_blob), ac_cpp, view_cmem(msg), sig_receiver, sig); + if (rv) return rv; + + return alloc_cmem_from_buf(sig, sig_out); + } catch (const std::bad_alloc&) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key) { + try { + if (!out_pub_key) return E_BADARG; + *out_pub_key = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t pk; + const coinbase::error_t rv = coinbase::api::eddsa_mp::get_public_key_compressed(view_cmem(key_blob), pk); + if (rv) return rv; + + return alloc_cmem_from_buf(pk, out_pub_key); + } catch (const std::bad_alloc&) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share) { + try { + if (!out_public_share) return E_BADARG; + *out_public_share = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t Qi; + const coinbase::error_t rv = coinbase::api::eddsa_mp::get_public_share_compressed(view_cmem(key_blob), Qi); + if (rv) return rv; + return alloc_cmem_from_buf(Qi, out_public_share); + } catch (const std::bad_alloc&) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed) { + try { + if (!out_public_key_blob || !out_private_scalar_fixed) return E_BADARG; + *out_public_key_blob = cmem_t{nullptr, 0}; + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t public_blob; + coinbase::buf_t private_scalar_fixed; + const coinbase::error_t rv = + coinbase::api::eddsa_mp::detach_private_scalar(view_cmem(key_blob), public_blob, private_scalar_fixed); + if (rv) return rv; + + const auto r1 = alloc_cmem_from_buf(public_blob, out_public_key_blob); + if (r1) return r1; + const auto r2 = alloc_cmem_from_buf(private_scalar_fixed, out_private_scalar_fixed); + if (r2) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + return r2; + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_eddsa_mp_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vpb = validate_cmem(public_key_blob); + if (vpb) return vpb; + const auto vx = validate_cmem(private_scalar_fixed); + if (vx) return vx; + const auto vq = validate_cmem(public_share_compressed); + if (vq) return vq; + + coinbase::buf_t merged; + const coinbase::error_t rv = coinbase::api::eddsa_mp::attach_private_scalar( + view_cmem(public_key_blob), view_cmem(private_scalar_fixed), view_cmem(public_share_compressed), merged); + if (rv) return rv; + return alloc_cmem_from_buf(merged, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/pve_base_pke.cpp b/src/cbmpc/c_api/pve_base_pke.cpp new file mode 100644 index 00000000..4f3cd818 --- /dev/null +++ b/src/cbmpc/c_api/pve_base_pke.cpp @@ -0,0 +1,565 @@ +#include + +#include +#include +#include +#include +#include +#include + +#include "pve_internal.h" +#include "util.h" + +using namespace coinbase::capi::detail; + +namespace { +using coinbase::capi::pve_detail::c_base_pke_adapter_t; +using coinbase::capi::pve_detail::ecies_p256_hsm_ecdh_cpp; +using coinbase::capi::pve_detail::rsa_oaep_hsm_decap_cpp; + +static coinbase::error_t kem_encap_shim(void* ctx, coinbase::mem_t ek_bytes, coinbase::mem_t rho32, + coinbase::buf_t& out_kem_ct, coinbase::buf_t& out_kem_ss) { + auto* kem = static_cast(ctx); + if (!kem || !kem->encap) return E_BADARG; + if (rho32.size != 32) return E_BADARG; + + cmem_t kem_ct{nullptr, 0}; + cmem_t kem_ss{nullptr, 0}; + const cbmpc_error_t rv = kem->encap(kem->ctx, cmem_t{const_cast(ek_bytes.data), ek_bytes.size}, + cmem_t{const_cast(rho32.data), rho32.size}, &kem_ct, &kem_ss); + if (rv) { + if (kem_ct.data) cbmpc_cmem_free(kem_ct); + if (kem_ss.data) cbmpc_cmem_free(kem_ss); + return rv; + } + + if (kem_ct.size < 0 || (kem_ct.size > 0 && !kem_ct.data)) { + cbmpc_free(kem_ct.data); + if (kem_ss.data) cbmpc_cmem_free(kem_ss); + return E_FORMAT; + } + if (kem_ss.size < 0 || (kem_ss.size > 0 && !kem_ss.data)) { + cbmpc_free(kem_ss.data); + if (kem_ct.data) cbmpc_cmem_free(kem_ct); + return E_FORMAT; + } + out_kem_ct = coinbase::buf_t(kem_ct.data, kem_ct.size); + out_kem_ss = coinbase::buf_t(kem_ss.data, kem_ss.size); + cbmpc_cmem_free(kem_ct); + cbmpc_cmem_free(kem_ss); + return CBMPC_SUCCESS; +} + +static coinbase::error_t kem_decap_shim(void* ctx, const void* dk_handle, coinbase::mem_t kem_ct, + coinbase::buf_t& out_kem_ss) { + auto* kem = static_cast(ctx); + if (!kem || !kem->decap) return E_BADARG; + if (!dk_handle) return E_BADARG; + const auto* dk_mem = static_cast(dk_handle); + + cmem_t kem_ss{nullptr, 0}; + const cbmpc_error_t rv = kem->decap(kem->ctx, cmem_t{const_cast(dk_mem->data), dk_mem->size}, + cmem_t{const_cast(kem_ct.data), kem_ct.size}, &kem_ss); + if (rv) { + if (kem_ss.data) cbmpc_cmem_free(kem_ss); + return rv; + } + + if (kem_ss.size < 0 || (kem_ss.size > 0 && !kem_ss.data)) { + cbmpc_free(kem_ss.data); + return E_FORMAT; + } + out_kem_ss = coinbase::buf_t(kem_ss.data, kem_ss.size); + cbmpc_cmem_free(kem_ss); + return CBMPC_SUCCESS; +} + +class c_base_kem_adapter_t final : public coinbase::api::pve::base_pke_i { + public: + explicit c_base_kem_adapter_t(const cbmpc_pve_base_kem_t* kem) : kem_(kem) {} + + coinbase::error_t encrypt(coinbase::mem_t ek, coinbase::mem_t label, coinbase::mem_t plain, coinbase::mem_t rho, + coinbase::buf_t& out_ct) const override { + if (!kem_ || !kem_->encap) return E_BADARG; + + coinbase::mpc::pve_runtime_kem_callbacks_t callbacks; + callbacks.ctx = const_cast(kem_); + callbacks.encap = kem_encap_shim; + callbacks.decap = kem_decap_shim; + + coinbase::mpc::pve_runtime_kem_ek_t ek_i; + ek_i.ek_bytes = ek; + ek_i.callbacks = &callbacks; + + return coinbase::mpc::pve_base_pke_runtime_kem().encrypt(coinbase::mpc::pve_keyref(ek_i), label, plain, rho, + out_ct); + } + + coinbase::error_t decrypt(coinbase::mem_t dk, coinbase::mem_t label, coinbase::mem_t ct, + coinbase::buf_t& out_plain) const override { + if (!kem_ || !kem_->decap) return E_BADARG; + + coinbase::mpc::pve_runtime_kem_callbacks_t callbacks; + callbacks.ctx = const_cast(kem_); + callbacks.encap = kem_encap_shim; + callbacks.decap = kem_decap_shim; + + const coinbase::mem_t dk_mem(dk.data, dk.size); + coinbase::mpc::pve_runtime_kem_dk_t dk_i; + dk_i.dk_handle = &dk_mem; + dk_i.callbacks = &callbacks; + + return coinbase::mpc::pve_base_pke_runtime_kem().decrypt(coinbase::mpc::pve_keyref(dk_i), label, ct, out_plain); + } + + private: + const cbmpc_pve_base_kem_t* kem_ = nullptr; +}; + +} // namespace + +extern "C" { + +cbmpc_error_t cbmpc_pve_generate_base_pke_rsa_keypair(cmem_t* out_ek, cmem_t* out_dk) { + try { + if (!out_ek || !out_dk) return E_BADARG; + *out_ek = cmem_t{nullptr, 0}; + *out_dk = cmem_t{nullptr, 0}; + + coinbase::buf_t ek_blob; + coinbase::buf_t dk_blob; + const coinbase::error_t rv = coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob); + if (rv) return rv; + + const cbmpc_error_t rv_ek = alloc_cmem_from_buf(ek_blob, out_ek); + if (rv_ek) return rv_ek; + const cbmpc_error_t rv_dk = alloc_cmem_from_buf(dk_blob, out_dk); + if (rv_dk) { + cbmpc_cmem_free(*out_ek); + *out_ek = cmem_t{nullptr, 0}; + return rv_dk; + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_ek && out_ek->data) cbmpc_cmem_free(*out_ek); + if (out_dk && out_dk->data) cbmpc_cmem_free(*out_dk); + if (out_ek) *out_ek = cmem_t{nullptr, 0}; + if (out_dk) *out_dk = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ek && out_ek->data) cbmpc_cmem_free(*out_ek); + if (out_dk && out_dk->data) cbmpc_cmem_free(*out_dk); + if (out_ek) *out_ek = cmem_t{nullptr, 0}; + if (out_dk) *out_dk = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_generate_base_pke_ecies_p256_keypair(cmem_t* out_ek, cmem_t* out_dk) { + try { + if (!out_ek || !out_dk) return E_BADARG; + *out_ek = cmem_t{nullptr, 0}; + *out_dk = cmem_t{nullptr, 0}; + + coinbase::buf_t ek_blob; + coinbase::buf_t dk_blob; + const coinbase::error_t rv = coinbase::api::pve::generate_base_pke_ecies_p256_keypair(ek_blob, dk_blob); + if (rv) return rv; + + const cbmpc_error_t rv_ek = alloc_cmem_from_buf(ek_blob, out_ek); + if (rv_ek) return rv_ek; + const cbmpc_error_t rv_dk = alloc_cmem_from_buf(dk_blob, out_dk); + if (rv_dk) { + cbmpc_cmem_free(*out_ek); + *out_ek = cmem_t{nullptr, 0}; + return rv_dk; + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_ek && out_ek->data) cbmpc_cmem_free(*out_ek); + if (out_dk && out_dk->data) cbmpc_cmem_free(*out_dk); + if (out_ek) *out_ek = cmem_t{nullptr, 0}; + if (out_dk) *out_dk = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ek && out_ek->data) cbmpc_cmem_free(*out_ek); + if (out_dk && out_dk->data) cbmpc_cmem_free(*out_dk); + if (out_ek) *out_ek = cmem_t{nullptr, 0}; + if (out_dk) *out_dk = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_base_pke_ecies_p256_ek_from_oct(cmem_t pub_key_oct, cmem_t* out_ek) { + try { + if (!out_ek) return E_BADARG; + *out_ek = cmem_t{nullptr, 0}; + + const auto vpk = validate_cmem(pub_key_oct); + if (vpk) return vpk; + + coinbase::buf_t ek_blob; + const coinbase::error_t rv = coinbase::api::pve::base_pke_ecies_p256_ek_from_oct(view_cmem(pub_key_oct), ek_blob); + if (rv) return rv; + + return alloc_cmem_from_buf(ek_blob, out_ek); + } catch (const std::bad_alloc&) { + if (out_ek) *out_ek = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ek) *out_ek = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_base_pke_rsa_ek_from_modulus(cmem_t modulus, cmem_t* out_ek) { + try { + if (!out_ek) return E_BADARG; + *out_ek = cmem_t{nullptr, 0}; + + const auto vm = validate_cmem(modulus); + if (vm) return vm; + + coinbase::buf_t ek_blob; + const coinbase::error_t rv = coinbase::api::pve::base_pke_rsa_ek_from_modulus(view_cmem(modulus), ek_blob); + if (rv) return rv; + + return alloc_cmem_from_buf(ek_blob, out_ek); + } catch (const std::bad_alloc&) { + if (out_ek) *out_ek = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ek) *out_ek = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_decrypt_rsa_oaep_hsm(cbmpc_curve_id_t curve, cmem_t dk_handle, cmem_t ek, cmem_t ciphertext, + cmem_t label, const cbmpc_pve_rsa_oaep_hsm_decap_t* cb, cmem_t* out_x) { + try { + if (!out_x) return E_BADARG; + *out_x = cmem_t{nullptr, 0}; + if (!cb || !cb->decap) return E_BADARG; + + const auto vdk = validate_cmem(dk_handle); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::pve::rsa_oaep_hsm_decap_cb_t cb_cpp; + cb_cpp.ctx = const_cast(cb); + cb_cpp.decap = rsa_oaep_hsm_decap_cpp; + + coinbase::buf_t x_out; + const coinbase::error_t rv = coinbase::api::pve::decrypt_rsa_oaep_hsm( + curve_cpp, view_cmem(dk_handle), view_cmem(ek), view_cmem(ciphertext), view_cmem(label), cb_cpp, x_out); + if (rv) return rv; + + return alloc_cmem_from_buf(x_out, out_x); + } catch (const std::bad_alloc&) { + if (out_x) *out_x = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_x) *out_x = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_decrypt_ecies_p256_hsm(cbmpc_curve_id_t curve, cmem_t dk_handle, cmem_t ek, cmem_t ciphertext, + cmem_t label, const cbmpc_pve_ecies_p256_hsm_ecdh_t* cb, cmem_t* out_x) { + try { + if (!out_x) return E_BADARG; + *out_x = cmem_t{nullptr, 0}; + if (!cb || !cb->ecdh) return E_BADARG; + + const auto vdk = validate_cmem(dk_handle); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb_cpp; + cb_cpp.ctx = const_cast(cb); + cb_cpp.ecdh = ecies_p256_hsm_ecdh_cpp; + + coinbase::buf_t x_out; + const coinbase::error_t rv = coinbase::api::pve::decrypt_ecies_p256_hsm( + curve_cpp, view_cmem(dk_handle), view_cmem(ek), view_cmem(ciphertext), view_cmem(label), cb_cpp, x_out); + if (rv) return rv; + + return alloc_cmem_from_buf(x_out, out_x); + } catch (const std::bad_alloc&) { + if (out_x) *out_x = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_x) *out_x = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_encrypt_with_kem(const cbmpc_pve_base_kem_t* kem, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t label, cmem_t x, cmem_t* out_ciphertext) { + try { + if (!out_ciphertext) return E_BADARG; + *out_ciphertext = cmem_t{nullptr, 0}; + if (!kem || !kem->encap) return E_BADARG; + + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vl = validate_cmem(label); + if (vl) return vl; + const auto vx = validate_cmem(x); + if (vx) return vx; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + c_base_kem_adapter_t adapter(kem); + coinbase::buf_t ct; + const coinbase::error_t rv = + coinbase::api::pve::encrypt(adapter, curve_cpp, view_cmem(ek), view_cmem(label), view_cmem(x), ct); + if (rv) return rv; + + return alloc_cmem_from_buf(ct, out_ciphertext); + } catch (const std::bad_alloc&) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_verify_with_kem(const cbmpc_pve_base_kem_t* kem, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t ciphertext, cmem_t Q_compressed, cmem_t label) { + try { + if (!kem || !kem->encap) return E_BADARG; + + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vq = validate_cmem(Q_compressed); + if (vq) return vq; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + c_base_kem_adapter_t adapter(kem); + return coinbase::api::pve::verify(adapter, curve_cpp, view_cmem(ek), view_cmem(ciphertext), view_cmem(Q_compressed), + view_cmem(label)); + } catch (const std::bad_alloc&) { + return E_INSUFFICIENT; + } catch (...) { + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_decrypt_with_kem(const cbmpc_pve_base_kem_t* kem, cbmpc_curve_id_t curve, cmem_t dk, cmem_t ek, + cmem_t ciphertext, cmem_t label, cmem_t* out_x) { + try { + if (!out_x) return E_BADARG; + *out_x = cmem_t{nullptr, 0}; + if (!kem || !kem->decap) return E_BADARG; + + const auto vdk = validate_cmem(dk); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + c_base_kem_adapter_t adapter(kem); + coinbase::buf_t x_out; + const coinbase::error_t rv = coinbase::api::pve::decrypt(adapter, curve_cpp, view_cmem(dk), view_cmem(ek), + view_cmem(ciphertext), view_cmem(label), x_out); + if (rv) return rv; + + return alloc_cmem_from_buf(x_out, out_x); + } catch (const std::bad_alloc&) { + if (out_x) *out_x = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_x) *out_x = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_encrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t ek, cmem_t label, + cmem_t x, cmem_t* out_ciphertext) { + try { + if (!out_ciphertext) return E_BADARG; + *out_ciphertext = cmem_t{nullptr, 0}; + + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vl = validate_cmem(label); + if (vl) return vl; + const auto vx = validate_cmem(x); + if (vx) return vx; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::buf_t ct; + coinbase::error_t rv = UNINITIALIZED_ERROR; + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + rv = coinbase::api::pve::encrypt(adapter, curve_cpp, view_cmem(ek), view_cmem(label), view_cmem(x), ct); + } else { + rv = coinbase::api::pve::encrypt(curve_cpp, view_cmem(ek), view_cmem(label), view_cmem(x), ct); + } + if (rv) return rv; + + return alloc_cmem_from_buf(ct, out_ciphertext); + } catch (const std::bad_alloc&) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_verify(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t ciphertext, cmem_t Q_compressed, cmem_t label) { + try { + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vq = validate_cmem(Q_compressed); + if (vq) return vq; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + return coinbase::api::pve::verify(adapter, curve_cpp, view_cmem(ek), view_cmem(ciphertext), + view_cmem(Q_compressed), view_cmem(label)); + } + return coinbase::api::pve::verify(curve_cpp, view_cmem(ek), view_cmem(ciphertext), view_cmem(Q_compressed), + view_cmem(label)); + } catch (const std::bad_alloc&) { + return E_INSUFFICIENT; + } catch (...) { + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_decrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t dk, cmem_t ek, + cmem_t ciphertext, cmem_t label, cmem_t* out_x) { + try { + if (!out_x) return E_BADARG; + *out_x = cmem_t{nullptr, 0}; + + const auto vdk = validate_cmem(dk); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::buf_t x_out; + coinbase::error_t rv = UNINITIALIZED_ERROR; + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + rv = coinbase::api::pve::decrypt(adapter, curve_cpp, view_cmem(dk), view_cmem(ek), view_cmem(ciphertext), + view_cmem(label), x_out); + } else { + rv = coinbase::api::pve::decrypt(curve_cpp, view_cmem(dk), view_cmem(ek), view_cmem(ciphertext), view_cmem(label), + x_out); + } + if (rv) return rv; + + return alloc_cmem_from_buf(x_out, out_x); + } catch (const std::bad_alloc&) { + if (out_x) *out_x = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_x) *out_x = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_get_Q(cmem_t ciphertext, cmem_t* out_Q_compressed) { + try { + if (!out_Q_compressed) return E_BADARG; + *out_Q_compressed = cmem_t{nullptr, 0}; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + coinbase::buf_t Q; + const coinbase::error_t rv = coinbase::api::pve::get_public_key_compressed(view_cmem(ciphertext), Q); + if (rv) return rv; + + return alloc_cmem_from_buf(Q, out_Q_compressed); + } catch (const std::bad_alloc&) { + if (out_Q_compressed) *out_Q_compressed = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_Q_compressed) *out_Q_compressed = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_get_Label(cmem_t ciphertext, cmem_t* out_label) { + try { + if (!out_label) return E_BADARG; + *out_label = cmem_t{nullptr, 0}; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + coinbase::buf_t L; + const coinbase::error_t rv = coinbase::api::pve::get_Label(view_cmem(ciphertext), L); + if (rv) return rv; + + return alloc_cmem_from_buf(L, out_label); + } catch (const std::bad_alloc&) { + if (out_label) *out_label = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_label) *out_label = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/pve_batch_ac.cpp b/src/cbmpc/c_api/pve_batch_ac.cpp new file mode 100644 index 00000000..3703828e --- /dev/null +++ b/src/cbmpc/c_api/pve_batch_ac.cpp @@ -0,0 +1,387 @@ +#include + +#include +#include +#include +#include + +#include "access_structure_adapter.h" +#include "pve_internal.h" +#include "util.h" + +using namespace coinbase::capi::detail; +using coinbase::capi::pve_detail::c_base_pke_adapter_t; +using coinbase::capi::pve_detail::ecies_p256_hsm_ecdh_cpp; +using coinbase::capi::pve_detail::rsa_oaep_hsm_decap_cpp; + +namespace { + +static cbmpc_error_t validate_leaf_key_mapping(const char* const* leaf_names, const cmem_t* leaf_keys, int leaf_count, + coinbase::api::pve::leaf_keys_t& out) { + out.clear(); + if (leaf_count < 0) return E_BADARG; + if (leaf_count == 0) return CBMPC_SUCCESS; + if (!leaf_names || !leaf_keys) return E_BADARG; + + for (int i = 0; i < leaf_count; i++) { + const char* name = leaf_names[i]; + if (!name || name[0] == '\0') return E_BADARG; + const auto vk = validate_cmem(leaf_keys[i]); + if (vk) return vk; + + const auto [it, inserted] = + out.emplace(std::string_view(name), coinbase::mem_t(leaf_keys[i].data, leaf_keys[i].size)); + if (!inserted) return E_BADARG; // duplicate leaf name + } + return CBMPC_SUCCESS; +} + +static cbmpc_error_t validate_leaf_share_mapping(const char* const* leaf_names, const cmem_t* leaf_shares, + int leaf_count, coinbase::api::pve::leaf_shares_t& out) { + out.clear(); + if (leaf_count < 0) return E_BADARG; + if (leaf_count == 0) return CBMPC_SUCCESS; + if (!leaf_names || !leaf_shares) return E_BADARG; + + for (int i = 0; i < leaf_count; i++) { + const char* name = leaf_names[i]; + if (!name || name[0] == '\0') return E_BADARG; + const auto vs = validate_cmem(leaf_shares[i]); + if (vs) return vs; + + const auto [it, inserted] = + out.emplace(std::string_view(name), coinbase::mem_t(leaf_shares[i].data, leaf_shares[i].size)); + if (!inserted) return E_BADARG; // duplicate leaf name + } + return CBMPC_SUCCESS; +} + +} // namespace + +extern "C" { + +cbmpc_error_t cbmpc_pve_ac_encrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, const char* const* leaf_names, + const cmem_t* leaf_eks, int leaf_count, cmem_t label, cmems_t xs, + cmem_t* out_ciphertext) { + try { + if (!out_ciphertext) return E_BADARG; + *out_ciphertext = cmem_t{nullptr, 0}; + + const auto vl = validate_cmem(label); + if (vl) return vl; + + std::vector xs_cpp; + const auto vxs = view_cmems(xs, xs_cpp); + if (vxs) return vxs; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::access_structure_t ac_cpp; + const auto aconv = coinbase::capi::detail::to_cpp_access_structure(ac, ac_cpp); + if (aconv) return aconv; + + coinbase::api::pve::leaf_keys_t keys_cpp; + const auto kmap = validate_leaf_key_mapping(leaf_names, leaf_eks, leaf_count, keys_cpp); + if (kmap) return kmap; + + coinbase::buf_t ct; + coinbase::error_t rv = UNINITIALIZED_ERROR; + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + rv = coinbase::api::pve::encrypt_ac(adapter, curve_cpp, ac_cpp, keys_cpp, view_cmem(label), xs_cpp, ct); + } else { + rv = coinbase::api::pve::encrypt_ac(curve_cpp, ac_cpp, keys_cpp, view_cmem(label), xs_cpp, ct); + } + if (rv) return rv; + + return alloc_cmem_from_buf(ct, out_ciphertext); + } catch (const std::bad_alloc&) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_ac_verify(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, const char* const* leaf_names, + const cmem_t* leaf_eks, int leaf_count, cmem_t ciphertext, cmems_t Qs_compressed, + cmem_t label) { + try { + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + std::vector qs_cpp; + const auto vqs = view_cmems(Qs_compressed, qs_cpp); + if (vqs) return vqs; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::access_structure_t ac_cpp; + const auto aconv = coinbase::capi::detail::to_cpp_access_structure(ac, ac_cpp); + if (aconv) return aconv; + + coinbase::api::pve::leaf_keys_t keys_cpp; + const auto kmap = validate_leaf_key_mapping(leaf_names, leaf_eks, leaf_count, keys_cpp); + if (kmap) return kmap; + + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + return coinbase::api::pve::verify_ac(adapter, curve_cpp, ac_cpp, keys_cpp, view_cmem(ciphertext), qs_cpp, + view_cmem(label)); + } + return coinbase::api::pve::verify_ac(curve_cpp, ac_cpp, keys_cpp, view_cmem(ciphertext), qs_cpp, view_cmem(label)); + } catch (const std::bad_alloc&) { + return E_INSUFFICIENT; + } catch (...) { + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_ac_partial_decrypt_attempt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, cmem_t ciphertext, + int attempt_index, const char* leaf_name, cmem_t dk, cmem_t label, + cmem_t* out_share) { + try { + if (!out_share) return E_BADARG; + *out_share = cmem_t{nullptr, 0}; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vdk = validate_cmem(dk); + if (vdk) return vdk; + const auto vl = validate_cmem(label); + if (vl) return vl; + if (!leaf_name || leaf_name[0] == '\0') return E_BADARG; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::access_structure_t ac_cpp; + const auto aconv = coinbase::capi::detail::to_cpp_access_structure(ac, ac_cpp); + if (aconv) return aconv; + + coinbase::buf_t share; + coinbase::error_t rv = UNINITIALIZED_ERROR; + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + rv = coinbase::api::pve::partial_decrypt_ac_attempt(adapter, curve_cpp, ac_cpp, view_cmem(ciphertext), + attempt_index, std::string_view(leaf_name), view_cmem(dk), + view_cmem(label), share); + } else { + rv = coinbase::api::pve::partial_decrypt_ac_attempt(curve_cpp, ac_cpp, view_cmem(ciphertext), attempt_index, + std::string_view(leaf_name), view_cmem(dk), view_cmem(label), + share); + } + if (rv) return rv; + + return alloc_cmem_from_buf(share, out_share); + } catch (const std::bad_alloc&) { + if (out_share) *out_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_share) *out_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_ac_partial_decrypt_attempt_rsa_oaep_hsm(cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, cmem_t ciphertext, + int attempt_index, const char* leaf_name, + cmem_t dk_handle, cmem_t ek, cmem_t label, + const cbmpc_pve_rsa_oaep_hsm_decap_t* cb, + cmem_t* out_share) { + try { + if (!out_share) return E_BADARG; + *out_share = cmem_t{nullptr, 0}; + if (!cb || !cb->decap) return E_BADARG; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vdk = validate_cmem(dk_handle); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vl = validate_cmem(label); + if (vl) return vl; + if (!leaf_name || leaf_name[0] == '\0') return E_BADARG; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::access_structure_t ac_cpp; + const auto aconv = coinbase::capi::detail::to_cpp_access_structure(ac, ac_cpp); + if (aconv) return aconv; + + coinbase::api::pve::rsa_oaep_hsm_decap_cb_t cb_cpp; + cb_cpp.ctx = const_cast(cb); + cb_cpp.decap = rsa_oaep_hsm_decap_cpp; + + coinbase::buf_t share; + const coinbase::error_t rv = coinbase::api::pve::partial_decrypt_ac_attempt_rsa_oaep_hsm( + curve_cpp, ac_cpp, view_cmem(ciphertext), attempt_index, std::string_view(leaf_name), view_cmem(dk_handle), + view_cmem(ek), view_cmem(label), cb_cpp, share); + if (rv) return rv; + + return alloc_cmem_from_buf(share, out_share); + } catch (const std::bad_alloc&) { + if (out_share) *out_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_share) *out_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_ac_partial_decrypt_attempt_ecies_p256_hsm(cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, cmem_t ciphertext, + int attempt_index, const char* leaf_name, + cmem_t dk_handle, cmem_t ek, cmem_t label, + const cbmpc_pve_ecies_p256_hsm_ecdh_t* cb, + cmem_t* out_share) { + try { + if (!out_share) return E_BADARG; + *out_share = cmem_t{nullptr, 0}; + if (!cb || !cb->ecdh) return E_BADARG; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vdk = validate_cmem(dk_handle); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vl = validate_cmem(label); + if (vl) return vl; + if (!leaf_name || leaf_name[0] == '\0') return E_BADARG; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::access_structure_t ac_cpp; + const auto aconv = coinbase::capi::detail::to_cpp_access_structure(ac, ac_cpp); + if (aconv) return aconv; + + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb_cpp; + cb_cpp.ctx = const_cast(cb); + cb_cpp.ecdh = ecies_p256_hsm_ecdh_cpp; + + coinbase::buf_t share; + const coinbase::error_t rv = coinbase::api::pve::partial_decrypt_ac_attempt_ecies_p256_hsm( + curve_cpp, ac_cpp, view_cmem(ciphertext), attempt_index, std::string_view(leaf_name), view_cmem(dk_handle), + view_cmem(ek), view_cmem(label), cb_cpp, share); + if (rv) return rv; + + return alloc_cmem_from_buf(share, out_share); + } catch (const std::bad_alloc&) { + if (out_share) *out_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_share) *out_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_ac_combine(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, + const cbmpc_access_structure_t* ac, const char* const* quorum_leaf_names, + const cmem_t* quorum_shares, int quorum_count, cmem_t ciphertext, int attempt_index, + cmem_t label, cmems_t* out_xs) { + try { + if (!out_xs) return E_BADARG; + *out_xs = cmems_t{0, nullptr, nullptr}; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::access_structure_t ac_cpp; + const auto aconv = coinbase::capi::detail::to_cpp_access_structure(ac, ac_cpp); + if (aconv) return aconv; + + coinbase::api::pve::leaf_shares_t quorum_cpp; + const auto qmap = validate_leaf_share_mapping(quorum_leaf_names, quorum_shares, quorum_count, quorum_cpp); + if (qmap) return qmap; + + std::vector xs_cpp; + coinbase::error_t rv = UNINITIALIZED_ERROR; + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + rv = coinbase::api::pve::combine_ac(adapter, curve_cpp, ac_cpp, view_cmem(ciphertext), attempt_index, + view_cmem(label), quorum_cpp, xs_cpp); + } else { + rv = coinbase::api::pve::combine_ac(curve_cpp, ac_cpp, view_cmem(ciphertext), attempt_index, view_cmem(label), + quorum_cpp, xs_cpp); + } + if (rv) return rv; + + return alloc_cmems_from_bufs(xs_cpp, out_xs); + } catch (const std::bad_alloc&) { + if (out_xs) *out_xs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } catch (...) { + if (out_xs) *out_xs = cmems_t{0, nullptr, nullptr}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_ac_get_count(cmem_t ciphertext, int* out_batch_count) { + try { + if (!out_batch_count) return E_BADARG; + *out_batch_count = 0; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + int n = 0; + const coinbase::error_t rv = coinbase::api::pve::get_ac_batch_count(view_cmem(ciphertext), n); + if (rv) return rv; + + *out_batch_count = n; + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_batch_count) *out_batch_count = 0; + return E_INSUFFICIENT; + } catch (...) { + if (out_batch_count) *out_batch_count = 0; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_ac_get_Qs(cmem_t ciphertext, cmems_t* out_Qs_compressed) { + try { + if (!out_Qs_compressed) return E_BADARG; + *out_Qs_compressed = cmems_t{0, nullptr, nullptr}; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + std::vector Qs; + const coinbase::error_t rv = coinbase::api::pve::get_public_keys_compressed_ac(view_cmem(ciphertext), Qs); + if (rv) return rv; + + return alloc_cmems_from_bufs(Qs, out_Qs_compressed); + } catch (const std::bad_alloc&) { + if (out_Qs_compressed) *out_Qs_compressed = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } catch (...) { + if (out_Qs_compressed) *out_Qs_compressed = cmems_t{0, nullptr, nullptr}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/pve_batch_single_recipient.cpp b/src/cbmpc/c_api/pve_batch_single_recipient.cpp new file mode 100644 index 00000000..6d7a3062 --- /dev/null +++ b/src/cbmpc/c_api/pve_batch_single_recipient.cpp @@ -0,0 +1,276 @@ +#include + +#include +#include +#include +#include + +#include "pve_internal.h" +#include "util.h" + +using namespace coinbase::capi::detail; +using coinbase::capi::pve_detail::c_base_pke_adapter_t; +using coinbase::capi::pve_detail::ecies_p256_hsm_ecdh_cpp; +using coinbase::capi::pve_detail::rsa_oaep_hsm_decap_cpp; + +extern "C" { + +cbmpc_error_t cbmpc_pve_batch_encrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t label, cmems_t xs, cmem_t* out_ciphertext) { + try { + if (!out_ciphertext) return E_BADARG; + *out_ciphertext = cmem_t{nullptr, 0}; + + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vl = validate_cmem(label); + if (vl) return vl; + + std::vector xs_cpp; + const auto vxs = view_cmems(xs, xs_cpp); + if (vxs) return vxs; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::buf_t ct; + coinbase::error_t rv = UNINITIALIZED_ERROR; + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + rv = coinbase::api::pve::encrypt_batch(adapter, curve_cpp, view_cmem(ek), view_cmem(label), xs_cpp, ct); + } else { + rv = coinbase::api::pve::encrypt_batch(curve_cpp, view_cmem(ek), view_cmem(label), xs_cpp, ct); + } + if (rv) return rv; + + return alloc_cmem_from_buf(ct, out_ciphertext); + } catch (const std::bad_alloc&) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_batch_verify(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t ek, + cmem_t ciphertext, cmems_t Qs_compressed, cmem_t label) { + try { + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + std::vector Qs_cpp; + const auto vqs = view_cmems(Qs_compressed, Qs_cpp); + if (vqs) return vqs; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + return coinbase::api::pve::verify_batch(adapter, curve_cpp, view_cmem(ek), view_cmem(ciphertext), Qs_cpp, + view_cmem(label)); + } + return coinbase::api::pve::verify_batch(curve_cpp, view_cmem(ek), view_cmem(ciphertext), Qs_cpp, view_cmem(label)); + } catch (const std::bad_alloc&) { + return E_INSUFFICIENT; + } catch (...) { + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_batch_decrypt(const cbmpc_pve_base_pke_t* base_pke, cbmpc_curve_id_t curve, cmem_t dk, + cmem_t ek, cmem_t ciphertext, cmem_t label, cmems_t* out_xs) { + try { + if (!out_xs) return E_BADARG; + *out_xs = cmems_t{0, nullptr, nullptr}; + + const auto vdk = validate_cmem(dk); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + std::vector xs_cpp; + coinbase::error_t rv = UNINITIALIZED_ERROR; + if (base_pke) { + c_base_pke_adapter_t adapter(base_pke); + rv = coinbase::api::pve::decrypt_batch(adapter, curve_cpp, view_cmem(dk), view_cmem(ek), view_cmem(ciphertext), + view_cmem(label), xs_cpp); + } else { + rv = coinbase::api::pve::decrypt_batch(curve_cpp, view_cmem(dk), view_cmem(ek), view_cmem(ciphertext), + view_cmem(label), xs_cpp); + } + if (rv) return rv; + + return alloc_cmems_from_bufs(xs_cpp, out_xs); + } catch (const std::bad_alloc&) { + if (out_xs) *out_xs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } catch (...) { + if (out_xs) *out_xs = cmems_t{0, nullptr, nullptr}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_batch_decrypt_rsa_oaep_hsm(cbmpc_curve_id_t curve, cmem_t dk_handle, cmem_t ek, + cmem_t ciphertext, cmem_t label, + const cbmpc_pve_rsa_oaep_hsm_decap_t* cb, cmems_t* out_xs) { + try { + if (!out_xs) return E_BADARG; + *out_xs = cmems_t{0, nullptr, nullptr}; + if (!cb) return E_BADARG; + + const auto vdk = validate_cmem(dk_handle); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::pve::rsa_oaep_hsm_decap_cb_t cb_cpp; + cb_cpp.ctx = const_cast(cb); + cb_cpp.decap = rsa_oaep_hsm_decap_cpp; + + std::vector xs_cpp; + const coinbase::error_t rv = coinbase::api::pve::decrypt_batch_rsa_oaep_hsm( + curve_cpp, view_cmem(dk_handle), view_cmem(ek), view_cmem(ciphertext), view_cmem(label), cb_cpp, xs_cpp); + if (rv) return rv; + + return alloc_cmems_from_bufs(xs_cpp, out_xs); + } catch (const std::bad_alloc&) { + if (out_xs) *out_xs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } catch (...) { + if (out_xs) *out_xs = cmems_t{0, nullptr, nullptr}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_batch_decrypt_ecies_p256_hsm(cbmpc_curve_id_t curve, cmem_t dk_handle, cmem_t ek, + cmem_t ciphertext, cmem_t label, + const cbmpc_pve_ecies_p256_hsm_ecdh_t* cb, cmems_t* out_xs) { + try { + if (!out_xs) return E_BADARG; + *out_xs = cmems_t{0, nullptr, nullptr}; + if (!cb) return E_BADARG; + + const auto vdk = validate_cmem(dk_handle); + if (vdk) return vdk; + const auto vek = validate_cmem(ek); + if (vek) return vek; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vl = validate_cmem(label); + if (vl) return vl; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb_cpp; + cb_cpp.ctx = const_cast(cb); + cb_cpp.ecdh = ecies_p256_hsm_ecdh_cpp; + + std::vector xs_cpp; + const coinbase::error_t rv = coinbase::api::pve::decrypt_batch_ecies_p256_hsm( + curve_cpp, view_cmem(dk_handle), view_cmem(ek), view_cmem(ciphertext), view_cmem(label), cb_cpp, xs_cpp); + if (rv) return rv; + + return alloc_cmems_from_bufs(xs_cpp, out_xs); + } catch (const std::bad_alloc&) { + if (out_xs) *out_xs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } catch (...) { + if (out_xs) *out_xs = cmems_t{0, nullptr, nullptr}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_batch_get_count(cmem_t ciphertext, int* out_batch_count) { + try { + if (!out_batch_count) return E_BADARG; + *out_batch_count = 0; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + int n = 0; + const coinbase::error_t rv = coinbase::api::pve::get_batch_count(view_cmem(ciphertext), n); + if (rv) return rv; + + *out_batch_count = n; + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_batch_count) *out_batch_count = 0; + return E_INSUFFICIENT; + } catch (...) { + if (out_batch_count) *out_batch_count = 0; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_batch_get_Qs(cmem_t ciphertext, cmems_t* out_Qs_compressed) { + try { + if (!out_Qs_compressed) return E_BADARG; + *out_Qs_compressed = cmems_t{0, nullptr, nullptr}; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + std::vector Qs; + const coinbase::error_t rv = coinbase::api::pve::get_public_keys_compressed_batch(view_cmem(ciphertext), Qs); + if (rv) return rv; + + return alloc_cmems_from_bufs(Qs, out_Qs_compressed); + } catch (const std::bad_alloc&) { + if (out_Qs_compressed) *out_Qs_compressed = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } catch (...) { + if (out_Qs_compressed) *out_Qs_compressed = cmems_t{0, nullptr, nullptr}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_pve_batch_get_Label(cmem_t ciphertext, cmem_t* out_label) { + try { + if (!out_label) return E_BADARG; + *out_label = cmem_t{nullptr, 0}; + + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + coinbase::buf_t L; + const coinbase::error_t rv = coinbase::api::pve::get_Label_batch(view_cmem(ciphertext), L); + if (rv) return rv; + + return alloc_cmem_from_buf(L, out_label); + } catch (const std::bad_alloc&) { + if (out_label) *out_label = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_label) *out_label = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/pve_internal.h b/src/cbmpc/c_api/pve_internal.h new file mode 100644 index 00000000..804ab67d --- /dev/null +++ b/src/cbmpc/c_api/pve_internal.h @@ -0,0 +1,111 @@ +#pragma once + +#include +#include +#include +#include + +namespace coinbase::capi::pve_detail { + +class c_base_pke_adapter_t final : public coinbase::api::pve::base_pke_i { + public: + explicit c_base_pke_adapter_t(const cbmpc_pve_base_pke_t* p) : p_(p) {} + + coinbase::error_t encrypt(coinbase::mem_t ek, coinbase::mem_t label, coinbase::mem_t plain, coinbase::mem_t rho, + coinbase::buf_t& out_ct) const override { + if (!p_ || !p_->encrypt) return E_BADARG; + + cmem_t out{nullptr, 0}; + const cbmpc_error_t rv = p_->encrypt( + p_->ctx, cmem_t{const_cast(ek.data), ek.size}, cmem_t{const_cast(label.data), label.size}, + cmem_t{const_cast(plain.data), plain.size}, cmem_t{const_cast(rho.data), rho.size}, &out); + if (rv) { + if (out.data) cbmpc_cmem_free(out); + return rv; + } + + if (out.size < 0 || (out.size > 0 && !out.data)) { + // Callback violated the ABI contract; do not attempt to read. + cbmpc_free(out.data); + return E_FORMAT; + } + out_ct = coinbase::buf_t(out.data, out.size); + cbmpc_cmem_free(out); + return CBMPC_SUCCESS; + } + + coinbase::error_t decrypt(coinbase::mem_t dk, coinbase::mem_t label, coinbase::mem_t ct, + coinbase::buf_t& out_plain) const override { + if (!p_ || !p_->decrypt) return E_BADARG; + + cmem_t out{nullptr, 0}; + const cbmpc_error_t rv = p_->decrypt(p_->ctx, cmem_t{const_cast(dk.data), dk.size}, + cmem_t{const_cast(label.data), label.size}, + cmem_t{const_cast(ct.data), ct.size}, &out); + if (rv) { + if (out.data) cbmpc_cmem_free(out); + return rv; + } + + if (out.size < 0 || (out.size > 0 && !out.data)) { + cbmpc_free(out.data); + return E_FORMAT; + } + out_plain = coinbase::buf_t(out.data, out.size); + cbmpc_cmem_free(out); + return CBMPC_SUCCESS; + } + + private: + const cbmpc_pve_base_pke_t* p_ = nullptr; +}; + +inline coinbase::error_t rsa_oaep_hsm_decap_cpp(void* ctx, coinbase::mem_t dk_handle, coinbase::mem_t kem_ct, + coinbase::buf_t& out_kem_ss) { + const auto* cb = static_cast(ctx); + if (!cb || !cb->decap) return E_BADARG; + + cmem_t kem_ss{nullptr, 0}; + const cbmpc_error_t rv = cb->decap(cb->ctx, cmem_t{const_cast(dk_handle.data), dk_handle.size}, + cmem_t{const_cast(kem_ct.data), kem_ct.size}, &kem_ss); + if (rv) { + if (kem_ss.data) cbmpc_cmem_free(kem_ss); + return rv; + } + + if (kem_ss.size < 0 || (kem_ss.size > 0 && !kem_ss.data)) { + cbmpc_free(kem_ss.data); + return E_FORMAT; + } + out_kem_ss = coinbase::buf_t(kem_ss.data, kem_ss.size); + cbmpc_cmem_free(kem_ss); + return CBMPC_SUCCESS; +} + +inline coinbase::error_t ecies_p256_hsm_ecdh_cpp(void* ctx, coinbase::mem_t dk_handle, coinbase::mem_t kem_ct, + coinbase::buf_t& out_dh_x32) { + const auto* cb = static_cast(ctx); + if (!cb || !cb->ecdh) return E_BADARG; + + cmem_t dh{nullptr, 0}; + const cbmpc_error_t rv = cb->ecdh(cb->ctx, cmem_t{const_cast(dk_handle.data), dk_handle.size}, + cmem_t{const_cast(kem_ct.data), kem_ct.size}, &dh); + if (rv) { + if (dh.data) cbmpc_cmem_free(dh); + return rv; + } + if (dh.size < 0 || (dh.size > 0 && !dh.data)) { + cbmpc_free(dh.data); + return E_FORMAT; + } + if (dh.size != 32) { + cbmpc_cmem_free(dh); + return E_CRYPTO; + } + + out_dh_x32 = coinbase::buf_t(dh.data, dh.size); + cbmpc_cmem_free(dh); + return CBMPC_SUCCESS; +} + +} // namespace coinbase::capi::pve_detail diff --git a/src/cbmpc/c_api/schnorr2pc.cpp b/src/cbmpc/c_api/schnorr2pc.cpp new file mode 100644 index 00000000..799dad84 --- /dev/null +++ b/src/cbmpc/c_api/schnorr2pc.cpp @@ -0,0 +1,239 @@ +#include + +#include +#include +#include +#include +#include +#include + +#include "util.h" + +using namespace coinbase::capi::detail; + +extern "C" { + +cbmpc_error_t cbmpc_schnorr_2p_dkg(const cbmpc_2pc_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t key_blob; + const coinbase::error_t rv = coinbase::api::schnorr_2p::dkg(ctx.job, curve_cpp, key_blob); + if (rv) return rv; + + return alloc_cmem_from_buf(key_blob, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_2p_refresh(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t* out_new_key_blob) { + try { + if (!out_new_key_blob) return E_BADARG; + *out_new_key_blob = cmem_t{nullptr, 0}; + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t new_key; + const coinbase::error_t rv = coinbase::api::schnorr_2p::refresh(ctx.job, view_cmem(key_blob), new_key); + if (rv) return rv; + + return alloc_cmem_from_buf(new_key, out_new_key_blob); + } catch (const std::bad_alloc&) { + if (out_new_key_blob) *out_new_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_new_key_blob) *out_new_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_2p_sign(const cbmpc_2pc_job_t* job, cmem_t key_blob, cmem_t msg, cmem_t* sig_out) { + try { + if (!sig_out) return E_BADARG; + *sig_out = cmem_t{nullptr, 0}; + const auto vjob = validate_2pc_job(job); + if (vjob) return vjob; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + const auto vm = validate_cmem(msg); + if (vm) return vm; + + coinbase::api::party_2p_t self_cpp; + const auto pconv = to_cpp_party(job->self, self_cpp); + if (pconv) return pconv; + + job_2p_cpp_ctx_t ctx(job, self_cpp); + coinbase::buf_t sig; + const coinbase::error_t rv = coinbase::api::schnorr_2p::sign(ctx.job, view_cmem(key_blob), view_cmem(msg), sig); + if (rv) return rv; + + return alloc_cmem_from_buf(sig, sig_out); + } catch (const std::bad_alloc&) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_2p_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key) { + try { + if (!out_pub_key) return E_BADARG; + *out_pub_key = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t pk; + const coinbase::error_t rv = coinbase::api::schnorr_2p::get_public_key_compressed(view_cmem(key_blob), pk); + if (rv) return rv; + + return alloc_cmem_from_buf(pk, out_pub_key); + } catch (const std::bad_alloc&) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_2p_extract_public_key_xonly(cmem_t key_blob, cmem_t* out_pub_key) { + try { + if (!out_pub_key) return E_BADARG; + *out_pub_key = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t pk; + const coinbase::error_t rv = coinbase::api::schnorr_2p::extract_public_key_xonly(view_cmem(key_blob), pk); + if (rv) return rv; + + return alloc_cmem_from_buf(pk, out_pub_key); + } catch (const std::bad_alloc&) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_2p_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share) { + try { + if (!out_public_share) return E_BADARG; + *out_public_share = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t Qi; + const coinbase::error_t rv = coinbase::api::schnorr_2p::get_public_share_compressed(view_cmem(key_blob), Qi); + if (rv) return rv; + return alloc_cmem_from_buf(Qi, out_public_share); + } catch (const std::bad_alloc&) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_2p_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed) { + try { + if (!out_public_key_blob || !out_private_scalar_fixed) return E_BADARG; + *out_public_key_blob = cmem_t{nullptr, 0}; + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t public_blob; + coinbase::buf_t private_scalar_fixed; + const coinbase::error_t rv = + coinbase::api::schnorr_2p::detach_private_scalar(view_cmem(key_blob), public_blob, private_scalar_fixed); + if (rv) return rv; + + const auto r1 = alloc_cmem_from_buf(public_blob, out_public_key_blob); + if (r1) return r1; + const auto r2 = alloc_cmem_from_buf(private_scalar_fixed, out_private_scalar_fixed); + if (r2) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + return r2; + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_2p_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vpb = validate_cmem(public_key_blob); + if (vpb) return vpb; + const auto vx = validate_cmem(private_scalar_fixed); + if (vx) return vx; + const auto vq = validate_cmem(public_share_compressed); + if (vq) return vq; + + coinbase::buf_t merged; + const coinbase::error_t rv = coinbase::api::schnorr_2p::attach_private_scalar( + view_cmem(public_key_blob), view_cmem(private_scalar_fixed), view_cmem(public_share_compressed), merged); + if (rv) return rv; + return alloc_cmem_from_buf(merged, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/schnorr_mp.cpp b/src/cbmpc/c_api/schnorr_mp.cpp new file mode 100644 index 00000000..006c6fcc --- /dev/null +++ b/src/cbmpc/c_api/schnorr_mp.cpp @@ -0,0 +1,461 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "access_structure_adapter.h" +#include "util.h" + +using namespace coinbase::capi::detail; + +extern "C" { + +cbmpc_error_t cbmpc_schnorr_mp_dkg_additive(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_key_blob, + cmem_t* out_sid) { + try { + if (!out_key_blob || !out_sid) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t key_blob; + coinbase::buf_t sid; + const coinbase::error_t rv = coinbase::api::schnorr_mp::dkg_additive(ctx.job, curve_cpp, key_blob, sid); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(key_blob, out_key_blob); + if (r_key) return r_key; + + const auto r_sid = alloc_cmem_from_buf(sid, out_sid); + if (r_sid) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_key_blob) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) { + cbmpc_cmem_free(*out_key_blob); + *out_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_dkg_ac(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t sid_in, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* out_ac_key_blob, cmem_t* out_sid) { + try { + if (!out_ac_key_blob || !out_sid) return E_BADARG; + *out_ac_key_blob = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + + std::vector quorum_names; + const auto vqn = to_cpp_quorum_party_names(quorum_party_names, quorum_party_names_count, quorum_names); + if (vqn) return vqn; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t key_blob; + const coinbase::error_t rv = + coinbase::api::schnorr_mp::dkg_ac(ctx.job, curve_cpp, sid, ac_cpp, quorum_names, key_blob); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(key_blob, out_ac_key_blob); + if (r_key) return r_key; + + const auto r_sid = alloc_cmem_from_buf(sid, out_sid); + if (r_sid) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_ac_key_blob) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_ac_key_blob) { + cbmpc_cmem_free(*out_ac_key_blob); + *out_ac_key_blob = cmem_t{nullptr, 0}; + } + if (out_sid) { + cbmpc_cmem_free(*out_sid); + *out_sid = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_refresh_additive(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t key_blob, + cmem_t* sid_out, cmem_t* out_new_key_blob) { + try { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (!out_new_key_blob) return E_BADARG; + *out_new_key_blob = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t new_key; + const coinbase::error_t rv = + coinbase::api::schnorr_mp::refresh_additive(ctx.job, sid, view_cmem(key_blob), new_key); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(new_key, out_new_key_blob); + if (r_key) return r_key; + + if (sid_out) { + const auto r_sid = alloc_cmem_from_buf(sid, sid_out); + if (r_sid) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_key_blob) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_key_blob) { + cbmpc_cmem_free(*out_new_key_blob); + *out_new_key_blob = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_refresh_ac(const cbmpc_mp_job_t* job, cmem_t sid_in, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, + const char* const* quorum_party_names, int quorum_party_names_count, + cmem_t* sid_out, cmem_t* out_new_ac_key_blob) { + try { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (!out_new_ac_key_blob) return E_BADARG; + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vsi = validate_cmem(sid_in); + if (vsi) return vsi; + const auto vkb = validate_cmem(ac_key_blob); + if (vkb) return vkb; + + std::vector quorum_names; + const auto vqn = to_cpp_quorum_party_names(quorum_party_names, quorum_party_names_count, quorum_names); + if (vqn) return vqn; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sid(sid_in.data, sid_in.size); + coinbase::buf_t new_key; + const coinbase::error_t rv = + coinbase::api::schnorr_mp::refresh_ac(ctx.job, sid, view_cmem(ac_key_blob), ac_cpp, quorum_names, new_key); + if (rv) return rv; + + const auto r_key = alloc_cmem_from_buf(new_key, out_new_ac_key_blob); + if (r_key) return r_key; + + if (sid_out) { + const auto r_sid = alloc_cmem_from_buf(sid, sid_out); + if (r_sid) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + return r_sid; + } + } + + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_ac_key_blob) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (sid_out) *sid_out = cmem_t{nullptr, 0}; + if (out_new_ac_key_blob) { + cbmpc_cmem_free(*out_new_ac_key_blob); + *out_new_ac_key_blob = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_sign_additive(const cbmpc_mp_job_t* job, cmem_t key_blob, cmem_t msg, + int32_t sig_receiver, cmem_t* sig_out) { + try { + if (!sig_out) return E_BADARG; + *sig_out = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + const auto vm = validate_cmem(msg); + if (vm) return vm; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sig; + const coinbase::error_t rv = + coinbase::api::schnorr_mp::sign_additive(ctx.job, view_cmem(key_blob), view_cmem(msg), sig_receiver, sig); + if (rv) return rv; + + return alloc_cmem_from_buf(sig, sig_out); + } catch (const std::bad_alloc&) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_sign_ac(const cbmpc_mp_job_t* job, cmem_t ac_key_blob, + const cbmpc_access_structure_t* access_structure, cmem_t msg, + int32_t sig_receiver, cmem_t* sig_out) { + try { + if (!sig_out) return E_BADARG; + *sig_out = cmem_t{nullptr, 0}; + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + const auto vkb = validate_cmem(ac_key_blob); + if (vkb) return vkb; + const auto vm = validate_cmem(msg); + if (vm) return vm; + + coinbase::api::access_structure_t ac_cpp; + const auto vac = to_cpp_access_structure(access_structure, ac_cpp); + if (vac) return vac; + + job_mp_cpp_ctx_t ctx(job); + + coinbase::buf_t sig; + const coinbase::error_t rv = + coinbase::api::schnorr_mp::sign_ac(ctx.job, view_cmem(ac_key_blob), ac_cpp, view_cmem(msg), sig_receiver, sig); + if (rv) return rv; + + return alloc_cmem_from_buf(sig, sig_out); + } catch (const std::bad_alloc&) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (sig_out) *sig_out = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_get_public_key_compressed(cmem_t key_blob, cmem_t* out_pub_key) { + try { + if (!out_pub_key) return E_BADARG; + *out_pub_key = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t pk; + const coinbase::error_t rv = coinbase::api::schnorr_mp::get_public_key_compressed(view_cmem(key_blob), pk); + if (rv) return rv; + + return alloc_cmem_from_buf(pk, out_pub_key); + } catch (const std::bad_alloc&) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_extract_public_key_xonly(cmem_t key_blob, cmem_t* out_pub_key) { + try { + if (!out_pub_key) return E_BADARG; + *out_pub_key = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t pk; + const coinbase::error_t rv = coinbase::api::schnorr_mp::extract_public_key_xonly(view_cmem(key_blob), pk); + if (rv) return rv; + + return alloc_cmem_from_buf(pk, out_pub_key); + } catch (const std::bad_alloc&) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_pub_key) *out_pub_key = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_get_public_share_compressed(cmem_t key_blob, cmem_t* out_public_share) { + try { + if (!out_public_share) return E_BADARG; + *out_public_share = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t Qi; + const coinbase::error_t rv = coinbase::api::schnorr_mp::get_public_share_compressed(view_cmem(key_blob), Qi); + if (rv) return rv; + return alloc_cmem_from_buf(Qi, out_public_share); + } catch (const std::bad_alloc&) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_public_share) *out_public_share = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_detach_private_scalar(cmem_t key_blob, cmem_t* out_public_key_blob, + cmem_t* out_private_scalar_fixed) { + try { + if (!out_public_key_blob || !out_private_scalar_fixed) return E_BADARG; + *out_public_key_blob = cmem_t{nullptr, 0}; + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + const auto vkb = validate_cmem(key_blob); + if (vkb) return vkb; + + coinbase::buf_t public_blob; + coinbase::buf_t private_scalar_fixed; + const coinbase::error_t rv = + coinbase::api::schnorr_mp::detach_private_scalar(view_cmem(key_blob), public_blob, private_scalar_fixed); + if (rv) return rv; + + const auto r1 = alloc_cmem_from_buf(public_blob, out_public_key_blob); + if (r1) return r1; + const auto r2 = alloc_cmem_from_buf(private_scalar_fixed, out_private_scalar_fixed); + if (r2) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + return r2; + } + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_INSUFFICIENT; + } catch (...) { + if (out_public_key_blob) { + cbmpc_cmem_free(*out_public_key_blob); + *out_public_key_blob = cmem_t{nullptr, 0}; + } + if (out_private_scalar_fixed) { + cbmpc_cmem_free(*out_private_scalar_fixed); + *out_private_scalar_fixed = cmem_t{nullptr, 0}; + } + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_schnorr_mp_attach_private_scalar(cmem_t public_key_blob, cmem_t private_scalar_fixed, + cmem_t public_share_compressed, cmem_t* out_key_blob) { + try { + if (!out_key_blob) return E_BADARG; + *out_key_blob = cmem_t{nullptr, 0}; + const auto vpb = validate_cmem(public_key_blob); + if (vpb) return vpb; + const auto vx = validate_cmem(private_scalar_fixed); + if (vx) return vx; + const auto vq = validate_cmem(public_share_compressed); + if (vq) return vq; + + coinbase::buf_t merged; + const coinbase::error_t rv = coinbase::api::schnorr_mp::attach_private_scalar( + view_cmem(public_key_blob), view_cmem(private_scalar_fixed), view_cmem(public_share_compressed), merged); + if (rv) return rv; + return alloc_cmem_from_buf(merged, out_key_blob); + } catch (const std::bad_alloc&) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_key_blob) *out_key_blob = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/tdh2.cpp b/src/cbmpc/c_api/tdh2.cpp new file mode 100644 index 00000000..4807e162 --- /dev/null +++ b/src/cbmpc/c_api/tdh2.cpp @@ -0,0 +1,341 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "access_structure_adapter.h" +#include "util.h" + +namespace { + +using namespace coinbase::capi::detail; + +struct tdh2_dkg_out_guard_t { + cmem_t* out_public_key = nullptr; + cmems_t* out_public_shares = nullptr; + cmem_t* out_private_share = nullptr; + cmem_t* out_sid = nullptr; + bool released = false; + + tdh2_dkg_out_guard_t(cmem_t* pk, cmems_t* pub, cmem_t* priv, cmem_t* sid) + : out_public_key(pk), out_public_shares(pub), out_private_share(priv), out_sid(sid) {} + + void release() { released = true; } + + void cleanup() const { + cbmpc_cmem_free(*out_public_key); + cbmpc_cmems_free(*out_public_shares); + cbmpc_cmem_free(*out_private_share); + cbmpc_cmem_free(*out_sid); + + *out_public_key = cmem_t{nullptr, 0}; + *out_public_shares = cmems_t{0, nullptr, nullptr}; + *out_private_share = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + } + + ~tdh2_dkg_out_guard_t() { + if (!released) cleanup(); + } +}; + +} // namespace + +extern "C" { + +cbmpc_error_t cbmpc_tdh2_dkg_additive(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t* out_public_key, + cmems_t* out_public_shares, cmem_t* out_private_share, cmem_t* out_sid) { + try { + if (!out_public_key || !out_public_shares || !out_private_share || !out_sid) return E_BADARG; + *out_public_key = cmem_t{nullptr, 0}; + *out_public_shares = cmems_t{0, nullptr, nullptr}; + *out_private_share = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + tdh2_dkg_out_guard_t out_guard(out_public_key, out_public_shares, out_private_share, out_sid); + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::capi::detail::job_mp_cpp_ctx_t job_ctx(job); + + coinbase::buf_t pk; + std::vector pub_shares; + coinbase::buf_t priv_share; + coinbase::buf_t sid; + + const coinbase::error_t rv = + coinbase::api::tdh2::dkg_additive(job_ctx.job, curve_cpp, pk, pub_shares, priv_share, sid); + if (rv) return rv; + + auto r1 = alloc_cmem_from_buf(pk, out_public_key); + if (r1) return r1; + auto r2 = alloc_cmems_from_bufs(pub_shares, out_public_shares); + if (r2) return r2; + auto r3 = alloc_cmem_from_buf(priv_share, out_private_share); + if (r3) return r3; + + auto r4 = alloc_cmem_from_buf(sid, out_sid); + if (r4) return r4; + + out_guard.release(); + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + return E_INSUFFICIENT; + } catch (...) { + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_tdh2_dkg_ac(const cbmpc_mp_job_t* job, cbmpc_curve_id_t curve, cmem_t sid_in, + const cbmpc_access_structure_t* access_structure, const char* const* quorum_party_names, + int quorum_party_names_count, cmem_t* out_public_key, cmems_t* out_public_shares, + cmem_t* out_private_share, cmem_t* out_sid) { + try { + if (!out_public_key || !out_public_shares || !out_private_share || !out_sid) return E_BADARG; + *out_public_key = cmem_t{nullptr, 0}; + *out_public_shares = cmems_t{0, nullptr, nullptr}; + *out_private_share = cmem_t{nullptr, 0}; + *out_sid = cmem_t{nullptr, 0}; + tdh2_dkg_out_guard_t out_guard(out_public_key, out_public_shares, out_private_share, out_sid); + + const auto vjob = validate_mp_job(job); + if (vjob) return vjob; + const auto vsid = validate_cmem(sid_in); + if (vsid) return vsid; + if (!access_structure) return E_BADARG; + + coinbase::api::curve_id curve_cpp; + const auto cconv = to_cpp_curve(curve, curve_cpp); + if (cconv) return cconv; + + coinbase::api::access_structure_t ac_cpp; + const auto ac_rv = to_cpp_access_structure(access_structure, ac_cpp); + if (ac_rv) return ac_rv; + + std::vector quorum_cpp; + const auto qrv = to_cpp_quorum_party_names(quorum_party_names, quorum_party_names_count, quorum_cpp); + if (qrv) return qrv; + + coinbase::capi::detail::job_mp_cpp_ctx_t job_ctx(job); + + coinbase::buf_t pk; + std::vector pub_shares; + coinbase::buf_t priv_share; + coinbase::buf_t sid(view_cmem(sid_in)); + + const coinbase::error_t rv = + coinbase::api::tdh2::dkg_ac(job_ctx.job, curve_cpp, sid, ac_cpp, quorum_cpp, pk, pub_shares, priv_share); + if (rv) return rv; + + auto r1 = alloc_cmem_from_buf(pk, out_public_key); + if (r1) return r1; + auto r2 = alloc_cmems_from_bufs(pub_shares, out_public_shares); + if (r2) return r2; + auto r3 = alloc_cmem_from_buf(priv_share, out_private_share); + if (r3) return r3; + auto r4 = alloc_cmem_from_buf(sid, out_sid); + if (r4) return r4; + + out_guard.release(); + return CBMPC_SUCCESS; + } catch (const std::bad_alloc&) { + return E_INSUFFICIENT; + } catch (...) { + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_tdh2_encrypt(cmem_t public_key, cmem_t plaintext, cmem_t label, cmem_t* out_ciphertext) { + try { + if (!out_ciphertext) return E_BADARG; + *out_ciphertext = cmem_t{nullptr, 0}; + const auto vpk = validate_cmem(public_key); + if (vpk) return vpk; + const auto vpt = validate_cmem(plaintext); + if (vpt) return vpt; + const auto vlb = validate_cmem(label); + if (vlb) return vlb; + + coinbase::buf_t ct; + const coinbase::error_t rv = + coinbase::api::tdh2::encrypt(view_cmem(public_key), view_cmem(plaintext), view_cmem(label), ct); + if (rv) return rv; + return alloc_cmem_from_buf(ct, out_ciphertext); + } catch (const std::bad_alloc&) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_ciphertext) *out_ciphertext = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_tdh2_verify(cmem_t public_key, cmem_t ciphertext, cmem_t label) { + try { + const auto vpk = validate_cmem(public_key); + if (vpk) return vpk; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vlb = validate_cmem(label); + if (vlb) return vlb; + return coinbase::api::tdh2::verify(view_cmem(public_key), view_cmem(ciphertext), view_cmem(label)); + } catch (const std::bad_alloc&) { + return E_INSUFFICIENT; + } catch (...) { + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_tdh2_partial_decrypt(cmem_t private_share, cmem_t ciphertext, cmem_t label, + cmem_t* out_partial_decryption) { + try { + if (!out_partial_decryption) return E_BADARG; + *out_partial_decryption = cmem_t{nullptr, 0}; + const auto vps = validate_cmem(private_share); + if (vps) return vps; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + const auto vlb = validate_cmem(label); + if (vlb) return vlb; + + coinbase::buf_t partial; + const coinbase::error_t rv = coinbase::api::tdh2::partial_decrypt(view_cmem(private_share), view_cmem(ciphertext), + view_cmem(label), partial); + if (rv) return rv; + return alloc_cmem_from_buf(partial, out_partial_decryption); + } catch (const std::bad_alloc&) { + if (out_partial_decryption) *out_partial_decryption = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_partial_decryption) *out_partial_decryption = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_tdh2_combine_additive(cmem_t public_key, cmems_t public_shares, cmem_t label, + cmems_t partial_decryptions, cmem_t ciphertext, cmem_t* out_plaintext) { + try { + if (!out_plaintext) return E_BADARG; + *out_plaintext = cmem_t{nullptr, 0}; + const auto vpk = validate_cmem(public_key); + if (vpk) return vpk; + const auto vlb = validate_cmem(label); + if (vlb) return vlb; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + std::vector pub_shares; + std::vector partials; + auto rv = view_cmems(public_shares, pub_shares); + if (rv) return rv; + rv = view_cmems(partial_decryptions, partials); + if (rv) return rv; + + coinbase::buf_t plain; + const coinbase::error_t crv = coinbase::api::tdh2::combine_additive( + view_cmem(public_key), pub_shares, view_cmem(label), partials, view_cmem(ciphertext), plain); + if (crv) return crv; + + return alloc_cmem_from_buf(plain, out_plaintext); + } catch (const std::bad_alloc&) { + if (out_plaintext) *out_plaintext = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_plaintext) *out_plaintext = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +cbmpc_error_t cbmpc_tdh2_combine_ac(const cbmpc_access_structure_t* access_structure, cmem_t public_key, + const char* const* party_names, int party_names_count, cmems_t public_shares, + cmem_t label, const char* const* partial_decryption_party_names, + int partial_decryption_party_names_count, cmems_t partial_decryptions, + cmem_t ciphertext, cmem_t* out_plaintext) { + try { + if (!out_plaintext) return E_BADARG; + *out_plaintext = cmem_t{nullptr, 0}; + if (!access_structure) return E_BADARG; + + const auto vpk = validate_cmem(public_key); + if (vpk) return vpk; + const auto vlb = validate_cmem(label); + if (vlb) return vlb; + const auto vct = validate_cmem(ciphertext); + if (vct) return vct; + + // Convert access structure. + coinbase::api::access_structure_t ac_cpp; + const auto ac_rv = to_cpp_access_structure(access_structure, ac_cpp); + if (ac_rv) return ac_rv; + + // Convert party names. + std::vector party_names_cpp; + { + if (party_names_count < 0) return E_BADARG; + if (party_names_count == 0) return E_BADARG; + if (!party_names) return E_BADARG; + party_names_cpp.reserve(static_cast(party_names_count)); + for (int i = 0; i < party_names_count; i++) { + const char* s = party_names[i]; + if (!s) return E_BADARG; + if (s[0] == '\0') return E_BADARG; + party_names_cpp.emplace_back(s); + } + } + + // Convert public shares. + std::vector public_shares_cpp; + auto rv = view_cmems(public_shares, public_shares_cpp); + if (rv) return rv; + if (public_shares_cpp.size() != static_cast(party_names_count)) return E_BADARG; + + // Convert partial decryption party names. + std::vector partial_names_cpp; + { + if (partial_decryption_party_names_count < 0) return E_BADARG; + if (partial_decryption_party_names_count == 0) return E_BADARG; + if (!partial_decryption_party_names) return E_BADARG; + partial_names_cpp.reserve(static_cast(partial_decryption_party_names_count)); + for (int i = 0; i < partial_decryption_party_names_count; i++) { + const char* s = partial_decryption_party_names[i]; + if (!s) return E_BADARG; + if (s[0] == '\0') return E_BADARG; + partial_names_cpp.emplace_back(s); + } + } + + // Convert partial decryptions. + std::vector partials_cpp; + rv = view_cmems(partial_decryptions, partials_cpp); + if (rv) return rv; + if (partials_cpp.size() != static_cast(partial_decryption_party_names_count)) return E_BADARG; + + coinbase::buf_t plain; + const coinbase::error_t crv = coinbase::api::tdh2::combine_ac( + ac_cpp, view_cmem(public_key), party_names_cpp, public_shares_cpp, view_cmem(label), partial_names_cpp, + partials_cpp, view_cmem(ciphertext), plain); + if (crv) return crv; + + return alloc_cmem_from_buf(plain, out_plaintext); + } catch (const std::bad_alloc&) { + if (out_plaintext) *out_plaintext = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } catch (...) { + if (out_plaintext) *out_plaintext = cmem_t{nullptr, 0}; + return E_GENERAL; + } +} + +} // extern "C" diff --git a/src/cbmpc/c_api/transport_adapter.h b/src/cbmpc/c_api/transport_adapter.h new file mode 100644 index 00000000..d5c54dfe --- /dev/null +++ b/src/cbmpc/c_api/transport_adapter.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace coinbase::capi::detail { + +inline void transport_free(const cbmpc_transport_t* t, void* ptr) { + if (!ptr) return; + if (t && t->free) { + t->free(t->ctx, ptr); + } else { + cbmpc_free(ptr); + } +} + +class c_transport_adapter_t final : public coinbase::api::data_transport_i { + public: + explicit c_transport_adapter_t(const cbmpc_transport_t* t) : t_(t) {} + + coinbase::error_t send(coinbase::api::party_idx_t receiver, coinbase::mem_t msg) override { + if (!t_ || !t_->send) return E_BADARG; + return t_->send(t_->ctx, receiver, msg.data, msg.size); + } + + coinbase::error_t receive(coinbase::api::party_idx_t sender, coinbase::buf_t& msg) override { + if (!t_ || !t_->receive) return E_BADARG; + + cmem_t in{nullptr, 0}; + const coinbase::error_t rv = t_->receive(t_->ctx, sender, &in); + if (rv) { + // Best-effort cleanup: integrators may allocate output buffers before + // returning an error. + transport_free(t_, in.data); + return rv; + } + + if (in.size < 0 || (in.size > 0 && !in.data)) { + transport_free(t_, in.data); + return E_FORMAT; + } + msg = (in.size == 0) ? coinbase::buf_t() : coinbase::buf_t(in.data, in.size); + transport_free(t_, in.data); + return SUCCESS; + } + + coinbase::error_t receive_all(const std::vector& senders, + std::vector& msgs) override { + if (!t_ || !t_->receive_all) return E_NOT_SUPPORTED; + if (senders.size() > static_cast(INT_MAX)) return E_RANGE; + + std::vector senders_i32; + senders_i32.reserve(senders.size()); + for (auto s : senders) senders_i32.push_back(static_cast(s)); + + cmems_t out{0, nullptr, nullptr}; + const coinbase::error_t rv = + t_->receive_all(t_->ctx, senders_i32.data(), static_cast(senders_i32.size()), &out); + if (rv) { + // Best-effort cleanup: integrators may allocate output buffers before + // returning an error. + transport_free(t_, out.data); + if (out.sizes && reinterpret_cast(out.sizes) != out.data) transport_free(t_, out.sizes); + return rv; + } + + if (out.count < 0 || out.count != static_cast(senders.size())) { + transport_free(t_, out.data); + transport_free(t_, out.sizes); + return E_FORMAT; + } + if (out.count > 0 && !out.sizes) { + transport_free(t_, out.data); + return E_FORMAT; + } + + int total = 0; + for (int i = 0; i < out.count; i++) { + const int sz = out.sizes[i]; + if (sz < 0) { + transport_free(t_, out.data); + transport_free(t_, out.sizes); + return E_FORMAT; + } + if (sz > INT_MAX - total) { + transport_free(t_, out.data); + transport_free(t_, out.sizes); + return E_RANGE; + } + total += sz; + } + if (total > 0 && !out.data) { + transport_free(t_, out.sizes); + return E_FORMAT; + } + + msgs.clear(); + msgs.reserve(static_cast(out.count)); + + int offset = 0; + for (int i = 0; i < out.count; i++) { + const int sz = out.sizes[i]; + if (sz == 0) { + msgs.emplace_back(); + continue; + } + msgs.emplace_back(out.data + offset, sz); + offset += sz; + } + + transport_free(t_, out.data); + transport_free(t_, out.sizes); + return SUCCESS; + } + + private: + const cbmpc_transport_t* t_; +}; + +} // namespace coinbase::capi::detail diff --git a/src/cbmpc/c_api/util.h b/src/cbmpc/c_api/util.h new file mode 100644 index 00000000..451a7c98 --- /dev/null +++ b/src/cbmpc/c_api/util.h @@ -0,0 +1,206 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "transport_adapter.h" + +namespace coinbase::capi::detail { + +inline cbmpc_error_t validate_cmem(cmem_t m) { + if (m.size < 0) return E_BADARG; + if (m.size > 0 && !m.data) return E_BADARG; + return CBMPC_SUCCESS; +} + +inline coinbase::mem_t view_cmem(cmem_t m) { return coinbase::mem_t(m.data, m.size); } + +inline cbmpc_error_t validate_2pc_job(const cbmpc_2pc_job_t* job) { + if (!job) return E_BADARG; + if (!job->p1_name || !job->p2_name) return E_BADARG; + if (job->p1_name[0] == '\0' || job->p2_name[0] == '\0') return E_BADARG; + if (std::strcmp(job->p1_name, job->p2_name) == 0) return E_BADARG; + if (!job->transport || !job->transport->send || !job->transport->receive) return E_BADARG; + return CBMPC_SUCCESS; +} + +inline cbmpc_error_t validate_mp_job(const cbmpc_mp_job_t* job) { + if (!job) return E_BADARG; + if (job->party_names_count < 0) return E_BADARG; + // Match the public C++ API contract: MP jobs require at least 2 parties. + if (job->party_names_count < 2) return E_BADARG; + if (job->party_names_count > 64) return E_RANGE; + if (job->self < 0 || job->self >= job->party_names_count) return E_BADARG; + + if (!job->party_names) return E_BADARG; + std::unordered_set names; + names.reserve(static_cast(job->party_names_count)); + for (int i = 0; i < job->party_names_count; i++) { + const char* name = job->party_names[i]; + if (!name) return E_BADARG; + if (name[0] == '\0') return E_BADARG; + if (!names.insert(std::string_view(name)).second) return E_BADARG; // duplicate + } + if (!job->transport || !job->transport->send || !job->transport->receive || !job->transport->receive_all) + return E_BADARG; + return CBMPC_SUCCESS; +} + +inline cbmpc_error_t to_cpp_party(cbmpc_2pc_party_t p, coinbase::api::party_2p_t& out) { + switch (p) { + case CBMPC_2PC_P1: + out = coinbase::api::party_2p_t::p1; + return CBMPC_SUCCESS; + case CBMPC_2PC_P2: + out = coinbase::api::party_2p_t::p2; + return CBMPC_SUCCESS; + } + return E_BADARG; +} + +inline cbmpc_error_t to_cpp_curve(cbmpc_curve_id_t c, coinbase::api::curve_id& out) { + switch (c) { + case CBMPC_CURVE_P256: + out = coinbase::api::curve_id::p256; + return CBMPC_SUCCESS; + case CBMPC_CURVE_SECP256K1: + out = coinbase::api::curve_id::secp256k1; + return CBMPC_SUCCESS; + case CBMPC_CURVE_ED25519: + out = coinbase::api::curve_id::ed25519; + return CBMPC_SUCCESS; + } + return E_BADARG; +} + +struct job_2p_cpp_ctx_t { + c_transport_adapter_t transport; + coinbase::api::job_2p_t job; + + job_2p_cpp_ctx_t(const cbmpc_2pc_job_t* job_c, coinbase::api::party_2p_t self) + : transport(job_c->transport), job{self, job_c->p1_name, job_c->p2_name, transport} {} +}; + +struct job_mp_cpp_ctx_t { + c_transport_adapter_t transport; + coinbase::api::job_mp_t job; + + explicit job_mp_cpp_ctx_t(const cbmpc_mp_job_t* job_c) + : transport(job_c->transport), job{job_c->self, make_party_names(job_c), transport} {} + + private: + static std::vector make_party_names(const cbmpc_mp_job_t* job_c) { + std::vector out; + out.reserve(static_cast(job_c->party_names_count)); + for (int i = 0; i < job_c->party_names_count; i++) out.emplace_back(job_c->party_names[i]); + return out; + } +}; + +inline cbmpc_error_t alloc_cmem_from_buf(const coinbase::buf_t& buf, cmem_t* out) { + if (!out) return E_BADARG; + out->data = nullptr; + out->size = 0; + + const int n = buf.size(); + if (n < 0) return E_FORMAT; + if (n == 0) return CBMPC_SUCCESS; + + uint8_t* data = static_cast(cbmpc_malloc(static_cast(n))); + if (!data) return E_INSUFFICIENT; + std::memmove(data, buf.data(), static_cast(n)); + out->data = data; + out->size = n; + return CBMPC_SUCCESS; +} + +inline cbmpc_error_t alloc_cmems_from_bufs(const std::vector& bufs, cmems_t* out) { + if (!out) return E_BADARG; + out->count = 0; + out->data = nullptr; + out->sizes = nullptr; + if (bufs.empty()) return CBMPC_SUCCESS; + + if (bufs.size() > static_cast(INT_MAX)) return E_RANGE; + + int total = 0; + for (const auto& b : bufs) { + const int sz = b.size(); + if (sz < 0) return E_FORMAT; + if (sz > INT_MAX - total) return E_RANGE; + total += sz; + } + + const int count = static_cast(bufs.size()); + out->count = count; + + out->sizes = static_cast(cbmpc_malloc(sizeof(int) * static_cast(count))); + if (!out->sizes) { + *out = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } + + if (total > 0) { + out->data = static_cast(cbmpc_malloc(static_cast(total))); + if (!out->data) { + cbmpc_free(out->sizes); + *out = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } + } + + int offset = 0; + for (int i = 0; i < count; i++) { + const int sz = bufs[i].size(); + out->sizes[i] = sz; + if (sz) { + std::memmove(out->data + offset, bufs[i].data(), static_cast(sz)); + offset += sz; + } + } + + return CBMPC_SUCCESS; +} + +inline cbmpc_error_t view_cmems(cmems_t in, std::vector& out) { + out.clear(); + if (in.count < 0) return E_BADARG; + if (in.count == 0) return CBMPC_SUCCESS; + if (!in.sizes) return E_BADARG; + + int total = 0; + for (int i = 0; i < in.count; i++) { + const int sz = in.sizes[i]; + if (sz < 0) return E_BADARG; + if (sz > INT_MAX - total) return E_RANGE; + total += sz; + } + if (total > 0 && !in.data) return E_BADARG; + + out.reserve(static_cast(in.count)); + int offset = 0; + for (int i = 0; i < in.count; i++) { + const int sz = in.sizes[i]; + if (sz == 0) { + out.emplace_back(nullptr, 0); + continue; + } + out.emplace_back(in.data + offset, sz); + offset += sz; + } + return CBMPC_SUCCESS; +} + +} // namespace coinbase::capi::detail diff --git a/src/cbmpc/core/CMakeLists.txt b/src/cbmpc/core/CMakeLists.txt index de4d98db..d74ab7a2 100755 --- a/src/cbmpc/core/CMakeLists.txt +++ b/src/cbmpc/core/CMakeLists.txt @@ -3,7 +3,7 @@ add_library(cbmpc_core OBJECT "") # Link OpenSSL before precompiled headers to ensure headers are available link_openssl(cbmpc_core) -target_precompile_headers(cbmpc_core PUBLIC "precompiled.h") +target_precompile_headers(cbmpc_core PUBLIC "${ROOT_DIR}/include/cbmpc/core/precompiled.h") target_sources(cbmpc_core PRIVATE buf.cpp diff --git a/src/cbmpc/core/buf.cpp b/src/cbmpc/core/buf.cpp index a9343b82..84e87788 100644 --- a/src/cbmpc/core/buf.cpp +++ b/src/cbmpc/core/buf.cpp @@ -1,13 +1,16 @@ -#include "buf.h" +#include -#include -#include +#include +#include +#include namespace coinbase { buf_t::buf_t() noexcept(true) : s(0) { static_assert(sizeof(buf_t) == 40, "Invalid buf_t size."); } buf_t::buf_t(int new_size) : s(new_size) { // NOLINT(*init*) + // NOTE: `buf_t(int)` intentionally leaves the buffer contents uninitialized. + // Callers must fully overwrite `size()` bytes before reading from `data()`. if (new_size > short_size) set_long_ptr(new byte_t[new_size]); } @@ -267,22 +270,24 @@ void buf_t::reverse() { mem_t(*this).reverse(); } std::string buf_t::to_string() const { return std::string(const_char_ptr(data()), s); } -byte_ptr buf_t::get_long_ptr() const { return ((byte_ptr*)m)[0]; } +byte_ptr buf_t::get_long_ptr() const { + byte_ptr ptr; + std::memcpy(&ptr, m, sizeof(ptr)); + return ptr; +} -void buf_t::set_long_ptr(byte_ptr ptr) { ((byte_ptr*)m)[0] = ptr; } +void buf_t::set_long_ptr(byte_ptr ptr) { std::memcpy(m, &ptr, sizeof(ptr)); } void buf_t::assign_short(const_byte_ptr src, int src_size) { for (int i = 0; i < src_size; i++) m[i] = src[i]; s = src_size; } -/** - * @notes: - * - Even though the size of m is 36, the next 4 bytes is used to store the size of the buffer (`s`). - * - Therefore, override ((uint64_t*)m)[4] is safe, even though it is seemingly out of bounds. - */ void buf_t::assign_short(const buf_t& src) { - for (int i = 0; i < 5; i++) ((uint64_t*)m)[i] = ((uint64_t*)src.m)[i]; + // Copy the entire short inline storage and the size. Keep this UB-free: `m` is byte-aligned, + // so type-punning via `(uint64_t*)` / `(byte_ptr*)` can violate strict-aliasing and alignment. + std::memcpy(m, src.m, short_size); + s = src.s; } void buf_t::assign_long_ptr(byte_ptr ptr, int size) { @@ -345,7 +350,11 @@ void buf_t::convert_last(converter_t& converter) { if (!converter.is_calc_size()) memmove(converter.current(), data(), size()); } else { if (converter.is_error()) return; - int s = converter.get_size() - converter.get_offset(); + const int s = converter.get_size() - converter.get_offset(); + if (s < 0 || !converter.at_least(s)) { + converter.set_error(); + return; + } memmove(alloc(s), converter.current(), s); } converter.forward(size()); @@ -359,10 +368,11 @@ void memmove_reverse(byte_ptr dst, const_byte_ptr src, int size) { void mem_t::reverse() { int l = 0; int r = size - 1; + byte_ptr p = const_cast(data); while (l < r) { - uint8_t t = data[l]; - data[l] = data[r]; - data[r] = t; + uint8_t t = p[l]; + p[l] = p[r]; + p[r] = t; l++; r--; } @@ -392,7 +402,9 @@ size_t mem_t::non_crypto_hash() const { uint32_t x = 1; while (n >= 4) { - x ^= *(const uint32_t*)p; + uint32_t chunk; + std::memcpy(&chunk, p, sizeof(chunk)); + x ^= chunk; x ^= x << 13; x ^= x >> 17; x ^= x << 5; @@ -412,7 +424,10 @@ size_t mem_t::non_crypto_hash() const { return x; } -std::string mem_t::to_string() const { return std::string(const_char_ptr(data), size); } +std::string mem_t::to_string() const { + if (size <= 0 || !data) return ""; + return std::string(const_char_ptr(data), static_cast(size)); +} // ------------------------- bits_t --------------------- @@ -476,9 +491,16 @@ void bits_t::free() { void bits_t::copy_from(const bits_t& src) { if (&src == this) return; + if (src.bits == 0) { + free(); + return; + } + alloc(src.bits); int n = bits_to_limbs(bits); + cb_assert(n > 0); + cb_assert(data && src.data); memmove(data, src.data, n * sizeof(limb_t)); } @@ -541,10 +563,17 @@ bits_t bits_t::from_bin(mem_t src) { } void bits_t::resize(int count) { + cb_assert(count >= 0 && "bits_t::resize: count must be non-negative"); + if (count == bits) return; + + // If we're growing, ensure the currently-unused tail bits are cleared before they become "visible". + if (count > bits && bits > 0) bzero_unused(); + int n_old = bits_to_limbs(bits); int n_new = bits_to_limbs(count); if (n_old == n_new) { bits = count; + if (bits > 0) bzero_unused(); return; } @@ -554,11 +583,14 @@ void bits_t::resize(int count) { } limb_t* old_data = data; - data = new limb_t[n_new]; bits = count; + data = new limb_t[n_new]; + memset(data, 0, n_new * sizeof(limb_t)); + int n_copy = std::min(n_old, n_new); if (n_copy) memmove(data, old_data, n_copy * sizeof(limb_t)); + bzero_unused(); if (n_old) { secure_bzero((byte_ptr)old_data, n_old * int(sizeof(limb_t))); @@ -594,6 +626,8 @@ void bits_t::set(int index, bool value) { } void bits_t::append(bool value) { + cb_assert(bits >= 0 && "bits_t::append: invalid bit count"); + cb_assert(bits < INT_MAX && "bits_t::append: bit count overflow"); resize(bits + 1); set(bits - 1, value); } @@ -611,21 +645,6 @@ bits_t::ref_t bits_t::operator[](int index) { return ref_t(data, index); } -bool bits_t::equ(const bits_t& src1, const bits_t& src2) { - if (src1.bits != src2.bits) return false; - - int n = src1.bits / 64; - if (n > 0) { - if (0 != memcmp(src1.data, src2.data, n * sizeof(uint64_t))) return false; - } - - for (int i = n * 64; i < src1.bits; i++) { - if (src1[i] != src2[i]) return false; - } - - return true; -} - bits_t& bits_t::operator^=(const bits_t& src) { cb_assert(src.bits == bits); int n = bits_to_limbs(bits); @@ -648,14 +667,32 @@ bits_t operator^(const bits_t& src1, const bits_t& src2) { bits_t& bits_t::operator+=(const bits_t& src2) { int n1 = count(); int n2 = src2.count(); + cb_assert(n1 >= 0 && n2 >= 0); + cb_assert(n2 <= INT_MAX - n1 && "bits_t::operator+=: size overflow"); + const int new_count = n1 + n2; + + // Special-case self-append: `to_bin()` returns a view into `data`, and `resize()` may reallocate + free the old + // buffer, turning that view into a dangling pointer. This can trigger UAF for `x += x` on the byte-aligned fast path. + if (&src2 == this) { + resize(new_count); + if ((n1 % 8) == 0) { + const int bytes = bits_to_bytes(n1); + // Append the original bytes by copying within the resized buffer. + memmove(byte_ptr(data) + bytes, data, bytes); + } else { + // Bit-level append (no overlap: read from [0..n1), write to [n1..n1+n2)). + for (int i = 0; i < n2; i++) (*this)[n1 + i] = (*this)[i]; + } + return *this; + } mem_t src1_mem = to_bin(); mem_t src2_mem = src2.to_bin(); - resize(n1 + n2); - if ((n1 % 8) == 0) + resize(new_count); + if ((n1 % 8) == 0) { memmove(byte_ptr(data) + src1_mem.size, src2_mem.data, src2_mem.size); - else { + } else { for (int i = 0; i < n2; i++) (*this)[n1 + i] = src2[i]; } return *this; @@ -664,7 +701,10 @@ bits_t& bits_t::operator+=(const bits_t& src2) { bits_t bits_t::operator+(const bits_t& src2) const { int n1 = count(); int n2 = src2.count(); - bits_t dst(n1 + n2); + cb_assert(n1 >= 0 && n2 >= 0); + cb_assert(n2 <= INT_MAX - n1 && "bits_t::operator+: size overflow"); + const int new_count = n1 + n2; + bits_t dst(new_count); mem_t dst_mem = dst.to_bin(); mem_t src1_mem = to_bin(); @@ -705,7 +745,20 @@ std::vector buf_t::from_mems(const std::vector& in) { } // namespace coinbase +namespace coinbase { + std::ostream& operator<<(std::ostream& os, mem_t mem) { + // NOTE: `mem_t` is frequently used to carry opaque blobs, including secrets (key shares, plaintexts, etc.). + // Dumping full hex in production is a common source of accidental secret leakage via logs. + // + // In Debug builds, keep the full hex dump for developer ergonomics. + // In non-Debug builds, redact content and print only the size. +#ifdef _DEBUG os << strext::to_hex(mem); +#else + os << ""; +#endif return os; } + +} // namespace coinbase diff --git a/src/cbmpc/core/buf128.cpp b/src/cbmpc/core/buf128.cpp index 9d497dee..b40af329 100644 --- a/src/cbmpc/core/buf128.cpp +++ b/src/cbmpc/core/buf128.cpp @@ -1,4 +1,4 @@ -#include +#include namespace coinbase { @@ -89,6 +89,8 @@ buf128_t& buf128_t::operator=(mem_t src) { return *this = load(src.data); } +buf128_t::operator mem_t() const { return mem_t(byte_ptr(this), sizeof(buf128_t)); } + uint64_t buf128_t::lo() const { return u128_lo(value); } uint64_t buf128_t::hi() const { return u128_hi(value); } buf128_t buf128_t::load(const_byte_ptr src) noexcept(true) { return u128(u128_load(src)); } @@ -111,7 +113,8 @@ bool buf128_t::get_bit(int index) const { cb_assert(index >= 0 && index < 128); int n = index / 64; index %= 64; - return ((((const uint64_t*)(this))[n] >> index) & 1) != 0; + const uint64_t limb = (n == 0) ? lo() : hi(); + return ((limb >> index) & 1) != 0; } void buf128_t::set_bit(int index, bool bit) { @@ -181,6 +184,7 @@ buf128_t::reverse_bytes() const { buf128_t buf128_t::operator<<(unsigned n) const { cb_assert(n < 128); + if (n == 0) return *this; uint64_t l = lo(); uint64_t r = hi(); if (n == 64) { @@ -199,6 +203,7 @@ buf128_t buf128_t::operator<<(unsigned n) const { buf128_t buf128_t::operator>>(unsigned n) const { cb_assert(n < 128); + if (n == 0) return *this; uint64_t l = lo(); uint64_t r = hi(); if (n == 64) { diff --git a/src/cbmpc/core/buf256.cpp b/src/cbmpc/core/buf256.cpp old mode 100755 new mode 100644 index 117c391d..e2298179 --- a/src/cbmpc/core/buf256.cpp +++ b/src/cbmpc/core/buf256.cpp @@ -1,4 +1,4 @@ -#include +#include namespace coinbase { @@ -7,6 +7,8 @@ buf256_t& buf256_t::operator=(mem_t src) { return *this = load(src.data); } +buf256_t::operator mem_t() const { return mem_t(byte_ptr(this), sizeof(buf256_t)); } + buf256_t& buf256_t::operator=(const buf_t& src) { cb_assert(src.size() == sizeof(buf256_t)); return *this = load(src.data()); @@ -46,21 +48,16 @@ void buf256_t::save(byte_ptr dst) const { bool buf256_t::get_bit(int index) const { cb_assert(index >= 0 && index < 256); - int n = index / 64; - index %= 64; - return ((((const uint64_t*)(this))[n] >> index) & 1) != 0; + if (index < 128) return lo.get_bit(index); + return hi.get_bit(index - 128); } void buf256_t::set_bit(int index, bool value) { cb_assert(index >= 0 && index < 256); - int n = index / 64; - index %= 64; - uint64_t mask = uint64_t(1) << index; - - if (value) - ((uint64_t*)(this))[n] |= mask; + if (index < 128) + lo.set_bit(index, value); else - ((uint64_t*)(this))[n] &= ~mask; + hi.set_bit(index - 128, value); } bool buf256_t::operator==(const buf256_t& src) const { return ((src.lo ^ lo) | (src.hi ^ hi)) == ZERO128; } @@ -149,6 +146,7 @@ void buf256_t::convert(coinbase::converter_t& converter) { buf256_t buf256_t::operator<<(unsigned n) const { cb_assert(n < 256); + if (n == 0) return *this; buf128_t l = lo; buf128_t r = hi; if (n == 128) { @@ -167,6 +165,7 @@ buf256_t buf256_t::operator<<(unsigned n) const { buf256_t buf256_t::operator>>(unsigned n) const { cb_assert(n < 256); + if (n == 0) return *this; buf128_t l = lo; buf128_t r = hi; if (n == 128) { @@ -236,7 +235,7 @@ buf256_t::caryless_mul(buf128_t a, buf128_t b) { buf256_t m = buf256_t::make(a, ZERO128); for (int i = 0; i < 128; i++) { - if (b.get_bit(i)) r ^= m; + r ^= m & b.get_bit(i); m <<= 1; } diff --git a/src/cbmpc/core/cmem.h b/src/cbmpc/core/cmem.h deleted file mode 100644 index d58c74b5..00000000 --- a/src/cbmpc/core/cmem.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once -#include - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct tag_cmem_t { - uint8_t* data; - int size; -} cmem_t; - -typedef struct tag_cmems_t { - int count; - uint8_t* data; - int* sizes; -} cmems_t; - -#ifdef __cplusplus -} -#endif diff --git a/src/cbmpc/core/convert.cpp b/src/cbmpc/core/convert.cpp index b3301f08..61d20205 100644 --- a/src/cbmpc/core/convert.cpp +++ b/src/cbmpc/core/convert.cpp @@ -1,4 +1,4 @@ -#include "convert.h" +#include namespace coinbase { @@ -94,7 +94,7 @@ void converter_t::convert(std::string& value) { convert(value_size); if (write) { - if (pointer) memmove(current(), &value[0], value_size); + if (pointer) memmove(current(), value.data(), value_size); } else { if (value_size < 0) { set_error(); @@ -106,7 +106,7 @@ void converter_t::convert(std::string& value) { return; } value.resize(value_size); - memmove(&value[0], current(), value_size); + memmove(value.data(), current(), value_size); } forward(value_size); } @@ -115,7 +115,8 @@ converter_t::converter_t(bool _write) : write(_write), rv_error(0), pointer(null converter_t::converter_t(byte_ptr out) : write(true), rv_error(0), pointer(out), offset(0), size(0) {} -converter_t::converter_t(mem_t src) : write(false), rv_error(0), pointer(src.data), offset(0), size(src.size) {} +converter_t::converter_t(mem_t src) + : write(false), rv_error(0), pointer(const_cast(src.data)), offset(0), size(src.size) {} void converter_t::set_error() { if (rv_error) return; @@ -130,7 +131,7 @@ void converter_t::set_error(error_t rv) { void converter_t::convert_len(uint32_t& len) { byte_t b = 0; if (write) { - cb_assert(len <= 0x1fffffff); + cb_assert(len <= MAX_CONVERT_LEN); if (len <= 0x7f) { b = byte_t(len); convert(b); @@ -168,13 +169,22 @@ void converter_t::convert_len(uint32_t& len) { } if ((b & 0x80) == 0) { len = b; + if (len > MAX_CONVERT_LEN) { + set_error(); + len = 0; + } return; } if ((b & 0x40) == 0) { len = b & 0x3f; convert(b); len = (len << 8) | b; - if (is_error()) len = 0; + if (is_error()) + len = 0; + else if (len > MAX_CONVERT_LEN) { + set_error(); + len = 0; + } return; } if ((b & 0x20) == 0) { @@ -183,7 +193,12 @@ void converter_t::convert_len(uint32_t& len) { len = (len << 8) | b; convert(b); len = (len << 8) | b; - if (is_error()) len = 0; + if (is_error()) + len = 0; + else if (len > MAX_CONVERT_LEN) { + set_error(); + len = 0; + } return; } len = b & 0x1f; @@ -193,7 +208,12 @@ void converter_t::convert_len(uint32_t& len) { len = (len << 8) | b; convert(b); len = (len << 8) | b; - if (is_error()) len = 0; + if (is_error()) + len = 0; + else if (len > MAX_CONVERT_LEN) { + set_error(); + len = 0; + } } } @@ -217,34 +237,6 @@ void converter_t::convert(std::vector& value) { } } -void convertable_t::factory_t::register_type(def_t* def, uint64_t code_type) { - g_convertable_factory.instance().map[code_type] = def; -} - -convertable_t* convertable_t::factory_t::create(uint64_t code_type) { - const auto& map = g_convertable_factory.instance().map; - const auto i = map.find(code_type); - if (i == map.end()) return nullptr; - return i->second->create(); -} - -convertable_t* convertable_t::factory_t::create(mem_t mem, bool convert) { - if (mem.size < sizeof(uint64_t)) return nullptr; - - uint64_t code_type = be_get_8(mem.data); - convertable_t* obj = create(code_type); - if (!convert) return obj; - - if (!obj) return nullptr; - - converter_t converter(mem); - obj->convert(converter); - if (!converter.is_error()) return obj; - - delete obj; - return nullptr; -} - uint64_t converter_t::convert_code_type(uint64_t code, uint64_t code2, uint64_t code3, uint64_t code4, uint64_t code5, uint64_t code6, uint64_t code7, uint64_t code8) { uint64_t value = code; diff --git a/src/cbmpc/core/error.cpp b/src/cbmpc/core/error.cpp old mode 100755 new mode 100644 index e683f83d..3ba96a9b --- a/src/cbmpc/core/error.cpp +++ b/src/cbmpc/core/error.cpp @@ -1,8 +1,7 @@ -#include "error.h" - -#include +#include #include -#include +#include +#include #if !defined(_DEBUG) // #define JSON_ERR diff --git a/src/cbmpc/core/extended_uint.cpp b/src/cbmpc/core/extended_uint.cpp index 77f4245c..b329dde9 100644 --- a/src/cbmpc/core/extended_uint.cpp +++ b/src/cbmpc/core/extended_uint.cpp @@ -1,7 +1,6 @@ -#include "extended_uint.h" - #include -#include +#include +#include namespace coinbase { diff --git a/src/cbmpc/core/strext.cpp b/src/cbmpc/core/strext.cpp index 6ee8ccb8..04d4c41f 100755 --- a/src/cbmpc/core/strext.cpp +++ b/src/cbmpc/core/strext.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include namespace coinbase { size_t insensitive_hasher_t::operator()(const std::string& key) const { @@ -64,7 +64,7 @@ std::vector strext::split_to_words(const std::string& str) { std::vector strext::tokenize(const std::string& str, const std::string& delim) { // static std::vector out; - buf_t buf(const_byte_ptr(str.c_str()), int(str.length()) + 1); + coinbase::buf_t buf(const_byte_ptr(str.c_str()), int(str.length()) + 1); char_ptr dup = char_ptr(buf.data()); char* save = nullptr; const_char_ptr token = strtok_r(dup, delim.c_str(), &save); @@ -124,8 +124,11 @@ void strext::print_hex_byte(char_ptr str, uint8_t value) { *str++ = hex[value & 15]; } -std::string strext::to_hex(mem_t mem) { - std::string out(mem.size * 2, char(0)); +std::string strext::to_hex(coinbase::mem_t mem) { + if (mem.size <= 0 || !mem.data) return ""; + + const size_t n = static_cast(mem.size); + std::string out(n * 2, char(0)); char_ptr s = buffer(out); for (int i = 0; i < mem.size; i++, s += 2) print_hex_byte(s, mem.data[i]); return out; @@ -146,7 +149,7 @@ std::string strext::to_hex(uint32_t src) { return print_hex(src, 4); } std::string strext::to_hex(uint64_t src) { return print_hex(src, 8); } -bool strext::from_hex(buf_t& dst, const std::string& src) { +bool strext::from_hex(coinbase::buf_t& dst, const std::string& src) { int length = (int)src.length(); if (length & 1) return false; int dst_size = length / 2; diff --git a/src/cbmpc/core/thread.h b/src/cbmpc/core/thread.h deleted file mode 100755 index e171ac98..00000000 --- a/src/cbmpc/core/thread.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -namespace coinbase { - -template -class global_t { - public: - global_t() noexcept(true) { change_ref_count(+1); } - ~global_t() { - if (change_ref_count(-1)) return; - T* ptr = instance_ptr(false); - if (ptr) ptr->~T(); - } - T& instance() { return *instance_ptr(true); } - - private: - static T* instance_ptr(bool force) { - static std::once_flag once; - static bool initialized = false; - if (!force && !initialized) return nullptr; - - static unsigned char __attribute__((aligned(16))) buf[sizeof(T)]; - - std::call_once(once, []() { - new ((T*)buf) T(); - initialized = true; - }); - return (T*)buf; - } - static int change_ref_count(int x) { - static int ref_count = 0; - return ref_count += x; - } -}; - -template -class global_init_t : public global_t { - public: - global_init_t() : global_t() { global_t::instance(); } -}; - -} // namespace coinbase diff --git a/src/cbmpc/crypto/CMakeLists.txt b/src/cbmpc/crypto/CMakeLists.txt index 7d4cbaf9..170f03d5 100755 --- a/src/cbmpc/crypto/CMakeLists.txt +++ b/src/cbmpc/crypto/CMakeLists.txt @@ -17,7 +17,6 @@ target_sources(cbmpc_crypto PRIVATE base_rsa.cpp base_paillier.cpp base_rsa_oaep.cpp - base_pki.cpp drbg.cpp ro.cpp diff --git a/src/cbmpc/crypto/base.cpp b/src/cbmpc/crypto/base.cpp index 8d7fca72..b576a414 100644 --- a/src/cbmpc/crypto/base.cpp +++ b/src/cbmpc/crypto/base.cpp @@ -1,8 +1,6 @@ -#include "base.h" - -#include - -#include "scope.h" +#include +#include +#include namespace coinbase::crypto { // clang-format off @@ -107,7 +105,7 @@ void gen_random(byte_ptr output, int size) { cb_assert(res > 0); } -void gen_random(mem_t out) { gen_random(out.data, out.size); } +void gen_random(mem_t out) { gen_random(const_cast(out.data), out.size); } bool gen_random_bool() { uint8_t temp = 0; @@ -125,7 +123,7 @@ buf_t gen_random_bitlen(int bitlen) { return gen_random(coinbase::bits_to_bytes( coinbase::bits_t gen_random_bits(int count) { coinbase::bits_t out(count); - gen_random(mem_t(out).data, coinbase::bits_to_bytes(count)); + gen_random(const_cast(mem_t(out).data), coinbase::bits_to_bytes(count)); return out; } @@ -233,7 +231,7 @@ void aes_gcm_t::encrypt_final(mem_t tag) // tag.data is output int out_size = 0; cb_assert(0 < EVP_EncryptFinal_ex(cipher.ctx, NULL, &out_size)); cb_assert(out_size == 0); - cb_assert(0 < EVP_CIPHER_CTX_ctrl(cipher.ctx, EVP_CTRL_GCM_GET_TAG, tag.size, tag.data)); + cb_assert(0 < EVP_CIPHER_CTX_ctrl(cipher.ctx, EVP_CTRL_GCM_GET_TAG, tag.size, const_cast(tag.data))); } void aes_gcm_t::decrypt_init(mem_t key, mem_t iv, mem_t auth) { @@ -248,7 +246,7 @@ void aes_gcm_t::decrypt_init(mem_t key, mem_t iv, mem_t auth) { } error_t aes_gcm_t::decrypt_final(mem_t tag) { - cb_assert(0 < EVP_CIPHER_CTX_ctrl(cipher.ctx, EVP_CTRL_GCM_SET_TAG, tag.size, tag.data)); + cb_assert(0 < EVP_CIPHER_CTX_ctrl(cipher.ctx, EVP_CTRL_GCM_SET_TAG, tag.size, const_cast(tag.data))); int dummy = 0; if (0 >= EVP_DecryptFinal_ex(cipher.ctx, NULL, &dummy)) return coinbase::error(E_CRYPTO); return SUCCESS; @@ -280,7 +278,7 @@ void aes_gmac_t::final(mem_t out) { cb_assert(0 < EVP_EncryptUpdate(ctx, &dummy, &out_size, &dummy, 0)); cb_assert(0 < EVP_EncryptFinal_ex(ctx, NULL, &out_size)); cb_assert(out_size == 0); - cb_assert(0 < EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, out.size, out.data)); + cb_assert(0 < EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, out.size, const_cast(out.data))); } buf_t aes_gmac_t::final(int size) { diff --git a/src/cbmpc/crypto/base_bn.cpp b/src/cbmpc/crypto/base_bn.cpp index d7a10a6b..cf4dec76 100644 --- a/src/cbmpc/crypto/base_bn.cpp +++ b/src/cbmpc/crypto/base_bn.cpp @@ -1,4 +1,4 @@ -#include +#include extern "C" void bn_correct_top(BIGNUM* a); @@ -10,10 +10,13 @@ static thread_local const mod_t* g_thread_local_storage_modo = nullptr; static const mod_t* thread_local_storage_mod() { return g_thread_local_storage_modo; } /** * @notes: - * - Although static code analysis marks this is dangerous, it is safe the way we use it: - * - We use it in the MODULE macros such that in `MODULE(q) { operation }`, all operations - * are done modulo q. In this case, the mod is set once and the pointer is valid until - * we exit the scope of the MODULE. + * - Static analysis flags this as dangerous because it is a single thread-local pointer that affects all `bn_t` + * arithmetic on the current thread. + * - It is intended to be used via the `MODULO(q) { ... }` macro so operations inside the block are performed + * modulo `q`. + * - The macro resets the modulus in the `for` loop update clause; if the block exits early (e.g., `return`, + * `break`, `throw`, `goto`), the modulus may remain set and affect subsequent operations on the same thread. + * - The modulus is not stacked, so `MODULO(...)` must not be nested. */ static void thread_local_storage_set_mod(const mod_t* ptr) { g_thread_local_storage_modo = ptr; } @@ -121,20 +124,33 @@ bn_t::operator BIGNUM*() { return &val; } void bn_t::correct_top() const { bn_correct_top((BIGNUM*)&val); } int64_t bn_t::get_int64() const { - int64_t result = (int64_t)BN_get_word(*this); - if (BN_is_negative(*this)) result = -result; - return result; + const bool neg = BN_is_negative(*this); + cb_assert(BN_num_bits(*this) <= 64); + + const uint64_t abs_val = static_cast(BN_get_word(*this)); + if (!neg) { + cb_assert(abs_val <= static_cast(INT64_MAX)); + return static_cast(abs_val); + } + + cb_assert(abs_val <= static_cast(INT64_MAX) + 1); + if (abs_val == static_cast(INT64_MAX) + 1) return INT64_MIN; + return -static_cast(abs_val); } void bn_t::set_int64(int64_t src) { bool neg = src < 0; - if (neg) src = -src; - int res = BN_set_word(*this, (BN_ULONG)src); + uint64_t abs_val = neg ? -static_cast(src) : static_cast(src); + int res = BN_set_word(*this, static_cast(abs_val)); cb_assert(res); if (neg) BN_set_negative(*this, 1); } -bn_t::operator int() const { return (int)get_int64(); } +bn_t::operator int() const { + int64_t val = get_int64(); + cb_assert(val >= INT_MIN && val <= INT_MAX); + return static_cast(val); +} bn_t& bn_t::operator=(int src) { set_int64(src); @@ -214,9 +230,9 @@ bn_t& bn_t::operator+=(int src2) { int res; if (src2 >= 0) - res = BN_add_word(*this, src2); + res = BN_add_word(*this, static_cast(src2)); else - res = BN_sub_word(*this, -src2); + res = BN_sub_word(*this, static_cast(-static_cast(src2))); cb_assert(res); return *this; } @@ -227,9 +243,9 @@ bn_t& bn_t::operator-=(int src2) { int res; if (src2 >= 0) - res = BN_sub_word(*this, src2); + res = BN_sub_word(*this, static_cast(src2)); else - res = BN_add_word(*this, -src2); + res = BN_add_word(*this, static_cast(-static_cast(src2))); cb_assert(res); return *this; } @@ -239,8 +255,8 @@ bn_t& bn_t::operator*=(int src2) { if (mod) return *this = mod->mul(*this, mod->mod(src2)); bool neg = src2 < 0; - if (neg) src2 = -src2; - int res = BN_mul_word(*this, src2); + const BN_ULONG abs_src2 = neg ? static_cast(-static_cast(src2)) : static_cast(src2); + int res = BN_mul_word(*this, abs_src2); cb_assert(res); if (neg) BN_set_negative(*this, !BN_is_negative(*this)); cb_assert(res); @@ -270,9 +286,9 @@ bn_t operator+(const bn_t& src1, int src2) { int res; bn_t result = src1; if (src2 >= 0) - res = BN_add_word(result, src2); + res = BN_add_word(result, static_cast(src2)); else - res = BN_sub_word(result, -src2); + res = BN_sub_word(result, static_cast(-static_cast(src2))); cb_assert(res); return result; } @@ -294,9 +310,9 @@ bn_t operator-(const bn_t& src1, int src2) { bn_t result = src1; int res; if (src2 >= 0) - res = BN_sub_word(result, src2); + res = BN_sub_word(result, static_cast(src2)); else - res = BN_add_word(result, -src2); + res = BN_add_word(result, static_cast(-static_cast(src2))); cb_assert(res); return result; } @@ -317,8 +333,8 @@ bn_t operator*(const bn_t& src1, int src2) { bn_t result = src1; bool neg = src2 < 0; - if (neg) src2 = -src2; - int res = BN_mul_word(result, src2); + const BN_ULONG abs_src2 = neg ? static_cast(-static_cast(src2)) : static_cast(src2); + int res = BN_mul_word(result, abs_src2); cb_assert(res); if (neg) BN_set_negative(result, !BN_is_negative(result)); return result; @@ -365,27 +381,35 @@ bn_t bn_t::div(const bn_t& src1, const bn_t& src2, bn_t* rem) { // static } bn_t& bn_t::operator<<=(int value) { - int res = BN_lshift(*this, *this, value); + if (value <= 0) return *this; + + const int res = BN_lshift(*this, *this, value); cb_assert(res); return *this; } bn_t& bn_t::operator>>=(int value) { - int res = BN_rshift(*this, *this, value); + if (value <= 0) return *this; + + const int res = BN_rshift(*this, *this, value); cb_assert(res); return *this; } bn_t bn_t::lshift(int n) const { + if (n <= 0) return *this; + bn_t result; - int res = BN_lshift(result, *this, n); + const int res = BN_lshift(result, *this, n); cb_assert(res); return result; } bn_t bn_t::rshift(int n) const { + if (n <= 0) return *this; + bn_t result; - int res = BN_rshift(result, *this, n); + const int res = BN_rshift(result, *this, n); cb_assert(res); return result; } @@ -495,12 +519,28 @@ bn_t bn_t::from_bin(mem_t mem) { // static } std::vector bn_t::vector_from_bin(mem_t mem, int n, int size, const mod_t& q) { // static - std::vector result(n); - cb_assert(mem.size == n * size); - for (int i = 0; i < n; i++, mem = mem.skip(size)) result[i] = bn_t::from_bin(mem.take(size)) % q; + std::vector result; + const error_t rv = vector_from_bin(mem, n, size, q, result); + if (rv) return {}; return result; } +error_t bn_t::vector_from_bin(mem_t mem, int n, int size, const mod_t& q, std::vector& out) { // static + if (n < 0) return coinbase::error(E_BADARG, "vector_from_bin: negative n", /*to_print_stack_trace=*/false); + if (size < 0) + return coinbase::error(E_BADARG, "vector_from_bin: negative element size", /*to_print_stack_trace=*/false); + if (mem.size < 0) + return coinbase::error(E_BADARG, "vector_from_bin: negative input size", /*to_print_stack_trace=*/false); + + const int64_t expected_size = static_cast(n) * static_cast(size); + if (expected_size != static_cast(mem.size)) + return coinbase::error(E_BADARG, "vector_from_bin: input size mismatch", /*to_print_stack_trace=*/false); + + out.resize(n); + for (int i = 0; i < n; i++, mem = mem.skip(size)) out[i] = bn_t::from_bin(mem.take(size)) % q; + return SUCCESS; +} + bn_t bn_t::from_bin_bitlen(mem_t mem, int bits) { // static cb_assert(mem.size == coinbase::bits_to_bytes(bits)); // Handle the 0-bit / empty-input case without indexing into `mem`. @@ -530,17 +570,39 @@ std::string bn_t::to_hex() const { return result; } +error_t bn_t::from_string(const_char_ptr str, bn_t& result) { + if (!str || *str == '\0') return coinbase::error(E_BADARG, "from_string: empty or null input"); + bn_t tmp; + BIGNUM* ptr = tmp; + int n = BN_dec2bn(&ptr, str); + if (n <= 0 || static_cast(n) != strlen(str)) + return coinbase::error(E_BADARG, "from_string: invalid decimal string"); + result = std::move(tmp); + return SUCCESS; +} + bn_t bn_t::from_string(const_char_ptr str) { bn_t result; - BIGNUM* ptr = result; - cb_assert(0 != BN_dec2bn(&ptr, str)); + error_t rv = from_string(str, result); + cb_assert(rv == 0); return result; } +error_t bn_t::from_hex(const_char_ptr str, bn_t& result) { + if (!str || *str == '\0') return coinbase::error(E_BADARG, "from_hex: empty or null input"); + bn_t tmp; + BIGNUM* ptr = tmp; + int n = BN_hex2bn(&ptr, str); + if (n <= 0 || static_cast(n) != strlen(str)) + return coinbase::error(E_BADARG, "from_hex: invalid hexadecimal string"); + result = std::move(tmp); + return SUCCESS; +} + bn_t bn_t::from_hex(const_char_ptr str) { bn_t result; - BIGNUM* ptr = result; - cb_assert(0 != BN_hex2bn(&ptr, str)); + error_t rv = from_hex(str, result); + cb_assert(rv == 0); return result; } @@ -550,43 +612,53 @@ int bn_t::sign() const { return +1; } -bn_t operator<<(const bn_t& src1, int src2) { - bn_t result; - int res = BN_lshift(result, src1, src2); - cb_assert(res); - return result; -} +bn_t operator<<(const bn_t& src1, int src2) { return src1.lshift(src2); } -bn_t operator>>(const bn_t& src1, int src2) { - bn_t result; - int res = BN_rshift(result, src1, src2); - cb_assert(res); - return result; -} +bn_t operator>>(const bn_t& src1, int src2) { return src1.rshift(src2); } void bn_t::convert(coinbase::converter_t& converter) { - uint32_t neg = sign() < 0; - uint32_t value_size = get_bin_size(); - uint32_t header = (value_size << 1) | neg; - converter.convert_len(header); + static_assert(MAX_SERIALIZED_BIGNUM_BYTES <= (coinbase::converter_t::MAX_CONVERT_LEN >> 1), + "CBMPC_MAX_SERIALIZED_BIGNUM_BYTES must be <= converter_t::MAX_CONVERT_LEN / 2"); if (converter.is_write()) { + const uint32_t neg = sign() < 0 ? 1u : 0u; + const int value_size = get_bin_size(); + cb_assert(value_size >= 0); + cb_assert(static_cast(value_size) <= MAX_SERIALIZED_BIGNUM_BYTES); + + uint32_t header = (static_cast(value_size) << 1) | neg; + converter.convert_len(header); if (!converter.is_calc_size()) to_bin(converter.current()); - } else { - neg = header & 1; - value_size = header >> 1; - if (converter.is_error() || !converter.at_least(value_size)) { - converter.set_error(); - return; - } - if (value_size == 0 && neg) { - converter.set_error(); - return; - } - auto res = BN_bin2bn(converter.current(), value_size, *this); - if (!res) throw std::bad_alloc(); - if (neg) BN_set_negative(*this, 1); + converter.forward(value_size); + return; } + + uint32_t header = 0; + converter.convert_len(header); + + const uint32_t neg = header & 1; + const uint32_t value_size_u32 = header >> 1; + if (value_size_u32 > MAX_SERIALIZED_BIGNUM_BYTES || value_size_u32 > static_cast(INT_MAX)) { + converter.set_error(); + return; + } + const int value_size = static_cast(value_size_u32); + + if (converter.is_error() || !converter.at_least(value_size)) { + converter.set_error(); + return; + } + if (value_size == 0 && neg) { + converter.set_error(); + return; + } + auto res = BN_bin2bn(converter.current(), value_size, *this); + if (!res) { + // Do not throw here: a malicious peer can trigger BN_bin2bn() failures by sending oversized bignums. + converter.set_error(); + return; + } + if (neg) BN_set_negative(*this, 1); converter.forward(value_size); } diff --git a/src/cbmpc/crypto/base_ec_core.cpp b/src/cbmpc/crypto/base_ec_core.cpp index 2dcdc268..83a2eda6 100644 --- a/src/cbmpc/crypto/base_ec_core.cpp +++ b/src/cbmpc/crypto/base_ec_core.cpp @@ -1,5 +1,5 @@ -#include "base_ec_core.h" +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/base_ecc.cpp b/src/cbmpc/crypto/base_ecc.cpp index 0e88e95b..8519c08e 100644 --- a/src/cbmpc/crypto/base_ecc.cpp +++ b/src/cbmpc/crypto/base_ecc.cpp @@ -1,12 +1,11 @@ -#include -#include -#include -#include - -#include "base_ecc_secp256k1.h" -#include "base_eddsa.h" -#include "base_pki.h" -#include "ec25519_core.h" +#include +#include +#include +#include +#include +#include +#include +#include namespace coinbase::crypto { @@ -892,7 +891,7 @@ ecc_point_t ecc_point_t::operator-() const { bool ecc_point_t::operator==(const ecc_point_t& val) const { if (!ptr) return val.ptr == nullptr; - if (!val.ptr) return ptr != nullptr; + if (!val.ptr) return false; if (!curve) return false; if (curve != val.curve) return false; return curve.ptr->equ_points(*this, val); diff --git a/src/cbmpc/crypto/base_ecc_secp256k1.cpp b/src/cbmpc/crypto/base_ecc_secp256k1.cpp index dd004bb2..69063492 100644 --- a/src/cbmpc/crypto/base_ecc_secp256k1.cpp +++ b/src/cbmpc/crypto/base_ecc_secp256k1.cpp @@ -1,4 +1,4 @@ -#include "base_ecc_secp256k1.h" +#include // clang-format off #include "secp256k1/src/assumptions.h" @@ -361,6 +361,8 @@ bn_t hash_message(const bn_t& rx, const ecc_point_t& pub_key, mem_t message) { error_t verify(const ecc_point_t& pub_key, mem_t m, mem_t sig) { error_t rv = UNINITIALIZED_ERROR; if (sig.size != 64) return coinbase::error(E_BADARG, "BIP340 verify: sig size != 64"); + if (m.size != 32) return coinbase::error(E_BADARG, "BIP340 verify: msg size != 32"); + if (!m.data) return coinbase::error(E_BADARG, "BIP340 verify: msg is null"); ecurve_t curve = curve_secp256k1; const mod_t& q = curve.order(); diff --git a/src/cbmpc/crypto/base_eddsa.cpp b/src/cbmpc/crypto/base_eddsa.cpp index 23781970..d347d687 100644 --- a/src/cbmpc/crypto/base_eddsa.cpp +++ b/src/cbmpc/crypto/base_eddsa.cpp @@ -1,5 +1,5 @@ -#include -#include +#include +#include namespace coinbase::crypto { @@ -179,7 +179,7 @@ buf_t ecurve_ed_t::prv_to_der(const ecc_prv_key_t& K) const { cb_assert(K.ed_bin.size() == ed25519::prv_bin_size()); buf_t out(ed25519::pkcs8_prefix.size + ed25519::prv_bin_size()); memmove(out.data(), ed25519::pkcs8_prefix.data, ed25519::pkcs8_prefix.size); - memmove(out.data() + ed25519::x509_prefix.size, K.ed_bin.data(), ed25519::prv_bin_size()); + memmove(out.data() + ed25519::pkcs8_prefix.size, K.ed_bin.data(), ed25519::prv_bin_size()); return out; } diff --git a/src/cbmpc/crypto/base_hash.cpp b/src/cbmpc/crypto/base_hash.cpp index 5f3f47de..f55a8292 100644 --- a/src/cbmpc/crypto/base_hash.cpp +++ b/src/cbmpc/crypto/base_hash.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include // NOLINTBEGIN(*magic-number*) namespace coinbase::crypto { @@ -49,15 +49,15 @@ static const uint8_t SHA3_384_oid[] = {0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, static const uint8_t SHA3_512_oid[] = {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x0a, 0x05, 0x00, 0x04, 0x40}; -static const EVP_MD *evp_sha256() noexcept(true) { return EVP_sha256(); } -static const EVP_MD *evp_sha384() noexcept(true) { return EVP_sha384(); } -static const EVP_MD *evp_sha512() noexcept(true) { return EVP_sha512(); } -static const EVP_MD *evp_sha3_256() noexcept(true) { return EVP_sha3_256(); } -static const EVP_MD *evp_sha3_384() noexcept(true) { return EVP_sha3_384(); } -static const EVP_MD *evp_sha3_512() noexcept(true) { return EVP_sha3_512(); } -static const EVP_MD *evp_blake2s256() noexcept(true) { return EVP_blake2s256(); } -static const EVP_MD *evp_blake2b512() noexcept(true) { return EVP_blake2b512(); } -static const EVP_MD *evp_ripemd160() noexcept(true) { return EVP_ripemd160(); } +static const EVP_MD* evp_sha256() noexcept(true) { return EVP_sha256(); } +static const EVP_MD* evp_sha384() noexcept(true) { return EVP_sha384(); } +static const EVP_MD* evp_sha512() noexcept(true) { return EVP_sha512(); } +static const EVP_MD* evp_sha3_256() noexcept(true) { return EVP_sha3_256(); } +static const EVP_MD* evp_sha3_384() noexcept(true) { return EVP_sha3_384(); } +static const EVP_MD* evp_sha3_512() noexcept(true) { return EVP_sha3_512(); } +static const EVP_MD* evp_blake2s256() noexcept(true) { return EVP_blake2s256(); } +static const EVP_MD* evp_blake2b512() noexcept(true) { return EVP_blake2b512(); } +static const EVP_MD* evp_ripemd160() noexcept(true) { return EVP_ripemd160(); } static const hash_alg_t alg_nohash = {hash_e::none, 0, 0, 0, 0, mem_t(), mem_t(), nullptr}; static const hash_alg_t alg_sha256 = { @@ -81,7 +81,7 @@ static const hash_alg_t alg_blake2s = {hash_e::blake2s, 32, 64, 0, 0, mem_t(), m static const hash_alg_t alg_blake2b = {hash_e::blake2b, 64, 128, 0, 0, mem_t(), mem_t(), evp_blake2b512()}; static const hash_alg_t alg_ripemd160 = {hash_e::ripemd160, 20, 64, 20, 8, mem_t(), mem_t(), evp_ripemd160()}; -const hash_alg_t &hash_alg_t::get(hash_e type) // static +const hash_alg_t& hash_alg_t::get(hash_e type) // static { switch (type) { case hash_e::sha256: @@ -121,20 +121,20 @@ void hash_t::free() { ctx_ptr = nullptr; } -hash_t &hash_t::init() { +hash_t& hash_t::init() { if (!ctx_ptr) ctx_ptr = ::EVP_MD_CTX_new(); ::EVP_DigestInit(ctx_ptr, alg.md); return *this; } -hash_t &hash_t::update(const_byte_ptr ptr, int size) { +hash_t& hash_t::update(const_byte_ptr ptr, int size) { ::EVP_DigestUpdate(ctx_ptr, ptr, size); return *this; } void hash_t::final(byte_ptr out) { ::EVP_DigestFinal(ctx_ptr, out, NULL); } -void hash_t::copy_state(hash_t &dst) { EVP_MD_CTX_copy(dst.ctx_ptr, ctx_ptr); } +void hash_t::copy_state(hash_t& dst) { EVP_MD_CTX_copy(dst.ctx_ptr, ctx_ptr); } buf_t hash_t::final() { buf_t out(alg.size); @@ -156,20 +156,20 @@ buf_t hmac_t::final() { return out; } -hmac_t &hmac_t::init(mem_t key) { +hmac_t& hmac_t::init(mem_t key) { if (!ctx_ptr) { - EVP_MAC *mac = EVP_MAC_fetch(NULL, "HMAC", NULL); + EVP_MAC* mac = EVP_MAC_fetch(NULL, "HMAC", NULL); ctx_ptr = EVP_MAC_CTX_new(mac); EVP_MAC_free(mac); } OSSL_PARAM params[2]; - params[0] = OSSL_PARAM_construct_utf8_string("digest", (char *)EVP_MD_name(alg.md), 0); + params[0] = OSSL_PARAM_construct_utf8_string("digest", (char*)EVP_MD_name(alg.md), 0); params[1] = OSSL_PARAM_construct_end(); EVP_MAC_init(ctx_ptr, key.data, key.size, params); return *this; } -hmac_t &hmac_t::update(const_byte_ptr ptr, int size) { +hmac_t& hmac_t::update(const_byte_ptr ptr, int size) { EVP_MAC_update(ctx_ptr, ptr, size); return *this; } @@ -181,7 +181,7 @@ void hmac_t::final(byte_ptr out) { ctx_ptr = nullptr; } -void hmac_t::copy_state(hmac_t &dst) { +void hmac_t::copy_state(hmac_t& dst) { if (dst.ctx_ptr) EVP_MAC_CTX_free(dst.ctx_ptr); dst.ctx_ptr = EVP_MAC_CTX_dup(ctx_ptr); } @@ -224,28 +224,28 @@ const uint64_t sha512_k[80] = { // ULL = uint64 // -------------------------- RFC 5869 HKDF ---------------------------- buf_t hkdf_extract(hash_e type, mem_t salt, mem_t ikm) { - const hash_alg_t &alg = hash_alg_t::get(type); + const hash_alg_t& alg = hash_alg_t::get(type); buf_t prk(alg.size); - EVP_KDF *kdf = EVP_KDF_fetch(NULL, "HKDF", NULL); + EVP_KDF* kdf = EVP_KDF_fetch(NULL, "HKDF", NULL); cb_assert(kdf && "EVP_KDF_fetch(HKDF) failed"); - EVP_KDF_CTX *kctx = EVP_KDF_CTX_new(kdf); + EVP_KDF_CTX* kctx = EVP_KDF_CTX_new(kdf); EVP_KDF_free(kdf); cb_assert(kctx && "EVP_KDF_CTX_new failed"); int mode = EVP_KDF_HKDF_MODE_EXTRACT_ONLY; OSSL_PARAM params[7]; int pidx = 0; - params[pidx++] = OSSL_PARAM_construct_utf8_string(OSSL_KDF_PARAM_DIGEST, (char *)EVP_MD_name(alg.md), 0); - params[pidx++] = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_KEY, (void *)ikm.data, (size_t)ikm.size); + params[pidx++] = OSSL_PARAM_construct_utf8_string(OSSL_KDF_PARAM_DIGEST, (char*)EVP_MD_name(alg.md), 0); + params[pidx++] = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_KEY, (void*)ikm.data, (size_t)ikm.size); buf_t zero_salt; if (salt.size > 0) { - params[pidx++] = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_SALT, (void *)salt.data, (size_t)salt.size); + params[pidx++] = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_SALT, (void*)salt.data, (size_t)salt.size); } else { zero_salt.resize(alg.size); memset(zero_salt.data(), 0, zero_salt.size()); params[pidx++] = - OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_SALT, (void *)zero_salt.data(), (size_t)zero_salt.size()); + OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_SALT, (void*)zero_salt.data(), (size_t)zero_salt.size()); } params[pidx++] = OSSL_PARAM_construct_int(OSSL_KDF_PARAM_MODE, &mode); params[pidx] = OSSL_PARAM_construct_end(); @@ -257,7 +257,7 @@ buf_t hkdf_extract(hash_e type, mem_t salt, mem_t ikm) { } buf_t hkdf_expand(hash_e type, mem_t prk, mem_t info, int out_len) { - const hash_alg_t &alg = hash_alg_t::get(type); + const hash_alg_t& alg = hash_alg_t::get(type); const int hash_len = alg.size; cb_assert(out_len >= 0); const int n = (out_len + hash_len - 1) / hash_len; @@ -265,19 +265,19 @@ buf_t hkdf_expand(hash_e type, mem_t prk, mem_t info, int out_len) { buf_t okm(out_len); - EVP_KDF *kdf = EVP_KDF_fetch(NULL, "HKDF", NULL); + EVP_KDF* kdf = EVP_KDF_fetch(NULL, "HKDF", NULL); cb_assert(kdf && "EVP_KDF_fetch(HKDF) failed"); - EVP_KDF_CTX *kctx = EVP_KDF_CTX_new(kdf); + EVP_KDF_CTX* kctx = EVP_KDF_CTX_new(kdf); EVP_KDF_free(kdf); cb_assert(kctx && "EVP_KDF_CTX_new failed"); int mode = EVP_KDF_HKDF_MODE_EXPAND_ONLY; OSSL_PARAM params[6]; int pidx = 0; - params[pidx++] = OSSL_PARAM_construct_utf8_string(OSSL_KDF_PARAM_DIGEST, (char *)EVP_MD_name(alg.md), 0); - params[pidx++] = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_KEY, (void *)prk.data, (size_t)prk.size); + params[pidx++] = OSSL_PARAM_construct_utf8_string(OSSL_KDF_PARAM_DIGEST, (char*)EVP_MD_name(alg.md), 0); + params[pidx++] = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_KEY, (void*)prk.data, (size_t)prk.size); if (info.size > 0) { - params[pidx++] = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_INFO, (void *)info.data, (size_t)info.size); + params[pidx++] = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_INFO, (void*)info.data, (size_t)info.size); } params[pidx++] = OSSL_PARAM_construct_int(OSSL_KDF_PARAM_MODE, &mode); params[pidx] = OSSL_PARAM_construct_end(); diff --git a/src/cbmpc/crypto/base_mod.cpp b/src/cbmpc/crypto/base_mod.cpp index 197f1c03..96cb10ec 100644 --- a/src/cbmpc/crypto/base_mod.cpp +++ b/src/cbmpc/crypto/base_mod.cpp @@ -1,14 +1,10 @@ -#include -#include -#include +#include +#include +#include namespace coinbase::crypto { -#if defined(TARGET_IPHONE_SIMULATOR) && TARGET_IPHONE_SIMULATOR -static thread_local int vartime_scope = 1; -#else static thread_local int vartime_scope = 0; -#endif vartime_scope_t::vartime_scope_t() { vartime_scope++; } vartime_scope_t::~vartime_scope_t() { vartime_scope--; } @@ -86,7 +82,7 @@ mod_t& mod_t::operator=(mod_t&& src) { return *this; } -void mod_t::check(const bn_t& a) const { assert(is_in_range(a) && "out of range for constant-time operations"); } +void mod_t::check(const bn_t& a) const { cb_assert(is_in_range(a) && "out of range for constant-time operations"); } bool mod_t::is_in_range(const bn_t& a) const { return a.sign() >= 0 && a < m; } @@ -293,12 +289,27 @@ void mod_t::scr_inv(bn_t& res, const bn_t& in) const { void mod_t::random_masking_inv(bn_t& r, const bn_t& a) const { // Eventhough, this function is not truely constant-time, the running time is not dependent on the input (bn_t a). // Therefore, it doesn't leak any information of the input. - bn_t mask = rand(); - bn_t masked_a = mul(a, mask); - masked_a.correct_top(); - auto res = BN_mod_inverse(r, masked_a, m, bn_t::thread_local_storage_bn_ctx()); - cb_assert(res && "mod_t::random_masking_inv failed"); - r = mul(r, mask); + // `BN_mod_inverse` fails when the operand is not invertible modulo `m`. + // Even when `a` is invertible, a random mask might not be (e.g. if `m` is composite), + // which would make `a*mask` non-invertible. Retry with a fresh mask. + // + // Bound retries to avoid hanging forever when `a` itself isn't invertible. + constexpr int max_attempts = 128; + bn_t mask; + bn_t masked_a; + for (int attempt = 0; attempt < max_attempts; attempt++) { + mask = rand(); + if (mask == 0) continue; + + masked_a = mul(a, mask); + masked_a.correct_top(); + if (BN_mod_inverse(r, masked_a, m, bn_t::thread_local_storage_bn_ctx())) { + r = mul(r, mask); + return; + } + } + + cb_assert(false && "mod_t::random_masking_inv failed"); } void mod_t::_inv(bn_t& r, const bn_t& a, inv_algo_e alg) const { @@ -503,8 +514,8 @@ bn_t mod_t::N_inv_mod_phiN_2048(const bn_t& N, const bn_t& phiN) { cb_assert(res); return result; } - assert(!phiN.is_odd()); - assert(N.is_odd()); + cb_assert(!phiN.is_odd()); + cb_assert(N.is_odd()); bn_t N_minus_phiN = LARGEST_PRIME_MOD_2048.sub(N, phiN); N_minus_phiN.correct_top(); mod_t mod_N_minus_phiN(N_minus_phiN, false); diff --git a/src/cbmpc/crypto/base_paillier.cpp b/src/cbmpc/crypto/base_paillier.cpp index 37736f81..2a5ed08a 100644 --- a/src/cbmpc/crypto/base_paillier.cpp +++ b/src/cbmpc/crypto/base_paillier.cpp @@ -1,5 +1,7 @@ -#include -#include +#include + +#include +#include namespace coinbase::crypto { @@ -19,6 +21,7 @@ void paillier_t::convert(coinbase::converter_t& converter) { } if (!converter.is_write()) { + if (converter.is_error()) return; if (has_private) update_private(); else @@ -61,6 +64,25 @@ void paillier_t::update_private() { inv_phi_N = N.inv(phi_N); + // Precompute N^{-1} mod 2^bit_size for constant-time L(u) extraction during decryption. + // (See `paillier_t::decrypt` for details.) + { + bn_t two_pow = bn_t(1) << bit_size; // 2^2048 + auto* inv = BN_mod_inverse(inv_N_mod_2k, N.value(), two_pow, bn_t::thread_local_storage_bn_ctx()); + cb_assert(inv && "paillier_t::update_private: failed to invert N mod 2^bit_size"); + + constexpr int BN_ULONG_BITS = int(sizeof(BN_ULONG) * 8); + static_assert(bit_size % BN_ULONG_BITS == 0, "Paillier bit_size must be BN_ULONG-word aligned"); + constexpr int k_words = bit_size / BN_ULONG_BITS; + BIGNUM& inv_bn = *(BIGNUM*)inv_N_mod_2k; + cb_assert(bn_wexpand(&inv_bn, k_words)); + // Zero any unused high words and force a fixed-top representation so decrypt can read `d[0..k_words)`. + for (int i = inv_bn.top; i < k_words; i++) inv_bn.d[i] = 0; + inv_bn.top = k_words; + inv_bn.neg = 0; + inv_bn.flags |= BN_FLG_FIXED_TOP | BN_FLG_CONSTTIME; + } + // p^2 bn_t p_sqr = p * p; @@ -211,12 +233,41 @@ bn_t paillier_t::decrypt(const bn_t& src) const { } // Side-channel note: - // This is the Paillier L(u) step: L(c1) = (c1 - 1) / N, with c1 ∈ Z_{N^2}. This division uses generic bignum - // arithmetic and is not designed to be strictly constant-time with respect to `c1`'s value. - // We acknowledge and accept the residual timing side-channel risk, but in our current threat model Paillier - // decryption is not exposed as a high-resolution timing oracle, so we consider this risk negligible in practice. - // If that assumption changes, replace this with a fixed-size, constant-time extraction of L(c1). - bn_t m1 = (c1 - 1) / N; + // This is the Paillier L(u) step: L(c1) = (c1 - 1) / N, with c1 ∈ Z_{N^2}. + // For odd N, division by N can be replaced by multiplication with N^{-1} modulo 2^k: + // (c1 - 1) = N * L(c1) ⇒ L(c1) ≡ (c1 - 1) * N^{-1} (mod 2^k) + // With k = 2048 and 0 ≤ L(c1) < N < 2^k, this recovers L(c1) exactly from the low k bits. + constexpr int BN_ULONG_BITS = int(sizeof(BN_ULONG) * 8); + static_assert(bit_size % BN_ULONG_BITS == 0, "Paillier bit_size must be BN_ULONG-word aligned"); + constexpr int k_words = bit_size / BN_ULONG_BITS; + + const BIGNUM& c1_bn = *(const BIGNUM*)c1; + cb_assert(c1_bn.top >= 0); + + // tmp_low = (c1 - 1) mod 2^k (little-endian words) + BN_ULONG tmp_low[k_words]; + BN_ULONG borrow = 1; + for (int i = 0; i < k_words; i++) { + BN_ULONG w = 0; + if (i < c1_bn.top) w = c1_bn.d[i]; + tmp_low[i] = w - borrow; + borrow = (borrow && (w == 0)) ? 1 : 0; + } + + // m1_words = tmp_low * inv_N_mod_2k mod 2^k + const BIGNUM& inv_bn = *(const BIGNUM*)inv_N_mod_2k; + cb_assert(inv_bn.top == k_words); + BN_ULONG prod[k_words * 2]; + bn_mul_normal(prod, tmp_low, k_words, inv_bn.d, k_words); + + bn_t m1; + BIGNUM& m1_bn = *(BIGNUM*)m1; + cb_assert(bn_wexpand(&m1_bn, k_words)); + std::copy(prod, prod + k_words, m1_bn.d); + m1_bn.top = k_words; + m1_bn.neg = 0; + m1_bn.flags |= BN_FLG_FIXED_TOP | BN_FLG_CONSTTIME; + MODULO(N) m1 *= inv_phi_N; return m1; } diff --git a/src/cbmpc/crypto/base_pki.cpp b/src/cbmpc/crypto/base_pki.cpp deleted file mode 100644 index ee2a95a2..00000000 --- a/src/cbmpc/crypto/base_pki.cpp +++ /dev/null @@ -1,76 +0,0 @@ -#include "base_pki.h" - -namespace coinbase::crypto { - -// For unified PKE types - -pub_key_t pub_key_t::from(const rsa_pub_key_t& src) { - pub_key_t out; - out.rsa_key = src; - out.key_type = key_type_e::RSA; - return out; -} - -pub_key_t pub_key_t::from(const ecc_pub_key_t& src) { - pub_key_t out; - out.ecc_key = src; - out.key_type = key_type_e::ECC; - return out; -} - -prv_key_t prv_key_t::from(const rsa_prv_key_t& src) { - prv_key_t out; - out.rsa_key = src; - out.key_type = key_type_e::RSA; - return out; -} - -prv_key_t prv_key_t::from(const ecc_prv_key_t& src) { - prv_key_t out; - out.ecc_key = src; - out.key_type = key_type_e::ECC; - return out; -} - -pub_key_t prv_key_t::pub() const { - if (key_type == key_type_e::ECC) - return pub_key_t::from(ecc_key.pub()); - else if (key_type == key_type_e::RSA) - return pub_key_t::from(rsa_key.pub()); - cb_assert(false && "Invalid key type"); - return pub_key_t(); -} - -error_t prv_key_t::execute(mem_t enc_info, buf_t& dec_info) const { - if (key_type == key_type_e::ECC) { - return ecc_key.execute(enc_info, dec_info); - } else if (key_type == key_type_e::RSA) - return rsa_key.execute(enc_info, dec_info); - else - return coinbase::error(E_BADARG, "Invalid key type"); -} - -// ------------------------- PKI ciphertext -------------------- -error_t ciphertext_t::encrypt(const pub_key_t& pub_key, mem_t label, mem_t plain, drbg_aes_ctr_t* drbg) { - key_type = pub_key.get_type(); - if (key_type == key_type_e::ECC) { - return ecies.encrypt(pub_key.ecc(), label, plain, drbg); - } else if (key_type == key_type_e::RSA) { - return rsa_kem.encrypt(pub_key.rsa(), label, plain, drbg); - } else { - return coinbase::error(E_BADARG, "Invalid key type to encrypt"); - } -} - -error_t ciphertext_t::decrypt(const prv_key_t& prv_key, mem_t label, buf_t& plain) const { - error_t rv = UNINITIALIZED_ERROR; - if (prv_key.get_type() != key_type) return coinbase::error(E_BADARG, "Key type and ciphertext mismatch"); - if (key_type == key_type_e::ECC) { - return ecies.decrypt(prv_key.ecc(), label, plain); - } else if (key_type == key_type_e::RSA) { - return rsa_kem.decrypt(prv_key.rsa(), label, plain); - } - return coinbase::error(E_BADARG); -} - -} // namespace coinbase::crypto diff --git a/src/cbmpc/crypto/base_rsa.cpp b/src/cbmpc/crypto/base_rsa.cpp index 739f3589..0eb8823d 100644 --- a/src/cbmpc/crypto/base_rsa.cpp +++ b/src/cbmpc/crypto/base_rsa.cpp @@ -1,10 +1,9 @@ #include -#include -#include - -#include "base_pki.h" -#include "scope.h" +#include +#include +#include +#include namespace coinbase::crypto { @@ -23,7 +22,7 @@ enum { // ------------------------------ rsa_pub_key_t ------------------------- -error_t rsa_pub_key_t::encrypt_raw(mem_t in, buf_t &out) const { +error_t rsa_pub_key_t::encrypt_raw(mem_t in, buf_t& out) const { int n_size = size(); if (n_size != in.size) return coinbase::error(E_CRYPTO); @@ -58,12 +57,12 @@ int rsa_pub_key_t::size() const { return EVP_PKEY_get_size(ptr); } -void rsa_pub_key_t::set(RSA_BASE *&rsa, const BIGNUM *n, const BIGNUM *e) { +void rsa_pub_key_t::set(RSA_BASE*& rsa, const BIGNUM* n, const BIGNUM* e) { cb_assert(n && e); - OSSL_PARAM_BLD *param_bld = OSSL_PARAM_BLD_new(); + OSSL_PARAM_BLD* param_bld = OSSL_PARAM_BLD_new(); OSSL_PARAM_BLD_push_BN(param_bld, "n", n); OSSL_PARAM_BLD_push_BN(param_bld, "e", e); - OSSL_PARAM *params = OSSL_PARAM_BLD_to_param(param_bld); + OSSL_PARAM* params = OSSL_PARAM_BLD_to_param(param_bld); scoped_ptr_t ctx = EVP_PKEY_CTX_new_from_name(NULL, "RSA", NULL); cb_assert(EVP_PKEY_fromdata_init(ctx) > 0); @@ -73,7 +72,7 @@ void rsa_pub_key_t::set(RSA_BASE *&rsa, const BIGNUM *n, const BIGNUM *e) { OSSL_PARAM_BLD_free(param_bld); } -rsa_pub_key_t::data_t rsa_pub_key_t::get(const EVP_PKEY *pkey) { +rsa_pub_key_t::data_t rsa_pub_key_t::get(const EVP_PKEY* pkey) { data_t data; data.n = NULL; data.e = NULL; @@ -97,7 +96,7 @@ rsa_pub_key_t::data_t rsa_pub_key_t::get(const EVP_PKEY *pkey) { return data; } -void rsa_pub_key_t::convert(coinbase::converter_t &converter) { +void rsa_pub_key_t::convert(coinbase::converter_t& converter) { uint8_t parts = 0; bn_t e, n; @@ -112,6 +111,9 @@ void rsa_pub_key_t::convert(coinbase::converter_t &converter) { parts |= part_n; n = bn_t(data.n); } + + BN_free(data.n); + BN_free(data.e); } converter.convert(parts); @@ -138,11 +140,11 @@ void rsa_pub_key_t::convert(coinbase::converter_t &converter) { // ------------------------------ rsa_prv_key_t ------------------------- -error_t rsa_prv_key_t::execute(mem_t enc_info, buf_t &dec_info) const { +error_t rsa_prv_key_t::execute(mem_t enc_info, buf_t& dec_info) const { return rsa_oaep_t(*this).execute(hash_e::sha256, hash_e::sha256, mem_t(), enc_info, dec_info); } -error_t rsa_prv_key_t::sign_pkcs1(mem_t in, hash_e hash_alg, buf_t &signature) const { +error_t rsa_prv_key_t::sign_pkcs1(mem_t in, hash_e hash_alg, buf_t& signature) const { buf_t buf; unsigned int signature_size = size(); @@ -158,7 +160,7 @@ error_t rsa_prv_key_t::sign_pkcs1(mem_t in, hash_e hash_alg, buf_t &signature) c return SUCCESS; } -error_t rsa_prv_key_t::decrypt_raw(mem_t in, buf_t &out) const { +error_t rsa_prv_key_t::decrypt_raw(mem_t in, buf_t& out) const { int n_size = size(); if (in.size != n_size) return coinbase::error(E_CRYPTO); @@ -173,15 +175,10 @@ error_t rsa_prv_key_t::decrypt_raw(mem_t in, buf_t &out) const { void rsa_prv_key_t::create() { free(); } -void rsa_prv_key_t::generate(int bits, const bn_t &e) { +void rsa_prv_key_t::generate(int bits) { create(); ptr = EVP_RSA_gen(bits); -} - -void rsa_prv_key_t::generate(int bits, int e) { - if (e == 0) e = 65537; - bn_t pub_exp(e); - generate(bits, pub_exp); + cb_assert(ptr); } int rsa_prv_key_t::size() const { @@ -189,29 +186,29 @@ int rsa_prv_key_t::size() const { return EVP_PKEY_get_size(ptr); } -rsa_prv_key_t::data_t rsa_prv_key_t::get(const RSA_BASE *rsa) { +rsa_prv_key_t::data_t rsa_prv_key_t::get(const RSA_BASE* rsa) { data_t data; - OSSL_PARAM *params = NULL; + OSSL_PARAM* params = NULL; cb_assert(EVP_PKEY_todata(rsa, EVP_PKEY_PUBLIC_KEY, ¶ms)); - const OSSL_PARAM *param_e = OSSL_PARAM_locate_const(params, "e"); + const OSSL_PARAM* param_e = OSSL_PARAM_locate_const(params, "e"); cb_assert(param_e); - BIGNUM *e_ptr = data.e; - const OSSL_PARAM *param_n = OSSL_PARAM_locate_const(params, "n"); + BIGNUM* e_ptr = data.e; + const OSSL_PARAM* param_n = OSSL_PARAM_locate_const(params, "n"); cb_assert(param_n); - BIGNUM *n_ptr = data.n; + BIGNUM* n_ptr = data.n; cb_assert(OSSL_PARAM_get_BN(param_e, &e_ptr) > 0); cb_assert(OSSL_PARAM_get_BN(param_n, &n_ptr) > 0); OSSL_PARAM_free(params); params = NULL; cb_assert(EVP_PKEY_todata(rsa, EVP_PKEY_PRIVATE_KEY, ¶ms)); - const OSSL_PARAM *param_p = OSSL_PARAM_locate_const(params, "rsa-factor1"); + const OSSL_PARAM* param_p = OSSL_PARAM_locate_const(params, "rsa-factor1"); cb_assert(param_p); - BIGNUM *p_ptr = data.p; - const OSSL_PARAM *param_q = OSSL_PARAM_locate_const(params, "rsa-factor2"); + BIGNUM* p_ptr = data.p; + const OSSL_PARAM* param_q = OSSL_PARAM_locate_const(params, "rsa-factor2"); cb_assert(param_q); - BIGNUM *q_ptr = data.q; + BIGNUM* q_ptr = data.q; cb_assert(OSSL_PARAM_get_BN(param_p, &p_ptr) > 0); cb_assert(OSSL_PARAM_get_BN(param_q, &q_ptr) > 0); OSSL_PARAM_free(params); @@ -219,13 +216,181 @@ rsa_prv_key_t::data_t rsa_prv_key_t::get(const RSA_BASE *rsa) { return data; } +void rsa_prv_key_t::set(RSA_BASE*& rsa, const BIGNUM* n, const BIGNUM* e, const BIGNUM* d) { + cb_assert(n && e && d); + OSSL_PARAM_BLD* param_bld = OSSL_PARAM_BLD_new(); + cb_assert(param_bld); + + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_N, n) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_E, e) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_D, d) > 0); + + OSSL_PARAM* params = OSSL_PARAM_BLD_to_param(param_bld); + cb_assert(params); + + scoped_ptr_t ctx = EVP_PKEY_CTX_new_from_name(NULL, "RSA", NULL); + cb_assert(ctx); + cb_assert(EVP_PKEY_fromdata_init(ctx) > 0); + cb_assert(EVP_PKEY_fromdata(ctx, &rsa, EVP_PKEY_KEYPAIR, params) > 0); + + OSSL_PARAM_free(params); + OSSL_PARAM_BLD_free(param_bld); +} + +void rsa_prv_key_t::set(RSA_BASE*& rsa, const BIGNUM* n, const BIGNUM* e, const BIGNUM* d, const BIGNUM* p, + const BIGNUM* q) { + cb_assert(n && e && d && p && q); + OSSL_PARAM_BLD* param_bld = OSSL_PARAM_BLD_new(); + cb_assert(param_bld); + + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_N, n) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_E, e) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_D, d) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_FACTOR1, p) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_FACTOR2, q) > 0); + + OSSL_PARAM* params = OSSL_PARAM_BLD_to_param(param_bld); + cb_assert(params); + + scoped_ptr_t ctx = EVP_PKEY_CTX_new_from_name(NULL, "RSA", NULL); + cb_assert(ctx); + cb_assert(EVP_PKEY_fromdata_init(ctx) > 0); + cb_assert(EVP_PKEY_fromdata(ctx, &rsa, EVP_PKEY_KEYPAIR, params) > 0); + + OSSL_PARAM_free(params); + OSSL_PARAM_BLD_free(param_bld); +} + +void rsa_prv_key_t::set(RSA_BASE*& rsa, const BIGNUM* n, const BIGNUM* e, const BIGNUM* d, const BIGNUM* p, + const BIGNUM* q, const BIGNUM* dp, const BIGNUM* dq, const BIGNUM* qinv) { + cb_assert(n && e && d && p && q && dp && dq && qinv); + OSSL_PARAM_BLD* param_bld = OSSL_PARAM_BLD_new(); + cb_assert(param_bld); + + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_N, n) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_E, e) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_D, d) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_FACTOR1, p) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_FACTOR2, q) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_EXPONENT1, dp) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_EXPONENT2, dq) > 0); + cb_assert(OSSL_PARAM_BLD_push_BN(param_bld, OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qinv) > 0); + + OSSL_PARAM* params = OSSL_PARAM_BLD_to_param(param_bld); + cb_assert(params); + + scoped_ptr_t ctx = EVP_PKEY_CTX_new_from_name(NULL, "RSA", NULL); + cb_assert(ctx); + cb_assert(EVP_PKEY_fromdata_init(ctx) > 0); + cb_assert(EVP_PKEY_fromdata(ctx, &rsa, EVP_PKEY_KEYPAIR, params) > 0); + + OSSL_PARAM_free(params); + OSSL_PARAM_BLD_free(param_bld); +} + +void rsa_prv_key_t::set(RSA_BASE*& rsa, const data_t& data) { + // Require full factors to reconstruct. + if (data.n == 0 || data.e == 0 || data.p == 0 || data.q == 0) { + cb_assert(false && "Incomplete RSA private key data"); + return; + } + + // Validate n == p*q. + bn_t n_check = data.p * data.q; + if (n_check != data.n) { + cb_assert(false && "Invalid RSA key data (n != p*q)"); + return; + } + + const bn_t p_minus_1 = data.p - 1; + const bn_t q_minus_1 = data.q - 1; + const bn_t phi_n = p_minus_1 * q_minus_1; + + // Compute d = e^{-1} mod phi(n). + bn_t d; + { + vartime_scope_t scope; + auto res = BN_mod_inverse(d, data.e, phi_n, bn_t::thread_local_storage_bn_ctx()); + cb_assert(res); + } + + // CRT parameters. + bn_t dp; + bn_t::div(d, p_minus_1, &dp); + bn_t dq; + bn_t::div(d, q_minus_1, &dq); + + bn_t qinv; + { + vartime_scope_t scope; + auto res = BN_mod_inverse(qinv, data.q, data.p, bn_t::thread_local_storage_bn_ctx()); + cb_assert(res); + } + + set(rsa, data.n, data.e, d, data.p, data.q, dp, dq, qinv); +} + +void rsa_prv_key_t::convert(coinbase::converter_t& converter) { + uint8_t parts = 0; + bn_t e, n, p, q; + + if (converter.is_write()) { + data_t data = get(); + if (data.e != 0) { + parts |= part_e; + e = data.e; + } + if (data.n != 0) { + parts |= part_n; + n = data.n; + } + if (data.p != 0) { + parts |= part_p; + p = data.p; + } + if (data.q != 0) { + parts |= part_q; + q = data.q; + } + } + + converter.convert(parts); + if (converter.is_error()) return; + + if (parts & part_e) converter.convert(e); + if (parts & part_n) converter.convert(n); + if (parts & part_p) converter.convert(p); + if (parts & part_q) converter.convert(q); + + if (!converter.is_write() && !converter.is_error()) { + create(); + switch (parts) { + case 0: + break; + case part_e | part_n | part_p | part_q: { + data_t data; + data.e = e; + data.n = n; + data.p = p; + data.q = q; + set(ptr, data); + break; + } + default: + converter.set_error(); + free(); + return; + } + } +} + rsa_pub_key_t rsa_prv_key_t::pub() const { rsa_pub_key_t pub_key; pub_key.set(get_n(), get_e()); return pub_key; } -error_t rsa_oaep_t::execute(hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t in, buf_t &out) const { +error_t rsa_oaep_t::execute(hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t in, buf_t& out) const { error_t rv = UNINITIALIZED_ERROR; if (!hash_alg_t::get(hash_alg).valid()) return coinbase::error(E_BADARG); if (!hash_alg_t::get(mgf_alg).valid()) return coinbase::error(E_BADARG); @@ -239,12 +404,12 @@ error_t rsa_oaep_t::execute(hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t return SUCCESS; } -error_t rsa_oaep_t::execute(void *ctx, int hash_alg, int mgf_alg, mem_t label, mem_t in, buf_t &out) { +error_t rsa_oaep_t::execute(void* ctx, int hash_alg, int mgf_alg, mem_t label, mem_t in, buf_t& out) { error_t rv = UNINITIALIZED_ERROR; if (!hash_alg_t::get(hash_e(hash_alg)).valid()) return coinbase::error(E_BADARG); if (!hash_alg_t::get(hash_e(mgf_alg)).valid()) return coinbase::error(E_BADARG); - const rsa_prv_key_t *key = (const rsa_prv_key_t *)ctx; + const rsa_prv_key_t* key = (const rsa_prv_key_t*)ctx; if (rv = key->decrypt_oaep(in, hash_e(hash_alg), hash_e(mgf_alg), label, out)) return rv; return SUCCESS; } diff --git a/src/cbmpc/crypto/base_rsa.h b/src/cbmpc/crypto/base_rsa.h deleted file mode 100644 index 2ee91e5c..00000000 --- a/src/cbmpc/crypto/base_rsa.h +++ /dev/null @@ -1,139 +0,0 @@ -#pragma once - -#include "base_bn.h" -#include "scope.h" - -typedef EVP_PKEY RSA_BASE; - -namespace coinbase::crypto { - -const int RSA_KEY_LENGTH = 2048; -class rsa_pub_key_t : public scoped_ptr_t { - public: - int size() const; - - static error_t pad_oaep(int bits, mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t &out); - static error_t pad_oaep_with_seed(int bits, mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t seed, - buf_t &out); - - error_t encrypt_raw(mem_t in, buf_t &out) const; - error_t encrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t &out) const; - error_t encrypt_oaep_with_seed(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t seed, buf_t &out) const; - error_t verify_pkcs1(mem_t data, hash_e hash_alg, mem_t signature) const; - - buf_t to_der() const; - buf_t to_der_pkcs1() const; - error_t from_der(mem_t der); - - bn_t get_e() const { return bn_t(get().e); } - bn_t get_n() const { return bn_t(get().n); } - void set(const BIGNUM *n, const BIGNUM *e) { - create(); - set(ptr, n, e); - } - - void convert(coinbase::converter_t &converter); - - bool operator==(const rsa_pub_key_t &val) const { return EVP_PKEY_eq(ptr, val.ptr); } - bool operator!=(const rsa_pub_key_t &val) const { return !EVP_PKEY_eq(ptr, val.ptr); } - - private: - struct data_t { - BIGNUM *n = nullptr, *e = nullptr; - }; - - static data_t get(const RSA_BASE *ptr); - static void set(RSA_BASE *&rsa, const BIGNUM *n, const BIGNUM *e); - - data_t get() const { return get(ptr); } - void create(); -}; - -class rsa_prv_key_t : public scoped_ptr_t { - public: - error_t execute(mem_t enc_info, buf_t &dec_info) const; - - rsa_pub_key_t pub() const; - int size() const; - - void generate(int bits, int e = 65537); - void generate(int bits, const bn_t &e); - - error_t decrypt_raw(mem_t in, buf_t &out) const; - error_t decrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t &out) const; - error_t sign_pkcs1(mem_t data, hash_e hash_alg, buf_t &sig) const; - - buf_t to_der() const; - error_t from_der(mem_t der); - - void convert(coinbase::converter_t &converter); - - bn_t get_e() const { return bn_t(get().e); } - bn_t get_n() const { return bn_t(get().n); } - bn_t get_p() const { return bn_t(get().p); } - bn_t get_q() const { return bn_t(get().q); } - - void set(const BIGNUM *n, const BIGNUM *e, const BIGNUM *d) { - create(); - set(ptr, n, e, d); - } - void set(const BIGNUM *n, const BIGNUM *e, const BIGNUM *d, const BIGNUM *p, const BIGNUM *q) { - create(); - set(ptr, n, e, d, p, q); - } - void set(const BIGNUM *n, const BIGNUM *e, const BIGNUM *d, const BIGNUM *p, const BIGNUM *q, const BIGNUM *dp, - const BIGNUM *dq, const BIGNUM *qinv) { - create(); - set(ptr, n, e, d, p, q, dp, dq, qinv); - } - error_t recover_factors(); - void set_paillier(const BIGNUM *n, const BIGNUM *p, const BIGNUM *q, const BIGNUM *dp, const BIGNUM *dq, - const BIGNUM *qinv); - - private: - struct data_t { - bn_t n, e; - bn_t p, q; - }; - static data_t get(const RSA_BASE *ptr); - static void set(RSA_BASE *rsa, const BIGNUM *n, const BIGNUM *e, const BIGNUM *d); - static void set(RSA_BASE *rsa, const BIGNUM *n, const BIGNUM *e, const BIGNUM *d, const BIGNUM *p, const BIGNUM *q); - static void set(RSA_BASE *rsa, const BIGNUM *n, const BIGNUM *e, const BIGNUM *d, const BIGNUM *p, const BIGNUM *q, - const BIGNUM *dp, const BIGNUM *dq, const BIGNUM *qinv); - static void set(RSA_BASE *rsa, const data_t &data); - - data_t get() const { return get(ptr); } - void create(); -}; - -class rsa_oaep_t { - public: - typedef error_t (*exec_t)(void *ctx, int hash_alg, int mgf_alg, mem_t label, mem_t input, buf_t &output); - - rsa_oaep_t(const rsa_prv_key_t &_key) : key(&_key), exec(nullptr), ctx(nullptr) {} - rsa_oaep_t(exec_t _exec, void *_ctx) : key(nullptr), exec(_exec), ctx(_ctx) {} - - error_t execute(hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t in, buf_t &out) const; - static error_t execute(void *ctx, int hash_alg, int mgf_alg, mem_t label, mem_t in, buf_t &out); - - private: - exec_t exec; - void *ctx; - const rsa_prv_key_t *key; -}; - -static int evp_md_size(hash_e type) { return hash_alg_t::get(type).size; } -static int evp_digest_init_ex(hash_t &ctx, hash_e type, void *impl) { - ctx.init(); - return 1; -} -static int evp_digest_update(hash_t &ctx, const void *d, size_t cnt) { - ctx.update(const_byte_ptr(d), int(cnt)); - return 1; -} -static int evp_digest_final_ex(hash_t &ctx, unsigned char *md, unsigned int *s) { - ctx.final(md); - return 1; -} - -} // namespace coinbase::crypto diff --git a/src/cbmpc/crypto/base_rsa_oaep.cpp b/src/cbmpc/crypto/base_rsa_oaep.cpp index 9426a63d..b8c67e20 100644 --- a/src/cbmpc/crypto/base_rsa_oaep.cpp +++ b/src/cbmpc/crypto/base_rsa_oaep.cpp @@ -1,4 +1,4 @@ -#include "base.h" +#include /* * Written by Ulf Moeller. This software is distributed on an "AS IS" basis, @@ -20,13 +20,13 @@ namespace coinbase::crypto { -static int mgf1_xor(unsigned char *out, size_t outlen, const unsigned char *seed, size_t seedlen, const EVP_MD *md, - OSSL_LIB_CTX *libctx, const char *propq) { +static int mgf1_xor(unsigned char* out, size_t outlen, const unsigned char* seed, size_t seedlen, const EVP_MD* md, + OSSL_LIB_CTX* libctx, const char* propq) { unsigned char dig[EVP_MAX_MD_SIZE]; unsigned int counter = 0; size_t done = 0; unsigned int mdsize = 0; - EVP_MD_CTX *ctx = EVP_MD_CTX_new(); + EVP_MD_CTX* ctx = EVP_MD_CTX_new(); if (ctx == NULL) return -1; mdsize = EVP_MD_get_size(md); @@ -72,10 +72,10 @@ static int mgf1_xor(unsigned char *out, size_t outlen, const unsigned char *seed * to avoid complicating an already difficult enough function. */ // NOLINTBEGIN -static int ossl_rsa_padding_add_PKCS1_OAEP_mgf1_ex(OSSL_LIB_CTX *libctx, unsigned char *to, int tlen, - const unsigned char *from, int flen, const unsigned char *param, - int plen, const EVP_MD *md, const EVP_MD *mgf1md, - const unsigned char *seed_data, int seedlen) { +static int ossl_rsa_padding_add_PKCS1_OAEP_mgf1_ex(OSSL_LIB_CTX* libctx, unsigned char* to, int tlen, + const unsigned char* from, int flen, const unsigned char* param, + int plen, const EVP_MD* md, const EVP_MD* mgf1md, + const unsigned char* seed_data, int seedlen) { int rv = 0; int emlen = tlen - 1; unsigned char *db, *seed; @@ -115,7 +115,7 @@ static int ossl_rsa_padding_add_PKCS1_OAEP_mgf1_ex(OSSL_LIB_CTX *libctx, unsigne db = to + mdlen + 1; /* step 3a: hash the additional input */ - if (!EVP_Digest((void *)param, plen, db, NULL, md, NULL)) goto err; + if (!EVP_Digest((void*)param, plen, db, NULL, md, NULL)) goto err; /* step 3b: zero bytes array of length nLen - KLen - 2 HLen -2 */ memset(db + mdlen, 0, emlen - flen - 2 * mdlen - 1); /* step 3c: DB = HA || PS || 00000001 || K */ @@ -138,7 +138,7 @@ static int ossl_rsa_padding_add_PKCS1_OAEP_mgf1_ex(OSSL_LIB_CTX *libctx, unsigne // NOLINTEND error_t rsa_pub_key_t::pad_oaep_with_seed(int bits, mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t seed, - buf_t &out) // static + buf_t& out) // static { int key_size = coinbase::bits_to_bytes(bits); if (0 >= ossl_rsa_padding_add_PKCS1_OAEP_mgf1_ex(NULL, out.alloc(key_size), key_size, in.data, in.size, label.data, @@ -148,13 +148,13 @@ error_t rsa_pub_key_t::pad_oaep_with_seed(int bits, mem_t in, hash_e hash_alg, h return SUCCESS; } -error_t rsa_pub_key_t::pad_oaep(int bits, mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t &out) // static +error_t rsa_pub_key_t::pad_oaep(int bits, mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t& out) // static { int seed_size = hash_alg_t::get(hash_alg).size; return pad_oaep_with_seed(bits, in, hash_alg, mgf_alg, label, gen_random(seed_size), out); } -error_t rsa_prv_key_t::decrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t &out) const { +error_t rsa_prv_key_t::decrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t& out) const { int n_size = size(); if (in.size != n_size) return coinbase::error(E_CRYPTO); @@ -169,9 +169,9 @@ error_t rsa_prv_key_t::decrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, m // using OPENSSL_free. Only pass memory allocated by OPENSSL_malloc, otherwise // we risk invalid-free crashes. if (label.size > 0) { - auto openssl_deleter = [](uint8_t *p) { OPENSSL_free(p); }; + auto openssl_deleter = [](uint8_t* p) { OPENSSL_free(p); }; std::unique_ptr label_ptr( - static_cast(OPENSSL_memdup(label.data, static_cast(label.size))), openssl_deleter); + static_cast(OPENSSL_memdup(label.data, static_cast(label.size))), openssl_deleter); if (!label_ptr) return openssl_error("RSA decrypt OAEP error"); if (EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, label_ptr.get(), label.size) <= 0) { return openssl_error("RSA decrypt OAEP error"); @@ -189,14 +189,14 @@ error_t rsa_prv_key_t::decrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, m } error_t rsa_pub_key_t::encrypt_oaep_with_seed(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, mem_t seed, - buf_t &out) const { + buf_t& out) const { error_t rv = UNINITIALIZED_ERROR; buf_t padded; if (rv = pad_oaep_with_seed(size() * 8, in, hash_alg, mgf_alg, label, seed, padded)) return rv; return rv = encrypt_raw(padded, out); } -error_t rsa_pub_key_t::encrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t &out) const { +error_t rsa_pub_key_t::encrypt_oaep(mem_t in, hash_e hash_alg, hash_e mgf_alg, mem_t label, buf_t& out) const { error_t rv = UNINITIALIZED_ERROR; buf_t padded; if (rv = pad_oaep(size() * 8, in, hash_alg, mgf_alg, label, padded)) return rv; diff --git a/src/cbmpc/crypto/drbg.cpp b/src/cbmpc/crypto/drbg.cpp index 8c9a4f30..7276384d 100644 --- a/src/cbmpc/crypto/drbg.cpp +++ b/src/cbmpc/crypto/drbg.cpp @@ -1,4 +1,4 @@ -#include +#include namespace coinbase::crypto { @@ -33,7 +33,7 @@ void drbg_aes_ctr_t::seed(mem_t in) { void drbg_aes_ctr_t::gen(mem_t out) { out.bzero(); - ctr.update(out, out.data); + ctr.update(out, const_cast(out.data)); } bn_t drbg_aes_ctr_t::gen_bn(const mod_t& mod) { return gen_bn(mod.get_bits_count() + SEC_P_STAT) % mod; } diff --git a/src/cbmpc/crypto/ec25519_core.cpp b/src/cbmpc/crypto/ec25519_core.cpp index d7293b86..6b824ecc 100644 --- a/src/cbmpc/crypto/ec25519_core.cpp +++ b/src/cbmpc/crypto/ec25519_core.cpp @@ -1,10 +1,8 @@ -#include "ec25519_core.h" - -#include -#include - -#include "base_ec_core.h" +#include +#include +#include +#include #define EXTENDED_COORD @@ -1005,6 +1003,7 @@ extern "C" int ED25519_sign_with_scalar(uint8_t* out_sig, const uint8_t* message for (int i = 0; i < 32; i++) az[i] = scalar_bin[31 - i]; sign_with_nonce(out_sig, message, message_len, public_key, az, nonce); + OPENSSL_cleanse(nonce, sizeof(nonce)); OPENSSL_cleanse(az, sizeof(az)); return 1; } diff --git a/src/cbmpc/crypto/elgamal.cpp b/src/cbmpc/crypto/elgamal.cpp index bb6be60d..ce42467d 100644 --- a/src/cbmpc/crypto/elgamal.cpp +++ b/src/cbmpc/crypto/elgamal.cpp @@ -1,7 +1,6 @@ -#include "elgamal.h" - -#include -#include +#include +#include +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/lagrange.cpp b/src/cbmpc/crypto/lagrange.cpp index 489298f3..67d3a3b5 100644 --- a/src/cbmpc/crypto/lagrange.cpp +++ b/src/cbmpc/crypto/lagrange.cpp @@ -1,4 +1,4 @@ -#include "lagrange.h" +#include namespace coinbase::crypto { diff --git a/src/cbmpc/crypto/pki_ffi.h b/src/cbmpc/crypto/pki_ffi.h deleted file mode 100644 index fa68faa2..00000000 --- a/src/cbmpc/crypto/pki_ffi.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include -#include - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// Forward declarations of functions that retrieve callbacks exposed via FFI. - -// Digital signature functions -typedef int (*ffi_sign_fn)(cmem_t /* sk */, cmem_t /* hash */, cmem_t* /* signature out */); -typedef int (*ffi_verify_fn)(cmem_t /* vk */, cmem_t /* hash */, cmem_t /* signature */); - -ffi_sign_fn get_ffi_sign_fn(void); -ffi_verify_fn get_ffi_verify_fn(void); - -// KEM functions -typedef int (*ffi_kem_encap_fn)(cmem_t /* ek_bytes */, cmem_t /* rho */, cmem_t* /* kem_ct out */, - cmem_t* /* kem_ss out */); -// Private key is treated as an opaque, process-local handle managed by the host. -// It must not be serialized or inspected by the callee. -typedef int (*ffi_kem_decap_fn)(const void* /* dk_handle */, cmem_t /* kem_ct */, cmem_t* /* kem_ss out */); -typedef int (*ffi_kem_dk_to_ek_fn)(const void* /* dk_handle */, cmem_t* /* out ek_bytes */); - -ffi_kem_encap_fn get_ffi_kem_encap_fn(void); -ffi_kem_decap_fn get_ffi_kem_decap_fn(void); -ffi_kem_dk_to_ek_fn get_ffi_kem_dk_to_ek_fn(void); - -#ifdef __cplusplus -} -#endif diff --git a/src/cbmpc/crypto/ro.cpp b/src/cbmpc/crypto/ro.cpp index a2aa7ffc..6a766661 100644 --- a/src/cbmpc/crypto/ro.cpp +++ b/src/cbmpc/crypto/ro.cpp @@ -1,6 +1,5 @@ -#include "ro.h" - -#include +#include +#include namespace coinbase::crypto::ro { // random oracle diff --git a/src/cbmpc/crypto/secret_sharing.cpp b/src/cbmpc/crypto/secret_sharing.cpp index b3ce3efb..278135cb 100644 --- a/src/cbmpc/crypto/secret_sharing.cpp +++ b/src/cbmpc/crypto/secret_sharing.cpp @@ -1,13 +1,12 @@ -#include "secret_sharing.h" - -#include -#include -#include -#include +#include +#include +#include +#include +#include namespace coinbase::crypto::ss { -std::vector share_and(const mod_t &q, const bn_t &x, const int n, crypto::drbg_aes_ctr_t *drbg) { +std::vector share_and(const mod_t& q, const bn_t& x, const int n, crypto::drbg_aes_ctr_t* drbg) { cb_assert(n > 0); std::vector shares(n); bn_t sum = 0; @@ -22,9 +21,9 @@ std::vector share_and(const mod_t &q, const bn_t &x, const int n, crypto:: return shares; } -std::pair, std::vector> share_threshold(const mod_t &q, const bn_t &a, const int threshold, - const int n, const std::vector &pids, - crypto::drbg_aes_ctr_t *drbg) { +std::pair, std::vector> share_threshold(const mod_t& q, const bn_t& a, const int threshold, + const int n, const std::vector& pids, + crypto::drbg_aes_ctr_t* drbg) { std::vector shares(n); std::vector b(threshold); cb_assert(threshold > 0); @@ -48,12 +47,12 @@ node_t::~node_t() { for (auto node : children) delete node; } -void node_t::add_child_node(node_t *node) { +void node_t::add_child_node(node_t* node) { children.push_back(node); node->parent = this; } -error_t node_t::validate_tree(std::set &names) const { +error_t node_t::validate_tree(std::set& names) const { error_t rv = UNINITIALIZED_ERROR; if (name.empty() && parent) return coinbase::error(E_BADARG, "unnamed node"); if (!parent && !name.empty()) return coinbase::error(E_BADARG, "named root node"); @@ -83,13 +82,13 @@ error_t node_t::validate_tree(std::set &names) const { return coinbase::error(E_BADARG, "invalid node type"); } - for (const node_t *child : children) + for (const node_t* child : children) if (rv = child->validate_tree(names)) return rv; return SUCCESS; } -void node_t::convert_node(coinbase::converter_t &c) { +void node_t::convert_node(coinbase::converter_t& c) { int temp = int(type); c.convert(temp); type = node_e(temp); @@ -99,7 +98,7 @@ void node_t::convert_node(coinbase::converter_t &c) { c.convert_len(n); for (int i = 0; i < n; i++) { - node_t *child = c.is_write() ? children[i] : new node_t(); + node_t* child = c.is_write() ? children[i] : new node_t(); child->convert_node(c); if (c.is_error()) { @@ -112,25 +111,35 @@ void node_t::convert_node(coinbase::converter_t &c) { } // ac stands for access structure -void ac_owned_t::convert(coinbase::converter_t &c) // static +void ac_owned_t::convert(coinbase::converter_t& c) // static { bool exists = (root != nullptr); c.convert(exists); error_t rv = UNINITIALIZED_ERROR; if (exists) { + c.convert(curve); + if (!c.is_write() && !curve.valid()) { + rv = coinbase::error(E_FORMAT, "access structure: invalid curve"); + c.set_error(rv); + delete root; + root = nullptr; + return; + } if (!c.is_write()) { delete root; root = new node_t(); } - ((node_t *)root)->convert_node(c); + ((node_t*)root)->convert_node(c); if (c.is_write()) return; if (!c.is_error()) { - rv = root->validate_tree(); + rv = validate_tree(); if (rv == 0) return; } + } else if (!c.is_write()) { + curve = nullptr; } delete root; @@ -138,13 +147,13 @@ void ac_owned_t::convert(coinbase::converter_t &c) // static if (rv) c.set_error(rv); } -std::vector node_t::get_sorted_children() const { - std::vector sorted = children; - std::sort(sorted.begin(), sorted.end(), [](node_t *n1, node_t *n2) -> auto { return n1->name < n2->name; }); +std::vector node_t::get_sorted_children() const { + std::vector sorted = children; + std::sort(sorted.begin(), sorted.end(), [](node_t* n1, node_t* n2) -> auto { return n1->name < n2->name; }); return sorted; } -static int find_child_index(const node_t *node, const std::string &name) { +static int find_child_index(const node_t* node, const std::string& name) { int n = int(node->children.size()); for (int i = 0; i < n; i++) { if (node->children[i]->name == name) return i; @@ -152,9 +161,9 @@ static int find_child_index(const node_t *node, const std::string &name) { return -1; } -node_t *node_t::clone() const { - node_t *node = new node_t(type, name, threshold); - for (const node_t *child : children) { +node_t* node_t::clone() const { + node_t* node = new node_t(type, name, threshold); + for (const node_t* child : children) { node->add_child_node(child->clone()); } return node; @@ -162,7 +171,7 @@ node_t *node_t::clone() const { void node_t::remove_and_delete() { if (parent) { - auto &parent_list = parent->children; + auto& parent_list = parent->children; auto it = std::find(parent_list.begin(), parent_list.end(), this); if (it != parent_list.end()) parent_list.erase(it); } @@ -171,7 +180,7 @@ void node_t::remove_and_delete() { std::string node_t::get_path() const { std::string path; - const node_t *node = this; + const node_t* node = this; while (node) { if (path.empty()) path = node->name; @@ -182,27 +191,27 @@ std::string node_t::get_path() const { return path; } -bn_t node_t::pid_from_path(const std::string &path) { return pid_from_name(strext::tokenize(path, "/").back()); } +bn_t node_t::pid_from_path(const std::string& path) { return pid_from_name(strext::tokenize(path, "/").back()); } bn_t node_t::get_pid() const { return pid_from_name(name); } -const node_t *node_t::find(const pname_t &name) const { +const node_t* node_t::find(const pname_t& name) const { if (this->name == name) return this; for (const auto child : children) { - const node_t *res = child->find(name); + const node_t* res = child->find(name); if (res) return res; } return nullptr; } -static void list_leaf_paths_recursive(const node_t *node, const std::string &parent_path, - std::vector &list) { +static void list_leaf_paths_recursive(const node_t* node, const std::string& parent_path, + std::vector& list) { std::string path = get_node_path(parent_path, node); if (node->type == node_e::LEAF) { list.push_back(path); } else { - for (const node_t *child : node->children) list_leaf_paths_recursive(child, path, list); + for (const node_t* child : node->children) list_leaf_paths_recursive(child, path, list); } } @@ -212,11 +221,11 @@ std::vector node_t::list_leaf_paths() const { return list; } -static void list_leaf_names_recursive(const node_t *node, std::set &list) { +static void list_leaf_names_recursive(const node_t* node, std::set& list) { if (node->type == node_e::LEAF) { list.insert(node->name); } else { - for (const node_t *child : node->children) list_leaf_names_recursive(child, list); + for (const node_t* child : node->children) list_leaf_names_recursive(child, list); } } @@ -226,7 +235,7 @@ std::set node_t::list_leaf_names() const { return list; } -bool node_t::enough_for_quorum(const std::set &names) const { +bool node_t::enough_for_quorum(const std::set& names) const { int count = 0; switch (type) { @@ -258,10 +267,10 @@ bool node_t::enough_for_quorum(const std::set &names) const { return false; } -static void share_recursive(const mod_t &q, const ecc_point_t &G, const node_t *node, const bn_t &a, - const bool output_additional_data, ac_shares_t &ac_shares, - ac_internal_shares_t &ac_internal_shares, ac_internal_pub_shares_t &ac_internal_pub_shares, - drbg_aes_ctr_t *drbg) { +static void share_recursive(const mod_t& q, const ecc_point_t& G, const node_t* node, const bn_t& a, + const bool output_additional_data, ac_shares_t& ac_shares, + ac_internal_shares_t& ac_internal_shares, ac_internal_pub_shares_t& ac_internal_pub_shares, + drbg_aes_ctr_t* drbg) { auto sorted_children = node->get_sorted_children(); int n = int(sorted_children.size()); @@ -307,33 +316,40 @@ static void share_recursive(const mod_t &q, const ecc_point_t &G, const node_t * } } -ac_shares_t ac_t::share(const mod_t &q, const bn_t &x, drbg_aes_ctr_t *drbg) const { +ac_shares_t ac_t::share(const mod_t& q, const bn_t& x, drbg_aes_ctr_t* drbg) const { ac_shares_t shares; ac_internal_shares_t dummy; ac_internal_pub_shares_t dummy_pub; bool output_additional_data = false; - share_recursive(q, G, root, x, output_additional_data, shares, dummy, dummy_pub, drbg); + ecc_point_t dummy_G; + share_recursive(q, dummy_G, root, x, output_additional_data, shares, dummy, dummy_pub, drbg); return shares; } -error_t ac_t::share_with_internals(const mod_t &q, const bn_t &x, ac_shares_t &shares, - ac_internal_shares_t &ac_internal_shares, - ac_internal_pub_shares_t &ac_internal_pub_shares, drbg_aes_ctr_t *drbg) const { +error_t ac_t::share_with_internals(const mod_t& q, const bn_t& x, ac_shares_t& shares, + ac_internal_shares_t& ac_internal_shares, + ac_internal_pub_shares_t& ac_internal_pub_shares, drbg_aes_ctr_t* drbg) const { + if (!root) return coinbase::error(E_BADARG, "missing root"); + if (!curve.valid()) return coinbase::error(E_BADARG, "missing curve"); + if (q != curve.order()) return coinbase::error(E_BADARG, "invalid modulus"); bool output_additional_data = true; - share_recursive(q, G, root, x, output_additional_data, shares, ac_internal_shares, ac_internal_pub_shares, drbg); + share_recursive(q, curve.generator(), root, x, output_additional_data, shares, ac_internal_shares, + ac_internal_pub_shares, drbg); return SUCCESS; } -error_t ac_t::verify_share_against_ancestors_pub_data(const ecc_point_t &Q, const bn_t &si, - const ac_internal_pub_shares_t &pub_data, - const pname_t &leaf) const { - vartime_scope_t vartime_scope; +error_t ac_t::verify_share_against_ancestors_pub_data(const ecc_point_t& Q, const bn_t& si, + const ac_internal_pub_shares_t& pub_data, + const pname_t& leaf) const { + if (!curve.valid()) return coinbase::error(E_BADARG, "missing curve"); + if (Q.get_curve() != curve) return coinbase::error(E_BADARG, "curve mismatch"); auto node = find(leaf); if (node == nullptr || node->type != node_e::LEAF) return coinbase::error(E_NOT_FOUND); - ecc_point_t expected_pub_share = si * G; - const node_t *child = nullptr; + ecc_point_t expected_pub_share = si * curve.generator(); + vartime_scope_t vartime_scope; + const node_t* child = nullptr; while (node != nullptr) { auto sorted_children = node->get_sorted_children(); @@ -346,7 +362,7 @@ error_t ac_t::verify_share_against_ancestors_pub_data(const ecc_point_t &Q, cons return coinbase::error(E_CRYPTO); } } else if (node->type == node_e::AND) { - ecc_point_t expected_sum = Q.get_curve().infinity(); + ecc_point_t expected_sum = curve.infinity(); for (size_t i = 0; i < sorted_children.size(); i++) { auto child_pub_shares = pub_data.at(sorted_children[i]->name); expected_sum += child_pub_shares; @@ -384,17 +400,20 @@ error_t ac_t::verify_share_against_ancestors_pub_data(const ecc_point_t &Q, cons return SUCCESS; } -static error_t reconstruct_recursive(const mod_t &q, const node_t *node, const ac_shares_t &shares, bn_t &x) { +static error_t reconstruct_recursive(const mod_t& q, const node_t* node, const ac_shares_t& shares, bn_t& x) { error_t rv = UNINITIALIZED_ERROR; int n = node->get_n(); switch (node->type) { case node_e::LEAF: { - const auto &[found, share] = lookup(shares, node->name); + const auto& [found, share] = lookup(shares, node->name); if (!found) { + // Missing leaf shares are expected in threshold/OR reconstructions. + // Do not emit stack traces / dylog output for this control-flow condition. + dylog_disable_scope_t dylog_disable_scope; return coinbase::error(E_INSUFFICIENT); } - x = share; + x = *share; } break; case node_e::OR: for (int i = 0; i < n; i++) { @@ -449,23 +468,23 @@ static error_t reconstruct_recursive(const mod_t &q, const node_t *node, const a return SUCCESS; } -error_t ac_t::reconstruct(const mod_t &q, const ac_shares_t &shares, bn_t &x) const { +error_t ac_t::reconstruct(const mod_t& q, const ac_shares_t& shares, bn_t& x) const { return reconstruct_recursive(q, root, shares, x); } -static error_t reconstruct_exponent_recursive(const node_t *node, const ac_pub_shares_t &shares, ecc_point_t &P) { +static error_t reconstruct_exponent_recursive(const node_t* node, const ac_pub_shares_t& shares, ecc_point_t& P) { error_t rv = UNINITIALIZED_ERROR; int n = node->get_n(); - const pname_t &name = node->name; + const pname_t& name = node->name; switch (node->type) { case node_e::LEAF: { - const auto &[found, share] = lookup(shares, name); + const auto& [found, share] = lookup(shares, name); if (!found) { dylog_disable_scope_t dylog_disable_scope; return coinbase::error(E_INSUFFICIENT, "missing share for leaf node " + name); } - P = share; + P = *share; } break; case node_e::OR: @@ -525,15 +544,23 @@ static error_t reconstruct_exponent_recursive(const node_t *node, const ac_pub_s return SUCCESS; } -error_t ac_t::reconstruct_exponent(const ac_pub_shares_t &shares, ecc_point_t &P) const { +error_t ac_t::reconstruct_exponent(const ac_pub_shares_t& shares, ecc_point_t& P) const { + if (!root) return coinbase::error(E_BADARG, "missing root"); + if (!curve.valid()) return coinbase::error(E_BADARG, "missing curve"); + allow_ecc_infinity_t allow_ecc_infinity; + for (const auto& [name, share] : shares) { + error_t rv = curve.check(share); + if (rv) return coinbase::error(rv, "invalid share point for " + name); + } + return reconstruct_exponent_recursive(root, shares, P); } -static void list_pub_data_nodes_recursive(const node_t *node, std::set &node_set) { +static void list_pub_data_nodes_recursive(const node_t* node, std::set& node_set) { if (node->type == node_e::LEAF) { return; } - for (const node_t *child : node->children) { + for (const node_t* child : node->children) { list_pub_data_nodes_recursive(child, node_set); } if (node->type == node_e::AND || node->type == node_e::THRESHOLD) { @@ -541,8 +568,8 @@ static void list_pub_data_nodes_recursive(const node_t *node, std::set ac_t::list_pub_data_nodes() const { - std::set nodes; +std::set ac_t::list_pub_data_nodes() const { + std::set nodes; list_pub_data_nodes_recursive(root, nodes); return nodes; } diff --git a/src/cbmpc/crypto/secret_sharing.h b/src/cbmpc/crypto/secret_sharing.h deleted file mode 100644 index a2493083..00000000 --- a/src/cbmpc/crypto/secret_sharing.h +++ /dev/null @@ -1,176 +0,0 @@ -#pragma once - -#include -#include - -namespace coinbase::crypto::ss { - -template -using party_map_t = std::map; - -std::vector share_and(const mod_t &q, const bn_t &x, const int n, crypto::drbg_aes_ctr_t *drbg = nullptr); -std::pair, std::vector> share_threshold(const mod_t &q, const bn_t &a, const int threshold, - const int n, const std::vector &pids, - crypto::drbg_aes_ctr_t *drbg = nullptr); - -enum class node_e { - NONE = 0, - LEAF = 1, - AND = 2, - OR = 3, - THRESHOLD = 4, -}; - -class node_t; - -typedef party_map_t ac_shares_t; -typedef party_map_t ac_internal_shares_t; -typedef party_map_t ac_pub_shares_t; -typedef party_map_t ac_internal_pub_shares_t; - -class ac_t; -class ac_owned_t; - -struct node_t { - friend class ac_t; - friend class ac_owned_t; - - node_e type; - pname_t name; - int threshold; - std::vector children; - node_t *parent = nullptr; - - node_t(node_e _type, pname_t _name, int _threshold = 0) : type(_type), name(_name), threshold(_threshold) {} - - node_t(node_e _type, pname_t _name, int _threshold, std::initializer_list nodes) - : type(_type), name(_name), threshold(_threshold), children(nodes) { - for (auto child : nodes) { - child->parent = this; - } - } - - ~node_t(); - node_t *clone() const; - - int get_n() const { return int(children.size()); } - std::string get_path() const; - - static bn_t pid_from_path(const std::string &path); - bn_t get_pid() const; - - std::vector list_leaf_paths() const; - std::set list_leaf_names() const; - const node_t *find(const pname_t &path) const; - void add_child_node(node_t *node); - void remove_and_delete(); - - error_t validate_tree() const { - std::set names; - return validate_tree(names); - } - error_t validate_tree(std::set &names) const; - bool enough_for_quorum(const std::set &names) const; - - std::vector get_sorted_children() const; - - private: - node_t() {} - void convert_node(coinbase::converter_t &c); -}; - -static std::string get_node_path(const std::string &parent_path, const node_t *node) { - if (!node->parent) return ""; - return parent_path + "/" + node->name; -} - -class ac_t { - public: - explicit ac_t() {} - explicit ac_t(const node_t *_root) : root(_root) {} - - const node_t *get_root() const { return root; } - bool has_root() const { return root != nullptr; } - - error_t validate_tree() const { - if (!root) return coinbase::error(E_BADARG, "missing root"); - return root->validate_tree(); - } - - const node_t *find(const pname_t &name) const { return root->find(name); } - std::set list_leaf_names() const { return root->list_leaf_names(); } - std::set list_pub_data_nodes() const; - int get_pub_data_size(const node_t *node) const { - if (node->type == node_e::AND) - return node->get_n(); - else if (node->type == node_e::THRESHOLD) - return node->threshold; - else - return 0; - } - - bool enough_for_quorum(const std::set names) const { return root ? root->enough_for_quorum(names) : false; } - template - bool enough_for_quorum(const party_map_t &map) const { - std::set names; - for (const auto &[name, value] : map) names.insert(name); - return root ? root->enough_for_quorum(names) : false; - } - - /** - * @specs: - * - basic-primitives-spec | ac-Share-1P - */ - ac_shares_t share(const mod_t &q, const bn_t &x, drbg_aes_ctr_t *drbg = nullptr) const; - error_t share_with_internals(const mod_t &q, const bn_t &x, ac_shares_t &shares, - ac_internal_shares_t &ac_internal_shares, - ac_internal_pub_shares_t &ac_internal_pub_shares, drbg_aes_ctr_t *drbg = nullptr) const; - error_t verify_share_against_ancestors_pub_data(const ecc_point_t &Q, const bn_t &si, - const ac_internal_pub_shares_t &pub_data, const pname_t &leaf) const; - - /** - * @specs: - * - basic-primitives-spec | ac-Reconstruct-1P - */ - error_t reconstruct(const mod_t &q, const ac_shares_t &shares, bn_t &x) const; - - /** - * @specs: - * - basic-primitives-spec | ac-Reconstruct-Exponent-1P - */ - error_t reconstruct_exponent(const ac_pub_shares_t &shares, ecc_point_t &P) const; - - const node_t *root = nullptr; - ecc_point_t G; -}; - -class ac_owned_t : public ac_t { - public: - ac_owned_t() = default; - explicit ac_owned_t(const node_t *_root) { assign(_root); } - explicit ac_owned_t(const ac_t &ac) { assign(ac.root); } - ~ac_owned_t() { delete root; } - void assign(const node_t *_root) { - delete root; - root = _root->clone(); - } - ac_owned_t(const ac_owned_t &src) : ac_t() { assign(src.root); } - ac_owned_t(ac_owned_t &&src) : ac_t() { - root = src.root; - src.root = nullptr; - } - ac_owned_t &operator=(const ac_owned_t &src) { - if (&src != this) assign(src.root); - return *this; - } - ac_owned_t &operator=(ac_owned_t &&src) { - if (&src != this) { - root = src.root; - src.root = nullptr; - } - return *this; - } - void convert(coinbase::converter_t &c); -}; - -} // namespace coinbase::crypto::ss \ No newline at end of file diff --git a/src/cbmpc/crypto/tdh2.cpp b/src/cbmpc/crypto/tdh2.cpp index cb77153c..25d41a85 100644 --- a/src/cbmpc/crypto/tdh2.cpp +++ b/src/cbmpc/crypto/tdh2.cpp @@ -1,6 +1,5 @@ -#include "tdh2.h" - -#include +#include +#include namespace coinbase::crypto::tdh2 { @@ -9,11 +8,10 @@ constexpr int tag_size = 16; ciphertext_t public_key_t::encrypt(mem_t plain, mem_t label) const { const auto& curve = Q.get_curve(); - const mod_t& q = curve.order(); buf_t iv = gen_random(iv_size); - bn_t r = bn_t::rand(q); - bn_t s = bn_t::rand(q); + bn_t r = curve.get_random_value(); + bn_t s = curve.get_random_value(); return encrypt(plain, label, r, s, iv); } @@ -51,10 +49,12 @@ error_t ciphertext_t::verify(const public_key_t& pub_key, mem_t label) const { const mod_t& q = curve.order(); if (label != L) return coinbase::error(E_CRYPTO, "ciphertext_t::verify: label mismatch"); + if (iv.size() != iv_size) return coinbase::error(E_CRYPTO, "ciphertext_t::verify: invalid iv"); + if (!q.is_in_range(e) || !q.is_in_range(f)) return coinbase::error(E_CRYPTO, "ciphertext_t::verify: invalid scalar"); if (rv = curve.check(R1)) return coinbase::error(rv, "ciphertext_t::verify: check R1 failed"); if (rv = curve.check(R2)) return coinbase::error(rv, "ciphertext_t::verify: check R2 failed"); - if (Gamma != ro::hash_curve(mem_t("TDH2-Gamma"), Q).curve(Q.get_curve())) + if (Gamma != ro::hash_curve(mem_t("TDH2-Gamma"), Q, pub_key.sid).curve(Q.get_curve())) return coinbase::error(E_CRYPTO, "ciphertext_t::verify: Gamma mismatch"); ecc_point_t W1 = f * G - e * R1; @@ -79,7 +79,7 @@ error_t private_share_t::decrypt(const ciphertext_t& ciphertext, mem_t label, bn_t& ei = partial_decryption.ei; bn_t& fi = partial_decryption.fi; - partial_decryption.pid = pid; + partial_decryption.rid = rid; Xi = x * R1; bn_t si = curve.get_random_value(); @@ -112,6 +112,7 @@ error_t partial_decryption_t::check_partial_decryption_helper(const ecc_point_t& const auto& G = curve.generator(); const mod_t& q = curve.order(); + if (!q.is_in_range(ei) || !q.is_in_range(fi)) return coinbase::error(E_CRYPTO); const ecc_point_t& R1 = ciphertext.R1; ecc_point_t Yi = fi * R1 - ei * Xi; @@ -140,9 +141,9 @@ error_t combine_additive(const public_key_t& pub_key, const pub_shares_t& Qi, me for (int i = 0; i < n; i++) { const partial_decryption_t& partial_decryption = partial_decryptions[i]; - int pid = partial_decryption.pid; - if (pid < 1 || pid > n) return coinbase::error(E_CRYPTO); - if (rv = partial_decryption.check_partial_decryption_helper(Qi[pid - 1], ciphertext, curve)) return rv; + const int rid = partial_decryption.rid; + if (rid < 1 || rid > n) return coinbase::error(E_CRYPTO); + if (rv = partial_decryption.check_partial_decryption_helper(Qi[rid - 1], ciphertext, curve)) return rv; V += partial_decryption.Xi; } diff --git a/src/cbmpc/ffi/CMakeLists.txt b/src/cbmpc/ffi/CMakeLists.txt deleted file mode 100644 index 52bbbeec..00000000 --- a/src/cbmpc/ffi/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_library(cbmpc_ffi OBJECT - cmem_adapter.cpp - pki.cpp -) - -target_include_directories(cbmpc_ffi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) - -# Ensure OpenSSL headers/libs are available (matches other components) -link_openssl(cbmpc_ffi) - - diff --git a/src/cbmpc/ffi/cmem_adapter.cpp b/src/cbmpc/ffi/cmem_adapter.cpp deleted file mode 100644 index bbceacc1..00000000 --- a/src/cbmpc/ffi/cmem_adapter.cpp +++ /dev/null @@ -1,83 +0,0 @@ -#include "cmem_adapter.h" - -#include -#include - -extern "C" { -// NOLINTNEXTLINE(cppcoreguidelines-no-malloc) -void* cgo_malloc(int size) { return std::malloc(static_cast(size)); } - -// NOLINTNEXTLINE(cppcoreguidelines-no-malloc) -void cgo_free(void* ptr) { std::free(ptr); } -} // extern "C" - -namespace coinbase::ffi { - -buf_t copy_from_cmem_and_free(cmem_t cmem) { - buf_t buf(cmem.data, cmem.size); - cgo_free(cmem.data); - return buf; -} - -cmem_t copy_to_cmem(mem_t mem) { - cmem_t out{nullptr, mem.size}; - if (mem.size > 0) { - out.data = static_cast(cgo_malloc(mem.size)); - if (out.data) std::memmove(out.data, mem.data, mem.size); - } - return out; -} - -cmem_t copy_to_cmem(const buf_t& buf) { return copy_to_cmem(mem_t(buf)); } - -std::vector view_cmems(cmems_t cmems) { - std::vector out; - if (cmems.count == 0) return out; - out.reserve(cmems.count); - int offset = 0; - for (int i = 0; i < cmems.count; i++) { - const int sz = cmems.sizes[i]; - out.emplace_back(cmems.data + offset, sz); - offset += sz; - } - return out; -} - -std::vector bufs_from_cmems(cmems_t cmems) { - auto mems = view_cmems(cmems); - std::vector bufs; - bufs.reserve(mems.size()); - for (const auto& m : mems) bufs.emplace_back(m); - return bufs; -} - -cmems_t copy_to_cmems(const std::vector& mems) { - cmems_t out{0, nullptr, nullptr}; - const auto count = static_cast(mems.size()); - if (count == 0) return out; - - // Calculate total bytes. - int total = 0; - for (const auto& m : mems) total += m.size; - - out.count = count; - out.data = static_cast(cgo_malloc(total)); - out.sizes = static_cast(cgo_malloc(sizeof(int) * count)); - if (!out.data || !out.sizes) { - cgo_free(out.data); - cgo_free(out.sizes); - return cmems_t{0, nullptr, nullptr}; - } - - int offset = 0; - for (int i = 0; i < count; i++) { - out.sizes[i] = mems[i].size; - if (mems[i].size) { - std::memmove(out.data + offset, mems[i].data, mems[i].size); - offset += mems[i].size; - } - } - return out; -} - -} // namespace coinbase::ffi diff --git a/src/cbmpc/ffi/cmem_adapter.h b/src/cbmpc/ffi/cmem_adapter.h deleted file mode 100644 index 087c66c7..00000000 --- a/src/cbmpc/ffi/cmem_adapter.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include - -#include -#include - -// C-callable allocators used by FFI layers (e.g., cgo). -extern "C" { -void* cgo_malloc(int size); -void cgo_free(void* ptr); -} - -namespace coinbase::ffi { - -// Non-owning view of a cmem_t as mem_t. -inline mem_t view(cmem_t cmem) { return mem_t(cmem.data, cmem.size); } - -// Copy cmem into a new buf_t and free the source buffer. -buf_t copy_from_cmem_and_free(cmem_t cmem); - -// Copy mem/buf into freshly allocated cmem_t owned by the caller. -cmem_t copy_to_cmem(mem_t mem); -cmem_t copy_to_cmem(const buf_t& buf); - -// Non-owning view of cmems_t (no freeing). -std::vector view_cmems(cmems_t cmems); - -// Copy cmems_t into new buffers (does not free the source). -std::vector bufs_from_cmems(cmems_t cmems); - -// Convert a flat list of mem views into cmems_t (data + sizes). -cmems_t copy_to_cmems(const std::vector& mems); - -} // namespace coinbase::ffi diff --git a/src/cbmpc/ffi/pki.cpp b/src/cbmpc/ffi/pki.cpp deleted file mode 100644 index 2ad9a9d9..00000000 --- a/src/cbmpc/ffi/pki.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "pki.h" - -// Weak stubs for callback getters so core C++ unit/integration tests can link -// without any language-specific FFI layer (Go, Python, Rust, …). When an FFI -// layer is linked, its strong definitions override these stubs. -extern "C" { - -__attribute__((weak)) ffi_verify_fn get_ffi_verify_fn(void) { return nullptr; } -__attribute__((weak)) ffi_sign_fn get_ffi_sign_fn(void) { return nullptr; } - -__attribute__((weak)) ffi_kem_encap_fn get_ffi_kem_encap_fn(void) { return nullptr; } -__attribute__((weak)) ffi_kem_decap_fn get_ffi_kem_decap_fn(void) { return nullptr; } -__attribute__((weak)) ffi_kem_dk_to_ek_fn get_ffi_kem_dk_to_ek_fn(void) { return nullptr; } - -} // extern "C" diff --git a/src/cbmpc/ffi/pki.h b/src/cbmpc/ffi/pki.h deleted file mode 100644 index 3194904a..00000000 --- a/src/cbmpc/ffi/pki.h +++ /dev/null @@ -1,114 +0,0 @@ -#pragma once - -#include -#include - -#include "cmem_adapter.h" - -namespace coinbase::crypto { - -// External KEM types (encapsulate/decapsulate via FFI) -struct ffi_kem_ek_t : public buf_t { - using buf_t::buf_t; - using buf_t::operator=; -}; - -struct ffi_kem_dk_t { - void* handle = nullptr; // Opaque process-local handle to the private key - - ffi_kem_dk_t() = default; - explicit ffi_kem_dk_t(void* h) : handle(h) {} - - // Derive the public key using user-supplied callback. - ffi_kem_ek_t pub() const { - ffi_kem_dk_to_ek_fn derive_fn = get_ffi_kem_dk_to_ek_fn(); - cb_assert(derive_fn && "ffi_kem_dk_to_ek_fn not set"); - - cmem_t out{}; - int rc = derive_fn(static_cast(handle), &out); - cb_assert(rc == 0 && "ffi_kem_dk_to_ek_fn failed"); - - ffi_kem_ek_t ek; - ek = ffi::copy_from_cmem_and_free(out); - return ek; - } -}; - -// Opaque container for the KEM ciphertext produced by the external PKI. -struct ffi_kem_ct_t : public buf_t { - using buf_t::operator=; - using buf_t::buf_t; -}; - -// Policy adapter that uses the external KEM FFI: -// - encapsulate: produce (kem_ct, kem_ss) -// - decapsulate: recover kem_ss from kem_ct -struct kem_policy_ffi_t { - using ek_t = ffi_kem_ek_t; - using dk_t = ffi_kem_dk_t; - - static error_t encapsulate(const ek_t& pub_key, buf_t& kem_ct, buf_t& kem_ss, drbg_aes_ctr_t* drbg) { - ffi_kem_encap_fn enc_fn = get_ffi_kem_encap_fn(); - if (!enc_fn) return E_BADARG; - constexpr int rho_size = 32; - buf_t rho = drbg ? drbg->gen(rho_size) : gen_random(rho_size); - cmem_t ct_out{}; - cmem_t ss_out{}; - int rc = enc_fn(cmem_t{pub_key.data(), pub_key.size()}, cmem_t{rho.data(), rho.size()}, &ct_out, &ss_out); - if (rc) return E_CRYPTO; - kem_ct = ffi::copy_from_cmem_and_free(ct_out); - kem_ss = ffi::copy_from_cmem_and_free(ss_out); - return SUCCESS; - } - - static error_t decapsulate(const dk_t& prv_key, mem_t kem_ct, buf_t& kem_ss) { - ffi_kem_decap_fn dec_fn = get_ffi_kem_decap_fn(); - if (!dec_fn) return E_BADARG; - cmem_t ss_out{}; - cmem_t kem_ct_c = cmem_t{kem_ct.data, kem_ct.size}; - int rc = dec_fn(static_cast(prv_key.handle), kem_ct_c, &ss_out); - if (rc) return E_CRYPTO; - kem_ss = ffi::copy_from_cmem_and_free(ss_out); - return SUCCESS; - } -}; - -// External Signing types -struct ffi_sign_sk_t : public buf_t { - using buf_t::buf_t; - using buf_t::operator=; - - ffi_sign_sk_t(const buf_t& other) : buf_t(other) {} - ffi_sign_sk_t(buf_t&& other) : buf_t(std::move(other)) {} - - buf_t sign(mem_t hash) const { - ffi_sign_fn sign_fn = get_ffi_sign_fn(); - if (!sign_fn) return buf_t(); - cmem_t out{}; - int rc = sign_fn(cmem_t{this->data(), this->size()}, cmem_t{hash.data, hash.size}, &out); - if (rc) return buf_t(); - return ffi::copy_from_cmem_and_free(out); - } -}; - -struct ffi_sign_vk_t : public buf_t { - using buf_t::buf_t; - using buf_t::operator=; - - // Allow construction from a signing key (they share format here) - ffi_sign_vk_t(const ffi_sign_sk_t& sk) : buf_t(sk) {} - - error_t verify(mem_t hash, mem_t signature) const { - ffi_verify_fn verify_fn = get_ffi_verify_fn(); - if (!verify_fn) return E_BADARG; - int rc = verify_fn(cmem_t{this->data(), this->size()}, cmem_t{hash.data, hash.size}, - cmem_t{signature.data, signature.size}); - if (rc) return E_CRYPTO; - return SUCCESS; - } -}; - -using ffi_pke_t = hybrid_pke_t>; -using ffi_sign_scheme_t = sign_scheme_t; - -} // namespace coinbase::crypto diff --git a/src/cbmpc/protocol/CMakeLists.txt b/src/cbmpc/protocol/CMakeLists.txt index 9255d675..b881d34b 100644 --- a/src/cbmpc/protocol/CMakeLists.txt +++ b/src/cbmpc/protocol/CMakeLists.txt @@ -1,12 +1,11 @@ add_library(cbmpc_protocol OBJECT "") -target_precompile_headers(cbmpc_protocol PUBLIC "mpc_job.h") +target_precompile_headers(cbmpc_protocol PUBLIC "${ROOT_DIR}/include-internal/cbmpc/internal/protocol/mpc_job.h") target_sources(cbmpc_protocol PRIVATE int_commitment.cpp mpc_job.cpp - mpc_job_session.cpp agree_random.cpp ot.cpp diff --git a/src/cbmpc/protocol/agree_random.cpp b/src/cbmpc/protocol/agree_random.cpp index c9c83e2f..a677ee7c 100644 --- a/src/cbmpc/protocol/agree_random.cpp +++ b/src/cbmpc/protocol/agree_random.cpp @@ -1,7 +1,6 @@ -#include "agree_random.h" - -#include -#include +#include +#include +#include namespace coinbase::mpc { diff --git a/src/cbmpc/protocol/data_transport.h b/src/cbmpc/protocol/data_transport.h deleted file mode 100644 index 7c728175..00000000 --- a/src/cbmpc/protocol/data_transport.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once -#include -#include - -#include - -namespace coinbase::mpc { -using party_idx_t = int32_t; // forward declaration to avoid including mpc_job.h - -class data_transport_interface_t { - public: - virtual error_t send(party_idx_t receiver, mem_t msg) = 0; - virtual error_t receive(party_idx_t sender, buf_t& msg) = 0; - virtual error_t receive_all(const std::vector& senders, std::vector& message) = 0; - virtual ~data_transport_interface_t() = default; -}; - -} // namespace coinbase::mpc diff --git a/src/cbmpc/protocol/ec_dkg.cpp b/src/cbmpc/protocol/ec_dkg.cpp index 7892d405..086872f2 100644 --- a/src/cbmpc/protocol/ec_dkg.cpp +++ b/src/cbmpc/protocol/ec_dkg.cpp @@ -1,12 +1,10 @@ -#include "ec_dkg.h" - -#include -#include -#include -#include -#include - -#include "util.h" +#include +#include +#include +#include +#include +#include +#include using namespace coinbase::crypto::ss; @@ -263,9 +261,9 @@ error_t key_share_mp_t::refresh(job_mp_t& job, buf_t& sid, const key_share_mp_t& return SUCCESS; } -error_t key_share_mp_t::threshold_dkg_or_refresh(job_mp_t& job, const ecurve_t& curve, buf_t& sid, - const crypto::ss::ac_t ac, const party_set_t& quorum_party_set, - key_share_mp_t& key, key_share_mp_t& new_key, bool is_refresh) { +error_t key_share_mp_t::dkg_or_refresh_ac(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_share_mp_t& key, + key_share_mp_t& new_key, bool is_refresh) { error_t rv = UNINITIALIZED_ERROR; const auto& G = curve.generator(); @@ -461,18 +459,17 @@ error_t key_share_mp_t::threshold_dkg_or_refresh(job_mp_t& job, const ecurve_t& return SUCCESS; } -error_t key_share_mp_t::threshold_dkg(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_share_mp_t& key) { +error_t key_share_mp_t::dkg_ac(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_share_mp_t& key) { key_share_mp_t dummy_new_key; bool is_refresh = false; - return threshold_dkg_or_refresh(job, curve, sid, ac, quorum_party_set, key, dummy_new_key, is_refresh); + return dkg_or_refresh_ac(job, curve, sid, ac, quorum_party_set, key, dummy_new_key, is_refresh); } -error_t key_share_mp_t::threshold_refresh(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_share_mp_t& key, - key_share_mp_t& new_key) { +error_t key_share_mp_t::refresh_ac(job_mp_t& job, const ecurve_t& curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_share_mp_t& key, key_share_mp_t& new_key) { bool is_refresh = true; - return threshold_dkg_or_refresh(job, curve, sid, ac, quorum_party_set, key, new_key, is_refresh); + return dkg_or_refresh_ac(job, curve, sid, ac, quorum_party_set, key, new_key, is_refresh); } error_t key_share_mp_t::reconstruct_additive_share(const mod_t& q, const node_t* node, @@ -531,10 +528,10 @@ error_t key_share_mp_t::reconstruct_additive_share(const mod_t& q, const node_t* break; case node_e::THRESHOLD: { - std::vector pids(node->threshold); + std::vector quorum_pids; + quorum_pids.reserve(n); bn_t share = 0; bn_t share_pid = 0; - int count = 0; for (int i = 0; i < n; i++) { bn_t share_from_child; @@ -545,25 +542,40 @@ error_t key_share_mp_t::reconstruct_additive_share(const mod_t& q, const node_t* } if (rv) return rv; + if (!child_is_in_quorum) continue; + + const bn_t child_pid = node->children[i]->get_pid(); + quorum_pids.push_back(child_pid); + if (share_from_child != 0) { - share_pid = node->children[i]->get_pid(); + share_pid = child_pid; share = share_from_child; } - - if (count < node->threshold && child_is_in_quorum) { - pids[count] = node->children[i]->get_pid(); - count++; - } - if (count == node->threshold && share_pid != 0) break; } - if (count < node->threshold) { + if (int(quorum_pids.size()) < node->threshold) { dylog_disable_scope_t dylog_disable_scope; return coinbase::error(E_INSUFFICIENT); } is_in_quorum = true; - additive_share = crypto::lagrange_partial_interpolate(0, {share}, {share_pid}, pids, q); + // Target party is outside the selected quorum subtree for this threshold node. + if (share_pid == 0) { + additive_share = 0; + break; + } + + std::vector interp_pids; + interp_pids.reserve(node->threshold); + interp_pids.push_back(share_pid); + for (const auto& pid : quorum_pids) { + if (int(interp_pids.size()) == node->threshold) break; + if (pid == share_pid) continue; + interp_pids.push_back(pid); + } + cb_assert(int(interp_pids.size()) == node->threshold); + + additive_share = crypto::lagrange_partial_interpolate(0, {share}, {share_pid}, interp_pids, q); } break; case node_e::NONE: { return coinbase::error(E_CRYPTO, "key_share_mp_t::reconstruct_additive_share: none node"); @@ -631,10 +643,10 @@ error_t key_share_mp_t::reconstruct_pub_additive_shares(const crypto::ss::node_t break; case node_e::THRESHOLD: { - std::vector pids(node->threshold); + std::vector quorum_pids; + quorum_pids.reserve(n); ecc_point_t share = curve.infinity(); bn_t share_pid = 0; - int count = 0; for (int i = 0; i < n; i++) { ecc_point_t share_from_child = curve.infinity(); @@ -646,25 +658,40 @@ error_t key_share_mp_t::reconstruct_pub_additive_shares(const crypto::ss::node_t } if (rv) return rv; + if (!child_is_in_quorum) continue; + + const bn_t child_pid = node->children[i]->get_pid(); + quorum_pids.push_back(child_pid); + if (!share_from_child.is_infinity()) { - share_pid = node->children[i]->get_pid(); + share_pid = child_pid; share = share_from_child; } - - if (count < node->threshold && child_is_in_quorum) { - pids[count] = node->children[i]->get_pid(); - count++; - } - if (count == node->threshold && share_pid != 0) break; } - if (count < node->threshold) { + if (int(quorum_pids.size()) < node->threshold) { dylog_disable_scope_t dylog_disable_scope; return coinbase::error(E_INSUFFICIENT); } is_in_quorum = true; - pub_additive_shares = crypto::lagrange_partial_interpolate_exponent(0, {share}, {share_pid}, pids); + // Target party is outside the selected quorum subtree for this threshold node. + if (share_pid == 0) { + pub_additive_shares = curve.infinity(); + break; + } + + std::vector interp_pids; + interp_pids.reserve(node->threshold); + interp_pids.push_back(share_pid); + for (const auto& pid : quorum_pids) { + if (int(interp_pids.size()) == node->threshold) break; + if (pid == share_pid) continue; + interp_pids.push_back(pid); + } + cb_assert(int(interp_pids.size()) == node->threshold); + + pub_additive_shares = crypto::lagrange_partial_interpolate_exponent(0, {share}, {share_pid}, interp_pids); } break; case node_e::NONE: { return coinbase::error(E_CRYPTO, "key_share_mp_t::reconstruct_pub_additive_shares: none node"); diff --git a/src/cbmpc/protocol/ecdsa_2p.cpp b/src/cbmpc/protocol/ecdsa_2p.cpp index 60d23ef2..4c88e774 100644 --- a/src/cbmpc/protocol/ecdsa_2p.cpp +++ b/src/cbmpc/protocol/ecdsa_2p.cpp @@ -1,12 +1,10 @@ -#include "ecdsa_2p.h" - -#include -#include -#include -#include -#include - -#include "util.h" +#include +#include +#include +#include +#include +#include +#include namespace coinbase::mpc::ecdsa2pc { @@ -370,6 +368,12 @@ error_t sign_batch_impl(job_2p_t& job, buf_t& sid, const key_t& key, const std:: return coinbase::error(rv, "zk_ecdsa_sign_2pc_integer_commit_t::verify failed"); } + { + crypto::vartime_scope_t vartime_scope; + if (rv = key.paillier.verify_cipher(c[i])) + return coinbase::error(rv, "ecdsa_2p: invalid Paillier ciphertext from counterparty"); + } + bn_t s = key.paillier.decrypt(c[i]); s = q.mod(s); diff --git a/src/cbmpc/protocol/ecdsa_mp.cpp b/src/cbmpc/protocol/ecdsa_mp.cpp index e999821c..2cd9d8a7 100644 --- a/src/cbmpc/protocol/ecdsa_mp.cpp +++ b/src/cbmpc/protocol/ecdsa_mp.cpp @@ -1,12 +1,10 @@ -#include "ecdsa_mp.h" - -#include -#include -#include -#include -#include - -#include "util.h" +#include +#include +#include +#include +#include +#include +#include using namespace coinbase::mpc; @@ -26,14 +24,14 @@ error_t refresh(job_mp_t& job, buf_t& sid, key_t& key, key_t& new_key) { return eckey::key_share_mp_t::refresh(job, sid, key, new_key); } -error_t threshold_dkg(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_t& key) { - return eckey::key_share_mp_t::threshold_dkg(job, curve, sid, ac, quorum_party_set, key); +error_t dkg_ac(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_t& key) { + return eckey::key_share_mp_t::dkg_ac(job, curve, sid, ac, quorum_party_set, key); } -error_t threshold_refresh(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_t& key, key_t& new_key) { - return eckey::key_share_mp_t::threshold_refresh(job, curve, sid, ac, quorum_party_set, key, new_key); +error_t refresh_ac(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_t& key, key_t& new_key) { + return eckey::key_share_mp_t::refresh_ac(job, curve, sid, ac, quorum_party_set, key, new_key); } error_t sign(job_mp_t& job, key_t& key, mem_t msg, const party_idx_t sig_receiver, diff --git a/src/cbmpc/protocol/eddsa.cpp b/src/cbmpc/protocol/eddsa.cpp index c84b5627..04140c85 100644 --- a/src/cbmpc/protocol/eddsa.cpp +++ b/src/cbmpc/protocol/eddsa.cpp @@ -1,4 +1,4 @@ -#include "eddsa.h" +#include namespace coinbase::mpc::eddsa2pc { diff --git a/src/cbmpc/protocol/hd_keyset_ecdsa_2p.cpp b/src/cbmpc/protocol/hd_keyset_ecdsa_2p.cpp index 4891bef7..291f4476 100644 --- a/src/cbmpc/protocol/hd_keyset_ecdsa_2p.cpp +++ b/src/cbmpc/protocol/hd_keyset_ecdsa_2p.cpp @@ -1,14 +1,11 @@ #include - -#include "hd_keyset_ecdsa_2p.h" - -#include -#include -#include -#include -#include - -#include "ec_dkg.h" +#include +#include +#include +#include +#include +#include +#include using namespace coinbase; diff --git a/src/cbmpc/protocol/hd_keyset_eddsa_2p.cpp b/src/cbmpc/protocol/hd_keyset_eddsa_2p.cpp index e12ed5ee..79b9ee7d 100644 --- a/src/cbmpc/protocol/hd_keyset_eddsa_2p.cpp +++ b/src/cbmpc/protocol/hd_keyset_eddsa_2p.cpp @@ -1,15 +1,12 @@ #include - -#include "hd_keyset_eddsa_2p.h" - -#include -#include -#include -#include -#include -#include - -#include "ec_dkg.h" +#include +#include +#include +#include +#include +#include +#include +#include using namespace coinbase; diff --git a/src/cbmpc/protocol/hd_tree_bip32.cpp b/src/cbmpc/protocol/hd_tree_bip32.cpp index 99510009..70a08270 100644 --- a/src/cbmpc/protocol/hd_tree_bip32.cpp +++ b/src/cbmpc/protocol/hd_tree_bip32.cpp @@ -1,4 +1,4 @@ -#include "hd_tree_bip32.h" +#include namespace coinbase::mpc { diff --git a/src/cbmpc/protocol/int_commitment.cpp b/src/cbmpc/protocol/int_commitment.cpp index 0cbcef59..e4ce12e8 100644 --- a/src/cbmpc/protocol/int_commitment.cpp +++ b/src/cbmpc/protocol/int_commitment.cpp @@ -1,7 +1,6 @@ -#include "int_commitment.h" - -#include -#include +#include +#include +#include namespace coinbase::crypto { @@ -185,7 +184,7 @@ unknown_order_pedersen_params_t::unknown_order_pedersen_params_t() { z_str[i++] = strdup("9130340523321926272646990853417370038246155271170179947744837498945093083918369505121825035642905085882612022313486190849414014926827617638936812641156205771369190239803479284488724400226760954524960554895495988761809767942777164170315841970551814286171643167034381931535010840481809260886551448387589429940369813833273720645832961952092083624307446361239081395560482989899883795369068826490408731703011866911409160003142092586648089306664342893166839255512060998626110596952719610858703015224466036386602297023203256982857428427159082738503923519979729078671601444454253521664147647619127385074665320804577489732119207709414698586838539417256185822604256"); // clang-format on - assert(i == SEC_P_COM); + cb_assert(i == SEC_P_COM); }; ; } // namespace coinbase::crypto \ No newline at end of file diff --git a/src/cbmpc/protocol/mpc_job.cpp b/src/cbmpc/protocol/mpc_job.cpp index 826676b0..9fbde233 100644 --- a/src/cbmpc/protocol/mpc_job.cpp +++ b/src/cbmpc/protocol/mpc_job.cpp @@ -1,4 +1,4 @@ -#include "mpc_job.h" +#include namespace coinbase::mpc { @@ -14,8 +14,8 @@ error_t job_mp_t::send_to_parties(party_set_t set, const std::vector& in) // default implementation simply by receiving one by one error_t job_mp_t::receive_many_impl(std::vector from_set, std::vector& outs) { - if (!transport_ptr) return E_NET_GENERAL; - return transport_ptr->receive_all(from_set, outs); + if (!transport_raw) return E_NET_GENERAL; + return transport_raw->receive_all(from_set, outs); } error_t job_mp_t::receive_from_parties(party_set_t set, std::vector& v) { diff --git a/src/cbmpc/protocol/ot.cpp b/src/cbmpc/protocol/ot.cpp index cd14f29a..c8031229 100644 --- a/src/cbmpc/protocol/ot.cpp +++ b/src/cbmpc/protocol/ot.cpp @@ -1,6 +1,5 @@ -#include "ot.h" - -#include +#include +#include namespace coinbase::mpc { @@ -162,7 +161,8 @@ static void matrix_transposition(uint8_t const* inp, uint8_t* out, int nrows, in for (int rr = 0; rr < nrows; rr += 16) { for (int cc = 0; cc < ncols; cc += 8) { for (int i = 0; i < 16; ++i) tmp.b[i] = INP_BYTE(rr + i, cc); - for (int i = 8; --i >= 0; tmp = lshift64x2(tmp)) *(uint16_t*)&OUT_BYTE(rr, cc + i) = high16x8(tmp); + for (int i = 8; --i >= 0; tmp = lshift64x2(tmp)) + coinbase::le_set_2(byte_ptr(&OUT_BYTE(rr, cc + i)), high16x8(tmp)); } } } diff --git a/src/cbmpc/protocol/pve.cpp b/src/cbmpc/protocol/pve.cpp index 114e6b4a..3d67892c 100644 --- a/src/cbmpc/protocol/pve.cpp +++ b/src/cbmpc/protocol/pve.cpp @@ -1,26 +1,28 @@ -#include "pve.h" - #include -#include +#include +#include namespace coinbase::mpc { -ec_pve_t::ec_pve_t() : base_pke(pve_base_pke_unified()) {} - -void ec_pve_t::encrypt(const void* ek, mem_t label, ecurve_t curve, const bn_t& _x) { +error_t ec_pve_t::encrypt(const pve_base_pke_i& base_pke, pve_keyref_t ek, mem_t label, ecurve_t curve, + const bn_t& _x) { + error_t rv = UNINITIALIZED_ERROR; const auto& G = curve.generator(); const mod_t& q = curve.order(); bn_t bn_x = _x % q; - Q = bn_x * G; + ecc_point_t Q_local = bn_x * G; buf128_t r0[kappa]; buf128_t r1[kappa]; buf_t c0[kappa]; buf_t c1[kappa]; ecc_point_t X0[kappa]; ecc_point_t X1[kappa]; - L = buf_t(label); - buf_t inner_label = genPVELabelWithPoint(label, Q); + buf_t L_local = buf_t(label); + buf_t inner_label = genPVELabelWithPoint(label, Q_local); + bn_t x_rows_local[kappa]; + buf128_t r_local[kappa]; + buf_t c_local[kappa]; for (int i = 0; i < kappa; i++) { bn_t x0, x1; @@ -36,25 +38,36 @@ void ec_pve_t::encrypt(const void* ek, mem_t label, ecurve_t curve, const bn_t& MODULO(q) x1 = bn_x - x0; buf_t rho1 = drbg1.gen(rho_size); - base_pke.encrypt(ek, inner_label, x0.to_bin(), rho0, c0[i]); + if (rv = base_pke.encrypt(ek, inner_label, x0.to_bin(), rho0, c0[i])) return rv; X0[i] = x0 * G; - base_pke.encrypt(ek, inner_label, x1.to_bin(), rho1, c1[i]); - X1[i] = Q - X0[i]; + if (rv = base_pke.encrypt(ek, inner_label, x1.to_bin(), rho1, c1[i])) return rv; + X1[i] = Q_local - X0[i]; - x_rows[i] = x1; // output. will be cleared out if later bi == 0 + x_rows_local[i] = x1; // output. will be cleared out if later bi == 0 } - b = crypto::ro::hash_string(Q, label, c0, c1, X0, X1).bitlen(kappa); + buf128_t b_local; + b_local = crypto::ro::hash_string(Q_local, label, c0, c1, X0, X1).bitlen(kappa); for (int i = 0; i < kappa; i++) { - bool bi = b.get_bit(i); - r[i] = bi ? r1[i] : r0[i]; - c[i] = bi ? c0[i] : c1[i]; - if (!bi) x_rows[i] = 0; // clear the output + bool bi = b_local.get_bit(i); + r_local[i] = bi ? r1[i] : r0[i]; + c_local[i] = bi ? c0[i] : c1[i]; + if (!bi) x_rows_local[i] = 0; // clear the output } + + Q = std::move(Q_local); + L = std::move(L_local); + b = b_local; + for (int i = 0; i < kappa; i++) { + x_rows[i] = std::move(x_rows_local[i]); + r[i] = r_local[i]; + c[i] = std::move(c_local[i]); + } + return SUCCESS; } -error_t ec_pve_t::verify(const void* ek, const ecc_point_t& Q, mem_t label) const { +error_t ec_pve_t::verify(const pve_base_pke_i& base_pke, pve_keyref_t ek, const ecc_point_t& Q, mem_t label) const { error_t rv = UNINITIALIZED_ERROR; ecurve_t curve = Q.get_curve(); if (rv = curve.check(Q)) return coinbase::error(rv, "ec_pve_t::verify: check Q failed"); @@ -64,6 +77,7 @@ error_t ec_pve_t::verify(const void* ek, const ecc_point_t& Q, mem_t label) cons const auto& G = curve.generator(); const mod_t& q = curve.order(); + const int max_scalar_size = q.get_bin_size(); buf_t c0[kappa]; buf_t c1[kappa]; @@ -71,17 +85,23 @@ error_t ec_pve_t::verify(const void* ek, const ecc_point_t& Q, mem_t label) cons ecc_point_t X1[kappa]; for (int i = 0; i < kappa; i++) { + if (x_rows[i].get_bin_size() > max_scalar_size) return coinbase::error(E_CRYPTO); + if (!q.is_in_range(x_rows[i])) return coinbase::error(E_CRYPTO); + bool bi = b.get_bit(i); crypto::drbg_aes_ctr_t drbg(r[i]); - bn_t xi = x_rows[i]; - if (!bi) xi = drbg.gen_bn(q); + bn_t xi; + if (bi) + xi = x_rows[i]; + else + xi = drbg.gen_bn(q); buf_t rho = drbg.gen(rho_size); X0[i] = xi * G; X1[i] = Q - X0[i]; - base_pke.encrypt(ek, inner_label, xi.to_bin(), rho, c0[i]); + if (rv = base_pke.encrypt(ek, inner_label, xi.to_bin(), rho, c0[i])) return rv; c1[i] = c[i]; if (bi) { @@ -98,12 +118,19 @@ error_t ec_pve_t::verify(const void* ek, const ecc_point_t& Q, mem_t label) cons error_t ec_pve_t::restore_from_decrypted(int row_index, mem_t decrypted_x_buf, ecurve_t curve, bn_t& x_value) const { const mod_t& q = curve.order(); const auto& G = curve.generator(); + const int max_scalar_size = q.get_bin_size(); bool bi = b.get_bit(row_index); + if (decrypted_x_buf.size > max_scalar_size) return coinbase::error(E_CRYPTO); bn_t x_bi_bar = bn_t::from_bin(decrypted_x_buf); - bn_t x_bi = x_rows[row_index]; - - if (!bi) { + if (!q.is_in_range(x_bi_bar)) return coinbase::error(E_CRYPTO); + + bn_t x_bi; + if (bi) { + if (x_rows[row_index].get_bin_size() > max_scalar_size) return coinbase::error(E_CRYPTO); + if (!q.is_in_range(x_rows[row_index])) return coinbase::error(E_CRYPTO); + x_bi = x_rows[row_index]; + } else { crypto::drbg_aes_ctr_t drbg0(r[row_index]); x_bi = drbg0.gen_bn(q); } @@ -117,10 +144,10 @@ error_t ec_pve_t::restore_from_decrypted(int row_index, mem_t decrypted_x_buf, e return SUCCESS; } -error_t ec_pve_t::decrypt(const void* dk, const void* ek, mem_t label, ecurve_t curve, bn_t& x_out, - bool skip_verify) const { +error_t ec_pve_t::decrypt(const pve_base_pke_i& base_pke, pve_keyref_t dk, pve_keyref_t ek, mem_t label, ecurve_t curve, + bn_t& x_out, bool skip_verify) const { error_t rv = UNINITIALIZED_ERROR; - if (!skip_verify && (rv = verify(ek, Q, label))) return rv; + if (!skip_verify && (rv = verify(base_pke, ek, Q, label))) return rv; buf_t inner_label = genPVELabelWithPoint(label, Q); diff --git a/src/cbmpc/protocol/pve.h b/src/cbmpc/protocol/pve.h deleted file mode 100644 index c25e3129..00000000 --- a/src/cbmpc/protocol/pve.h +++ /dev/null @@ -1,90 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace coinbase::mpc { - -class ec_pve_t { - public: - // Default to unified PKE when not provided explicitly - ec_pve_t(); - explicit ec_pve_t(const pve_base_pke_i& base_pke) : base_pke(base_pke) {} - - // Custom copy/move ctors bind the reference member correctly - ec_pve_t(const ec_pve_t& other) : base_pke(other.base_pke), L(other.L), Q(other.Q), b(other.b) { - for (int i = 0; i < kappa; ++i) { - x_rows[i] = other.x_rows[i]; - r[i] = other.r[i]; - c[i] = other.c[i]; - } - } - ec_pve_t(ec_pve_t&& other) noexcept - : base_pke(other.base_pke), L(std::move(other.L)), Q(std::move(other.Q)), b(other.b) { - for (int i = 0; i < kappa; ++i) { - x_rows[i] = std::move(other.x_rows[i]); - r[i] = std::move(other.r[i]); - c[i] = std::move(other.c[i]); - } - } - // Assignment operators copy payload fields; reference member remains bound - ec_pve_t& operator=(const ec_pve_t& other) { - if (this == &other) return *this; - L = other.L; - Q = other.Q; - b = other.b; - for (int i = 0; i < kappa; ++i) { - x_rows[i] = other.x_rows[i]; - r[i] = other.r[i]; - c[i] = other.c[i]; - } - return *this; - } - ec_pve_t& operator=(ec_pve_t&& other) noexcept { - if (this == &other) return *this; - L = std::move(other.L); - Q = std::move(other.Q); - b = other.b; - for (int i = 0; i < kappa; ++i) { - x_rows[i] = std::move(other.x_rows[i]); - r[i] = std::move(other.r[i]); - c[i] = std::move(other.c[i]); - } - return *this; - } - - const static int kappa = SEC_P_COM; - const static int rho_size = 32; - - void encrypt(const void* ek, mem_t label, ecurve_t curve, const bn_t& x); - error_t verify(const void* ek, const ecc_point_t& Q, mem_t label) const; - error_t decrypt(const void* dk, const void* ek, mem_t label, ecurve_t curve, bn_t& x, bool skip_verify = false) const; - - const ecc_point_t& get_Q() const { return Q; } - const buf_t& get_Label() const { return L; } - - void convert(coinbase::converter_t& converter) { - converter.convert(Q, L, b); - for (int i = 0; i < kappa; i++) { - converter.convert(x_rows[i]); - converter.convert(r[i]); - converter.convert(c[i]); - } - } - - private: - const pve_base_pke_i& base_pke; - - buf_t L; - ecc_point_t Q; - buf128_t b; - - bn_t x_rows[kappa]; - buf128_t r[kappa]; - buf_t c[kappa]; - - error_t restore_from_decrypted(int row_index, mem_t decrypted_x_buf, ecurve_t curve, bn_t& x_value) const; -}; - -} // namespace coinbase::mpc diff --git a/src/cbmpc/protocol/pve_ac.cpp b/src/cbmpc/protocol/pve_ac.cpp index 2c507713..64c249c5 100644 --- a/src/cbmpc/protocol/pve_ac.cpp +++ b/src/cbmpc/protocol/pve_ac.cpp @@ -1,4 +1,4 @@ -#include "pve_ac.h" +#include using namespace coinbase::crypto; @@ -20,10 +20,10 @@ static error_t batch_from_bin(ecurve_t curve, int batch_size, mem_t bin, std::ve return SUCCESS; } -ec_pve_ac_t::ec_pve_ac_t() : base_pke(pve_base_pke_unified()), rows(kappa) {} - -void ec_pve_ac_t::encrypt_row(const ss::ac_t& ac, const pks_t& ac_pks, mem_t L, ecurve_t curve, mem_t seed, mem_t plain, - buf_t& c, std::vector& quorum_c) const { +error_t ec_pve_ac_t::encrypt_row(const pve_base_pke_i& base_pke, const ss::ac_t& ac, const pks_t& ac_pks, mem_t L, + ecurve_t curve, mem_t seed, mem_t plain, buf_t& c, + std::vector& quorum_c) const { + error_t rv = UNINITIALIZED_ERROR; const mod_t& q = curve.order(); crypto::drbg_aes_ctr_t drbg(seed); bn_t K = drbg.gen_bn(q); @@ -32,7 +32,7 @@ void ec_pve_ac_t::encrypt_row(const ss::ac_t& ac, const pks_t& ac_pks, mem_t L, for (const auto& [path, pub_key_ptr] : ac_pks) { ciphertext_adapter_t item; buf_t ct_ser; - base_pke.encrypt(pub_key_ptr, L, K_shares[path].to_bin(), drbg.gen(32), ct_ser); + if (rv = base_pke.encrypt(pub_key_ptr, L, K_shares[path].to_bin(), drbg.gen(32), ct_ser)) return rv; item.ct_ser = ct_ser; quorum_c.push_back(std::move(item)); } @@ -42,35 +42,38 @@ void ec_pve_ac_t::encrypt_row(const ss::ac_t& ac, const pks_t& ac_pks, mem_t L, mem_t iv = k_and_iv.skip(32); crypto::aes_gcm_t::encrypt(k_aes, iv, L, tag_size, plain, c); + return SUCCESS; } -void ec_pve_ac_t::encrypt_row0(const ss::ac_t& ac, const pks_t& ac_pks, mem_t L, ecurve_t curve, mem_t r0_1, mem_t r0_2, - int batch_size, std::vector& x0, buf_t& c0, - std::vector& quorum_c0) const { +error_t ec_pve_ac_t::encrypt_row0(const pve_base_pke_i& base_pke, const ss::ac_t& ac, const pks_t& ac_pks, mem_t L, + ecurve_t curve, mem_t r0_1, mem_t r0_2, int batch_size, std::vector& x0, + buf_t& c0, std::vector& quorum_c0) const { const mod_t& q = curve.order(); x0.resize(batch_size); crypto::drbg_aes_ctr_t drbg(r0_1); for (int j = 0; j < batch_size; j++) x0[j] = drbg.gen_bn(q); - encrypt_row(ac, ac_pks, L, curve, - r0_2, // seed - r0_1, // plain - c0, // output - quorum_c0 // output + return encrypt_row(base_pke, ac, ac_pks, L, curve, + r0_2, // seed + r0_1, // plain + c0, // output + quorum_c0 // output ); } -void ec_pve_ac_t::encrypt_row1(const ss::ac_t& ac, const pks_t& ac_pks, mem_t L, ecurve_t curve, mem_t r1, mem_t x1_bin, - buf_t& c1, std::vector& quorum_c1) const { - encrypt_row(ac, ac_pks, L, curve, - r1, // seed - x1_bin, // plain - c1, // output - quorum_c1 // output +error_t ec_pve_ac_t::encrypt_row1(const pve_base_pke_i& base_pke, const ss::ac_t& ac, const pks_t& ac_pks, mem_t L, + ecurve_t curve, mem_t r1, mem_t x1_bin, buf_t& c1, + std::vector& quorum_c1) const { + return encrypt_row(base_pke, ac, ac_pks, L, curve, + r1, // seed + x1_bin, // plain + c1, // output + quorum_c1 // output ); } -void ec_pve_ac_t::encrypt(const ss::ac_t& ac, const pks_t& ac_pks, mem_t label, ecurve_t curve, - const std::vector& _x) { +error_t ec_pve_ac_t::encrypt(const pve_base_pke_i& base_pke, const ss::ac_t& ac, const pks_t& ac_pks, mem_t label, + ecurve_t curve, const std::vector& _x) { + error_t rv = UNINITIALIZED_ERROR; int batch_size = int(_x.size()); const auto& G = curve.generator(); const mod_t& q = curve.order(); @@ -102,14 +105,15 @@ void ec_pve_ac_t::encrypt(const ss::ac_t& ac, const pks_t& ac_pks, mem_t label, r1[i] = crypto::gen_random_bitlen(SEC_P_COM); std::vector x0; - encrypt_row0(ac, ac_pks, L, curve, r0_1[i], r0_2[i], batch_size, x0, c0[i], quorum_c0[i]); + if (rv = encrypt_row0(base_pke, ac, ac_pks, L, curve, r0_1[i], r0_2[i], batch_size, x0, c0[i], quorum_c0[i])) + return rv; std::vector x1(batch_size); for (int j = 0; j < batch_size; j++) MODULO(q) x1[j] = x[j] - x0[j]; row_t& row = rows[i]; row.x_bin = batch_to_bin(curve, x1); - encrypt_row1(ac, ac_pks, L, curve, r1[i], row.x_bin, c1[i], quorum_c1[i]); + if (rv = encrypt_row1(base_pke, ac, ac_pks, L, curve, r1[i], row.x_bin, c1[i], quorum_c1[i])) return rv; for (int j = 0; j < batch_size; j++) { X0[i][j] = x0[j] * G; @@ -126,10 +130,11 @@ void ec_pve_ac_t::encrypt(const ss::ac_t& ac, const pks_t& ac_pks, mem_t label, rows[i].quorum_c = bit ? quorum_c0[i] : quorum_c1[i]; if (!bit) rows[i].x_bin.free(); // clear output } + return SUCCESS; } -error_t ec_pve_ac_t::verify(const ss::ac_t& ac, const pks_t& ac_pks, const std::vector& Q, - mem_t label) const { +error_t ec_pve_ac_t::verify(const pve_base_pke_i& base_pke, const ss::ac_t& ac, const pks_t& ac_pks, + const std::vector& Q, mem_t label) const { error_t rv = UNINITIALIZED_ERROR; int batch_size = int(Q.size()); if (batch_size == 0) return coinbase::error(E_BADARG); @@ -164,14 +169,14 @@ error_t ec_pve_ac_t::verify(const ss::ac_t& ac, const pks_t& ac_pks, const std:: quorum_c0[i] = row.quorum_c; if (rv = batch_from_bin(curve, batch_size, row.x_bin, xb)) return rv; mem_t r1 = row.r; - encrypt_row1(ac, ac_pks, L, curve, r1, row.x_bin, c1[i], quorum_c1[i]); + if (rv = encrypt_row1(base_pke, ac, ac_pks, L, curve, r1, row.x_bin, c1[i], quorum_c1[i])) return rv; } else { c1[i] = row.c; quorum_c1[i] = row.quorum_c; if (row.r.size() != 32) return coinbase::error(E_CRYPTO); mem_t r0_1 = row.r.take(16); mem_t r0_2 = row.r.skip(16); - encrypt_row0(ac, ac_pks, L, curve, r0_1, r0_2, batch_size, xb, c0[i], quorum_c0[i]); + if (rv = encrypt_row0(base_pke, ac, ac_pks, L, curve, r0_1, r0_2, batch_size, xb, c0[i], quorum_c0[i])) return rv; } for (int j = 0; j < batch_size; j++) { @@ -200,8 +205,9 @@ error_t ec_pve_ac_t::find_quorum_ciphertext(const std::vector& sort return SUCCESS; } -error_t ec_pve_ac_t::party_decrypt_row(const ss::ac_t& ac, int row_index, const std::string& path, - const void* prv_key_ptr, mem_t label, bn_t& out_share) const { +error_t ec_pve_ac_t::party_decrypt_row(const pve_base_pke_i& base_pke, const ss::ac_t& ac, int row_index, + const std::string& path, pve_keyref_t prv_key, mem_t label, + bn_t& out_share) const { error_t rv = UNINITIALIZED_ERROR; if (row_index < 0 || row_index >= kappa) return coinbase::error(E_RANGE); if (Q.empty()) return coinbase::error(E_BADARG); @@ -217,20 +223,23 @@ error_t ec_pve_ac_t::party_decrypt_row(const ss::ac_t& ac, int row_index, const if (rv = find_quorum_ciphertext(sorted_leaves, path, row, c)) return rv; buf_t plain; - if (rv = base_pke.decrypt(prv_key_ptr, L, c->ct_ser, plain)) return rv; + if (rv = base_pke.decrypt(prv_key, L, c->ct_ser, plain)) return rv; out_share = bn_t::from_bin(plain); return SUCCESS; } -error_t ec_pve_ac_t::aggregate_to_restore_row(const ss::ac_t& ac, int row_index, mem_t label, - const std::map& quorum_decrypted, std::vector& x, - bool skip_verify, const pks_t& all_ac_pks) const { +error_t ec_pve_ac_t::aggregate_to_restore_row(const pve_base_pke_i& base_pke, const ss::ac_t& ac, int row_index, + mem_t label, const std::map& quorum_decrypted, + std::vector& x, bool skip_verify, const pks_t& all_ac_pks) const { error_t rv = UNINITIALIZED_ERROR; if (row_index < 0 || row_index >= kappa) return coinbase::error(E_RANGE); if (Q.empty()) return coinbase::error(E_BADARG); - if (!skip_verify && !all_ac_pks.empty()) { - if (rv = verify(ac, all_ac_pks, Q, label)) return rv; + if (!skip_verify) { + if (all_ac_pks.empty()) { + return coinbase::error(E_BADARG, "all_ac_pks is required when skip_verify is false"); + } + if (rv = verify(base_pke, ac, all_ac_pks, Q, label)) return rv; } const row_t& row = rows[row_index]; diff --git a/src/cbmpc/protocol/pve_ac.h b/src/cbmpc/protocol/pve_ac.h deleted file mode 100644 index 99cd35b4..00000000 --- a/src/cbmpc/protocol/pve_ac.h +++ /dev/null @@ -1,101 +0,0 @@ -#pragma once - -#include -#include - -#include "pve.h" - -namespace coinbase::mpc { - -class ec_pve_ac_t { - public: - struct ciphertext_adapter_t { - buf_t ct_ser; - void convert(coinbase::converter_t& converter) { converter.convert(ct_ser); } - }; - - typedef std::map pks_t; // maps leaf path -> encryption key pointer - typedef std::map sks_t; // maps leaf path -> decryption key pointer - - static constexpr int kappa = SEC_P_COM; - static constexpr std::size_t iv_size = crypto::KEM_AEAD_IV_SIZE; - static constexpr std::size_t tag_size = crypto::KEM_AEAD_TAG_SIZE; - static constexpr std::size_t iv_bitlen = iv_size * 8; - - // Default to unified PKE when not provided explicitly - ec_pve_ac_t(); - explicit ec_pve_ac_t(const pve_base_pke_i& base_pke) : base_pke(base_pke), rows(kappa) {} - - void convert(coinbase::converter_t& converter) { - converter.convert(Q, L, b); - - for (int i = 0; i < kappa; i++) { - converter.convert(rows[i].x_bin); - converter.convert(rows[i].r); - converter.convert(rows[i].c); - converter.convert(rows[i].quorum_c); - } - } - - /** - * @specs: - * - publicly-verifiable-encryption-spec | vencrypt-batch-many-1P - */ - void encrypt(const crypto::ss::ac_t& ac, const pks_t& ac_pks, mem_t label, ecurve_t curve, - const std::vector& x); - - /** - * @specs: - * - publicly-verifiable-encryption-spec | vverify-batch-many-1P - */ - error_t verify(const crypto::ss::ac_t& ac, const pks_t& ac_pks, const std::vector& Q, mem_t label) const; - - /** - * @specs: - * - publicly-verifiable-encryption-spec | vdecrypt-local-batch-many-1P - * - * @notes: - * Each party calls party_decrypt_row to produce its share for a specific row. - * Then, the caller aggregates shares using aggregate_to_restore_row to recover x. - * This is different from the spec since the decryption is not done in a loop, rather at each - * invocation, a single row is decrypted. As a result, it is the responsibility of the caller application - * to call this api multiple times if needed. - */ - error_t party_decrypt_row(const crypto::ss::ac_t& ac, int row_index, const std::string& path, const void* prv_key_ptr, - mem_t label, bn_t& out_share) const; - - /** - * @specs: - * - publicly-verifiable-encryption-spec | vdecrypt-combine-batch-many-1P - */ - error_t aggregate_to_restore_row(const crypto::ss::ac_t& ac, int row_index, mem_t label, - const std::map& quorum_decrypted, std::vector& x, - bool skip_verify = false, const pks_t& all_ac_pks = pks_t()) const; - const std::vector& get_Q() const { return Q; } - - private: - const pve_base_pke_i& base_pke; - std::vector Q; - buf_t L; - buf128_t b; - struct row_t { - buf_t x_bin, r, c; - std::vector quorum_c; - }; - std::vector rows; - - void encrypt_row(const crypto::ss::ac_t& ac, const pks_t& ac_pks, mem_t label, ecurve_t curve, mem_t seed, - mem_t plain, buf_t& c, std::vector& quorum_c) const; - - void encrypt_row0(const crypto::ss::ac_t& ac, const pks_t& ac_pks, mem_t label, ecurve_t curve, mem_t r0_1, - mem_t r0_2, int batch_size, std::vector& x0, buf_t& c, - std::vector& quorum_c) const; - - void encrypt_row1(const crypto::ss::ac_t& ac, const pks_t& ac_pks, mem_t label, ecurve_t curve, mem_t r1, - mem_t x1_bin, buf_t& c, std::vector& quorum_c) const; - - static error_t find_quorum_ciphertext(const std::vector& sorted_leaves, const std::string& path, - const row_t& row, const ciphertext_adapter_t*& c); -}; - -} // namespace coinbase::mpc diff --git a/src/cbmpc/protocol/pve_base.cpp b/src/cbmpc/protocol/pve_base.cpp index 6ba57ed3..0470f0b2 100644 --- a/src/cbmpc/protocol/pve_base.cpp +++ b/src/cbmpc/protocol/pve_base.cpp @@ -1,4 +1,4 @@ -#include "pve_base.h" +#include namespace coinbase::mpc { @@ -7,20 +7,24 @@ namespace { // Generic helper to invoke a specific HPKE-like type from the base_pke. template struct pve_base_pke_impl_t : public pve_base_pke_i { - error_t encrypt(const void* ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { + error_t encrypt(pve_keyref_t ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { + const EK* pub_key = ek.get(); + if (!pub_key) return coinbase::error(E_BADARG, "invalid encryption key"); crypto::drbg_aes_ctr_t drbg(rho); CT ct; error_t rv = UNINITIALIZED_ERROR; - if (rv = ct.encrypt(*static_cast(ek), label, plain, &drbg)) return rv; + if (rv = ct.encrypt(*pub_key, label, plain, &drbg)) return rv; out_ct = ser(ct); return SUCCESS; } - error_t decrypt(const void* dk, mem_t label, mem_t ct_ser, buf_t& out_plain) const override { + error_t decrypt(pve_keyref_t dk, mem_t label, mem_t ct_ser, buf_t& out_plain) const override { + const DK* prv_key = dk.get(); + if (!prv_key) return coinbase::error(E_BADARG, "invalid decryption key"); error_t rv = UNINITIALIZED_ERROR; CT ct; if (rv = deser(ct_ser, ct)) return rv; - if (rv = ct.decrypt(*static_cast(dk), label, out_plain)) return rv; + if (rv = ct.decrypt(*prv_key, label, out_plain)) return rv; return SUCCESS; } }; @@ -34,11 +38,8 @@ const pve_base_pke_impl_t> base_pke_ecies; -const pve_base_pke_impl_t base_pke_unified; - } // namespace -const pve_base_pke_i& pve_base_pke_unified() { return base_pke_unified; } const pve_base_pke_i& pve_base_pke_rsa() { return base_pke_rsa; } const pve_base_pke_i& pve_base_pke_ecies() { return base_pke_ecies; } diff --git a/src/cbmpc/protocol/pve_base.h b/src/cbmpc/protocol/pve_base.h deleted file mode 100644 index 321b7207..00000000 --- a/src/cbmpc/protocol/pve_base.h +++ /dev/null @@ -1,79 +0,0 @@ -#pragma once - -#include -#include - -namespace coinbase::mpc { - -struct pve_base_pke_i { - virtual ~pve_base_pke_i() = default; - virtual error_t encrypt(const void* ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const = 0; - virtual error_t decrypt(const void* dk, mem_t label, mem_t ct, buf_t& out_plain) const = 0; -}; - -// Generic adapter that turns any KEM policy into a PVE base PKE via kem_aead_ciphertext_t -template -struct kem_pve_base_pke_t : public pve_base_pke_i { - using EK = typename KEM_POLICY::ek_t; - using DK = typename KEM_POLICY::dk_t; - using CT = crypto::kem_aead_ciphertext_t; - - error_t encrypt(const void* ek, mem_t label, mem_t plain, mem_t rho, buf_t& out_ct) const override { - crypto::drbg_aes_ctr_t drbg(rho); - CT ct; - error_t rv = ct.encrypt(*static_cast(ek), label, plain, &drbg); - if (rv) return rv; - out_ct = ser(ct); - return SUCCESS; - } - - error_t decrypt(const void* dk, mem_t label, mem_t ct_ser, buf_t& out_plain) const override { - error_t rv = UNINITIALIZED_ERROR; - CT ct; - if (rv = deser(ct_ser, ct)) return rv; - return ct.decrypt(*static_cast(dk), label, out_plain); - } -}; - -template -inline const pve_base_pke_i& kem_pve_base_pke() { - static const kem_pve_base_pke_t pke; - return pke; -} - -// Accessors to built-in base PKE implementations for testing and convenience -const pve_base_pke_i& pve_base_pke_unified(); -const pve_base_pke_i& pve_base_pke_rsa(); -const pve_base_pke_i& pve_base_pke_ecies(); - -/** - * @notes: - * - This is the underlying encryption used in PVE - */ -template -buf_t pve_base_encrypt(const typename HPKE_T::ek_t& pub_key, mem_t label, const buf_t& plaintext, mem_t rho) { - crypto::drbg_aes_ctr_t drbg(rho); - typename HPKE_T::ct_t ct; - ct.encrypt(pub_key, label, plaintext, &drbg); - return ser(ct); -} - -/** - * @notes: - * - This is the underlying decryption used in PVE - */ -template -error_t pve_base_decrypt(const typename HPKE_T::dk_t& prv_key, mem_t label, mem_t ciphertext, buf_t& plain) { - error_t rv = UNINITIALIZED_ERROR; - typename HPKE_T::ct_t ct; - if (rv = deser(ciphertext, ct)) return rv; - if (rv = ct.decrypt(prv_key, label, plain)) return rv; - return SUCCESS; -} - -template -static buf_t genPVELabelWithPoint(mem_t label, const T& Q) { - return buf_t(label) + "-" + strext::to_hex(crypto::sha256_t::hash(Q)); -} - -} // namespace coinbase::mpc \ No newline at end of file diff --git a/src/cbmpc/protocol/pve_batch.cpp b/src/cbmpc/protocol/pve_batch.cpp index 9af3fe91..7b9d7d90 100644 --- a/src/cbmpc/protocol/pve_batch.cpp +++ b/src/cbmpc/protocol/pve_batch.cpp @@ -1,13 +1,10 @@ -#include "pve_batch.h" +#include namespace coinbase::mpc { -ec_pve_batch_t::ec_pve_batch_t(int batch_count) : base_pke(pve_base_pke_unified()), n(batch_count), rows(kappa) { - cb_assert(batch_count > 0 && batch_count <= MAX_BATCH_COUNT); - Q.resize(n); -} - -void ec_pve_batch_t::encrypt(const void* ek, mem_t label, ecurve_t curve, const std::vector& _x) { +error_t ec_pve_batch_t::encrypt(const pve_base_pke_i& base_pke, pve_keyref_t ek, mem_t label, ecurve_t curve, + const std::vector& _x) { + error_t rv = UNINITIALIZED_ERROR; cb_assert(n > 0 && n <= MAX_BATCH_COUNT); cb_assert(int(_x.size()) == n); @@ -45,7 +42,9 @@ void ec_pve_batch_t::encrypt(const void* ek, mem_t label, ecurve_t curve, const buf_t rho0 = drbg02.gen(rho_size); buf_t rho1 = drbg1.gen(rho_size); - std::vector x0 = bn_t::vector_from_bin(x0_source_bin, n, curve_size + coinbase::bits_to_bytes(SEC_P_STAT), q); + std::vector x0; + if (rv = bn_t::vector_from_bin(x0_source_bin, n, curve_size + coinbase::bits_to_bytes(SEC_P_STAT), q, x0)) + return rv; std::vector x1(n); for (int j = 0; j < n; j++) { MODULO(q) x1[j] = x[j] - x0[j]; @@ -56,8 +55,8 @@ void ec_pve_batch_t::encrypt(const void* ek, mem_t label, ecurve_t curve, const buf_t x1_bin = bn_t::vector_to_bin(x1, curve_size); - base_pke.encrypt(ek, inner_label, r01[i], rho0, c0[i]); - base_pke.encrypt(ek, inner_label, x1_bin, rho1, c1[i]); + if (rv = base_pke.encrypt(ek, inner_label, r01[i], rho0, c0[i])) return rv; + if (rv = base_pke.encrypt(ek, inner_label, x1_bin, rho1, c1[i])) return rv; rows[i].x_bin = x1_bin; // some of these will be reset to zero later based on `bi` } @@ -69,9 +68,11 @@ void ec_pve_batch_t::encrypt(const void* ek, mem_t label, ecurve_t curve, const rows[i].c = bi ? c0[i] : c1[i]; if (!bi) rows[i].x_bin.free(); } + return SUCCESS; } -error_t ec_pve_batch_t::verify(const void* ek, const std::vector& Q, mem_t label) const { +error_t ec_pve_batch_t::verify(const pve_base_pke_i& base_pke, pve_keyref_t ek, const std::vector& Q, + mem_t label) const { error_t rv = UNINITIALIZED_ERROR; if (n <= 0 || n > MAX_BATCH_COUNT) return coinbase::error(E_BADARG); if (int(Q.size()) != n) return coinbase::error(E_BADARG); @@ -103,25 +104,26 @@ error_t ec_pve_batch_t::verify(const void* ek, const std::vector& Q if (bi) { c0[i] = rows[i].c; - xi = bn_t::vector_from_bin(rows[i].x_bin, n, curve_size, q); - if (rows[i].r.size() != 16) return coinbase::error(E_CRYPTO); + if (rows[i].x_bin.size() != n * curve_size) return coinbase::error(E_CRYPTO); + if (rv = bn_t::vector_from_bin(rows[i].x_bin, n, curve_size, q, xi)) return rv; crypto::drbg_aes_ctr_t drbg1(rows[i].r); buf_t rho1 = drbg1.gen(rho_size); - base_pke.encrypt(ek, inner_label, bn_t::vector_to_bin(xi, curve_size), rho1, c1[i]); + if (rv = base_pke.encrypt(ek, inner_label, bn_t::vector_to_bin(xi, curve_size), rho1, c1[i])) return rv; } else { c1[i] = rows[i].c; if (rows[i].r.size() != 32) return coinbase::error(E_CRYPTO); crypto::drbg_aes_ctr_t drbg01(rows[i].r.take(16)); buf_t x0_source_bin = drbg01.gen(n * (curve_size + coinbase::bits_to_bytes(SEC_P_STAT))); - xi = bn_t::vector_from_bin(x0_source_bin, n, curve_size + coinbase::bits_to_bytes(SEC_P_STAT), q); + if (rv = bn_t::vector_from_bin(x0_source_bin, n, curve_size + coinbase::bits_to_bytes(SEC_P_STAT), q, xi)) + return rv; crypto::drbg_aes_ctr_t drbg02(rows[i].r.skip(16)); buf_t rho0 = drbg02.gen(rho_size); - base_pke.encrypt(ek, inner_label, rows[i].r.take(16), rho0, c0[i]); + if (rv = base_pke.encrypt(ek, inner_label, rows[i].r.take(16), rho0, c0[i])) return rv; } X0[i].resize(n); @@ -141,28 +143,40 @@ error_t ec_pve_batch_t::verify(const void* ek, const std::vector& Q error_t ec_pve_batch_t::restore_from_decrypted(int row_index, mem_t decrypted_x_buf, ecurve_t curve, std::vector& x) const { - if (row_index > kappa) return coinbase::error(E_BADARG); + if (row_index < 0 || row_index >= kappa) return coinbase::error(E_BADARG); const mod_t& q = curve.order(); const auto& G = curve.generator(); int curve_size = curve.size(); - buf_t r01, x1_bin; - bool bi = b.get_bit(row_index); + const row_t& row = rows[row_index]; + const bool bi = b.get_bit(row_index); + + mem_t r01; + mem_t x1_bin; if (bi) { - x1_bin = rows[row_index].x_bin; + // When bi=1, the opened value is x1_bin (stored), and the unopened value is r01 (decrypted). + if (row.x_bin.size() != n * curve_size) return coinbase::error(E_CRYPTO); + x1_bin = row.x_bin; r01 = decrypted_x_buf; + if (r01.size != 16) return coinbase::error(E_CRYPTO); } else { + // When bi=0, the opened value is r01||r02 (stored in row.r), and the unopened value is x1_bin (decrypted). + if (row.r.size() != 32) return coinbase::error(E_CRYPTO); + r01 = row.r.take(16); x1_bin = decrypted_x_buf; - if (rows[row_index].r.size() != 32) return coinbase::error(E_CRYPTO); - r01 = rows[row_index].r.take(16); + if (x1_bin.size != n * curve_size) return coinbase::error(E_CRYPTO); } - crypto::drbg_aes_ctr_t drbg01(r01); // decrypted_x_buf = r01 + crypto::drbg_aes_ctr_t drbg01(r01); buf_t x0_source_bin = drbg01.gen(n * (curve_size + coinbase::bits_to_bytes(SEC_P_STAT))); - std::vector x0 = bn_t::vector_from_bin(x0_source_bin, n, curve_size + coinbase::bits_to_bytes(SEC_P_STAT), q); + std::vector x0; + error_t rv = bn_t::vector_from_bin(x0_source_bin, n, curve_size + coinbase::bits_to_bytes(SEC_P_STAT), q, x0); + if (rv) return rv; - std::vector x1 = bn_t::vector_from_bin(x1_bin, n, curve_size, q); + std::vector x1; + rv = bn_t::vector_from_bin(x1_bin, n, curve_size, q, x1); + if (rv) return rv; for (int i = 0; i < n; i++) { MODULO(q) x[i] = x0[i] + x1[i]; @@ -172,11 +186,11 @@ error_t ec_pve_batch_t::restore_from_decrypted(int row_index, mem_t decrypted_x_ return SUCCESS; } -error_t ec_pve_batch_t::decrypt(const void* dk, const void* ek, mem_t label, ecurve_t curve, std::vector& xs, - bool skip_verify) const { +error_t ec_pve_batch_t::decrypt(const pve_base_pke_i& base_pke, pve_keyref_t dk, pve_keyref_t ek, mem_t label, + ecurve_t curve, std::vector& xs, bool skip_verify) const { error_t rv = UNINITIALIZED_ERROR; xs.resize(n); - if (!skip_verify && (rv = verify(ek, Q, label))) return rv; + if (!skip_verify && (rv = verify(base_pke, ek, Q, label))) return rv; if (label != this->L) return coinbase::error(E_CRYPTO); buf_t inner_label = genPVELabelWithPoint(label, Q); diff --git a/src/cbmpc/protocol/pve_batch.h b/src/cbmpc/protocol/pve_batch.h deleted file mode 100644 index 810e144f..00000000 --- a/src/cbmpc/protocol/pve_batch.h +++ /dev/null @@ -1,108 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace coinbase::mpc { - -class ec_pve_batch_t { - public: - // Default to unified PKE when not provided explicitly - explicit ec_pve_batch_t(int batch_count); - ec_pve_batch_t(int batch_count, const pve_base_pke_i& base_pke) : base_pke(base_pke), n(batch_count), rows(kappa) { - cb_assert(batch_count > 0 && batch_count <= MAX_BATCH_COUNT); - Q.resize(n); - } - - // Custom copy/move ctors bind the reference member correctly - ec_pve_batch_t(const ec_pve_batch_t& other) - : base_pke(other.base_pke), n(other.n), L(other.L), Q(other.Q), b(other.b), rows(other.rows) {} - ec_pve_batch_t(ec_pve_batch_t&& other) noexcept - : base_pke(other.base_pke), - n(other.n), - L(std::move(other.L)), - Q(std::move(other.Q)), - b(other.b), - rows(std::move(other.rows)) {} - // Assignment operators copy payload fields; reference member remains bound - ec_pve_batch_t& operator=(const ec_pve_batch_t& other) { - if (this == &other) return *this; - n = other.n; - L = other.L; - Q = other.Q; - b = other.b; - rows = other.rows; - return *this; - } - ec_pve_batch_t& operator=(ec_pve_batch_t&& other) noexcept { - if (this == &other) return *this; - n = other.n; - L = std::move(other.L); - Q = std::move(other.Q); - b = other.b; - rows = std::move(other.rows); - return *this; - } - - const static int kappa = SEC_P_COM; - // Upper bound to prevent integer-overflow and unbounded memory allocation when `n` is untrusted. - // This is a defensive limit; callers should treat any larger batch as invalid input. - static constexpr int MAX_BATCH_COUNT = 100000; - // We assume the base encryption scheme requires 32 bytes of randomness. If it needs more, it can be changed to use - // DRBG with 32 bytes of randomness as the seed. - const static int rho_size = 32; - - /** - * @specs: - * - publicly-verifiable-encryption-spec | vencrypt-batch-1P - */ - void encrypt(const void* ek, mem_t label, ecurve_t curve, const std::vector& x); - - /** - * @specs: - * - publicly-verifiable-encryption-spec | vverify-batch-1P - */ - error_t verify(const void* ek, const std::vector& Q, mem_t label) const; - - /** - * @specs: - * - publicly-verifiable-encryption-spec | vdecrypt-batch-1P - */ - error_t decrypt(const void* dk, const void* ek, mem_t label, ecurve_t curve, std::vector& x, - bool skip_verify = false) const; - - void convert(coinbase::converter_t& converter) { - if (int(Q.size()) != n) { - converter.set_error(); - return; - } - - converter.convert(Q, L, b); - - for (int i = 0; i < kappa; i++) { - converter.convert(rows[i].x_bin); - converter.convert(rows[i].r); - converter.convert(rows[i].c); - } - } - - private: - const pve_base_pke_i& base_pke; - int n; - - buf_t L; - std::vector Q; - buf128_t b; - - struct row_t { - buf_t x_bin; - buf_t r; - buf_t c; - }; - std::vector rows; - - error_t restore_from_decrypted(int row_index, mem_t decrypted_x_buf, ecurve_t curve, std::vector& xs) const; -}; - -} // namespace coinbase::mpc \ No newline at end of file diff --git a/src/cbmpc/protocol/schnorr_2p.cpp b/src/cbmpc/protocol/schnorr_2p.cpp index 0f55a9dc..ce37ea41 100644 --- a/src/cbmpc/protocol/schnorr_2p.cpp +++ b/src/cbmpc/protocol/schnorr_2p.cpp @@ -1,10 +1,8 @@ -#include "schnorr_2p.h" - -#include -#include -#include - -#include "util.h" +#include +#include +#include +#include +#include namespace coinbase::mpc::schnorr2p { @@ -76,6 +74,8 @@ error_t sign_batch(job_2p_t& job, key_t& key, const std::vector& msgs, st if (variant == variant_e::BIP340) { if (curve != crypto::curve_secp256k1) return coinbase::error(E_BADARG, "BIP340 variant requires secp256k1 curve"); for (int i = 0; i < n_sigs; i++) { + if (msgs[i].size != 32) return coinbase::error(E_BADARG, "schnorr_2p: BIP340 msg size != 32"); + if (!msgs[i].data) return coinbase::error(E_BADARG, "schnorr_2p: BIP340 msg is null"); bn_t rx, ry; R[i].get_coordinates(rx, ry); if (ry.is_odd()) { @@ -133,4 +133,4 @@ error_t sign_batch(job_2p_t& job, key_t& key, const std::vector& msgs, st return SUCCESS; } -} // namespace coinbase::mpc::schnorr2p \ No newline at end of file +} // namespace coinbase::mpc::schnorr2p diff --git a/src/cbmpc/protocol/schnorr_mp.cpp b/src/cbmpc/protocol/schnorr_mp.cpp index 4861e9ee..1c2db3b0 100644 --- a/src/cbmpc/protocol/schnorr_mp.cpp +++ b/src/cbmpc/protocol/schnorr_mp.cpp @@ -1,14 +1,12 @@ -#include "schnorr_mp.h" - #include #include -#include -#include -#include -#include - -#include "util.h" +#include +#include +#include +#include +#include +#include #define _i msg #define _j received(j) @@ -26,14 +24,14 @@ error_t refresh(job_mp_t& job, buf_t& sid, key_t& key, key_t& new_key) { return eckey::key_share_mp_t::refresh(job, sid, key, new_key); } -error_t threshold_dkg(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_t& key) { - return eckey::key_share_mp_t::threshold_dkg(job, curve, sid, ac, quorum_party_set, key); +error_t dkg_ac(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_t& key) { + return eckey::key_share_mp_t::dkg_ac(job, curve, sid, ac, quorum_party_set, key); } -error_t threshold_refresh(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, - const party_set_t& quorum_party_set, key_t& key, key_t& new_key) { - return eckey::key_share_mp_t::threshold_refresh(job, curve, sid, ac, quorum_party_set, key, new_key); +error_t refresh_ac(job_mp_t& job, ecurve_t curve, buf_t& sid, const crypto::ss::ac_t ac, + const party_set_t& quorum_party_set, key_t& key, key_t& new_key) { + return eckey::key_share_mp_t::refresh_ac(job, curve, sid, ac, quorum_party_set, key, new_key); } static bn_t calc_eddsa_HRAM(const ecc_point_t& R, const ecc_point_t& Q, mem_t in) { @@ -133,6 +131,8 @@ error_t sign_batch(job_mp_t& job, key_t& key, const std::vector& msgs, pa return coinbase::error(E_BADARG, "BIP340 variant requires secp256k1 curve"); bn_t rx, ry; for (size_t l = 0; l < msgs.size(); l++) { + if (msgs[l].size != 32) return coinbase::error(E_BADARG, "schnorr_mp: BIP340 msg size != 32"); + if (!msgs[l].data) return coinbase::error(E_BADARG, "schnorr_mp: BIP340 msg is null"); R[l].get_coordinates(rx, ry); if (ry.is_odd()) ki[l] = q - ki[l]; e[l] = crypto::bip340::hash_message(rx, key.Q, msgs[l]); diff --git a/src/cbmpc/zk/fischlin.cpp b/src/cbmpc/zk/fischlin.cpp index 35e37a83..1a201b66 100644 --- a/src/cbmpc/zk/fischlin.cpp +++ b/src/cbmpc/zk/fischlin.cpp @@ -1,6 +1,5 @@ -#include "fischlin.h" - -#include +#include +#include namespace coinbase::zk { diff --git a/src/cbmpc/zk/small_primes.cpp b/src/cbmpc/zk/small_primes.cpp index 0d7e725a..caa8fb97 100755 --- a/src/cbmpc/zk/small_primes.cpp +++ b/src/cbmpc/zk/small_primes.cpp @@ -1,6 +1,8 @@ -#include "small_primes.h" +#include // NOLINTBEGIN(*avoid-magic-numbers*) +namespace coinbase::zk { + const unsigned small_primes[small_primes_count] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, @@ -719,3 +721,5 @@ const unsigned small_primes[small_primes_count] = { 104711, 104717, 104723, 104729, }; // NOLINTEND(*avoid-magic-numbers*) + +} // namespace coinbase::zk diff --git a/src/cbmpc/zk/zk_ec.cpp b/src/cbmpc/zk/zk_ec.cpp index 3347e2cf..5f3b4f81 100644 --- a/src/cbmpc/zk/zk_ec.cpp +++ b/src/cbmpc/zk/zk_ec.cpp @@ -1,8 +1,7 @@ -#include "zk_ec.h" - -#include -#include -#include +#include +#include +#include +#include namespace coinbase::zk { @@ -78,6 +77,8 @@ error_t uc_dl_t::verify(const ecc_point_t& Q, mem_t session_id, uint64_t aux) co ecc_point_t A_sum = curve.infinity(); for (int i = 0; i < rho; i++) { + if (rv = crypto::check_right_open_range(0, z[i], q)) return rv; + bn_t sigma = bn_t::rand_bitlen(SEC_P_STAT); MODULO(q) { z_sum += sigma * z[i]; @@ -247,6 +248,8 @@ error_t uc_batch_dl_finite_difference_impl_t::verify(const std::vector -#include +#include +#include +#include namespace coinbase::zk { @@ -91,6 +90,9 @@ error_t uc_elgamal_com_t::verify(const ecc_point_t& Q, const elg_com_t& UV, mem_ ecc_point_t B_sum = curve.infinity(); for (int i = 0; i < rho; i++) { + if (rv = crypto::check_right_open_range(0, z1[i], q)) return rv; + if (rv = crypto::check_right_open_range(0, z2[i], q)) return rv; + bn_t sigma = bn_t::rand_bitlen(SEC_P_STAT); MODULO(q) { z1_sum += sigma * z1[i]; diff --git a/src/cbmpc/zk/zk_paillier.cpp b/src/cbmpc/zk/zk_paillier.cpp index 8dddb957..45a97c1a 100644 --- a/src/cbmpc/zk/zk_paillier.cpp +++ b/src/cbmpc/zk/zk_paillier.cpp @@ -1,8 +1,6 @@ -#include "zk_paillier.h" - -#include - -#include "small_primes.h" +#include +#include +#include namespace coinbase::zk { @@ -13,7 +11,7 @@ void valid_paillier_t::prove(const crypto::paillier_t& paillier, mem_t session_i bn_t N_inv = mod_t::N_inv_mod_phiN_2048(N, phi_N); - assert(SEC_P_COM == 128 && "security parameter changed, please update the code"); + static_assert(SEC_P_COM == 128, "security parameter changed, please update the code"); buf128_t k = crypto::ro::hash_string(N, session_id, aux).bitlen128(); crypto::drbg_aes_ctr_t drbg(k); @@ -506,7 +504,7 @@ error_t pdl_t::verify(const bn_t& c_key, const crypto::paillier_t& paillier, con } bn_t qq = q * q; - if (N.get_bits_count() < 2048 || N < ((qq << (SEC_P_STAT + 1)) + (qq << 1))) return coinbase::error(E_CRYPTO); + if (N.get_bits_count() < 2048 || N <= ((qq << (SEC_P_STAT + 1)) + (qq << 1))) return coinbase::error(E_CRYPTO); const mod_t& NN = paillier.get_NN(); if (rv = coinbase::crypto::check_open_range(0, c_r, NN)) return rv; diff --git a/src/cbmpc/zk/zk_pedersen.cpp b/src/cbmpc/zk/zk_pedersen.cpp index a2580b6b..57623ceb 100644 --- a/src/cbmpc/zk/zk_pedersen.cpp +++ b/src/cbmpc/zk/zk_pedersen.cpp @@ -1,6 +1,5 @@ -#include "zk_pedersen.h" - -#include "small_primes.h" +#include +#include namespace coinbase::zk { @@ -26,9 +25,9 @@ pedersen_commitment_params_t::pedersen_commitment_params_t() { }; p = mod_t(bn_t::from_bin(mem_t(PED_P_BIN, sizeof(PED_P_BIN))), /* multiplicative_dense */ true); - assert(bn_t(p).prime()); + cb_assert(bn_t(p).prime()); p_tag = mod_t((bn_t(p) - 1) / 2, /* multiplicative_dense */ true); - assert(bn_t(p_tag).prime()); + cb_assert(bn_t(p_tag).prime()); sqrt_g = 2; g = 4; diff --git a/src/cbmpc/zk/zk_unknown_order.cpp b/src/cbmpc/zk/zk_unknown_order.cpp index c908cd5b..1c5e8417 100644 --- a/src/cbmpc/zk/zk_unknown_order.cpp +++ b/src/cbmpc/zk/zk_unknown_order.cpp @@ -1,7 +1,6 @@ -#include "zk_unknown_order.h" - -#include -#include +#include +#include +#include namespace coinbase::zk { void unknown_order_dl_t::prove(const bn_t& a, const bn_t& b, const mod_t& N, const int l, const bn_t& w, mem_t sid, diff --git a/tests/dudect/dudect_util/dudect.h b/tests/dudect/dudect_util/dudect.h index c5d979e0..00d994e6 100644 --- a/tests/dudect/dudect_util/dudect.h +++ b/tests/dudect/dudect_util/dudect.h @@ -116,23 +116,23 @@ typedef struct { } ttest_ctx_t; typedef struct { - int64_t *ticks; - int64_t *exec_times; - uint8_t *input_data; - uint8_t *classes; - dudect_config_t *config; - ttest_ctx_t *ttest_ctxs[DUDECT_TESTS]; - int64_t *percentiles; + int64_t* ticks; + int64_t* exec_times; + uint8_t* input_data; + uint8_t* classes; + dudect_config_t* config; + ttest_ctx_t* ttest_ctxs[DUDECT_TESTS]; + int64_t* percentiles; } dudect_ctx_t; typedef enum { DUDECT_LEAKAGE_FOUND = 0, DUDECT_NO_LEAKAGE_EVIDENCE_YET } dudect_state_t; /* Public API */ -DUDECT_VISIBILITY inline int dudect_init(dudect_ctx_t *ctx, dudect_config_t *conf); -DUDECT_VISIBILITY inline dudect_state_t dudect_main(dudect_ctx_t *c); -DUDECT_VISIBILITY inline int dudect_free(dudect_ctx_t *ctx); -DUDECT_VISIBILITY inline void randombytes(uint8_t *x, size_t how_much); +DUDECT_VISIBILITY inline int dudect_init(dudect_ctx_t* ctx, dudect_config_t* conf); +DUDECT_VISIBILITY inline dudect_state_t dudect_main(dudect_ctx_t* c); +DUDECT_VISIBILITY inline int dudect_free(dudect_ctx_t* ctx); +DUDECT_VISIBILITY inline void randombytes(uint8_t* x, size_t how_much); DUDECT_VISIBILITY inline uint8_t randombit(void); /* Public configuration */ @@ -143,8 +143,8 @@ DUDECT_VISIBILITY inline uint8_t randombit(void); #include // kill this -extern void prepare_inputs(dudect_config_t *c, uint8_t *input_data, uint8_t *classes); -extern uint8_t do_one_computation(uint8_t *data); +extern void prepare_inputs(dudect_config_t* c, uint8_t* input_data, uint8_t* classes); +extern uint8_t do_one_computation(uint8_t* data); #endif /* DUDECT_H_INCLUDED */ @@ -176,7 +176,7 @@ extern uint8_t do_one_computation(uint8_t *data); see https://en.wikipedia.org/wiki/Welch%27s_t-test */ -static void t_push(ttest_ctx_t *ctx, double x, uint8_t clazz) { +static void t_push(ttest_ctx_t* ctx, double x, uint8_t clazz) { assert(clazz == 0 || clazz == 1); ctx->n[clazz]++; /* @@ -188,7 +188,7 @@ static void t_push(ttest_ctx_t *ctx, double x, uint8_t clazz) { ctx->m2[clazz] = ctx->m2[clazz] + delta * (x - ctx->mean[clazz]); } -static double t_compute(ttest_ctx_t *ctx) { +static double t_compute(ttest_ctx_t* ctx) { double var[2] = {0.0, 0.0}; var[0] = ctx->m2[0] / (ctx->n[0] - 1); var[1] = ctx->m2[1] / (ctx->n[1] - 1); @@ -198,7 +198,7 @@ static double t_compute(ttest_ctx_t *ctx) { return t_value; } -static void t_init(ttest_ctx_t *ctx) { +static void t_init(ttest_ctx_t* ctx) { for (int clazz = 0; clazz < 2; clazz++) { ctx->mean[clazz] = 0.0; ctx->m2[clazz] = 0.0; @@ -206,9 +206,9 @@ static void t_init(ttest_ctx_t *ctx) { } } -static int cmp(const int64_t *a, const int64_t *b) { return (int)(*a - *b); } +static int cmp(const int64_t* a, const int64_t* b) { return (int)(*a - *b); } -static int64_t percentile(int64_t *a_sorted, double which, size_t size) { +static int64_t percentile(int64_t* a_sorted, double which, size_t size) { size_t array_position = (size_t)((double)size * (double)which); assert(array_position < size); return a_sorted[array_position]; @@ -220,8 +220,8 @@ static int64_t percentile(int64_t *a_sorted, double which, size_t size) { the measurements distribution, but there's not more science than that. */ -static void prepare_percentiles(dudect_ctx_t *ctx) { - qsort(ctx->exec_times, ctx->config->number_measurements, sizeof(int64_t), (int (*)(const void *, const void *))cmp); +static void prepare_percentiles(dudect_ctx_t* ctx) { + qsort(ctx->exec_times, ctx->config->number_measurements, sizeof(int64_t), (int (*)(const void*, const void*))cmp); for (size_t i = 0; i < DUDECT_NUMBER_PERCENTILES; i++) { ctx->percentiles[i] = percentile(ctx->exec_times, 1 - (pow(0.5, 10 * (double)(i + 1) / DUDECT_NUMBER_PERCENTILES)), ctx->config->number_measurements); @@ -229,7 +229,7 @@ static void prepare_percentiles(dudect_ctx_t *ctx) { } /* this comes from ebacs */ -void randombytes(uint8_t *x, size_t how_much) { +void randombytes(uint8_t* x, size_t how_much) { ssize_t i; static int fd = -1; @@ -334,7 +334,7 @@ static inline int64_t cpucycles(void) { #define t_threshold_bananas 500 // test failed, with overwhelming probability #define t_threshold_moderate 10 // test failed. Pankaj likes 4.5 but let's be more lenient -static void measure(dudect_ctx_t *ctx) { +static void measure(dudect_ctx_t* ctx) { for (size_t i = 0; i < ctx->config->number_measurements; i++) { ctx->ticks[i] = cpucycles(); do_one_computation(ctx->input_data + i * ctx->config->chunk_size); @@ -345,7 +345,7 @@ static void measure(dudect_ctx_t *ctx) { } } -static void update_statistics(dudect_ctx_t *ctx) { +static void update_statistics(dudect_ctx_t* ctx) { for (size_t i = 10 /* discard the first few measurements */; i < (ctx->config->number_measurements - 1); i++) { int64_t difference = ctx->exec_times[i]; @@ -373,7 +373,7 @@ static void update_statistics(dudect_ctx_t *ctx) { } #if DUDECT_TRACE -static void report_test(ttest_ctx_t *x) { +static void report_test(ttest_ctx_t* x) { if (x->n[0] > DUDECT_ENOUGH_MEASUREMENTS) { double tval = t_compute(x); printf(" abs(t): %4.2f, number measurements: %f\n", tval, x->n[0] + x->n[1]); @@ -383,7 +383,7 @@ static void report_test(ttest_ctx_t *x) { } #endif /* DUDECT_TRACE */ -static ttest_ctx_t *max_test(dudect_ctx_t *ctx) { +static ttest_ctx_t* max_test(dudect_ctx_t* ctx) { size_t ret = 0; double max = 0; for (size_t i = 0; i < DUDECT_TESTS; i++) { @@ -398,7 +398,7 @@ static ttest_ctx_t *max_test(dudect_ctx_t *ctx) { return ctx->ttest_ctxs[ret]; } -static dudect_state_t report(dudect_ctx_t *ctx) { +static dudect_state_t report(dudect_ctx_t* ctx) { #if DUDECT_TRACE for (size_t i = 0; i < DUDECT_TESTS; i++) { @@ -417,7 +417,7 @@ static dudect_state_t report(dudect_ctx_t *ctx) { report_test(ctx->ttest_ctxs[1 + DUDECT_NUMBER_PERCENTILES]); #endif /* DUDECT_TRACE */ - ttest_ctx_t *t = max_test(ctx); + ttest_ctx_t* t = max_test(ctx); double max_t = fabs(t_compute(t)); double number_traces_max_t = t->n[0] + t->n[1]; double max_tau = max_t / sqrt(number_traces_max_t); @@ -464,7 +464,7 @@ static dudect_state_t report(dudect_ctx_t *ctx) { return DUDECT_NO_LEAKAGE_EVIDENCE_YET; } -dudect_state_t dudect_main(dudect_ctx_t *ctx) { +dudect_state_t dudect_main(dudect_ctx_t* ctx) { prepare_inputs(ctx->config, ctx->input_data, ctx->classes); measure(ctx); @@ -483,22 +483,22 @@ dudect_state_t dudect_main(dudect_ctx_t *ctx) { return ret; } -int dudect_init(dudect_ctx_t *ctx, dudect_config_t *conf) { - ctx->config = (dudect_config_t *)calloc(1, sizeof(*conf)); +int dudect_init(dudect_ctx_t* ctx, dudect_config_t* conf) { + ctx->config = (dudect_config_t*)calloc(1, sizeof(*conf)); ctx->config->number_measurements = conf->number_measurements; ctx->config->chunk_size = conf->chunk_size; - ctx->ticks = (int64_t *)calloc(ctx->config->number_measurements, sizeof(int64_t)); - ctx->exec_times = (int64_t *)calloc(ctx->config->number_measurements, sizeof(int64_t)); - ctx->classes = (uint8_t *)calloc(ctx->config->number_measurements, sizeof(uint8_t)); - ctx->input_data = (uint8_t *)calloc(ctx->config->number_measurements * ctx->config->chunk_size, sizeof(uint8_t)); + ctx->ticks = (int64_t*)calloc(ctx->config->number_measurements, sizeof(int64_t)); + ctx->exec_times = (int64_t*)calloc(ctx->config->number_measurements, sizeof(int64_t)); + ctx->classes = (uint8_t*)calloc(ctx->config->number_measurements, sizeof(uint8_t)); + ctx->input_data = (uint8_t*)calloc(ctx->config->number_measurements * ctx->config->chunk_size, sizeof(uint8_t)); for (int i = 0; i < DUDECT_TESTS; i++) { - ctx->ttest_ctxs[i] = (ttest_ctx_t *)calloc(1, sizeof(ttest_ctx_t)); + ctx->ttest_ctxs[i] = (ttest_ctx_t*)calloc(1, sizeof(ttest_ctx_t)); assert(ctx->ttest_ctxs[i]); t_init(ctx->ttest_ctxs[i]); } - ctx->percentiles = (int64_t *)calloc(DUDECT_NUMBER_PERCENTILES, sizeof(int64_t)); + ctx->percentiles = (int64_t*)calloc(DUDECT_NUMBER_PERCENTILES, sizeof(int64_t)); assert(ctx->ticks); assert(ctx->exec_times); @@ -509,7 +509,7 @@ int dudect_init(dudect_ctx_t *ctx, dudect_config_t *conf) { return 0; } -int dudect_free(dudect_ctx_t *ctx) { +int dudect_free(dudect_ctx_t* ctx) { for (int i = 0; i < DUDECT_TESTS; i++) { free(ctx->ttest_ctxs[i]); } diff --git a/tests/dudect/dudect_util/dudect_implementation.h b/tests/dudect/dudect_util/dudect_implementation.h index eb66eb18..a1f196d5 100644 --- a/tests/dudect/dudect_util/dudect_implementation.h +++ b/tests/dudect/dudect_util/dudect_implementation.h @@ -1,4 +1,4 @@ -#include +#include namespace coinbase::dudect { diff --git a/tests/public_headers_smoke.cpp b/tests/public_headers_smoke.cpp new file mode 100644 index 00000000..d71a17b4 --- /dev/null +++ b/tests/public_headers_smoke.cpp @@ -0,0 +1,44 @@ +// Compile-only smoke test: public headers must not depend on internal headers. +// +// This TU is built with include paths limited to: +// - /include +// - OpenSSL include dir +// +// If any public header includes , this will fail to compile. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/tests/unit/api/test_ecdsa2pc.cpp b/tests/unit/api/test_ecdsa2pc.cpp new file mode 100644 index 00000000..d50a6355 --- /dev/null +++ b/tests/unit/api/test_ecdsa2pc.cpp @@ -0,0 +1,827 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "utils/local_network/network_context.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::curve_id; +using coinbase::api::data_transport_i; +using coinbase::api::party_idx_t; + +using coinbase::api::ecdsa_2p::party_t; + +using coinbase::testutils::mpc_net_context_t; + +// Key blob codec (mirrors src/cbmpc/api/ecdsa2pc.cpp). +struct key_blob_v1_t { + uint32_t version = 1; + uint32_t role = 0; // 0=p1, 1=p2 + uint32_t curve = 0; // coinbase::api::curve_id + + buf_t Q_compressed; + coinbase::crypto::bn_t x_share; + coinbase::crypto::bn_t c_key; + coinbase::crypto::paillier_t paillier; + + void convert(coinbase::converter_t& c) { c.convert(version, role, curve, Q_compressed, x_share, c_key, paillier); } +}; + +class local_api_transport_t final : public data_transport_i { + public: + explicit local_api_transport_t(std::shared_ptr ctx) : ctx_(std::move(ctx)) {} + + error_t send(party_idx_t receiver, mem_t msg) override { + ctx_->send(receiver, msg); + return SUCCESS; + } + + error_t receive(party_idx_t sender, buf_t& msg) override { return ctx_->receive(sender, msg); } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + std::vector s; + s.reserve(senders.size()); + for (auto x : senders) s.push_back(static_cast(x)); + return ctx_->receive_all(s, msgs); + } + + private: + std::shared_ptr ctx_; +}; + +template +static void run_2pc(const std::shared_ptr& c1, const std::shared_ptr& c2, F1&& f1, + F2&& f2, error_t& out_rv1, error_t& out_rv2) { + c1->reset(); + c2->reset(); + + std::atomic aborted{false}; + + std::thread t1([&] { + out_rv1 = f1(); + if (out_rv1 && !aborted.exchange(true)) { + c1->abort(); + c2->abort(); + } + }); + std::thread t2([&] { + out_rv2 = f2(); + if (out_rv2 && !aborted.exchange(true)) { + c1->abort(); + c2->abort(); + } + }); + + t1.join(); + t2.join(); +} + +static void exercise_curve(curve_id curve, const coinbase::crypto::ecurve_t& verify_curve) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::dkg(job1, curve, key_blob_1); }, + [&] { return coinbase::api::ecdsa_2p::dkg(job2, curve, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t pub1; + buf_t pub2; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(key_blob_1, pub1), SUCCESS); + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(key_blob_2, pub2), SUCCESS); + ASSERT_EQ(pub1, pub2); + + // Deterministic 32-byte "hash" for testing. + buf_t msg_hash(32); + for (int i = 0; i < msg_hash.size(); i++) msg_hash[i] = static_cast(i); + + buf_t sid1; + buf_t sid2; + buf_t sig1; + buf_t sig2; + + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::sign(job1, key_blob_1, msg_hash, sid1, sig1); }, + [&] { return coinbase::api::ecdsa_2p::sign(job2, key_blob_2, msg_hash, sid2, sig2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sid1, sid2); + EXPECT_GT(sig1.size(), 0); + EXPECT_EQ(sig2.size(), 0); + + // Verify the returned signature against the extracted public key. + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(verify_curve, pub1), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + ASSERT_EQ(verify_key.verify(msg_hash, sig1), SUCCESS); + + // Refresh shares, ensuring public key stays the same. + buf_t new_key_blob_1; + buf_t new_key_blob_2; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::refresh(job1, key_blob_1, new_key_blob_1); }, + [&] { return coinbase::api::ecdsa_2p::refresh(job2, key_blob_2, new_key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t new_pub1; + buf_t new_pub2; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(new_key_blob_1, new_pub1), SUCCESS); + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(new_key_blob_2, new_pub2), SUCCESS); + EXPECT_EQ(new_pub1, pub1); + EXPECT_EQ(new_pub2, pub2); + EXPECT_EQ(new_pub1, new_pub2); + + // Sign again with refreshed shares. + buf_t sid3; + buf_t sid4; + buf_t sig3; + buf_t sig4; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::sign(job1, new_key_blob_1, msg_hash, sid3, sig3); }, + [&] { return coinbase::api::ecdsa_2p::sign(job2, new_key_blob_2, msg_hash, sid4, sig4); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sid3, sid4); + EXPECT_GT(sig3.size(), 0); + EXPECT_EQ(sig4.size(), 0); + ASSERT_EQ(verify_key.verify(msg_hash, sig3), SUCCESS); +} + +static void exercise_detach_attach(curve_id curve, const coinbase::crypto::ecurve_t& verify_curve) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::dkg(job1, curve, key_blob_1); }, + [&] { return coinbase::api::ecdsa_2p::dkg(job2, curve, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + // Refresh (exercise detach/attach on refreshed blobs too). + buf_t refreshed_1; + buf_t refreshed_2; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::refresh(job1, key_blob_1, refreshed_1); }, + [&] { return coinbase::api::ecdsa_2p::refresh(job2, key_blob_2, refreshed_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t pub1; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(refreshed_1, pub1), SUCCESS); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(verify_curve, pub1), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + // Detach into scalar-redacted blob + variable-length scalar. + buf_t public_1; + buf_t public_2; + buf_t x1; + buf_t x2; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(refreshed_1, public_1, x1), SUCCESS); + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(refreshed_2, public_2, x2), SUCCESS); + EXPECT_GT(public_1.size(), 0); + EXPECT_GT(public_2.size(), 0); + EXPECT_GT(x1.size(), 0); + EXPECT_GT(x2.size(), 0); + + // Capture share points before detaching (public blobs no longer carry them). + buf_t Qi_full_1; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_share_compressed(refreshed_1, Qi_full_1), SUCCESS); + + buf_t Qi_full_2; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_share_compressed(refreshed_2, Qi_full_2), SUCCESS); + + // Public blob should not be usable for signing. + buf_t msg_hash(32); + for (int i = 0; i < msg_hash.size(); i++) msg_hash[i] = static_cast(i); + { + class unused_transport_t final : public data_transport_i { + public: + error_t send(party_idx_t /*receiver*/, mem_t /*msg*/) override { return E_GENERAL; } + error_t receive(party_idx_t /*sender*/, buf_t& /*msg*/) override { return E_GENERAL; } + error_t receive_all(const std::vector& /*senders*/, std::vector& /*msgs*/) override { + return E_GENERAL; + } + }; + unused_transport_t t; + const coinbase::api::job_2p_t bad_job{party_t::p1, "p1", "p2", t}; + buf_t sid; + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(bad_job, public_1, msg_hash, sid, sig), SUCCESS); + } + + // Attach scalars back and sign. + buf_t merged_1; + buf_t merged_2; + ASSERT_EQ(coinbase::api::ecdsa_2p::attach_private_scalar(public_1, x1, Qi_full_1, merged_1), SUCCESS); + ASSERT_EQ(coinbase::api::ecdsa_2p::attach_private_scalar(public_2, x2, Qi_full_2, merged_2), SUCCESS); + + // Leading-zero padded encoding should also be accepted (variable-length encoding). + buf_t x1_padded(static_cast(x1.size()) + 1); + x1_padded[0] = 0x00; + std::memcpy(x1_padded.data() + 1, x1.data(), static_cast(x1.size())); + buf_t merged_1_padded; + ASSERT_EQ(coinbase::api::ecdsa_2p::attach_private_scalar(public_1, x1_padded, Qi_full_1, merged_1_padded), SUCCESS); + + buf_t sid1; + buf_t sid2; + buf_t sig1; + buf_t sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::sign(job1, merged_1, msg_hash, sid1, sig1); }, + [&] { return coinbase::api::ecdsa_2p::sign(job2, merged_2, msg_hash, sid2, sig2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sid1, sid2); + EXPECT_GT(sig1.size(), 0); + EXPECT_EQ(sig2.size(), 0); + ASSERT_EQ(verify_key.verify(msg_hash, sig1), SUCCESS); + + // Negative: wrong scalar should fail to attach. + buf_t bad_x1 = x1; + bad_x1[0] ^= 0x01; + buf_t bad_merged; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(public_1, bad_x1, Qi_full_1, bad_merged), SUCCESS); +} + +} // namespace + +TEST(ApiEcdsa2pc, DkgSignRefreshSign) { + exercise_curve(curve_id::secp256k1, coinbase::crypto::curve_secp256k1); + exercise_curve(curve_id::p256, coinbase::crypto::curve_p256); +} + +TEST(ApiEcdsa2pc, KeyBlobPrivScalar_NoPubSign) { + exercise_detach_attach(curve_id::secp256k1, coinbase::crypto::curve_secp256k1); + exercise_detach_attach(curve_id::p256, coinbase::crypto::curve_p256); +} + +TEST(ApiEcdsa2pc, UnsupportedCurveRejected) { + class unused_transport_t final : public data_transport_i { + public: + error_t send(party_idx_t /*receiver*/, mem_t /*msg*/) override { return E_GENERAL; } + error_t receive(party_idx_t /*sender*/, buf_t& /*msg*/) override { return E_GENERAL; } + error_t receive_all(const std::vector& /*senders*/, std::vector& /*msgs*/) override { + return E_GENERAL; + } + }; + + unused_transport_t t; + buf_t key_blob; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + const error_t rv = coinbase::api::ecdsa_2p::dkg(job, static_cast(42), key_blob); + EXPECT_EQ(rv, E_BADARG); +} + +TEST(ApiEcdsa2pc, InvalidKeyBlobRejected) { + buf_t pub_key; + const error_t rv = coinbase::api::ecdsa_2p::get_public_key_compressed(mem_t(), pub_key); + EXPECT_NE(rv, SUCCESS); +} + +TEST(ApiEcdsa2pc, RejectTamperedP1ShareBinding) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::dkg(job1, curve_id::secp256k1, key_blob_1); }, + [&] { return coinbase::api::ecdsa_2p::dkg(job2, curve_id::secp256k1, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + key_blob_v1_t blob; + ASSERT_EQ(coinbase::convert(blob, key_blob_1), SUCCESS); + + // Tamper with P1's share while keeping `c_key` unchanged; deserialization should reject it. + blob.x_share = blob.x_share + 1; + buf_t malformed = coinbase::convert(blob); + + buf_t pub_key; + EXPECT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(malformed, pub_key), E_FORMAT); +} + +TEST(ApiEcdsa2pc, TransportSendFailNoDeadlock) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + class fail_first_send_transport_t final : public data_transport_i { + public: + explicit fail_first_send_transport_t(std::shared_ptr ctx) : ctx_(std::move(ctx)) {} + + error_t send(party_idx_t /*receiver*/, mem_t /*msg*/) override { + if (!failed_.exchange(true)) return E_NET_GENERAL; + return E_NET_GENERAL; + } + + error_t receive(party_idx_t sender, buf_t& msg) override { return ctx_->receive(sender, msg); } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + std::vector s; + s.reserve(senders.size()); + for (auto x : senders) s.push_back(static_cast(x)); + return ctx_->receive_all(s, msgs); + } + + private: + std::shared_ptr ctx_; + std::atomic failed_{false}; + }; + + fail_first_send_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::dkg(job1, curve_id::secp256k1, key_blob_1); }, + [&] { return coinbase::api::ecdsa_2p::dkg(job2, curve_id::secp256k1, key_blob_2); }, rv1, rv2); + + EXPECT_NE(rv1, SUCCESS); + EXPECT_NE(rv2, SUCCESS); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +// ========================================================================== +// Negative test helpers +// ========================================================================== + +namespace { + +class noop_transport_t final : public data_transport_i { + public: + error_t send(party_idx_t, mem_t) override { return E_GENERAL; } + error_t receive(party_idx_t, buf_t&) override { return E_GENERAL; } + error_t receive_all(const std::vector&, std::vector&) override { return E_GENERAL; } +}; + +static coinbase::api::job_2p_t make_noop_job(noop_transport_t& t, party_t self = party_t::p1) { + return {self, "p1", "p2", t}; +} + +static void generate_key_blobs(curve_id curve, buf_t& blob1, buf_t& blob2) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::dkg(job1, curve, blob1); }, + [&] { return coinbase::api::ecdsa_2p::dkg(job2, curve, blob2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); +} + +} // namespace + +class ApiEcdsa2pcNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { generate_key_blobs(curve_id::secp256k1, blob1_, blob2_); } + + static buf_t blob1_; + static buf_t blob2_; +}; + +buf_t ApiEcdsa2pcNegWithBlobs::blob1_; +buf_t ApiEcdsa2pcNegWithBlobs::blob2_; + +// ========================================================================== +// Negative: dkg +// ========================================================================== + +TEST(ApiEcdsa2pc, NegDkgEdwardsCurve) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t key_blob; + EXPECT_NE(coinbase::api::ecdsa_2p::dkg(job, curve_id::ed25519, key_blob), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegDkgInvalidCurveValues) { + noop_transport_t t; + auto job = make_noop_job(t); + for (uint32_t val : {0u, 4u, 100u, 255u}) { + buf_t key_blob; + EXPECT_NE(coinbase::api::ecdsa_2p::dkg(job, static_cast(val), key_blob), SUCCESS) + << "Expected failure for curve_id=" << val; + } +} + +// ========================================================================== +// Negative: get_public_key_compressed +// ========================================================================== + +TEST(ApiEcdsa2pc, NegGetPubKeyGarbageBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04}; + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_key_compressed(mem_t(garbage, sizeof(garbage)), pub), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegGetPubKeyOversizedBlob) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_key_compressed(big, pub), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegGetPubKeyAllZeroBlob) { + uint8_t zeros[64] = {}; + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_key_compressed(mem_t(zeros, sizeof(zeros)), pub), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegGetPubKeyOneByte) { + uint8_t one = 0x00; + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_key_compressed(mem_t(&one, 1), pub), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegKeyBlobVersionZero) { + buf_t tampered(blob1_.size()); + std::memcpy(tampered.data(), blob1_.data(), static_cast(blob1_.size())); + tampered[0] = 0x00; + tampered[1] = 0x00; + tampered[2] = 0x00; + tampered[3] = 0x00; + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_key_compressed(tampered, pub), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegKeyBlobWrongVersion) { + buf_t tampered(blob1_.size()); + std::memcpy(tampered.data(), blob1_.data(), static_cast(blob1_.size())); + tampered[0] = 0x00; + tampered[1] = 0x00; + tampered[2] = 0x00; + tampered[3] = 0x02; + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_key_compressed(tampered, pub), SUCCESS); +} + +// ========================================================================== +// Negative: get_public_share_compressed +// ========================================================================== + +TEST(ApiEcdsa2pc, NegGetPubShareAllZeroBlob) { + uint8_t zeros[64] = {}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_share_compressed(mem_t(zeros, sizeof(zeros)), out), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegGetPubShareEmptyBlob) { + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_share_compressed(mem_t(), out), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegGetPubShareGarbageBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_share_compressed(mem_t(garbage, sizeof(garbage)), out), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegGetPubShareOversizedBlob) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::get_public_share_compressed(big, out), SUCCESS); +} + +// ========================================================================== +// Negative: detach_private_scalar +// ========================================================================== + +TEST(ApiEcdsa2pc, NegDetachAllZeroBlob) { + uint8_t zeros[64] = {}; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_2p::detach_private_scalar(mem_t(zeros, sizeof(zeros)), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegDetachEmptyBlob) { + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_2p::detach_private_scalar(mem_t(), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegDetachGarbageBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04}; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_2p::detach_private_scalar(mem_t(garbage, sizeof(garbage)), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegDetachOversizedBlob) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_2p::detach_private_scalar(big, pub_blob, scalar), SUCCESS); +} + +// ========================================================================== +// Negative: attach_private_scalar +// ========================================================================== + +TEST(ApiEcdsa2pc, NegAttachEmptyPublicKeyBlob) { + uint8_t scalar[] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(mem_t(), mem_t(scalar, 1), mem_t(point, 33), out), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegAttachGarbagePublicKeyBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + uint8_t scalar[] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(mem_t(garbage, sizeof(garbage)), mem_t(scalar, 1), + mem_t(point, 33), out), + SUCCESS); +} + +TEST(ApiEcdsa2pc, NegAttachOversizedPublicKeyBlob) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + uint8_t scalar[] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(big, mem_t(scalar, 1), mem_t(point, 33), out), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachEmptyPrivateScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_share_compressed(blob1_, Qi), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub_blob, mem_t(), Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachGarbagePrivateScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_share_compressed(blob1_, Qi), SUCCESS); + + uint8_t garbage[512]; + std::memset(garbage, 0xFF, sizeof(garbage)); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub_blob, mem_t(garbage, sizeof(garbage)), Qi, out), + SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachGarbagePublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub_blob, x), SUCCESS); + + uint8_t bad_point[33]; + bad_point[0] = 0x05; + std::memset(bad_point + 1, 0xAB, 32); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub_blob, x, mem_t(bad_point, 33), out), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachEmptyPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub_blob, x), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub_blob, x, mem_t(), out), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachSwappedScalars) { + buf_t pub1, x1, pub2, x2; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub1, x1), SUCCESS); + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob2_, pub2, x2), SUCCESS); + + buf_t Qi1; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_share_compressed(blob1_, Qi1), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub1, x2, Qi1, out), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachSwappedPublicShares) { + buf_t pub1, x1; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub1, x1), SUCCESS); + + buf_t Qi2; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_share_compressed(blob2_, Qi2), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub1, x1, Qi2, out), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachZeroScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_share_compressed(blob1_, Qi), SUCCESS); + + uint8_t zero[32] = {}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub_blob, mem_t(zero, 32), Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachSingleByteZeroScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_share_compressed(blob1_, Qi), SUCCESS); + + uint8_t zero_byte = 0x00; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub_blob, mem_t(&zero_byte, 1), Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegAttachAllZeroPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_2p::detach_private_scalar(blob1_, pub_blob, x), SUCCESS); + + uint8_t zero_point[33] = {}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_2p::attach_private_scalar(pub_blob, x, mem_t(zero_point, 33), out), SUCCESS); +} + +// ========================================================================== +// Negative: sign (input validation, pre-protocol) +// ========================================================================== + +TEST(ApiEcdsa2pc, NegSignEmptyKeyBlob) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t msg_hash(32); + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, mem_t(), msg_hash, sid, sig), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegSignGarbageKeyBlob) { + noop_transport_t t; + auto job = make_noop_job(t); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg_hash(32); + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, mem_t(garbage, sizeof(garbage)), msg_hash, sid, sig), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegSignOversizedKeyBlob) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t msg_hash(32); + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, big, msg_hash, sid, sig), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegSignAllZeroKeyBlob) { + noop_transport_t t; + auto job = make_noop_job(t); + uint8_t zeros[64] = {}; + buf_t msg_hash(32); + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, mem_t(zeros, sizeof(zeros)), msg_hash, sid, sig), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegSignEmptyMsgHash) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, blob1_, mem_t(), sid, sig), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegSignOversizedMsgHash) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t huge_hash(65); + std::memset(huge_hash.data(), 0x42, static_cast(huge_hash.size())); + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, blob1_, huge_hash, sid, sig), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegSignRoleMismatch) { + noop_transport_t t; + auto job = make_noop_job(t, party_t::p2); + buf_t msg_hash(32); + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, blob1_, msg_hash, sid, sig), SUCCESS); +} + +// ========================================================================== +// Negative: refresh (input validation, pre-protocol) +// ========================================================================== + +TEST(ApiEcdsa2pc, NegRefreshEmptyKeyBlob) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t new_blob; + EXPECT_NE(coinbase::api::ecdsa_2p::refresh(job, mem_t(), new_blob), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegRefreshGarbageKeyBlob) { + noop_transport_t t; + auto job = make_noop_job(t); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t new_blob; + EXPECT_NE(coinbase::api::ecdsa_2p::refresh(job, mem_t(garbage, sizeof(garbage)), new_blob), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegRefreshOversizedKeyBlob) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t new_blob; + EXPECT_NE(coinbase::api::ecdsa_2p::refresh(job, big, new_blob), SUCCESS); +} + +TEST(ApiEcdsa2pc, NegRefreshAllZeroKeyBlob) { + noop_transport_t t; + auto job = make_noop_job(t); + uint8_t zeros[64] = {}; + buf_t new_blob; + EXPECT_NE(coinbase::api::ecdsa_2p::refresh(job, mem_t(zeros, sizeof(zeros)), new_blob), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegRefreshRoleMismatch) { + noop_transport_t t; + auto job = make_noop_job(t, party_t::p2); + buf_t new_blob; + EXPECT_NE(coinbase::api::ecdsa_2p::refresh(job, blob1_, new_blob), SUCCESS); +} diff --git a/tests/unit/api/test_ecdsa_mp.cpp b/tests/unit/api/test_ecdsa_mp.cpp new file mode 100644 index 00000000..c210bf14 --- /dev/null +++ b/tests/unit/api/test_ecdsa_mp.cpp @@ -0,0 +1,673 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "utils/local_network/network_context.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::curve_id; +using coinbase::api::data_transport_i; +using coinbase::api::job_mp_t; +using coinbase::api::party_idx_t; + +using coinbase::testutils::mpc_net_context_t; + +class local_api_transport_t final : public data_transport_i { + public: + explicit local_api_transport_t(std::shared_ptr ctx) : ctx_(std::move(ctx)) {} + + error_t send(party_idx_t receiver, mem_t msg) override { + ctx_->send(receiver, msg); + return SUCCESS; + } + + error_t receive(party_idx_t sender, buf_t& msg) override { return ctx_->receive(sender, msg); } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + std::vector s; + s.reserve(senders.size()); + for (auto x : senders) s.push_back(static_cast(x)); + return ctx_->receive_all(s, msgs); + } + + private: + std::shared_ptr ctx_; +}; + +template +static void run_mp(const std::vector>& peers, F&& f, std::vector& out_rv) { + for (const auto& p : peers) p->reset(); + + out_rv.assign(peers.size(), UNINITIALIZED_ERROR); + std::atomic aborted{false}; + std::vector threads; + threads.reserve(peers.size()); + + for (size_t i = 0; i < peers.size(); i++) { + threads.emplace_back([&, i] { + out_rv[i] = f(static_cast(i)); + if (out_rv[i] && !aborted.exchange(true)) { + for (const auto& p : peers) p->abort(); + } + }); + } + for (auto& t : threads) t.join(); +} + +static void exercise_4p() { + constexpr int n = 4; + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + std::vector keys(n); + std::vector new_keys(n); + std::vector sids(n); + std::vector sigs(n); + std::vector new_sigs(n); + std::vector rvs; + + buf_t msg_hash(32); + for (int i = 0; i < msg_hash.size(); i++) msg_hash[i] = static_cast(i); + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_additive(job, curve_id::secp256k1, keys[static_cast(i)], + sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + buf_t pub0; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(keys[0], pub0), SUCCESS); + EXPECT_EQ(pub0.size(), 33); + for (int i = 1; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(keys[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub0), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + // Change the party ordering ("role" indices) between protocols. + // Example: a party that was at index 1 ("p1") moves to index 2. + const std::vector name_views2 = {names[0], names[2], names[1], names[3]}; + // Map new role index -> old role index (DKG) for the same party name. + const int perm[n] = {0, 2, 1, 3}; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::sign_additive(job, keys[static_cast(perm[i])], msg_hash, + /*sig_receiver=*/2, sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + EXPECT_GT(sigs[2].size(), 0); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + EXPECT_EQ(sigs[static_cast(i)].size(), 0); + } + ASSERT_EQ(verify_key.verify(msg_hash, sigs[2]), SUCCESS); + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::refresh_additive(job, sids[static_cast(perm[i])], + keys[static_cast(perm[i])], + new_keys[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(new_keys[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::sign_additive(job, new_keys[static_cast(i)], msg_hash, + /*sig_receiver=*/2, new_sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + EXPECT_GT(new_sigs[2].size(), 0); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + EXPECT_EQ(new_sigs[static_cast(i)].size(), 0); + } + ASSERT_EQ(verify_key.verify(msg_hash, new_sigs[2]), SUCCESS); +} + +class dummy_transport_t final : public data_transport_i { + public: + error_t send(party_idx_t /*receiver*/, mem_t /*msg*/) override { return E_GENERAL; } + error_t receive(party_idx_t /*sender*/, buf_t& /*msg*/) override { return E_GENERAL; } + error_t receive_all(const std::vector& /*senders*/, std::vector& /*msgs*/) override { + return E_GENERAL; + } +}; + +} // namespace + +TEST(ApiEcdsaMp, DkgSignRefreshSign4p) { exercise_4p(); } + +TEST(ApiEcdsaMp, RejectsInvalidJobSelf) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t bad_job{/*self=*/3, names, t}; + + buf_t key; + buf_t sid; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_additive(bad_job, curve_id::secp256k1, key, sid), E_BADARG); +} + +TEST(ApiEcdsaMp, RejectsDuplicatePartyNames) { + dummy_transport_t t; + std::vector names = {"dup", "dup"}; + job_mp_t bad_job{/*self=*/0, names, t}; + + buf_t key; + buf_t sid; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_additive(bad_job, curve_id::secp256k1, key, sid), E_BADARG); +} + +TEST(ApiEcdsaMp, RejectsInvalidSigReceiver) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sig; + EXPECT_EQ(coinbase::api::ecdsa_mp::sign_additive(job, mem_t(), mem_t(), /*sig_receiver=*/5, sig), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ +// +TEST(ApiEcdsaMp, NegDkgEdwardsCurve) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_additive(job, curve_id::ed25519, key, sid), E_BADARG); +} + +TEST(ApiEcdsaMp, NegDkgCurveZero) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_NE(coinbase::api::ecdsa_mp::dkg_additive(job, static_cast(0), key, sid), SUCCESS); +} + +TEST(ApiEcdsaMp, NegDkgCurveFour) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_NE(coinbase::api::ecdsa_mp::dkg_additive(job, static_cast(4), key, sid), SUCCESS); +} + +TEST(ApiEcdsaMp, NegDkgCurve255) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_NE(coinbase::api::ecdsa_mp::dkg_additive(job, static_cast(255), key, sid), SUCCESS); +} + +TEST(ApiEcdsaMp, NegDkgSingleParty) { + dummy_transport_t t; + std::vector names = {"p0"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_additive(job, curve_id::secp256k1, key, sid), E_BADARG); +} + +TEST(ApiEcdsaMp, NegDkgEmptyPartyNames) { + dummy_transport_t t; + std::vector names; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_additive(job, curve_id::secp256k1, key, sid), E_BADARG); +} + +TEST(ApiEcdsaMp, NegGetPubKeyEmpty) { + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_key_compressed(mem_t(), pub), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubKeyGarbage) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_key_compressed(mem_t(garbage, sizeof(garbage)), pub), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubKeyAllZero) { + uint8_t zeros[64] = {}; + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_key_compressed(mem_t(zeros, sizeof(zeros)), pub), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubKeyOneByte) { + uint8_t one = 0x00; + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_key_compressed(mem_t(&one, 1), pub), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubKeyOversized) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t pub; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_key_compressed(big, pub), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubShareEmpty) { + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_share_compressed(mem_t(), out), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubShareGarbage) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_share_compressed(mem_t(garbage, sizeof(garbage)), out), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubShareAllZero) { + uint8_t zeros[64] = {}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_share_compressed(mem_t(zeros, sizeof(zeros)), out), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubShareOneByte) { + uint8_t one = 0x00; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_share_compressed(mem_t(&one, 1), out), SUCCESS); +} + +TEST(ApiEcdsaMp, NegGetPubShareOversized) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::get_public_share_compressed(big, out), SUCCESS); +} + +TEST(ApiEcdsaMp, NegDetachEmpty) { + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_mp::detach_private_scalar(mem_t(), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEcdsaMp, NegDetachGarbage) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_mp::detach_private_scalar(mem_t(garbage, sizeof(garbage)), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEcdsaMp, NegDetachAllZero) { + uint8_t zeros[64] = {}; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_mp::detach_private_scalar(mem_t(zeros, sizeof(zeros)), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEcdsaMp, NegDetachOneByte) { + uint8_t one = 0x00; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_mp::detach_private_scalar(mem_t(&one, 1), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEcdsaMp, NegDetachOversized) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::ecdsa_mp::detach_private_scalar(big, pub_blob, scalar), SUCCESS); +} + +TEST(ApiEcdsaMp, NegAttachEmptyPublicKeyBlob) { + uint8_t scalar[32] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(mem_t(), mem_t(scalar, 32), mem_t(point, 33), out), SUCCESS); +} + +TEST(ApiEcdsaMp, NegAttachGarbagePublicKeyBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + uint8_t scalar[32] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(mem_t(garbage, sizeof(garbage)), mem_t(scalar, 32), + mem_t(point, 33), out), + SUCCESS); +} + +TEST(ApiEcdsaMp, NegAttachOversizedPublicKeyBlob) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + uint8_t scalar[32] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(big, mem_t(scalar, 32), mem_t(point, 33), out), SUCCESS); +} + +TEST(ApiEcdsaMp, NegRefreshEmptyKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_additive(job, sid, mem_t(), new_blob), SUCCESS); +} + +TEST(ApiEcdsaMp, NegRefreshGarbageKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_additive(job, sid, mem_t(garbage, sizeof(garbage)), new_blob), SUCCESS); +} + +TEST(ApiEcdsaMp, NegRefreshAllZeroKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t zeros[64] = {}; + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_additive(job, sid, mem_t(zeros, sizeof(zeros)), new_blob), SUCCESS); +} + +TEST(ApiEcdsaMp, NegRefreshOneByteKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t one = 0x00; + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_additive(job, sid, mem_t(&one, 1), new_blob), SUCCESS); +} + +TEST(ApiEcdsaMp, NegRefreshOversizedKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_additive(job, sid, big, new_blob), SUCCESS); +} + +TEST(ApiEcdsaMp, NegSignEmptyKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t msg_hash(32); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_additive(job, mem_t(), msg_hash, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEcdsaMp, NegSignGarbageKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg_hash(32); + buf_t sig; + EXPECT_NE( + coinbase::api::ecdsa_mp::sign_additive(job, mem_t(garbage, sizeof(garbage)), msg_hash, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEcdsaMp, NegSignAllZeroKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t zeros[64] = {}; + buf_t msg_hash(32); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_additive(job, mem_t(zeros, sizeof(zeros)), msg_hash, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEcdsaMp, NegSignOneByteKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t one = 0x00; + buf_t msg_hash(32); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_additive(job, mem_t(&one, 1), msg_hash, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEcdsaMp, NegSignOversizedKeyBlob) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t msg_hash(32); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_additive(job, big, msg_hash, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEcdsaMp, NegSignSigReceiverNegative) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sig; + EXPECT_EQ(coinbase::api::ecdsa_mp::sign_additive(job, mem_t(), mem_t(), /*sig_receiver=*/-1, sig), E_BADARG); +} + +namespace { +static void generate_mp_key_blobs(curve_id curve, int n, std::vector& blobs) { + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names; + std::vector name_views; + for (int i = 0; i < n; i++) names.push_back("p" + std::to_string(i)); + for (const auto& nm : names) name_views.emplace_back(nm); + + blobs.resize(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_additive(job, curve, blobs[static_cast(i)], + sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); +} +} // namespace + +class ApiEcdsaMpNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { generate_mp_key_blobs(curve_id::secp256k1, 3, blobs_); } + static std::vector blobs_; +}; +std::vector ApiEcdsaMpNegWithBlobs::blobs_; + +TEST_F(ApiEcdsaMpNegWithBlobs, NegSignEmptyMsgHash) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_additive(job, blobs_[0], mem_t(), /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegSignOversizedMsgHash) { + dummy_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t huge_hash(65); + std::memset(huge_hash.data(), 0x42, static_cast(huge_hash.size())); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_additive(job, blobs_[0], huge_hash, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachEmptyPrivateScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, mem_t(), Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachGarbagePrivateScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t garbage[512]; + std::memset(garbage, 0xFF, sizeof(garbage)); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, mem_t(garbage, sizeof(garbage)), Qi, out), + SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachWrongSizeScalar31) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t short_scalar[31]; + std::memset(short_scalar, 0x01, sizeof(short_scalar)); + buf_t out; + EXPECT_NE( + coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, mem_t(short_scalar, sizeof(short_scalar)), Qi, out), + SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachWrongSizeScalar33) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t long_scalar[33]; + std::memset(long_scalar, 0x01, sizeof(long_scalar)); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, mem_t(long_scalar, sizeof(long_scalar)), Qi, out), + SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachZeroScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t zero[32] = {}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, mem_t(zero, 32), Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachSingleByteScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t one_byte = 0x01; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, mem_t(&one_byte, 1), Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachEmptyPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, x, mem_t(), out), SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachAllZeroPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + uint8_t zero_point[33] = {}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, x, mem_t(zero_point, 33), out), SUCCESS); +} + +TEST_F(ApiEcdsaMpNegWithBlobs, NegAttachGarbagePublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + uint8_t bad_point[33]; + bad_point[0] = 0x05; + std::memset(bad_point + 1, 0xAB, 32); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, x, mem_t(bad_point, 33), out), SUCCESS); +} diff --git a/tests/unit/api/test_ecdsa_mp_ac.cpp b/tests/unit/api/test_ecdsa_mp_ac.cpp new file mode 100644 index 00000000..ee852ed1 --- /dev/null +++ b/tests/unit/api/test_ecdsa_mp_ac.cpp @@ -0,0 +1,1163 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::api::curve_id; +using coinbase::api::job_mp_t; +using coinbase::api::party_idx_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_mp; + +static std::vector> make_peers(int n) { + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + return peers; +} + +static std::vector> make_transports( + const std::vector>& peers) { + std::vector> transports; + transports.reserve(peers.size()); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + return transports; +} + +static buf_t make_msg_hash32() { + buf_t msg_hash(32); + for (int i = 0; i < msg_hash.size(); i++) msg_hash[i] = static_cast(i); + return msg_hash; +} + +} // namespace + +TEST(ApiEcdsaMpAc, DkgRefreshSign4p) { + constexpr int n = 4; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // THRESHOLD[2](p0, p1, p2, p3) + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + + // Only p0 and p1 actively contribute to the DKG/refresh. + const std::vector quorum_party_names = {names[0], names[1]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sids[static_cast(i)], ac, + quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + buf_t pub0; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(key_blobs[0], pub0), SUCCESS); + EXPECT_EQ(pub0.size(), 33); + for (int i = 1; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(key_blobs[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + // Verify a signature using only the online signing quorum parties. + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub0), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + buf_t msg_hash(32); + for (int i = 0; i < msg_hash.size(); i++) msg_hash[i] = static_cast(i); + + std::vector> sign_peers = {peers[0], peers[1]}; + std::vector> sign_transports = {transports[0], transports[1]}; + + constexpr int quorum_n = 2; + std::vector sigs(quorum_n); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::sign_ac(job, key_blobs[static_cast(i)], ac, msg_hash, + /*sig_receiver=*/0, sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_GT(sigs[0].size(), 0); + EXPECT_EQ(sigs[1].size(), 0); + ASSERT_EQ(verify_key.verify(msg_hash, sigs[0]), SUCCESS); + + // Threshold refresh. + std::vector new_key_blobs(n); + std::vector refresh_sids(n); + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::refresh_ac(job, refresh_sids[static_cast(i)], + key_blobs[static_cast(i)], ac, quorum_party_names, + new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(new_key_blobs[static_cast(i)], pub_i), + SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + std::vector sigs2(quorum_n); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::sign_ac(job, new_key_blobs[static_cast(i)], ac, msg_hash, + /*sig_receiver=*/0, sigs2[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_GT(sigs2[0].size(), 0); + EXPECT_EQ(sigs2[1].size(), 0); + ASSERT_EQ(verify_key.verify(msg_hash, sigs2[0]), SUCCESS); +} + +TEST(ApiEcdsaMpAc, RejectsAccessStructureLeafMismatch) { + constexpr int n = 3; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // Missing leaf "p2" -> should reject before protocol starts. + const coinbase::api::access_structure_t bad_ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum_party_names = {names[0], names[1]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sids[static_cast(i)], bad_ac, + quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, E_BADARG); +} + +TEST(ApiEcdsaMpAc, ComplexAccess_DkgRefreshSign4p) { + constexpr int n = 4; + + std::vector> peers = make_peers(n); + std::vector> transports = make_transports(peers); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // Access structure: + // THRESHOLD[2]( + // AND(p0, p1), + // OR(p2, p3) + // ) + // + // Equivalent policy: p0 AND p1 AND (p2 OR p3). + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }), + coinbase::api::access_structure_t::Or({ + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }), + }); + + // DKG quorum must satisfy the access structure. Choose {p0, p1, p2}. + const std::vector dkg_quorum_party_names = {names[0], names[1], names[2]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sids[static_cast(i)], ac, + dkg_quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + buf_t pub0; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(key_blobs[0], pub0), SUCCESS); + EXPECT_EQ(pub0.size(), 33); + for (int i = 1; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(key_blobs[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub0), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + const buf_t msg_hash = make_msg_hash32(); + + // Signing quorum A: {p0, p1, p2} (satisfies OR via p2). + const std::vector quorum_a = {names[0], names[1], names[2]}; + + { + std::vector> sign_peers = make_peers(static_cast(quorum_a.size())); + std::vector> sign_transports = make_transports(sign_peers); + + std::vector sigs(quorum_a.size()); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_a, *sign_transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::sign_ac(job, key_blobs[static_cast(i)], ac, msg_hash, + /*sig_receiver=*/0, sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_GT(sigs[0].size(), 0); + for (size_t i = 1; i < sigs.size(); i++) EXPECT_EQ(sigs[i].size(), 0); + ASSERT_EQ(verify_key.verify(msg_hash, sigs[0]), SUCCESS); + } + + // Signing quorum B: {p0, p1, p3} (satisfies OR via p3). + const std::vector quorum_b = {names[0], names[1], names[3]}; + const std::vector quorum_b_key_blob_indices = {0, 1, 3}; + + { + std::vector> sign_peers = make_peers(static_cast(quorum_b.size())); + std::vector> sign_transports = make_transports(sign_peers); + + std::vector sigs(quorum_b.size()); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_b, *sign_transports[static_cast(i)]}; + const size_t key_blob_idx = quorum_b_key_blob_indices[static_cast(i)]; + return coinbase::api::ecdsa_mp::sign_ac(job, key_blobs[key_blob_idx], ac, msg_hash, /*sig_receiver=*/0, + sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_GT(sigs[0].size(), 0); + for (size_t i = 1; i < sigs.size(); i++) EXPECT_EQ(sigs[i].size(), 0); + ASSERT_EQ(verify_key.verify(msg_hash, sigs[0]), SUCCESS); + } + + // Threshold refresh with a different quorum that still satisfies the access structure. + // Choose {p0, p1, p3}. + std::vector new_key_blobs(n); + std::vector refresh_sids(n); + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::refresh_ac(job, refresh_sids[static_cast(i)], + key_blobs[static_cast(i)], ac, quorum_b, + new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_key_compressed(new_key_blobs[static_cast(i)], pub_i), + SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + // Ensure we can still sign using `sign_ac` for quorum B after refresh. + + { + std::vector> sign_peers = make_peers(static_cast(quorum_b.size())); + std::vector> sign_transports = make_transports(sign_peers); + + std::vector sigs(quorum_b.size()); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_b, *sign_transports[static_cast(i)]}; + const size_t key_blob_idx = quorum_b_key_blob_indices[static_cast(i)]; + return coinbase::api::ecdsa_mp::sign_ac(job, new_key_blobs[key_blob_idx], ac, msg_hash, /*sig_receiver=*/0, + sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_GT(sigs[0].size(), 0); + for (size_t i = 1; i < sigs.size(); i++) EXPECT_EQ(sigs[i].size(), 0); + ASSERT_EQ(verify_key.verify(msg_hash, sigs[0]), SUCCESS); + } +} + +TEST(ApiEcdsaMpAc, RejectsMalformedAccessStructure) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const std::vector quorum = {names[0], names[1]}; + + buf_t sid; + buf_t key_blob; + + // Root leaf is not supported. + { + const auto ac = coinbase::api::access_structure_t::leaf(names[0]); + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), E_BADARG); + } + + // Leaf node with children. + { + coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + ac.children[0].children.push_back(coinbase::api::access_structure_t::leaf("x")); + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), E_BADARG); + } + + // Internal node with no children. + { + coinbase::api::access_structure_t ac; + ac.type = coinbase::api::access_structure_t::node_type::and_node; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), E_BADARG); + } + + // Invalid threshold_k. + { + coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + /*k=*/0, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), E_BADARG); + } + + // Duplicate leaf names should be rejected. + { + coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), E_BADARG); + } +} + +TEST(ApiEcdsaMpAc, RejectDkgQuorumInsufficient) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2", "p3"}; + job_mp_t job{/*self=*/0, names, t}; + + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }), + coinbase::api::access_structure_t::Or({ + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }), + }); + + // Missing p1 => does not satisfy AND(p0, p1). + const std::vector bad_quorum = {names[0], names[2], names[3]}; + + buf_t sid; + buf_t key_blob; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, bad_quorum, key_blob), E_BADARG); +} + +TEST(ApiEcdsaMpAc, RejectsDkgQuorumUnknownPartyName) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + }); + + const std::vector bad_quorum = {names[0], "unknown"}; + + buf_t sid; + buf_t key_blob; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, bad_quorum, key_blob), E_BADARG); +} + +TEST(ApiEcdsaMpAc, KeyBlobPrivScalar_NoPubSign) { + constexpr int n = 5; + + std::vector> peers = make_peers(n); + std::vector> transports = make_transports(peers); + + std::vector names = {"p0", "p1", "p2", "p3", "p4"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // THRESHOLD[3](p0, p1, p2, p3, p4) + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(3, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + coinbase::api::access_structure_t::leaf(names[4]), + }); + + const std::vector quorum_party_names = {names[0], names[1], names[2]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sids[static_cast(i)], ac, + quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + + // Detach/attach private scalar for each party and ensure the public blob preserves Qi_self extraction. + std::vector redacted(n); + std::vector x_fixed(n); + std::vector merged(n); + for (int i = 0; i < n; i++) { + ASSERT_EQ( + coinbase::api::ecdsa_mp::detach_private_scalar(key_blobs[static_cast(i)], redacted[i], x_fixed[i]), + SUCCESS); + EXPECT_GT(redacted[i].size(), 0); + EXPECT_EQ(x_fixed[i].size(), 32); // secp256k1 order size + + buf_t Qi_full; + buf_t Qi_redacted; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(key_blobs[static_cast(i)], Qi_full), + SUCCESS); + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(redacted[i], Qi_redacted), SUCCESS); + EXPECT_EQ(Qi_full, Qi_redacted); + + ASSERT_EQ(coinbase::api::ecdsa_mp::attach_private_scalar(redacted[i], x_fixed[i], Qi_full, merged[i]), SUCCESS); + EXPECT_GT(merged[i].size(), 0); + } + + // Public blob should not be usable for signing. + // Avoid spinning up a full protocol run here: sign_ac rejects at key blob parsing + // before any transport calls, so a single local call is sufficient. + const buf_t msg_hash = make_msg_hash32(); + { + failing_transport_t t; + job_mp_t job{/*self=*/0, quorum_party_names, t}; + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_ac(job, redacted[0], ac, msg_hash, /*sig_receiver=*/0, sig), SUCCESS); + } + + // Merged blobs should be usable for signing. + std::vector> sign_peers = {peers[0], peers[1], peers[2]}; + std::vector> sign_transports = {transports[0], transports[1], transports[2]}; + std::vector sigs(quorum_party_names.size()); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::sign_ac(job, merged[static_cast(i)], ac, msg_hash, /*sig_receiver=*/0, + sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_GT(sigs[0].size(), 0); + for (size_t i = 1; i < sigs.size(); i++) EXPECT_EQ(sigs[i].size(), 0); + + // Negative: merging the wrong scalar should fail. + buf_t Qi0; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(key_blobs[0], Qi0), SUCCESS); + buf_t bad_x = x_fixed[0]; + bad_x[0] ^= 0x01; + buf_t bad_merged; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(redacted[0], bad_x, Qi0, bad_merged), SUCCESS); +} + +TEST(ApiEcdsaMpAc, SignAcAndRefreshNegativeCases) { + constexpr int n = 4; + + std::vector> peers = make_peers(n); + std::vector> transports = make_transports(peers); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // Strict access structure requires all 4 parties to satisfy: AND(p0, p1, p2, p3). + const coinbase::api::access_structure_t strict_ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + + // DKG quorum must satisfy the access structure. + const std::vector dkg_quorum_party_names = {names[0], names[1], names[2], names[3]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sids[static_cast(i)], strict_ac, + dkg_quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + + const buf_t msg_hash = make_msg_hash32(); + + // `sign_ac` should reject when the key blob does not match `job.self`. + { + const std::vector quorum_missing_self = {names[1], names[2]}; + failing_transport_t t; + job_mp_t job{/*self=*/0, quorum_missing_self, t}; // self == "p1" + buf_t sig; + EXPECT_EQ(coinbase::api::ecdsa_mp::sign_ac(job, key_blobs[0], strict_ac, msg_hash, /*sig_receiver=*/0, sig), + E_BADARG); + } + + // `sign_ac` should reject when the online signing quorum does not satisfy the access structure. + { + const std::vector insufficient_quorum = {names[0], names[1]}; + failing_transport_t t; + job_mp_t job{/*self=*/0, insufficient_quorum, t}; + buf_t sig; + EXPECT_EQ(coinbase::api::ecdsa_mp::sign_ac(job, key_blobs[0], strict_ac, msg_hash, /*sig_receiver=*/0, sig), + E_INSUFFICIENT); + } + + // refresh_ac should reject additive (v1) key blobs. + { + const std::vector quorum = {names[0], names[1]}; + + // Produce an additive-share (v1) key blob via additive DKG. + constexpr int quorum_n = 2; + std::vector> dkg_peers = make_peers(quorum_n); + std::vector> dkg_transports = make_transports(dkg_peers); + + std::vector additive_key_blobs(quorum_n); + std::vector additive_sids(quorum_n); + run_mp( + dkg_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum, *dkg_transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_additive(job, curve_id::secp256k1, + additive_key_blobs[static_cast(i)], + additive_sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + + failing_transport_t t; + job_mp_t job{/*self=*/0, quorum, t}; + + // Any access structure will do since the call should fail at key blob deserialization. + const coinbase::api::access_structure_t quorum_ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(quorum[0]), + coinbase::api::access_structure_t::leaf(quorum[1]), + }); + + buf_t sid; + buf_t new_key_blob; + EXPECT_EQ(coinbase::api::ecdsa_mp::refresh_ac(job, sid, additive_key_blobs[0], quorum_ac, quorum, new_key_blob), + E_FORMAT); + } +} + +TEST(ApiEcdsaMpAc, SignAcRejectsWrongAccessStructure) { + constexpr int n = 4; + + std::vector> peers = make_peers(n); + std::vector> transports = make_transports(peers); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // DKG access structure requires all parties: AND(p0, p1, p2, p3) + const coinbase::api::access_structure_t strict_ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + const std::vector strict_quorum = {names[0], names[1], names[2], names[3]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sids[static_cast(i)], strict_ac, + strict_quorum, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + + // Try to "weaken" the policy by supplying a different access structure that allows a 2-of-4 quorum. + const coinbase::api::access_structure_t wrong_ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + const std::vector wrong_quorum = {names[0], names[1]}; + + failing_transport_t t; + job_mp_t job{/*self=*/0, wrong_quorum, t}; + buf_t sig_der; + const buf_t msg_hash = make_msg_hash32(); + EXPECT_NE(coinbase::api::ecdsa_mp::sign_ac(job, key_blobs[0], wrong_ac, msg_hash, /*sig_receiver=*/0, sig_der), + SUCCESS); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +#include + +namespace { + +using coinbase::mem_t; + +static void generate_mp_ac_key_blobs(curve_id curve, int n, std::vector& blobs) { + auto peers = make_peers(n); + auto transports = make_transports(peers); + + std::vector names; + std::vector name_views; + for (int i = 0; i < n; i++) names.push_back("p" + std::to_string(i)); + for (const auto& nm : names) name_views.emplace_back(nm); + + const auto ac = coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + const std::vector quorum = {names[0], names[1]}; + + blobs.resize(n); + std::vector sids(n); + std::vector rvs; + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::ecdsa_mp::dkg_ac(job, curve, sids[static_cast(i)], ac, quorum, + blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); +} + +} // namespace + +class ApiEcdsaMpAcNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { generate_mp_ac_key_blobs(curve_id::secp256k1, 4, blobs_); } + static std::vector blobs_; +}; + +std::vector ApiEcdsaMpAcNegWithBlobs::blobs_; + +TEST(ApiEcdsaMpAc, NegDkgEdwardsCurve) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegDkgInvalidCurveValues) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + for (uint32_t val : {0u, 4u, 255u}) { + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::dkg_ac(job, static_cast(val), sid, ac, quorum, key_blob), SUCCESS) + << "Expected failure for curve_id=" << val; + } +} + +TEST(ApiEcdsaMpAc, NegDkgEmptyQuorum) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector empty_quorum; + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, empty_quorum, key_blob), SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegDkgSinglePartyJob) { + failing_transport_t t; + std::vector names = {"p0"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(1, { + coinbase::api::access_structure_t::leaf(names[0]), + }); + const std::vector quorum = {names[0]}; + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegDkgEmptyPartyName) { + failing_transport_t t; + std::vector names = {"p0", "", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + }); + const std::vector quorum = {names[0], names[1], names[2]}; + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegDkgThresholdExceedsChildren) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(3, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEcdsaMpAc, NegDkgNegativeThreshold) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = + coinbase::api::access_structure_t::Threshold(-1, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::ecdsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEcdsaMpAc, NegSignAcEmptyKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + buf_t msg_hash = make_msg_hash32(); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_ac(job, mem_t(), ac, msg_hash, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcGarbageKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg_hash = make_msg_hash32(); + buf_t sig; + EXPECT_NE( + coinbase::api::ecdsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg_hash, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcAllZeroKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t zeros[64] = {}; + buf_t msg_hash = make_msg_hash32(); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_ac(job, mem_t(zeros, sizeof(zeros)), ac, msg_hash, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcOversizedKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t msg_hash = make_msg_hash32(); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_ac(job, big, ac, msg_hash, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcEmptyMsgHash) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t sig; + EXPECT_NE( + coinbase::api::ecdsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, mem_t(), /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcOversizedMsgHash) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t huge_hash(65); + std::memset(huge_hash.data(), 0x42, static_cast(huge_hash.size())); + buf_t sig; + EXPECT_NE( + coinbase::api::ecdsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, huge_hash, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcNegativeSigReceiver) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg_hash = make_msg_hash32(); + buf_t sig; + EXPECT_NE( + coinbase::api::ecdsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg_hash, /*sig_receiver=*/-1, sig), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcSigReceiverTooLarge) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg_hash = make_msg_hash32(); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg_hash, + /*sig_receiver=*/static_cast(names.size()), sig), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcEmptyAndAccessStructure) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({}); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg_hash = make_msg_hash32(); + buf_t sig; + EXPECT_NE( + coinbase::api::ecdsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg_hash, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegSignAcEmptyOrAccessStructure) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::Or({}); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg_hash = make_msg_hash32(); + buf_t sig; + EXPECT_NE( + coinbase::api::ecdsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg_hash, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegRefreshAcEmptyKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_ac(job, sid, mem_t(), ac, quorum, new_key_blob), SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegRefreshAcGarbageKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const std::vector quorum = {names[0], names[1]}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_ac(job, sid, mem_t(garbage, sizeof(garbage)), ac, quorum, new_key_blob), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegRefreshAcAllZeroKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t zeros[64] = {}; + const std::vector quorum = {names[0], names[1]}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_ac(job, sid, mem_t(zeros, sizeof(zeros)), ac, quorum, new_key_blob), + SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegRefreshAcOversizedKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_ac(job, sid, big, ac, quorum, new_key_blob), SUCCESS); +} + +TEST(ApiEcdsaMpAc, NegRefreshAcEmptyQuorum) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const std::vector empty_quorum; + buf_t sid, new_key_blob; + EXPECT_NE( + coinbase::api::ecdsa_mp::refresh_ac(job, sid, mem_t(garbage, sizeof(garbage)), ac, empty_quorum, new_key_blob), + SUCCESS); +} + +TEST_F(ApiEcdsaMpAcNegWithBlobs, NegSignAcEmptyMsgHashValidBlob) { + std::vector names = {"p0", "p1"}; + std::vector name_views(names.begin(), names.end()); + failing_transport_t t; + job_mp_t job{/*self=*/0, name_views, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf("p0"), + coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3"), + }); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_ac(job, blobs_[0], ac, mem_t(), /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST_F(ApiEcdsaMpAcNegWithBlobs, NegSignAcOversizedMsgHashValidBlob) { + std::vector names = {"p0", "p1"}; + std::vector name_views(names.begin(), names.end()); + failing_transport_t t; + job_mp_t job{/*self=*/0, name_views, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf("p0"), + coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3"), + }); + buf_t huge_hash(65); + std::memset(huge_hash.data(), 0x42, static_cast(huge_hash.size())); + buf_t sig; + EXPECT_NE(coinbase::api::ecdsa_mp::sign_ac(job, blobs_[0], ac, huge_hash, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST_F(ApiEcdsaMpAcNegWithBlobs, NegRefreshAcInvalidAccessStructure) { + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views(names.begin(), names.end()); + failing_transport_t t; + job_mp_t job{/*self=*/0, name_views, t}; + const auto bad_ac = coinbase::api::access_structure_t::leaf("p0"); + const std::vector quorum = {"p0", "p1"}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::ecdsa_mp::refresh_ac(job, sid, blobs_[0], bad_ac, quorum, new_key_blob), SUCCESS); +} + +TEST_F(ApiEcdsaMpAcNegWithBlobs, NegAttachWrongScalarSize) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + buf_t wrong_size(16); + std::memset(wrong_size.data(), 0x01, static_cast(wrong_size.size())); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, wrong_size, Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsaMpAcNegWithBlobs, NegAttachGarbageScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + buf_t garbage_scalar(x.size()); + std::memset(garbage_scalar.data(), 0xDE, static_cast(garbage_scalar.size())); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, garbage_scalar, Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsaMpAcNegWithBlobs, NegAttachZeroScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::ecdsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + buf_t zero_scalar(x.size()); + std::memset(zero_scalar.data(), 0x00, static_cast(zero_scalar.size())); + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, zero_scalar, Qi, out), SUCCESS); +} + +TEST_F(ApiEcdsaMpAcNegWithBlobs, NegAttachEmptyPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, x, mem_t(), out), SUCCESS); +} + +TEST_F(ApiEcdsaMpAcNegWithBlobs, NegAttachAllZeroPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::ecdsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + uint8_t zero_point[33] = {}; + buf_t out; + EXPECT_NE(coinbase::api::ecdsa_mp::attach_private_scalar(pub_blob, x, mem_t(zero_point, 33), out), SUCCESS); +} diff --git a/tests/unit/api/test_eddsa2pc.cpp b/tests/unit/api/test_eddsa2pc.cpp new file mode 100644 index 00000000..20b09aac --- /dev/null +++ b/tests/unit/api/test_eddsa2pc.cpp @@ -0,0 +1,207 @@ +#include +#include +#include + +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::api::curve_id; +using coinbase::api::eddsa_2p::party_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_2pc; + +static void exercise_ed25519() { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::dkg(job1, curve_id::ed25519, key_blob_1); }, + [&] { return coinbase::api::eddsa_2p::dkg(job2, curve_id::ed25519, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t pub1; + buf_t pub2; + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(key_blob_1, pub1), SUCCESS); + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(key_blob_2, pub2), SUCCESS); + EXPECT_EQ(pub1.size(), 32); + EXPECT_EQ(pub1, pub2); + + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub1), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + // Deterministic 32-byte message for testing. + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + + buf_t sig1; + buf_t sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::sign(job1, key_blob_1, msg, sig1); }, + [&] { return coinbase::api::eddsa_2p::sign(job2, key_blob_2, msg, sig2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sig1.size(), 64); + EXPECT_EQ(sig2.size(), 0); + ASSERT_EQ(verify_key.verify(msg, sig1), SUCCESS); + + // Refresh and sign again. + buf_t new_key_blob_1; + buf_t new_key_blob_2; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::refresh(job1, key_blob_1, new_key_blob_1); }, + [&] { return coinbase::api::eddsa_2p::refresh(job2, key_blob_2, new_key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t sig3; + buf_t sig4; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::sign(job1, new_key_blob_1, msg, sig3); }, + [&] { return coinbase::api::eddsa_2p::sign(job2, new_key_blob_2, msg, sig4); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sig3.size(), 64); + EXPECT_EQ(sig4.size(), 0); + ASSERT_EQ(verify_key.verify(msg, sig3), SUCCESS); + + buf_t pub3; + buf_t pub4; + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(new_key_blob_1, pub3), SUCCESS); + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(new_key_blob_2, pub4), SUCCESS); + EXPECT_EQ(pub3, pub4); + EXPECT_EQ(pub3, pub1); + + // Role is fixed to the share: signing with the "wrong" job.self should fail. + buf_t bad_sig1; + buf_t bad_sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::sign(job1, key_blob_2, msg, bad_sig1); }, + [&] { return coinbase::api::eddsa_2p::sign(job2, key_blob_2, msg, bad_sig2); }, rv1, rv2); + EXPECT_EQ(rv1, E_BADARG); +} + +} // namespace + +TEST(ApiEdDSA2pc, DkgSignRefreshSign) { exercise_ed25519(); } + +TEST(ApiEdDSA2pc, UnsupportedCurveRejected) { + failing_transport_t t; + buf_t key_blob; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::eddsa_2p::dkg(job, curve_id::secp256k1, key_blob), E_BADARG); +} + +TEST(ApiEdDSA2pc, KeyBlobPrivScalar_NoPubSign) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::dkg(job1, curve_id::ed25519, key_blob_1); }, + [&] { return coinbase::api::eddsa_2p::dkg(job2, curve_id::ed25519, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + // Refresh (exercise detach/attach on refreshed blobs too). + buf_t refreshed_1; + buf_t refreshed_2; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::refresh(job1, key_blob_1, refreshed_1); }, + [&] { return coinbase::api::eddsa_2p::refresh(job2, key_blob_2, refreshed_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t public_1; + buf_t public_2; + buf_t x1_fixed; + buf_t x2_fixed; + ASSERT_EQ(coinbase::api::eddsa_2p::detach_private_scalar(refreshed_1, public_1, x1_fixed), SUCCESS); + ASSERT_EQ(coinbase::api::eddsa_2p::detach_private_scalar(refreshed_2, public_2, x2_fixed), SUCCESS); + EXPECT_EQ(x1_fixed.size(), 32); + EXPECT_EQ(x2_fixed.size(), 32); + + buf_t Qi_full_1; + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_share_compressed(refreshed_1, Qi_full_1), SUCCESS); + + buf_t Qi_full_2; + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_share_compressed(refreshed_2, Qi_full_2), SUCCESS); + + // Public blob should not be usable for signing. + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + { + failing_transport_t ft; + const coinbase::api::job_2p_t bad_job{party_t::p1, "p1", "p2", ft}; + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_2p::sign(bad_job, public_1, msg, sig), SUCCESS); + } + + buf_t merged_1; + buf_t merged_2; + ASSERT_EQ(coinbase::api::eddsa_2p::attach_private_scalar(public_1, x1_fixed, Qi_full_1, merged_1), SUCCESS); + ASSERT_EQ(coinbase::api::eddsa_2p::attach_private_scalar(public_2, x2_fixed, Qi_full_2, merged_2), SUCCESS); + + // Sign again with merged blobs. + buf_t pub; + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(merged_1, pub), SUCCESS); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + buf_t sig1; + buf_t sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::sign(job1, merged_1, msg, sig1); }, + [&] { return coinbase::api::eddsa_2p::sign(job2, merged_2, msg, sig2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + ASSERT_EQ(sig1.size(), 64); + EXPECT_EQ(sig2.size(), 0); + ASSERT_EQ(verify_key.verify(msg, sig1), SUCCESS); + + // Negative: wrong scalar should fail to attach. + buf_t bad_x = x1_fixed; + bad_x[0] ^= 0x01; + buf_t bad_merged; + EXPECT_NE(coinbase::api::eddsa_2p::attach_private_scalar(public_1, bad_x, Qi_full_1, bad_merged), SUCCESS); +} diff --git a/tests/unit/api/test_eddsa_mp.cpp b/tests/unit/api/test_eddsa_mp.cpp new file mode 100644 index 00000000..27d5b2b8 --- /dev/null +++ b/tests/unit/api/test_eddsa_mp.cpp @@ -0,0 +1,619 @@ +#include +#include +#include +#include + +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::api::curve_id; +using coinbase::api::job_mp_t; +using coinbase::api::party_idx_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_mp; + +static void exercise_4p_role_change() { + constexpr int n = 4; + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + std::vector keys(n); + std::vector new_keys(n); + std::vector sids(n); + std::vector sigs(n); + std::vector new_sigs(n); + std::vector rvs; + + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::dkg_additive(job, curve_id::ed25519, keys[static_cast(i)], + sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + buf_t pub0; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_key_compressed(keys[0], pub0), SUCCESS); + EXPECT_EQ(pub0.size(), 32); + for (int i = 1; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_key_compressed(keys[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub0), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + // Change the party ordering ("role" indices) between protocols. + const std::vector name_views2 = {names[0], names[2], names[1], names[3]}; + // Map new role index -> old role index (DKG) for the same party name. + const int perm[n] = {0, 2, 1, 3}; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::sign_additive(job, keys[static_cast(perm[i])], msg, /*sig_receiver=*/2, + sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + EXPECT_EQ(sigs[2].size(), 64); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + EXPECT_EQ(sigs[static_cast(i)].size(), 0); + } + ASSERT_EQ(verify_key.verify(msg, sigs[2]), SUCCESS); + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::refresh_additive(job, sids[static_cast(perm[i])], + keys[static_cast(perm[i])], + new_keys[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_key_compressed(new_keys[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::sign_additive(job, new_keys[static_cast(i)], msg, /*sig_receiver=*/2, + new_sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + EXPECT_EQ(new_sigs[2].size(), 64); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + EXPECT_EQ(new_sigs[static_cast(i)].size(), 0); + } + ASSERT_EQ(verify_key.verify(msg, new_sigs[2]), SUCCESS); +} + +} // namespace + +TEST(ApiEdDSAMp, DkgSignRefreshSignRoleChange4p) { exercise_4p_role_change(); } + +TEST(ApiEdDSAMp, RejectsInvalidSigReceiver) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sig; + EXPECT_EQ(coinbase::api::eddsa_mp::sign_additive(job, coinbase::mem_t(), coinbase::mem_t(), /*sig_receiver=*/5, sig), + E_BADARG); +} + +TEST(ApiEdDSAMp, UnsupportedCurveRejected) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key; + buf_t sid; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_additive(job, curve_id::secp256k1, key, sid), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiEdDSAMp, NegDkgP256Curve) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_additive(job, curve_id::p256, key, sid), E_BADARG); +} + +TEST(ApiEdDSAMp, NegDkgCurveZero) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_additive(job, static_cast(0), key, sid), SUCCESS); +} + +TEST(ApiEdDSAMp, NegDkgCurveFour) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_additive(job, static_cast(4), key, sid), SUCCESS); +} + +TEST(ApiEdDSAMp, NegDkgCurve255) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_additive(job, static_cast(255), key, sid), SUCCESS); +} + +TEST(ApiEdDSAMp, NegDkgSingleParty) { + failing_transport_t t; + std::vector names = {"p0"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_additive(job, curve_id::ed25519, key, sid), E_BADARG); +} + +TEST(ApiEdDSAMp, NegDkgEmptyPartyNames) { + failing_transport_t t; + std::vector names; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_additive(job, curve_id::ed25519, key, sid), E_BADARG); +} + +TEST(ApiEdDSAMp, NegDkgDuplicatePartyNames) { + failing_transport_t t; + std::vector names = {"dup", "dup"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key, sid; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_additive(job, curve_id::ed25519, key, sid), E_BADARG); +} + +TEST(ApiEdDSAMp, NegDkgSelfOutOfRange) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/3, names, t}; + + buf_t key, sid; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_additive(job, curve_id::ed25519, key, sid), E_BADARG); +} + +TEST(ApiEdDSAMp, NegGetPubKeyEmpty) { + buf_t pub; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_key_compressed(coinbase::mem_t(), pub), SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubKeyGarbage) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pub; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_key_compressed(coinbase::mem_t(garbage, sizeof(garbage)), pub), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubKeyAllZero) { + uint8_t zeros[64] = {}; + buf_t pub; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_key_compressed(coinbase::mem_t(zeros, sizeof(zeros)), pub), SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubKeyOneByte) { + uint8_t one = 0x00; + buf_t pub; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_key_compressed(coinbase::mem_t(&one, 1), pub), SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubKeyOversized) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t pub; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_key_compressed(big, pub), SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubShareEmpty) { + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_share_compressed(coinbase::mem_t(), out), SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubShareGarbage) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_share_compressed(coinbase::mem_t(garbage, sizeof(garbage)), out), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubShareAllZero) { + uint8_t zeros[64] = {}; + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_share_compressed(coinbase::mem_t(zeros, sizeof(zeros)), out), SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubShareOneByte) { + uint8_t one = 0x00; + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_share_compressed(coinbase::mem_t(&one, 1), out), SUCCESS); +} + +TEST(ApiEdDSAMp, NegGetPubShareOversized) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::get_public_share_compressed(big, out), SUCCESS); +} + +TEST(ApiEdDSAMp, NegDetachEmpty) { + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::eddsa_mp::detach_private_scalar(coinbase::mem_t(), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEdDSAMp, NegDetachGarbage) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::eddsa_mp::detach_private_scalar(coinbase::mem_t(garbage, sizeof(garbage)), pub_blob, scalar), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegDetachAllZero) { + uint8_t zeros[64] = {}; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::eddsa_mp::detach_private_scalar(coinbase::mem_t(zeros, sizeof(zeros)), pub_blob, scalar), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegDetachOneByte) { + uint8_t one = 0x00; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::eddsa_mp::detach_private_scalar(coinbase::mem_t(&one, 1), pub_blob, scalar), SUCCESS); +} + +TEST(ApiEdDSAMp, NegDetachOversized) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::eddsa_mp::detach_private_scalar(big, pub_blob, scalar), SUCCESS); +} + +TEST(ApiEdDSAMp, NegAttachEmptyPublicKeyBlob) { + uint8_t scalar[32] = {0x01}; + uint8_t point[32] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(coinbase::mem_t(), coinbase::mem_t(scalar, 32), + coinbase::mem_t(point, 32), out), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegAttachGarbagePublicKeyBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + uint8_t scalar[32] = {0x01}; + uint8_t point[32] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE( + coinbase::api::eddsa_mp::attach_private_scalar(coinbase::mem_t(garbage, sizeof(garbage)), + coinbase::mem_t(scalar, 32), coinbase::mem_t(point, 32), out), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegAttachOversizedPublicKeyBlob) { + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + uint8_t scalar[32] = {0x01}; + uint8_t point[32] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE( + coinbase::api::eddsa_mp::attach_private_scalar(big, coinbase::mem_t(scalar, 32), coinbase::mem_t(point, 32), out), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegRefreshEmptyKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_additive(job, sid, coinbase::mem_t(), new_blob), SUCCESS); +} + +TEST(ApiEdDSAMp, NegRefreshGarbageKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_additive(job, sid, coinbase::mem_t(garbage, sizeof(garbage)), new_blob), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegRefreshAllZeroKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t zeros[64] = {}; + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_additive(job, sid, coinbase::mem_t(zeros, sizeof(zeros)), new_blob), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegRefreshOneByteKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t one = 0x00; + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_additive(job, sid, coinbase::mem_t(&one, 1), new_blob), SUCCESS); +} + +TEST(ApiEdDSAMp, NegRefreshOversizedKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t sid, new_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_additive(job, sid, big, new_blob), SUCCESS); +} + +TEST(ApiEdDSAMp, NegSignEmptyKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t msg(32); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_additive(job, coinbase::mem_t(), msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEdDSAMp, NegSignGarbageKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_additive(job, coinbase::mem_t(garbage, sizeof(garbage)), msg, + /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegSignAllZeroKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t zeros[64] = {}; + buf_t msg(32); + buf_t sig; + EXPECT_NE( + coinbase::api::eddsa_mp::sign_additive(job, coinbase::mem_t(zeros, sizeof(zeros)), msg, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegSignOneByteKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + uint8_t one = 0x00; + buf_t msg(32); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_additive(job, coinbase::mem_t(&one, 1), msg, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEdDSAMp, NegSignOversizedKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t msg(32); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_additive(job, big, msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEdDSAMp, NegSignSigReceiverNegative) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sig; + EXPECT_EQ(coinbase::api::eddsa_mp::sign_additive(job, coinbase::mem_t(), coinbase::mem_t(), /*sig_receiver=*/-1, sig), + E_BADARG); +} + +namespace { +using coinbase::mem_t; + +static void generate_eddsa_mp_key_blobs(int n, std::vector& blobs) { + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names; + std::vector name_views; + for (int i = 0; i < n; i++) names.push_back("p" + std::to_string(i)); + for (const auto& nm : names) name_views.emplace_back(nm); + + blobs.resize(n); + std::vector sids(n); + std::vector rvs; + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::dkg_additive(job, curve_id::ed25519, blobs[static_cast(i)], + sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); +} +} // namespace + +class ApiEdDSAMpNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { generate_eddsa_mp_key_blobs(3, blobs_); } + static std::vector blobs_; +}; +std::vector ApiEdDSAMpNegWithBlobs::blobs_; + +TEST_F(ApiEdDSAMpNegWithBlobs, NegSignEmptyMsg) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_additive(job, blobs_[0], mem_t(), /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST_F(ApiEdDSAMpNegWithBlobs, NegAttachEmptyPrivateScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, mem_t(), Qi, out), SUCCESS); +} + +TEST_F(ApiEdDSAMpNegWithBlobs, NegAttachGarbagePrivateScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t garbage[512]; + std::memset(garbage, 0xFF, sizeof(garbage)); + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, mem_t(garbage, sizeof(garbage)), Qi, out), + SUCCESS); +} + +TEST_F(ApiEdDSAMpNegWithBlobs, NegAttachWrongSizeScalar31) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t short_scalar[31]; + std::memset(short_scalar, 0x01, sizeof(short_scalar)); + buf_t out; + EXPECT_NE( + coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, mem_t(short_scalar, sizeof(short_scalar)), Qi, out), + SUCCESS); +} + +TEST_F(ApiEdDSAMpNegWithBlobs, NegAttachWrongSizeScalar33) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t long_scalar[33]; + std::memset(long_scalar, 0x01, sizeof(long_scalar)); + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, mem_t(long_scalar, sizeof(long_scalar)), Qi, out), + SUCCESS); +} + +TEST_F(ApiEdDSAMpNegWithBlobs, NegAttachZeroScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + uint8_t zero[32] = {}; + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, mem_t(zero, 32), Qi, out), SUCCESS); +} + +TEST_F(ApiEdDSAMpNegWithBlobs, NegAttachEmptyPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, x, mem_t(), out), SUCCESS); +} + +TEST_F(ApiEdDSAMpNegWithBlobs, NegAttachAllZeroPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + uint8_t zero_point[32] = {}; + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, x, mem_t(zero_point, 32), out), SUCCESS); +} + +TEST_F(ApiEdDSAMpNegWithBlobs, NegAttachGarbagePublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + uint8_t bad_point[32]; + std::memset(bad_point, 0xAB, sizeof(bad_point)); + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, x, mem_t(bad_point, 32), out), SUCCESS); +} diff --git a/tests/unit/api/test_eddsa_mp_ac.cpp b/tests/unit/api/test_eddsa_mp_ac.cpp new file mode 100644 index 00000000..547db995 --- /dev/null +++ b/tests/unit/api/test_eddsa_mp_ac.cpp @@ -0,0 +1,869 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::api::curve_id; +using coinbase::api::job_mp_t; +using coinbase::api::party_idx_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_mp; + +} // namespace + +TEST(ApiEdDSAMpAc, DkgRefreshSign4p) { + constexpr int n = 4; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // THRESHOLD[2](p0, p1, p2, p3) + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + + // Only p0 and p1 actively contribute to the DKG/refresh. + const std::vector quorum_party_names = {names[0], names[1]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sids[static_cast(i)], ac, + quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + buf_t pub0; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_key_compressed(key_blobs[0], pub0), SUCCESS); + EXPECT_EQ(pub0.size(), 32); + for (int i = 1; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_key_compressed(key_blobs[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub0), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + + std::vector> sign_peers = {peers[0], peers[1]}; + std::vector> sign_transports = {transports[0], transports[1]}; + + constexpr int quorum_n = 2; + std::vector sigs(quorum_n); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::sign_ac(job, key_blobs[static_cast(i)], ac, msg, /*sig_receiver=*/0, + sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_EQ(sigs[0].size(), 64); + EXPECT_EQ(sigs[1].size(), 0); + ASSERT_EQ(verify_key.verify(msg, sigs[0]), SUCCESS); + + // Threshold refresh. + std::vector new_key_blobs(n); + std::vector refresh_sids(n); + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::refresh_ac(job, refresh_sids[static_cast(i)], + key_blobs[static_cast(i)], ac, quorum_party_names, + new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_key_compressed(new_key_blobs[static_cast(i)], pub_i), + SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + std::vector sigs2(quorum_n); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::sign_ac(job, new_key_blobs[static_cast(i)], ac, msg, + /*sig_receiver=*/0, sigs2[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_EQ(sigs2[0].size(), 64); + EXPECT_EQ(sigs2[1].size(), 0); + ASSERT_EQ(verify_key.verify(msg, sigs2[0]), SUCCESS); +} + +TEST(ApiEdDSAMpAc, KeyBlobPrivScalar_NoPubSign) { + constexpr int n = 4; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // THRESHOLD[2](p0, p1, p2, p3) + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + const std::vector quorum_party_names = {names[0], names[1]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sids[static_cast(i)], ac, + quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + + buf_t pub0; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_key_compressed(key_blobs[0], pub0), SUCCESS); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub0), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + std::vector public_blobs(n); + std::vector x_fixed(n); + std::vector merged(n); + for (int i = 0; i < n; i++) { + ASSERT_EQ( + coinbase::api::eddsa_mp::detach_private_scalar(key_blobs[static_cast(i)], public_blobs[i], x_fixed[i]), + SUCCESS); + EXPECT_GT(public_blobs[i].size(), 0); + EXPECT_EQ(x_fixed[i].size(), 32); // ed25519 order size + + buf_t Qi_full; + buf_t Qi_public; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(key_blobs[static_cast(i)], Qi_full), + SUCCESS); + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(public_blobs[i], Qi_public), SUCCESS); + EXPECT_EQ(Qi_full, Qi_public); + + ASSERT_EQ(coinbase::api::eddsa_mp::attach_private_scalar(public_blobs[i], x_fixed[i], Qi_full, merged[i]), SUCCESS); + EXPECT_GT(merged[i].size(), 0); + } + + // Public blob should not be usable for signing. + // Avoid spinning up a full protocol run here: sign_ac rejects at key blob parsing + // before any transport calls, so a single local call is sufficient. + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + { + failing_transport_t t; + job_mp_t job{/*self=*/0, quorum_party_names, t}; + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, public_blobs[0], ac, msg, /*sig_receiver=*/0, sig), SUCCESS); + } + + // Merged blobs should be usable for signing. + std::vector> sign_peers = {peers[0], peers[1]}; + std::vector> sign_transports = {transports[0], transports[1]}; + constexpr int quorum_n = 2; + std::vector sigs(quorum_n); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::sign_ac(job, merged[static_cast(i)], ac, msg, /*sig_receiver=*/0, + sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_EQ(sigs[0].size(), 64); + EXPECT_EQ(sigs[1].size(), 0); + ASSERT_EQ(verify_key.verify(msg, sigs[0]), SUCCESS); + + // Negative: merging the wrong scalar should fail. + buf_t Qi0; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(key_blobs[0], Qi0), SUCCESS); + buf_t bad_x = x_fixed[0]; + bad_x[0] ^= 0x01; + buf_t bad_merged; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(public_blobs[0], bad_x, Qi0, bad_merged), SUCCESS); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiEdDSAMpAc, NegDkgWeierstrassCurves) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + { + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::p256, sid, ac, quorum, key_blob), SUCCESS); + } + { + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::secp256k1, sid, ac, quorum, key_blob), SUCCESS); + } +} + +TEST(ApiEdDSAMpAc, NegDkgInvalidCurveValues) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + for (uint32_t val : {0u, 4u, 255u}) { + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_ac(job, static_cast(val), sid, ac, quorum, key_blob), SUCCESS) + << "Expected failure for curve_id=" << val; + } +} + +TEST(ApiEdDSAMpAc, NegDkgEmptyQuorum) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector empty_quorum; + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, empty_quorum, key_blob), SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegDkgSinglePartyJob) { + failing_transport_t t; + std::vector names = {"p0"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(1, { + coinbase::api::access_structure_t::leaf(names[0]), + }); + const std::vector quorum = {names[0]}; + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegDkgEmptyPartyName) { + failing_transport_t t; + std::vector names = {"p0", "", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + }); + const std::vector quorum = {names[0], names[1], names[2]}; + buf_t sid, key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegDkgRootLeaf) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::leaf(names[0]); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEdDSAMpAc, NegDkgThresholdExceedsChildren) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(3, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEdDSAMpAc, NegDkgThresholdZero) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(0, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEdDSAMpAc, NegDkgNegativeThreshold) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = + coinbase::api::access_structure_t::Threshold(-1, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEdDSAMpAc, NegDkgEmptyAnd) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({}); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEdDSAMpAc, NegDkgDuplicateLeaves) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEdDSAMpAc, NegDkgInternalNodeNoChildren) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + coinbase::api::access_structure_t ac; + ac.type = coinbase::api::access_structure_t::node_type::and_node; + const std::vector quorum = {names[0], names[1]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, key_blob), E_BADARG); +} + +TEST(ApiEdDSAMpAc, NegDkgUnknownPartyInQuorum) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + }); + const std::vector bad_quorum = {names[0], "unknown"}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, bad_quorum, key_blob), E_BADARG); +} + +TEST(ApiEdDSAMpAc, NegDkgInsufficientQuorum) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2", "p3"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }), + coinbase::api::access_structure_t::Or({ + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }), + }); + const std::vector bad_quorum = {names[0], names[2], names[3]}; + buf_t sid, key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sid, ac, bad_quorum, key_blob), E_BADARG); +} + +#include + +namespace { + +using coinbase::mem_t; + +static void generate_eddsa_mp_ac_key_blobs(int n, std::vector& blobs) { + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names; + std::vector name_views; + for (int i = 0; i < n; i++) names.push_back("p" + std::to_string(i)); + for (const auto& nm : names) name_views.emplace_back(nm); + + const auto ac = coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + const std::vector quorum = {names[0], names[1]}; + + blobs.resize(n); + std::vector sids(n); + std::vector rvs; + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::dkg_ac(job, curve_id::ed25519, sids[static_cast(i)], ac, quorum, + blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); +} + +} // namespace + +class ApiEdDSAMpAcNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { generate_eddsa_mp_ac_key_blobs(4, blobs_); } + static std::vector blobs_; +}; + +std::vector ApiEdDSAMpAcNegWithBlobs::blobs_; + +TEST(ApiEdDSAMpAc, NegSignAcEmptyKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, mem_t(), ac, msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegSignAcGarbageKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegSignAcAllZeroKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t zeros[64] = {}; + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, mem_t(zeros, sizeof(zeros)), ac, msg, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegSignAcOversizedKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, big, ac, msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegSignAcEmptyMsg) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t sig; + EXPECT_NE( + coinbase::api::eddsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, mem_t(), /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegSignAcNegativeSigReceiver) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg, /*sig_receiver=*/-1, sig), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegSignAcSigReceiverTooLarge) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg, + /*sig_receiver=*/static_cast(names.size()), sig), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegSignAcEmptyAndAccessStructure) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({}); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegSignAcEmptyOrAccessStructure) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::Or({}); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, mem_t(garbage, sizeof(garbage)), ac, msg, /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegRefreshAcEmptyKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_ac(job, sid, mem_t(), ac, quorum, new_key_blob), SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegRefreshAcGarbageKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const std::vector quorum = {names[0], names[1]}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_ac(job, sid, mem_t(garbage, sizeof(garbage)), ac, quorum, new_key_blob), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegRefreshAcAllZeroKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t zeros[64] = {}; + const std::vector quorum = {names[0], names[1]}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_ac(job, sid, mem_t(zeros, sizeof(zeros)), ac, quorum, new_key_blob), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegRefreshAcOversizedKeyBlob) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + buf_t big(1024 * 1024 + 1); + std::memset(big.data(), 0xAB, static_cast(big.size())); + const std::vector quorum = {names[0], names[1]}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_ac(job, sid, big, ac, quorum, new_key_blob), SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegRefreshAcEmptyQuorum) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + const auto ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + }); + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const std::vector empty_quorum; + buf_t sid, new_key_blob; + EXPECT_NE( + coinbase::api::eddsa_mp::refresh_ac(job, sid, mem_t(garbage, sizeof(garbage)), ac, empty_quorum, new_key_blob), + SUCCESS); +} + +TEST(ApiEdDSAMpAc, NegRefreshAcAdditiveKeyBlob) { + constexpr int quorum_n = 2; + std::vector> dkg_peers; + dkg_peers.reserve(quorum_n); + for (int i = 0; i < quorum_n; i++) dkg_peers.push_back(std::make_shared(i)); + for (const auto& p : dkg_peers) p->init_with_peers(dkg_peers); + + std::vector> dkg_transports; + dkg_transports.reserve(quorum_n); + for (const auto& p : dkg_peers) dkg_transports.push_back(std::make_shared(p)); + + const std::vector quorum = {"p0", "p1"}; + + std::vector additive_key_blobs(quorum_n); + std::vector additive_sids(quorum_n); + std::vector rvs; + run_mp( + dkg_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum, *dkg_transports[static_cast(i)]}; + return coinbase::api::eddsa_mp::dkg_additive(job, curve_id::ed25519, additive_key_blobs[static_cast(i)], + additive_sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + + failing_transport_t t; + job_mp_t job{/*self=*/0, quorum, t}; + + const auto quorum_ac = coinbase::api::access_structure_t::And({ + coinbase::api::access_structure_t::leaf(quorum[0]), + coinbase::api::access_structure_t::leaf(quorum[1]), + }); + + buf_t sid, new_key_blob; + EXPECT_EQ(coinbase::api::eddsa_mp::refresh_ac(job, sid, additive_key_blobs[0], quorum_ac, quorum, new_key_blob), + E_FORMAT); +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegSignAcEmptyMsgValidBlob) { + std::vector names = {"p0", "p1"}; + std::vector name_views(names.begin(), names.end()); + failing_transport_t t; + job_mp_t job{/*self=*/0, name_views, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf("p0"), + coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3"), + }); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, blobs_[0], ac, mem_t(), /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegAttachWrongScalarSize) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + { + buf_t short_scalar(31); + std::memset(short_scalar.data(), 0x01, static_cast(short_scalar.size())); + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, short_scalar, Qi, out), SUCCESS); + } + { + buf_t long_scalar(33); + std::memset(long_scalar.data(), 0x01, static_cast(long_scalar.size())); + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, long_scalar, Qi, out), SUCCESS); + } +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegAttachZeroScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + buf_t zero_scalar(x.size()); + std::memset(zero_scalar.data(), 0x00, static_cast(zero_scalar.size())); + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, zero_scalar, Qi, out), SUCCESS); +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegAttachGarbageScalar) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t Qi; + ASSERT_EQ(coinbase::api::eddsa_mp::get_public_share_compressed(blobs_[0], Qi), SUCCESS); + + buf_t garbage_scalar(x.size()); + std::memset(garbage_scalar.data(), 0xDE, static_cast(garbage_scalar.size())); + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, garbage_scalar, Qi, out), SUCCESS); +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegAttachEmptyPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, x, mem_t(), out), SUCCESS); +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegAttachAllZeroPublicShare) { + buf_t pub_blob, x; + ASSERT_EQ(coinbase::api::eddsa_mp::detach_private_scalar(blobs_[0], pub_blob, x), SUCCESS); + + uint8_t zero_point[32] = {}; + buf_t out; + EXPECT_NE(coinbase::api::eddsa_mp::attach_private_scalar(pub_blob, x, mem_t(zero_point, 32), out), SUCCESS); +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegRefreshAcInvalidAccessStructure) { + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views(names.begin(), names.end()); + failing_transport_t t; + job_mp_t job{/*self=*/0, name_views, t}; + const auto bad_ac = coinbase::api::access_structure_t::leaf("p0"); + const std::vector quorum = {"p0", "p1"}; + buf_t sid, new_key_blob; + EXPECT_NE(coinbase::api::eddsa_mp::refresh_ac(job, sid, blobs_[0], bad_ac, quorum, new_key_blob), SUCCESS); +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegSignAcWrongAccessStructure) { + std::vector names = {"p0", "p1"}; + std::vector name_views(names.begin(), names.end()); + failing_transport_t t; + job_mp_t job{/*self=*/0, name_views, t}; + const auto wrong_ac = + coinbase::api::access_structure_t::Threshold(3, { + coinbase::api::access_structure_t::leaf("p0"), + coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3"), + }); + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, blobs_[0], wrong_ac, msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST_F(ApiEdDSAMpAcNegWithBlobs, NegSignAcInsufficientQuorum) { + std::vector name_views = {"p0"}; + failing_transport_t t; + job_mp_t job{/*self=*/0, name_views, t}; + const auto ac = coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf("p0"), + coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3"), + }); + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + buf_t sig; + EXPECT_NE(coinbase::api::eddsa_mp::sign_ac(job, blobs_[0], ac, msg, /*sig_receiver=*/0, sig), SUCCESS); +} diff --git a/tests/unit/api/test_hd_keyset_ecdsa_2p.cpp b/tests/unit/api/test_hd_keyset_ecdsa_2p.cpp new file mode 100644 index 00000000..2c67ac05 --- /dev/null +++ b/tests/unit/api/test_hd_keyset_ecdsa_2p.cpp @@ -0,0 +1,582 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::api::curve_id; + +using coinbase::api::ecdsa_2p::party_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_2pc; + +static void exercise_curve(curve_id curve, const coinbase::crypto::ecurve_t& verify_curve) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + buf_t keyset1; + buf_t keyset2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + run_2pc( + c1, c2, [&] { return coinbase::api::hd_keyset_ecdsa_2p::dkg(job1, curve, keyset1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::dkg(job2, curve, keyset2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t root_pub1; + buf_t root_pub2; + ASSERT_EQ(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(keyset1, root_pub1), SUCCESS); + ASSERT_EQ(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(keyset2, root_pub2), SUCCESS); + ASSERT_EQ(root_pub1, root_pub2); + + // Deterministic 32-byte "hash" for testing. + buf_t msg_hash(32); + for (int i = 0; i < msg_hash.size(); i++) msg_hash[i] = static_cast(0xA0 + i); + + // Derive 2 keys. + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; // demo-only + + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 1}}); + + std::vector derived1; + std::vector derived2; + buf_t sid1; + buf_t sid2; + + run_2pc( + c1, c2, + [&] { + return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job1, keyset1, hard, non_hard, sid1, derived1); + }, + [&] { + return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job2, keyset2, hard, non_hard, sid2, derived2); + }, + rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + ASSERT_EQ(sid1, sid2); + ASSERT_EQ(derived1.size(), non_hard.size()); + ASSERT_EQ(derived2.size(), non_hard.size()); + + // Derived pubkeys must match across parties. + for (size_t i = 0; i < non_hard.size(); i++) { + buf_t pub_a; + buf_t pub_b; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(derived1[i], pub_a), SUCCESS); + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(derived2[i], pub_b), SUCCESS); + EXPECT_EQ(pub_a, pub_b); + } + + // Sign using the first derived key. + buf_t sig1; + buf_t sig2; + buf_t sid3; + buf_t sid4; + run_2pc( + c1, c2, [&] { return coinbase::api::ecdsa_2p::sign(job1, derived1[0], msg_hash, sid3, sig1); }, + [&] { return coinbase::api::ecdsa_2p::sign(job2, derived2[0], msg_hash, sid4, sig2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sid3, sid4); + EXPECT_GT(sig1.size(), 0); + EXPECT_EQ(sig2.size(), 0); + + buf_t derived_pub; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(derived1[0], derived_pub), SUCCESS); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(verify_curve, derived_pub), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + ASSERT_EQ(verify_key.verify(msg_hash, sig1), SUCCESS); + + // Refresh keyset shares. + buf_t keyset1_ref; + buf_t keyset2_ref; + run_2pc( + c1, c2, [&] { return coinbase::api::hd_keyset_ecdsa_2p::refresh(job1, keyset1, keyset1_ref); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::refresh(job2, keyset2, keyset2_ref); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t root_pub1_ref; + buf_t root_pub2_ref; + ASSERT_EQ(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(keyset1_ref, root_pub1_ref), SUCCESS); + ASSERT_EQ(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(keyset2_ref, root_pub2_ref), SUCCESS); + EXPECT_EQ(root_pub1_ref, root_pub1); + EXPECT_EQ(root_pub2_ref, root_pub2); + + // Derive again; derived pubkeys should remain stable. + std::vector derived1_ref; + std::vector derived2_ref; + buf_t sid5; + buf_t sid6; + run_2pc( + c1, c2, + [&] { + return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job1, keyset1_ref, hard, non_hard, sid5, + derived1_ref); + }, + [&] { + return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job2, keyset2_ref, hard, non_hard, sid6, + derived2_ref); + }, + rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(derived1_ref.size(), derived1.size()); + EXPECT_EQ(derived2_ref.size(), derived2.size()); + + for (size_t i = 0; i < derived1.size(); i++) { + buf_t pub_old; + buf_t pub_new; + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(derived1[i], pub_old), SUCCESS); + ASSERT_EQ(coinbase::api::ecdsa_2p::get_public_key_compressed(derived1_ref[i], pub_new), SUCCESS); + EXPECT_EQ(pub_old, pub_new); + } +} + +} // namespace + +TEST(ApiHdKeysetEcdsa2p, DkgDeriveSignRefreshDerive) { + exercise_curve(curve_id::secp256k1, coinbase::crypto::curve_secp256k1); + exercise_curve(curve_id::p256, coinbase::crypto::curve_p256); +} + +TEST(ApiHdKeysetEcdsa2p, UnsupportedCurveRejected) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_ecdsa_2p::dkg(job, curve_id::ed25519, keyset), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiHdKeysetEcdsa2p, DkgInvalidCurveZero) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_ecdsa_2p::dkg(job, curve_id(0), keyset), E_BADARG); +} + +TEST(ApiHdKeysetEcdsa2p, DkgInvalidCurveFour) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_ecdsa_2p::dkg(job, curve_id(4), keyset), E_BADARG); +} + +TEST(ApiHdKeysetEcdsa2p, DkgInvalidCurve255) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_ecdsa_2p::dkg(job, curve_id(255), keyset), E_BADARG); +} + +TEST(ApiHdKeysetEcdsa2p, DkgEmptyP1Name) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "", "p2", t}; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::dkg(job, curve_id::secp256k1, keyset), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, DkgEmptyP2Name) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "", t}; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::dkg(job, curve_id::secp256k1, keyset), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, DkgSameP1P2Name) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "same", "same", t}; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::dkg(job, curve_id::secp256k1, keyset), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, ExtractRootPubKeyEmptyBlob) { + buf_t out; + buf_t empty; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(empty, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, ExtractRootPubKeyGarbageBlob) { + buf_t out; + buf_t garbage(4); + garbage[0] = 0xDE; + garbage[1] = 0xAD; + garbage[2] = 0xBE; + garbage[3] = 0xEF; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(garbage, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, ExtractRootPubKeyAllZeroBlob) { + buf_t out; + buf_t zeros(64); + std::memset(zeros.data(), 0, zeros.size()); + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(zeros, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, ExtractRootPubKeyOneByteBlob) { + buf_t out; + buf_t one(1); + one[0] = 0x42; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(one, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, ExtractRootPubKeyOversizedBlob) { + buf_t out; + buf_t oversized(1048577); + std::memset(oversized.data(), 0xAA, oversized.size()); + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::extract_root_public_key_compressed(oversized, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, RefreshEmptyBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t empty; + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::refresh(job, empty, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, RefreshGarbageBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t garbage(4); + garbage[0] = 0xDE; + garbage[1] = 0xAD; + garbage[2] = 0xBE; + garbage[3] = 0xEF; + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::refresh(job, garbage, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, RefreshAllZeroBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t zeros(64); + std::memset(zeros.data(), 0, zeros.size()); + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::refresh(job, zeros, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, RefreshOneByteBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t one(1); + one[0] = 0x42; + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::refresh(job, one, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, RefreshOversizedBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t oversized(1048577); + std::memset(oversized.data(), 0xAA, oversized.size()); + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::refresh(job, oversized, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, DeriveEmptyBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t empty; + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job, empty, hard, non_hard, sid, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, DeriveGarbageBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t garbage(4); + garbage[0] = 0xDE; + garbage[1] = 0xAD; + garbage[2] = 0xBE; + garbage[3] = 0xEF; + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job, garbage, hard, non_hard, sid, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, DeriveAllZeroBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t zeros(64); + std::memset(zeros.data(), 0, zeros.size()); + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job, zeros, hard, non_hard, sid, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, DeriveOneByteBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t one(1); + one[0] = 0x42; + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job, one, hard, non_hard, sid, out), SUCCESS); +} + +TEST(ApiHdKeysetEcdsa2p, DeriveOversizedBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t oversized(1048577); + std::memset(oversized.data(), 0xAA, oversized.size()); + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job, oversized, hard, non_hard, sid, out), SUCCESS); +} + +namespace { + +using coinbase::mem_t; + +static void generate_ecdsa_hd_keyset_blobs(curve_id curve, buf_t& blob1, buf_t& blob2) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + error_t rv1 = UNINITIALIZED_ERROR, rv2 = UNINITIALIZED_ERROR; + run_2pc( + c1, c2, [&] { return coinbase::api::hd_keyset_ecdsa_2p::dkg(job1, curve, blob1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::dkg(job2, curve, blob2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); +} + +} // namespace + +class ApiHdKeysetEcdsa2pNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { generate_ecdsa_hd_keyset_blobs(curve_id::secp256k1, blob1_, blob2_); } + static buf_t blob1_; + static buf_t blob2_; +}; +buf_t ApiHdKeysetEcdsa2pNegWithBlobs::blob1_; +buf_t ApiHdKeysetEcdsa2pNegWithBlobs::blob2_; + +TEST_F(ApiHdKeysetEcdsa2pNegWithBlobs, RefreshRoleMismatchP1BlobWithP2Job) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p2, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p1, "p1", "p2", t2}; + + buf_t out1, out2; + error_t rv1 = UNINITIALIZED_ERROR, rv2 = UNINITIALIZED_ERROR; + run_2pc( + c1, c2, [&] { return coinbase::api::hd_keyset_ecdsa_2p::refresh(job1, blob1_, out1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::refresh(job2, blob2_, out2); }, rv1, rv2); + EXPECT_NE(rv1, SUCCESS); + EXPECT_NE(rv2, SUCCESS); +} + +TEST_F(ApiHdKeysetEcdsa2pNegWithBlobs, RefreshRoleMismatchP2BlobWithP1Job) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + buf_t out1, out2; + error_t rv1 = UNINITIALIZED_ERROR, rv2 = UNINITIALIZED_ERROR; + run_2pc( + c1, c2, [&] { return coinbase::api::hd_keyset_ecdsa_2p::refresh(job1, blob2_, out1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::refresh(job2, blob1_, out2); }, rv1, rv2); + EXPECT_NE(rv1, SUCCESS); + EXPECT_NE(rv2, SUCCESS); +} + +TEST_F(ApiHdKeysetEcdsa2pNegWithBlobs, DeriveRoleMismatchP1BlobWithP2Job) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p2, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p1, "p1", "p2", t2}; + + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + + buf_t sid1, sid2; + std::vector out1, out2; + error_t rv1 = UNINITIALIZED_ERROR, rv2 = UNINITIALIZED_ERROR; + run_2pc( + c1, c2, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job1, blob1_, hard, non_hard, sid1, out1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job2, blob2_, hard, non_hard, sid2, out2); }, + rv1, rv2); + EXPECT_NE(rv1, SUCCESS); + EXPECT_NE(rv2, SUCCESS); +} + +TEST_F(ApiHdKeysetEcdsa2pNegWithBlobs, DeriveEmptyNonHardenedPaths) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + + buf_t sid1, sid2; + std::vector out1, out2; + error_t rv1 = UNINITIALIZED_ERROR, rv2 = UNINITIALIZED_ERROR; + run_2pc( + c1, c2, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job1, blob1_, hard, non_hard, sid1, out1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job2, blob2_, hard, non_hard, sid2, out2); }, + rv1, rv2); + bool both_ok = (rv1 == SUCCESS) && (rv2 == SUCCESS); + bool both_fail = (rv1 != SUCCESS) && (rv2 != SUCCESS); + EXPECT_TRUE(both_ok || both_fail); + if (both_ok) { + EXPECT_EQ(out1.size(), 0u); + EXPECT_EQ(out2.size(), 0u); + } +} + +TEST_F(ApiHdKeysetEcdsa2pNegWithBlobs, DeriveDuplicateNonHardenedPaths) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t dup_path; + dup_path.indices = {0, 0}; + std::vector non_hard = {dup_path, dup_path}; + + buf_t sid1, sid2; + std::vector out1, out2; + error_t rv1 = UNINITIALIZED_ERROR, rv2 = UNINITIALIZED_ERROR; + run_2pc( + c1, c2, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job1, blob1_, hard, non_hard, sid1, out1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job2, blob2_, hard, non_hard, sid2, out2); }, + rv1, rv2); + EXPECT_EQ(rv1, E_BADARG); + EXPECT_EQ(rv2, E_BADARG); +} + +TEST_F(ApiHdKeysetEcdsa2pNegWithBlobs, DeriveEmptyHardenedPath) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t hard; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_ecdsa_2p::bip32_path_t{{0, 0}}); + + buf_t sid1, sid2; + std::vector out1, out2; + error_t rv1 = UNINITIALIZED_ERROR, rv2 = UNINITIALIZED_ERROR; + run_2pc( + c1, c2, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job1, blob1_, hard, non_hard, sid1, out1); }, + [&] { return coinbase::api::hd_keyset_ecdsa_2p::derive_ecdsa_2p_keys(job2, blob2_, hard, non_hard, sid2, out2); }, + rv1, rv2); + EXPECT_NE(rv1, UNINITIALIZED_ERROR); + EXPECT_NE(rv2, UNINITIALIZED_ERROR); + EXPECT_EQ(rv1, rv2); +} diff --git a/tests/unit/api/test_hd_keyset_eddsa_2p.cpp b/tests/unit/api/test_hd_keyset_eddsa_2p.cpp new file mode 100644 index 00000000..7f858877 --- /dev/null +++ b/tests/unit/api/test_hd_keyset_eddsa_2p.cpp @@ -0,0 +1,458 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::curve_id; +using coinbase::api::eddsa_2p::party_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_2pc; + +static void exercise_ed25519() { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + buf_t keyset1; + buf_t keyset2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + run_2pc( + c1, c2, [&] { return coinbase::api::hd_keyset_eddsa_2p::dkg(job1, curve_id::ed25519, keyset1); }, + [&] { return coinbase::api::hd_keyset_eddsa_2p::dkg(job2, curve_id::ed25519, keyset2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t root_pub1; + buf_t root_pub2; + ASSERT_EQ(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(keyset1, root_pub1), SUCCESS); + ASSERT_EQ(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(keyset2, root_pub2), SUCCESS); + EXPECT_EQ(root_pub1.size(), 32); + EXPECT_EQ(root_pub1, root_pub2); + + // Deterministic 32-byte message for testing. + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(0x11 + i); + + // Derive two keys. + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 1}}); + + std::vector derived1; + std::vector derived2; + buf_t sid1; + buf_t sid2; + run_2pc( + c1, c2, + [&] { + return coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job1, keyset1, hard, non_hard, sid1, derived1); + }, + [&] { + return coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job2, keyset2, hard, non_hard, sid2, derived2); + }, + rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sid1, sid2); + ASSERT_EQ(derived1.size(), non_hard.size()); + ASSERT_EQ(derived2.size(), non_hard.size()); + + for (size_t i = 0; i < derived1.size(); i++) { + buf_t pub_a; + buf_t pub_b; + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(derived1[i], pub_a), SUCCESS); + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(derived2[i], pub_b), SUCCESS); + EXPECT_EQ(pub_a.size(), 32); + EXPECT_EQ(pub_a, pub_b); + } + + // Sign with derived key #0. + buf_t derived_pub; + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(derived1[0], derived_pub), SUCCESS); + coinbase::crypto::ecc_point_t derived_Q; + ASSERT_EQ(derived_Q.from_bin(coinbase::crypto::curve_ed25519, derived_pub), SUCCESS); + const coinbase::crypto::ecc_pub_key_t derived_verify_key(derived_Q); + + buf_t sig1; + buf_t sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::sign(job1, derived1[0], msg, sig1); }, + [&] { return coinbase::api::eddsa_2p::sign(job2, derived2[0], msg, sig2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sig1.size(), 64); + EXPECT_EQ(sig2.size(), 0); + ASSERT_EQ(derived_verify_key.verify(msg, sig1), SUCCESS); + + // Refresh shares and ensure root pub stays same and derived pub stays same. + buf_t keyset1_ref; + buf_t keyset2_ref; + run_2pc( + c1, c2, [&] { return coinbase::api::hd_keyset_eddsa_2p::refresh(job1, keyset1, keyset1_ref); }, + [&] { return coinbase::api::hd_keyset_eddsa_2p::refresh(job2, keyset2, keyset2_ref); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t root_pub1_ref; + buf_t root_pub2_ref; + ASSERT_EQ(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(keyset1_ref, root_pub1_ref), SUCCESS); + ASSERT_EQ(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(keyset2_ref, root_pub2_ref), SUCCESS); + EXPECT_EQ(root_pub1_ref, root_pub1); + EXPECT_EQ(root_pub2_ref, root_pub2); + + std::vector derived1_ref; + std::vector derived2_ref; + buf_t sid3; + buf_t sid4; + run_2pc( + c1, c2, + [&] { + return coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job1, keyset1_ref, hard, non_hard, sid3, + derived1_ref); + }, + [&] { + return coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job2, keyset2_ref, hard, non_hard, sid4, + derived2_ref); + }, + rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + ASSERT_EQ(derived1_ref.size(), derived1.size()); + ASSERT_EQ(derived2_ref.size(), derived2.size()); + + for (size_t i = 0; i < derived1.size(); i++) { + buf_t pub_old; + buf_t pub_new; + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(derived1[i], pub_old), SUCCESS); + ASSERT_EQ(coinbase::api::eddsa_2p::get_public_key_compressed(derived1_ref[i], pub_new), SUCCESS); + EXPECT_EQ(pub_old, pub_new); + } +} + +} // namespace + +TEST(ApiHdKeysetEddsa2p, DkgDeriveSignRefreshDerive) { exercise_ed25519(); } + +TEST(ApiHdKeysetEddsa2p, UnsupportedCurveRejected) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_eddsa_2p::dkg(job, curve_id::secp256k1, keyset), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiHdKeysetEddsa2p, DkgInvalidCurveP256) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_eddsa_2p::dkg(job, curve_id::p256, keyset), E_BADARG); +} + +TEST(ApiHdKeysetEddsa2p, DkgInvalidCurveZero) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_eddsa_2p::dkg(job, curve_id(0), keyset), E_BADARG); +} + +TEST(ApiHdKeysetEddsa2p, DkgInvalidCurveFour) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_eddsa_2p::dkg(job, curve_id(4), keyset), E_BADARG); +} + +TEST(ApiHdKeysetEddsa2p, DkgInvalidCurve255) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::hd_keyset_eddsa_2p::dkg(job, curve_id(255), keyset), E_BADARG); +} + +TEST(ApiHdKeysetEddsa2p, DkgEmptyP1Name) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "", "p2", t}; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::dkg(job, curve_id::ed25519, keyset), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, DkgEmptyP2Name) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "", t}; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::dkg(job, curve_id::ed25519, keyset), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, DkgSamePartyNames) { + failing_transport_t t; + buf_t keyset; + const coinbase::api::job_2p_t job{party_t::p1, "same", "same", t}; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::dkg(job, curve_id::ed25519, keyset), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, ExtractRootPubKeyEmptyBlob) { + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(mem_t(), out), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, ExtractRootPubKeyGarbage) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(mem_t(garbage, sizeof(garbage)), out), + SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, ExtractRootPubKeyAllZero) { + uint8_t zeros[64] = {}; + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(mem_t(zeros, sizeof(zeros)), out), + SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, ExtractRootPubKeyOneByte) { + uint8_t one = 0; + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(mem_t(&one, 1), out), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, ExtractRootPubKeyOversized) { + buf_t big(1024 * 1024 + 1); + buf_t out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::extract_root_public_key_compressed(big, out), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, RefreshEmptyBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t new_keyset; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::refresh(job, mem_t(), new_keyset), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, RefreshGarbageBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t new_keyset; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::refresh(job, mem_t(garbage, sizeof(garbage)), new_keyset), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, RefreshAllZeroBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + uint8_t zeros[64] = {}; + buf_t new_keyset; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::refresh(job, mem_t(zeros, sizeof(zeros)), new_keyset), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, RefreshOneByteBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + uint8_t one = 0; + buf_t new_keyset; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::refresh(job, mem_t(&one, 1), new_keyset), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, RefreshOversizedBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t big(1024 * 1024 + 1); + buf_t new_keyset; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::refresh(job, big, new_keyset), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, DeriveEmptyBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, mem_t(), hard, non_hard, sid, out), SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, DeriveGarbageBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, mem_t(garbage, sizeof(garbage)), hard, + non_hard, sid, out), + SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, DeriveAllZeroBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + uint8_t zeros[64] = {}; + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, mem_t(zeros, sizeof(zeros)), hard, non_hard, + sid, out), + SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, DeriveOneByteBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + uint8_t one = 0; + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, mem_t(&one, 1), hard, non_hard, sid, out), + SUCCESS); +} + +TEST(ApiHdKeysetEddsa2p, DeriveOversizedBlob) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t big(1024 * 1024 + 1); + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, big, hard, non_hard, sid, out), SUCCESS); +} + +namespace { + +static void generate_eddsa_hd_keyset_blobs(buf_t& blob1, buf_t& blob2) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + error_t rv1 = UNINITIALIZED_ERROR, rv2 = UNINITIALIZED_ERROR; + run_2pc( + c1, c2, [&] { return coinbase::api::hd_keyset_eddsa_2p::dkg(job1, curve_id::ed25519, blob1); }, + [&] { return coinbase::api::hd_keyset_eddsa_2p::dkg(job2, curve_id::ed25519, blob2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); +} + +} // namespace + +class ApiHdKeysetEddsa2pNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { generate_eddsa_hd_keyset_blobs(blob1_, blob2_); } + static buf_t blob1_; + static buf_t blob2_; +}; +buf_t ApiHdKeysetEddsa2pNegWithBlobs::blob1_; +buf_t ApiHdKeysetEddsa2pNegWithBlobs::blob2_; + +TEST_F(ApiHdKeysetEddsa2pNegWithBlobs, RefreshRoleMismatchP1BlobP2Job) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p2, "p1", "p2", t}; + buf_t new_keyset; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::refresh(job, blob1_, new_keyset), SUCCESS); +} + +TEST_F(ApiHdKeysetEddsa2pNegWithBlobs, RefreshRoleMismatchP2BlobP1Job) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + buf_t new_keyset; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::refresh(job, blob2_, new_keyset), SUCCESS); +} + +TEST_F(ApiHdKeysetEddsa2pNegWithBlobs, DeriveRoleMismatchP1BlobP2Job) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p2, "p1", "p2", t}; + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, blob1_, hard, non_hard, sid, out), SUCCESS); +} + +TEST_F(ApiHdKeysetEddsa2pNegWithBlobs, DeriveDuplicateNonHardenedPaths) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_EQ(coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, blob1_, hard, non_hard, sid, out), E_BADARG); +} + +TEST_F(ApiHdKeysetEddsa2pNegWithBlobs, DeriveEmptyNonHardenedPaths) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t hard; + hard.indices = {0x8000002c, 0x80000000, 0x80000000}; + std::vector empty_paths; + buf_t sid; + std::vector out; + error_t rv = coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, blob1_, hard, empty_paths, sid, out); + if (rv == SUCCESS) { + EXPECT_TRUE(out.empty()); + } +} + +TEST_F(ApiHdKeysetEddsa2pNegWithBlobs, DeriveEmptyHardenedPath) { + failing_transport_t t; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + coinbase::api::hd_keyset_eddsa_2p::bip32_path_t empty_hard; + std::vector non_hard; + non_hard.push_back(coinbase::api::hd_keyset_eddsa_2p::bip32_path_t{{0, 0}}); + buf_t sid; + std::vector out; + EXPECT_NE(coinbase::api::hd_keyset_eddsa_2p::derive_eddsa_2p_keys(job, blob1_, empty_hard, non_hard, sid, out), + SUCCESS); +} diff --git a/tests/unit/api/test_pve.cpp b/tests/unit/api/test_pve.cpp new file mode 100644 index 00000000..9610ad52 --- /dev/null +++ b/tests/unit/api/test_pve.cpp @@ -0,0 +1,741 @@ +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::curve_id; + +class toy_base_pke_t final : public coinbase::api::pve::base_pke_i { + public: + explicit toy_base_pke_t(bool mutate) : mutate_(mutate) {} + + error_t encrypt(mem_t /*ek*/, mem_t /*label*/, mem_t plain, mem_t /*rho*/, buf_t& out_ct) const override { + out_ct = buf_t(plain); + if (mutate_) { + const uint8_t b = 0x42; + out_ct += mem_t(&b, 1); + } + return SUCCESS; + } + + error_t decrypt(mem_t /*dk*/, mem_t /*label*/, mem_t ct, buf_t& out_plain) const override { + out_plain = buf_t(ct); + if (mutate_) { + if (out_plain.size() < 1) return E_FORMAT; + out_plain.resize(out_plain.size() - 1); + } + return SUCCESS; + } + + private: + bool mutate_ = false; +}; + +static buf_t expected_Q(curve_id cid, mem_t x) { + const coinbase::crypto::ecurve_t curve = (cid == curve_id::p256) ? coinbase::crypto::curve_p256 + : (cid == curve_id::secp256k1) ? coinbase::crypto::curve_secp256k1 + : (cid == curve_id::ed25519) ? coinbase::crypto::curve_ed25519 + : coinbase::crypto::ecurve_t(); + cb_assert(curve.valid()); + + const coinbase::crypto::bn_t bn_x = coinbase::crypto::bn_t::from_bin(x) % curve.order(); + const coinbase::crypto::ecc_point_t Q = bn_x * curve.generator(); + return Q.to_compressed_bin(); +} + +// Mirror of the cbmpc base-PKE key blob format used by `coinbase::api::pve`. +// This is test-only plumbing so we can build HSM stubs using software keys. +constexpr uint32_t base_pke_key_blob_version_v1 = 1; +enum class base_pke_key_type_v1 : uint32_t { + rsa_oaep_2048 = 1, + ecies_p256 = 2, +}; + +struct base_pke_dk_blob_v1_t { + uint32_t version = base_pke_key_blob_version_v1; + uint32_t key_type = static_cast(base_pke_key_type_v1::rsa_oaep_2048); + + coinbase::crypto::rsa_prv_key_t rsa_dk; + coinbase::crypto::ecc_prv_key_t ecies_dk; + + void convert(coinbase::converter_t& c) { + c.convert(version, key_type); + switch (static_cast(key_type)) { + case base_pke_key_type_v1::rsa_oaep_2048: + c.convert(rsa_dk); + return; + case base_pke_key_type_v1::ecies_p256: + c.convert(ecies_dk); + return; + default: + c.set_error(); + return; + } + } +}; + +static error_t parse_rsa_prv_from_dk_blob(mem_t dk_blob, coinbase::crypto::rsa_prv_key_t& out_sk) { + base_pke_dk_blob_v1_t blob; + error_t rv = coinbase::convert(blob, dk_blob); + if (rv) return rv; + if (blob.version != base_pke_key_blob_version_v1) return E_FORMAT; + if (static_cast(blob.key_type) != base_pke_key_type_v1::rsa_oaep_2048) return E_BADARG; + out_sk = blob.rsa_dk; + return SUCCESS; +} + +static error_t parse_ecies_prv_from_dk_blob(mem_t dk_blob, coinbase::crypto::ecc_prv_key_t& out_sk) { + base_pke_dk_blob_v1_t blob; + error_t rv = coinbase::convert(blob, dk_blob); + if (rv) return rv; + if (blob.version != base_pke_key_blob_version_v1) return E_FORMAT; + if (static_cast(blob.key_type) != base_pke_key_type_v1::ecies_p256) return E_BADARG; + out_sk = blob.ecies_dk; + return SUCCESS; +} + +} // namespace + +TEST(ApiPve, EncryptVerifyDecrypt_CustomBasePke) { + const toy_base_pke_t base_pke(/*mutate=*/false); + + const curve_id curve = curve_id::secp256k1; + const buf_t ek = buf_t("ek"); + const buf_t dk = buf_t("dk"); + const buf_t label = buf_t("label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i); + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve, ek, label, x_mem, ct), SUCCESS); + ASSERT_GT(ct.size(), 0); + + buf_t Q_ct; + ASSERT_EQ(coinbase::api::pve::get_public_key_compressed(ct, Q_ct), SUCCESS); + + buf_t L_ct; + ASSERT_EQ(coinbase::api::pve::get_Label(ct, L_ct), SUCCESS); + EXPECT_EQ(L_ct, label); + + const buf_t Q_expected = expected_Q(curve, x_mem); + EXPECT_EQ(Q_ct, Q_expected); + + ASSERT_EQ(coinbase::api::pve::verify(base_pke, curve, ek, ct, Q_expected, label), SUCCESS); + + buf_t x_out; + ASSERT_EQ(coinbase::api::pve::decrypt(base_pke, curve, dk, ek, ct, label, x_out), SUCCESS); + EXPECT_EQ(x_out.size(), 32); + EXPECT_EQ(x_out, buf_t(x_mem)); +} + +TEST(ApiPve, EncVerDec_DefBasePke_EciesBlob) { + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0xA0 + i); + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ek_blob; + buf_t dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(ek_blob, dk_blob), SUCCESS); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve, ek_blob, label, x_mem, ct), SUCCESS); + + const buf_t Q_expected = expected_Q(curve, x_mem); + ASSERT_EQ(coinbase::api::pve::verify(curve, ek_blob, ct, Q_expected, label), SUCCESS); + + buf_t x_out; + ASSERT_EQ(coinbase::api::pve::decrypt(curve, dk_blob, ek_blob, ct, label, x_out), SUCCESS); + EXPECT_EQ(x_out.size(), 32); + EXPECT_EQ(x_out, buf_t(x_mem)); +} + +TEST(ApiPve, EncVerDec_DefBasePke_RsaBlob) { + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x11 + i); + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ek_blob; + buf_t dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob), SUCCESS); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve, ek_blob, label, x_mem, ct), SUCCESS); + + const buf_t Q_expected = expected_Q(curve, x_mem); + ASSERT_EQ(coinbase::api::pve::verify(curve, ek_blob, ct, Q_expected, label), SUCCESS); + + buf_t x_out; + ASSERT_EQ(coinbase::api::pve::decrypt(curve, dk_blob, ek_blob, ct, label, x_out), SUCCESS); + EXPECT_EQ(x_out.size(), 32); + EXPECT_EQ(x_out, buf_t(x_mem)); +} + +TEST(ApiPve, EncryptRejectsOversizedX) { + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + buf_t ek_blob; + buf_t dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob), SUCCESS); + + // secp256k1 order is 32 bytes; ensure oversize inputs are rejected. + std::array x_bytes{}; + for (int i = 0; i < 33; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ct; + EXPECT_EQ(coinbase::api::pve::encrypt(curve, ek_blob, label, x_mem, ct), E_RANGE); +} + +TEST(ApiPve, DecryptRsaOaepHsm_UsesCallback) { + struct ctx_t { + coinbase::crypto::rsa_prv_key_t sk; + } ctx; + + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x22 + i); + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ek_blob; + buf_t dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob), SUCCESS); + ASSERT_EQ(parse_rsa_prv_from_dk_blob(dk_blob, ctx.sk), SUCCESS); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve, ek_blob, label, x_mem, ct), SUCCESS); + + coinbase::api::pve::rsa_oaep_hsm_decap_cb_t cb; + cb.ctx = &ctx; + cb.decap = +[](void* c, mem_t /*dk_handle*/, mem_t kem_ct, buf_t& out_kem_ss) -> error_t { + auto* ctxp = static_cast(c); + // OAEP label is empty per cbmpc's RSA KEM policy. + return ctxp->sk.decrypt_oaep(kem_ct, coinbase::crypto::hash_e::sha256, coinbase::crypto::hash_e::sha256, mem_t(), + out_kem_ss); + }; + + buf_t x_out; + ASSERT_EQ( + coinbase::api::pve::decrypt_rsa_oaep_hsm(curve, /*dk_handle=*/buf_t("hsm-handle"), ek_blob, ct, label, cb, x_out), + SUCCESS); + EXPECT_EQ(x_out, buf_t(x_mem)); +} + +TEST(ApiPve, DecryptEciesP256Hsm_UsesCallback) { + struct ctx_t { + coinbase::crypto::ecc_prv_key_t sk; + } ctx; + + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x33 + i); + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ek_blob; + buf_t dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(ek_blob, dk_blob), SUCCESS); + ASSERT_EQ(parse_ecies_prv_from_dk_blob(dk_blob, ctx.sk), SUCCESS); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve, ek_blob, label, x_mem, ct), SUCCESS); + + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb; + cb.ctx = &ctx; + cb.ecdh = +[](void* c, mem_t /*dk_handle*/, mem_t kem_ct, buf_t& out_dh_x32) -> error_t { + auto* ctxp = static_cast(c); + coinbase::crypto::ecc_point_t E; + error_t rv = E.from_oct(coinbase::crypto::curve_p256, kem_ct); + if (rv) return rv; + if (rv = coinbase::crypto::curve_p256.check(E)) return rv; + out_dh_x32 = ctxp->sk.ecdh(E); + return SUCCESS; + }; + + buf_t x_out; + ASSERT_EQ(coinbase::api::pve::decrypt_ecies_p256_hsm(curve, /*dk_handle=*/buf_t("hsm-handle"), ek_blob, ct, label, cb, + x_out), + SUCCESS); + EXPECT_EQ(x_out, buf_t(x_mem)); +} + +TEST(ApiPve, VerifyRejectsWrongLabel) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const curve_id curve = curve_id::secp256k1; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + + std::array x_bytes{}; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve, ek, label, x_mem, ct), SUCCESS); + + const buf_t Q_expected = expected_Q(curve, x_mem); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify(base_pke, curve, ek, ct, Q_expected, /*label=*/buf_t("wrong")), SUCCESS); +} + +TEST(ApiPve, VerifyRejectsWrongQ) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const curve_id curve = curve_id::secp256k1; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + + std::array x_bytes{}; + x_bytes[0] = 7; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve, ek, label, x_mem, ct), SUCCESS); + + // Wrong Q: flip one byte. + buf_t Q_wrong = expected_Q(curve, x_mem); + Q_wrong[0] ^= 0x01; + + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify(base_pke, curve, ek, ct, Q_wrong, label), SUCCESS); +} + +TEST(ApiPve, BasePkeMismatchRejected) { + const toy_base_pke_t base_pke1(/*mutate=*/false); + const toy_base_pke_t base_pke2(/*mutate=*/true); + + const curve_id curve = curve_id::secp256k1; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + + std::array x_bytes{}; + x_bytes[0] = 9; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke1, curve, ek, label, x_mem, ct), SUCCESS); + const buf_t Q_expected = expected_Q(curve, x_mem); + + // Verifying with a different base PKE should fail. + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify(base_pke2, curve, ek, ct, Q_expected, label), SUCCESS); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiPveNeg, EncryptInvalidCurve) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + dylog_disable_scope_t no_log_err; + for (int c : {0, 4, 255}) { + EXPECT_NE(coinbase::api::pve::encrypt(base_pke, static_cast(c), ek, label, x_mem, ct), SUCCESS); + } +} + +TEST(ApiPveNeg, EncryptEmptyEk) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, mem_t(), label, x_mem, ct), SUCCESS); +} + +TEST(ApiPveNeg, EncryptEmptyLabel) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, mem_t(), x_mem, ct), SUCCESS); +} + +TEST(ApiPveNeg, EncryptEmptyX) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + buf_t ct; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, label, mem_t(), ct), SUCCESS); +} + +TEST(ApiPveNeg, EncryptGarbageEk) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::encrypt(curve_id::secp256k1, mem_t(garbage, 4), label, x_mem, ct), SUCCESS); +} + +TEST(ApiPveNeg, VerifyInvalidCurve) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, label, x_mem, ct), SUCCESS); + const buf_t Q = expected_Q(curve_id::secp256k1, x_mem); + dylog_disable_scope_t no_log_err; + for (int c : {0, 4, 255}) { + EXPECT_NE(coinbase::api::pve::verify(base_pke, static_cast(c), ek, ct, Q, label), SUCCESS); + } +} + +TEST(ApiPveNeg, VerifyEmptyEk) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, label, x_mem, ct), SUCCESS); + const buf_t Q = expected_Q(curve_id::secp256k1, x_mem); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify(base_pke, curve_id::secp256k1, mem_t(), ct, Q, label), SUCCESS); +} + +TEST(ApiPveNeg, VerifyEmptyCiphertext) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + const buf_t Q = buf_t("Q"); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify(base_pke, curve_id::secp256k1, ek, mem_t(), Q, label), SUCCESS); +} + +TEST(ApiPveNeg, VerifyEmptyQ) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, label, x_mem, ct), SUCCESS); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify(base_pke, curve_id::secp256k1, ek, ct, mem_t(), label), SUCCESS); +} + +TEST(ApiPveNeg, VerifyEmptyLabel) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, label, x_mem, ct), SUCCESS); + const buf_t Q = expected_Q(curve_id::secp256k1, x_mem); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify(base_pke, curve_id::secp256k1, ek, ct, Q, mem_t()), SUCCESS); +} + +TEST(ApiPveNeg, VerifyGarbageCiphertext) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const buf_t Q = buf_t("Q"); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify(base_pke, curve_id::secp256k1, ek, mem_t(garbage, 4), Q, label), SUCCESS); +} + +TEST(ApiPveNeg, DecryptInvalidCurve) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t dk = buf_t("dk"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, label, x_mem, ct), SUCCESS); + buf_t x_out; + dylog_disable_scope_t no_log_err; + for (int c : {0, 4, 255}) { + EXPECT_NE(coinbase::api::pve::decrypt(base_pke, static_cast(c), dk, ek, ct, label, x_out), SUCCESS); + } +} + +TEST(ApiPveNeg, DecryptEmptyDk) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, label, x_mem, ct), SUCCESS); + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt(base_pke, curve_id::secp256k1, mem_t(), ek, ct, label, x_out), SUCCESS); +} + +TEST(ApiPveNeg, DecryptEmptyEk) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t dk = buf_t("dk"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, buf_t("ek"), label, x_mem, ct), SUCCESS); + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt(base_pke, curve_id::secp256k1, dk, mem_t(), ct, label, x_out), SUCCESS); +} + +TEST(ApiPveNeg, DecryptEmptyCiphertext) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t dk = buf_t("dk"); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt(base_pke, curve_id::secp256k1, dk, ek, mem_t(), label, x_out), SUCCESS); +} + +TEST(ApiPveNeg, DecryptEmptyLabel) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t ek = buf_t("ek"); + const buf_t dk = buf_t("dk"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(base_pke, curve_id::secp256k1, ek, label, x_mem, ct), SUCCESS); + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt(base_pke, curve_id::secp256k1, dk, ek, ct, mem_t(), x_out), SUCCESS); +} + +TEST(ApiPveNeg, DecryptGarbageCiphertext) { + const toy_base_pke_t base_pke(/*mutate=*/false); + const buf_t dk = buf_t("dk"); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt(base_pke, curve_id::secp256k1, dk, ek, mem_t(garbage, 4), label, x_out), + SUCCESS); +} + +TEST(ApiPveNeg, DecryptRsaHsmEmptyDkHandle) { + buf_t ek_blob, dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob), SUCCESS); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve_id::secp256k1, ek_blob, buf_t("label"), x_mem, ct), SUCCESS); + coinbase::api::pve::rsa_oaep_hsm_decap_cb_t cb; + cb.decap = +[](void*, mem_t, mem_t, buf_t&) -> error_t { return SUCCESS; }; + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE( + coinbase::api::pve::decrypt_rsa_oaep_hsm(curve_id::secp256k1, mem_t(), ek_blob, ct, buf_t("label"), cb, x_out), + SUCCESS); +} + +TEST(ApiPveNeg, DecryptRsaHsmNullCallback) { + buf_t ek_blob, dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob), SUCCESS); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve_id::secp256k1, ek_blob, buf_t("label"), x_mem, ct), SUCCESS); + coinbase::api::pve::rsa_oaep_hsm_decap_cb_t cb; + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_rsa_oaep_hsm(curve_id::secp256k1, buf_t("handle"), ek_blob, ct, buf_t("label"), + cb, x_out), + SUCCESS); +} + +TEST(ApiPveNeg, DecryptRsaHsmEkTypeMismatch) { + buf_t ecies_ek, ecies_dk; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(ecies_ek, ecies_dk), SUCCESS); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve_id::secp256k1, ecies_ek, buf_t("label"), x_mem, ct), SUCCESS); + coinbase::api::pve::rsa_oaep_hsm_decap_cb_t cb; + cb.decap = +[](void*, mem_t, mem_t, buf_t&) -> error_t { return SUCCESS; }; + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_rsa_oaep_hsm(curve_id::secp256k1, buf_t("handle"), ecies_ek, ct, buf_t("label"), + cb, x_out), + SUCCESS); +} + +TEST(ApiPveNeg, DecryptEciesHsmEmptyDkHandle) { + buf_t ek_blob, dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(ek_blob, dk_blob), SUCCESS); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve_id::secp256k1, ek_blob, buf_t("label"), x_mem, ct), SUCCESS); + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb; + cb.ecdh = +[](void*, mem_t, mem_t, buf_t&) -> error_t { return SUCCESS; }; + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE( + coinbase::api::pve::decrypt_ecies_p256_hsm(curve_id::secp256k1, mem_t(), ek_blob, ct, buf_t("label"), cb, x_out), + SUCCESS); +} + +TEST(ApiPveNeg, DecryptEciesHsmNullCallback) { + buf_t ek_blob, dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_ecies_p256_keypair(ek_blob, dk_blob), SUCCESS); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve_id::secp256k1, ek_blob, buf_t("label"), x_mem, ct), SUCCESS); + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb; + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_ecies_p256_hsm(curve_id::secp256k1, buf_t("handle"), ek_blob, ct, + buf_t("label"), cb, x_out), + SUCCESS); +} + +TEST(ApiPveNeg, DecryptEciesHsmEkTypeMismatch) { + buf_t rsa_ek, rsa_dk; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(rsa_ek, rsa_dk), SUCCESS); + std::array x_bytes{}; + x_bytes[0] = 1; + const mem_t x_mem(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt(curve_id::secp256k1, rsa_ek, buf_t("label"), x_mem, ct), SUCCESS); + coinbase::api::pve::ecies_p256_hsm_ecdh_cb_t cb; + cb.ecdh = +[](void*, mem_t, mem_t, buf_t&) -> error_t { return SUCCESS; }; + buf_t x_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_ecies_p256_hsm(curve_id::secp256k1, buf_t("handle"), rsa_ek, ct, buf_t("label"), + cb, x_out), + SUCCESS); +} + +TEST(ApiPveNeg, GetPublicKeyCompressedEmptyCiphertext) { + buf_t Q; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_public_key_compressed(mem_t(), Q), SUCCESS); +} + +TEST(ApiPveNeg, GetPublicKeyCompressedGarbageCiphertext) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t Q; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_public_key_compressed(mem_t(garbage, 4), Q), SUCCESS); +} + +TEST(ApiPveNeg, GetLabelEmptyCiphertext) { + buf_t label; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_Label(mem_t(), label), SUCCESS); +} + +TEST(ApiPveNeg, GetLabelGarbageCiphertext) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t label; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_Label(mem_t(garbage, 4), label), SUCCESS); +} + +TEST(ApiPveNeg, EciesP256EkFromOctEmptyInput) { + buf_t ek; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::base_pke_ecies_p256_ek_from_oct(mem_t(), ek), SUCCESS); +} + +TEST(ApiPveNeg, EciesP256EkFromOctGarbageInput) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t ek; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::base_pke_ecies_p256_ek_from_oct(mem_t(garbage, 4), ek), SUCCESS); +} + +TEST(ApiPveNeg, EciesP256EkFromOctWrongSize) { + std::array data{}; + buf_t ek; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::base_pke_ecies_p256_ek_from_oct(mem_t(data.data(), static_cast(data.size())), ek), + SUCCESS); +} + +TEST(ApiPveNeg, EciesP256EkFromOctAllZero65) { + std::array data{}; + buf_t ek; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::base_pke_ecies_p256_ek_from_oct(mem_t(data.data(), static_cast(data.size())), ek), + SUCCESS); +} + +TEST(ApiPveNeg, RsaEkFromModulusEmptyInput) { + buf_t ek; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::base_pke_rsa_ek_from_modulus(mem_t(), ek), SUCCESS); +} + +TEST(ApiPveNeg, RsaEkFromModulusGarbageInput) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t ek; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::base_pke_rsa_ek_from_modulus(mem_t(garbage, 4), ek), SUCCESS); +} + +TEST(ApiPveNeg, RsaEkFromModulusWrongSize) { + std::array data{}; + data[0] = 1; + buf_t ek; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::base_pke_rsa_ek_from_modulus(mem_t(data.data(), static_cast(data.size())), ek), + SUCCESS); +} + +TEST(ApiPveNeg, RsaEkFromModulusAllZero256) { + std::array data{}; + buf_t ek; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::base_pke_rsa_ek_from_modulus(mem_t(data.data(), static_cast(data.size())), ek), + SUCCESS); +} diff --git a/tests/unit/api/test_pve_ac.cpp b/tests/unit/api/test_pve_ac.cpp new file mode 100644 index 00000000..0622041d --- /dev/null +++ b/tests/unit/api/test_pve_ac.cpp @@ -0,0 +1,897 @@ +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::curve_id; + +static buf_t expected_Q(curve_id cid, mem_t x) { + const coinbase::crypto::ecurve_t curve = (cid == curve_id::p256) ? coinbase::crypto::curve_p256 + : (cid == curve_id::secp256k1) ? coinbase::crypto::curve_secp256k1 + : (cid == curve_id::ed25519) ? coinbase::crypto::curve_ed25519 + : coinbase::crypto::ecurve_t(); + cb_assert(curve.valid()); + + const coinbase::crypto::bn_t bn_x = coinbase::crypto::bn_t::from_bin(x) % curve.order(); + const coinbase::crypto::ecc_point_t Q = bn_x * curve.generator(); + return Q.to_compressed_bin(); +} + +class toy_base_pke_t final : public coinbase::api::pve::base_pke_i { + public: + error_t encrypt(mem_t /*ek*/, mem_t /*label*/, mem_t plain, mem_t /*rho*/, buf_t& out_ct) const override { + out_ct = buf_t(plain); + return SUCCESS; + } + + error_t decrypt(mem_t /*dk*/, mem_t /*label*/, mem_t ct, buf_t& out_plain) const override { + out_plain = buf_t(ct); + return SUCCESS; + } +}; + +} // namespace + +TEST(ApiPveAc, EncVer_PDec_Agg_DefPke_Rsa) { + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + constexpr int n = 4; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + std::array eks{}; + std::array dks{}; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[0], dks[0]), SUCCESS); + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[1], dks[1]), SUCCESS); + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[2], dks[2]), SUCCESS); + + coinbase::api::pve::leaf_keys_t ac_pks; + ASSERT_TRUE(ac_pks.emplace("p1", mem_t(eks[0].data(), eks[0].size())).second); + ASSERT_TRUE(ac_pks.emplace("p2", mem_t(eks[1].data(), eks[1].size())).second); + ASSERT_TRUE(ac_pks.emplace("p3", mem_t(eks[2].data(), eks[2].size())).second); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(curve, ac, ac_pks, label, xs, ct), SUCCESS); + + int batch_count = 0; + ASSERT_EQ(coinbase::api::pve::get_ac_batch_count(ct, batch_count), SUCCESS); + ASSERT_EQ(batch_count, n); + + std::vector Qs_expected; + Qs_expected.reserve(n); + for (int i = 0; i < n; i++) Qs_expected.push_back(expected_Q(curve, xs[static_cast(i)])); + + std::vector Qs_expected_mem; + Qs_expected_mem.reserve(n); + for (const auto& q : Qs_expected) Qs_expected_mem.emplace_back(q.data(), q.size()); + + ASSERT_EQ(coinbase::api::pve::verify_ac(curve, ac, ac_pks, ct, Qs_expected_mem, label), SUCCESS); + + const int attempt_index = 0; + buf_t share_p1; + buf_t share_p2; + ASSERT_EQ(coinbase::api::pve::partial_decrypt_ac_attempt(curve, ac, ct, attempt_index, "p1", + mem_t(dks[0].data(), dks[0].size()), label, share_p1), + SUCCESS); + ASSERT_EQ(coinbase::api::pve::partial_decrypt_ac_attempt(curve, ac, ct, attempt_index, "p2", + mem_t(dks[1].data(), dks[1].size()), label, share_p2), + SUCCESS); + + coinbase::api::pve::leaf_shares_t quorum; + ASSERT_TRUE(quorum.emplace("p1", mem_t(share_p1.data(), share_p1.size())).second); + ASSERT_TRUE(quorum.emplace("p2", mem_t(share_p2.data(), share_p2.size())).second); + + std::vector xs_out; + ASSERT_EQ(coinbase::api::pve::combine_ac(curve, ac, ct, attempt_index, label, quorum, xs_out), SUCCESS); + + ASSERT_EQ(xs_out.size(), static_cast(n)); + for (int i = 0; i < n; i++) EXPECT_EQ(xs_out[static_cast(i)], buf_t(xs[static_cast(i)])); + + // Insufficient quorum should fail. + coinbase::api::pve::leaf_shares_t insufficient; + ASSERT_TRUE(insufficient.emplace("p1", mem_t(share_p1.data(), share_p1.size())).second); + std::vector xs_out2; + EXPECT_NE(coinbase::api::pve::combine_ac(curve, ac, ct, attempt_index, label, insufficient, xs_out2), SUCCESS); +} + +TEST(ApiPveAc, EncryptRejectsOversizedX) { + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + std::array x_bytes{}; + for (int i = 0; i < 33; i++) x_bytes[static_cast(i)] = static_cast(0x20 + i); + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + + std::array eks{}; + std::array dks{}; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[0], dks[0]), SUCCESS); + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[1], dks[1]), SUCCESS); + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(eks[2], dks[2]), SUCCESS); + + coinbase::api::pve::leaf_keys_t ac_pks; + ASSERT_TRUE(ac_pks.emplace("p1", mem_t(eks[0].data(), eks[0].size())).second); + ASSERT_TRUE(ac_pks.emplace("p2", mem_t(eks[1].data(), eks[1].size())).second); + ASSERT_TRUE(ac_pks.emplace("p3", mem_t(eks[2].data(), eks[2].size())).second); + + buf_t ct; + EXPECT_EQ(coinbase::api::pve::encrypt_ac(curve, ac, ac_pks, label, xs, ct), E_RANGE); +} + +TEST(ApiPveAc, EncVer_PartDec_Agg_CustomBasePke) { + const toy_base_pke_t base_pke; + + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + constexpr int n = 3; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) + xs_bytes[static_cast(i)][static_cast(j)] = static_cast(0x77 + i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + // Toy per-leaf keys. + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + + coinbase::api::pve::leaf_keys_t ac_pks; + ASSERT_TRUE(ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())).second); + ASSERT_TRUE(ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())).second); + ASSERT_TRUE(ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())).second); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve, ac, ac_pks, label, xs, ct), SUCCESS); + + std::vector Qs_expected; + Qs_expected.reserve(n); + for (int i = 0; i < n; i++) Qs_expected.push_back(expected_Q(curve, xs[static_cast(i)])); + + std::vector Qs_expected_mem; + Qs_expected_mem.reserve(n); + for (const auto& q : Qs_expected) Qs_expected_mem.emplace_back(q.data(), q.size()); + + ASSERT_EQ(coinbase::api::pve::verify_ac(base_pke, curve, ac, ac_pks, ct, Qs_expected_mem, label), SUCCESS); + + const int attempt_index = 0; + buf_t share_p1; + buf_t share_p3; + ASSERT_EQ(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve, ac, ct, attempt_index, "p1", ek1, label, + share_p1), + SUCCESS); + ASSERT_EQ(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve, ac, ct, attempt_index, "p3", ek3, label, + share_p3), + SUCCESS); + + coinbase::api::pve::leaf_shares_t quorum; + ASSERT_TRUE(quorum.emplace("p1", mem_t(share_p1.data(), share_p1.size())).second); + ASSERT_TRUE(quorum.emplace("p3", mem_t(share_p3.data(), share_p3.size())).second); + + std::vector xs_out; + ASSERT_EQ(coinbase::api::pve::combine_ac(base_pke, curve, ac, ct, attempt_index, label, quorum, xs_out), SUCCESS); + + ASSERT_EQ(xs_out.size(), static_cast(n)); + for (int i = 0; i < n; i++) EXPECT_EQ(xs_out[static_cast(i)], buf_t(xs[static_cast(i)])); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiPveAcNeg, EncryptAc_InvalidCurve) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + buf_t ct; + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, static_cast(0), ac, ac_pks, label, xs, ct), SUCCESS); + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, static_cast(4), ac, ac_pks, label, xs, ct), SUCCESS); + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, static_cast(255), ac, ac_pks, label, xs, ct), SUCCESS); +} + +TEST(ApiPveAcNeg, EncryptAc_EmptyLabel) { + const toy_base_pke_t base_pke; + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + buf_t ct; + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, mem_t(), xs, ct), SUCCESS); +} + +TEST(ApiPveAcNeg, EncryptAc_EmptyXsVector) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::vector xs; + buf_t ct; + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); +} + +TEST(ApiPveAcNeg, EncryptAc_XsWithEmptyElement) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::vector xs; + xs.emplace_back(mem_t()); + buf_t ct; + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); +} + +TEST(ApiPveAcNeg, EncryptAc_EmptyAcPks) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + coinbase::api::pve::leaf_keys_t ac_pks; + buf_t ct; + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); +} + +TEST(ApiPveAcNeg, EncryptAc_AcPksMissingLeaf) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + + buf_t ct; + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); +} + +TEST(ApiPveAcNeg, EncryptAc_AcPksExtraLeaf) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + const buf_t ek4 = buf_t("ek4"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + ac_pks.emplace("unknown", mem_t(ek4.data(), ek4.size())); + + buf_t ct; + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); +} + +TEST(ApiPveAcNeg, EncryptAc_AcNoLeaves) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold(2, {}); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + coinbase::api::pve::leaf_keys_t ac_pks; + buf_t ct; + EXPECT_NE(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); +} + +TEST(ApiPveAcNeg, VerifyAc_InvalidCurve) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + buf_t Q = expected_Q(curve_id::secp256k1, xs[0]); + std::vector Qs; + Qs.emplace_back(Q.data(), Q.size()); + + EXPECT_NE(coinbase::api::pve::verify_ac(base_pke, static_cast(0), ac, ac_pks, ct, Qs, label), SUCCESS); + EXPECT_NE(coinbase::api::pve::verify_ac(base_pke, static_cast(4), ac, ac_pks, ct, Qs, label), SUCCESS); + EXPECT_NE(coinbase::api::pve::verify_ac(base_pke, static_cast(255), ac, ac_pks, ct, Qs, label), SUCCESS); +} + +TEST(ApiPveAcNeg, VerifyAc_EmptyCiphertext) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + buf_t Q_dummy = buf_t("Q"); + std::vector Qs; + Qs.emplace_back(Q_dummy.data(), Q_dummy.size()); + + EXPECT_NE(coinbase::api::pve::verify_ac(base_pke, curve_id::secp256k1, ac, ac_pks, mem_t(), Qs, label), SUCCESS); +} + +TEST(ApiPveAcNeg, VerifyAc_EmptyQsCompressed) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + std::vector empty_Qs; + EXPECT_NE(coinbase::api::pve::verify_ac(base_pke, curve_id::secp256k1, ac, ac_pks, ct, empty_Qs, label), SUCCESS); +} + +TEST(ApiPveAcNeg, VerifyAc_EmptyLabel) { + const toy_base_pke_t base_pke; + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + const buf_t label = buf_t("label"); + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + buf_t Q = expected_Q(curve_id::secp256k1, xs[0]); + std::vector Qs; + Qs.emplace_back(Q.data(), Q.size()); + + EXPECT_NE(coinbase::api::pve::verify_ac(base_pke, curve_id::secp256k1, ac, ac_pks, ct, Qs, mem_t()), SUCCESS); +} + +TEST(ApiPveAcNeg, VerifyAc_GarbageCiphertext) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + const std::array garbage = {0xDE, 0xAD, 0xBE, 0xEF}; + const mem_t garbage_ct(garbage.data(), 4); + + buf_t Q_dummy = buf_t("Q"); + std::vector Qs; + Qs.emplace_back(Q_dummy.data(), Q_dummy.size()); + + EXPECT_NE(coinbase::api::pve::verify_ac(base_pke, curve_id::secp256k1, ac, ac_pks, garbage_ct, Qs, label), SUCCESS); +} + +TEST(ApiPveAcNeg, PartialDecryptAcAttempt_InvalidCurve) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + buf_t share; + EXPECT_NE(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, static_cast(0), ac, ct, 0, "p1", ek1, + label, share), + SUCCESS); + EXPECT_NE(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, static_cast(4), ac, ct, 0, "p1", ek1, + label, share), + SUCCESS); + EXPECT_NE(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, static_cast(255), ac, ct, 0, "p1", ek1, + label, share), + SUCCESS); +} + +TEST(ApiPveAcNeg, PartialDecryptAcAttempt_EmptyCiphertext) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t dk = buf_t("dk1"); + buf_t share; + EXPECT_NE(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, mem_t(), 0, "p1", dk, + label, share), + SUCCESS); +} + +TEST(ApiPveAcNeg, PartialDecryptAcAttempt_EmptyDk) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + buf_t share; + EXPECT_NE(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, ct, 0, "p1", mem_t(), + label, share), + SUCCESS); +} + +TEST(ApiPveAcNeg, PartialDecryptAcAttempt_EmptyLabel) { + const toy_base_pke_t base_pke; + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + const buf_t label = buf_t("label"); + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + buf_t share; + EXPECT_NE(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, ct, 0, "p1", ek1, mem_t(), + share), + SUCCESS); +} + +TEST(ApiPveAcNeg, PartialDecryptAcAttempt_EmptyLeafName) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + const buf_t dk = buf_t("dk1"); + buf_t share; + EXPECT_NE( + coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, ct, 0, "", dk, label, share), + SUCCESS); +} + +TEST(ApiPveAcNeg, PartialDecryptAcAttempt_UnknownLeafName) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + const buf_t dk = buf_t("dk_nonexistent"); + buf_t share; + EXPECT_NE(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, ct, 0, "nonexistent", dk, + label, share), + SUCCESS); +} + +TEST(ApiPveAcNeg, PartialDecryptAcAttempt_GarbageCiphertext) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const std::array garbage = {0xDE, 0xAD, 0xBE, 0xEF}; + const mem_t garbage_ct(garbage.data(), 4); + + const buf_t dk = buf_t("dk1"); + buf_t share; + EXPECT_NE(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, garbage_ct, 0, "p1", dk, + label, share), + SUCCESS); +} + +TEST(ApiPveAcNeg, CombineAc_InvalidCurve) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + buf_t share_p1; + buf_t share_p2; + ASSERT_EQ(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, ct, 0, "p1", ek1, label, + share_p1), + SUCCESS); + ASSERT_EQ(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, ct, 0, "p2", ek2, label, + share_p2), + SUCCESS); + + coinbase::api::pve::leaf_shares_t quorum; + quorum.emplace("p1", mem_t(share_p1.data(), share_p1.size())); + quorum.emplace("p2", mem_t(share_p2.data(), share_p2.size())); + + std::vector xs_out; + EXPECT_NE(coinbase::api::pve::combine_ac(base_pke, static_cast(0), ac, ct, 0, label, quorum, xs_out), + SUCCESS); + EXPECT_NE(coinbase::api::pve::combine_ac(base_pke, static_cast(4), ac, ct, 0, label, quorum, xs_out), + SUCCESS); + EXPECT_NE(coinbase::api::pve::combine_ac(base_pke, static_cast(255), ac, ct, 0, label, quorum, xs_out), + SUCCESS); +} + +TEST(ApiPveAcNeg, CombineAc_EmptyCiphertext) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t share_dummy = buf_t("share_dummy_32bytes_padding_here"); + coinbase::api::pve::leaf_shares_t quorum; + quorum.emplace("p1", mem_t(share_dummy.data(), share_dummy.size())); + quorum.emplace("p2", mem_t(share_dummy.data(), share_dummy.size())); + + std::vector xs_out; + EXPECT_NE(coinbase::api::pve::combine_ac(base_pke, curve_id::secp256k1, ac, mem_t(), 0, label, quorum, xs_out), + SUCCESS); +} + +TEST(ApiPveAcNeg, CombineAc_EmptyLabel) { + const toy_base_pke_t base_pke; + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + const buf_t label = buf_t("label"); + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + buf_t share_p1; + buf_t share_p2; + ASSERT_EQ(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, ct, 0, "p1", ek1, label, + share_p1), + SUCCESS); + ASSERT_EQ(coinbase::api::pve::partial_decrypt_ac_attempt(base_pke, curve_id::secp256k1, ac, ct, 0, "p2", ek2, label, + share_p2), + SUCCESS); + + coinbase::api::pve::leaf_shares_t quorum; + quorum.emplace("p1", mem_t(share_p1.data(), share_p1.size())); + quorum.emplace("p2", mem_t(share_p2.data(), share_p2.size())); + + std::vector xs_out; + EXPECT_NE(coinbase::api::pve::combine_ac(base_pke, curve_id::secp256k1, ac, ct, 0, mem_t(), quorum, xs_out), SUCCESS); +} + +TEST(ApiPveAcNeg, CombineAc_EmptyQuorumShares) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const buf_t ek1 = buf_t("ek1"); + const buf_t ek2 = buf_t("ek2"); + const buf_t ek3 = buf_t("ek3"); + coinbase::api::pve::leaf_keys_t ac_pks; + ac_pks.emplace("p1", mem_t(ek1.data(), ek1.size())); + ac_pks.emplace("p2", mem_t(ek2.data(), ek2.size())); + ac_pks.emplace("p3", mem_t(ek3.data(), ek3.size())); + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i + 1); + std::vector xs; + xs.emplace_back(x_bytes.data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_ac(base_pke, curve_id::secp256k1, ac, ac_pks, label, xs, ct), SUCCESS); + + coinbase::api::pve::leaf_shares_t empty_quorum; + std::vector xs_out; + EXPECT_NE(coinbase::api::pve::combine_ac(base_pke, curve_id::secp256k1, ac, ct, 0, label, empty_quorum, xs_out), + SUCCESS); +} + +TEST(ApiPveAcNeg, CombineAc_GarbageCiphertext) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p1"), coinbase::api::access_structure_t::leaf("p2"), + coinbase::api::access_structure_t::leaf("p3")}); + + const std::array garbage = {0xDE, 0xAD, 0xBE, 0xEF}; + const mem_t garbage_ct(garbage.data(), 4); + + const buf_t share_dummy = buf_t("share_dummy_32bytes_padding_here"); + coinbase::api::pve::leaf_shares_t quorum; + quorum.emplace("p1", mem_t(share_dummy.data(), share_dummy.size())); + quorum.emplace("p2", mem_t(share_dummy.data(), share_dummy.size())); + + std::vector xs_out; + EXPECT_NE(coinbase::api::pve::combine_ac(base_pke, curve_id::secp256k1, ac, garbage_ct, 0, label, quorum, xs_out), + SUCCESS); +} + +TEST(ApiPveAcNeg, GetAcBatchCount_EmptyCiphertext) { + int count = 0; + EXPECT_NE(coinbase::api::pve::get_ac_batch_count(mem_t(), count), SUCCESS); +} + +TEST(ApiPveAcNeg, GetAcBatchCount_GarbageCiphertext) { + const std::array garbage = {0xDE, 0xAD, 0xBE, 0xEF}; + int count = 0; + EXPECT_NE(coinbase::api::pve::get_ac_batch_count(mem_t(garbage.data(), 4), count), SUCCESS); +} + +TEST(ApiPveAcNeg, GetPublicKeysCompressedAc_EmptyCiphertext) { + std::vector Qs; + EXPECT_NE(coinbase::api::pve::get_public_keys_compressed_ac(mem_t(), Qs), SUCCESS); +} + +TEST(ApiPveAcNeg, GetPublicKeysCompressedAc_GarbageCiphertext) { + const std::array garbage = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector Qs; + EXPECT_NE(coinbase::api::pve::get_public_keys_compressed_ac(mem_t(garbage.data(), 4), Qs), SUCCESS); +} diff --git a/tests/unit/api/test_pve_batch.cpp b/tests/unit/api/test_pve_batch.cpp new file mode 100644 index 00000000..a14a03c3 --- /dev/null +++ b/tests/unit/api/test_pve_batch.cpp @@ -0,0 +1,426 @@ +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::curve_id; + +static buf_t expected_Q(curve_id cid, mem_t x) { + const coinbase::crypto::ecurve_t curve = (cid == curve_id::p256) ? coinbase::crypto::curve_p256 + : (cid == curve_id::secp256k1) ? coinbase::crypto::curve_secp256k1 + : (cid == curve_id::ed25519) ? coinbase::crypto::curve_ed25519 + : coinbase::crypto::ecurve_t(); + cb_assert(curve.valid()); + + const coinbase::crypto::bn_t bn_x = coinbase::crypto::bn_t::from_bin(x) % curve.order(); + const coinbase::crypto::ecc_point_t Q = bn_x * curve.generator(); + return Q.to_compressed_bin(); +} + +class toy_base_pke_t final : public coinbase::api::pve::base_pke_i { + public: + error_t encrypt(mem_t /*ek*/, mem_t /*label*/, mem_t plain, mem_t /*rho*/, buf_t& out_ct) const override { + out_ct = buf_t(plain); + return SUCCESS; + } + + error_t decrypt(mem_t /*dk*/, mem_t /*label*/, mem_t ct, buf_t& out_plain) const override { + out_plain = buf_t(ct); + return SUCCESS; + } +}; + +} // namespace + +TEST(ApiPveBatch, EncVerDec_DefBasePke_RsaBlob) { + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + buf_t ek_blob; + buf_t dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob), SUCCESS); + + constexpr int n = 4; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) xs_bytes[static_cast(i)][static_cast(j)] = static_cast(i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(curve, ek_blob, label, xs, ct), SUCCESS); + + int batch_count = 0; + ASSERT_EQ(coinbase::api::pve::get_batch_count(ct, batch_count), SUCCESS); + ASSERT_EQ(batch_count, n); + + buf_t label_extracted; + ASSERT_EQ(coinbase::api::pve::get_Label_batch(ct, label_extracted), SUCCESS); + ASSERT_EQ(label_extracted, label); + + std::vector Qs_extracted; + ASSERT_EQ(coinbase::api::pve::get_public_keys_compressed_batch(ct, Qs_extracted), SUCCESS); + ASSERT_EQ(Qs_extracted.size(), static_cast(n)); + + std::vector Qs_expected; + Qs_expected.reserve(n); + for (int i = 0; i < n; i++) Qs_expected.push_back(expected_Q(curve, xs[static_cast(i)])); + + for (int i = 0; i < n; i++) EXPECT_EQ(Qs_extracted[static_cast(i)], Qs_expected[static_cast(i)]); + + std::vector Qs_expected_mem; + Qs_expected_mem.reserve(n); + for (const auto& q : Qs_expected) Qs_expected_mem.emplace_back(q.data(), q.size()); + + ASSERT_EQ(coinbase::api::pve::verify_batch(curve, ek_blob, ct, Qs_expected_mem, label), SUCCESS); + + std::vector xs_out; + ASSERT_EQ(coinbase::api::pve::decrypt_batch(curve, dk_blob, ek_blob, ct, label, xs_out), SUCCESS); + ASSERT_EQ(xs_out.size(), static_cast(n)); + for (int i = 0; i < n; i++) EXPECT_EQ(xs_out[static_cast(i)], buf_t(xs[static_cast(i)])); +} + +TEST(ApiPveBatch, EncryptRejectsOversizedX) { + const curve_id curve = curve_id::secp256k1; + const buf_t label = buf_t("label"); + + buf_t ek_blob; + buf_t dk_blob; + ASSERT_EQ(coinbase::api::pve::generate_base_pke_rsa_keypair(ek_blob, dk_blob), SUCCESS); + + std::array x_bytes{}; + for (int i = 0; i < 33; i++) x_bytes[static_cast(i)] = static_cast(0x10 + i); + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + + buf_t ct; + EXPECT_EQ(coinbase::api::pve::encrypt_batch(curve, ek_blob, label, xs, ct), E_RANGE); +} + +TEST(ApiPveBatch, EncryptVerifyDecrypt_CustomBasePke) { + const toy_base_pke_t base_pke; + + const curve_id curve = curve_id::secp256k1; + const buf_t ek = buf_t("ek"); + const buf_t dk = buf_t("dk"); + const buf_t label = buf_t("label"); + + constexpr int n = 3; + std::array, n> xs_bytes{}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < 32; j++) + xs_bytes[static_cast(i)][static_cast(j)] = static_cast(0x77 + i + j); + } + std::vector xs; + xs.reserve(n); + for (int i = 0; i < n; i++) xs.emplace_back(xs_bytes[static_cast(i)].data(), 32); + + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve, ek, label, xs, ct), SUCCESS); + + std::vector Qs_expected; + Qs_expected.reserve(n); + for (int i = 0; i < n; i++) Qs_expected.push_back(expected_Q(curve, xs[static_cast(i)])); + + std::vector Qs_expected_mem; + Qs_expected_mem.reserve(n); + for (const auto& q : Qs_expected) Qs_expected_mem.emplace_back(q.data(), q.size()); + + ASSERT_EQ(coinbase::api::pve::verify_batch(base_pke, curve, ek, ct, Qs_expected_mem, label), SUCCESS); + + std::vector xs_out; + ASSERT_EQ(coinbase::api::pve::decrypt_batch(base_pke, curve, dk, ek, ct, label, xs_out), SUCCESS); + ASSERT_EQ(xs_out.size(), static_cast(n)); + for (int i = 0; i < n; i++) EXPECT_EQ(xs_out[static_cast(i)], buf_t(xs[static_cast(i)])); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiPveBatchNeg, EncryptBatchInvalidCurve) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + dylog_disable_scope_t no_log_err; + for (int c : {0, 4, 255}) { + EXPECT_NE(coinbase::api::pve::encrypt_batch(base_pke, static_cast(c), ek, label, xs, ct), SUCCESS); + } +} + +TEST(ApiPveBatchNeg, EncryptBatchEmptyEk) { + const toy_base_pke_t base_pke; + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, mem_t(), label, xs, ct), SUCCESS); +} + +TEST(ApiPveBatchNeg, EncryptBatchEmptyLabel) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, mem_t(), xs, ct), SUCCESS); +} + +TEST(ApiPveBatchNeg, EncryptBatchEmptyXsVector) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::vector xs; + buf_t ct; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); +} + +TEST(ApiPveBatchNeg, EncryptBatchXsWithEmptyElement) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::vector xs; + xs.emplace_back(mem_t()); + buf_t ct; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); +} + +TEST(ApiPveBatchNeg, VerifyBatchInvalidCurve) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); + const buf_t Q = expected_Q(curve_id::secp256k1, xs[0]); + std::vector Qs; + Qs.emplace_back(Q.data(), Q.size()); + dylog_disable_scope_t no_log_err; + for (int c : {0, 4, 255}) { + EXPECT_NE(coinbase::api::pve::verify_batch(base_pke, static_cast(c), ek, ct, Qs, label), SUCCESS); + } +} + +TEST(ApiPveBatchNeg, VerifyBatchEmptyEk) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); + const buf_t Q = expected_Q(curve_id::secp256k1, xs[0]); + std::vector Qs; + Qs.emplace_back(Q.data(), Q.size()); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify_batch(base_pke, curve_id::secp256k1, mem_t(), ct, Qs, label), SUCCESS); +} + +TEST(ApiPveBatchNeg, VerifyBatchEmptyCiphertext) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + const buf_t Q = buf_t("Q"); + std::vector Qs; + Qs.emplace_back(Q.data(), Q.size()); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify_batch(base_pke, curve_id::secp256k1, ek, mem_t(), Qs, label), SUCCESS); +} + +TEST(ApiPveBatchNeg, VerifyBatchEmptyQsVector) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); + std::vector Qs; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify_batch(base_pke, curve_id::secp256k1, ek, ct, Qs, label), SUCCESS); +} + +TEST(ApiPveBatchNeg, VerifyBatchEmptyLabel) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); + const buf_t Q = expected_Q(curve_id::secp256k1, xs[0]); + std::vector Qs; + Qs.emplace_back(Q.data(), Q.size()); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify_batch(base_pke, curve_id::secp256k1, ek, ct, Qs, mem_t()), SUCCESS); +} + +TEST(ApiPveBatchNeg, VerifyBatchGarbageCiphertext) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const buf_t Q = buf_t("Q"); + std::vector Qs; + Qs.emplace_back(Q.data(), Q.size()); + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::verify_batch(base_pke, curve_id::secp256k1, ek, mem_t(garbage, 4), Qs, label), SUCCESS); +} + +TEST(ApiPveBatchNeg, DecryptBatchInvalidCurve) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t dk = buf_t("dk"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); + std::vector xs_out; + dylog_disable_scope_t no_log_err; + for (int c : {0, 4, 255}) { + EXPECT_NE(coinbase::api::pve::decrypt_batch(base_pke, static_cast(c), dk, ek, ct, label, xs_out), + SUCCESS); + } +} + +TEST(ApiPveBatchNeg, DecryptBatchEmptyDk) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); + std::vector xs_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_batch(base_pke, curve_id::secp256k1, mem_t(), ek, ct, label, xs_out), SUCCESS); +} + +TEST(ApiPveBatchNeg, DecryptBatchEmptyEk) { + const toy_base_pke_t base_pke; + const buf_t dk = buf_t("dk"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, buf_t("ek"), label, xs, ct), SUCCESS); + std::vector xs_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_batch(base_pke, curve_id::secp256k1, dk, mem_t(), ct, label, xs_out), SUCCESS); +} + +TEST(ApiPveBatchNeg, DecryptBatchEmptyCiphertext) { + const toy_base_pke_t base_pke; + const buf_t dk = buf_t("dk"); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + std::vector xs_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_batch(base_pke, curve_id::secp256k1, dk, ek, mem_t(), label, xs_out), SUCCESS); +} + +TEST(ApiPveBatchNeg, DecryptBatchEmptyLabel) { + const toy_base_pke_t base_pke; + const buf_t ek = buf_t("ek"); + const buf_t dk = buf_t("dk"); + const buf_t label = buf_t("label"); + std::array x_bytes{}; + x_bytes[0] = 1; + std::vector xs; + xs.emplace_back(x_bytes.data(), static_cast(x_bytes.size())); + buf_t ct; + ASSERT_EQ(coinbase::api::pve::encrypt_batch(base_pke, curve_id::secp256k1, ek, label, xs, ct), SUCCESS); + std::vector xs_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_batch(base_pke, curve_id::secp256k1, dk, ek, ct, mem_t(), xs_out), SUCCESS); +} + +TEST(ApiPveBatchNeg, DecryptBatchGarbageCiphertext) { + const toy_base_pke_t base_pke; + const buf_t dk = buf_t("dk"); + const buf_t ek = buf_t("ek"); + const buf_t label = buf_t("label"); + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector xs_out; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::decrypt_batch(base_pke, curve_id::secp256k1, dk, ek, mem_t(garbage, 4), label, xs_out), + SUCCESS); +} + +TEST(ApiPveBatchNeg, GetBatchCountEmptyCiphertext) { + int count = 0; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_batch_count(mem_t(), count), SUCCESS); +} + +TEST(ApiPveBatchNeg, GetBatchCountGarbageCiphertext) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + int count = 0; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_batch_count(mem_t(garbage, 4), count), SUCCESS); +} + +TEST(ApiPveBatchNeg, GetPublicKeysCompressedBatchEmptyCiphertext) { + std::vector Qs; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_public_keys_compressed_batch(mem_t(), Qs), SUCCESS); +} + +TEST(ApiPveBatchNeg, GetPublicKeysCompressedBatchGarbageCiphertext) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector Qs; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_public_keys_compressed_batch(mem_t(garbage, 4), Qs), SUCCESS); +} + +TEST(ApiPveBatchNeg, GetLabelBatchEmptyCiphertext) { + buf_t label; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_Label_batch(mem_t(), label), SUCCESS); +} + +TEST(ApiPveBatchNeg, GetLabelBatchGarbageCiphertext) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t label; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::pve::get_Label_batch(mem_t(garbage, 4), label), SUCCESS); +} diff --git a/tests/unit/api/test_schnorr2pc.cpp b/tests/unit/api/test_schnorr2pc.cpp new file mode 100644 index 00000000..5d8f9c26 --- /dev/null +++ b/tests/unit/api/test_schnorr2pc.cpp @@ -0,0 +1,523 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::api::curve_id; +using coinbase::api::schnorr_2p::party_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_2pc; + +struct key_blob_v1_t { + uint32_t version = 1; + uint32_t role = 0; + uint32_t curve = 0; + buf_t Q_compressed; + coinbase::crypto::bn_t x_share; + + void convert(coinbase::converter_t& c) { c.convert(version, role, curve, Q_compressed, x_share); } +}; + +static void exercise_secp256k1_bip340() { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::dkg(job1, curve_id::secp256k1, key_blob_1); }, + [&] { return coinbase::api::schnorr_2p::dkg(job2, curve_id::secp256k1, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t pub1; + buf_t pub2; + ASSERT_EQ(coinbase::api::schnorr_2p::get_public_key_compressed(key_blob_1, pub1), SUCCESS); + ASSERT_EQ(coinbase::api::schnorr_2p::get_public_key_compressed(key_blob_2, pub2), SUCCESS); + EXPECT_EQ(pub1.size(), 33); + EXPECT_EQ(pub1, pub2); + + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub1), SUCCESS); + + buf_t pub_xonly; + ASSERT_EQ(coinbase::api::schnorr_2p::extract_public_key_xonly(key_blob_1, pub_xonly), SUCCESS); + EXPECT_EQ(pub_xonly.size(), 32); + EXPECT_EQ(pub_xonly, Q.get_x().to_bin(32)); + + // Deterministic 32-byte message for testing. + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + + buf_t sig1; + buf_t sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::sign(job1, key_blob_1, msg, sig1); }, + [&] { return coinbase::api::schnorr_2p::sign(job2, key_blob_2, msg, sig2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sig1.size(), 64); + EXPECT_EQ(sig2.size(), 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, msg, sig1), SUCCESS); + + // Refresh and sign again. + buf_t new_key_blob_1; + buf_t new_key_blob_2; + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::refresh(job1, key_blob_1, new_key_blob_1); }, + [&] { return coinbase::api::schnorr_2p::refresh(job2, key_blob_2, new_key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t sig3; + buf_t sig4; + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::sign(job1, new_key_blob_1, msg, sig3); }, + [&] { return coinbase::api::schnorr_2p::sign(job2, new_key_blob_2, msg, sig4); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + EXPECT_EQ(sig3.size(), 64); + EXPECT_EQ(sig4.size(), 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, msg, sig3), SUCCESS); + + buf_t pub3; + buf_t pub4; + ASSERT_EQ(coinbase::api::schnorr_2p::get_public_key_compressed(new_key_blob_1, pub3), SUCCESS); + ASSERT_EQ(coinbase::api::schnorr_2p::get_public_key_compressed(new_key_blob_2, pub4), SUCCESS); + EXPECT_EQ(pub3, pub4); + EXPECT_EQ(pub3, pub1); + + // Role is fixed to the share: signing with the "wrong" job.self should fail. + buf_t bad_sig1; + buf_t bad_sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::sign(job1, key_blob_2, msg, bad_sig1); }, + [&] { return coinbase::api::schnorr_2p::sign(job2, key_blob_2, msg, bad_sig2); }, rv1, rv2); + EXPECT_EQ(rv1, E_BADARG); +} + +} // namespace + +TEST(ApiSchnorr2pc, DkgSignRefreshSign) { exercise_secp256k1_bip340(); } + +TEST(ApiSchnorr2pc, UnsupportedCurveRejected) { + failing_transport_t t; + buf_t key_blob; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", t}; + EXPECT_EQ(coinbase::api::schnorr_2p::dkg(job, curve_id::p256, key_blob), E_BADARG); +} + +TEST(ApiSchnorr2pc, RejectsOutOfRangeXShareInKeyBlob) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::dkg(job1, curve_id::secp256k1, key_blob_1); }, + [&] { return coinbase::api::schnorr_2p::dkg(job2, curve_id::secp256k1, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + key_blob_v1_t blob; + ASSERT_EQ(coinbase::convert(blob, key_blob_1), SUCCESS); + + // `x_share == q` is out of range; valid shares must satisfy 0 <= x_share < q. + blob.x_share = coinbase::crypto::bn_t(coinbase::crypto::curve_secp256k1.order()); + buf_t malformed_blob = coinbase::convert(blob); + + buf_t pub; + EXPECT_EQ(coinbase::api::schnorr_2p::get_public_key_compressed(malformed_blob, pub), E_FORMAT); +} + +TEST(ApiSchnorr2pc, KeyBlobPrivScalar_NoPubSign) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::dkg(job1, curve_id::secp256k1, key_blob_1); }, + [&] { return coinbase::api::schnorr_2p::dkg(job2, curve_id::secp256k1, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + // Refresh (exercise detach/attach on refreshed blobs too). + buf_t refreshed_1; + buf_t refreshed_2; + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::refresh(job1, key_blob_1, refreshed_1); }, + [&] { return coinbase::api::schnorr_2p::refresh(job2, key_blob_2, refreshed_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + // Detach into public blob + scalar. + buf_t public_1; + buf_t public_2; + buf_t x1_fixed; + buf_t x2_fixed; + ASSERT_EQ(coinbase::api::schnorr_2p::detach_private_scalar(refreshed_1, public_1, x1_fixed), SUCCESS); + ASSERT_EQ(coinbase::api::schnorr_2p::detach_private_scalar(refreshed_2, public_2, x2_fixed), SUCCESS); + EXPECT_EQ(x1_fixed.size(), 32); + EXPECT_EQ(x2_fixed.size(), 32); + + // Capture share points before detaching (public blobs no longer carry them). + buf_t Qi_full_1; + ASSERT_EQ(coinbase::api::schnorr_2p::get_public_share_compressed(refreshed_1, Qi_full_1), SUCCESS); + + buf_t Qi_full_2; + ASSERT_EQ(coinbase::api::schnorr_2p::get_public_share_compressed(refreshed_2, Qi_full_2), SUCCESS); + + // Public blob should not be usable for signing. + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + { + failing_transport_t ft; + const coinbase::api::job_2p_t bad_job{party_t::p1, "p1", "p2", ft}; + buf_t sig; + EXPECT_NE(coinbase::api::schnorr_2p::sign(bad_job, public_1, msg, sig), SUCCESS); + } + + // Attach scalars back and sign. + buf_t merged_1; + buf_t merged_2; + ASSERT_EQ(coinbase::api::schnorr_2p::attach_private_scalar(public_1, x1_fixed, Qi_full_1, merged_1), SUCCESS); + ASSERT_EQ(coinbase::api::schnorr_2p::attach_private_scalar(public_2, x2_fixed, Qi_full_2, merged_2), SUCCESS); + + buf_t pub; + ASSERT_EQ(coinbase::api::schnorr_2p::get_public_key_compressed(merged_1, pub), SUCCESS); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub), SUCCESS); + + buf_t sig1; + buf_t sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::sign(job1, merged_1, msg, sig1); }, + [&] { return coinbase::api::schnorr_2p::sign(job2, merged_2, msg, sig2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + ASSERT_EQ(sig1.size(), 64); + EXPECT_EQ(sig2.size(), 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, msg, sig1), SUCCESS); + + // Negative: wrong scalar should fail to attach. + buf_t bad_x = x1_fixed; + bad_x[0] ^= 0x01; + buf_t bad_merged; + EXPECT_NE(coinbase::api::schnorr_2p::attach_private_scalar(public_1, bad_x, Qi_full_1, bad_merged), SUCCESS); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +namespace { + +using coinbase::mem_t; + +static void generate_schnorr_key_blobs(buf_t& blob1, buf_t& blob2) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + run_2pc( + c1, c2, [&] { return coinbase::api::schnorr_2p::dkg(job1, curve_id::secp256k1, blob1); }, + [&] { return coinbase::api::schnorr_2p::dkg(job2, curve_id::secp256k1, blob2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); +} + +} // namespace + +class ApiSchnorr2pcNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { generate_schnorr_key_blobs(blob1_, blob2_); } + static buf_t blob1_; + static buf_t blob2_; +}; + +buf_t ApiSchnorr2pcNegWithBlobs::blob1_; +buf_t ApiSchnorr2pcNegWithBlobs::blob2_; + +TEST(ApiSchnorr2pcNeg, DkgInvalidCurve0) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t key_blob; + EXPECT_NE(coinbase::api::schnorr_2p::dkg(job, static_cast(0), key_blob), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, DkgInvalidCurve4) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t key_blob; + EXPECT_NE(coinbase::api::schnorr_2p::dkg(job, static_cast(4), key_blob), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, DkgInvalidCurve255) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t key_blob; + EXPECT_NE(coinbase::api::schnorr_2p::dkg(job, static_cast(255), key_blob), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, DkgUnsupportedCurveP256) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t key_blob; + EXPECT_NE(coinbase::api::schnorr_2p::dkg(job, curve_id::p256, key_blob), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, DkgUnsupportedCurveEd25519) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t key_blob; + EXPECT_NE(coinbase::api::schnorr_2p::dkg(job, curve_id::ed25519, key_blob), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, DkgEmptyP1Name) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "", "p2", ft}; + buf_t key_blob; + EXPECT_NE(coinbase::api::schnorr_2p::dkg(job, curve_id::secp256k1, key_blob), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, DkgSameP1P2Name) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p1", ft}; + buf_t key_blob; + EXPECT_NE(coinbase::api::schnorr_2p::dkg(job, curve_id::secp256k1, key_blob), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, GetPubKeyCompressedEmptyBlob) { + buf_t pub; + EXPECT_NE(coinbase::api::schnorr_2p::get_public_key_compressed(mem_t(), pub), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, GetPubKeyCompressedGarbageBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04}; + buf_t pub; + EXPECT_NE(coinbase::api::schnorr_2p::get_public_key_compressed(mem_t(garbage, sizeof(garbage)), pub), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, GetPubKeyCompressedOneByteBlob) { + const uint8_t one = 0x00; + buf_t pub; + EXPECT_NE(coinbase::api::schnorr_2p::get_public_key_compressed(mem_t(&one, 1), pub), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, GetPubKeyCompressedAllZeroBlob) { + uint8_t zeros[64] = {}; + buf_t pub; + EXPECT_NE(coinbase::api::schnorr_2p::get_public_key_compressed(mem_t(zeros, sizeof(zeros)), pub), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, ExtractPubKeyXonlyEmptyBlob) { + buf_t pub; + EXPECT_NE(coinbase::api::schnorr_2p::extract_public_key_xonly(mem_t(), pub), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, ExtractPubKeyXonlyGarbageBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04}; + buf_t pub; + EXPECT_NE(coinbase::api::schnorr_2p::extract_public_key_xonly(mem_t(garbage, sizeof(garbage)), pub), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, GetPublicShareCompressedEmptyBlob) { + buf_t out; + EXPECT_NE(coinbase::api::schnorr_2p::get_public_share_compressed(mem_t(), out), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, GetPublicShareCompressedGarbageBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04}; + buf_t out; + EXPECT_NE(coinbase::api::schnorr_2p::get_public_share_compressed(mem_t(garbage, sizeof(garbage)), out), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, DetachPrivateScalarEmptyBlob) { + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::schnorr_2p::detach_private_scalar(mem_t(), pub_blob, scalar), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, DetachPrivateScalarGarbageBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04}; + buf_t pub_blob, scalar; + EXPECT_NE(coinbase::api::schnorr_2p::detach_private_scalar(mem_t(garbage, sizeof(garbage)), pub_blob, scalar), + SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, AttachPrivateScalarEmptyPubBlob) { + uint8_t scalar_bytes[32] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::schnorr_2p::attach_private_scalar(mem_t(), mem_t(scalar_bytes, 32), mem_t(point, 33), out), + SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, AttachPrivateScalarGarbagePubBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + uint8_t scalar_bytes[32] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t out; + EXPECT_NE(coinbase::api::schnorr_2p::attach_private_scalar(mem_t(garbage, sizeof(garbage)), mem_t(scalar_bytes, 32), + mem_t(point, 33), out), + SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, AttachPrivateScalarWrongSizeScalar) { + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t dummy_pub(64); + buf_t out; + + uint8_t short_scalar[31] = {0x01}; + EXPECT_NE(coinbase::api::schnorr_2p::attach_private_scalar(dummy_pub, mem_t(short_scalar, 31), mem_t(point, 33), out), + SUCCESS); + + uint8_t long_scalar[33] = {0x01}; + EXPECT_NE(coinbase::api::schnorr_2p::attach_private_scalar(dummy_pub, mem_t(long_scalar, 33), mem_t(point, 33), out), + SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, AttachPrivateScalarEmptyScalar) { + uint8_t point[33] = {}; + point[0] = 0x02; + buf_t dummy_pub(64); + buf_t out; + EXPECT_NE(coinbase::api::schnorr_2p::attach_private_scalar(dummy_pub, mem_t(), mem_t(point, 33), out), SUCCESS); +} + +TEST(ApiSchnorr2pcNeg, AttachPrivateScalarEmptyPubShare) { + uint8_t scalar_bytes[32] = {0x01}; + buf_t dummy_pub(64); + buf_t out; + EXPECT_NE(coinbase::api::schnorr_2p::attach_private_scalar(dummy_pub, mem_t(scalar_bytes, 32), mem_t(), out), + SUCCESS); +} + +TEST_F(ApiSchnorr2pcNegWithBlobs, RefreshRoleMismatchP1BlobP2Job) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p2, "p1", "p2", ft}; + buf_t new_blob; + EXPECT_NE(coinbase::api::schnorr_2p::refresh(job, blob1_, new_blob), SUCCESS); +} + +TEST_F(ApiSchnorr2pcNegWithBlobs, RefreshRoleMismatchP2BlobP1Job) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t new_blob; + EXPECT_NE(coinbase::api::schnorr_2p::refresh(job, blob2_, new_blob), SUCCESS); +} + +TEST_F(ApiSchnorr2pcNegWithBlobs, RefreshEmptyBlob) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t new_blob; + EXPECT_NE(coinbase::api::schnorr_2p::refresh(job, mem_t(), new_blob), SUCCESS); +} + +TEST_F(ApiSchnorr2pcNegWithBlobs, RefreshGarbageBlob) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t new_blob; + EXPECT_NE(coinbase::api::schnorr_2p::refresh(job, mem_t(garbage, sizeof(garbage)), new_blob), SUCCESS); +} + +TEST_F(ApiSchnorr2pcNegWithBlobs, SignMsgNot32Bytes) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t sig; + + buf_t msg_31(31); + EXPECT_NE(coinbase::api::schnorr_2p::sign(job, blob1_, msg_31, sig), SUCCESS); + + buf_t msg_33(33); + EXPECT_NE(coinbase::api::schnorr_2p::sign(job, blob1_, msg_33, sig), SUCCESS); + + EXPECT_NE(coinbase::api::schnorr_2p::sign(job, blob1_, mem_t(), sig), SUCCESS); +} + +TEST_F(ApiSchnorr2pcNegWithBlobs, SignEmptyBlob) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + buf_t msg(32); + buf_t sig; + EXPECT_NE(coinbase::api::schnorr_2p::sign(job, mem_t(), msg, sig), SUCCESS); +} + +TEST_F(ApiSchnorr2pcNegWithBlobs, SignGarbageBlob) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p1, "p1", "p2", ft}; + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + buf_t sig; + EXPECT_NE(coinbase::api::schnorr_2p::sign(job, mem_t(garbage, sizeof(garbage)), msg, sig), SUCCESS); +} + +TEST_F(ApiSchnorr2pcNegWithBlobs, SignRoleMismatch) { + failing_transport_t ft; + const coinbase::api::job_2p_t job{party_t::p2, "p1", "p2", ft}; + buf_t msg(32); + buf_t sig; + EXPECT_NE(coinbase::api::schnorr_2p::sign(job, blob1_, msg, sig), SUCCESS); +} diff --git a/tests/unit/api/test_schnorr_mp.cpp b/tests/unit/api/test_schnorr_mp.cpp new file mode 100644 index 00000000..41184011 --- /dev/null +++ b/tests/unit/api/test_schnorr_mp.cpp @@ -0,0 +1,337 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::curve_id; +using coinbase::api::job_mp_t; +using coinbase::api::party_idx_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_mp; + +static void exercise_4p_role_change() { + constexpr int n = 4; + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + std::vector keys(n); + std::vector new_keys(n); + std::vector sids(n); + std::vector sigs(n); + std::vector new_sigs(n); + std::vector rvs; + + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::dkg_additive(job, curve_id::secp256k1, keys[static_cast(i)], + sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + buf_t pub0; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_key_compressed(keys[0], pub0), SUCCESS); + EXPECT_EQ(pub0.size(), 33); + for (int i = 1; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_key_compressed(keys[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub0), SUCCESS); + + buf_t xonly0; + ASSERT_EQ(coinbase::api::schnorr_mp::extract_public_key_xonly(keys[0], xonly0), SUCCESS); + EXPECT_EQ(xonly0.size(), 32); + EXPECT_EQ(xonly0, Q.get_x().to_bin(32)); + + // Change the party ordering ("role" indices) between protocols. + const std::vector name_views2 = {names[0], names[2], names[1], names[3]}; + // Map new role index -> old role index (DKG) for the same party name. + const int perm[n] = {0, 2, 1, 3}; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::sign_additive(job, keys[static_cast(perm[i])], msg, + /*sig_receiver=*/2, sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + EXPECT_EQ(sigs[2].size(), 64); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + EXPECT_EQ(sigs[static_cast(i)].size(), 0); + } + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, msg, sigs[2]), SUCCESS); + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::refresh_additive(job, sids[static_cast(perm[i])], + keys[static_cast(perm[i])], + new_keys[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_key_compressed(new_keys[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views2, *transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::sign_additive(job, new_keys[static_cast(i)], msg, /*sig_receiver=*/2, + new_sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + EXPECT_EQ(new_sigs[2].size(), 64); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + EXPECT_EQ(new_sigs[static_cast(i)].size(), 0); + } + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, msg, new_sigs[2]), SUCCESS); +} + +} // namespace + +TEST(ApiSchnorrMp, DkgSignRefreshSignRoleChange4p) { exercise_4p_role_change(); } + +TEST(ApiSchnorrMp, RejectsInvalidSigReceiver) { + failing_transport_t t; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t sig; + EXPECT_EQ(coinbase::api::schnorr_mp::sign_additive(job, mem_t(), mem_t(), /*sig_receiver=*/5, sig), E_BADARG); +} + +TEST(ApiSchnorrMp, UnsupportedCurveRejected) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t key; + buf_t sid; + EXPECT_EQ(coinbase::api::schnorr_mp::dkg_additive(job, curve_id::p256, key, sid), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiSchnorrMpNeg, DkgInvalidCurve) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t key, sid; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::dkg_additive(job, curve_id(0), key, sid), SUCCESS); + EXPECT_NE(coinbase::api::schnorr_mp::dkg_additive(job, curve_id(4), key, sid), SUCCESS); + EXPECT_NE(coinbase::api::schnorr_mp::dkg_additive(job, curve_id(255), key, sid), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, DkgEmptyPartyName) { + failing_transport_t ft; + std::vector names = {}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t key, sid; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::dkg_additive(job, curve_id::secp256k1, key, sid), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, DkgSingleParty) { + failing_transport_t ft; + std::vector names = {"p0"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t key, sid; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::dkg_additive(job, curve_id::secp256k1, key, sid), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, GetPubKeyCompressedEmptyBlob) { + buf_t pub; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::get_public_key_compressed(mem_t(), pub), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, GetPubKeyCompressedGarbageBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pub; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::get_public_key_compressed(mem_t(garbage, 4), pub), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, ExtractPubKeyXonlyEmptyBlob) { + buf_t pub; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::extract_public_key_xonly(mem_t(), pub), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, ExtractPubKeyXonlyGarbageBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pub; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::extract_public_key_xonly(mem_t(garbage, 4), pub), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, GetPublicShareCompressedEmptyBlob) { + buf_t Qi; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::get_public_share_compressed(mem_t(), Qi), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, GetPublicShareCompressedGarbageBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t Qi; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::get_public_share_compressed(mem_t(garbage, 4), Qi), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, DetachPrivateScalarEmptyBlob) { + buf_t out_pub, out_scalar; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::detach_private_scalar(mem_t(), out_pub, out_scalar), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, DetachPrivateScalarGarbageBlob) { + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t out_pub, out_scalar; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::detach_private_scalar(mem_t(garbage, 4), out_pub, out_scalar), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, AttachPrivateScalarEmptyPubBlob) { + buf_t scalar(32); + buf_t Qi(33); + buf_t out_blob; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::attach_private_scalar(mem_t(), scalar, Qi, out_blob), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, AttachPrivateScalarWrongSizeScalar) { + const uint8_t pub[] = {0x01, 0x02, 0x03, 0x04}; + buf_t scalar_16(16); + buf_t Qi(33); + buf_t out_blob; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::attach_private_scalar(mem_t(pub, 4), scalar_16, Qi, out_blob), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, SignAdditiveInvalidSigReceiverNeg1) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t msg(32); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_additive(job, mem_t(), msg, /*sig_receiver=*/-1, sig), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, SignAdditiveMsg31Bytes) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t msg(31); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_additive(job, mem_t(), msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, SignAdditiveMsg33Bytes) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t msg(33); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_additive(job, mem_t(), msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, SignAdditiveMsg0Bytes) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_additive(job, mem_t(), mem_t(), /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, SignAdditiveEmptyKeyBlob) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t msg(32); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_additive(job, mem_t(), msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, SignAdditiveGarbageKeyBlob) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_additive(job, mem_t(garbage, 4), msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, RefreshAdditiveEmptyKeyBlob) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t sid, new_key; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::refresh_additive(job, sid, mem_t(), new_key), SUCCESS); +} + +TEST(ApiSchnorrMpNeg, RefreshAdditiveGarbageKeyBlob) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t sid, new_key; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::refresh_additive(job, sid, mem_t(garbage, 4), new_key), SUCCESS); +} diff --git a/tests/unit/api/test_schnorr_mp_ac.cpp b/tests/unit/api/test_schnorr_mp_ac.cpp new file mode 100644 index 00000000..eb77c996 --- /dev/null +++ b/tests/unit/api/test_schnorr_mp_ac.cpp @@ -0,0 +1,358 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::api::curve_id; +using coinbase::api::job_mp_t; +using coinbase::api::party_idx_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_mp; + +} // namespace + +TEST(ApiSchnorrMpAc, DkgRefreshSign4p) { + constexpr int n = 4; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // THRESHOLD[2](p0, p1, p2, p3) + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + + // Only p0 and p1 actively contribute to the DKG/refresh. + const std::vector quorum_party_names = {names[0], names[1]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::dkg_ac(job, curve_id::secp256k1, sids[static_cast(i)], ac, + quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(sids[0], sids[static_cast(i)]); + + buf_t pub0; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_key_compressed(key_blobs[0], pub0), SUCCESS); + EXPECT_EQ(pub0.size(), 33); + for (int i = 1; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_key_compressed(key_blobs[static_cast(i)], pub_i), SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub0), SUCCESS); + + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + + std::vector> sign_peers = {peers[0], peers[1]}; + std::vector> sign_transports = {transports[0], transports[1]}; + + constexpr int quorum_n = 2; + std::vector sigs(quorum_n); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::sign_ac(job, key_blobs[static_cast(i)], ac, msg, /*sig_receiver=*/0, + sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_EQ(sigs[0].size(), 64); + EXPECT_EQ(sigs[1].size(), 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, msg, sigs[0]), SUCCESS); + + // Threshold refresh. + std::vector new_key_blobs(n); + std::vector refresh_sids(n); + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::refresh_ac(job, refresh_sids[static_cast(i)], + key_blobs[static_cast(i)], ac, quorum_party_names, + new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) EXPECT_EQ(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + buf_t pub_i; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_key_compressed(new_key_blobs[static_cast(i)], pub_i), + SUCCESS); + EXPECT_EQ(pub_i, pub0); + } + + std::vector sigs2(quorum_n); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::sign_ac(job, new_key_blobs[static_cast(i)], ac, msg, + /*sig_receiver=*/0, sigs2[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_EQ(sigs2[0].size(), 64); + EXPECT_EQ(sigs2[1].size(), 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, msg, sigs2[0]), SUCCESS); +} + +TEST(ApiSchnorrMpAc, KeyBlobPrivScalar_NoPubSign) { + constexpr int n = 4; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2", "p3"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + // THRESHOLD[2](p0, p1, p2, p3) + const coinbase::api::access_structure_t ac = + coinbase::api::access_structure_t::Threshold(2, { + coinbase::api::access_structure_t::leaf(names[0]), + coinbase::api::access_structure_t::leaf(names[1]), + coinbase::api::access_structure_t::leaf(names[2]), + coinbase::api::access_structure_t::leaf(names[3]), + }); + const std::vector quorum_party_names = {names[0], names[1]}; + + std::vector key_blobs(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::dkg_ac(job, curve_id::secp256k1, sids[static_cast(i)], ac, + quorum_party_names, key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + + buf_t pub0; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_key_compressed(key_blobs[0], pub0), SUCCESS); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub0), SUCCESS); + + std::vector public_blobs(n); + std::vector x_fixed(n); + std::vector merged(n); + for (int i = 0; i < n; i++) { + ASSERT_EQ(coinbase::api::schnorr_mp::detach_private_scalar(key_blobs[static_cast(i)], public_blobs[i], + x_fixed[i]), + SUCCESS); + EXPECT_GT(public_blobs[i].size(), 0); + EXPECT_EQ(x_fixed[i].size(), 32); // secp256k1 order size + + buf_t Qi_full; + buf_t Qi_public; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_share_compressed(key_blobs[static_cast(i)], Qi_full), + SUCCESS); + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_share_compressed(public_blobs[i], Qi_public), SUCCESS); + EXPECT_EQ(Qi_full, Qi_public); + + ASSERT_EQ(coinbase::api::schnorr_mp::attach_private_scalar(public_blobs[i], x_fixed[i], Qi_full, merged[i]), + SUCCESS); + EXPECT_GT(merged[i].size(), 0); + } + + // Public blob should not be usable for signing. + // Avoid spinning up a full protocol run here: sign_ac rejects at key blob parsing + // before any transport calls, so a single local call is sufficient. + buf_t msg(32); + for (int i = 0; i < msg.size(); i++) msg[i] = static_cast(i); + { + failing_transport_t t; + job_mp_t job{/*self=*/0, quorum_party_names, t}; + buf_t sig; + EXPECT_NE(coinbase::api::schnorr_mp::sign_ac(job, public_blobs[0], ac, msg, /*sig_receiver=*/0, sig), SUCCESS); + } + + // Merged blobs should be usable for signing. + std::vector> sign_peers = {peers[0], peers[1]}; + std::vector> sign_transports = {transports[0], transports[1]}; + constexpr int quorum_n = 2; + std::vector sigs(quorum_n); + run_mp( + sign_peers, + [&](int i) { + job_mp_t job{static_cast(i), quorum_party_names, *sign_transports[static_cast(i)]}; + return coinbase::api::schnorr_mp::sign_ac(job, merged[static_cast(i)], ac, msg, /*sig_receiver=*/0, + sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + ASSERT_EQ(sigs[0].size(), 64); + EXPECT_EQ(sigs[1].size(), 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, msg, sigs[0]), SUCCESS); + + // Negative: merging the wrong scalar should fail. + buf_t Qi0; + ASSERT_EQ(coinbase::api::schnorr_mp::get_public_share_compressed(key_blobs[0], Qi0), SUCCESS); + buf_t bad_x = x_fixed[0]; + bad_x[0] ^= 0x01; + buf_t bad_merged; + EXPECT_NE(coinbase::api::schnorr_mp::attach_private_scalar(public_blobs[0], bad_x, Qi0, bad_merged), SUCCESS); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(ApiSchnorrMpAcNeg, DkgAcInvalidCurve) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p0"), coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2")}); + const std::vector quorum = {"p0", "p1"}; + buf_t key, sid; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::dkg_ac(job, curve_id(0), sid, ac, quorum, key), SUCCESS); + EXPECT_NE(coinbase::api::schnorr_mp::dkg_ac(job, curve_id(4), sid, ac, quorum, key), SUCCESS); + EXPECT_NE(coinbase::api::schnorr_mp::dkg_ac(job, curve_id(255), sid, ac, quorum, key), SUCCESS); +} + +TEST(ApiSchnorrMpAcNeg, SignAcMsg31Bytes) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p0"), coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2")}); + buf_t msg(31); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_ac(job, coinbase::mem_t(), ac, msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiSchnorrMpAcNeg, SignAcMsg33Bytes) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p0"), coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2")}); + buf_t msg(33); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_ac(job, coinbase::mem_t(), ac, msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiSchnorrMpAcNeg, SignAcMsg0Bytes) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p0"), coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2")}); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_ac(job, coinbase::mem_t(), ac, coinbase::mem_t(), /*sig_receiver=*/0, sig), + SUCCESS); +} + +TEST(ApiSchnorrMpAcNeg, SignAcInvalidSigReceiverNeg1) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p0"), coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2")}); + buf_t msg(32); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_ac(job, coinbase::mem_t(), ac, msg, /*sig_receiver=*/-1, sig), SUCCESS); +} + +TEST(ApiSchnorrMpAcNeg, SignAcInvalidSigReceiverTooLarge) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p0"), coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2")}); + buf_t msg(32); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_ac(job, coinbase::mem_t(), ac, msg, /*sig_receiver=*/5, sig), SUCCESS); +} + +TEST(ApiSchnorrMpAcNeg, SignAcEmptyKeyBlob) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p0"), coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2")}); + buf_t msg(32); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_ac(job, coinbase::mem_t(), ac, msg, /*sig_receiver=*/0, sig), SUCCESS); +} + +TEST(ApiSchnorrMpAcNeg, SignAcGarbageKeyBlob) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const coinbase::api::access_structure_t ac = coinbase::api::access_structure_t::Threshold( + 2, {coinbase::api::access_structure_t::leaf("p0"), coinbase::api::access_structure_t::leaf("p1"), + coinbase::api::access_structure_t::leaf("p2")}); + const uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t msg(32); + buf_t sig; + dylog_disable_scope_t no_log_err; + EXPECT_NE(coinbase::api::schnorr_mp::sign_ac(job, coinbase::mem_t(garbage, 4), ac, msg, /*sig_receiver=*/0, sig), + SUCCESS); +} diff --git a/tests/unit/api/test_tdh2.cpp b/tests/unit/api/test_tdh2.cpp new file mode 100644 index 00000000..1064356e --- /dev/null +++ b/tests/unit/api/test_tdh2.cpp @@ -0,0 +1,570 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::access_structure_t; +using coinbase::api::curve_id; +using coinbase::api::job_mp_t; +using coinbase::api::party_idx_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::api_harness::failing_transport_t; +using coinbase::testutils::api_harness::local_api_transport_t; +using coinbase::testutils::api_harness::run_mp; + +static void exercise_dkg_round_trip(curve_id curve) { + constexpr int n = 3; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + std::vector public_keys(n); + std::vector> public_shares(n); + std::vector private_shares(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::tdh2::dkg_additive(job, curve, public_keys[static_cast(i)], + public_shares[static_cast(i)], + private_shares[static_cast(i)], sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) { + EXPECT_EQ(public_keys[0], public_keys[static_cast(i)]); + EXPECT_EQ(public_shares[0], public_shares[static_cast(i)]); + EXPECT_EQ(sids[0], sids[static_cast(i)]); + } + ASSERT_EQ(public_shares[0].size(), static_cast(n)); + + buf_t plaintext(32); + for (int i = 0; i < plaintext.size(); i++) plaintext[i] = static_cast(0xA5 ^ i); + const mem_t label("tdh2-label"); + + buf_t ciphertext; + ASSERT_EQ(coinbase::api::tdh2::encrypt(public_keys[0], plaintext, label, ciphertext), SUCCESS); + ASSERT_GT(ciphertext.size(), 0); + ASSERT_EQ(coinbase::api::tdh2::verify(public_keys[0], ciphertext, label), SUCCESS); + + std::vector partials(n); + for (int i = 0; i < n; i++) { + ASSERT_EQ(coinbase::api::tdh2::partial_decrypt(private_shares[static_cast(i)], ciphertext, label, + partials[static_cast(i)]), + SUCCESS); + } + + const auto public_shares_mems = buf_t::to_mems(public_shares[0]); + const auto partials_mems = buf_t::to_mems(partials); + + buf_t decrypted; + ASSERT_EQ(coinbase::api::tdh2::combine_additive(public_keys[0], public_shares_mems, label, partials_mems, ciphertext, + decrypted), + SUCCESS); + EXPECT_EQ(mem_t(decrypted), mem_t(plaintext)); + + const mem_t wrong_label("wrong-label"); + EXPECT_NE(coinbase::api::tdh2::verify(public_keys[0], ciphertext, wrong_label), SUCCESS); +} + +static void exercise_dkg_ac_round_trip(curve_id curve) { + constexpr int n = 3; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector> transports; + transports.reserve(n); + for (const auto& p : peers) transports.push_back(std::make_shared(p)); + + std::vector names = {"p0", "p1", "p2"}; + std::vector name_views; + name_views.reserve(names.size()); + for (const auto& name : names) name_views.emplace_back(name); + + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + const std::vector quorum = {"p0", "p1"}; + + std::vector public_keys(n); + std::vector> public_shares(n); + std::vector private_shares(n); + std::vector sids(n); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + job_mp_t job{static_cast(i), name_views, *transports[static_cast(i)]}; + return coinbase::api::tdh2::dkg_ac(job, curve, sids[static_cast(i)], ac, quorum, + public_keys[static_cast(i)], public_shares[static_cast(i)], + private_shares[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, SUCCESS); + for (int i = 1; i < n; i++) { + EXPECT_EQ(public_keys[0], public_keys[static_cast(i)]); + EXPECT_EQ(public_shares[0], public_shares[static_cast(i)]); + EXPECT_EQ(sids[0], sids[static_cast(i)]); + } + + buf_t plaintext(32); + for (int i = 0; i < plaintext.size(); i++) plaintext[i] = static_cast(0x3C ^ i); + const mem_t label("tdh2-label"); + + buf_t ciphertext; + ASSERT_EQ(coinbase::api::tdh2::encrypt(public_keys[0], plaintext, label, ciphertext), SUCCESS); + ASSERT_EQ(coinbase::api::tdh2::verify(public_keys[0], ciphertext, label), SUCCESS); + + // Collect partial decryptions from a 2-party quorum. + buf_t partial0; + buf_t partial1; + ASSERT_EQ(coinbase::api::tdh2::partial_decrypt(private_shares[0], ciphertext, label, partial0), SUCCESS); + ASSERT_EQ(coinbase::api::tdh2::partial_decrypt(private_shares[1], ciphertext, label, partial1), SUCCESS); + + const std::vector partial_names = {"p0", "p1"}; + const std::vector partial_bufs = {partial0, partial1}; + + const auto public_shares_mems = buf_t::to_mems(public_shares[0]); + const auto partials_mems = buf_t::to_mems(partial_bufs); + + buf_t decrypted; + ASSERT_EQ(coinbase::api::tdh2::combine_ac(ac, public_keys[0], name_views, public_shares_mems, label, partial_names, + partials_mems, ciphertext, decrypted), + SUCCESS); + EXPECT_EQ(mem_t(decrypted), mem_t(plaintext)); + + // Not enough shares should fail. + buf_t decrypted2; + const std::vector one_name = {"p0"}; + const std::vector one_partial = {partial0}; + const auto one_partial_mems = buf_t::to_mems(one_partial); + EXPECT_NE(coinbase::api::tdh2::combine_ac(ac, public_keys[0], name_views, public_shares_mems, label, one_name, + one_partial_mems, ciphertext, decrypted2), + SUCCESS); +} + +} // namespace + +TEST(ApiTdh2, RoundTripEncryptDecrypt) { + exercise_dkg_round_trip(coinbase::api::curve_id::secp256k1); + exercise_dkg_round_trip(coinbase::api::curve_id::p256); + exercise_dkg_ac_round_trip(coinbase::api::curve_id::secp256k1); + exercise_dkg_ac_round_trip(coinbase::api::curve_id::p256); +} + +TEST(ApiTdh2, InvalidCurveRejected) { + failing_transport_t t; + std::vector names = {"p0", "p1"}; + job_mp_t job{/*self=*/0, names, t}; + + buf_t public_key; + std::vector public_shares; + buf_t private_share; + buf_t sid; + EXPECT_EQ( + coinbase::api::tdh2::dkg_additive(job, static_cast(42), public_key, public_shares, private_share, sid), + E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +#include + +TEST(ApiTdh2Neg, DkgAdditiveInvalidCurve0) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t public_key; + std::vector public_shares; + buf_t private_share; + buf_t sid; + EXPECT_NE( + coinbase::api::tdh2::dkg_additive(job, static_cast(0), public_key, public_shares, private_share, sid), + SUCCESS); +} + +TEST(ApiTdh2Neg, DkgAdditiveInvalidCurve255) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t public_key; + std::vector public_shares; + buf_t private_share; + buf_t sid; + EXPECT_NE( + coinbase::api::tdh2::dkg_additive(job, static_cast(255), public_key, public_shares, private_share, sid), + SUCCESS); +} + +TEST(ApiTdh2Neg, DkgAdditiveEd25519Rejected) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t public_key; + std::vector public_shares; + buf_t private_share; + buf_t sid; + EXPECT_NE(coinbase::api::tdh2::dkg_additive(job, curve_id::ed25519, public_key, public_shares, private_share, sid), + SUCCESS); +} + +TEST(ApiTdh2Neg, DkgAdditiveEmptyPartyList) { + failing_transport_t ft; + std::vector names = {}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t public_key; + std::vector public_shares; + buf_t private_share; + buf_t sid; + EXPECT_NE(coinbase::api::tdh2::dkg_additive(job, curve_id::secp256k1, public_key, public_shares, private_share, sid), + SUCCESS); +} + +TEST(ApiTdh2Neg, DkgAdditiveSingleParty) { + failing_transport_t ft; + std::vector names = {"p0"}; + job_mp_t job{/*self=*/0, names, ft}; + buf_t public_key; + std::vector public_shares; + buf_t private_share; + buf_t sid; + EXPECT_NE(coinbase::api::tdh2::dkg_additive(job, curve_id::secp256k1, public_key, public_shares, private_share, sid), + SUCCESS); +} + +TEST(ApiTdh2Neg, DkgAcInvalidCurve0) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + const std::vector quorum = {"p0", "p1"}; + buf_t public_key; + std::vector public_shares; + buf_t private_share; + buf_t sid; + EXPECT_NE(coinbase::api::tdh2::dkg_ac(job, static_cast(0), sid, ac, quorum, public_key, public_shares, + private_share), + SUCCESS); +} + +TEST(ApiTdh2Neg, DkgAcEd25519Rejected) { + failing_transport_t ft; + std::vector names = {"p0", "p1", "p2"}; + job_mp_t job{/*self=*/0, names, ft}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + const std::vector quorum = {"p0", "p1"}; + buf_t public_key; + std::vector public_shares; + buf_t private_share; + buf_t sid; + EXPECT_NE( + coinbase::api::tdh2::dkg_ac(job, curve_id::ed25519, sid, ac, quorum, public_key, public_shares, private_share), + SUCCESS); +} + +TEST(ApiTdh2Neg, EncryptEmptyPublicKey) { + dylog_disable_scope_t no_log; + buf_t ct; + buf_t pt(16); + EXPECT_NE(coinbase::api::tdh2::encrypt(mem_t(), pt, mem_t("label"), ct), SUCCESS); +} + +TEST(ApiTdh2Neg, EncryptGarbagePublicKey) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t ct; + buf_t pt(16); + EXPECT_NE(coinbase::api::tdh2::encrypt(mem_t(garbage, 4), pt, mem_t("label"), ct), SUCCESS); +} + +TEST(ApiTdh2Neg, EncryptEmptyPlaintext) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t ct; + EXPECT_NE(coinbase::api::tdh2::encrypt(mem_t(garbage, 4), mem_t(), mem_t("label"), ct), SUCCESS); +} + +TEST(ApiTdh2Neg, EncryptEmptyLabel) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t ct; + buf_t pt(16); + EXPECT_NE(coinbase::api::tdh2::encrypt(mem_t(garbage, 4), pt, mem_t(), ct), SUCCESS); +} + +TEST(ApiTdh2Neg, VerifyEmptyPublicKey) { + dylog_disable_scope_t no_log; + EXPECT_NE(coinbase::api::tdh2::verify(mem_t(), mem_t("ct"), mem_t("label")), SUCCESS); +} + +TEST(ApiTdh2Neg, VerifyGarbagePublicKey) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_NE(coinbase::api::tdh2::verify(mem_t(garbage, 4), mem_t("ct"), mem_t("label")), SUCCESS); +} + +TEST(ApiTdh2Neg, VerifyEmptyCiphertext) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_NE(coinbase::api::tdh2::verify(mem_t(garbage, 4), mem_t(), mem_t("label")), SUCCESS); +} + +TEST(ApiTdh2Neg, VerifyGarbageCiphertext) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_NE(coinbase::api::tdh2::verify(mem_t(garbage, 4), mem_t(garbage, 4), mem_t("label")), SUCCESS); +} + +TEST(ApiTdh2Neg, VerifyEmptyLabel) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_NE(coinbase::api::tdh2::verify(mem_t(garbage, 4), mem_t("ct"), mem_t()), SUCCESS); +} + +TEST(ApiTdh2Neg, PartialDecryptEmptyPrivateShare) { + dylog_disable_scope_t no_log; + buf_t pd; + EXPECT_NE(coinbase::api::tdh2::partial_decrypt(mem_t(), mem_t("ct"), mem_t("label"), pd), SUCCESS); +} + +TEST(ApiTdh2Neg, PartialDecryptGarbagePrivateShare) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pd; + EXPECT_NE(coinbase::api::tdh2::partial_decrypt(mem_t(garbage, 4), mem_t("ct"), mem_t("label"), pd), SUCCESS); +} + +TEST(ApiTdh2Neg, PartialDecryptEmptyCiphertext) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pd; + EXPECT_NE(coinbase::api::tdh2::partial_decrypt(mem_t(garbage, 4), mem_t(), mem_t("label"), pd), SUCCESS); +} + +TEST(ApiTdh2Neg, PartialDecryptGarbageCiphertext) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pd; + EXPECT_NE(coinbase::api::tdh2::partial_decrypt(mem_t(garbage, 4), mem_t(garbage, 4), mem_t("label"), pd), SUCCESS); +} + +TEST(ApiTdh2Neg, PartialDecryptEmptyLabel) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t pd; + EXPECT_NE(coinbase::api::tdh2::partial_decrypt(mem_t(garbage, 4), mem_t("ct"), mem_t(), pd), SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAdditiveEmptyPublicKey) { + dylog_disable_scope_t no_log; + std::vector ps = {mem_t("s1")}; + std::vector pd = {mem_t("d1")}; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_additive(mem_t(), ps, mem_t("label"), pd, mem_t("ct"), pt), SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAdditiveGarbagePublicKey) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector ps = {mem_t("s1")}; + std::vector pd = {mem_t("d1")}; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_additive(mem_t(garbage, 4), ps, mem_t("label"), pd, mem_t("ct"), pt), SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAdditiveEmptyCiphertext) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector ps = {mem_t("s1")}; + std::vector pd = {mem_t("d1")}; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_additive(mem_t(garbage, 4), ps, mem_t("label"), pd, mem_t(), pt), SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAdditiveGarbageCiphertext) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector ps = {mem_t("s1")}; + std::vector pd = {mem_t("d1")}; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_additive(mem_t(garbage, 4), ps, mem_t("label"), pd, mem_t(garbage, 4), pt), + SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAdditiveEmptyLabel) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector ps = {mem_t("s1")}; + std::vector pd = {mem_t("d1")}; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_additive(mem_t(garbage, 4), ps, mem_t(), pd, mem_t("ct"), pt), SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAdditiveSizeMismatch) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector ps = {mem_t("s1"), mem_t("s2")}; + std::vector pd = {mem_t("d1")}; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_additive(mem_t(garbage, 4), ps, mem_t("label"), pd, mem_t("ct"), pt), SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAdditiveEmptyVectors) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::vector ps; + std::vector pd; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_additive(mem_t(garbage, 4), ps, mem_t("label"), pd, mem_t("ct"), pt), SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAcEmptyPublicKey) { + dylog_disable_scope_t no_log; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + std::vector party_names = {"p0", "p1", "p2"}; + std::vector ps = {mem_t("s0"), mem_t("s1"), mem_t("s2")}; + std::vector pn = {"p0", "p1"}; + std::vector pd = {mem_t("d0"), mem_t("d1")}; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_ac(ac, mem_t(), party_names, ps, mem_t("label"), pn, pd, mem_t("ct"), pt), + SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAcGarbagePublicKey) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + std::vector party_names = {"p0", "p1", "p2"}; + std::vector ps = {mem_t("s0"), mem_t("s1"), mem_t("s2")}; + std::vector pn = {"p0", "p1"}; + std::vector pd = {mem_t("d0"), mem_t("d1")}; + buf_t pt; + EXPECT_NE( + coinbase::api::tdh2::combine_ac(ac, mem_t(garbage, 4), party_names, ps, mem_t("label"), pn, pd, mem_t("ct"), pt), + SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAcEmptyCiphertext) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + std::vector party_names = {"p0", "p1", "p2"}; + std::vector ps = {mem_t("s0"), mem_t("s1"), mem_t("s2")}; + std::vector pn = {"p0", "p1"}; + std::vector pd = {mem_t("d0"), mem_t("d1")}; + buf_t pt; + EXPECT_NE( + coinbase::api::tdh2::combine_ac(ac, mem_t(garbage, 4), party_names, ps, mem_t("label"), pn, pd, mem_t(), pt), + SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAcEmptyLabel) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + std::vector party_names = {"p0", "p1", "p2"}; + std::vector ps = {mem_t("s0"), mem_t("s1"), mem_t("s2")}; + std::vector pn = {"p0", "p1"}; + std::vector pd = {mem_t("d0"), mem_t("d1")}; + buf_t pt; + EXPECT_NE(coinbase::api::tdh2::combine_ac(ac, mem_t(garbage, 4), party_names, ps, mem_t(), pn, pd, mem_t("ct"), pt), + SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAcPartySharesSizeMismatch) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + std::vector party_names = {"p0", "p1", "p2"}; + std::vector ps = {mem_t("s0"), mem_t("s1")}; + std::vector pn = {"p0", "p1"}; + std::vector pd = {mem_t("d0"), mem_t("d1")}; + buf_t pt; + EXPECT_NE( + coinbase::api::tdh2::combine_ac(ac, mem_t(garbage, 4), party_names, ps, mem_t("label"), pn, pd, mem_t("ct"), pt), + SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAcPartialsSizeMismatch) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + std::vector party_names = {"p0", "p1", "p2"}; + std::vector ps = {mem_t("s0"), mem_t("s1"), mem_t("s2")}; + std::vector pn = {"p0", "p1"}; + std::vector pd = {mem_t("d0")}; + buf_t pt; + EXPECT_NE( + coinbase::api::tdh2::combine_ac(ac, mem_t(garbage, 4), party_names, ps, mem_t("label"), pn, pd, mem_t("ct"), pt), + SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAcEmptyPartyName) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + std::vector party_names = {"p0", "", "p2"}; + std::vector ps = {mem_t("s0"), mem_t("s1"), mem_t("s2")}; + std::vector pn = {"p0", "p1"}; + std::vector pd = {mem_t("d0"), mem_t("d1")}; + buf_t pt; + EXPECT_NE( + coinbase::api::tdh2::combine_ac(ac, mem_t(garbage, 4), party_names, ps, mem_t("label"), pn, pd, mem_t("ct"), pt), + SUCCESS); +} + +TEST(ApiTdh2Neg, CombineAcDuplicatePartyName) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + const access_structure_t ac = access_structure_t::Threshold( + 2, {access_structure_t::leaf("p0"), access_structure_t::leaf("p1"), access_structure_t::leaf("p2")}); + std::vector party_names = {"p0", "p0", "p2"}; + std::vector ps = {mem_t("s0"), mem_t("s1"), mem_t("s2")}; + std::vector pn = {"p0", "p1"}; + std::vector pd = {mem_t("d0"), mem_t("d1")}; + buf_t pt; + EXPECT_NE( + coinbase::api::tdh2::combine_ac(ac, mem_t(garbage, 4), party_names, ps, mem_t("label"), pn, pd, mem_t("ct"), pt), + SUCCESS); +} diff --git a/tests/unit/api/test_transport_harness.h b/tests/unit/api/test_transport_harness.h new file mode 100644 index 00000000..82beb1b5 --- /dev/null +++ b/tests/unit/api/test_transport_harness.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "utils/local_network/network_context.h" + +namespace coinbase::testutils::api_harness { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; +using coinbase::api::data_transport_i; +using coinbase::api::party_idx_t; +using coinbase::testutils::mpc_net_context_t; + +class local_api_transport_t final : public data_transport_i { + public: + explicit local_api_transport_t(std::shared_ptr ctx) : ctx_(std::move(ctx)) {} + + error_t send(party_idx_t receiver, mem_t msg) override { + ctx_->send(receiver, msg); + return SUCCESS; + } + + error_t receive(party_idx_t sender, buf_t& msg) override { return ctx_->receive(sender, msg); } + + error_t receive_all(const std::vector& senders, std::vector& msgs) override { + std::vector s; + s.reserve(senders.size()); + for (auto x : senders) s.push_back(static_cast(x)); + return ctx_->receive_all(s, msgs); + } + + private: + std::shared_ptr ctx_; +}; + +template +inline void run_2pc(const std::shared_ptr& c1, const std::shared_ptr& c2, F1&& f1, + F2&& f2, error_t& out_rv1, error_t& out_rv2) { + c1->reset(); + c2->reset(); + + std::atomic aborted{false}; + + std::thread t1([&] { + out_rv1 = f1(); + if (out_rv1 && !aborted.exchange(true)) { + c1->abort(); + c2->abort(); + } + }); + std::thread t2([&] { + out_rv2 = f2(); + if (out_rv2 && !aborted.exchange(true)) { + c1->abort(); + c2->abort(); + } + }); + + t1.join(); + t2.join(); +} + +template +inline void run_mp(const std::vector>& peers, F&& f, std::vector& out_rv) { + for (const auto& p : peers) p->reset(); + + out_rv.assign(peers.size(), UNINITIALIZED_ERROR); + std::atomic aborted{false}; + std::vector threads; + threads.reserve(peers.size()); + + for (size_t i = 0; i < peers.size(); i++) { + threads.emplace_back([&, i] { + out_rv[i] = f(static_cast(i)); + if (out_rv[i] && !aborted.exchange(true)) { + for (const auto& p : peers) p->abort(); + } + }); + } + for (auto& t : threads) t.join(); +} + +class failing_transport_t final : public data_transport_i { + public: + error_t send(party_idx_t /*receiver*/, mem_t /*msg*/) override { return E_GENERAL; } + error_t receive(party_idx_t /*sender*/, buf_t& /*msg*/) override { return E_GENERAL; } + error_t receive_all(const std::vector& /*senders*/, std::vector& /*msgs*/) override { + return E_GENERAL; + } +}; + +} // namespace coinbase::testutils::api_harness diff --git a/tests/unit/c_api/test_curve_validation.cpp b/tests/unit/c_api/test_curve_validation.cpp new file mode 100644 index 00000000..3bd5b114 --- /dev/null +++ b/tests/unit/c_api/test_curve_validation.cpp @@ -0,0 +1,69 @@ +#include + +#include +#include +#include +#include + +namespace { + +static cbmpc_error_t dummy_send(void* /*ctx*/, int32_t /*receiver*/, const uint8_t* /*data*/, int /*size*/) { + return E_GENERAL; +} + +static cbmpc_error_t dummy_receive(void* /*ctx*/, int32_t /*sender*/, cmem_t* /*out_msg*/) { return E_GENERAL; } + +static cbmpc_error_t dummy_receive_all(void* /*ctx*/, const int32_t* /*senders*/, int /*senders_count*/, + cmems_t* /*out_msgs*/) { + return E_GENERAL; +} + +} // namespace + +TEST(CApiCurveValidation, Ecdsa2pcRejectsInvalidCurve) { + const cbmpc_transport_t t = { + /*ctx=*/nullptr, + /*send=*/dummy_send, + /*receive=*/dummy_receive, + /*receive_all=*/nullptr, + /*free=*/nullptr, + }; + + cmem_t key_blob = {reinterpret_cast(0x1), 123}; + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &t}; + const cbmpc_error_t rv = cbmpc_ecdsa_2p_dkg(&job, static_cast(42), &key_blob); + EXPECT_EQ(rv, E_BADARG); + EXPECT_EQ(key_blob.data, nullptr); + EXPECT_EQ(key_blob.size, 0); +} + +TEST(CApiCurveValidation, Tdh2RejectsInvalidCurve) { + const cbmpc_transport_t t = { + /*ctx=*/nullptr, + /*send=*/dummy_send, + /*receive=*/dummy_receive, + /*receive_all=*/dummy_receive_all, + /*free=*/nullptr, + }; + + const char* names[2] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {/*self=*/0, /*party_names=*/names, /*party_names_count=*/2, /*transport=*/&t}; + + cmem_t pk = {reinterpret_cast(0x1), 123}; + cmems_t pub_shares = {123, reinterpret_cast(0x1), reinterpret_cast(0x1)}; + cmem_t priv_share = {reinterpret_cast(0x1), 123}; + cmem_t sid = {reinterpret_cast(0x1), 123}; + + const cbmpc_error_t rv = + cbmpc_tdh2_dkg_additive(&job, static_cast(42), &pk, &pub_shares, &priv_share, &sid); + EXPECT_EQ(rv, E_BADARG); + EXPECT_EQ(pk.data, nullptr); + EXPECT_EQ(pk.size, 0); + EXPECT_EQ(pub_shares.count, 0); + EXPECT_EQ(pub_shares.data, nullptr); + EXPECT_EQ(pub_shares.sizes, nullptr); + EXPECT_EQ(priv_share.data, nullptr); + EXPECT_EQ(priv_share.size, 0); + EXPECT_EQ(sid.data, nullptr); + EXPECT_EQ(sid.size, 0); +} diff --git a/tests/unit/c_api/test_ecdsa2pc.cpp b/tests/unit/c_api/test_ecdsa2pc.cpp new file mode 100644 index 00000000..9c8c9cdc --- /dev/null +++ b/tests/unit/c_api/test_ecdsa2pc.cpp @@ -0,0 +1,861 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "utils/local_network/network_context.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::party_idx_t; +using coinbase::testutils::mpc_net_context_t; + +struct transport_ctx_t { + std::shared_ptr net; + std::atomic* free_calls = nullptr; +}; + +static cbmpc_error_t transport_send(void* ctx, int32_t receiver, const uint8_t* data, int size) { + if (!ctx) return E_BADARG; + if (size < 0) return E_BADARG; + if (size > 0 && !data) return E_BADARG; + auto* c = static_cast(ctx); + c->net->send(static_cast(receiver), mem_t(data, size)); + return CBMPC_SUCCESS; +} + +static cbmpc_error_t transport_receive(void* ctx, int32_t sender, cmem_t* out_msg) { + if (!out_msg) return E_BADARG; + *out_msg = cmem_t{nullptr, 0}; + if (!ctx) return E_BADARG; + + auto* c = static_cast(ctx); + buf_t msg; + const error_t rv = c->net->receive(static_cast(sender), msg); + if (rv) return rv; + + const int n = msg.size(); + if (n < 0) return E_FORMAT; + if (n == 0) return CBMPC_SUCCESS; + + out_msg->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_msg->data) return E_INSUFFICIENT; + out_msg->size = n; + std::memmove(out_msg->data, msg.data(), static_cast(n)); + return CBMPC_SUCCESS; +} + +static cbmpc_error_t transport_receive_all(void* ctx, const int32_t* senders, int senders_count, cmems_t* out_msgs) { + if (!out_msgs) return E_BADARG; + *out_msgs = cmems_t{0, nullptr, nullptr}; + if (!ctx) return E_BADARG; + if (senders_count < 0) return E_BADARG; + if (senders_count > 0 && !senders) return E_BADARG; + + auto* c = static_cast(ctx); + std::vector s; + s.reserve(static_cast(senders_count)); + for (int i = 0; i < senders_count; i++) s.push_back(static_cast(senders[i])); + + std::vector msgs; + const error_t rv = c->net->receive_all(s, msgs); + if (rv) return rv; + if (msgs.size() != static_cast(senders_count)) return E_GENERAL; + + // Flatten into (data + sizes) buffers. + int total = 0; + for (const auto& m : msgs) { + const int sz = m.size(); + if (sz < 0) return E_FORMAT; + if (sz > INT_MAX - total) return E_RANGE; + total += sz; + } + + out_msgs->count = senders_count; + out_msgs->sizes = static_cast(cbmpc_malloc(sizeof(int) * static_cast(senders_count))); + if (!out_msgs->sizes) { + *out_msgs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } + + if (total > 0) { + out_msgs->data = static_cast(cbmpc_malloc(static_cast(total))); + if (!out_msgs->data) { + cbmpc_free(out_msgs->sizes); + *out_msgs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } + } + + int offset = 0; + for (int i = 0; i < senders_count; i++) { + const int sz = msgs[i].size(); + out_msgs->sizes[i] = sz; + if (sz) { + std::memmove(out_msgs->data + offset, msgs[i].data(), static_cast(sz)); + offset += sz; + } + } + + return CBMPC_SUCCESS; +} + +static void transport_free(void* ctx, void* ptr) { + if (!ptr) return; + auto* c = static_cast(ctx); + if (c && c->free_calls) c->free_calls->fetch_add(1); + cbmpc_free(ptr); +} + +template +static void run_2pc(const std::shared_ptr& c1, const std::shared_ptr& c2, F1&& f1, + F2&& f2, cbmpc_error_t& out_rv1, cbmpc_error_t& out_rv2) { + c1->reset(); + c2->reset(); + + std::atomic aborted{false}; + + std::thread t1([&] { + out_rv1 = f1(); + if (out_rv1 && !aborted.exchange(true)) { + c1->abort(); + c2->abort(); + } + }); + std::thread t2([&] { + out_rv2 = f2(); + if (out_rv2 && !aborted.exchange(true)) { + c1->abort(); + c2->abort(); + } + }); + + t1.join(); + t2.join(); +} + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +} // namespace + +TEST(CApiEcdsa2pc, DkgSignRefreshSign) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + std::atomic free_calls_1{0}; + std::atomic free_calls_2{0}; + transport_ctx_t ctx1{c1, &free_calls_1}; + transport_ctx_t ctx2{c2, &free_calls_2}; + + const cbmpc_transport_t t1 = { + /*ctx=*/&ctx1, + /*send=*/transport_send, + /*receive=*/transport_receive, + /*receive_all=*/transport_receive_all, + /*free=*/transport_free, + }; + const cbmpc_transport_t t2 = { + /*ctx=*/&ctx2, + /*send=*/transport_send, + /*receive=*/transport_receive, + /*receive_all=*/transport_receive_all, + /*free=*/transport_free, + }; + + cmem_t key_blob_1{nullptr, 0}; + cmem_t key_blob_2{nullptr, 0}; + cbmpc_error_t rv1 = UNINITIALIZED_ERROR; + cbmpc_error_t rv2 = UNINITIALIZED_ERROR; + + const cbmpc_2pc_job_t job1 = {CBMPC_2PC_P1, "p1", "p2", &t1}; + const cbmpc_2pc_job_t job2 = {CBMPC_2PC_P2, "p1", "p2", &t2}; + run_2pc( + c1, c2, [&] { return cbmpc_ecdsa_2p_dkg(&job1, CBMPC_CURVE_SECP256K1, &key_blob_1); }, + [&] { return cbmpc_ecdsa_2p_dkg(&job2, CBMPC_CURVE_SECP256K1, &key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_GT(key_blob_1.size, 0); + ASSERT_GT(key_blob_2.size, 0); + + cmem_t pub1{nullptr, 0}; + cmem_t pub2{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_key_compressed(key_blob_1, &pub1), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_key_compressed(key_blob_2, &pub2), CBMPC_SUCCESS); + expect_eq(pub1, pub2); + + uint8_t msg_hash_bytes[32]; + for (int i = 0; i < 32; i++) msg_hash_bytes[i] = static_cast(i); + const cmem_t msg_hash = {msg_hash_bytes, 32}; + const cmem_t sid_in = {nullptr, 0}; + + cmem_t sid_out1{nullptr, 0}; + cmem_t sid_out2{nullptr, 0}; + cmem_t sig1{nullptr, 0}; + cmem_t sig2{nullptr, 0}; + + run_2pc( + c1, c2, [&] { return cbmpc_ecdsa_2p_sign(&job1, key_blob_1, msg_hash, sid_in, &sid_out1, &sig1); }, + [&] { return cbmpc_ecdsa_2p_sign(&job2, key_blob_2, msg_hash, sid_in, &sid_out2, &sig2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_GT(sig1.size, 0); + ASSERT_EQ(sig2.size, 0); + expect_eq(sid_out1, sid_out2); + + // Verify signature against the returned public key. + const buf_t pub_buf(pub1.data, pub1.size); + const buf_t hash_buf(msg_hash_bytes, 32); + const buf_t sig_buf(sig1.data, sig1.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub_buf), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + ASSERT_EQ(verify_key.verify(hash_buf, sig_buf), SUCCESS); + + cmem_t new_key_blob_1{nullptr, 0}; + cmem_t new_key_blob_2{nullptr, 0}; + run_2pc( + c1, c2, [&] { return cbmpc_ecdsa_2p_refresh(&job1, key_blob_1, &new_key_blob_1); }, + [&] { return cbmpc_ecdsa_2p_refresh(&job2, key_blob_2, &new_key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + + cmem_t pub3{nullptr, 0}; + cmem_t pub4{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_key_compressed(new_key_blob_1, &pub3), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_key_compressed(new_key_blob_2, &pub4), CBMPC_SUCCESS); + expect_eq(pub3, pub4); + expect_eq(pub1, pub3); + + // The adapter must free buffers returned by receive/receive_all using our callback. + EXPECT_GT(free_calls_1.load(), 0); + EXPECT_GT(free_calls_2.load(), 0); + + cbmpc_cmem_free(pub1); + cbmpc_cmem_free(pub2); + cbmpc_cmem_free(pub3); + cbmpc_cmem_free(pub4); + cbmpc_cmem_free(sig1); + cbmpc_cmem_free(sig2); + cbmpc_cmem_free(sid_out1); + cbmpc_cmem_free(sid_out2); + cbmpc_cmem_free(key_blob_1); + cbmpc_cmem_free(key_blob_2); + cbmpc_cmem_free(new_key_blob_1); + cbmpc_cmem_free(new_key_blob_2); +} + +TEST(CApiEcdsa2pc, ValidatesArgs) { + cmem_t out{reinterpret_cast(0x1), 123}; + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&bad_job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + EXPECT_EQ(out.size, 0); + + // Missing sig_der_out is invalid. + EXPECT_EQ(cbmpc_ecdsa_2p_sign(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr, nullptr), + E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +// ========================================================================== +// Negative test helpers +// ========================================================================== + +namespace { + +static cbmpc_error_t noop_send(void*, int32_t, const uint8_t*, int) { return E_GENERAL; } +static cbmpc_error_t noop_receive(void*, int32_t, cmem_t*) { return E_GENERAL; } +static cbmpc_error_t noop_receive_all(void*, const int32_t*, int, cmems_t*) { return E_GENERAL; } + +static const cbmpc_transport_t noop_capi_transport = {nullptr, noop_send, noop_receive, noop_receive_all, nullptr}; + +static void capi_generate_key_blobs(cbmpc_curve_id_t curve, cmem_t& blob1, cmem_t& blob2) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + transport_ctx_t ctx1{c1, nullptr}; + transport_ctx_t ctx2{c2, nullptr}; + + const cbmpc_transport_t t1 = {&ctx1, transport_send, transport_receive, transport_receive_all, transport_free}; + const cbmpc_transport_t t2 = {&ctx2, transport_send, transport_receive, transport_receive_all, transport_free}; + + blob1 = {nullptr, 0}; + blob2 = {nullptr, 0}; + cbmpc_error_t rv1 = UNINITIALIZED_ERROR; + cbmpc_error_t rv2 = UNINITIALIZED_ERROR; + + const cbmpc_2pc_job_t job1 = {CBMPC_2PC_P1, "p1", "p2", &t1}; + const cbmpc_2pc_job_t job2 = {CBMPC_2PC_P2, "p1", "p2", &t2}; + run_2pc( + c1, c2, [&] { return cbmpc_ecdsa_2p_dkg(&job1, curve, &blob1); }, + [&] { return cbmpc_ecdsa_2p_dkg(&job2, curve, &blob2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_GT(blob1.size, 0); + ASSERT_GT(blob2.size, 0); +} + +} // namespace + +class CApiEcdsa2pcNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { capi_generate_key_blobs(CBMPC_CURVE_SECP256K1, blob1_, blob2_); } + + static void TearDownTestSuite() { + cbmpc_cmem_free(blob1_); + cbmpc_cmem_free(blob2_); + blob1_ = {nullptr, 0}; + blob2_ = {nullptr, 0}; + } + + static cmem_t blob1_; + static cmem_t blob2_; +}; + +cmem_t CApiEcdsa2pcNegWithBlobs::blob1_ = {nullptr, 0}; +cmem_t CApiEcdsa2pcNegWithBlobs::blob2_ = {nullptr, 0}; + +// ========================================================================== +// Negative: dkg +// ========================================================================== + +TEST(CApiEcdsa2pc, NegDkgNullOutput) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, nullptr), E_BADARG); +} + +TEST(CApiEcdsa2pc, NegDkgNullJob) { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_dkg(nullptr, CBMPC_CURVE_SECP256K1, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiEcdsa2pc, NegDkgInvalidJobFields) { + cmem_t out{nullptr, 0}; + + { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + } + { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, nullptr, "p2", &noop_capi_transport}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + } + { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", nullptr, &noop_capi_transport}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + } + { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "", "p2", &noop_capi_transport}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + } + { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "", &noop_capi_transport}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + } + { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "same", "same", &noop_capi_transport}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.send = nullptr; + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &bad_t}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive = nullptr; + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &bad_t}; + EXPECT_EQ(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + } +} + +TEST(CApiEcdsa2pc, NegDkgInvalidCurves) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_ED25519, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + for (int val : {0, 4, 100, 255}) { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_dkg(&job, static_cast(val), &out), CBMPC_SUCCESS) + << "Expected failure for curve_id=" << val; + EXPECT_EQ(out.data, nullptr); + } +} + +TEST(CApiEcdsa2pc, NegDkgInvalidParty) { + const cbmpc_2pc_job_t job = {static_cast(5), "p1", "p2", &noop_capi_transport}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_dkg(&job, CBMPC_CURVE_SECP256K1, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +// ========================================================================== +// Negative: get_public_key_compressed +// ========================================================================== + +TEST(CApiEcdsa2pc, NegGetPubKeyNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiEcdsa2pc, NegGetPubKeyBadBlob) { + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t{zeros, 64}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t one = 0x00; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t{&one, 1}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t{data, -1}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_get_public_key_compressed(cmem_t{nullptr, 10}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } +} + +// ========================================================================== +// Negative: get_public_share_compressed +// ========================================================================== + +TEST(CApiEcdsa2pc, NegGetPubShareNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_2p_get_public_share_compressed(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiEcdsa2pc, NegGetPubShareBadBlob) { + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_get_public_share_compressed(cmem_t{zeros, 64}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_get_public_share_compressed(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_get_public_share_compressed(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_get_public_share_compressed(cmem_t{data, -1}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } +} + +// ========================================================================== +// Negative: detach_private_scalar +// ========================================================================== + +TEST(CApiEcdsa2pc, NegDetachNullOutputs) { + uint8_t dummy[] = {0x01}; + cmem_t blob = {dummy, 1}; + cmem_t out1{nullptr, 0}; + cmem_t out2{nullptr, 0}; + + EXPECT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob, nullptr, &out2), E_BADARG); + EXPECT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob, &out1, nullptr), E_BADARG); + EXPECT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob, nullptr, nullptr), E_BADARG); +} + +TEST(CApiEcdsa2pc, NegDetachBadBlob) { + { + uint8_t zeros[64] = {}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_detach_private_scalar(cmem_t{zeros, 64}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_detach_private_scalar(cmem_t{nullptr, 0}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_detach_private_scalar(cmem_t{garbage, 4}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(cmem_t{data, -1}, &pub, &scalar), E_BADARG); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } +} + +// ========================================================================== +// Negative: attach_private_scalar +// ========================================================================== + +TEST(CApiEcdsa2pc, NegAttachNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_2p_attach_private_scalar(cmem_t{dummy, 1}, cmem_t{dummy, 1}, cmem_t{dummy, 1}, nullptr), + E_BADARG); +} + +TEST(CApiEcdsa2pc, NegAttachBadCmemInputs) { + cmem_t out{nullptr, 0}; + + { + uint8_t scalar[] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(cmem_t{nullptr, 0}, cmem_t{scalar, 1}, cmem_t{point, 33}, &out), + CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + uint8_t scalar[] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(cmem_t{garbage, 4}, cmem_t{scalar, 1}, cmem_t{point, 33}, &out), + CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_2p_attach_private_scalar(cmem_t{data, -1}, cmem_t{data, 1}, cmem_t{data, 1}, &out), E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_2p_attach_private_scalar(cmem_t{data, 1}, cmem_t{data, -1}, cmem_t{data, 1}, &out), E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_2p_attach_private_scalar(cmem_t{data, 1}, cmem_t{data, 1}, cmem_t{data, -1}, &out), E_BADARG); + } +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegAttachEmptyPrivateScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob1_, &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_share_compressed(blob1_, &Qi), CBMPC_SUCCESS); + + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(pub, cmem_t{nullptr, 0}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegAttachGarbagePrivateScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob1_, &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_share_compressed(blob1_, &Qi), CBMPC_SUCCESS); + + uint8_t garbage[512]; + std::memset(garbage, 0xFF, sizeof(garbage)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(pub, cmem_t{garbage, 512}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegAttachGarbagePublicShare) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob1_, &pub, &x), CBMPC_SUCCESS); + + uint8_t bad_point[33]; + bad_point[0] = 0x05; + std::memset(bad_point + 1, 0xAB, 32); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(pub, x, cmem_t{bad_point, 33}, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegAttachSwappedScalars) { + cmem_t pub1{nullptr, 0}; + cmem_t x1{nullptr, 0}; + cmem_t pub2{nullptr, 0}; + cmem_t x2{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob1_, &pub1, &x1), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob2_, &pub2, &x2), CBMPC_SUCCESS); + + cmem_t Qi1{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_share_compressed(blob1_, &Qi1), CBMPC_SUCCESS); + + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(pub1, x2, Qi1, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub1); + cbmpc_cmem_free(x1); + cbmpc_cmem_free(pub2); + cbmpc_cmem_free(x2); + cbmpc_cmem_free(Qi1); +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegAttachSwappedPublicShares) { + cmem_t pub1{nullptr, 0}; + cmem_t x1{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob1_, &pub1, &x1), CBMPC_SUCCESS); + + cmem_t Qi2{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_share_compressed(blob2_, &Qi2), CBMPC_SUCCESS); + + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(pub1, x1, Qi2, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub1); + cbmpc_cmem_free(x1); + cbmpc_cmem_free(Qi2); +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegAttachZeroScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob1_, &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_share_compressed(blob1_, &Qi), CBMPC_SUCCESS); + + uint8_t zero[32] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(pub, cmem_t{zero, 32}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegAttachSingleByteZeroScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob1_, &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_get_public_share_compressed(blob1_, &Qi), CBMPC_SUCCESS); + + uint8_t zero_byte = 0x00; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(pub, cmem_t{&zero_byte, 1}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegAttachAllZeroPublicShare) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_2p_detach_private_scalar(blob1_, &pub, &x), CBMPC_SUCCESS); + + uint8_t zero_point[33] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_attach_private_scalar(pub, x, cmem_t{zero_point, 33}, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); +} + +// ========================================================================== +// Negative: sign +// ========================================================================== + +TEST(CApiEcdsa2pc, NegSignNullSigOutput) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + uint8_t hash[32] = {}; + EXPECT_EQ(cbmpc_ecdsa_2p_sign(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, cmem_t{nullptr, 0}, nullptr, nullptr), + E_BADARG); +} + +TEST(CApiEcdsa2pc, NegSignNullJob) { + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_sign(nullptr, cmem_t{nullptr, 0}, cmem_t{hash, 32}, cmem_t{nullptr, 0}, nullptr, &sig), + CBMPC_SUCCESS); +} + +TEST(CApiEcdsa2pc, NegSignBadKeyBlob) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + uint8_t hash[32] = {}; + + { + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_sign(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, cmem_t{nullptr, 0}, nullptr, &sig), + CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_sign(&job, cmem_t{garbage, 4}, cmem_t{hash, 32}, cmem_t{nullptr, 0}, nullptr, &sig), + CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_sign(&job, cmem_t{data, -1}, cmem_t{hash, 32}, cmem_t{nullptr, 0}, nullptr, &sig), + E_BADARG); + } + { + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_sign(&job, cmem_t{nullptr, 10}, cmem_t{hash, 32}, cmem_t{nullptr, 0}, nullptr, &sig), + E_BADARG); + } + { + uint8_t zeros[64] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_sign(&job, cmem_t{zeros, 64}, cmem_t{hash, 32}, cmem_t{nullptr, 0}, nullptr, &sig), + CBMPC_SUCCESS); + } +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegSignEmptyMsgHash) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_sign(&job, blob1_, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr, &sig), CBMPC_SUCCESS); +} + +TEST(CApiEcdsa2pc, NegSignBadMsgHash) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + uint8_t dummy_blob[] = {0x01}; + + { + uint8_t data[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_sign(&job, cmem_t{dummy_blob, 1}, cmem_t{data, -1}, cmem_t{nullptr, 0}, nullptr, &sig), + E_BADARG); + } + { + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_sign(&job, cmem_t{dummy_blob, 1}, cmem_t{nullptr, 10}, cmem_t{nullptr, 0}, nullptr, &sig), + E_BADARG); + } +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegSignOversizedMsgHash) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + uint8_t huge_hash[65]; + std::memset(huge_hash, 0x42, sizeof(huge_hash)); + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_sign(&job, blob1_, cmem_t{huge_hash, 65}, cmem_t{nullptr, 0}, nullptr, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegSignRoleMismatch) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P2, "p1", "p2", &noop_capi_transport}; + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_sign(&job, blob1_, cmem_t{hash, 32}, cmem_t{nullptr, 0}, nullptr, &sig), CBMPC_SUCCESS); +} + +// ========================================================================== +// Negative: refresh +// ========================================================================== + +TEST(CApiEcdsa2pc, NegRefreshNullOutput) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_2p_refresh(&job, cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiEcdsa2pc, NegRefreshNullJob) { + uint8_t dummy[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_refresh(nullptr, cmem_t{dummy, 1}, &out), CBMPC_SUCCESS); +} + +TEST(CApiEcdsa2pc, NegRefreshBadKeyBlob) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P1, "p1", "p2", &noop_capi_transport}; + + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_refresh(&job, cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_refresh(&job, cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_2p_refresh(&job, cmem_t{data, -1}, &out), E_BADARG); + } + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_refresh(&job, cmem_t{zeros, 64}, &out), CBMPC_SUCCESS); + } +} + +TEST_F(CApiEcdsa2pcNegWithBlobs, NegRefreshRoleMismatch) { + const cbmpc_2pc_job_t job = {CBMPC_2PC_P2, "p1", "p2", &noop_capi_transport}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_2p_refresh(&job, blob1_, &out), CBMPC_SUCCESS); +} diff --git a/tests/unit/c_api/test_ecdsa_mp.cpp b/tests/unit/c_api/test_ecdsa_mp.cpp new file mode 100644 index 00000000..ae66ef1b --- /dev/null +++ b/tests/unit/c_api/test_ecdsa_mp.cpp @@ -0,0 +1,996 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "utils/local_network/network_context.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::api::party_idx_t; +using coinbase::testutils::mpc_net_context_t; + +struct transport_ctx_t { + std::shared_ptr net; + std::atomic* free_calls = nullptr; +}; + +static cbmpc_error_t transport_send(void* ctx, int32_t receiver, const uint8_t* data, int size) { + if (!ctx) return E_BADARG; + if (size < 0) return E_BADARG; + if (size > 0 && !data) return E_BADARG; + auto* c = static_cast(ctx); + c->net->send(static_cast(receiver), mem_t(data, size)); + return CBMPC_SUCCESS; +} + +static cbmpc_error_t transport_receive(void* ctx, int32_t sender, cmem_t* out_msg) { + if (!out_msg) return E_BADARG; + *out_msg = cmem_t{nullptr, 0}; + if (!ctx) return E_BADARG; + + auto* c = static_cast(ctx); + buf_t msg; + const error_t rv = c->net->receive(static_cast(sender), msg); + if (rv) return rv; + + const int n = msg.size(); + if (n < 0) return E_FORMAT; + if (n == 0) return CBMPC_SUCCESS; + + out_msg->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_msg->data) return E_INSUFFICIENT; + out_msg->size = n; + std::memmove(out_msg->data, msg.data(), static_cast(n)); + return CBMPC_SUCCESS; +} + +static cbmpc_error_t transport_receive_all(void* ctx, const int32_t* senders, int senders_count, cmems_t* out_msgs) { + if (!out_msgs) return E_BADARG; + *out_msgs = cmems_t{0, nullptr, nullptr}; + if (!ctx) return E_BADARG; + if (senders_count < 0) return E_BADARG; + if (senders_count > 0 && !senders) return E_BADARG; + + auto* c = static_cast(ctx); + std::vector s; + s.reserve(static_cast(senders_count)); + for (int i = 0; i < senders_count; i++) s.push_back(static_cast(senders[i])); + + std::vector msgs; + const error_t rv = c->net->receive_all(s, msgs); + if (rv) return rv; + if (msgs.size() != static_cast(senders_count)) return E_GENERAL; + + // Flatten into (data + sizes) buffers. + int total = 0; + for (const auto& m : msgs) { + const int sz = m.size(); + if (sz < 0) return E_FORMAT; + if (sz > INT_MAX - total) return E_RANGE; + total += sz; + } + + out_msgs->count = senders_count; + out_msgs->sizes = static_cast(cbmpc_malloc(sizeof(int) * static_cast(senders_count))); + if (!out_msgs->sizes) { + *out_msgs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } + + if (total > 0) { + out_msgs->data = static_cast(cbmpc_malloc(static_cast(total))); + if (!out_msgs->data) { + cbmpc_free(out_msgs->sizes); + *out_msgs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } + } + + int offset = 0; + for (int i = 0; i < senders_count; i++) { + const int sz = msgs[i].size(); + out_msgs->sizes[i] = sz; + if (sz) { + std::memmove(out_msgs->data + offset, msgs[i].data(), static_cast(sz)); + offset += sz; + } + } + + return CBMPC_SUCCESS; +} + +static void transport_free(void* ctx, void* ptr) { + if (!ptr) return; + auto* c = static_cast(ctx); + if (c && c->free_calls) c->free_calls->fetch_add(1); + cbmpc_free(ptr); +} + +template +static void run_mp(const std::vector>& peers, F&& f, + std::vector& out_rv) { + for (const auto& p : peers) p->reset(); + + out_rv.assign(peers.size(), UNINITIALIZED_ERROR); + std::atomic aborted{false}; + std::vector threads; + threads.reserve(peers.size()); + + for (size_t i = 0; i < peers.size(); i++) { + threads.emplace_back([&, i] { + out_rv[i] = f(static_cast(i)); + if (out_rv[i] && !aborted.exchange(true)) { + for (const auto& p : peers) p->abort(); + } + }); + } + for (auto& t : threads) t.join(); +} + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +} // namespace + +TEST(CApiEcdsaMp, DkgSignRefreshSign4p) { + constexpr int n = 4; + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::atomic free_calls[n]; + transport_ctx_t ctx[n]; + cbmpc_transport_t transports[n]; + for (int i = 0; i < n; i++) { + free_calls[i].store(0); + ctx[i] = transport_ctx_t{peers[static_cast(i)], &free_calls[i]}; + transports[i] = cbmpc_transport_t{ + /*ctx=*/&ctx[i], + /*send=*/transport_send, + /*receive=*/transport_receive, + /*receive_all=*/transport_receive_all, + /*free=*/transport_free, + }; + } + + const char* party_names[n] = {"p0", "p1", "p2", "p3"}; + + std::vector key_blobs(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key_blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(key_blobs[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + for (int i = 1; i < n; i++) expect_eq(sids[0], sids[static_cast(i)]); + + cmem_t pub0{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(key_blobs[0], &pub0), CBMPC_SUCCESS); + ASSERT_EQ(pub0.size, 33); + for (int i = 1; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const buf_t pub_buf(pub0.data, pub0.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub_buf), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + // Change the party ordering ("role" indices) between protocols. + // Example: a party that was at index 1 ("p1") moves to index 2. + const char* party_names2[n] = {"p0", "p2", "p1", "p3"}; + // Map new role index -> old role index (DKG) for the same party name. + const int perm[n] = {0, 2, 1, 3}; + + uint8_t msg_hash_bytes[32]; + for (int i = 0; i < 32; i++) msg_hash_bytes[i] = static_cast(i); + const cmem_t msg_hash = {msg_hash_bytes, 32}; + + std::vector sigs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_ecdsa_mp_sign_additive(&job, key_blobs[static_cast(perm[i])], msg_hash, + /*sig_receiver=*/2, &sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_GT(sigs[2].size, 0); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + ASSERT_EQ(sigs[static_cast(i)].size, 0); + } + ASSERT_EQ(verify_key.verify(buf_t(msg_hash_bytes, 32), buf_t(sigs[2].data, sigs[2].size)), SUCCESS); + + std::vector new_key_blobs(n, cmem_t{nullptr, 0}); + std::vector sid_outs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_ecdsa_mp_refresh_additive( + &job, sids[static_cast(perm[i])], key_blobs[static_cast(perm[i])], + &sid_outs[static_cast(i)], &new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) ASSERT_GT(new_key_blobs[static_cast(i)].size, 0); + for (int i = 1; i < n; i++) expect_eq(sid_outs[0], sid_outs[static_cast(i)]); + expect_eq(sids[0], sid_outs[0]); + + for (int i = 0; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(new_key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + std::vector new_sigs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_ecdsa_mp_sign_additive(&job, new_key_blobs[static_cast(i)], msg_hash, /*sig_receiver=*/2, + &new_sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_GT(new_sigs[2].size, 0); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + ASSERT_EQ(new_sigs[static_cast(i)].size, 0); + } + ASSERT_EQ(verify_key.verify(buf_t(msg_hash_bytes, 32), buf_t(new_sigs[2].data, new_sigs[2].size)), SUCCESS); + + for (int i = 0; i < n; i++) EXPECT_GT(free_calls[i].load(), 0); + + cbmpc_cmem_free(pub0); + for (auto m : new_sigs) cbmpc_cmem_free(m); + for (auto m : sid_outs) cbmpc_cmem_free(m); + for (auto m : new_key_blobs) cbmpc_cmem_free(m); + for (auto m : sigs) cbmpc_cmem_free(m); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : key_blobs) cbmpc_cmem_free(m); +} + +TEST(CApiEcdsaMp, ValidatesArgs) { + cmem_t key{reinterpret_cast(0x1), 123}; + cmem_t sid{reinterpret_cast(0x1), 123}; + + const cbmpc_transport_t bad_transport = {/*ctx=*/nullptr, /*send=*/nullptr, /*receive=*/nullptr, + /*receive_all=*/nullptr, + /*free=*/nullptr}; + const char* names[2] = {"p0", "p1"}; + const cbmpc_mp_job_t bad_job = {/*self=*/0, /*party_names=*/names, /*party_names_count=*/2, + /*transport=*/&bad_transport}; + + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_additive(&bad_job, CBMPC_CURVE_SECP256K1, &key, &sid), E_BADARG); + EXPECT_EQ(key.data, nullptr); + EXPECT_EQ(key.size, 0); + EXPECT_EQ(sid.data, nullptr); + EXPECT_EQ(sid.size, 0); + + // Missing sig_der_out is invalid. + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, 0, nullptr), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +// ========================================================================== +// Negative test helpers +// ========================================================================== + +namespace { + +static cbmpc_error_t noop_send(void*, int32_t, const uint8_t*, int) { return E_GENERAL; } +static cbmpc_error_t noop_receive(void*, int32_t, cmem_t*) { return E_GENERAL; } +static cbmpc_error_t noop_receive_all(void*, const int32_t*, int, cmems_t*) { return E_GENERAL; } + +static const cbmpc_transport_t noop_capi_transport = {nullptr, noop_send, noop_receive, noop_receive_all, nullptr}; + +static void capi_generate_mp_key_blobs(cbmpc_curve_id_t curve, int n, std::vector& blobs) { + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector ctxs(n); + std::vector transports(n); + for (int i = 0; i < n; i++) { + ctxs[i] = transport_ctx_t{peers[static_cast(i)], nullptr}; + transports[i] = + cbmpc_transport_t{&ctxs[i], transport_send, transport_receive, transport_receive_all, transport_free}; + } + + std::vector names; + for (int i = 0; i < n; i++) names.push_back("p" + std::to_string(i)); + std::vector name_ptrs; + for (const auto& nm : names) name_ptrs.push_back(nm.c_str()); + + blobs.resize(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = {i, name_ptrs.data(), n, &transports[static_cast(i)]}; + return cbmpc_ecdsa_mp_dkg_additive(&job, curve, &blobs[static_cast(i)], &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (auto m : sids) cbmpc_cmem_free(m); +} + +} // namespace + +class CApiEcdsaMpNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { capi_generate_mp_key_blobs(CBMPC_CURVE_SECP256K1, 3, blobs_); } + + static void TearDownTestSuite() { + for (auto m : blobs_) cbmpc_cmem_free(m); + blobs_.clear(); + } + + static std::vector blobs_; +}; + +std::vector CApiEcdsaMpNegWithBlobs::blobs_; + +// ========================================================================== +// Negative: dkg +// ========================================================================== + +TEST(CApiEcdsaMp, NegDkgNullOutKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_capi_transport}; + cmem_t sid{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, nullptr, &sid), E_BADARG); +} + +TEST(CApiEcdsaMp, NegDkgNullOutSid) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, nullptr), E_BADARG); +} + +TEST(CApiEcdsaMp, NegDkgNullJob) { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_additive(nullptr, CBMPC_CURVE_SECP256K1, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + EXPECT_EQ(sid.data, nullptr); +} + +TEST(CApiEcdsaMp, NegDkgInvalidJobFields) { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + + { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.send = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive_all = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), E_BADARG); + } +} + +TEST(CApiEcdsaMp, NegDkgInvalidCurves) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_capi_transport}; + + { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + for (int val : {0, 4, 255}) { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_additive(&job, static_cast(val), &key, &sid), CBMPC_SUCCESS) + << "Expected failure for curve_id=" << val; + EXPECT_EQ(key.data, nullptr); + } +} + +TEST(CApiEcdsaMp, NegDkgInvalidParty) { + { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {3, names, 3, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {-1, names, 2, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + { + const char* names[] = {"p0"}; + const cbmpc_mp_job_t job = {0, names, 1, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + { + const cbmpc_mp_job_t job = {0, nullptr, 0, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } +} + +TEST(CApiEcdsaMp, NegDkgDuplicatePartyNames) { + const char* names[] = {"p0", "p0", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); +} + +// ========================================================================== +// Negative: get_public_key_compressed +// ========================================================================== + +TEST(CApiEcdsaMp, NegGetPubKeyNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiEcdsaMp, NegGetPubKeyBadBlob) { + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t{data, -1}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t{nullptr, 10}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t{zeros, 64}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t one = 0x00; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t{&one, 1}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } +} + +TEST(CApiEcdsaMp, NegGetPubKeyOversizedBlob) { + uint8_t huge[4096]; + std::memset(huge, 0x42, sizeof(huge)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_key_compressed(cmem_t{huge, 4096}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +// ========================================================================== +// Negative: get_public_share_compressed +// ========================================================================== + +TEST(CApiEcdsaMp, NegGetPubShareNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiEcdsaMp, NegGetPubShareBadBlob) { + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t{data, -1}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t{nullptr, 10}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t{zeros, 64}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } +} + +TEST(CApiEcdsaMp, NegGetPubShareOversizedBlob) { + uint8_t huge[4096]; + std::memset(huge, 0x42, sizeof(huge)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_get_public_share_compressed(cmem_t{huge, 4096}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +// ========================================================================== +// Negative: detach_private_scalar +// ========================================================================== + +TEST(CApiEcdsaMp, NegDetachNullOutputs) { + uint8_t dummy[] = {0x01}; + cmem_t blob = {dummy, 1}; + cmem_t out1{nullptr, 0}; + cmem_t out2{nullptr, 0}; + + EXPECT_EQ(cbmpc_ecdsa_mp_detach_private_scalar(blob, nullptr, &out2), E_BADARG); + EXPECT_EQ(cbmpc_ecdsa_mp_detach_private_scalar(blob, &out1, nullptr), E_BADARG); + EXPECT_EQ(cbmpc_ecdsa_mp_detach_private_scalar(blob, nullptr, nullptr), E_BADARG); +} + +TEST(CApiEcdsaMp, NegDetachBadBlob) { + { + uint8_t zeros[64] = {}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_detach_private_scalar(cmem_t{zeros, 64}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_detach_private_scalar(cmem_t{nullptr, 0}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_detach_private_scalar(cmem_t{garbage, 4}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_detach_private_scalar(cmem_t{data, -1}, &pub, &scalar), E_BADARG); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } +} + +// ========================================================================== +// Negative: attach_private_scalar +// ========================================================================== + +TEST(CApiEcdsaMp, NegAttachNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{dummy, 1}, cmem_t{dummy, 1}, cmem_t{dummy, 1}, nullptr), + E_BADARG); +} + +TEST(CApiEcdsaMp, NegAttachBadCmemInputs) { + cmem_t out{nullptr, 0}; + + { + uint8_t scalar[] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + EXPECT_NE(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{nullptr, 0}, cmem_t{scalar, 1}, cmem_t{point, 33}, &out), + CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + uint8_t scalar[] = {0x01}; + uint8_t point[33] = {}; + point[0] = 0x02; + EXPECT_NE(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{garbage, 4}, cmem_t{scalar, 1}, cmem_t{point, 33}, &out), + CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{data, -1}, cmem_t{data, 1}, cmem_t{data, 1}, &out), E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{data, 1}, cmem_t{data, -1}, cmem_t{data, 1}, &out), E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{data, 1}, cmem_t{data, 1}, cmem_t{data, -1}, &out), E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{nullptr, 10}, cmem_t{data, 1}, cmem_t{data, 1}, &out), + E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{data, 1}, cmem_t{nullptr, 10}, cmem_t{data, 1}, &out), + E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_attach_private_scalar(cmem_t{data, 1}, cmem_t{data, 1}, cmem_t{nullptr, 10}, &out), + E_BADARG); + } +} + +TEST_F(CApiEcdsaMpNegWithBlobs, NegAttachEmptyScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_share_compressed(blobs_[0], &Qi), CBMPC_SUCCESS); + + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_attach_private_scalar(pub, cmem_t{nullptr, 0}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEcdsaMpNegWithBlobs, NegAttachZeroScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_share_compressed(blobs_[0], &Qi), CBMPC_SUCCESS); + + uint8_t zero[32] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_attach_private_scalar(pub, cmem_t{zero, 32}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEcdsaMpNegWithBlobs, NegAttachGarbageScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_share_compressed(blobs_[0], &Qi), CBMPC_SUCCESS); + + uint8_t garbage[512]; + std::memset(garbage, 0xFF, sizeof(garbage)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_attach_private_scalar(pub, cmem_t{garbage, 512}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +// ========================================================================== +// Negative: sign_additive +// ========================================================================== + +TEST(CApiEcdsaMp, NegSignNullSigOutput) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t hash[32] = {}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, 0, nullptr), E_BADARG); +} + +TEST(CApiEcdsaMp, NegSignNullJob) { + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiEcdsaMp, NegSignInvalidJob) { + uint8_t hash[32] = {}; + + { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, 0, &sig), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.send = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, 0, &sig), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, 0, &sig), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive_all = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, 0, &sig), E_BADARG); + } +} + +TEST(CApiEcdsaMp, NegSignBadKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t hash[32] = {}; + + { + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{garbage, 4}, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{data, -1}, cmem_t{hash, 32}, 0, &sig), E_BADARG); + } + { + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{nullptr, 10}, cmem_t{hash, 32}, 0, &sig), E_BADARG); + } + { + uint8_t zeros[64] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{zeros, 64}, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); + } +} + +TEST(CApiEcdsaMp, NegSignBadMsgHash) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t dummy_blob[] = {0x01}; + + { + uint8_t data[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{dummy_blob, 1}, cmem_t{data, -1}, 0, &sig), E_BADARG); + } + { + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{dummy_blob, 1}, cmem_t{nullptr, 10}, 0, &sig), E_BADARG); + } +} + +TEST(CApiEcdsaMp, NegSignOversizedKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t hash[32] = {}; + uint8_t huge[4096]; + std::memset(huge, 0x42, sizeof(huge)); + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{huge, 4096}, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpNegWithBlobs, NegSignOversizedMsgHash) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t huge_hash[65]; + std::memset(huge_hash, 0x42, sizeof(huge_hash)); + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, blobs_[0], cmem_t{huge_hash, 65}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiEcdsaMp, NegSignInvalidSigReceiver) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t hash[32] = {}; + uint8_t dummy[] = {0x01}; + + { + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{dummy, 1}, cmem_t{hash, 32}, -1, &sig), CBMPC_SUCCESS); + } + { + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{dummy, 1}, cmem_t{hash, 32}, 100, &sig), CBMPC_SUCCESS); + } +} + +TEST_F(CApiEcdsaMpNegWithBlobs, NegSignEmptyKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpNegWithBlobs, NegSignAllZeroKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t hash[32] = {}; + uint8_t zeros[256] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_additive(&job, cmem_t{zeros, 256}, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); +} + +// ========================================================================== +// Negative: refresh_additive +// ========================================================================== + +TEST(CApiEcdsaMp, NegRefreshNullOutput) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{dummy, 1}, nullptr, nullptr), E_BADARG); +} + +TEST(CApiEcdsaMp, NegRefreshNullJob) { + uint8_t dummy[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{dummy, 1}, nullptr, &out), + CBMPC_SUCCESS); +} + +TEST(CApiEcdsaMp, NegRefreshInvalidJob) { + uint8_t dummy[] = {0x01}; + uint8_t hash[32] = {}; + + { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_additive(&job, cmem_t{hash, 32}, cmem_t{dummy, 1}, nullptr, &out), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.send = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_additive(&job, cmem_t{hash, 32}, cmem_t{dummy, 1}, nullptr, &out), E_BADARG); + } +} + +TEST(CApiEcdsaMp, NegRefreshBadKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr, &out), + CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{garbage, 4}, nullptr, &out), + CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{data, -1}, nullptr, &out), E_BADARG); + } + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{zeros, 64}, nullptr, &out), + CBMPC_SUCCESS); + } +} + +TEST(CApiEcdsaMp, NegRefreshOversizedKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t huge[4096]; + std::memset(huge, 0x42, sizeof(huge)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{huge, 4096}, nullptr, &out), + CBMPC_SUCCESS); +} diff --git a/tests/unit/c_api/test_ecdsa_mp_ac.cpp b/tests/unit/c_api/test_ecdsa_mp_ac.cpp new file mode 100644 index 00000000..16464563 --- /dev/null +++ b/tests/unit/c_api/test_ecdsa_mp_ac.cpp @@ -0,0 +1,896 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_mp; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +static void make_peers(int n, std::vector>& peers) { + peers.clear(); + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); +} + +static void make_transports(const std::vector>& peers, + std::vector& ctxs, std::vector& transports) { + ctxs.resize(peers.size()); + transports.resize(peers.size()); + for (size_t i = 0; i < peers.size(); i++) { + ctxs[i] = transport_ctx_t{peers[i], /*free_calls=*/nullptr}; + transports[i] = make_transport(&ctxs[i]); + } +} + +} // namespace + +TEST(CApiEcdsaMpAc, ComplexAccess_DkgRefreshSign4p) { + constexpr int n = 4; + + // Full 4-party network for threshold DKG/refresh. + std::vector> peers; + make_peers(n, peers); + + std::vector ctxs; + std::vector transports; + make_transports(peers, ctxs, transports); + + const char* party_names[n] = {"p0", "p1", "p2", "p3"}; + + // Access structure: + // THRESHOLD[2]( + // AND(p0, p1), + // OR(p2, p3) + // ) + const int32_t child_indices[] = {1, 2, 3, 4, 5, 6}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, /*off=*/0, /*cnt=*/2}, + {CBMPC_ACCESS_STRUCTURE_NODE_AND, /*leaf_name=*/nullptr, /*k=*/0, /*off=*/2, /*cnt=*/2}, + {CBMPC_ACCESS_STRUCTURE_NODE_OR, /*leaf_name=*/nullptr, /*k=*/0, /*off=*/4, /*cnt=*/2}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p0", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p1", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p2", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p3", /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes, + /*nodes_count=*/static_cast(sizeof(nodes) / sizeof(nodes[0])), + /*child_indices=*/child_indices, + /*child_indices_count=*/static_cast(sizeof(child_indices) / sizeof(child_indices[0])), + /*root_index=*/0, + }; + + // DKG quorum must satisfy the access structure. Use {p0, p1, p2}. + const char* dkg_quorum[] = {"p0", "p1", "p2"}; + + std::vector key_blobs(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, + /*sid_in=*/cmem_t{nullptr, 0}, &ac, dkg_quorum, + /*quorum_party_names_count=*/3, &key_blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(key_blobs[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + for (int i = 1; i < n; i++) expect_eq(sids[0], sids[static_cast(i)]); + + cmem_t pub0{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(key_blobs[0], &pub0), CBMPC_SUCCESS); + ASSERT_EQ(pub0.size, 33); + for (int i = 1; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const buf_t pub_buf(pub0.data, pub0.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub_buf), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + uint8_t msg_hash_bytes[32]; + for (int i = 0; i < 32; i++) msg_hash_bytes[i] = static_cast(i); + const cmem_t msg_hash = {msg_hash_bytes, 32}; + + // Signing quorum A: {p0, p1, p2} + const char* quorum_a[] = {"p0", "p1", "p2"}; + const cmem_t quorum_a_key_blobs[] = {key_blobs[0], key_blobs[1], key_blobs[2]}; + + { + std::vector> sign_peers; + make_peers(3, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(3, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/quorum_a, + /*party_names_count=*/3, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_ecdsa_mp_sign_ac(&job, quorum_a_key_blobs[static_cast(i)], &ac, msg_hash, + /*sig_receiver=*/0, &sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_GT(sigs[0].size, 0); + EXPECT_EQ(sigs[1].size, 0); + EXPECT_EQ(sigs[2].size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_hash_bytes, 32), buf_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + // Signing quorum B: {p0, p1, p3} + const char* quorum_b[] = {"p0", "p1", "p3"}; + const cmem_t quorum_b_key_blobs[] = {key_blobs[0], key_blobs[1], key_blobs[3]}; + + { + std::vector> sign_peers; + make_peers(3, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(3, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/quorum_b, + /*party_names_count=*/3, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_ecdsa_mp_sign_ac(&job, quorum_b_key_blobs[static_cast(i)], &ac, msg_hash, + /*sig_receiver=*/0, &sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_GT(sigs[0].size, 0); + EXPECT_EQ(sigs[1].size, 0); + EXPECT_EQ(sigs[2].size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_hash_bytes, 32), buf_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + // Threshold refresh with quorum B. + std::vector new_key_blobs(n, cmem_t{nullptr, 0}); + std::vector refresh_sids(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_ecdsa_mp_refresh_ac(&job, + /*sid_in=*/cmem_t{nullptr, 0}, key_blobs[static_cast(i)], &ac, + quorum_b, + /*quorum_party_names_count=*/3, &refresh_sids[static_cast(i)], + &new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 1; i < n; i++) expect_eq(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_ecdsa_mp_get_public_key_compressed(new_key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const cmem_t quorum_b_new_key_blobs[] = {new_key_blobs[0], new_key_blobs[1], new_key_blobs[3]}; + + { + std::vector> sign_peers; + make_peers(3, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(3, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/quorum_b, + /*party_names_count=*/3, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_ecdsa_mp_sign_ac(&job, quorum_b_new_key_blobs[static_cast(i)], &ac, msg_hash, + /*sig_receiver=*/0, &sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_GT(sigs[0].size, 0); + EXPECT_EQ(sigs[1].size, 0); + EXPECT_EQ(sigs[2].size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_hash_bytes, 32), buf_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + cbmpc_cmem_free(pub0); + for (auto m : refresh_sids) cbmpc_cmem_free(m); + for (auto m : new_key_blobs) cbmpc_cmem_free(m); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : key_blobs) cbmpc_cmem_free(m); +} + +TEST(CApiEcdsaMpAc, RejectInvalidAccessStructEncoding) { + // Dummy transport (won't be used; inputs fail before any I/O). + const cbmpc_transport_t transport = { + /*ctx=*/nullptr, + /*send=*/[](void*, int32_t, const uint8_t*, int) -> cbmpc_error_t { return E_GENERAL; }, + /*receive=*/[](void*, int32_t, cmem_t*) -> cbmpc_error_t { return E_GENERAL; }, + /*receive_all=*/[](void*, const int32_t*, int, cmems_t*) -> cbmpc_error_t { return E_GENERAL; }, + /*free=*/nullptr, + }; + + const char* party_names[2] = {"p0", "p1"}; + const cbmpc_mp_job_t job = { + /*self=*/0, + /*party_names=*/party_names, + /*party_names_count=*/2, + /*transport=*/&transport, + }; + + const char* quorum[] = {"p0", "p1"}; + + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + + // Root leaf is rejected. + { + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p0", /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes, + /*nodes_count=*/1, + /*child_indices=*/nullptr, + /*child_indices_count=*/0, + /*root_index=*/0, + }; + EXPECT_EQ( + cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); + EXPECT_EQ(out_key.data, nullptr); + EXPECT_EQ(out_key.size, 0); + EXPECT_EQ(out_sid.data, nullptr); + EXPECT_EQ(out_sid.size, 0); + } + + // Cycle (node 0 references itself). + { + const int32_t child_indices[] = {0}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_AND, /*leaf_name=*/nullptr, /*k=*/0, /*off=*/0, /*cnt=*/1}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes, + /*nodes_count=*/1, + /*child_indices=*/child_indices, + /*child_indices_count=*/1, + /*root_index=*/0, + }; + EXPECT_EQ( + cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); + } +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +// --------------------------------------------------------------------------- +// Helpers for negative tests +// --------------------------------------------------------------------------- + +namespace { + +static const cbmpc_transport_t noop_transport = { + nullptr, + [](void*, int32_t, const uint8_t*, int) -> cbmpc_error_t { return E_GENERAL; }, + [](void*, int32_t, cmem_t*) -> cbmpc_error_t { return E_GENERAL; }, + [](void*, const int32_t*, int, cmems_t*) -> cbmpc_error_t { return E_GENERAL; }, + nullptr, +}; + +static cbmpc_access_structure_t make_simple_ac_2of4() { + static const int32_t ci[] = {1, 2, 3, 4}; + static const cbmpc_access_structure_node_t nd[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 4}, {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p3", 0, 0, 0}, + }; + return {nd, 5, ci, 4, 0}; +} + +static void capi_generate_mp_ac_key_blobs(cbmpc_curve_id_t curve, int n, std::vector& blobs) { + std::vector> peers; + make_peers(n, peers); + + std::vector ctxs; + std::vector transports; + make_transports(peers, ctxs, transports); + + std::vector names; + for (int i = 0; i < n; i++) names.push_back("p" + std::to_string(i)); + std::vector name_ptrs; + for (const auto& nm : names) name_ptrs.push_back(nm.c_str()); + + const int32_t child_indices[] = {1, 2, 3, 4}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 4}, {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p3", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 5, child_indices, 4, 0}; + + const char* quorum[] = {"p0", "p1"}; + + blobs.resize(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = {i, name_ptrs.data(), n, &transports[static_cast(i)]}; + return cbmpc_ecdsa_mp_dkg_ac(&job, curve, cmem_t{nullptr, 0}, &ac, quorum, 2, &blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (auto m : sids) cbmpc_cmem_free(m); +} + +} // namespace + +class CApiEcdsaMpAcNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { capi_generate_mp_ac_key_blobs(CBMPC_CURVE_SECP256K1, 4, blobs_); } + static void TearDownTestSuite() { + for (auto m : blobs_) cbmpc_cmem_free(m); + blobs_.clear(); + } + static std::vector blobs_; +}; +std::vector CApiEcdsaMpAcNegWithBlobs::blobs_; + +// =========================================================================== +// Negative: DKG AC +// =========================================================================== + +TEST(CApiEcdsaMpAc, NegDkgAcNullOutKey) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, nullptr, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcNullOutSid) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, nullptr), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcNullJob) { + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE( + cbmpc_ecdsa_mp_dkg_ac(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEcdsaMpAc, NegDkgAcJobNullTransport) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcInvalidCurveEd25519) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEcdsaMpAc, NegDkgAcInvalidCurveZero) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_ac(&job, static_cast(0), cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, + &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEcdsaMpAc, NegDkgAcInvalidCurve4) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_ac(&job, static_cast(4), cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, + &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEcdsaMpAc, NegDkgAcInvalidCurve255) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_ac(&job, static_cast(255), cmem_t{nullptr, 0}, &ac, quorum, 2, + &out_key, &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEcdsaMpAc, NegDkgAcNullAccessStructure) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ( + cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, nullptr, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcNodesNull) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const cbmpc_access_structure_t ac = {nullptr, 5, nullptr, 0, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcNodesCountZero) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const cbmpc_access_structure_node_t dummy_node = {CBMPC_ACCESS_STRUCTURE_NODE_AND, nullptr, 0, 0, 0}; + const cbmpc_access_structure_t ac = {&dummy_node, 0, nullptr, 0, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcNodesCountNegative) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const cbmpc_access_structure_node_t dummy_node = {CBMPC_ACCESS_STRUCTURE_NODE_AND, nullptr, 0, 0, 0}; + const cbmpc_access_structure_t ac = {&dummy_node, -1, nullptr, 0, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcChildIndicesNullWithPositiveCount) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 4}, {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p3", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 5, nullptr, 4, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcChildIndicesCountNegative) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const int32_t ci[] = {1, 2, 3, 4}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 4}, {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p3", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 5, ci, -1, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcRootIndexNegative) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac_base = make_simple_ac_2of4(); + const cbmpc_access_structure_t ac = {ac_base.nodes, ac_base.nodes_count, ac_base.child_indices, + ac_base.child_indices_count, -1}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcRootIndexTooLarge) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac_base = make_simple_ac_2of4(); + const cbmpc_access_structure_t ac = {ac_base.nodes, ac_base.nodes_count, ac_base.child_indices, + ac_base.child_indices_count, 999}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcNullQuorum) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, nullptr, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcNegativeQuorumCount) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, -1, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEcdsaMpAc, NegDkgAcZeroQuorumCount) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 0, &out_key, &out_sid), + CBMPC_SUCCESS); +} + +// =========================================================================== +// Negative: Sign AC +// =========================================================================== + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcNullSigDerOut) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{hash, 32}, 0, nullptr), E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcNullJob) { + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(nullptr, blobs_[0], &ac, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcJobNullTransport) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{hash, 32}, 0, &sig), E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcNullAccessStructure) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], nullptr, cmem_t{hash, 32}, 0, &sig), E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcGarbageKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(&job, cmem_t{garbage, 4}, &ac, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcEmptyKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(&job, cmem_t{nullptr, 0}, &ac, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcAllZeroKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + uint8_t zeros[64] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(&job, cmem_t{zeros, 64}, &ac, cmem_t{hash, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcNegativeSizeKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + uint8_t data[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_ac(&job, cmem_t{data, -1}, &ac, cmem_t{hash, 32}, 0, &sig), E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcOversizedKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + std::vector huge(1024 * 1024, 0x42); + cmem_t sig{nullptr, 0}; + EXPECT_NE( + cbmpc_ecdsa_mp_sign_ac(&job, cmem_t{huge.data(), static_cast(huge.size())}, &ac, cmem_t{hash, 32}, 0, &sig), + CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcEmptyMsgHash) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{nullptr, 0}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcGarbageMsgHash) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t garbage[] = {0xDE, 0xAD}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{garbage, 2}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcNegativeSizeMsgHash) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{hash, -1}, 0, &sig), E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcOversizedMsgHash) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t huge_hash[65]; + std::memset(huge_hash, 0x42, sizeof(huge_hash)); + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{huge_hash, 65}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcSigReceiverNegative) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{hash, 32}, -1, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegSignAcSigReceiverTooLarge) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + uint8_t hash[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{hash, 32}, 999, &sig), CBMPC_SUCCESS); +} + +// =========================================================================== +// Negative: Refresh AC +// =========================================================================== + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcNullOutNewKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], &ac, quorum, 2, &sid_out, nullptr), + E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcNullJob) { + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_ac(nullptr, cmem_t{nullptr, 0}, blobs_[0], &ac, quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcJobNullTransport) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], &ac, quorum, 2, &sid_out, &out_new), + E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcNullAccessStructure) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], nullptr, quorum, 2, &sid_out, &out_new), + E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcGarbageKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, cmem_t{garbage, 4}, &ac, quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcEmptyKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &ac, quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcAllZeroKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + uint8_t zeros[64] = {}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, cmem_t{zeros, 64}, &ac, quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcOversizedKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + std::vector huge(1024 * 1024, 0x42); + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, cmem_t{huge.data(), static_cast(huge.size())}, &ac, + quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcNullQuorum) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], &ac, nullptr, 2, &sid_out, &out_new), + E_BADARG); +} + +TEST_F(CApiEcdsaMpAcNegWithBlobs, NegRefreshAcNegativeQuorumCount) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of4(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_EQ(cbmpc_ecdsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], &ac, quorum, -1, &sid_out, &out_new), + E_BADARG); +} diff --git a/tests/unit/c_api/test_eddsa2pc.cpp b/tests/unit/c_api/test_eddsa2pc.cpp new file mode 100644 index 00000000..424052f1 --- /dev/null +++ b/tests/unit/c_api/test_eddsa2pc.cpp @@ -0,0 +1,144 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_2pc; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +} // namespace + +TEST(CApiEdDSA2pc, DkgSignRefreshSign) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + std::atomic free_calls_1{0}; + std::atomic free_calls_2{0}; + transport_ctx_t ctx1{c1, &free_calls_1}; + transport_ctx_t ctx2{c2, &free_calls_2}; + + const cbmpc_transport_t t1 = make_transport(&ctx1); + const cbmpc_transport_t t2 = make_transport(&ctx2); + + cmem_t key_blob_1{nullptr, 0}; + cmem_t key_blob_2{nullptr, 0}; + cbmpc_error_t rv1 = UNINITIALIZED_ERROR; + cbmpc_error_t rv2 = UNINITIALIZED_ERROR; + + const cbmpc_2pc_job_t job1 = {CBMPC_2PC_P1, "p1", "p2", &t1}; + const cbmpc_2pc_job_t job2 = {CBMPC_2PC_P2, "p1", "p2", &t2}; + run_2pc( + c1, c2, [&] { return cbmpc_eddsa_2p_dkg(&job1, CBMPC_CURVE_ED25519, &key_blob_1); }, + [&] { return cbmpc_eddsa_2p_dkg(&job2, CBMPC_CURVE_ED25519, &key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_GT(key_blob_1.size, 0); + ASSERT_GT(key_blob_2.size, 0); + + cmem_t pub1{nullptr, 0}; + cmem_t pub2{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_2p_get_public_key_compressed(key_blob_1, &pub1), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_eddsa_2p_get_public_key_compressed(key_blob_2, &pub2), CBMPC_SUCCESS); + expect_eq(pub1, pub2); + ASSERT_EQ(pub1.size, 32); + + const buf_t pub_buf(pub1.data, pub1.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub_buf), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + uint8_t msg_bytes[32]; + for (int i = 0; i < 32; i++) msg_bytes[i] = static_cast(i); + const cmem_t msg = {msg_bytes, 32}; + + cmem_t sig1{nullptr, 0}; + cmem_t sig2{nullptr, 0}; + run_2pc( + c1, c2, [&] { return cbmpc_eddsa_2p_sign(&job1, key_blob_1, msg, &sig1); }, + [&] { return cbmpc_eddsa_2p_sign(&job2, key_blob_2, msg, &sig2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_EQ(sig1.size, 64); + ASSERT_EQ(sig2.size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_bytes, 32), buf_t(sig1.data, sig1.size)), SUCCESS); + + cmem_t new_key_blob_1{nullptr, 0}; + cmem_t new_key_blob_2{nullptr, 0}; + run_2pc( + c1, c2, [&] { return cbmpc_eddsa_2p_refresh(&job1, key_blob_1, &new_key_blob_1); }, + [&] { return cbmpc_eddsa_2p_refresh(&job2, key_blob_2, &new_key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_GT(new_key_blob_1.size, 0); + ASSERT_GT(new_key_blob_2.size, 0); + + cmem_t pub3{nullptr, 0}; + cmem_t pub4{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_2p_get_public_key_compressed(new_key_blob_1, &pub3), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_eddsa_2p_get_public_key_compressed(new_key_blob_2, &pub4), CBMPC_SUCCESS); + expect_eq(pub3, pub4); + expect_eq(pub1, pub3); + + cmem_t sig3{nullptr, 0}; + cmem_t sig4{nullptr, 0}; + run_2pc( + c1, c2, [&] { return cbmpc_eddsa_2p_sign(&job1, new_key_blob_1, msg, &sig3); }, + [&] { return cbmpc_eddsa_2p_sign(&job2, new_key_blob_2, msg, &sig4); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_EQ(sig3.size, 64); + ASSERT_EQ(sig4.size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_bytes, 32), buf_t(sig3.data, sig3.size)), SUCCESS); + + EXPECT_GT(free_calls_1.load(), 0); + EXPECT_GT(free_calls_2.load(), 0); + + cbmpc_cmem_free(pub1); + cbmpc_cmem_free(pub2); + cbmpc_cmem_free(pub3); + cbmpc_cmem_free(pub4); + cbmpc_cmem_free(sig1); + cbmpc_cmem_free(sig2); + cbmpc_cmem_free(sig3); + cbmpc_cmem_free(sig4); + cbmpc_cmem_free(key_blob_1); + cbmpc_cmem_free(key_blob_2); + cbmpc_cmem_free(new_key_blob_1); + cbmpc_cmem_free(new_key_blob_2); +} + +TEST(CApiEdDSA2pc, ValidatesArgs) { + cmem_t out{reinterpret_cast(0x1), 123}; + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + EXPECT_EQ(cbmpc_eddsa_2p_dkg(&bad_job, CBMPC_CURVE_ED25519, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + EXPECT_EQ(out.size, 0); + + // Missing sig_out is invalid. + EXPECT_EQ(cbmpc_eddsa_2p_sign(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr), E_BADARG); +} diff --git a/tests/unit/c_api/test_eddsa_mp.cpp b/tests/unit/c_api/test_eddsa_mp.cpp new file mode 100644 index 00000000..61ead377 --- /dev/null +++ b/tests/unit/c_api/test_eddsa_mp.cpp @@ -0,0 +1,933 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_mp; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +} // namespace + +TEST(CApiEdDSAMp, DkgSignRefreshSignRoleChange4p) { + constexpr int n = 4; + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::atomic free_calls[n]; + transport_ctx_t ctx[n]; + cbmpc_transport_t transports[n]; + for (int i = 0; i < n; i++) { + free_calls[i].store(0); + ctx[i] = transport_ctx_t{peers[static_cast(i)], &free_calls[i]}; + transports[i] = make_transport(&ctx[i]); + } + + const char* party_names[n] = {"p0", "p1", "p2", "p3"}; + + std::vector key_blobs(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key_blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(key_blobs[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + for (int i = 1; i < n; i++) expect_eq(sids[0], sids[static_cast(i)]); + + cmem_t pub0{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(key_blobs[0], &pub0), CBMPC_SUCCESS); + ASSERT_EQ(pub0.size, 32); + for (int i = 1; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const buf_t pub_buf(pub0.data, pub0.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub_buf), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + // Change the party ordering ("role" indices) between protocols. + const char* party_names2[n] = {"p0", "p2", "p1", "p3"}; + // Map new role index -> old role index (DKG) for the same party name. + const int perm[n] = {0, 2, 1, 3}; + + uint8_t msg_bytes[32]; + for (int i = 0; i < 32; i++) msg_bytes[i] = static_cast(i); + const cmem_t msg = {msg_bytes, 32}; + + std::vector sigs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_eddsa_mp_sign_additive(&job, key_blobs[static_cast(perm[i])], msg, /*sig_receiver=*/2, + &sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[2].size, 64); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + ASSERT_EQ(sigs[static_cast(i)].size, 0); + } + ASSERT_EQ(verify_key.verify(buf_t(msg_bytes, 32), buf_t(sigs[2].data, sigs[2].size)), SUCCESS); + + std::vector new_key_blobs(n, cmem_t{nullptr, 0}); + std::vector sid_outs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_eddsa_mp_refresh_additive( + &job, sids[static_cast(perm[i])], key_blobs[static_cast(perm[i])], + &sid_outs[static_cast(i)], &new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) ASSERT_GT(new_key_blobs[static_cast(i)].size, 0); + for (int i = 1; i < n; i++) expect_eq(sid_outs[0], sid_outs[static_cast(i)]); + expect_eq(sids[0], sid_outs[0]); + + for (int i = 0; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(new_key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + std::vector new_sigs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_eddsa_mp_sign_additive(&job, new_key_blobs[static_cast(i)], msg, /*sig_receiver=*/2, + &new_sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(new_sigs[2].size, 64); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + ASSERT_EQ(new_sigs[static_cast(i)].size, 0); + } + ASSERT_EQ(verify_key.verify(buf_t(msg_bytes, 32), buf_t(new_sigs[2].data, new_sigs[2].size)), SUCCESS); + + for (int i = 0; i < n; i++) EXPECT_GT(free_calls[i].load(), 0); + + cbmpc_cmem_free(pub0); + for (auto m : new_sigs) cbmpc_cmem_free(m); + for (auto m : sid_outs) cbmpc_cmem_free(m); + for (auto m : new_key_blobs) cbmpc_cmem_free(m); + for (auto m : sigs) cbmpc_cmem_free(m); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : key_blobs) cbmpc_cmem_free(m); +} + +TEST(CApiEdDSAMp, ValidatesArgs) { + cmem_t key{reinterpret_cast(0x1), 123}; + cmem_t sid{reinterpret_cast(0x1), 123}; + + const cbmpc_transport_t bad_transport = {/*ctx=*/nullptr, /*send=*/nullptr, /*receive=*/nullptr, + /*receive_all=*/nullptr, + /*free=*/nullptr}; + const char* names[2] = {"p0", "p1"}; + const cbmpc_mp_job_t bad_job = {/*self=*/0, /*party_names=*/names, /*party_names_count=*/2, + /*transport=*/&bad_transport}; + + EXPECT_EQ(cbmpc_eddsa_mp_dkg_additive(&bad_job, CBMPC_CURVE_ED25519, &key, &sid), E_BADARG); + EXPECT_EQ(key.data, nullptr); + EXPECT_EQ(key.size, 0); + EXPECT_EQ(sid.data, nullptr); + EXPECT_EQ(sid.size, 0); + + // Missing sig_out is invalid. + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, 0, nullptr), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +namespace { + +static cbmpc_error_t noop_send(void*, int32_t, const uint8_t*, int) { return E_GENERAL; } +static cbmpc_error_t noop_receive(void*, int32_t, cmem_t*) { return E_GENERAL; } +static cbmpc_error_t noop_receive_all(void*, const int32_t*, int, cmems_t*) { return E_GENERAL; } + +static const cbmpc_transport_t noop_capi_transport = {nullptr, noop_send, noop_receive, noop_receive_all, nullptr}; + +static void capi_generate_eddsa_mp_key_blobs(int n, std::vector& blobs) { + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::vector ctxs(n); + std::vector transports(n); + for (int i = 0; i < n; i++) { + ctxs[i] = transport_ctx_t{peers[static_cast(i)], nullptr}; + transports[i] = make_transport(&ctxs[i]); + } + + std::vector names; + for (int i = 0; i < n; i++) names.push_back("p" + std::to_string(i)); + std::vector name_ptrs; + for (const auto& nm : names) name_ptrs.push_back(nm.c_str()); + + blobs.resize(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = {i, name_ptrs.data(), n, &transports[static_cast(i)]}; + return cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (auto m : sids) cbmpc_cmem_free(m); +} + +} // namespace + +class CApiEdDSAMpNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { capi_generate_eddsa_mp_key_blobs(3, blobs_); } + static void TearDownTestSuite() { + for (auto m : blobs_) cbmpc_cmem_free(m); + blobs_.clear(); + } + static std::vector blobs_; +}; +std::vector CApiEdDSAMpNegWithBlobs::blobs_; + +// ========================================================================== +// Negative: dkg +// ========================================================================== + +TEST(CApiEdDSAMp, NegDkgNullOutKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_capi_transport}; + cmem_t sid{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, nullptr, &sid), E_BADARG); +} + +TEST(CApiEdDSAMp, NegDkgNullOutSid) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, nullptr), E_BADARG); +} + +TEST(CApiEdDSAMp, NegDkgNullJob) { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(nullptr, CBMPC_CURVE_ED25519, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + EXPECT_EQ(sid.data, nullptr); +} + +TEST(CApiEdDSAMp, NegDkgInvalidJobFields) { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + + { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.send = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive_all = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), E_BADARG); + } +} + +TEST(CApiEdDSAMp, NegDkgInvalidCurves) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_capi_transport}; + + { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_P256, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + for (int val : {0, 4, 255}) { + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(&job, static_cast(val), &key, &sid), CBMPC_SUCCESS) + << "Expected failure for curve_id=" << val; + EXPECT_EQ(key.data, nullptr); + } +} + +TEST(CApiEdDSAMp, NegDkgInvalidParty) { + { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {3, names, 3, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {-1, names, 2, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + { + const char* names[] = {"p0"}; + const cbmpc_mp_job_t job = {0, names, 1, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } + { + const cbmpc_mp_job_t job = {0, nullptr, 0, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); + } +} + +TEST(CApiEdDSAMp, NegDkgDuplicatePartyNames) { + const char* names[] = {"p0", "p0", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_additive(&job, CBMPC_CURVE_ED25519, &key, &sid), CBMPC_SUCCESS); + EXPECT_EQ(key.data, nullptr); +} + +// ========================================================================== +// Negative: get_public_key_compressed +// ========================================================================== + +TEST(CApiEdDSAMp, NegGetPubKeyNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiEdDSAMp, NegGetPubKeyBadBlob) { + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_key_compressed(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(cmem_t{data, -1}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(cmem_t{nullptr, 10}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_key_compressed(cmem_t{zeros, 64}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t one = 0x00; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_key_compressed(cmem_t{&one, 1}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_key_compressed(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } +} + +TEST(CApiEdDSAMp, NegGetPubKeyOversizedBlob) { + uint8_t huge[4096]; + std::memset(huge, 0x42, sizeof(huge)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_key_compressed(cmem_t{huge, 4096}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +// ========================================================================== +// Negative: get_public_share_compressed +// ========================================================================== + +TEST(CApiEdDSAMp, NegGetPubShareNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_get_public_share_compressed(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiEdDSAMp, NegGetPubShareBadBlob) { + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_share_compressed(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_get_public_share_compressed(cmem_t{data, -1}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_get_public_share_compressed(cmem_t{nullptr, 10}, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + } + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_share_compressed(cmem_t{zeros, 64}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_share_compressed(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); + } +} + +TEST(CApiEdDSAMp, NegGetPubShareOversizedBlob) { + uint8_t huge[4096]; + std::memset(huge, 0x42, sizeof(huge)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_get_public_share_compressed(cmem_t{huge, 4096}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +// ========================================================================== +// Negative: detach_private_scalar +// ========================================================================== + +TEST(CApiEdDSAMp, NegDetachNullOutputs) { + uint8_t dummy[] = {0x01}; + cmem_t blob = {dummy, 1}; + cmem_t out1{nullptr, 0}; + cmem_t out2{nullptr, 0}; + + EXPECT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blob, nullptr, &out2), E_BADARG); + EXPECT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blob, &out1, nullptr), E_BADARG); + EXPECT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blob, nullptr, nullptr), E_BADARG); +} + +TEST(CApiEdDSAMp, NegDetachBadBlob) { + { + uint8_t zeros[64] = {}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_detach_private_scalar(cmem_t{zeros, 64}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_detach_private_scalar(cmem_t{nullptr, 0}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_detach_private_scalar(cmem_t{garbage, 4}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } + { + uint8_t data[] = {0x01}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_detach_private_scalar(cmem_t{data, -1}, &pub, &scalar), E_BADARG); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); + } +} + +// ========================================================================== +// Negative: attach_private_scalar +// ========================================================================== + +TEST(CApiEdDSAMp, NegAttachNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{dummy, 1}, cmem_t{dummy, 1}, cmem_t{dummy, 1}, nullptr), + E_BADARG); +} + +TEST(CApiEdDSAMp, NegAttachBadCmemInputs) { + cmem_t out{nullptr, 0}; + + { + uint8_t scalar[] = {0x01}; + uint8_t point[32] = {}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{nullptr, 0}, cmem_t{scalar, 1}, cmem_t{point, 32}, &out), + CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + uint8_t scalar[] = {0x01}; + uint8_t point[32] = {}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{garbage, 4}, cmem_t{scalar, 1}, cmem_t{point, 32}, &out), + CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{data, -1}, cmem_t{data, 1}, cmem_t{data, 1}, &out), E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{data, 1}, cmem_t{data, -1}, cmem_t{data, 1}, &out), E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{data, 1}, cmem_t{data, 1}, cmem_t{data, -1}, &out), E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{nullptr, 10}, cmem_t{data, 1}, cmem_t{data, 1}, &out), + E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{data, 1}, cmem_t{nullptr, 10}, cmem_t{data, 1}, &out), + E_BADARG); + } + { + uint8_t data[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_attach_private_scalar(cmem_t{data, 1}, cmem_t{data, 1}, cmem_t{nullptr, 10}, &out), + E_BADARG); + } +} + +// ========================================================================== +// Negative: sign_additive +// ========================================================================== + +TEST(CApiEdDSAMp, NegSignNullSigOutput) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t msg[32] = {}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, 0, nullptr), E_BADARG); +} + +TEST(CApiEdDSAMp, NegSignNullJob) { + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMp, NegSignInvalidJob) { + uint8_t msg[32] = {}; + + { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, 0, &sig), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.send = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, 0, &sig), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, 0, &sig), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.receive_all = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, 0, &sig), E_BADARG); + } +} + +TEST(CApiEdDSAMp, NegSignBadKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t msg[32] = {}; + + { + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{garbage, 4}, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{data, -1}, cmem_t{msg, 32}, 0, &sig), E_BADARG); + } + { + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{nullptr, 10}, cmem_t{msg, 32}, 0, &sig), E_BADARG); + } + { + uint8_t zeros[64] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{zeros, 64}, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); + } +} + +TEST(CApiEdDSAMp, NegSignBadMsg) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t dummy_blob[] = {0x01}; + + { + uint8_t data[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{dummy_blob, 1}, cmem_t{data, -1}, 0, &sig), E_BADARG); + } + { + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{dummy_blob, 1}, cmem_t{nullptr, 10}, 0, &sig), E_BADARG); + } +} + +TEST(CApiEdDSAMp, NegSignOversizedKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t msg[32] = {}; + uint8_t huge[4096]; + std::memset(huge, 0x42, sizeof(huge)); + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{huge, 4096}, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMp, NegSignInvalidSigReceiver) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t msg[32] = {}; + uint8_t dummy[] = {0x01}; + + { + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{dummy, 1}, cmem_t{msg, 32}, -1, &sig), CBMPC_SUCCESS); + } + { + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{dummy, 1}, cmem_t{msg, 32}, 100, &sig), CBMPC_SUCCESS); + } +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegSignEmptyKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegSignAllZeroKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t msg[32] = {}; + uint8_t zeros[256] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, cmem_t{zeros, 256}, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegSignEmptyMsg) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_additive(&job, blobs_[0], cmem_t{nullptr, 0}, 0, &sig), CBMPC_SUCCESS); +} + +// ========================================================================== +// Negative: refresh_additive +// ========================================================================== + +TEST(CApiEdDSAMp, NegRefreshNullOutput) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{dummy, 1}, nullptr, nullptr), E_BADARG); +} + +TEST(CApiEdDSAMp, NegRefreshNullJob) { + uint8_t dummy[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{dummy, 1}, nullptr, &out), + CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMp, NegRefreshInvalidJob) { + uint8_t dummy[] = {0x01}; + uint8_t sid_data[32] = {}; + + { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_additive(&job, cmem_t{sid_data, 32}, cmem_t{dummy, 1}, nullptr, &out), E_BADARG); + } + { + cbmpc_transport_t bad_t = noop_capi_transport; + bad_t.send = nullptr; + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &bad_t}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_additive(&job, cmem_t{sid_data, 32}, cmem_t{dummy, 1}, nullptr, &out), E_BADARG); + } +} + +TEST(CApiEdDSAMp, NegRefreshBadKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + + { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr, &out), + CBMPC_SUCCESS); + } + { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{garbage, 4}, nullptr, &out), + CBMPC_SUCCESS); + } + { + uint8_t data[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{data, -1}, nullptr, &out), E_BADARG); + } + { + uint8_t zeros[64] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{zeros, 64}, nullptr, &out), + CBMPC_SUCCESS); + } +} + +TEST(CApiEdDSAMp, NegRefreshOversizedKeyBlob) { + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t job = {0, names, 3, &noop_capi_transport}; + uint8_t huge[4096]; + std::memset(huge, 0x42, sizeof(huge)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_additive(&job, cmem_t{nullptr, 0}, cmem_t{huge, 4096}, nullptr, &out), + CBMPC_SUCCESS); +} + +// ========================================================================== +// Negative (fixture): attach_private_scalar +// ========================================================================== + +TEST_F(CApiEdDSAMpNegWithBlobs, NegAttachEmptyScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_share_compressed(blobs_[0], &Qi), CBMPC_SUCCESS); + + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(pub, cmem_t{nullptr, 0}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegAttachZeroScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_share_compressed(blobs_[0], &Qi), CBMPC_SUCCESS); + + uint8_t zero[32] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(pub, cmem_t{zero, 32}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegAttachGarbageScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_share_compressed(blobs_[0], &Qi), CBMPC_SUCCESS); + + uint8_t garbage[512]; + std::memset(garbage, 0xFF, sizeof(garbage)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(pub, cmem_t{garbage, 512}, Qi, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegAttachWrongSizeScalar) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + cmem_t Qi{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_share_compressed(blobs_[0], &Qi), CBMPC_SUCCESS); + + { + uint8_t short_scalar[31]; + std::memset(short_scalar, 0x42, sizeof(short_scalar)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(pub, cmem_t{short_scalar, 31}, Qi, &out), CBMPC_SUCCESS); + } + { + uint8_t long_scalar[33]; + std::memset(long_scalar, 0x42, sizeof(long_scalar)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(pub, cmem_t{long_scalar, 33}, Qi, &out), CBMPC_SUCCESS); + } + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); + cbmpc_cmem_free(Qi); +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegAttachEmptyPublicShare) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(pub, x, cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegAttachAllZeroPublicShare) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + uint8_t zero_share[32] = {}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(pub, x, cmem_t{zero_share, 32}, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); +} + +TEST_F(CApiEdDSAMpNegWithBlobs, NegAttachGarbagePublicShare) { + cmem_t pub{nullptr, 0}; + cmem_t x{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_detach_private_scalar(blobs_[0], &pub, &x), CBMPC_SUCCESS); + + uint8_t garbage_share[32]; + std::memset(garbage_share, 0xFF, sizeof(garbage_share)); + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_attach_private_scalar(pub, x, cmem_t{garbage_share, 32}, &out), CBMPC_SUCCESS); + + cbmpc_cmem_free(pub); + cbmpc_cmem_free(x); +} diff --git a/tests/unit/c_api/test_eddsa_mp_ac.cpp b/tests/unit/c_api/test_eddsa_mp_ac.cpp new file mode 100644 index 00000000..5c7463f9 --- /dev/null +++ b/tests/unit/c_api/test_eddsa_mp_ac.cpp @@ -0,0 +1,777 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_mp; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +static void make_peers(int n, std::vector>& peers) { + peers.clear(); + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); +} + +static void make_transports(const std::vector>& peers, + std::vector& ctxs, std::vector& transports) { + ctxs.resize(peers.size()); + transports.resize(peers.size()); + for (size_t i = 0; i < peers.size(); i++) { + ctxs[i] = transport_ctx_t{peers[i], /*free_calls=*/nullptr}; + transports[i] = make_transport(&ctxs[i]); + } +} + +} // namespace + +TEST(CApiEdDSAMpAc, DkgRefreshSign2of3) { + constexpr int n = 3; + + // Full 3-party network for threshold DKG/refresh. + std::vector> peers; + make_peers(n, peers); + + std::vector ctxs; + std::vector transports; + make_transports(peers, ctxs, transports); + + const char* party_names[n] = {"p0", "p1", "p2"}; + + // Access structure: THRESHOLD[2](p0, p1, p2) + const int32_t child_indices[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, /*off=*/0, /*cnt=*/3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p0", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p1", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p2", /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes, + /*nodes_count=*/static_cast(sizeof(nodes) / sizeof(nodes[0])), + /*child_indices=*/child_indices, + /*child_indices_count=*/static_cast(sizeof(child_indices) / sizeof(child_indices[0])), + /*root_index=*/0, + }; + + // Only p0 and p1 actively contribute to DKG/refresh. + const char* quorum[] = {"p0", "p1"}; + + std::vector key_blobs(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, /*sid_in=*/cmem_t{nullptr, 0}, &ac, quorum, + /*quorum_party_names_count=*/2, &key_blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(key_blobs[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + for (int i = 1; i < n; i++) expect_eq(sids[0], sids[static_cast(i)]); + + cmem_t pub0{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(key_blobs[0], &pub0), CBMPC_SUCCESS); + ASSERT_EQ(pub0.size, 32); + for (int i = 1; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const buf_t pub_buf(pub0.data, pub0.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub_buf), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + uint8_t msg_bytes[32]; + for (int i = 0; i < 32; i++) msg_bytes[i] = static_cast(0x11 + i); + const cmem_t msg = {msg_bytes, 32}; + + // Signing quorum: {p0, p1} + const char* sign_party_names[2] = {"p0", "p1"}; + const cmem_t sign_key_blobs[2] = {key_blobs[0], key_blobs[1]}; + + { + std::vector> sign_peers; + make_peers(2, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(2, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/sign_party_names, + /*party_names_count=*/2, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_eddsa_mp_sign_ac(&job, sign_key_blobs[static_cast(i)], &ac, msg, /*sig_receiver=*/0, + &sigs[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[0].size, 64); + EXPECT_EQ(sigs[1].size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_bytes, 32), buf_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + // Threshold refresh. + std::vector new_key_blobs(n, cmem_t{nullptr, 0}); + std::vector refresh_sids(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_eddsa_mp_refresh_ac(&job, /*sid_in=*/cmem_t{nullptr, 0}, key_blobs[static_cast(i)], &ac, + quorum, /*quorum_party_names_count=*/2, &refresh_sids[static_cast(i)], + &new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) ASSERT_GT(new_key_blobs[static_cast(i)].size, 0); + for (int i = 1; i < n; i++) expect_eq(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(new_key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const cmem_t sign_new_key_blobs[2] = {new_key_blobs[0], new_key_blobs[1]}; + + { + std::vector> sign_peers; + make_peers(2, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(2, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/sign_party_names, + /*party_names_count=*/2, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_eddsa_mp_sign_ac(&job, sign_new_key_blobs[static_cast(i)], &ac, msg, /*sig_receiver=*/0, + &sigs[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[0].size, 64); + EXPECT_EQ(sigs[1].size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_bytes, 32), buf_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + cbmpc_cmem_free(pub0); + for (auto m : refresh_sids) cbmpc_cmem_free(m); + for (auto m : new_key_blobs) cbmpc_cmem_free(m); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : key_blobs) cbmpc_cmem_free(m); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +// =========================================================================== +// Helpers for negative tests +// =========================================================================== + +namespace { + +static const cbmpc_transport_t noop_transport = { + nullptr, + [](void*, int32_t, const uint8_t*, int) -> cbmpc_error_t { return E_GENERAL; }, + [](void*, int32_t, cmem_t*) -> cbmpc_error_t { return E_GENERAL; }, + [](void*, const int32_t*, int, cmems_t*) -> cbmpc_error_t { return E_GENERAL; }, + nullptr, +}; + +static cbmpc_access_structure_t make_simple_ac_2of3() { + static const int32_t ci[] = {1, 2, 3}; + static const cbmpc_access_structure_node_t nd[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + return {nd, 4, ci, 3, 0}; +} + +static void capi_generate_eddsa_ac_key_blobs(int n, std::vector& blobs) { + std::vector> peers; + make_peers(n, peers); + + std::vector ctxs; + std::vector transports; + make_transports(peers, ctxs, transports); + + std::vector names; + for (int i = 0; i < n; i++) names.push_back("p" + std::to_string(i)); + std::vector name_ptrs; + for (const auto& nm : names) name_ptrs.push_back(nm.c_str()); + + const cbmpc_access_structure_t ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + + blobs.resize(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = {i, name_ptrs.data(), n, &transports[static_cast(i)]}; + return cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, + &blobs[static_cast(i)], &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (auto m : sids) cbmpc_cmem_free(m); +} + +} // namespace + +class CApiEdDSAMpAcNegWithBlobs : public ::testing::Test { + protected: + static void SetUpTestSuite() { capi_generate_eddsa_ac_key_blobs(3, blobs_); } + static void TearDownTestSuite() { + for (auto m : blobs_) cbmpc_cmem_free(m); + blobs_.clear(); + } + static std::vector blobs_; +}; +std::vector CApiEdDSAMpAcNegWithBlobs::blobs_; + +// =========================================================================== +// Negative: DKG AC +// =========================================================================== + +TEST(CApiEdDSAMpAc, NegDkgAcNullOutKey) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, nullptr, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcNullOutSid) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, nullptr), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcNullJob) { + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_ac(nullptr, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMpAc, NegDkgAcJobNullTransport) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcInvalidCurveSecp256k1) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMpAc, NegDkgAcInvalidCurveP256) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_P256, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMpAc, NegDkgAcInvalidCurveZero) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_ac(&job, static_cast(0), cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, + &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMpAc, NegDkgAcInvalidCurve4) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_ac(&job, static_cast(4), cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, + &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMpAc, NegDkgAcInvalidCurve255) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_ac(&job, static_cast(255), cmem_t{nullptr, 0}, &ac, quorum, 2, + &out_key, &out_sid), + CBMPC_SUCCESS); +} + +TEST(CApiEdDSAMpAc, NegDkgAcNullAccessStructure) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ( + cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, nullptr, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcNodesNull) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const cbmpc_access_structure_t ac = {nullptr, 4, nullptr, 0, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcNodesCountZero) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const cbmpc_access_structure_node_t dummy_node = {CBMPC_ACCESS_STRUCTURE_NODE_AND, nullptr, 0, 0, 0}; + const cbmpc_access_structure_t ac = {&dummy_node, 0, nullptr, 0, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcNodesCountNegative) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const cbmpc_access_structure_node_t dummy_node = {CBMPC_ACCESS_STRUCTURE_NODE_AND, nullptr, 0, 0, 0}; + const cbmpc_access_structure_t ac = {&dummy_node, -1, nullptr, 0, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcChildIndicesNullWithPositiveCount) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, nullptr, 3, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcChildIndicesCountNegative) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, -1, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcRootIndexNegative) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac_base = make_simple_ac_2of3(); + const cbmpc_access_structure_t ac = {ac_base.nodes, ac_base.nodes_count, ac_base.child_indices, + ac_base.child_indices_count, -1}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcRootIndexTooLarge) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac_base = make_simple_ac_2of3(); + const cbmpc_access_structure_t ac = {ac_base.nodes, ac_base.nodes_count, ac_base.child_indices, + ac_base.child_indices_count, 999}; + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcNullQuorum) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, nullptr, 2, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcNegativeQuorumCount) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, -1, &out_key, &out_sid), + E_BADARG); +} + +TEST(CApiEdDSAMpAc, NegDkgAcZeroQuorumCount) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t out_key{nullptr, 0}; + cmem_t out_sid{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, cmem_t{nullptr, 0}, &ac, quorum, 0, &out_key, &out_sid), + CBMPC_SUCCESS); +} + +// =========================================================================== +// Negative: Sign AC +// =========================================================================== + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcNullSigOut) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{msg, 32}, 0, nullptr), E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcNullJob) { + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_ac(nullptr, blobs_[0], &ac, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcJobNullTransport) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{msg, 32}, 0, &sig), E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcNullAccessStructure) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_ac(&job, blobs_[0], nullptr, cmem_t{msg, 32}, 0, &sig), E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcGarbageKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_ac(&job, cmem_t{garbage, 4}, &ac, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcEmptyKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_ac(&job, cmem_t{nullptr, 0}, &ac, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcAllZeroKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + uint8_t zeros[64] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_ac(&job, cmem_t{zeros, 64}, &ac, cmem_t{msg, 32}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcNegativeSizeKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + uint8_t data[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_ac(&job, cmem_t{data, -1}, &ac, cmem_t{msg, 32}, 0, &sig), E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcOversizedKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + std::vector huge(1024 * 1024, 0x42); + cmem_t sig{nullptr, 0}; + EXPECT_NE( + cbmpc_eddsa_mp_sign_ac(&job, cmem_t{huge.data(), static_cast(huge.size())}, &ac, cmem_t{msg, 32}, 0, &sig), + CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcEmptyMsg) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{nullptr, 0}, 0, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcNegativeSizeMsg) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{msg, -1}, 0, &sig), E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcSigReceiverNegative) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{msg, 32}, -1, &sig), CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegSignAcSigReceiverTooLarge) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_sign_ac(&job, blobs_[0], &ac, cmem_t{msg, 32}, 999, &sig), CBMPC_SUCCESS); +} + +// =========================================================================== +// Negative: Refresh AC +// =========================================================================== + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcNullOutNewKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], &ac, quorum, 2, &sid_out, nullptr), + E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcNullJob) { + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_ac(nullptr, cmem_t{nullptr, 0}, blobs_[0], &ac, quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcJobNullTransport) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, nullptr}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], &ac, quorum, 2, &sid_out, &out_new), + E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcNullAccessStructure) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], nullptr, quorum, 2, &sid_out, &out_new), + E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcGarbageKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, cmem_t{garbage, 4}, &ac, quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcEmptyKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &ac, quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcAllZeroKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + uint8_t zeros[64] = {}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, cmem_t{zeros, 64}, &ac, quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcOversizedKeyBlob) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + std::vector huge(1024 * 1024, 0x42); + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_NE(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, cmem_t{huge.data(), static_cast(huge.size())}, &ac, + quorum, 2, &sid_out, &out_new), + CBMPC_SUCCESS); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcNullQuorum) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], &ac, nullptr, 2, &sid_out, &out_new), + E_BADARG); +} + +TEST_F(CApiEdDSAMpAcNegWithBlobs, NegRefreshAcNegativeQuorumCount) { + const char* names[] = {"p0", "p1"}; + const cbmpc_mp_job_t job = {0, names, 2, &noop_transport}; + const auto ac = make_simple_ac_2of3(); + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t out_new{nullptr, 0}; + EXPECT_EQ(cbmpc_eddsa_mp_refresh_ac(&job, cmem_t{nullptr, 0}, blobs_[0], &ac, quorum, -1, &sid_out, &out_new), + E_BADARG); +} diff --git a/tests/unit/c_api/test_eddsa_mp_threshold.cpp b/tests/unit/c_api/test_eddsa_mp_threshold.cpp new file mode 100644 index 00000000..9bf15583 --- /dev/null +++ b/tests/unit/c_api/test_eddsa_mp_threshold.cpp @@ -0,0 +1,230 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_mp; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +static void make_peers(int n, std::vector>& peers) { + peers.clear(); + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); +} + +static void make_transports(const std::vector>& peers, + std::vector& ctxs, std::vector& transports) { + ctxs.resize(peers.size()); + transports.resize(peers.size()); + for (size_t i = 0; i < peers.size(); i++) { + ctxs[i] = transport_ctx_t{peers[i], /*free_calls=*/nullptr}; + transports[i] = make_transport(&ctxs[i]); + } +} + +} // namespace + +TEST(CApiEdDSAMpThreshold, DkgRefreshSign2of3) { + constexpr int n = 3; + + // Full 3-party network for threshold DKG/refresh. + std::vector> peers; + make_peers(n, peers); + + std::vector ctxs; + std::vector transports; + make_transports(peers, ctxs, transports); + + const char* party_names[n] = {"p0", "p1", "p2"}; + + // Access structure: THRESHOLD[2](p0, p1, p2) + const int32_t child_indices[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, /*off=*/0, /*cnt=*/3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p0", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p1", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p2", /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes, + /*nodes_count=*/static_cast(sizeof(nodes) / sizeof(nodes[0])), + /*child_indices=*/child_indices, + /*child_indices_count=*/static_cast(sizeof(child_indices) / sizeof(child_indices[0])), + /*root_index=*/0, + }; + + // Only p0 and p1 actively contribute to DKG/refresh. + const char* quorum[] = {"p0", "p1"}; + + std::vector key_blobs(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_eddsa_mp_dkg_ac(&job, CBMPC_CURVE_ED25519, /*sid_in=*/cmem_t{nullptr, 0}, &ac, quorum, + /*quorum_party_names_count=*/2, &key_blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(key_blobs[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + for (int i = 1; i < n; i++) expect_eq(sids[0], sids[static_cast(i)]); + + cmem_t pub0{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(key_blobs[0], &pub0), CBMPC_SUCCESS); + ASSERT_EQ(pub0.size, 32); + for (int i = 1; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const buf_t pub_buf(pub0.data, pub0.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_ed25519, pub_buf), SUCCESS); + const coinbase::crypto::ecc_pub_key_t verify_key(Q); + + uint8_t msg_bytes[32]; + for (int i = 0; i < 32; i++) msg_bytes[i] = static_cast(0x11 + i); + const cmem_t msg = {msg_bytes, 32}; + + // Signing quorum: {p0, p1} + const char* sign_party_names[2] = {"p0", "p1"}; + const cmem_t sign_key_blobs[2] = {key_blobs[0], key_blobs[1]}; + + { + std::vector> sign_peers; + make_peers(2, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(2, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/sign_party_names, + /*party_names_count=*/2, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_eddsa_mp_sign_ac(&job, sign_key_blobs[static_cast(i)], &ac, msg, /*sig_receiver=*/0, + &sigs[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[0].size, 64); + EXPECT_EQ(sigs[1].size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_bytes, 32), buf_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + // Threshold refresh. + std::vector new_key_blobs(n, cmem_t{nullptr, 0}); + std::vector refresh_sids(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_eddsa_mp_refresh_ac(&job, /*sid_in=*/cmem_t{nullptr, 0}, key_blobs[static_cast(i)], &ac, + quorum, /*quorum_party_names_count=*/2, &refresh_sids[static_cast(i)], + &new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) ASSERT_GT(new_key_blobs[static_cast(i)].size, 0); + for (int i = 1; i < n; i++) expect_eq(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_eddsa_mp_get_public_key_compressed(new_key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const cmem_t sign_new_key_blobs[2] = {new_key_blobs[0], new_key_blobs[1]}; + + { + std::vector> sign_peers; + make_peers(2, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(2, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/sign_party_names, + /*party_names_count=*/2, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_eddsa_mp_sign_ac(&job, sign_new_key_blobs[static_cast(i)], &ac, msg, /*sig_receiver=*/0, + &sigs[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[0].size, 64); + EXPECT_EQ(sigs[1].size, 0); + ASSERT_EQ(verify_key.verify(buf_t(msg_bytes, 32), buf_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + cbmpc_cmem_free(pub0); + for (auto m : refresh_sids) cbmpc_cmem_free(m); + for (auto m : new_key_blobs) cbmpc_cmem_free(m); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : key_blobs) cbmpc_cmem_free(m); +} diff --git a/tests/unit/c_api/test_pve.cpp b/tests/unit/c_api/test_pve.cpp new file mode 100644 index 00000000..083a2ee0 --- /dev/null +++ b/tests/unit/c_api/test_pve.cpp @@ -0,0 +1,721 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +using coinbase::buf_t; +using coinbase::mem_t; + +static cbmpc_error_t toy_encrypt(void* ctx, cmem_t /*ek*/, cmem_t /*label*/, cmem_t plain, cmem_t /*rho*/, + cmem_t* out_ct) { + if (!out_ct) return E_BADARG; + *out_ct = cmem_t{nullptr, 0}; + if (plain.size < 0) return E_BADARG; + if (plain.size > 0 && !plain.data) return E_BADARG; + + const int mode = ctx ? *static_cast(ctx) : 0; + const int extra = (mode == 0) ? 0 : 1; + const int n = plain.size + extra; + if (n == 0) return CBMPC_SUCCESS; + + out_ct->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_ct->data) return E_INSUFFICIENT; + out_ct->size = n; + if (plain.size) std::memmove(out_ct->data, plain.data, static_cast(plain.size)); + if (extra) out_ct->data[n - 1] = 0x42; + return CBMPC_SUCCESS; +} + +static cbmpc_error_t toy_decrypt(void* ctx, cmem_t /*dk*/, cmem_t /*label*/, cmem_t ct, cmem_t* out_plain) { + if (!out_plain) return E_BADARG; + *out_plain = cmem_t{nullptr, 0}; + if (ct.size < 0) return E_BADARG; + if (ct.size > 0 && !ct.data) return E_BADARG; + + const int mode = ctx ? *static_cast(ctx) : 0; + const int extra = (mode == 0) ? 0 : 1; + if (ct.size < extra) return E_FORMAT; + + const int n = ct.size - extra; + if (n == 0) return CBMPC_SUCCESS; + + out_plain->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_plain->data) return E_INSUFFICIENT; + out_plain->size = n; + std::memmove(out_plain->data, ct.data, static_cast(n)); + return CBMPC_SUCCESS; +} + +static buf_t expected_Q(cbmpc_curve_id_t curve_id, mem_t x) { + const coinbase::crypto::ecurve_t curve = (curve_id == CBMPC_CURVE_P256) ? coinbase::crypto::curve_p256 + : (curve_id == CBMPC_CURVE_SECP256K1) ? coinbase::crypto::curve_secp256k1 + : (curve_id == CBMPC_CURVE_ED25519) ? coinbase::crypto::curve_ed25519 + : coinbase::crypto::ecurve_t(); + cb_assert(curve.valid()); + + const coinbase::crypto::bn_t bn_x = coinbase::crypto::bn_t::from_bin(x) % curve.order(); + const coinbase::crypto::ecc_point_t Q = bn_x * curve.generator(); + return Q.to_compressed_bin(); +} + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +// Mirror of the cbmpc base-PKE key blob format used by `cbmpc_pve_generate_base_pke_*`. +// Test-only plumbing to build HSM stubs using software keys. +constexpr uint32_t base_pke_key_blob_version_v1 = 1; +enum class base_pke_key_type_v1 : uint32_t { + rsa_oaep_2048 = 1, + ecies_p256 = 2, +}; + +struct base_pke_dk_blob_v1_t { + uint32_t version = base_pke_key_blob_version_v1; + uint32_t key_type = static_cast(base_pke_key_type_v1::rsa_oaep_2048); + + coinbase::crypto::rsa_prv_key_t rsa_dk; + coinbase::crypto::ecc_prv_key_t ecies_dk; + + void convert(coinbase::converter_t& c) { + c.convert(version, key_type); + switch (static_cast(key_type)) { + case base_pke_key_type_v1::rsa_oaep_2048: + c.convert(rsa_dk); + return; + case base_pke_key_type_v1::ecies_p256: + c.convert(ecies_dk); + return; + default: + c.set_error(); + return; + } + } +}; + +static cbmpc_error_t parse_rsa_prv_from_dk_blob(cmem_t dk_blob, coinbase::crypto::rsa_prv_key_t& out_sk) { + base_pke_dk_blob_v1_t blob; + const coinbase::error_t rv = coinbase::convert(blob, mem_t(dk_blob.data, dk_blob.size)); + if (rv) return rv; + if (blob.version != base_pke_key_blob_version_v1) return E_FORMAT; + if (static_cast(blob.key_type) != base_pke_key_type_v1::rsa_oaep_2048) return E_BADARG; + out_sk = blob.rsa_dk; + return CBMPC_SUCCESS; +} + +static cbmpc_error_t parse_ecies_prv_from_dk_blob(cmem_t dk_blob, coinbase::crypto::ecc_prv_key_t& out_sk) { + base_pke_dk_blob_v1_t blob; + const coinbase::error_t rv = coinbase::convert(blob, mem_t(dk_blob.data, dk_blob.size)); + if (rv) return rv; + if (blob.version != base_pke_key_blob_version_v1) return E_FORMAT; + if (static_cast(blob.key_type) != base_pke_key_type_v1::ecies_p256) return E_BADARG; + out_sk = blob.ecies_dk; + return CBMPC_SUCCESS; +} + +} // namespace + +TEST(CApiPve, EncryptVerifyDecrypt_CustomBasePke) { + int mode0 = 0; + const cbmpc_pve_base_pke_t base_pke0 = { + /*ctx=*/&mode0, + /*encrypt=*/toy_encrypt, + /*decrypt=*/toy_decrypt, + }; + + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t ek = {reinterpret_cast(const_cast("ek")), 2}; + const cmem_t dk = {reinterpret_cast(const_cast("dk")), 2}; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(i); + const cmem_t x = {x_bytes.data(), static_cast(x_bytes.size())}; + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(&base_pke0, curve, ek, label, x, &ct), CBMPC_SUCCESS); + ASSERT_GT(ct.size, 0); + + cmem_t Q_ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_get_Q(ct, &Q_ct), CBMPC_SUCCESS); + + cmem_t L_ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_get_Label(ct, &L_ct), CBMPC_SUCCESS); + expect_eq(L_ct, label); + + const buf_t Q_expected_buf = expected_Q(curve, mem_t(x.data, x.size)); + ASSERT_EQ(Q_ct.size, Q_expected_buf.size()); + ASSERT_EQ(std::memcmp(Q_ct.data, Q_expected_buf.data(), static_cast(Q_ct.size)), 0); + + ASSERT_EQ(cbmpc_pve_verify(&base_pke0, curve, ek, ct, Q_ct, label), CBMPC_SUCCESS); + + cmem_t x_out{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_decrypt(&base_pke0, curve, dk, ek, ct, label, &x_out), CBMPC_SUCCESS); + ASSERT_EQ(x_out.size, 32); + ASSERT_EQ(std::memcmp(x_out.data, x.data, 32), 0); + + cbmpc_cmem_free(x_out); + cbmpc_cmem_free(L_ct); + cbmpc_cmem_free(Q_ct); + cbmpc_cmem_free(ct); +} + +TEST(CApiPve, EncVerDec_DefBasePke_EciesBlob) { + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0xB0 + i); + const cmem_t x = {x_bytes.data(), static_cast(x_bytes.size())}; + + cmem_t ek_blob{nullptr, 0}; + cmem_t dk_blob{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek_blob, &dk_blob), CBMPC_SUCCESS); + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(/*base_pke=*/nullptr, curve, ek_blob, label, x, &ct), CBMPC_SUCCESS); + + const buf_t Q_expected_buf = expected_Q(curve, mem_t(x.data, x.size)); + const cmem_t Q_expected = {const_cast(Q_expected_buf.data()), Q_expected_buf.size()}; + ASSERT_EQ(cbmpc_pve_verify(/*base_pke=*/nullptr, curve, ek_blob, ct, Q_expected, label), CBMPC_SUCCESS); + + cmem_t x_out{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_decrypt(/*base_pke=*/nullptr, curve, dk_blob, ek_blob, ct, label, &x_out), CBMPC_SUCCESS); + ASSERT_EQ(x_out.size, 32); + ASSERT_EQ(std::memcmp(x_out.data, x.data, 32), 0); + + cbmpc_cmem_free(x_out); + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk_blob); + cbmpc_cmem_free(ek_blob); +} + +TEST(CApiPve, EncVerDec_DefBasePke_RsaBlob) { + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0xC0 + i); + const cmem_t x = {x_bytes.data(), static_cast(x_bytes.size())}; + + cmem_t ek_blob{nullptr, 0}; + cmem_t dk_blob{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(&ek_blob, &dk_blob), CBMPC_SUCCESS); + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(/*base_pke=*/nullptr, curve, ek_blob, label, x, &ct), CBMPC_SUCCESS); + + const buf_t Q_expected_buf = expected_Q(curve, mem_t(x.data, x.size)); + const cmem_t Q_expected = {const_cast(Q_expected_buf.data()), Q_expected_buf.size()}; + ASSERT_EQ(cbmpc_pve_verify(/*base_pke=*/nullptr, curve, ek_blob, ct, Q_expected, label), CBMPC_SUCCESS); + + cmem_t x_out{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_decrypt(/*base_pke=*/nullptr, curve, dk_blob, ek_blob, ct, label, &x_out), CBMPC_SUCCESS); + ASSERT_EQ(x_out.size, 32); + ASSERT_EQ(std::memcmp(x_out.data, x.data, 32), 0); + + cbmpc_cmem_free(x_out); + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk_blob); + cbmpc_cmem_free(ek_blob); +} + +static cbmpc_error_t rsa_oaep_hsm_decap_cb(void* ctx, cmem_t /*dk_handle*/, cmem_t kem_ct, cmem_t* out_kem_ss) { + if (!ctx || !out_kem_ss) return E_BADARG; + *out_kem_ss = cmem_t{nullptr, 0}; + auto* sk = static_cast(ctx); + + coinbase::buf_t kem_ss; + const coinbase::error_t rv = sk->decrypt_oaep(mem_t(kem_ct.data, kem_ct.size), coinbase::crypto::hash_e::sha256, + coinbase::crypto::hash_e::sha256, mem_t(), kem_ss); + if (rv) return rv; + + out_kem_ss->data = static_cast(cbmpc_malloc(static_cast(kem_ss.size()))); + if (!out_kem_ss->data) return E_INSUFFICIENT; + out_kem_ss->size = kem_ss.size(); + std::memmove(out_kem_ss->data, kem_ss.data(), static_cast(kem_ss.size())); + return CBMPC_SUCCESS; +} + +TEST(CApiPve, DecryptRsaOaepHsm_UsesCallback) { + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x44 + i); + const cmem_t x = {x_bytes.data(), static_cast(x_bytes.size())}; + + cmem_t ek_blob{nullptr, 0}; + cmem_t dk_blob{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(&ek_blob, &dk_blob), CBMPC_SUCCESS); + + coinbase::crypto::rsa_prv_key_t sk; + ASSERT_EQ(parse_rsa_prv_from_dk_blob(dk_blob, sk), CBMPC_SUCCESS); + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(/*base_pke=*/nullptr, curve, ek_blob, label, x, &ct), CBMPC_SUCCESS); + + cbmpc_pve_rsa_oaep_hsm_decap_t cb; + cb.ctx = &sk; + cb.decap = rsa_oaep_hsm_decap_cb; + + cmem_t x_out{nullptr, 0}; + ASSERT_EQ( + cbmpc_pve_decrypt_rsa_oaep_hsm(curve, + /*dk_handle=*/cmem_t{reinterpret_cast(const_cast("hsm")), 3}, + ek_blob, ct, label, &cb, &x_out), + CBMPC_SUCCESS); + ASSERT_EQ(x_out.size, 32); + ASSERT_EQ(std::memcmp(x_out.data, x.data, 32), 0); + + cbmpc_cmem_free(x_out); + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk_blob); + cbmpc_cmem_free(ek_blob); +} + +static cbmpc_error_t ecies_p256_hsm_ecdh_cb(void* ctx, cmem_t /*dk_handle*/, cmem_t kem_ct, cmem_t* out_dh_x32) { + if (!ctx || !out_dh_x32) return E_BADARG; + *out_dh_x32 = cmem_t{nullptr, 0}; + auto* sk = static_cast(ctx); + + coinbase::crypto::ecc_point_t E; + coinbase::error_t rv = E.from_oct(coinbase::crypto::curve_p256, mem_t(kem_ct.data, kem_ct.size)); + if (rv) return rv; + if (rv = coinbase::crypto::curve_p256.check(E)) return rv; + + const coinbase::buf_t dh = sk->ecdh(E); + if (dh.size() != 32) return E_FORMAT; + + out_dh_x32->data = static_cast(cbmpc_malloc(32)); + if (!out_dh_x32->data) return E_INSUFFICIENT; + out_dh_x32->size = 32; + std::memmove(out_dh_x32->data, dh.data(), 32); + return CBMPC_SUCCESS; +} + +TEST(CApiPve, DecryptEciesP256Hsm_UsesCallback) { + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x55 + i); + const cmem_t x = {x_bytes.data(), static_cast(x_bytes.size())}; + + cmem_t ek_blob{nullptr, 0}; + cmem_t dk_blob{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek_blob, &dk_blob), CBMPC_SUCCESS); + + coinbase::crypto::ecc_prv_key_t sk; + ASSERT_EQ(parse_ecies_prv_from_dk_blob(dk_blob, sk), CBMPC_SUCCESS); + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(/*base_pke=*/nullptr, curve, ek_blob, label, x, &ct), CBMPC_SUCCESS); + + cbmpc_pve_ecies_p256_hsm_ecdh_t cb; + cb.ctx = &sk; + cb.ecdh = ecies_p256_hsm_ecdh_cb; + + cmem_t x_out{nullptr, 0}; + ASSERT_EQ( + cbmpc_pve_decrypt_ecies_p256_hsm(curve, + /*dk_handle=*/cmem_t{reinterpret_cast(const_cast("hsm")), 3}, + ek_blob, ct, label, &cb, &x_out), + CBMPC_SUCCESS); + ASSERT_EQ(x_out.size, 32); + ASSERT_EQ(std::memcmp(x_out.data, x.data, 32), 0); + + cbmpc_cmem_free(x_out); + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk_blob); + cbmpc_cmem_free(ek_blob); +} + +struct toy_kem_ctx_t { + uint8_t tag = 0; +}; + +static cbmpc_error_t toy_kem_encap(void* ctx, cmem_t /*ek*/, cmem_t rho32, cmem_t* out_kem_ct, cmem_t* out_kem_ss) { + if (!ctx || !out_kem_ct || !out_kem_ss) return E_BADARG; + if (rho32.size != 32) return E_BADARG; + *out_kem_ct = cmem_t{nullptr, 0}; + *out_kem_ss = cmem_t{nullptr, 0}; + + const auto* c = static_cast(ctx); + out_kem_ct->data = static_cast(cbmpc_malloc(32)); + out_kem_ss->data = static_cast(cbmpc_malloc(32)); + if (!out_kem_ct->data || !out_kem_ss->data) { + if (out_kem_ct->data) cbmpc_free(out_kem_ct->data); + if (out_kem_ss->data) cbmpc_free(out_kem_ss->data); + *out_kem_ct = cmem_t{nullptr, 0}; + *out_kem_ss = cmem_t{nullptr, 0}; + return E_INSUFFICIENT; + } + out_kem_ct->size = 32; + out_kem_ss->size = 32; + + for (int i = 0; i < 32; i++) { + const uint8_t b = static_cast(rho32.data[i] ^ c->tag); + out_kem_ct->data[i] = b; + out_kem_ss->data[i] = b; + } + return CBMPC_SUCCESS; +} + +static cbmpc_error_t toy_kem_decap(void* ctx, cmem_t /*dk*/, cmem_t kem_ct, cmem_t* out_kem_ss) { + if (!ctx || !out_kem_ss) return E_BADARG; + if (kem_ct.size != 32) return E_BADARG; + *out_kem_ss = cmem_t{nullptr, 0}; + + out_kem_ss->data = static_cast(cbmpc_malloc(32)); + if (!out_kem_ss->data) return E_INSUFFICIENT; + out_kem_ss->size = 32; + std::memmove(out_kem_ss->data, kem_ct.data, 32); + return CBMPC_SUCCESS; +} + +TEST(CApiPve, EncVerDec_CustomKem_TwoInstOneProc) { + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + const cmem_t ek = {reinterpret_cast(const_cast("ek")), 2}; + const cmem_t dk = {reinterpret_cast(const_cast("dk")), 2}; + + std::array x_bytes{}; + for (int i = 0; i < 32; i++) x_bytes[static_cast(i)] = static_cast(0x66 + i); + const cmem_t x = {x_bytes.data(), static_cast(x_bytes.size())}; + + toy_kem_ctx_t ctx_a{.tag = 0x00}; + toy_kem_ctx_t ctx_b{.tag = 0xFF}; + const cbmpc_pve_base_kem_t kem_a = {&ctx_a, toy_kem_encap, toy_kem_decap}; + const cbmpc_pve_base_kem_t kem_b = {&ctx_b, toy_kem_encap, toy_kem_decap}; + + cmem_t ct_a{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt_with_kem(&kem_a, curve, ek, label, x, &ct_a), CBMPC_SUCCESS); + + const buf_t Q_expected_buf = expected_Q(curve, mem_t(x.data, x.size)); + const cmem_t Q_expected = {const_cast(Q_expected_buf.data()), Q_expected_buf.size()}; + + ASSERT_EQ(cbmpc_pve_verify_with_kem(&kem_a, curve, ek, ct_a, Q_expected, label), CBMPC_SUCCESS); + dylog_disable_scope_t no_log_err; + EXPECT_NE(cbmpc_pve_verify_with_kem(&kem_b, curve, ek, ct_a, Q_expected, label), CBMPC_SUCCESS); + + cmem_t x_out{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_decrypt_with_kem(&kem_a, curve, dk, ek, ct_a, label, &x_out), CBMPC_SUCCESS); + ASSERT_EQ(x_out.size, 32); + ASSERT_EQ(std::memcmp(x_out.data, x.data, 32), 0); + cbmpc_cmem_free(x_out); + cbmpc_cmem_free(ct_a); +} + +TEST(CApiPve, BasePkeMismatchRejected) { + int mode0 = 0; + int mode1 = 1; + const cbmpc_pve_base_pke_t base_pke0 = { + /*ctx=*/&mode0, + /*encrypt=*/toy_encrypt, + /*decrypt=*/toy_decrypt, + }; + const cbmpc_pve_base_pke_t base_pke1 = { + /*ctx=*/&mode1, + /*encrypt=*/toy_encrypt, + /*decrypt=*/toy_decrypt, + }; + + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t ek = {reinterpret_cast(const_cast("ek")), 2}; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + std::array x_bytes{}; + x_bytes[0] = 9; + const cmem_t x = {x_bytes.data(), static_cast(x_bytes.size())}; + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(&base_pke0, curve, ek, label, x, &ct), CBMPC_SUCCESS); + + const buf_t Q_expected_buf = expected_Q(curve, mem_t(x.data, x.size)); + cmem_t Q_expected{const_cast(Q_expected_buf.data()), Q_expected_buf.size()}; + + dylog_disable_scope_t no_log_err; + EXPECT_NE(cbmpc_pve_verify(&base_pke1, curve, ek, ct, Q_expected, label), CBMPC_SUCCESS); + + cbmpc_cmem_free(ct); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(CApiPveNeg, Encrypt) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek, &dk), CBMPC_SUCCESS); + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + const cmem_t x = {x_bytes.data(), 32}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t ct{nullptr, 0}; + + EXPECT_EQ(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, x, nullptr), E_BADARG); + EXPECT_NE(cbmpc_pve_encrypt(nullptr, static_cast(0), ek, label, x, &ct), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, label, x, &ct), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, cmem_t{nullptr, 0}, x, &ct), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, cmem_t{nullptr, 0}, &ct), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{garbage, 4}, label, x, &ct), CBMPC_SUCCESS); + + cbmpc_cmem_free(dk); + cbmpc_cmem_free(ek); +} + +TEST(CApiPveNeg, Verify) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek, &dk), CBMPC_SUCCESS); + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + const cmem_t x = {x_bytes.data(), 32}; + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, x, &ct), CBMPC_SUCCESS); + const auto Q_buf = expected_Q(CBMPC_CURVE_SECP256K1, coinbase::mem_t(x.data, x.size)); + const cmem_t Q = {const_cast(Q_buf.data()), Q_buf.size()}; + + EXPECT_NE(cbmpc_pve_verify(nullptr, static_cast(0), ek, ct, Q, label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_verify(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, ct, Q, label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_verify(nullptr, CBMPC_CURVE_SECP256K1, ek, cmem_t{nullptr, 0}, Q, label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_verify(nullptr, CBMPC_CURVE_SECP256K1, ek, ct, cmem_t{nullptr, 0}, label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_verify(nullptr, CBMPC_CURVE_SECP256K1, ek, ct, Q, cmem_t{nullptr, 0}), CBMPC_SUCCESS); + + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk); + cbmpc_cmem_free(ek); +} + +TEST(CApiPveNeg, Decrypt) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek, &dk), CBMPC_SUCCESS); + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + const cmem_t x = {x_bytes.data(), 32}; + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, x, &ct), CBMPC_SUCCESS); + cmem_t x_out{nullptr, 0}; + + EXPECT_EQ(cbmpc_pve_decrypt(nullptr, CBMPC_CURVE_SECP256K1, dk, ek, ct, label, nullptr), E_BADARG); + EXPECT_NE(cbmpc_pve_decrypt(nullptr, static_cast(0), dk, ek, ct, label, &x_out), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_decrypt(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, ek, ct, label, &x_out), + CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_decrypt(nullptr, CBMPC_CURVE_SECP256K1, dk, cmem_t{nullptr, 0}, ct, label, &x_out), + CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_decrypt(nullptr, CBMPC_CURVE_SECP256K1, dk, ek, cmem_t{nullptr, 0}, label, &x_out), + CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_decrypt(nullptr, CBMPC_CURVE_SECP256K1, dk, ek, ct, cmem_t{nullptr, 0}, &x_out), CBMPC_SUCCESS); + + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk); + cbmpc_cmem_free(ek); +} + +TEST(CApiPveNeg, GetQ) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_EQ(cbmpc_pve_get_Q(cmem_t{nullptr, 0}, nullptr), E_BADARG); + cmem_t Q{nullptr, 0}; + EXPECT_NE(cbmpc_pve_get_Q(cmem_t{nullptr, 0}, &Q), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_get_Q(cmem_t{garbage, 4}, &Q), CBMPC_SUCCESS); +} + +TEST(CApiPveNeg, GetLabel) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_EQ(cbmpc_pve_get_Label(cmem_t{nullptr, 0}, nullptr), E_BADARG); + cmem_t label{nullptr, 0}; + EXPECT_NE(cbmpc_pve_get_Label(cmem_t{nullptr, 0}, &label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_get_Label(cmem_t{garbage, 4}, &label), CBMPC_SUCCESS); +} + +TEST(CApiPveNeg, GenerateBasePkeRsaKeypair) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(nullptr, &dk), E_BADARG); + EXPECT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(&ek, nullptr), E_BADARG); +} + +TEST(CApiPveNeg, GenerateBasePkeEciesP256Keypair) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(nullptr, &dk), E_BADARG); + EXPECT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek, nullptr), E_BADARG); +} + +TEST(CApiPveNeg, BasePkeEciesP256EkFromOct) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::array wrong_size{}; + cmem_t ek{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_base_pke_ecies_p256_ek_from_oct(cmem_t{nullptr, 0}, nullptr), E_BADARG); + EXPECT_NE(cbmpc_pve_base_pke_ecies_p256_ek_from_oct(cmem_t{nullptr, 0}, &ek), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_base_pke_ecies_p256_ek_from_oct(cmem_t{garbage, 4}, &ek), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_base_pke_ecies_p256_ek_from_oct(cmem_t{wrong_size.data(), 64}, &ek), CBMPC_SUCCESS); +} + +TEST(CApiPveNeg, BasePkeRsaEkFromModulus) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + std::array wrong_size{}; + wrong_size[0] = 1; + std::array all_zero{}; + cmem_t ek{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_base_pke_rsa_ek_from_modulus(cmem_t{nullptr, 0}, nullptr), E_BADARG); + EXPECT_NE(cbmpc_pve_base_pke_rsa_ek_from_modulus(cmem_t{nullptr, 0}, &ek), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_base_pke_rsa_ek_from_modulus(cmem_t{garbage, 4}, &ek), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_base_pke_rsa_ek_from_modulus(cmem_t{wrong_size.data(), 128}, &ek), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_base_pke_rsa_ek_from_modulus(cmem_t{all_zero.data(), 256}, &ek), CBMPC_SUCCESS); +} + +TEST(CApiPveNeg, DecryptRsaOaepHsm) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(&ek, &dk), CBMPC_SUCCESS); + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + const cmem_t x = {x_bytes.data(), 32}; + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, x, &ct), CBMPC_SUCCESS); + coinbase::crypto::rsa_prv_key_t sk; + ASSERT_EQ(parse_rsa_prv_from_dk_blob(dk, sk), CBMPC_SUCCESS); + cbmpc_pve_rsa_oaep_hsm_decap_t cb; + cb.ctx = &sk; + cb.decap = rsa_oaep_hsm_decap_cb; + cbmpc_pve_rsa_oaep_hsm_decap_t cb_null_decap; + cb_null_decap.ctx = nullptr; + cb_null_decap.decap = nullptr; + cmem_t x_out{nullptr, 0}; + const cmem_t dk_handle = {reinterpret_cast(const_cast("hsm")), 3}; + + EXPECT_EQ(cbmpc_pve_decrypt_rsa_oaep_hsm(CBMPC_CURVE_SECP256K1, dk_handle, ek, ct, label, &cb, nullptr), E_BADARG); + EXPECT_EQ(cbmpc_pve_decrypt_rsa_oaep_hsm(CBMPC_CURVE_SECP256K1, dk_handle, ek, ct, label, nullptr, &x_out), E_BADARG); + EXPECT_EQ(cbmpc_pve_decrypt_rsa_oaep_hsm(CBMPC_CURVE_SECP256K1, dk_handle, ek, ct, label, &cb_null_decap, &x_out), + E_BADARG); + EXPECT_NE(cbmpc_pve_decrypt_rsa_oaep_hsm(CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, ek, ct, label, &cb, &x_out), + CBMPC_SUCCESS); + + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk); + cbmpc_cmem_free(ek); +} + +TEST(CApiPveNeg, DecryptEciesP256Hsm) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek, &dk), CBMPC_SUCCESS); + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + const cmem_t x = {x_bytes.data(), 32}; + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, x, &ct), CBMPC_SUCCESS); + coinbase::crypto::ecc_prv_key_t ecies_sk; + ASSERT_EQ(parse_ecies_prv_from_dk_blob(dk, ecies_sk), CBMPC_SUCCESS); + cbmpc_pve_ecies_p256_hsm_ecdh_t cb; + cb.ctx = &ecies_sk; + cb.ecdh = ecies_p256_hsm_ecdh_cb; + cbmpc_pve_ecies_p256_hsm_ecdh_t cb_null_ecdh; + cb_null_ecdh.ctx = nullptr; + cb_null_ecdh.ecdh = nullptr; + cmem_t x_out{nullptr, 0}; + const cmem_t dk_handle = {reinterpret_cast(const_cast("hsm")), 3}; + + EXPECT_EQ(cbmpc_pve_decrypt_ecies_p256_hsm(CBMPC_CURVE_SECP256K1, dk_handle, ek, ct, label, &cb, nullptr), E_BADARG); + EXPECT_EQ(cbmpc_pve_decrypt_ecies_p256_hsm(CBMPC_CURVE_SECP256K1, dk_handle, ek, ct, label, nullptr, &x_out), + E_BADARG); + EXPECT_EQ(cbmpc_pve_decrypt_ecies_p256_hsm(CBMPC_CURVE_SECP256K1, dk_handle, ek, ct, label, &cb_null_ecdh, &x_out), + E_BADARG); + EXPECT_NE(cbmpc_pve_decrypt_ecies_p256_hsm(CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, ek, ct, label, &cb, &x_out), + CBMPC_SUCCESS); + + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk); + cbmpc_cmem_free(ek); +} + +TEST(CApiPveNeg, EncryptWithKem) { + dylog_disable_scope_t no_log; + const cmem_t ek = {reinterpret_cast(const_cast("ek")), 2}; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + const cmem_t x = {x_bytes.data(), 32}; + toy_kem_ctx_t ctx{.tag = 0x00}; + const cbmpc_pve_base_kem_t kem = {&ctx, toy_kem_encap, toy_kem_decap}; + const cbmpc_pve_base_kem_t kem_null_encap = {&ctx, nullptr, toy_kem_decap}; + cmem_t ct{nullptr, 0}; + + EXPECT_EQ(cbmpc_pve_encrypt_with_kem(&kem, CBMPC_CURVE_SECP256K1, ek, label, x, nullptr), E_BADARG); + EXPECT_EQ(cbmpc_pve_encrypt_with_kem(nullptr, CBMPC_CURVE_SECP256K1, ek, label, x, &ct), E_BADARG); + EXPECT_EQ(cbmpc_pve_encrypt_with_kem(&kem_null_encap, CBMPC_CURVE_SECP256K1, ek, label, x, &ct), E_BADARG); +} + +TEST(CApiPveNeg, VerifyWithKem) { + dylog_disable_scope_t no_log; + const cmem_t ek = {reinterpret_cast(const_cast("ek")), 2}; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + const cmem_t x = {x_bytes.data(), 32}; + toy_kem_ctx_t ctx{.tag = 0x00}; + const cbmpc_pve_base_kem_t kem = {&ctx, toy_kem_encap, toy_kem_decap}; + const cbmpc_pve_base_kem_t kem_null_encap = {&ctx, nullptr, toy_kem_decap}; + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt_with_kem(&kem, CBMPC_CURVE_SECP256K1, ek, label, x, &ct), CBMPC_SUCCESS); + const auto Q_buf = expected_Q(CBMPC_CURVE_SECP256K1, coinbase::mem_t(x.data, x.size)); + const cmem_t Q = {const_cast(Q_buf.data()), Q_buf.size()}; + + EXPECT_EQ(cbmpc_pve_verify_with_kem(nullptr, CBMPC_CURVE_SECP256K1, ek, ct, Q, label), E_BADARG); + EXPECT_EQ(cbmpc_pve_verify_with_kem(&kem_null_encap, CBMPC_CURVE_SECP256K1, ek, ct, Q, label), E_BADARG); + + cbmpc_cmem_free(ct); +} + +TEST(CApiPveNeg, DecryptWithKem) { + dylog_disable_scope_t no_log; + const cmem_t ek = {reinterpret_cast(const_cast("ek")), 2}; + const cmem_t dk = {reinterpret_cast(const_cast("dk")), 2}; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + const cmem_t x = {x_bytes.data(), 32}; + toy_kem_ctx_t ctx{.tag = 0x00}; + const cbmpc_pve_base_kem_t kem = {&ctx, toy_kem_encap, toy_kem_decap}; + const cbmpc_pve_base_kem_t kem_null_decap = {&ctx, toy_kem_encap, nullptr}; + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_encrypt_with_kem(&kem, CBMPC_CURVE_SECP256K1, ek, label, x, &ct), CBMPC_SUCCESS); + cmem_t x_out{nullptr, 0}; + + EXPECT_EQ(cbmpc_pve_decrypt_with_kem(&kem, CBMPC_CURVE_SECP256K1, dk, ek, ct, label, nullptr), E_BADARG); + EXPECT_EQ(cbmpc_pve_decrypt_with_kem(nullptr, CBMPC_CURVE_SECP256K1, dk, ek, ct, label, &x_out), E_BADARG); + EXPECT_EQ(cbmpc_pve_decrypt_with_kem(&kem_null_decap, CBMPC_CURVE_SECP256K1, dk, ek, ct, label, &x_out), E_BADARG); + + cbmpc_cmem_free(ct); +} diff --git a/tests/unit/c_api/test_pve_ac.cpp b/tests/unit/c_api/test_pve_ac.cpp new file mode 100644 index 00000000..bbe22465 --- /dev/null +++ b/tests/unit/c_api/test_pve_ac.cpp @@ -0,0 +1,739 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +using coinbase::buf_t; +using coinbase::mem_t; + +struct fake_hsm_ecies_p256_t { + std::unordered_map keys; +}; + +static cbmpc_error_t hsm_ecies_p256_ecdh_cb(void* ctx, cmem_t dk_handle, cmem_t kem_ct, cmem_t* out_dh_x32) { + if (!out_dh_x32) return E_BADARG; + *out_dh_x32 = cmem_t{nullptr, 0}; + if (!ctx) return E_BADARG; + if (dk_handle.size <= 0 || !dk_handle.data) return E_BADARG; + if (kem_ct.size < 0) return E_BADARG; + if (kem_ct.size > 0 && !kem_ct.data) return E_BADARG; + + const auto* hsm = static_cast(ctx); + const std::string handle(reinterpret_cast(dk_handle.data), static_cast(dk_handle.size)); + const auto it = hsm->keys.find(handle); + if (it == hsm->keys.end()) return E_BADARG; + + coinbase::crypto::ecc_point_t E; + coinbase::error_t rv = E.from_oct(coinbase::crypto::curve_p256, mem_t(kem_ct.data, kem_ct.size)); + if (rv) return rv; + if (rv = coinbase::crypto::curve_p256.check(E)) return rv; + + const buf_t dh = it->second.ecdh(E); + if (dh.size() != 32) return CBMPC_E_CRYPTO; + + out_dh_x32->data = static_cast(cbmpc_malloc(32)); + if (!out_dh_x32->data) return E_INSUFFICIENT; + out_dh_x32->size = 32; + std::memmove(out_dh_x32->data, dh.data(), 32); + return CBMPC_SUCCESS; +} + +static buf_t expected_Q(cbmpc_curve_id_t curve_id, mem_t x) { + const coinbase::crypto::ecurve_t curve = (curve_id == CBMPC_CURVE_P256) ? coinbase::crypto::curve_p256 + : (curve_id == CBMPC_CURVE_SECP256K1) ? coinbase::crypto::curve_secp256k1 + : (curve_id == CBMPC_CURVE_ED25519) ? coinbase::crypto::curve_ed25519 + : coinbase::crypto::ecurve_t(); + cb_assert(curve.valid()); + + const coinbase::crypto::bn_t bn_x = coinbase::crypto::bn_t::from_bin(x) % curve.order(); + const coinbase::crypto::ecc_point_t Q = bn_x * curve.generator(); + return Q.to_compressed_bin(); +} + +static cbmpc_error_t toy_encrypt(void* /*ctx*/, cmem_t /*ek*/, cmem_t /*label*/, cmem_t plain, cmem_t /*rho*/, + cmem_t* out_ct) { + if (!out_ct) return E_BADARG; + *out_ct = cmem_t{nullptr, 0}; + if (plain.size < 0) return E_BADARG; + if (plain.size > 0 && !plain.data) return E_BADARG; + + const int n = plain.size; + if (n == 0) return CBMPC_SUCCESS; + + out_ct->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_ct->data) return E_INSUFFICIENT; + out_ct->size = n; + std::memmove(out_ct->data, plain.data, static_cast(n)); + return CBMPC_SUCCESS; +} + +static cbmpc_error_t toy_decrypt(void* /*ctx*/, cmem_t /*dk*/, cmem_t /*label*/, cmem_t ct, cmem_t* out_plain) { + if (!out_plain) return E_BADARG; + *out_plain = cmem_t{nullptr, 0}; + if (ct.size < 0) return E_BADARG; + if (ct.size > 0 && !ct.data) return E_BADARG; + + const int n = ct.size; + if (n == 0) return CBMPC_SUCCESS; + + out_plain->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_plain->data) return E_INSUFFICIENT; + out_plain->size = n; + std::memmove(out_plain->data, ct.data, static_cast(n)); + return CBMPC_SUCCESS; +} + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +} // namespace + +TEST(CApiPveAc, EncVer_PDec_Agg_DefPke_Rsa) { + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + // Access structure: THRESHOLD(2-of-3) over leaves {p1,p2,p3}. + const char* p1 = "p1"; + const char* p2 = "p2"; + const char* p3 = "p3"; + + const std::array child_indices = {1, 2, 3}; + const std::array nodes = { + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, + /*off=*/0, /*cnt=*/3}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p1, /*k=*/0, /*off=*/0, /*cnt=*/0}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p2, /*k=*/0, /*off=*/0, /*cnt=*/0}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p3, /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes.data(), + /*nodes_count=*/static_cast(nodes.size()), + /*child_indices=*/child_indices.data(), + /*child_indices_count=*/static_cast(child_indices.size()), + /*root_index=*/0, + }; + + constexpr int n = 4; + std::array(n) * 32> xs_flat{}; + std::array xs_sizes{}; + for (int i = 0; i < n; i++) { + xs_sizes[static_cast(i)] = 32; + for (int j = 0; j < 32; j++) xs_flat[static_cast(i * 32 + j)] = static_cast(i + j); + } + const cmems_t xs_in = {n, xs_flat.data(), xs_sizes.data()}; + + // Per-leaf base-PKE key blobs. + std::array eks = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + std::array dks = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(&eks[0], &dks[0]), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(&eks[1], &dks[1]), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(&eks[2], &dks[2]), CBMPC_SUCCESS); + + const std::array leaf_names = {p1, p2, p3}; + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_ac_encrypt(/*base_pke=*/nullptr, curve, &ac, leaf_names.data(), eks.data(), + static_cast(eks.size()), label, xs_in, &ct), + CBMPC_SUCCESS); + + int batch_count = 0; + ASSERT_EQ(cbmpc_pve_ac_get_count(ct, &batch_count), CBMPC_SUCCESS); + ASSERT_EQ(batch_count, n); + + // Expected Qs. + std::array(n) * 33> Qs_flat{}; + std::array Qs_sizes{}; + for (int i = 0; i < n; i++) { + const mem_t xi(xs_flat.data() + i * 32, 32); + const buf_t qi = expected_Q(curve, xi); + ASSERT_EQ(qi.size(), 33); + Qs_sizes[static_cast(i)] = qi.size(); + std::memmove(Qs_flat.data() + i * 33, qi.data(), static_cast(qi.size())); + } + const cmems_t Qs_expected = {n, Qs_flat.data(), Qs_sizes.data()}; + + ASSERT_EQ(cbmpc_pve_ac_verify(/*base_pke=*/nullptr, curve, &ac, leaf_names.data(), eks.data(), + static_cast(eks.size()), ct, Qs_expected, label), + CBMPC_SUCCESS); + + const int attempt_index = 0; + cmem_t share_p1{nullptr, 0}; + cmem_t share_p2{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_ac_partial_decrypt_attempt(/*base_pke=*/nullptr, curve, &ac, ct, attempt_index, p1, dks[0], label, + &share_p1), + CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_pve_ac_partial_decrypt_attempt(/*base_pke=*/nullptr, curve, &ac, ct, attempt_index, p2, dks[1], label, + &share_p2), + CBMPC_SUCCESS); + + const std::array quorum_names = {p1, p2}; + const std::array quorum_shares = {share_p1, share_p2}; + + cmems_t xs_out{0, nullptr, nullptr}; + ASSERT_EQ(cbmpc_pve_ac_combine(/*base_pke=*/nullptr, curve, &ac, quorum_names.data(), quorum_shares.data(), + static_cast(quorum_shares.size()), ct, attempt_index, label, &xs_out), + CBMPC_SUCCESS); + ASSERT_EQ(xs_out.count, n); + + int off = 0; + for (int i = 0; i < n; i++) { + ASSERT_EQ(xs_out.sizes[i], 32); + ASSERT_EQ(std::memcmp(xs_out.data + off, xs_flat.data() + i * 32, 32), 0); + off += xs_out.sizes[i]; + } + + // Insufficient quorum should fail. + const std::array q1_names = {p1}; + const std::array q1_shares = {share_p1}; + cmems_t xs_out2{0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_combine(/*base_pke=*/nullptr, curve, &ac, q1_names.data(), q1_shares.data(), + static_cast(q1_shares.size()), ct, attempt_index, label, &xs_out2), + CBMPC_SUCCESS); + + cbmpc_cmem_free(share_p2); + cbmpc_cmem_free(share_p1); + cbmpc_cmems_free(xs_out); + cbmpc_cmem_free(ct); + for (int i = 0; i < 3; i++) { + cbmpc_cmem_free(dks[static_cast(i)]); + cbmpc_cmem_free(eks[static_cast(i)]); + } +} + +TEST(CApiPveAc, EncVer_PartDec_Agg_CustomBasePke) { + const cbmpc_pve_base_pke_t base_pke = { + /*ctx=*/nullptr, + /*encrypt=*/toy_encrypt, + /*decrypt=*/toy_decrypt, + }; + + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + const char* p1 = "p1"; + const char* p2 = "p2"; + const char* p3 = "p3"; + const std::array child_indices = {1, 2, 3}; + const std::array nodes = { + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, + /*off=*/0, /*cnt=*/3}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p1, /*k=*/0, /*off=*/0, /*cnt=*/0}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p2, /*k=*/0, /*off=*/0, /*cnt=*/0}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p3, /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes.data(), + /*nodes_count=*/static_cast(nodes.size()), + /*child_indices=*/child_indices.data(), + /*child_indices_count=*/static_cast(child_indices.size()), + /*root_index=*/0, + }; + + constexpr int n = 3; + std::array(n) * 32> xs_flat{}; + std::array xs_sizes{}; + for (int i = 0; i < n; i++) { + xs_sizes[static_cast(i)] = 32; + for (int j = 0; j < 32; j++) xs_flat[static_cast(i * 32 + j)] = static_cast(0x77 + i + j); + } + const cmems_t xs_in = {n, xs_flat.data(), xs_sizes.data()}; + + const cmem_t ek1 = {reinterpret_cast(const_cast("ek1")), 3}; + const cmem_t ek2 = {reinterpret_cast(const_cast("ek2")), 3}; + const cmem_t ek3 = {reinterpret_cast(const_cast("ek3")), 3}; + const std::array eks = {ek1, ek2, ek3}; + const std::array leaf_names = {p1, p2, p3}; + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_ac_encrypt(&base_pke, curve, &ac, leaf_names.data(), eks.data(), static_cast(eks.size()), + label, xs_in, &ct), + CBMPC_SUCCESS); + + // Verify expected Qs. + std::array(n) * 33> Qs_flat{}; + std::array Qs_sizes{}; + for (int i = 0; i < n; i++) { + const mem_t xi(xs_flat.data() + i * 32, 32); + const buf_t qi = expected_Q(curve, xi); + ASSERT_EQ(qi.size(), 33); + Qs_sizes[static_cast(i)] = qi.size(); + std::memmove(Qs_flat.data() + i * 33, qi.data(), static_cast(qi.size())); + } + const cmems_t Qs_expected = {n, Qs_flat.data(), Qs_sizes.data()}; + + ASSERT_EQ(cbmpc_pve_ac_verify(&base_pke, curve, &ac, leaf_names.data(), eks.data(), static_cast(eks.size()), ct, + Qs_expected, label), + CBMPC_SUCCESS); + + const int attempt_index = 0; + cmem_t share_p1{nullptr, 0}; + cmem_t share_p3{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_ac_partial_decrypt_attempt(&base_pke, curve, &ac, ct, attempt_index, p1, ek1, label, &share_p1), + CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_pve_ac_partial_decrypt_attempt(&base_pke, curve, &ac, ct, attempt_index, p3, ek3, label, &share_p3), + CBMPC_SUCCESS); + + const std::array quorum_names = {p1, p3}; + const std::array quorum_shares = {share_p1, share_p3}; + + cmems_t xs_out{0, nullptr, nullptr}; + ASSERT_EQ(cbmpc_pve_ac_combine(&base_pke, curve, &ac, quorum_names.data(), quorum_shares.data(), + static_cast(quorum_shares.size()), ct, attempt_index, label, &xs_out), + CBMPC_SUCCESS); + ASSERT_EQ(xs_out.count, n); + + int off = 0; + for (int i = 0; i < n; i++) { + ASSERT_EQ(xs_out.sizes[i], 32); + ASSERT_EQ(std::memcmp(xs_out.data + off, xs_flat.data() + i * 32, 32), 0); + off += xs_out.sizes[i]; + } + + cbmpc_cmem_free(share_p3); + cbmpc_cmem_free(share_p1); + cbmpc_cmems_free(xs_out); + cbmpc_cmem_free(ct); +} + +TEST(CApiPveAc, EncVer_PDec_Agg_DefPke_EciesHsmRow) { + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + // Access structure: THRESHOLD(2-of-3) over leaves {p1,p2,p3}. + const char* p1 = "p1"; + const char* p2 = "p2"; + const char* p3 = "p3"; + const std::array child_indices = {1, 2, 3}; + const std::array nodes = { + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, + /*off=*/0, /*cnt=*/3}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p1, /*k=*/0, /*off=*/0, /*cnt=*/0}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p2, /*k=*/0, /*off=*/0, /*cnt=*/0}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p3, /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes.data(), + /*nodes_count=*/static_cast(nodes.size()), + /*child_indices=*/child_indices.data(), + /*child_indices_count=*/static_cast(child_indices.size()), + /*root_index=*/0, + }; + + constexpr int n = 4; + std::array(n) * 32> xs_flat{}; + std::array xs_sizes{}; + for (int i = 0; i < n; i++) { + xs_sizes[static_cast(i)] = 32; + for (int j = 0; j < 32; j++) xs_flat[static_cast(i * 32 + j)] = static_cast(0x55 + i + j); + } + const cmems_t xs_in = {n, xs_flat.data(), xs_sizes.data()}; + + // Simulated HSM that stores P-256 private keys. The library only sees opaque handles. + fake_hsm_ecies_p256_t hsm; + std::array handles = {"hsm-ecies-p256-p1", "hsm-ecies-p256-p2", "hsm-ecies-p256-p3"}; + + std::array eks = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + for (int i = 0; i < 3; i++) { + coinbase::crypto::ecc_prv_key_t prv; + prv.generate(coinbase::crypto::curve_p256); + + const buf_t pub_oct = prv.pub().to_oct(); + ASSERT_EQ(pub_oct.size(), 65); + + const cmem_t pub_oct_mem = {pub_oct.data(), pub_oct.size()}; + ASSERT_EQ(cbmpc_pve_base_pke_ecies_p256_ek_from_oct(pub_oct_mem, &eks[static_cast(i)]), CBMPC_SUCCESS); + ASSERT_GT(eks[static_cast(i)].size, 0); + + hsm.keys.emplace(handles[static_cast(i)], std::move(prv)); + } + + const std::array leaf_names = {p1, p2, p3}; + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_ac_encrypt(/*base_pke=*/nullptr, curve, &ac, leaf_names.data(), eks.data(), + static_cast(eks.size()), label, xs_in, &ct), + CBMPC_SUCCESS); + + // Verify expected Qs. + std::array(n) * 33> Qs_flat{}; + std::array Qs_sizes{}; + for (int i = 0; i < n; i++) { + const mem_t xi(xs_flat.data() + i * 32, 32); + const buf_t qi = expected_Q(curve, xi); + ASSERT_EQ(qi.size(), 33); + Qs_sizes[static_cast(i)] = qi.size(); + std::memmove(Qs_flat.data() + i * 33, qi.data(), static_cast(qi.size())); + } + const cmems_t Qs_expected = {n, Qs_flat.data(), Qs_sizes.data()}; + + ASSERT_EQ(cbmpc_pve_ac_verify(/*base_pke=*/nullptr, curve, &ac, leaf_names.data(), eks.data(), + static_cast(eks.size()), ct, Qs_expected, label), + CBMPC_SUCCESS); + + const cbmpc_pve_ecies_p256_hsm_ecdh_t cb = { + /*ctx=*/&hsm, + /*ecdh=*/hsm_ecies_p256_ecdh_cb, + }; + + const int attempt_index = 0; + const cmem_t h1 = {reinterpret_cast(handles[0].data()), static_cast(handles[0].size())}; + const cmem_t h3 = {reinterpret_cast(handles[2].data()), static_cast(handles[2].size())}; + + cmem_t share_p1{nullptr, 0}; + cmem_t share_p3{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_ac_partial_decrypt_attempt_ecies_p256_hsm(curve, &ac, ct, attempt_index, p1, h1, eks[0], label, + &cb, &share_p1), + CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_pve_ac_partial_decrypt_attempt_ecies_p256_hsm(curve, &ac, ct, attempt_index, p3, h3, eks[2], label, + &cb, &share_p3), + CBMPC_SUCCESS); + + const std::array quorum_names = {p1, p3}; + const std::array quorum_shares = {share_p1, share_p3}; + + cmems_t xs_out{0, nullptr, nullptr}; + ASSERT_EQ(cbmpc_pve_ac_combine(/*base_pke=*/nullptr, curve, &ac, quorum_names.data(), quorum_shares.data(), + static_cast(quorum_shares.size()), ct, attempt_index, label, &xs_out), + CBMPC_SUCCESS); + ASSERT_EQ(xs_out.count, n); + + int off = 0; + for (int i = 0; i < n; i++) { + ASSERT_EQ(xs_out.sizes[i], 32); + ASSERT_EQ(std::memcmp(xs_out.data + off, xs_flat.data() + i * 32, 32), 0); + off += xs_out.sizes[i]; + } + + cbmpc_cmem_free(share_p3); + cbmpc_cmem_free(share_p1); + cbmpc_cmems_free(xs_out); + cbmpc_cmem_free(ct); + for (int i = 0; i < 3; i++) cbmpc_cmem_free(eks[static_cast(i)]); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +static cbmpc_error_t noop_rsa_decap(void* /*ctx*/, cmem_t /*dk_handle*/, cmem_t /*kem_ct*/, cmem_t* /*out*/) { + return E_BADARG; +} + +static cbmpc_error_t noop_ecies_ecdh(void* /*ctx*/, cmem_t /*dk_handle*/, cmem_t /*kem_ct*/, cmem_t* /*out*/) { + return E_BADARG; +} + +class CApiPveAcNeg : public ::testing::Test { + protected: + void SetUp() override { + base_pke_ = {nullptr, toy_encrypt, toy_decrypt}; + label_ = {reinterpret_cast(const_cast("label")), 5}; + ek1_ = {reinterpret_cast(const_cast("ek1")), 3}; + ek2_ = {reinterpret_cast(const_cast("ek2")), 3}; + ek3_ = {reinterpret_cast(const_cast("ek3")), 3}; + child_indices_ = {1, 2, 3}; + nodes_ = { + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p1_, 0, 0, 0}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p2_, 0, 0, 0}, + cbmpc_access_structure_node_t{CBMPC_ACCESS_STRUCTURE_NODE_LEAF, p3_, 0, 0, 0}, + }; + ac_ = {nodes_.data(), static_cast(nodes_.size()), child_indices_.data(), + static_cast(child_indices_.size()), 0}; + leaf_names_ = {p1_, p2_, p3_}; + eks_ = {ek1_, ek2_, ek3_}; + for (int j = 0; j < 32; j++) xs_flat_[static_cast(j)] = static_cast(j); + xs_sizes_[0] = 32; + xs_ = {1, xs_flat_.data(), xs_sizes_.data()}; + } + + const char* p1_ = "p1"; + const char* p2_ = "p2"; + const char* p3_ = "p3"; + cbmpc_pve_base_pke_t base_pke_{}; + cbmpc_curve_id_t curve_ = CBMPC_CURVE_SECP256K1; + cmem_t label_{}; + cmem_t ek1_{}; + cmem_t ek2_{}; + cmem_t ek3_{}; + std::array child_indices_{}; + std::array nodes_{}; + cbmpc_access_structure_t ac_{}; + std::array leaf_names_{}; + std::array eks_{}; + std::array xs_flat_{}; + std::array xs_sizes_{}; + cmems_t xs_{}; +}; + +TEST_F(CApiPveAcNeg, EncryptNullOutCiphertext) { + EXPECT_EQ(cbmpc_pve_ac_encrypt(&base_pke_, curve_, &ac_, leaf_names_.data(), eks_.data(), 3, label_, xs_, nullptr), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, EncryptInvalidCurve) { + cmem_t ct{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_encrypt(&base_pke_, static_cast(0), &ac_, leaf_names_.data(), eks_.data(), 3, + label_, xs_, &ct), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, EncryptNullAc) { + cmem_t ct{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_ac_encrypt(&base_pke_, curve_, nullptr, leaf_names_.data(), eks_.data(), 3, label_, xs_, &ct), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, EncryptNullLeafNames) { + cmem_t ct{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_ac_encrypt(&base_pke_, curve_, &ac_, nullptr, eks_.data(), 3, label_, xs_, &ct), E_BADARG); +} + +TEST_F(CApiPveAcNeg, EncryptNullLeafEks) { + cmem_t ct{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_ac_encrypt(&base_pke_, curve_, &ac_, leaf_names_.data(), nullptr, 3, label_, xs_, &ct), E_BADARG); +} + +TEST_F(CApiPveAcNeg, EncryptLeafCountZero) { + cmem_t ct{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_encrypt(&base_pke_, curve_, &ac_, leaf_names_.data(), eks_.data(), 0, label_, xs_, &ct), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, EncryptLeafCountMismatch) { + cmem_t ct{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_encrypt(&base_pke_, curve_, &ac_, leaf_names_.data(), eks_.data(), 2, label_, xs_, &ct), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, EncryptEmptyLabelEmptyXs) { + const cmem_t empty_label = {nullptr, 0}; + const cmems_t empty_xs = {0, nullptr, nullptr}; + cmem_t ct{nullptr, 0}; + EXPECT_NE( + cbmpc_pve_ac_encrypt(&base_pke_, curve_, &ac_, leaf_names_.data(), eks_.data(), 3, empty_label, empty_xs, &ct), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, EncryptXsWithEmptyElement) { + std::array sizes = {0}; + const cmems_t xs = {1, nullptr, sizes.data()}; + cmem_t ct{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_encrypt(&base_pke_, curve_, &ac_, leaf_names_.data(), eks_.data(), 3, label_, xs, &ct), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, VerifyInvalidCurve) { + const cmems_t empty_Qs = {0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_verify(&base_pke_, static_cast(0), &ac_, leaf_names_.data(), eks_.data(), 3, + cmem_t{nullptr, 0}, empty_Qs, label_), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, VerifyNullAc) { + const cmems_t empty_Qs = {0, nullptr, nullptr}; + EXPECT_EQ(cbmpc_pve_ac_verify(&base_pke_, curve_, nullptr, leaf_names_.data(), eks_.data(), 3, cmem_t{nullptr, 0}, + empty_Qs, label_), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, VerifyEmptyCiphertextEmptyLabel) { + const cmems_t empty_Qs = {0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_verify(&base_pke_, curve_, &ac_, leaf_names_.data(), eks_.data(), 3, cmem_t{nullptr, 0}, + empty_Qs, cmem_t{nullptr, 0}), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, VerifyGarbageCiphertext) { + uint8_t garbage[4] = {0xDE, 0xAD, 0xBE, 0xEF}; + const cmem_t ct = {garbage, 4}; + const cmems_t empty_Qs = {0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_verify(&base_pke_, curve_, &ac_, leaf_names_.data(), eks_.data(), 3, ct, empty_Qs, label_), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, PartialDecryptNullOutShare) { + EXPECT_EQ(cbmpc_pve_ac_partial_decrypt_attempt(&base_pke_, curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, label_, nullptr), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, PartialDecryptInvalidCurve) { + cmem_t share{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_partial_decrypt_attempt(&base_pke_, static_cast(0), &ac_, cmem_t{nullptr, 0}, + 0, p1_, cmem_t{nullptr, 0}, label_, &share), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, PartialDecryptNullAc) { + cmem_t share{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_ac_partial_decrypt_attempt(&base_pke_, curve_, nullptr, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, label_, &share), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, PartialDecryptEmptyCiphertextDkLabel) { + cmem_t share{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_partial_decrypt_attempt(&base_pke_, curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &share), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, PartialDecryptNullAndEmptyLeafName) { + cmem_t share{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_partial_decrypt_attempt(&base_pke_, curve_, &ac_, cmem_t{nullptr, 0}, 0, nullptr, + cmem_t{nullptr, 0}, label_, &share), + CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_ac_partial_decrypt_attempt(&base_pke_, curve_, &ac_, cmem_t{nullptr, 0}, 0, "", + cmem_t{nullptr, 0}, label_, &share), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, PartialDecryptGarbageCiphertext) { + uint8_t garbage[4] = {0xDE, 0xAD, 0xBE, 0xEF}; + const cmem_t ct = {garbage, 4}; + cmem_t share{nullptr, 0}; + EXPECT_NE( + cbmpc_pve_ac_partial_decrypt_attempt(&base_pke_, curve_, &ac_, ct, 0, p1_, cmem_t{nullptr, 0}, label_, &share), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, PartialDecryptRsaHsmNullOutShareNullCb) { + const cbmpc_pve_rsa_oaep_hsm_decap_t cb = {nullptr, noop_rsa_decap}; + EXPECT_EQ(cbmpc_pve_ac_partial_decrypt_attempt_rsa_oaep_hsm(curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, ek1_, label_, &cb, nullptr), + E_BADARG); + cmem_t share{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_ac_partial_decrypt_attempt_rsa_oaep_hsm(curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, ek1_, label_, nullptr, &share), + E_BADARG); + const cbmpc_pve_rsa_oaep_hsm_decap_t cb_null = {nullptr, nullptr}; + EXPECT_EQ(cbmpc_pve_ac_partial_decrypt_attempt_rsa_oaep_hsm(curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, ek1_, label_, &cb_null, &share), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, PartialDecryptRsaHsmEmptyDkHandle) { + const cbmpc_pve_rsa_oaep_hsm_decap_t cb = {nullptr, noop_rsa_decap}; + cmem_t share{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_partial_decrypt_attempt_rsa_oaep_hsm(curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, ek1_, label_, &cb, &share), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, PartialDecryptEciesHsmNullOutShareNullCb) { + const cbmpc_pve_ecies_p256_hsm_ecdh_t cb = {nullptr, noop_ecies_ecdh}; + EXPECT_EQ(cbmpc_pve_ac_partial_decrypt_attempt_ecies_p256_hsm(curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, ek1_, label_, &cb, nullptr), + E_BADARG); + cmem_t share{nullptr, 0}; + EXPECT_EQ(cbmpc_pve_ac_partial_decrypt_attempt_ecies_p256_hsm(curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, ek1_, label_, nullptr, &share), + E_BADARG); + const cbmpc_pve_ecies_p256_hsm_ecdh_t cb_null = {nullptr, nullptr}; + EXPECT_EQ(cbmpc_pve_ac_partial_decrypt_attempt_ecies_p256_hsm(curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, ek1_, label_, &cb_null, &share), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, PartialDecryptEciesHsmEmptyDkHandle) { + const cbmpc_pve_ecies_p256_hsm_ecdh_t cb = {nullptr, noop_ecies_ecdh}; + cmem_t share{nullptr, 0}; + EXPECT_NE(cbmpc_pve_ac_partial_decrypt_attempt_ecies_p256_hsm(curve_, &ac_, cmem_t{nullptr, 0}, 0, p1_, + cmem_t{nullptr, 0}, ek1_, label_, &cb, &share), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, CombineNullOutXs) { + const std::array qn = {p1_, p2_}; + const std::array qs = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + EXPECT_EQ( + cbmpc_pve_ac_combine(&base_pke_, curve_, &ac_, qn.data(), qs.data(), 2, cmem_t{nullptr, 0}, 0, label_, nullptr), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, CombineInvalidCurve) { + const std::array qn = {p1_, p2_}; + const std::array qs = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + cmems_t out{0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_combine(&base_pke_, static_cast(0), &ac_, qn.data(), qs.data(), 2, + cmem_t{nullptr, 0}, 0, label_, &out), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, CombineNullAc) { + const std::array qn = {p1_, p2_}; + const std::array qs = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + cmems_t out{0, nullptr, nullptr}; + EXPECT_EQ( + cbmpc_pve_ac_combine(&base_pke_, curve_, nullptr, qn.data(), qs.data(), 2, cmem_t{nullptr, 0}, 0, label_, &out), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, CombineEmptyCiphertextEmptyLabel) { + const std::array qn = {p1_, p2_}; + const std::array qs = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + cmems_t out{0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_combine(&base_pke_, curve_, &ac_, qn.data(), qs.data(), 2, cmem_t{nullptr, 0}, 0, + cmem_t{nullptr, 0}, &out), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, CombineQuorumCountZero) { + cmems_t out{0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_combine(&base_pke_, curve_, &ac_, nullptr, nullptr, 0, cmem_t{nullptr, 0}, 0, label_, &out), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, CombineNullQuorumNamesAndShares) { + cmems_t out{0, nullptr, nullptr}; + const std::array qs = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + EXPECT_EQ(cbmpc_pve_ac_combine(&base_pke_, curve_, &ac_, nullptr, qs.data(), 2, cmem_t{nullptr, 0}, 0, label_, &out), + E_BADARG); + const std::array qn = {p1_, p2_}; + EXPECT_EQ(cbmpc_pve_ac_combine(&base_pke_, curve_, &ac_, qn.data(), nullptr, 2, cmem_t{nullptr, 0}, 0, label_, &out), + E_BADARG); +} + +TEST_F(CApiPveAcNeg, CombineGarbageCiphertext) { + uint8_t garbage[4] = {0xDE, 0xAD, 0xBE, 0xEF}; + const cmem_t ct = {garbage, 4}; + const std::array qn = {p1_, p2_}; + const std::array qs = {cmem_t{nullptr, 0}, cmem_t{nullptr, 0}}; + cmems_t out{0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_combine(&base_pke_, curve_, &ac_, qn.data(), qs.data(), 2, ct, 0, label_, &out), + CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, GetCountNullOutput) { EXPECT_EQ(cbmpc_pve_ac_get_count(cmem_t{nullptr, 0}, nullptr), E_BADARG); } + +TEST_F(CApiPveAcNeg, GetQsNullOutput) { EXPECT_EQ(cbmpc_pve_ac_get_Qs(cmem_t{nullptr, 0}, nullptr), E_BADARG); } + +TEST_F(CApiPveAcNeg, GetCountEmptyAndGarbageCiphertext) { + int count = 0; + EXPECT_NE(cbmpc_pve_ac_get_count(cmem_t{nullptr, 0}, &count), CBMPC_SUCCESS); + uint8_t garbage[4] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_NE(cbmpc_pve_ac_get_count(cmem_t{garbage, 4}, &count), CBMPC_SUCCESS); +} + +TEST_F(CApiPveAcNeg, GetQsEmptyAndGarbageCiphertext) { + cmems_t qs{0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_ac_get_Qs(cmem_t{nullptr, 0}, &qs), CBMPC_SUCCESS); + uint8_t garbage[4] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_NE(cbmpc_pve_ac_get_Qs(cmem_t{garbage, 4}, &qs), CBMPC_SUCCESS); +} diff --git a/tests/unit/c_api/test_pve_batch.cpp b/tests/unit/c_api/test_pve_batch.cpp new file mode 100644 index 00000000..53429c0e --- /dev/null +++ b/tests/unit/c_api/test_pve_batch.cpp @@ -0,0 +1,293 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +using coinbase::buf_t; +using coinbase::mem_t; + +static buf_t expected_Q(cbmpc_curve_id_t curve_id, mem_t x) { + const coinbase::crypto::ecurve_t curve = (curve_id == CBMPC_CURVE_P256) ? coinbase::crypto::curve_p256 + : (curve_id == CBMPC_CURVE_SECP256K1) ? coinbase::crypto::curve_secp256k1 + : (curve_id == CBMPC_CURVE_ED25519) ? coinbase::crypto::curve_ed25519 + : coinbase::crypto::ecurve_t(); + cb_assert(curve.valid()); + + const coinbase::crypto::bn_t bn_x = coinbase::crypto::bn_t::from_bin(x) % curve.order(); + const coinbase::crypto::ecc_point_t Q = bn_x * curve.generator(); + return Q.to_compressed_bin(); +} + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +static cbmpc_error_t toy_encrypt(void* /*ctx*/, cmem_t /*ek*/, cmem_t /*label*/, cmem_t plain, cmem_t /*rho*/, + cmem_t* out_ct) { + if (!out_ct) return E_BADARG; + *out_ct = cmem_t{nullptr, 0}; + if (plain.size < 0) return E_BADARG; + if (plain.size > 0 && !plain.data) return E_BADARG; + + const int n = plain.size; + if (n == 0) return CBMPC_SUCCESS; + + out_ct->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_ct->data) return E_INSUFFICIENT; + out_ct->size = n; + std::memmove(out_ct->data, plain.data, static_cast(n)); + return CBMPC_SUCCESS; +} + +static cbmpc_error_t toy_decrypt(void* /*ctx*/, cmem_t /*dk*/, cmem_t /*label*/, cmem_t ct, cmem_t* out_plain) { + if (!out_plain) return E_BADARG; + *out_plain = cmem_t{nullptr, 0}; + if (ct.size < 0) return E_BADARG; + if (ct.size > 0 && !ct.data) return E_BADARG; + + const int n = ct.size; + if (n == 0) return CBMPC_SUCCESS; + + out_plain->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_plain->data) return E_INSUFFICIENT; + out_plain->size = n; + std::memmove(out_plain->data, ct.data, static_cast(n)); + return CBMPC_SUCCESS; +} + +} // namespace + +TEST(CApiPveBatch, EncVerDec_DefBasePke_RsaBlob) { + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + constexpr int n = 4; + std::array(n) * 32> xs_flat{}; + std::array xs_sizes{}; + for (int i = 0; i < n; i++) { + xs_sizes[static_cast(i)] = 32; + for (int j = 0; j < 32; j++) xs_flat[static_cast(i * 32 + j)] = static_cast(i + j); + } + const cmems_t xs_in = {n, xs_flat.data(), xs_sizes.data()}; + + cmem_t ek_blob{nullptr, 0}; + cmem_t dk_blob{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_rsa_keypair(&ek_blob, &dk_blob), CBMPC_SUCCESS); + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_batch_encrypt(/*base_pke=*/nullptr, curve, ek_blob, label, xs_in, &ct), CBMPC_SUCCESS); + + int batch_count = 0; + ASSERT_EQ(cbmpc_pve_batch_get_count(ct, &batch_count), CBMPC_SUCCESS); + ASSERT_EQ(batch_count, n); + + cmem_t L_ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_batch_get_Label(ct, &L_ct), CBMPC_SUCCESS); + expect_eq(L_ct, label); + cbmpc_cmem_free(L_ct); + + std::array(n) * 33> Qs_flat{}; + std::array Qs_sizes{}; + for (int i = 0; i < n; i++) { + const mem_t xi(xs_flat.data() + i * 32, 32); + const buf_t qi = expected_Q(curve, xi); + ASSERT_EQ(qi.size(), 33); + Qs_sizes[static_cast(i)] = qi.size(); + std::memmove(Qs_flat.data() + i * 33, qi.data(), static_cast(qi.size())); + } + const cmems_t Qs_expected = {n, Qs_flat.data(), Qs_sizes.data()}; + + ASSERT_EQ(cbmpc_pve_batch_verify(/*base_pke=*/nullptr, curve, ek_blob, ct, Qs_expected, label), CBMPC_SUCCESS); + + cmems_t xs_out{0, nullptr, nullptr}; + ASSERT_EQ(cbmpc_pve_batch_decrypt(/*base_pke=*/nullptr, curve, dk_blob, ek_blob, ct, label, &xs_out), CBMPC_SUCCESS); + ASSERT_EQ(xs_out.count, n); + + int off = 0; + for (int i = 0; i < n; i++) { + ASSERT_EQ(xs_out.sizes[i], 32); + ASSERT_EQ(std::memcmp(xs_out.data + off, xs_flat.data() + i * 32, 32), 0); + off += xs_out.sizes[i]; + } + + cbmpc_cmems_free(xs_out); + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk_blob); + cbmpc_cmem_free(ek_blob); +} + +TEST(CApiPveBatch, EncryptVerifyDecrypt_CustomBasePke) { + const cbmpc_pve_base_pke_t base_pke = { + /*ctx=*/nullptr, + /*encrypt=*/toy_encrypt, + /*decrypt=*/toy_decrypt, + }; + + const cbmpc_curve_id_t curve = CBMPC_CURVE_SECP256K1; + const cmem_t ek = {reinterpret_cast(const_cast("ek")), 2}; + const cmem_t dk = {reinterpret_cast(const_cast("dk")), 2}; + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + + constexpr int n = 3; + std::array(n) * 32> xs_flat{}; + std::array xs_sizes{}; + for (int i = 0; i < n; i++) { + xs_sizes[static_cast(i)] = 32; + for (int j = 0; j < 32; j++) xs_flat[static_cast(i * 32 + j)] = static_cast(0x77 + i + j); + } + const cmems_t xs_in = {n, xs_flat.data(), xs_sizes.data()}; + + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_batch_encrypt(&base_pke, curve, ek, label, xs_in, &ct), CBMPC_SUCCESS); + + std::array(n) * 33> Qs_flat{}; + std::array Qs_sizes{}; + for (int i = 0; i < n; i++) { + const mem_t xi(xs_flat.data() + i * 32, 32); + const buf_t qi = expected_Q(curve, xi); + ASSERT_EQ(qi.size(), 33); + Qs_sizes[static_cast(i)] = qi.size(); + std::memmove(Qs_flat.data() + i * 33, qi.data(), static_cast(qi.size())); + } + const cmems_t Qs_expected = {n, Qs_flat.data(), Qs_sizes.data()}; + + ASSERT_EQ(cbmpc_pve_batch_verify(&base_pke, curve, ek, ct, Qs_expected, label), CBMPC_SUCCESS); + + cmems_t xs_out{0, nullptr, nullptr}; + ASSERT_EQ(cbmpc_pve_batch_decrypt(&base_pke, curve, dk, ek, ct, label, &xs_out), CBMPC_SUCCESS); + ASSERT_EQ(xs_out.count, n); + + int off = 0; + for (int i = 0; i < n; i++) { + ASSERT_EQ(xs_out.sizes[i], 32); + ASSERT_EQ(std::memcmp(xs_out.data + off, xs_flat.data() + i * 32, 32), 0); + off += xs_out.sizes[i]; + } + + cbmpc_cmems_free(xs_out); + cbmpc_cmem_free(ct); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +#include + +TEST(CApiPveBatchNeg, Encrypt) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek, &dk), CBMPC_SUCCESS); + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + int x_size = 32; + const cmems_t xs = {1, x_bytes.data(), &x_size}; + cmem_t ct{nullptr, 0}; + + EXPECT_EQ(cbmpc_pve_batch_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, xs, nullptr), E_BADARG); + EXPECT_NE(cbmpc_pve_batch_encrypt(nullptr, static_cast(0), ek, label, xs, &ct), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_encrypt(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, label, xs, &ct), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, cmem_t{nullptr, 0}, xs, &ct), CBMPC_SUCCESS); + const cmems_t empty_xs = {0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_batch_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, empty_xs, &ct), CBMPC_SUCCESS); + + cbmpc_cmem_free(dk); + cbmpc_cmem_free(ek); +} + +TEST(CApiPveBatchNeg, Verify) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek, &dk), CBMPC_SUCCESS); + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + int x_size = 32; + const cmems_t xs = {1, x_bytes.data(), &x_size}; + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_batch_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, xs, &ct), CBMPC_SUCCESS); + cmems_t Qs{0, nullptr, nullptr}; + ASSERT_EQ(cbmpc_pve_batch_get_Qs(ct, &Qs), CBMPC_SUCCESS); + + EXPECT_NE(cbmpc_pve_batch_verify(nullptr, static_cast(0), ek, ct, Qs, label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_verify(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, ct, Qs, label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_verify(nullptr, CBMPC_CURVE_SECP256K1, ek, cmem_t{nullptr, 0}, Qs, label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_verify(nullptr, CBMPC_CURVE_SECP256K1, ek, ct, Qs, cmem_t{nullptr, 0}), CBMPC_SUCCESS); + + cbmpc_cmems_free(Qs); + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk); + cbmpc_cmem_free(ek); +} + +TEST(CApiPveBatchNeg, Decrypt) { + dylog_disable_scope_t no_log; + cmem_t ek{nullptr, 0}; + cmem_t dk{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_generate_base_pke_ecies_p256_keypair(&ek, &dk), CBMPC_SUCCESS); + const cmem_t label = {reinterpret_cast(const_cast("label")), 5}; + std::array x_bytes{}; + x_bytes[0] = 1; + int x_size = 32; + const cmems_t xs = {1, x_bytes.data(), &x_size}; + cmem_t ct{nullptr, 0}; + ASSERT_EQ(cbmpc_pve_batch_encrypt(nullptr, CBMPC_CURVE_SECP256K1, ek, label, xs, &ct), CBMPC_SUCCESS); + cmems_t xs_out{0, nullptr, nullptr}; + + EXPECT_EQ(cbmpc_pve_batch_decrypt(nullptr, CBMPC_CURVE_SECP256K1, dk, ek, ct, label, nullptr), E_BADARG); + EXPECT_NE(cbmpc_pve_batch_decrypt(nullptr, static_cast(0), dk, ek, ct, label, &xs_out), + CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_decrypt(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, ek, ct, label, &xs_out), + CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_decrypt(nullptr, CBMPC_CURVE_SECP256K1, dk, cmem_t{nullptr, 0}, ct, label, &xs_out), + CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_decrypt(nullptr, CBMPC_CURVE_SECP256K1, dk, ek, cmem_t{nullptr, 0}, label, &xs_out), + CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_decrypt(nullptr, CBMPC_CURVE_SECP256K1, dk, ek, ct, cmem_t{nullptr, 0}, &xs_out), + CBMPC_SUCCESS); + + cbmpc_cmem_free(ct); + cbmpc_cmem_free(dk); + cbmpc_cmem_free(ek); +} + +TEST(CApiPveBatchNeg, GetCount) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_EQ(cbmpc_pve_batch_get_count(cmem_t{nullptr, 0}, nullptr), E_BADARG); + int count = 0; + EXPECT_NE(cbmpc_pve_batch_get_count(cmem_t{nullptr, 0}, &count), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_get_count(cmem_t{garbage, 4}, &count), CBMPC_SUCCESS); +} + +TEST(CApiPveBatchNeg, GetQs) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_EQ(cbmpc_pve_batch_get_Qs(cmem_t{nullptr, 0}, nullptr), E_BADARG); + cmems_t Qs{0, nullptr, nullptr}; + EXPECT_NE(cbmpc_pve_batch_get_Qs(cmem_t{nullptr, 0}, &Qs), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_get_Qs(cmem_t{garbage, 4}, &Qs), CBMPC_SUCCESS); +} + +TEST(CApiPveBatchNeg, GetLabel) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + EXPECT_EQ(cbmpc_pve_batch_get_Label(cmem_t{nullptr, 0}, nullptr), E_BADARG); + cmem_t label{nullptr, 0}; + EXPECT_NE(cbmpc_pve_batch_get_Label(cmem_t{nullptr, 0}, &label), CBMPC_SUCCESS); + EXPECT_NE(cbmpc_pve_batch_get_Label(cmem_t{garbage, 4}, &label), CBMPC_SUCCESS); +} diff --git a/tests/unit/c_api/test_schnorr2pc.cpp b/tests/unit/c_api/test_schnorr2pc.cpp new file mode 100644 index 00000000..14d07d5a --- /dev/null +++ b/tests/unit/c_api/test_schnorr2pc.cpp @@ -0,0 +1,363 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_2pc; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +} // namespace + +TEST(CApiSchnorr2pc, DkgSignRefreshSign) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + std::atomic free_calls_1{0}; + std::atomic free_calls_2{0}; + transport_ctx_t ctx1{c1, &free_calls_1}; + transport_ctx_t ctx2{c2, &free_calls_2}; + + const cbmpc_transport_t t1 = make_transport(&ctx1); + const cbmpc_transport_t t2 = make_transport(&ctx2); + + cmem_t key_blob_1{nullptr, 0}; + cmem_t key_blob_2{nullptr, 0}; + cbmpc_error_t rv1 = UNINITIALIZED_ERROR; + cbmpc_error_t rv2 = UNINITIALIZED_ERROR; + + const cbmpc_2pc_job_t job1 = {CBMPC_2PC_P1, "p1", "p2", &t1}; + const cbmpc_2pc_job_t job2 = {CBMPC_2PC_P2, "p1", "p2", &t2}; + run_2pc( + c1, c2, [&] { return cbmpc_schnorr_2p_dkg(&job1, CBMPC_CURVE_SECP256K1, &key_blob_1); }, + [&] { return cbmpc_schnorr_2p_dkg(&job2, CBMPC_CURVE_SECP256K1, &key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_GT(key_blob_1.size, 0); + ASSERT_GT(key_blob_2.size, 0); + + cmem_t pub1{nullptr, 0}; + cmem_t pub2{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_2p_get_public_key_compressed(key_blob_1, &pub1), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_schnorr_2p_get_public_key_compressed(key_blob_2, &pub2), CBMPC_SUCCESS); + expect_eq(pub1, pub2); + ASSERT_EQ(pub1.size, 33); + + const buf_t pub_buf(pub1.data, pub1.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub_buf), SUCCESS); + + cmem_t xonly1{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_2p_extract_public_key_xonly(key_blob_1, &xonly1), CBMPC_SUCCESS); + ASSERT_EQ(xonly1.size, 32); + const buf_t expected_xonly = Q.get_x().to_bin(32); + ASSERT_EQ(std::memcmp(expected_xonly.data(), xonly1.data, 32), 0); + + uint8_t msg_bytes[32]; + for (int i = 0; i < 32; i++) msg_bytes[i] = static_cast(i); + const cmem_t msg = {msg_bytes, 32}; + + cmem_t sig1{nullptr, 0}; + cmem_t sig2{nullptr, 0}; + run_2pc( + c1, c2, [&] { return cbmpc_schnorr_2p_sign(&job1, key_blob_1, msg, &sig1); }, + [&] { return cbmpc_schnorr_2p_sign(&job2, key_blob_2, msg, &sig2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_EQ(sig1.size, 64); + ASSERT_EQ(sig2.size, 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, coinbase::mem_t(msg_bytes, 32), coinbase::mem_t(sig1.data, 64)), + SUCCESS); + + cmem_t new_key_blob_1{nullptr, 0}; + cmem_t new_key_blob_2{nullptr, 0}; + run_2pc( + c1, c2, [&] { return cbmpc_schnorr_2p_refresh(&job1, key_blob_1, &new_key_blob_1); }, + [&] { return cbmpc_schnorr_2p_refresh(&job2, key_blob_2, &new_key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_GT(new_key_blob_1.size, 0); + ASSERT_GT(new_key_blob_2.size, 0); + + cmem_t pub3{nullptr, 0}; + cmem_t pub4{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_2p_get_public_key_compressed(new_key_blob_1, &pub3), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_schnorr_2p_get_public_key_compressed(new_key_blob_2, &pub4), CBMPC_SUCCESS); + expect_eq(pub3, pub4); + expect_eq(pub1, pub3); + + cmem_t sig3{nullptr, 0}; + cmem_t sig4{nullptr, 0}; + run_2pc( + c1, c2, [&] { return cbmpc_schnorr_2p_sign(&job1, new_key_blob_1, msg, &sig3); }, + [&] { return cbmpc_schnorr_2p_sign(&job2, new_key_blob_2, msg, &sig4); }, rv1, rv2); + ASSERT_EQ(rv1, CBMPC_SUCCESS); + ASSERT_EQ(rv2, CBMPC_SUCCESS); + ASSERT_EQ(sig3.size, 64); + ASSERT_EQ(sig4.size, 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, coinbase::mem_t(msg_bytes, 32), coinbase::mem_t(sig3.data, 64)), + SUCCESS); + + EXPECT_GT(free_calls_1.load(), 0); + EXPECT_GT(free_calls_2.load(), 0); + + cbmpc_cmem_free(pub1); + cbmpc_cmem_free(pub2); + cbmpc_cmem_free(pub3); + cbmpc_cmem_free(pub4); + cbmpc_cmem_free(xonly1); + cbmpc_cmem_free(sig1); + cbmpc_cmem_free(sig2); + cbmpc_cmem_free(sig3); + cbmpc_cmem_free(sig4); + cbmpc_cmem_free(key_blob_1); + cbmpc_cmem_free(key_blob_2); + cbmpc_cmem_free(new_key_blob_1); + cbmpc_cmem_free(new_key_blob_2); +} + +TEST(CApiSchnorr2pc, ValidatesArgs) { + cmem_t out{reinterpret_cast(0x1), 123}; + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + EXPECT_EQ(cbmpc_schnorr_2p_dkg(&bad_job, CBMPC_CURVE_SECP256K1, &out), E_BADARG); + EXPECT_EQ(out.data, nullptr); + EXPECT_EQ(out.size, 0); + + // Missing sig_out is invalid. + EXPECT_EQ(cbmpc_schnorr_2p_sign(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +TEST(CApiSchnorr2pcNeg, DkgNullOutput) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + EXPECT_EQ(cbmpc_schnorr_2p_dkg(&bad_job, CBMPC_CURVE_SECP256K1, nullptr), E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, DkgNullJob) { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_dkg(nullptr, CBMPC_CURVE_SECP256K1, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, DkgInvalidCurve) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_dkg(&bad_job, static_cast(0), &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, RefreshNullOutput) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + EXPECT_EQ(cbmpc_schnorr_2p_refresh(&bad_job, cmem_t{nullptr, 0}, nullptr), E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, RefreshNullJob) { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_refresh(nullptr, cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, RefreshEmptyKeyBlob) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_refresh(&bad_job, cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, RefreshGarbageKeyBlob) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_refresh(&bad_job, cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, SignNullSigOut) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + EXPECT_EQ(cbmpc_schnorr_2p_sign(&bad_job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr), E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, SignNullJob) { + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_sign(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &sig), CBMPC_SUCCESS); + EXPECT_EQ(sig.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, SignEmptyKeyBlob) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_sign(&bad_job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, &sig), CBMPC_SUCCESS); + EXPECT_EQ(sig.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, SignEmptyMsg) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + uint8_t blob[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_sign(&bad_job, cmem_t{blob, 1}, cmem_t{nullptr, 0}, &sig), CBMPC_SUCCESS); + EXPECT_EQ(sig.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, SignMsgWrongSize) { + const cbmpc_2pc_job_t bad_job = {CBMPC_2PC_P1, "p1", "p2", nullptr}; + uint8_t blob[] = {0x01}; + { + uint8_t msg[31] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_sign(&bad_job, cmem_t{blob, 1}, cmem_t{msg, 31}, &sig), CBMPC_SUCCESS); + EXPECT_EQ(sig.data, nullptr); + } + { + uint8_t msg[33] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_sign(&bad_job, cmem_t{blob, 1}, cmem_t{msg, 33}, &sig), CBMPC_SUCCESS); + EXPECT_EQ(sig.data, nullptr); + } +} + +TEST(CApiSchnorr2pcNeg, GetPubKeyNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_schnorr_2p_get_public_key_compressed(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, GetPubKeyEmptyBlob) { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_get_public_key_compressed(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, GetPubKeyGarbageBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_get_public_key_compressed(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, ExtractXonlyNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_schnorr_2p_extract_public_key_xonly(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, ExtractXonlyEmptyBlob) { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_extract_public_key_xonly(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, ExtractXonlyGarbageBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_extract_public_key_xonly(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, GetPubShareNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_schnorr_2p_get_public_share_compressed(cmem_t{dummy, 1}, nullptr), E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, GetPubShareEmptyBlob) { + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_get_public_share_compressed(cmem_t{nullptr, 0}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, GetPubShareGarbageBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_get_public_share_compressed(cmem_t{garbage, 4}, &out), CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, DetachNullPubOutput) { + uint8_t dummy[] = {0x01}; + cmem_t scalar{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_2p_detach_private_scalar(cmem_t{dummy, 1}, nullptr, &scalar), E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, DetachNullScalarOutput) { + uint8_t dummy[] = {0x01}; + cmem_t pub{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_2p_detach_private_scalar(cmem_t{dummy, 1}, &pub, nullptr), E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, DetachEmptyBlob) { + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_detach_private_scalar(cmem_t{nullptr, 0}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, DetachGarbageBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_detach_private_scalar(cmem_t{garbage, 4}, &pub, &scalar), CBMPC_SUCCESS); + EXPECT_EQ(pub.data, nullptr); + EXPECT_EQ(scalar.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, AttachNullOutput) { + uint8_t dummy[] = {0x01}; + EXPECT_EQ(cbmpc_schnorr_2p_attach_private_scalar(cmem_t{dummy, 1}, cmem_t{dummy, 1}, cmem_t{dummy, 1}, nullptr), + E_BADARG); +} + +TEST(CApiSchnorr2pcNeg, AttachEmptyPubBlob) { + uint8_t dummy[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_attach_private_scalar(cmem_t{nullptr, 0}, cmem_t{dummy, 1}, cmem_t{dummy, 1}, &out), + CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, AttachEmptyScalar) { + uint8_t dummy[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_attach_private_scalar(cmem_t{dummy, 1}, cmem_t{nullptr, 0}, cmem_t{dummy, 1}, &out), + CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, AttachEmptyPubShare) { + uint8_t dummy[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_attach_private_scalar(cmem_t{dummy, 1}, cmem_t{dummy, 1}, cmem_t{nullptr, 0}, &out), + CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} + +TEST(CApiSchnorr2pcNeg, AttachGarbagePubBlob) { + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + uint8_t dummy[] = {0x01}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_2p_attach_private_scalar(cmem_t{garbage, 4}, cmem_t{dummy, 1}, cmem_t{dummy, 1}, &out), + CBMPC_SUCCESS); + EXPECT_EQ(out.data, nullptr); +} diff --git a/tests/unit/c_api/test_schnorr_mp.cpp b/tests/unit/c_api/test_schnorr_mp.cpp new file mode 100644 index 00000000..bcd41def --- /dev/null +++ b/tests/unit/c_api/test_schnorr_mp.cpp @@ -0,0 +1,445 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_mp; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +} // namespace + +TEST(CApiSchnorrMp, DkgSignRefreshSignRoleChange4p) { + constexpr int n = 4; + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + std::atomic free_calls[n]; + transport_ctx_t ctx[n]; + cbmpc_transport_t transports[n]; + for (int i = 0; i < n; i++) { + free_calls[i].store(0); + ctx[i] = transport_ctx_t{peers[static_cast(i)], &free_calls[i]}; + transports[i] = make_transport(&ctx[i]); + } + + const char* party_names[n] = {"p0", "p1", "p2", "p3"}; + + std::vector key_blobs(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_schnorr_mp_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &key_blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(key_blobs[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + for (int i = 1; i < n; i++) expect_eq(sids[0], sids[static_cast(i)]); + + cmem_t pub0{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(key_blobs[0], &pub0), CBMPC_SUCCESS); + ASSERT_EQ(pub0.size, 33); + for (int i = 1; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const buf_t pub_buf(pub0.data, pub0.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub_buf), SUCCESS); + + cmem_t xonly0{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_extract_public_key_xonly(key_blobs[0], &xonly0), CBMPC_SUCCESS); + ASSERT_EQ(xonly0.size, 32); + const buf_t expected_xonly = Q.get_x().to_bin(32); + ASSERT_EQ(std::memcmp(expected_xonly.data(), xonly0.data, 32), 0); + + // Change the party ordering ("role" indices) between protocols. + const char* party_names2[n] = {"p0", "p2", "p1", "p3"}; + // Map new role index -> old role index (DKG) for the same party name. + const int perm[n] = {0, 2, 1, 3}; + + uint8_t msg_bytes[32]; + for (int i = 0; i < 32; i++) msg_bytes[i] = static_cast(i); + const cmem_t msg = {msg_bytes, 32}; + + std::vector sigs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_schnorr_mp_sign_additive(&job, key_blobs[static_cast(perm[i])], msg, /*sig_receiver=*/2, + &sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[2].size, 64); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + ASSERT_EQ(sigs[static_cast(i)].size, 0); + } + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, coinbase::mem_t(msg_bytes, 32), coinbase::mem_t(sigs[2].data, 64)), + SUCCESS); + + std::vector new_key_blobs(n, cmem_t{nullptr, 0}); + std::vector sid_outs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_schnorr_mp_refresh_additive( + &job, sids[static_cast(perm[i])], key_blobs[static_cast(perm[i])], + &sid_outs[static_cast(i)], &new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) ASSERT_GT(new_key_blobs[static_cast(i)].size, 0); + for (int i = 1; i < n; i++) expect_eq(sid_outs[0], sid_outs[static_cast(i)]); + expect_eq(sids[0], sid_outs[0]); + + for (int i = 0; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(new_key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + std::vector new_sigs(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names2, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_schnorr_mp_sign_additive(&job, new_key_blobs[static_cast(i)], msg, /*sig_receiver=*/2, + &new_sigs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(new_sigs[2].size, 64); + for (int i = 0; i < n; i++) { + if (i == 2) continue; + ASSERT_EQ(new_sigs[static_cast(i)].size, 0); + } + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, coinbase::mem_t(msg_bytes, 32), coinbase::mem_t(new_sigs[2].data, 64)), + SUCCESS); + + for (int i = 0; i < n; i++) EXPECT_GT(free_calls[i].load(), 0); + + cbmpc_cmem_free(pub0); + cbmpc_cmem_free(xonly0); + for (auto m : new_sigs) cbmpc_cmem_free(m); + for (auto m : sid_outs) cbmpc_cmem_free(m); + for (auto m : new_key_blobs) cbmpc_cmem_free(m); + for (auto m : sigs) cbmpc_cmem_free(m); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : key_blobs) cbmpc_cmem_free(m); +} + +TEST(CApiSchnorrMp, ValidatesArgs) { + cmem_t key{reinterpret_cast(0x1), 123}; + cmem_t sid{reinterpret_cast(0x1), 123}; + + const cbmpc_transport_t bad_transport = {/*ctx=*/nullptr, /*send=*/nullptr, /*receive=*/nullptr, + /*receive_all=*/nullptr, + /*free=*/nullptr}; + const char* names[2] = {"p0", "p1"}; + const cbmpc_mp_job_t bad_job = {/*self=*/0, /*party_names=*/names, /*party_names_count=*/2, + /*transport=*/&bad_transport}; + + EXPECT_EQ(cbmpc_schnorr_mp_dkg_additive(&bad_job, CBMPC_CURVE_SECP256K1, &key, &sid), E_BADARG); + EXPECT_EQ(key.data, nullptr); + EXPECT_EQ(key.size, 0); + EXPECT_EQ(sid.data, nullptr); + EXPECT_EQ(sid.size, 0); + + // Missing sig_out is invalid. + EXPECT_EQ(cbmpc_schnorr_mp_sign_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, 0, nullptr), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +#include + +TEST(CApiSchnorrMpNeg, DkgAdditiveNullOutKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + cmem_t sid{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_dkg_additive(&bad_job, CBMPC_CURVE_SECP256K1, nullptr, &sid), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, DkgAdditiveNullOutSid) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + cmem_t key{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_dkg_additive(&bad_job, CBMPC_CURVE_SECP256K1, &key, nullptr), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, DkgAdditiveNullJob) { + dylog_disable_scope_t no_log; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_dkg_additive(nullptr, CBMPC_CURVE_SECP256K1, &key, &sid), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, DkgAdditiveInvalidCurve) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_dkg_additive(&bad_job, static_cast(0), &key, &sid), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, RefreshAdditiveNullOutNewKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + cmem_t sid_out{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_refresh_additive(&bad_job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &sid_out, nullptr), + E_BADARG); +} + +TEST(CApiSchnorrMpNeg, RefreshAdditiveNullJob) { + dylog_disable_scope_t no_log; + cmem_t sid_out{nullptr, 0}; + cmem_t new_key{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_refresh_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &sid_out, &new_key), + E_BADARG); +} + +TEST(CApiSchnorrMpNeg, RefreshAdditiveEmptyKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + cmem_t sid_out{nullptr, 0}; + cmem_t new_key{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_refresh_additive(&bad_job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &sid_out, &new_key), + CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, RefreshAdditiveGarbageKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t sid_out{nullptr, 0}; + cmem_t new_key{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_refresh_additive(&bad_job, cmem_t{nullptr, 0}, cmem_t{garbage, 4}, &sid_out, &new_key), + CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, SignAdditiveNullSigOut) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + EXPECT_EQ(cbmpc_schnorr_mp_sign_additive(&bad_job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, 0, nullptr), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, SignAdditiveNullJob) { + dylog_disable_scope_t no_log; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_sign_additive(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, 0, &sig), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, SignAdditiveEmptyKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_sign_additive(&bad_job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, SignAdditiveEmptyMsg) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + uint8_t key_bytes[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_sign_additive(&bad_job, cmem_t{key_bytes, 1}, cmem_t{nullptr, 0}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, SignAdditiveMsgWrongSize) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + uint8_t short_msg[31] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_sign_additive(&bad_job, cmem_t{nullptr, 0}, cmem_t{short_msg, 31}, 0, &sig), + CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, SignAdditiveInvalidSigReceiver) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_sign_additive(&bad_job, cmem_t{nullptr, 0}, cmem_t{msg, 32}, -1, &sig), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, GetPublicKeyCompressedNullOutput) { + dylog_disable_scope_t no_log; + EXPECT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(cmem_t{nullptr, 0}, nullptr), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, GetPublicKeyCompressedEmptyKeyBlob) { + dylog_disable_scope_t no_log; + cmem_t pub{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_get_public_key_compressed(cmem_t{nullptr, 0}, &pub), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, GetPublicKeyCompressedGarbageKeyBlob) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t pub{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_get_public_key_compressed(cmem_t{garbage, 4}, &pub), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, ExtractPublicKeyXonlyNullOutput) { + dylog_disable_scope_t no_log; + EXPECT_EQ(cbmpc_schnorr_mp_extract_public_key_xonly(cmem_t{nullptr, 0}, nullptr), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, ExtractPublicKeyXonlyEmptyKeyBlob) { + dylog_disable_scope_t no_log; + cmem_t pub{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_extract_public_key_xonly(cmem_t{nullptr, 0}, &pub), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, ExtractPublicKeyXonlyGarbageKeyBlob) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t pub{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_extract_public_key_xonly(cmem_t{garbage, 4}, &pub), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, GetPublicShareCompressedNullOutput) { + dylog_disable_scope_t no_log; + EXPECT_EQ(cbmpc_schnorr_mp_get_public_share_compressed(cmem_t{nullptr, 0}, nullptr), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, GetPublicShareCompressedEmptyKeyBlob) { + dylog_disable_scope_t no_log; + cmem_t share{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_get_public_share_compressed(cmem_t{nullptr, 0}, &share), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, GetPublicShareCompressedGarbageKeyBlob) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t share{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_get_public_share_compressed(cmem_t{garbage, 4}, &share), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, DetachPrivateScalarNullOutPublicKeyBlob) { + dylog_disable_scope_t no_log; + cmem_t scalar{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_detach_private_scalar(cmem_t{nullptr, 0}, nullptr, &scalar), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, DetachPrivateScalarNullOutPrivateScalarFixed) { + dylog_disable_scope_t no_log; + cmem_t pub{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_detach_private_scalar(cmem_t{nullptr, 0}, &pub, nullptr), E_BADARG); +} + +TEST(CApiSchnorrMpNeg, DetachPrivateScalarEmptyKeyBlob) { + dylog_disable_scope_t no_log; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_detach_private_scalar(cmem_t{nullptr, 0}, &pub, &scalar), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, DetachPrivateScalarGarbageKeyBlob) { + dylog_disable_scope_t no_log; + uint8_t garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; + cmem_t pub{nullptr, 0}; + cmem_t scalar{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_detach_private_scalar(cmem_t{garbage, 4}, &pub, &scalar), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, AttachPrivateScalarNullOutKeyBlob) { + dylog_disable_scope_t no_log; + EXPECT_EQ(cbmpc_schnorr_mp_attach_private_scalar(cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr), + E_BADARG); +} + +TEST(CApiSchnorrMpNeg, AttachPrivateScalarEmptyPublicKeyBlob) { + dylog_disable_scope_t no_log; + uint8_t scalar[] = {0x01}; + uint8_t share[] = {0x02}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_attach_private_scalar(cmem_t{nullptr, 0}, cmem_t{scalar, 1}, cmem_t{share, 1}, &out), + CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, AttachPrivateScalarEmptyPrivateScalarFixed) { + dylog_disable_scope_t no_log; + uint8_t pub[] = {0x01}; + uint8_t share[] = {0x02}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_attach_private_scalar(cmem_t{pub, 1}, cmem_t{nullptr, 0}, cmem_t{share, 1}, &out), + CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpNeg, AttachPrivateScalarEmptyPublicShareCompressed) { + dylog_disable_scope_t no_log; + uint8_t pub[] = {0x01}; + uint8_t scalar[] = {0x02}; + cmem_t out{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_attach_private_scalar(cmem_t{pub, 1}, cmem_t{scalar, 1}, cmem_t{nullptr, 0}, &out), + CBMPC_SUCCESS); +} diff --git a/tests/unit/c_api/test_schnorr_mp_ac.cpp b/tests/unit/c_api/test_schnorr_mp_ac.cpp new file mode 100644 index 00000000..8ae4db53 --- /dev/null +++ b/tests/unit/c_api/test_schnorr_mp_ac.cpp @@ -0,0 +1,493 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_mp; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +static void make_peers(int n, std::vector>& peers) { + peers.clear(); + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); +} + +static void make_transports(const std::vector>& peers, + std::vector& ctxs, std::vector& transports) { + ctxs.resize(peers.size()); + transports.resize(peers.size()); + for (size_t i = 0; i < peers.size(); i++) { + ctxs[i] = transport_ctx_t{peers[i], /*free_calls=*/nullptr}; + transports[i] = make_transport(&ctxs[i]); + } +} + +} // namespace + +TEST(CApiSchnorrMpAc, DkgRefreshSign2of3) { + constexpr int n = 3; + + // Full 3-party network for threshold DKG/refresh. + std::vector> peers; + make_peers(n, peers); + + std::vector ctxs; + std::vector transports; + make_transports(peers, ctxs, transports); + + const char* party_names[n] = {"p0", "p1", "p2"}; + + // Access structure: THRESHOLD[2](p0, p1, p2) + const int32_t child_indices[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, /*off=*/0, /*cnt=*/3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p0", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p1", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p2", /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes, + /*nodes_count=*/static_cast(sizeof(nodes) / sizeof(nodes[0])), + /*child_indices=*/child_indices, + /*child_indices_count=*/static_cast(sizeof(child_indices) / sizeof(child_indices[0])), + /*root_index=*/0, + }; + + // Only p0 and p1 actively contribute to DKG/refresh. + const char* quorum[] = {"p0", "p1"}; + + std::vector key_blobs(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_schnorr_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, /*sid_in=*/cmem_t{nullptr, 0}, &ac, quorum, + /*quorum_party_names_count=*/2, &key_blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(key_blobs[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + for (int i = 1; i < n; i++) expect_eq(sids[0], sids[static_cast(i)]); + + cmem_t pub0{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(key_blobs[0], &pub0), CBMPC_SUCCESS); + ASSERT_EQ(pub0.size, 33); + for (int i = 1; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const buf_t pub_buf(pub0.data, pub0.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub_buf), SUCCESS); + + uint8_t msg_bytes[32]; + for (int i = 0; i < 32; i++) msg_bytes[i] = static_cast(0x22 + i); + const cmem_t msg = {msg_bytes, 32}; + + // Signing quorum: {p0, p1} + const char* sign_party_names[2] = {"p0", "p1"}; + const cmem_t sign_key_blobs[2] = {key_blobs[0], key_blobs[1]}; + + { + std::vector> sign_peers; + make_peers(2, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(2, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/sign_party_names, + /*party_names_count=*/2, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_schnorr_mp_sign_ac(&job, sign_key_blobs[static_cast(i)], &ac, msg, /*sig_receiver=*/0, + &sigs[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[0].size, 64); + EXPECT_EQ(sigs[1].size, 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, mem_t(msg_bytes, 32), mem_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + // Threshold refresh. + std::vector new_key_blobs(n, cmem_t{nullptr, 0}); + std::vector refresh_sids(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_schnorr_mp_refresh_ac(&job, /*sid_in=*/cmem_t{nullptr, 0}, key_blobs[static_cast(i)], &ac, + quorum, /*quorum_party_names_count=*/2, + &refresh_sids[static_cast(i)], + &new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) ASSERT_GT(new_key_blobs[static_cast(i)].size, 0); + for (int i = 1; i < n; i++) expect_eq(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(new_key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const cmem_t sign_new_key_blobs[2] = {new_key_blobs[0], new_key_blobs[1]}; + + { + std::vector> sign_peers; + make_peers(2, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(2, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/sign_party_names, + /*party_names_count=*/2, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_schnorr_mp_sign_ac(&job, sign_new_key_blobs[static_cast(i)], &ac, msg, + /*sig_receiver=*/0, &sigs[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[0].size, 64); + EXPECT_EQ(sigs[1].size, 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, mem_t(msg_bytes, 32), mem_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + cbmpc_cmem_free(pub0); + for (auto m : refresh_sids) cbmpc_cmem_free(m); + for (auto m : new_key_blobs) cbmpc_cmem_free(m); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : key_blobs) cbmpc_cmem_free(m); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +#include + +TEST(CApiSchnorrMpAcNeg, DkgAcNullOutAcKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + const char* quorum[] = {"p0", "p1"}; + cmem_t sid{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_dkg_ac(&bad_job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, nullptr, &sid), + E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, DkgAcNullOutSid) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + const char* quorum[] = {"p0", "p1"}; + cmem_t key{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_dkg_ac(&bad_job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &key, nullptr), + E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, DkgAcNullJob) { + dylog_disable_scope_t no_log; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + const char* quorum[] = {"p0", "p1"}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_dkg_ac(nullptr, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, &ac, quorum, 2, &key, &sid), + E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, DkgAcInvalidCurve) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + const char* quorum[] = {"p0", "p1"}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_dkg_ac(&bad_job, static_cast(0), cmem_t{nullptr, 0}, &ac, quorum, 2, + &key, &sid), + CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpAcNeg, DkgAcNullAc) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const char* quorum[] = {"p0", "p1"}; + cmem_t key{nullptr, 0}; + cmem_t sid{nullptr, 0}; + EXPECT_EQ( + cbmpc_schnorr_mp_dkg_ac(&bad_job, CBMPC_CURVE_SECP256K1, cmem_t{nullptr, 0}, nullptr, quorum, 2, &key, &sid), + E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, RefreshAcNullOutNewAcKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + EXPECT_EQ( + cbmpc_schnorr_mp_refresh_ac(&bad_job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &ac, quorum, 2, &sid_out, nullptr), + E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, RefreshAcNullJob) { + dylog_disable_scope_t no_log; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t new_key{nullptr, 0}; + EXPECT_EQ( + cbmpc_schnorr_mp_refresh_ac(nullptr, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &ac, quorum, 2, &sid_out, &new_key), + E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, RefreshAcEmptyKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t new_key{nullptr, 0}; + EXPECT_NE( + cbmpc_schnorr_mp_refresh_ac(&bad_job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, &ac, quorum, 2, &sid_out, &new_key), + CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpAcNeg, RefreshAcNullAc) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const char* quorum[] = {"p0", "p1"}; + cmem_t sid_out{nullptr, 0}; + cmem_t new_key{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_refresh_ac(&bad_job, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr, quorum, 2, &sid_out, + &new_key), + E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, SignAcNullSigOut) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_sign_ac(&bad_job, cmem_t{nullptr, 0}, &ac, cmem_t{nullptr, 0}, 0, nullptr), E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, SignAcNullJob) { + dylog_disable_scope_t no_log; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_sign_ac(nullptr, cmem_t{nullptr, 0}, &ac, cmem_t{nullptr, 0}, 0, &sig), E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, SignAcNullAc) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + cmem_t sig{nullptr, 0}; + EXPECT_EQ(cbmpc_schnorr_mp_sign_ac(&bad_job, cmem_t{nullptr, 0}, nullptr, cmem_t{nullptr, 0}, 0, &sig), E_BADARG); +} + +TEST(CApiSchnorrMpAcNeg, SignAcEmptyKeyBlob) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_sign_ac(&bad_job, cmem_t{nullptr, 0}, &ac, cmem_t{nullptr, 0}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpAcNeg, SignAcEmptyMsg) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + uint8_t key_bytes[] = {0x01}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_sign_ac(&bad_job, cmem_t{key_bytes, 1}, &ac, cmem_t{nullptr, 0}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpAcNeg, SignAcMsgWrongSize) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + uint8_t short_msg[31] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_sign_ac(&bad_job, cmem_t{nullptr, 0}, &ac, cmem_t{short_msg, 31}, 0, &sig), CBMPC_SUCCESS); +} + +TEST(CApiSchnorrMpAcNeg, SignAcInvalidSigReceiver) { + dylog_disable_scope_t no_log; + const char* names[] = {"p0", "p1", "p2"}; + const cbmpc_mp_job_t bad_job = {0, names, 3, nullptr}; + const int32_t ci[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, + }; + const cbmpc_access_structure_t ac = {nodes, 4, ci, 3, 0}; + uint8_t msg[32] = {}; + cmem_t sig{nullptr, 0}; + EXPECT_NE(cbmpc_schnorr_mp_sign_ac(&bad_job, cmem_t{nullptr, 0}, &ac, cmem_t{msg, 32}, -1, &sig), CBMPC_SUCCESS); +} diff --git a/tests/unit/c_api/test_schnorr_mp_threshold.cpp b/tests/unit/c_api/test_schnorr_mp_threshold.cpp new file mode 100644 index 00000000..5a075b00 --- /dev/null +++ b/tests/unit/c_api/test_schnorr_mp_threshold.cpp @@ -0,0 +1,231 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_mp; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +static void make_peers(int n, std::vector>& peers) { + peers.clear(); + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); +} + +static void make_transports(const std::vector>& peers, + std::vector& ctxs, std::vector& transports) { + ctxs.resize(peers.size()); + transports.resize(peers.size()); + for (size_t i = 0; i < peers.size(); i++) { + ctxs[i] = transport_ctx_t{peers[i], /*free_calls=*/nullptr}; + transports[i] = make_transport(&ctxs[i]); + } +} + +} // namespace + +TEST(CApiSchnorrMpThreshold, DkgRefreshSign2of3) { + constexpr int n = 3; + + // Full 3-party network for threshold DKG/refresh. + std::vector> peers; + make_peers(n, peers); + + std::vector ctxs; + std::vector transports; + make_transports(peers, ctxs, transports); + + const char* party_names[n] = {"p0", "p1", "p2"}; + + // Access structure: THRESHOLD[2](p0, p1, p2) + const int32_t child_indices[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, /*off=*/0, /*cnt=*/3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p0", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p1", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p2", /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes, + /*nodes_count=*/static_cast(sizeof(nodes) / sizeof(nodes[0])), + /*child_indices=*/child_indices, + /*child_indices_count=*/static_cast(sizeof(child_indices) / sizeof(child_indices[0])), + /*root_index=*/0, + }; + + // Only p0 and p1 actively contribute to DKG/refresh. + const char* quorum[] = {"p0", "p1"}; + + std::vector key_blobs(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_schnorr_mp_dkg_ac(&job, CBMPC_CURVE_SECP256K1, /*sid_in=*/cmem_t{nullptr, 0}, &ac, quorum, + /*quorum_party_names_count=*/2, &key_blobs[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(key_blobs[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + for (int i = 1; i < n; i++) expect_eq(sids[0], sids[static_cast(i)]); + + cmem_t pub0{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(key_blobs[0], &pub0), CBMPC_SUCCESS); + ASSERT_EQ(pub0.size, 33); + for (int i = 1; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const buf_t pub_buf(pub0.data, pub0.size); + coinbase::crypto::ecc_point_t Q; + ASSERT_EQ(Q.from_bin(coinbase::crypto::curve_secp256k1, pub_buf), SUCCESS); + + uint8_t msg_bytes[32]; + for (int i = 0; i < 32; i++) msg_bytes[i] = static_cast(0x22 + i); + const cmem_t msg = {msg_bytes, 32}; + + // Signing quorum: {p0, p1} + const char* sign_party_names[2] = {"p0", "p1"}; + const cmem_t sign_key_blobs[2] = {key_blobs[0], key_blobs[1]}; + + { + std::vector> sign_peers; + make_peers(2, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(2, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/sign_party_names, + /*party_names_count=*/2, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_schnorr_mp_sign_ac(&job, sign_key_blobs[static_cast(i)], &ac, msg, /*sig_receiver=*/0, + &sigs[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[0].size, 64); + EXPECT_EQ(sigs[1].size, 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, mem_t(msg_bytes, 32), mem_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + // Threshold refresh. + std::vector new_key_blobs(n, cmem_t{nullptr, 0}); + std::vector refresh_sids(n, cmem_t{nullptr, 0}); + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[static_cast(i)], + }; + return cbmpc_schnorr_mp_refresh_ac(&job, /*sid_in=*/cmem_t{nullptr, 0}, key_blobs[static_cast(i)], &ac, + quorum, /*quorum_party_names_count=*/2, + &refresh_sids[static_cast(i)], + &new_key_blobs[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) ASSERT_GT(new_key_blobs[static_cast(i)].size, 0); + for (int i = 1; i < n; i++) expect_eq(refresh_sids[0], refresh_sids[static_cast(i)]); + + for (int i = 0; i < n; i++) { + cmem_t pub_i{nullptr, 0}; + ASSERT_EQ(cbmpc_schnorr_mp_get_public_key_compressed(new_key_blobs[static_cast(i)], &pub_i), CBMPC_SUCCESS); + expect_eq(pub_i, pub0); + cbmpc_cmem_free(pub_i); + } + + const cmem_t sign_new_key_blobs[2] = {new_key_blobs[0], new_key_blobs[1]}; + + { + std::vector> sign_peers; + make_peers(2, sign_peers); + + std::vector sign_ctxs; + std::vector sign_transports; + make_transports(sign_peers, sign_ctxs, sign_transports); + + std::vector sigs(2, cmem_t{nullptr, 0}); + run_mp( + sign_peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/sign_party_names, + /*party_names_count=*/2, + /*transport=*/&sign_transports[static_cast(i)], + }; + return cbmpc_schnorr_mp_sign_ac(&job, sign_new_key_blobs[static_cast(i)], &ac, msg, + /*sig_receiver=*/0, &sigs[static_cast(i)]); + }, + rvs); + + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + ASSERT_EQ(sigs[0].size, 64); + EXPECT_EQ(sigs[1].size, 0); + ASSERT_EQ(coinbase::crypto::bip340::verify(Q, mem_t(msg_bytes, 32), mem_t(sigs[0].data, sigs[0].size)), SUCCESS); + + for (auto m : sigs) cbmpc_cmem_free(m); + } + + cbmpc_cmem_free(pub0); + for (auto m : refresh_sids) cbmpc_cmem_free(m); + for (auto m : new_key_blobs) cbmpc_cmem_free(m); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : key_blobs) cbmpc_cmem_free(m); +} diff --git a/tests/unit/c_api/test_tdh2.cpp b/tests/unit/c_api/test_tdh2.cpp new file mode 100644 index 00000000..722bdee4 --- /dev/null +++ b/tests/unit/c_api/test_tdh2.cpp @@ -0,0 +1,657 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "test_transport_harness.h" + +namespace { + +using coinbase::testutils::mpc_net_context_t; +using coinbase::testutils::capi_harness::make_transport; +using coinbase::testutils::capi_harness::run_mp; +using coinbase::testutils::capi_harness::transport_ctx_t; + +static void expect_eq(cmem_t a, cmem_t b) { + ASSERT_EQ(a.size, b.size); + if (a.size > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(a.size)), 0); + } +} + +static void expect_eq_cmems(cmems_t a, cmems_t b) { + ASSERT_EQ(a.count, b.count); + if (a.count == 0) return; + ASSERT_NE(a.sizes, nullptr); + ASSERT_NE(b.sizes, nullptr); + int total_a = 0; + int total_b = 0; + for (int i = 0; i < a.count; i++) { + ASSERT_EQ(a.sizes[i], b.sizes[i]); + total_a += a.sizes[i]; + total_b += b.sizes[i]; + } + ASSERT_EQ(total_a, total_b); + if (total_a > 0) { + ASSERT_NE(a.data, nullptr); + ASSERT_NE(b.data, nullptr); + ASSERT_EQ(std::memcmp(a.data, b.data, static_cast(total_a)), 0); + } +} + +static cmems_t pack_cmems_copy(const std::vector& mems) { + cmems_t out{0, nullptr, nullptr}; + if (mems.empty()) return out; + if (mems.size() > static_cast(INT_MAX)) return out; + + int total = 0; + for (const auto& m : mems) { + if (m.size < 0) return cmems_t{0, nullptr, nullptr}; + if (m.size > INT_MAX - total) return cmems_t{0, nullptr, nullptr}; + total += m.size; + } + + out.count = static_cast(mems.size()); + out.sizes = static_cast(cbmpc_malloc(sizeof(int) * mems.size())); + if (!out.sizes) return cmems_t{0, nullptr, nullptr}; + out.data = (total > 0) ? static_cast(cbmpc_malloc(static_cast(total))) : nullptr; + if (total > 0 && !out.data) { + cbmpc_free(out.sizes); + return cmems_t{0, nullptr, nullptr}; + } + + int offset = 0; + for (int i = 0; i < out.count; i++) { + out.sizes[i] = mems[i].size; + if (mems[i].size) { + std::memmove(out.data + offset, mems[i].data, static_cast(mems[i].size)); + offset += mems[i].size; + } + } + + return out; +} + +} // namespace + +TEST(CApiTdh2, DkgRoundTripEncryptDecrypt) { + constexpr int n = 3; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + transport_ctx_t ctx[n]; + cbmpc_transport_t transports[n]; + for (int i = 0; i < n; i++) { + ctx[i] = transport_ctx_t{peers[static_cast(i)], /*free_calls=*/nullptr}; + transports[i] = make_transport(&ctx[i]); + } + + const char* party_names[n] = {"p0", "p1", "p2"}; + + std::vector public_keys(n, cmem_t{nullptr, 0}); + std::vector public_shares(n, cmems_t{0, nullptr, nullptr}); + std::vector private_shares(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_tdh2_dkg_additive(&job, CBMPC_CURVE_SECP256K1, &public_keys[static_cast(i)], + &public_shares[static_cast(i)], &private_shares[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 0; i < n; i++) { + ASSERT_GT(public_keys[static_cast(i)].size, 0); + ASSERT_EQ(public_shares[static_cast(i)].count, n); + ASSERT_GT(private_shares[static_cast(i)].size, 0); + ASSERT_GT(sids[static_cast(i)].size, 0); + } + + for (int i = 1; i < n; i++) { + expect_eq(public_keys[0], public_keys[static_cast(i)]); + expect_eq_cmems(public_shares[0], public_shares[static_cast(i)]); + expect_eq(sids[0], sids[static_cast(i)]); + } + + uint8_t plaintext_bytes[32]; + for (int i = 0; i < 32; i++) plaintext_bytes[i] = static_cast(0xA5 ^ i); + const cmem_t plaintext = {plaintext_bytes, 32}; + const cmem_t label = {reinterpret_cast(const_cast("tdh2-label")), 9}; + + cmem_t ciphertext{nullptr, 0}; + ASSERT_EQ(cbmpc_tdh2_encrypt(public_keys[0], plaintext, label, &ciphertext), CBMPC_SUCCESS); + ASSERT_GT(ciphertext.size, 0); + ASSERT_EQ(cbmpc_tdh2_verify(public_keys[0], ciphertext, label), CBMPC_SUCCESS); + + std::vector partials(n, cmem_t{nullptr, 0}); + for (int i = 0; i < n; i++) { + ASSERT_EQ(cbmpc_tdh2_partial_decrypt(private_shares[static_cast(i)], ciphertext, label, + &partials[static_cast(i)]), + CBMPC_SUCCESS); + ASSERT_GT(partials[static_cast(i)].size, 0); + } + + const cmems_t partials_flat = pack_cmems_copy(partials); + ASSERT_EQ(partials_flat.count, n); + + cmem_t decrypted{nullptr, 0}; + ASSERT_EQ(cbmpc_tdh2_combine_additive(public_keys[0], public_shares[0], label, partials_flat, ciphertext, &decrypted), + CBMPC_SUCCESS); + ASSERT_EQ(decrypted.size, 32); + ASSERT_NE(decrypted.data, nullptr); + EXPECT_EQ(std::memcmp(decrypted.data, plaintext_bytes, 32), 0); + + // Wrong label should fail verification. + const cmem_t wrong_label = {reinterpret_cast(const_cast("wrong-label")), 11}; + EXPECT_NE(cbmpc_tdh2_verify(public_keys[0], ciphertext, wrong_label), CBMPC_SUCCESS); + + cbmpc_cmem_free(decrypted); + cbmpc_cmems_free(partials_flat); + for (auto p : partials) cbmpc_cmem_free(p); + cbmpc_cmem_free(ciphertext); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : private_shares) cbmpc_cmem_free(m); + for (auto m : public_shares) cbmpc_cmems_free(m); + for (auto m : public_keys) cbmpc_cmem_free(m); +} + +TEST(CApiTdh2, ThresholdDkg_Combine2of3) { + constexpr int n = 3; + + std::vector> peers; + peers.reserve(n); + for (int i = 0; i < n; i++) peers.push_back(std::make_shared(i)); + for (const auto& p : peers) p->init_with_peers(peers); + + transport_ctx_t ctx[n]; + cbmpc_transport_t transports[n]; + for (int i = 0; i < n; i++) { + ctx[i] = transport_ctx_t{peers[static_cast(i)], /*free_calls=*/nullptr}; + transports[i] = make_transport(&ctx[i]); + } + + const char* party_names[n] = {"p0", "p1", "p2"}; + + // Access structure: THRESHOLD[2](p0, p1, p2) + const int32_t child_indices[] = {1, 2, 3}; + const cbmpc_access_structure_node_t nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, /*leaf_name=*/nullptr, /*k=*/2, /*off=*/0, /*cnt=*/3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p0", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p1", /*k=*/0, /*off=*/0, /*cnt=*/0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, /*leaf_name=*/"p2", /*k=*/0, /*off=*/0, /*cnt=*/0}, + }; + const cbmpc_access_structure_t ac = { + /*nodes=*/nodes, + /*nodes_count=*/static_cast(sizeof(nodes) / sizeof(nodes[0])), + /*child_indices=*/child_indices, + /*child_indices_count=*/static_cast(sizeof(child_indices) / sizeof(child_indices[0])), + /*root_index=*/0, + }; + + const char* dkg_quorum[] = {"p0", "p1"}; + + std::vector public_keys(n, cmem_t{nullptr, 0}); + std::vector public_shares(n, cmems_t{0, nullptr, nullptr}); + std::vector private_shares(n, cmem_t{nullptr, 0}); + std::vector sids(n, cmem_t{nullptr, 0}); + std::vector rvs; + + run_mp( + peers, + [&](int i) { + const cbmpc_mp_job_t job = { + /*self=*/i, + /*party_names=*/party_names, + /*party_names_count=*/n, + /*transport=*/&transports[i], + }; + return cbmpc_tdh2_dkg_ac(&job, CBMPC_CURVE_P256, + /*sid_in=*/cmem_t{nullptr, 0}, &ac, dkg_quorum, + /*quorum_party_names_count=*/2, &public_keys[static_cast(i)], + &public_shares[static_cast(i)], &private_shares[static_cast(i)], + &sids[static_cast(i)]); + }, + rvs); + for (auto rv : rvs) ASSERT_EQ(rv, CBMPC_SUCCESS); + for (int i = 1; i < n; i++) { + expect_eq(public_keys[0], public_keys[static_cast(i)]); + expect_eq_cmems(public_shares[0], public_shares[static_cast(i)]); + expect_eq(sids[0], sids[static_cast(i)]); + } + + uint8_t plaintext_bytes[32]; + for (int i = 0; i < 32; i++) plaintext_bytes[i] = static_cast(0x5A ^ i); + const cmem_t plaintext = {plaintext_bytes, 32}; + const cmem_t label = {reinterpret_cast(const_cast("tdh2-label")), 9}; + + cmem_t ciphertext{nullptr, 0}; + ASSERT_EQ(cbmpc_tdh2_encrypt(public_keys[0], plaintext, label, &ciphertext), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_tdh2_verify(public_keys[0], ciphertext, label), CBMPC_SUCCESS); + + cmem_t partial0{nullptr, 0}; + cmem_t partial1{nullptr, 0}; + ASSERT_EQ(cbmpc_tdh2_partial_decrypt(private_shares[0], ciphertext, label, &partial0), CBMPC_SUCCESS); + ASSERT_EQ(cbmpc_tdh2_partial_decrypt(private_shares[1], ciphertext, label, &partial1), CBMPC_SUCCESS); + + const char* partial_names[] = {"p0", "p1"}; + const std::vector partial_vec = {partial0, partial1}; + const cmems_t partials_flat = pack_cmems_copy(partial_vec); + + cmem_t decrypted{nullptr, 0}; + ASSERT_EQ(cbmpc_tdh2_combine_ac(&ac, public_keys[0], party_names, n, public_shares[0], label, partial_names, 2, + partials_flat, ciphertext, &decrypted), + CBMPC_SUCCESS); + ASSERT_EQ(decrypted.size, 32); + EXPECT_EQ(std::memcmp(decrypted.data, plaintext_bytes, 32), 0); + + // Not enough partial decryptions should fail. + const char* one_name[] = {"p0"}; + const std::vector one_partial_vec = {partial0}; + const cmems_t one_partials = pack_cmems_copy(one_partial_vec); + cmem_t decrypted2{nullptr, 0}; + EXPECT_NE(cbmpc_tdh2_combine_ac(&ac, public_keys[0], party_names, n, public_shares[0], label, one_name, 1, + one_partials, ciphertext, &decrypted2), + CBMPC_SUCCESS); + + cbmpc_cmem_free(decrypted2); + cbmpc_cmems_free(one_partials); + cbmpc_cmem_free(decrypted); + cbmpc_cmems_free(partials_flat); + cbmpc_cmem_free(partial0); + cbmpc_cmem_free(partial1); + cbmpc_cmem_free(ciphertext); + for (auto m : sids) cbmpc_cmem_free(m); + for (auto m : private_shares) cbmpc_cmem_free(m); + for (auto m : public_shares) cbmpc_cmems_free(m); + for (auto m : public_keys) cbmpc_cmem_free(m); +} + +TEST(CApiTdh2, ValidatesArgs) { + cmem_t pk{reinterpret_cast(0x1), 123}; + cmem_t priv{reinterpret_cast(0x1), 123}; + cmem_t sid{reinterpret_cast(0x1), 123}; + cmems_t pub{123, reinterpret_cast(0x1), reinterpret_cast(0x1)}; + + const cbmpc_transport_t bad_transport = {/*ctx=*/nullptr, /*send=*/nullptr, /*receive=*/nullptr, + /*receive_all=*/nullptr, + /*free=*/nullptr}; + const char* names[2] = {"p0", "p1"}; + const cbmpc_mp_job_t bad_job = {/*self=*/0, /*party_names=*/names, /*party_names_count=*/2, + /*transport=*/&bad_transport}; + + EXPECT_EQ(cbmpc_tdh2_dkg_additive(&bad_job, CBMPC_CURVE_SECP256K1, &pk, &pub, &priv, &sid), E_BADARG); + EXPECT_EQ(pk.data, nullptr); + EXPECT_EQ(pk.size, 0); + EXPECT_EQ(pub.count, 0); + EXPECT_EQ(priv.data, nullptr); + EXPECT_EQ(priv.size, 0); + EXPECT_EQ(sid.data, nullptr); + EXPECT_EQ(sid.size, 0); + + EXPECT_EQ(cbmpc_tdh2_encrypt(cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, cmem_t{nullptr, 0}, nullptr), E_BADARG); +} + +// ------------ Disclaimer: All the following tests have been generated by AI ------------ + +// --------------------------------------------------------------------------- +// Negative tests +// --------------------------------------------------------------------------- + +#include + +namespace { + +const cmem_t empty_cmem{nullptr, 0}; +const cmems_t empty_cmems{0, nullptr, nullptr}; + +uint8_t g_garbage[] = {0xDE, 0xAD, 0xBE, 0xEF}; +const cmem_t garbage_cmem{g_garbage, 4}; + +const char* g_names[] = {"p0", "p1", "p2"}; +const cbmpc_mp_job_t g_bad_job = {0, g_names, 3, nullptr}; + +const int32_t g_child_indices[] = {1, 2, 3}; +const cbmpc_access_structure_node_t g_nodes[] = { + {CBMPC_ACCESS_STRUCTURE_NODE_THRESHOLD, nullptr, 2, 0, 3}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p0", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p1", 0, 0, 0}, + {CBMPC_ACCESS_STRUCTURE_NODE_LEAF, "p2", 0, 0, 0}, +}; +const cbmpc_access_structure_t g_ac = {g_nodes, 4, g_child_indices, 3, 0}; + +} // namespace + +// --- dkg_additive --- + +TEST(CApiTdh2Neg, DkgAdditiveNullOutPk) { + dylog_disable_scope_t no_log; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_additive(&g_bad_job, CBMPC_CURVE_SECP256K1, nullptr, &pub, &priv, &sid), E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAdditiveNullOutPubShares) { + dylog_disable_scope_t no_log; + cmem_t pk = empty_cmem; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_additive(&g_bad_job, CBMPC_CURVE_SECP256K1, &pk, nullptr, &priv, &sid), E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAdditiveNullOutPrivShare) { + dylog_disable_scope_t no_log; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_additive(&g_bad_job, CBMPC_CURVE_SECP256K1, &pk, &pub, nullptr, &sid), E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAdditiveNullOutSid) { + dylog_disable_scope_t no_log; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_additive(&g_bad_job, CBMPC_CURVE_SECP256K1, &pk, &pub, &priv, nullptr), E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAdditiveNullJob) { + dylog_disable_scope_t no_log; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_additive(nullptr, CBMPC_CURVE_SECP256K1, &pk, &pub, &priv, &sid), E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAdditiveInvalidCurve) { + dylog_disable_scope_t no_log; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_additive(&g_bad_job, static_cast(0), &pk, &pub, &priv, &sid), E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAdditiveEd25519Rejected) { + dylog_disable_scope_t no_log; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_additive(&g_bad_job, CBMPC_CURVE_ED25519, &pk, &pub, &priv, &sid), E_BADARG); +} + +// --- dkg_ac --- + +TEST(CApiTdh2Neg, DkgAcNullOutPk) { + dylog_disable_scope_t no_log; + const char* quorum[] = {"p0", "p1"}; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_ac(&g_bad_job, CBMPC_CURVE_P256, empty_cmem, &g_ac, quorum, 2, nullptr, &pub, &priv, &sid), + E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAcNullOutSid) { + dylog_disable_scope_t no_log; + const char* quorum[] = {"p0", "p1"}; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_ac(&g_bad_job, CBMPC_CURVE_P256, empty_cmem, &g_ac, quorum, 2, &pk, &pub, &priv, nullptr), + E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAcNullJob) { + dylog_disable_scope_t no_log; + const char* quorum[] = {"p0", "p1"}; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_ac(nullptr, CBMPC_CURVE_P256, empty_cmem, &g_ac, quorum, 2, &pk, &pub, &priv, &sid), + E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAcInvalidCurve) { + dylog_disable_scope_t no_log; + const char* quorum[] = {"p0", "p1"}; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_ac(&g_bad_job, static_cast(0), empty_cmem, &g_ac, quorum, 2, &pk, &pub, + &priv, &sid), + E_BADARG); +} + +TEST(CApiTdh2Neg, DkgAcNullAc) { + dylog_disable_scope_t no_log; + const char* quorum[] = {"p0", "p1"}; + cmem_t pk = empty_cmem; + cmems_t pub = empty_cmems; + cmem_t priv = empty_cmem; + cmem_t sid = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_dkg_ac(&g_bad_job, CBMPC_CURVE_P256, empty_cmem, nullptr, quorum, 2, &pk, &pub, &priv, &sid), + E_BADARG); +} + +// --- encrypt --- + +TEST(CApiTdh2Neg, EncryptNullOutput) { + dylog_disable_scope_t no_log; + EXPECT_EQ(cbmpc_tdh2_encrypt(garbage_cmem, garbage_cmem, garbage_cmem, nullptr), E_BADARG); +} + +TEST(CApiTdh2Neg, EncryptEmptyPk) { + dylog_disable_scope_t no_log; + cmem_t ct = empty_cmem; + EXPECT_NE(cbmpc_tdh2_encrypt(empty_cmem, garbage_cmem, garbage_cmem, &ct), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, EncryptGarbagePk) { + dylog_disable_scope_t no_log; + cmem_t ct = empty_cmem; + EXPECT_NE(cbmpc_tdh2_encrypt(garbage_cmem, garbage_cmem, garbage_cmem, &ct), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, EncryptEmptyPlaintext) { + dylog_disable_scope_t no_log; + cmem_t ct = empty_cmem; + EXPECT_NE(cbmpc_tdh2_encrypt(garbage_cmem, empty_cmem, garbage_cmem, &ct), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, EncryptEmptyLabel) { + dylog_disable_scope_t no_log; + cmem_t ct = empty_cmem; + EXPECT_NE(cbmpc_tdh2_encrypt(garbage_cmem, garbage_cmem, empty_cmem, &ct), CBMPC_SUCCESS); +} + +// --- verify --- + +TEST(CApiTdh2Neg, VerifyEmptyPk) { + dylog_disable_scope_t no_log; + EXPECT_NE(cbmpc_tdh2_verify(empty_cmem, garbage_cmem, garbage_cmem), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, VerifyGarbagePk) { + dylog_disable_scope_t no_log; + EXPECT_NE(cbmpc_tdh2_verify(garbage_cmem, garbage_cmem, garbage_cmem), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, VerifyEmptyCt) { + dylog_disable_scope_t no_log; + EXPECT_NE(cbmpc_tdh2_verify(garbage_cmem, empty_cmem, garbage_cmem), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, VerifyGarbageCt) { + dylog_disable_scope_t no_log; + EXPECT_NE(cbmpc_tdh2_verify(garbage_cmem, garbage_cmem, garbage_cmem), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, VerifyEmptyLabel) { + dylog_disable_scope_t no_log; + EXPECT_NE(cbmpc_tdh2_verify(garbage_cmem, garbage_cmem, empty_cmem), CBMPC_SUCCESS); +} + +// --- partial_decrypt --- + +TEST(CApiTdh2Neg, PartialDecryptNullOutput) { + dylog_disable_scope_t no_log; + EXPECT_EQ(cbmpc_tdh2_partial_decrypt(garbage_cmem, garbage_cmem, garbage_cmem, nullptr), E_BADARG); +} + +TEST(CApiTdh2Neg, PartialDecryptEmptyPrivShare) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_partial_decrypt(empty_cmem, garbage_cmem, garbage_cmem, &out), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, PartialDecryptGarbagePrivShare) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_partial_decrypt(garbage_cmem, garbage_cmem, garbage_cmem, &out), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, PartialDecryptEmptyCt) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_partial_decrypt(garbage_cmem, empty_cmem, garbage_cmem, &out), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, PartialDecryptGarbageCt) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_partial_decrypt(garbage_cmem, garbage_cmem, garbage_cmem, &out), CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, PartialDecryptEmptyLabel) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_partial_decrypt(garbage_cmem, garbage_cmem, empty_cmem, &out), CBMPC_SUCCESS); +} + +// --- combine_additive --- + +TEST(CApiTdh2Neg, CombineAdditiveNullOutput) { + dylog_disable_scope_t no_log; + EXPECT_EQ(cbmpc_tdh2_combine_additive(garbage_cmem, empty_cmems, garbage_cmem, empty_cmems, garbage_cmem, nullptr), + E_BADARG); +} + +TEST(CApiTdh2Neg, CombineAdditiveEmptyPk) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_combine_additive(empty_cmem, empty_cmems, garbage_cmem, empty_cmems, garbage_cmem, &out), + CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, CombineAdditiveGarbagePk) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_combine_additive(garbage_cmem, empty_cmems, garbage_cmem, empty_cmems, garbage_cmem, &out), + CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, CombineAdditiveEmptyCt) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_combine_additive(garbage_cmem, empty_cmems, garbage_cmem, empty_cmems, empty_cmem, &out), + CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, CombineAdditiveEmptyLabel) { + dylog_disable_scope_t no_log; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_combine_additive(garbage_cmem, empty_cmems, empty_cmem, empty_cmems, garbage_cmem, &out), + CBMPC_SUCCESS); +} + +// --- combine_ac --- + +TEST(CApiTdh2Neg, CombineAcNullOutput) { + dylog_disable_scope_t no_log; + const char* party_names[] = {"p0", "p1", "p2"}; + const char* partial_names[] = {"p0", "p1"}; + EXPECT_EQ(cbmpc_tdh2_combine_ac(&g_ac, garbage_cmem, party_names, 3, empty_cmems, garbage_cmem, partial_names, 2, + empty_cmems, garbage_cmem, nullptr), + E_BADARG); +} + +TEST(CApiTdh2Neg, CombineAcNullAc) { + dylog_disable_scope_t no_log; + const char* party_names[] = {"p0", "p1", "p2"}; + const char* partial_names[] = {"p0", "p1"}; + cmem_t out = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_combine_ac(nullptr, garbage_cmem, party_names, 3, empty_cmems, garbage_cmem, partial_names, 2, + empty_cmems, garbage_cmem, &out), + E_BADARG); +} + +TEST(CApiTdh2Neg, CombineAcNullPartyNames) { + dylog_disable_scope_t no_log; + const char* partial_names[] = {"p0", "p1"}; + cmem_t out = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_combine_ac(&g_ac, garbage_cmem, nullptr, 3, empty_cmems, garbage_cmem, partial_names, 2, + empty_cmems, garbage_cmem, &out), + E_BADARG); +} + +TEST(CApiTdh2Neg, CombineAcNullPartialNames) { + dylog_disable_scope_t no_log; + const char* party_names[] = {"p0", "p1", "p2"}; + cmem_t out = empty_cmem; + EXPECT_EQ(cbmpc_tdh2_combine_ac(&g_ac, garbage_cmem, party_names, 3, empty_cmems, garbage_cmem, nullptr, 2, + empty_cmems, garbage_cmem, &out), + E_BADARG); +} + +TEST(CApiTdh2Neg, CombineAcEmptyPk) { + dylog_disable_scope_t no_log; + const char* party_names[] = {"p0", "p1", "p2"}; + const char* partial_names[] = {"p0", "p1"}; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_combine_ac(&g_ac, empty_cmem, party_names, 3, empty_cmems, garbage_cmem, partial_names, 2, + empty_cmems, garbage_cmem, &out), + CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, CombineAcEmptyCt) { + dylog_disable_scope_t no_log; + const char* party_names[] = {"p0", "p1", "p2"}; + const char* partial_names[] = {"p0", "p1"}; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_combine_ac(&g_ac, garbage_cmem, party_names, 3, empty_cmems, garbage_cmem, partial_names, 2, + empty_cmems, empty_cmem, &out), + CBMPC_SUCCESS); +} + +TEST(CApiTdh2Neg, CombineAcEmptyLabel) { + dylog_disable_scope_t no_log; + const char* party_names[] = {"p0", "p1", "p2"}; + const char* partial_names[] = {"p0", "p1"}; + cmem_t out = empty_cmem; + EXPECT_NE(cbmpc_tdh2_combine_ac(&g_ac, garbage_cmem, party_names, 3, empty_cmems, empty_cmem, partial_names, 2, + empty_cmems, garbage_cmem, &out), + CBMPC_SUCCESS); +} diff --git a/tests/unit/c_api/test_transport_harness.h b/tests/unit/c_api/test_transport_harness.h new file mode 100644 index 00000000..c0049936 --- /dev/null +++ b/tests/unit/c_api/test_transport_harness.h @@ -0,0 +1,179 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "utils/local_network/network_context.h" + +namespace coinbase::testutils::capi_harness { + +using coinbase::buf_t; +using coinbase::error_t; +using coinbase::mem_t; +using coinbase::api::party_idx_t; +using coinbase::testutils::mpc_net_context_t; + +struct transport_ctx_t { + std::shared_ptr net; + std::atomic* free_calls = nullptr; +}; + +inline cbmpc_error_t transport_send(void* ctx, int32_t receiver, const uint8_t* data, int size) { + if (!ctx) return E_BADARG; + if (size < 0) return E_BADARG; + if (size > 0 && !data) return E_BADARG; + auto* c = static_cast(ctx); + c->net->send(static_cast(receiver), mem_t(data, size)); + return CBMPC_SUCCESS; +} + +inline cbmpc_error_t transport_receive(void* ctx, int32_t sender, cmem_t* out_msg) { + if (!out_msg) return E_BADARG; + *out_msg = cmem_t{nullptr, 0}; + if (!ctx) return E_BADARG; + + auto* c = static_cast(ctx); + buf_t msg; + const error_t rv = c->net->receive(static_cast(sender), msg); + if (rv) return rv; + + const int n = msg.size(); + if (n < 0) return E_FORMAT; + if (n == 0) return CBMPC_SUCCESS; + + out_msg->data = static_cast(cbmpc_malloc(static_cast(n))); + if (!out_msg->data) return E_INSUFFICIENT; + out_msg->size = n; + std::memmove(out_msg->data, msg.data(), static_cast(n)); + return CBMPC_SUCCESS; +} + +inline cbmpc_error_t transport_receive_all(void* ctx, const int32_t* senders, int senders_count, cmems_t* out_msgs) { + if (!out_msgs) return E_BADARG; + *out_msgs = cmems_t{0, nullptr, nullptr}; + if (!ctx) return E_BADARG; + if (senders_count < 0) return E_BADARG; + if (senders_count > 0 && !senders) return E_BADARG; + + auto* c = static_cast(ctx); + std::vector s; + s.reserve(static_cast(senders_count)); + for (int i = 0; i < senders_count; i++) s.push_back(static_cast(senders[i])); + + std::vector msgs; + const error_t rv = c->net->receive_all(s, msgs); + if (rv) return rv; + if (msgs.size() != static_cast(senders_count)) return E_GENERAL; + + // Flatten into (data + sizes) buffers. + int total = 0; + for (const auto& m : msgs) { + const int sz = m.size(); + if (sz < 0) return E_FORMAT; + if (sz > INT_MAX - total) return E_RANGE; + total += sz; + } + + out_msgs->count = senders_count; + out_msgs->sizes = static_cast(cbmpc_malloc(sizeof(int) * static_cast(senders_count))); + if (!out_msgs->sizes) { + *out_msgs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } + + if (total > 0) { + out_msgs->data = static_cast(cbmpc_malloc(static_cast(total))); + if (!out_msgs->data) { + cbmpc_free(out_msgs->sizes); + *out_msgs = cmems_t{0, nullptr, nullptr}; + return E_INSUFFICIENT; + } + } + + int offset = 0; + for (int i = 0; i < senders_count; i++) { + const int sz = msgs[i].size(); + out_msgs->sizes[i] = sz; + if (sz) { + std::memmove(out_msgs->data + offset, msgs[i].data(), static_cast(sz)); + offset += sz; + } + } + + return CBMPC_SUCCESS; +} + +inline void transport_free(void* ctx, void* ptr) { + if (!ptr) return; + auto* c = static_cast(ctx); + if (c && c->free_calls) c->free_calls->fetch_add(1); + cbmpc_free(ptr); +} + +inline cbmpc_transport_t make_transport(transport_ctx_t* ctx) { + return cbmpc_transport_t{ + /*ctx=*/ctx, + /*send=*/transport_send, + /*receive=*/transport_receive, + /*receive_all=*/transport_receive_all, + /*free=*/transport_free, + }; +} + +template +inline void run_2pc(const std::shared_ptr& c1, const std::shared_ptr& c2, F1&& f1, + F2&& f2, cbmpc_error_t& out_rv1, cbmpc_error_t& out_rv2) { + c1->reset(); + c2->reset(); + + std::atomic aborted{false}; + + std::thread t1([&] { + out_rv1 = f1(); + if (out_rv1 && !aborted.exchange(true)) { + c1->abort(); + c2->abort(); + } + }); + std::thread t2([&] { + out_rv2 = f2(); + if (out_rv2 && !aborted.exchange(true)) { + c1->abort(); + c2->abort(); + } + }); + + t1.join(); + t2.join(); +} + +template +inline void run_mp(const std::vector>& peers, F&& f, + std::vector& out_rv) { + for (const auto& p : peers) p->reset(); + + out_rv.assign(peers.size(), UNINITIALIZED_ERROR); + std::atomic aborted{false}; + std::vector threads; + threads.reserve(peers.size()); + + for (size_t i = 0; i < peers.size(); i++) { + threads.emplace_back([&, i] { + out_rv[i] = f(static_cast(i)); + if (out_rv[i] && !aborted.exchange(true)) { + for (const auto& p : peers) p->abort(); + } + }); + } + for (auto& t : threads) t.join(); +} + +} // namespace coinbase::testutils::capi_harness diff --git a/tests/unit/core/test_buf.cpp b/tests/unit/core/test_buf.cpp index 4b6e0f2d..363ac1fc 100644 --- a/tests/unit/core/test_buf.cpp +++ b/tests/unit/core/test_buf.cpp @@ -1,5 +1,6 @@ #include +#include #include #include @@ -36,6 +37,34 @@ TEST(Buf, ConstructFromMem) { EXPECT_EQ(buf.to_string(), test_str); } +TEST(Mem, NegativeSizeToStringIsSafe) { + const coinbase::mem_t mem(nullptr, -1); + EXPECT_EQ(mem.to_string(), ""); +} + +TEST(Mem, NegativeSizeStreamIsSafe) { + const coinbase::mem_t mem(nullptr, -1); + std::ostringstream oss; + oss << mem; +#ifdef _DEBUG + EXPECT_EQ(oss.str(), ""); +#else + EXPECT_EQ(oss.str(), ""); +#endif +} + +TEST(Mem, NullDataPositiveSizeIsSafe) { + const coinbase::mem_t mem(nullptr, 5); + EXPECT_EQ(mem.to_string(), ""); + std::ostringstream oss; + oss << mem; +#ifdef _DEBUG + EXPECT_EQ(oss.str(), ""); +#else + EXPECT_EQ(oss.str(), ""); +#endif +} + TEST(Buf, CopyConstructor) { coinbase::buf_t original(5); for (int i = 0; i < 5; ++i) { @@ -88,7 +117,8 @@ TEST(Buf, Resize) { for (int i = 0; i < 5; ++i) { EXPECT_EQ(buf[i], static_cast(i)); } - // The remaining bytes might be uninitialized, but ensure no crash occurs. + // Note: `resize()` preserves previous bytes but does not guarantee initialization + // of newly-grown regions. } TEST(Buf, PlusOperator) { diff --git a/tests/unit/core/test_buf128.cpp b/tests/unit/core/test_buf128.cpp index a6229626..a2ca7085 100644 --- a/tests/unit/core/test_buf128.cpp +++ b/tests/unit/core/test_buf128.cpp @@ -107,6 +107,13 @@ TEST(Buf128, BitwiseOperations) { } TEST(Buf128, Shifts) { + // Shift by 0 should be a no-op. + { + auto z = buf128_t::make(0x0123456789ABCDEFULL, 0xFEDCBA9876543210ULL); + EXPECT_EQ((z << 0), z); + EXPECT_EQ((z >> 0), z); + } + // left shift auto b = buf128_t::make(0x00000000000000FFULL, 0ULL); b = b << 8; diff --git a/tests/unit/core/test_buf256.cpp b/tests/unit/core/test_buf256.cpp index fc539ca0..ed41ed02 100644 --- a/tests/unit/core/test_buf256.cpp +++ b/tests/unit/core/test_buf256.cpp @@ -116,6 +116,14 @@ TEST(Buf256Test, BitwiseOperations) { } TEST(Buf256Test, Shifts) { + // Shift by 0 should be a no-op. + { + auto z = buf256_t::make(buf128_t::make(0x0123456789ABCDEFULL, 0x0011223344556677ULL), + buf128_t::make(0x8899AABBCCDDEEFFULL, 0xFEDCBA9876543210ULL)); + EXPECT_EQ((z << 0), z); + EXPECT_EQ((z >> 0), z); + } + // left shift auto lo = buf128_t::make(0x00000000000000FFULL, 0ULL); auto hi = buf128_t::make(0ULL, 0ULL); diff --git a/tests/unit/core/test_convert.cpp b/tests/unit/core/test_convert.cpp index e1ac0e3d..792f72a7 100644 --- a/tests/unit/core/test_convert.cpp +++ b/tests/unit/core/test_convert.cpp @@ -2,8 +2,8 @@ #include #include -#include -#include +#include +#include #include "utils/test_macros.h" @@ -122,4 +122,37 @@ TEST(CoreConvert, CustomStruct) { EXPECT_EQ(out.s, ""); } -} // namespace \ No newline at end of file +TEST(CoreConvert, ConvertLenRejectsOversizedLengths) { + // Encode a 4-byte length prefix > converter_t::MAX_CONVERT_LEN (64 MiB). + // len = 0x04000001 -> bytes: 0xE4 0x00 0x00 0x01 + byte_t bin[] = {0xE4, 0x00, 0x00, 0x01}; + converter_t converter(mem_t(bin, sizeof(bin))); + uint32_t len = 0; + converter.convert_len(len); + EXPECT_NE(converter.get_rv(), SUCCESS); + EXPECT_EQ(len, 0u); +} + +TEST(CoreConvert, ConvertLenAllowsMaxValue) { + // len = 0x04000000 (64 MiB) -> bytes: 0xE4 0x00 0x00 0x00 + byte_t bin[] = {0xE4, 0x00, 0x00, 0x00}; + converter_t converter(mem_t(bin, sizeof(bin))); + uint32_t len = 0; + converter.convert_len(len); + EXPECT_EQ(converter.get_rv(), SUCCESS); + EXPECT_EQ(len, converter_t::MAX_CONVERT_LEN); +} + +TEST(CoreConvert, ConvertLastRejectsNegRemainingSize) { + byte_t bin[] = {0x42}; + converter_t converter(mem_t(bin, sizeof(bin))); + // Simulate parser-state corruption / misuse: offset moved past the source size. + converter.forward(2); + + buf_t out; + out.convert_last(converter); + EXPECT_NE(converter.get_rv(), SUCCESS); + EXPECT_EQ(out.size(), 0); +} + +} // namespace diff --git a/tests/unit/core/test_error.cpp b/tests/unit/core/test_error.cpp index 5db4880a..548616dc 100644 --- a/tests/unit/core/test_error.cpp +++ b/tests/unit/core/test_error.cpp @@ -7,6 +7,8 @@ namespace { +using coinbase::error_t; + error_t inner_func() { return coinbase::error(E_BADARG, "inner error msg"); } error_t outer_func() { diff --git a/tests/unit/core/test_util.cpp b/tests/unit/core/test_util.cpp index 84a175fa..04d0cff6 100644 --- a/tests/unit/core/test_util.cpp +++ b/tests/unit/core/test_util.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include using namespace coinbase; @@ -75,10 +75,12 @@ TEST(CoreUtils, LookupInMap) { auto [found1, value1] = lookup(sampleMap, 2); EXPECT_TRUE(found1); - EXPECT_EQ(value1, "two"); + ASSERT_NE(value1, nullptr); + EXPECT_EQ(*value1, "two"); auto [found2, value2] = lookup(sampleMap, 99); EXPECT_FALSE(found2); + EXPECT_EQ(value2, nullptr); } // Test has in container diff --git a/tests/unit/crypto/test_base.cpp b/tests/unit/crypto/test_base.cpp index 37c30417..d2d37f73 100644 --- a/tests/unit/crypto/test_base.cpp +++ b/tests/unit/crypto/test_base.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "utils/test_macros.h" @@ -59,6 +59,18 @@ TEST(BaseTest, TestGenRandomHelpers) { SUCCEED() << "Generated a random int: " << r_int; } +TEST(BaseTest, BitsSelfAppendIsSafe) { + // Ensure `x += x` works correctly and does not rely on dangling views during resize(). + // Use a byte-aligned bit count to exercise the fast-path in bits_t::operator+=. + bits_t x = gen_random_bits(128); + bits_t expected = x + x; + + bits_t y = x; + y += y; + // `bits_t::equ` is intentionally not part of the API; compare the binary representation instead. + EXPECT_TRUE(mem_t(expected) == mem_t(y)); +} + TEST(BaseTest, TestAES_CTR) { buf_t key = bn_t(0x00).to_bin(16); buf_t iv = bn_t(0x01).to_bin(16); diff --git a/tests/unit/crypto/test_base_bn.cpp b/tests/unit/crypto/test_base_bn.cpp index d05e0257..e1c8bbe8 100644 --- a/tests/unit/crypto/test_base_bn.cpp +++ b/tests/unit/crypto/test_base_bn.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "utils/test_macros.h" @@ -28,6 +28,38 @@ TEST(BigNumber, Subtraction) { EXPECT_EQ(bn_t(999) - bn_t(0), 999); } +TEST(BigNumber, IntOperatorsHandleIntMin) { + const int v = INT_MIN; + + bn_t abs_v; + abs_v.set_int64(2147483648LL); + + // Non-modular operators + EXPECT_EQ(bn_t(0) + v, bn_t(v)); + EXPECT_EQ(bn_t(1) * v, bn_t(v)); + + EXPECT_EQ(bn_t(0) - v, abs_v); + EXPECT_EQ(bn_t(0) * v, 0); + + bn_t x = 0; + EXPECT_NO_THROW(x += v); + EXPECT_EQ(x, bn_t(v)); + + bn_t y = 0; + EXPECT_NO_THROW(y -= v); + EXPECT_EQ(y, abs_v); + + bn_t z = 1; + EXPECT_NO_THROW(z *= v); + EXPECT_EQ(z, bn_t(v)); + + // Modular path uses `mod_t::mod(int)` internally; ensure INT_MIN does not trigger UB. + const mod_t& q = crypto::curve_ed25519.order(); + bn_t expected_mod; + MODULO(q) { expected_mod = -abs_v; } + MODULO(q) { EXPECT_EQ(bn_t(0) + v, expected_mod); } +} + TEST(BigNumber, Multiplication) { EXPECT_EQ(bn_t(123) * bn_t(456), 56088); EXPECT_EQ(bn_t(-123) * bn_t(456), -56088); @@ -89,6 +121,18 @@ TEST(BigNumber, ShiftOperators) { bn_t val3 = val2 >> 2; EXPECT_EQ(val3, 10); + + // Negative shifts are treated as no-ops. + bn_t neg1(32); + neg1 <<= -5; + EXPECT_EQ(neg1, 32); + + bn_t neg2(1); + neg2 >>= -10; + EXPECT_EQ(neg2, 1); + + EXPECT_EQ(bn_t(5) << -3, 5); + EXPECT_EQ(bn_t(10) >> -2, 10); } TEST(BigNumber, BitwiseSetAndCheck) { @@ -130,6 +174,46 @@ TEST(BigNumber, RangeCheck) { EXPECT_ER_MSG(check_open_range(bn_t(3), bn_t(5), bn_t(5)), "check_open_range failed"); } +TEST(BigNumber, FromStringValidInput) { + bn_t result; + EXPECT_OK(bn_t::from_string("0", result)); + EXPECT_EQ(result, 0); + EXPECT_OK(bn_t::from_string("12345", result)); + EXPECT_EQ(result, 12345); + EXPECT_OK(bn_t::from_string("-42", result)); + EXPECT_EQ(result, -42); +} + +TEST(BigNumber, FromStringRejectsInvalidInput) { + bn_t result; + EXPECT_ER(bn_t::from_string("", result)); + EXPECT_ER(bn_t::from_string("not_a_number", result)); + EXPECT_ER(bn_t::from_string("1234ncc", result)); + EXPECT_ER(bn_t::from_string("0xAB", result)); + EXPECT_ER(bn_t::from_string(nullptr, result)); +} + +TEST(BigNumber, FromHexValidInput) { + bn_t result; + EXPECT_OK(bn_t::from_hex("0", result)); + EXPECT_EQ(result, 0); + EXPECT_OK(bn_t::from_hex("FF", result)); + EXPECT_EQ(result, 255); + EXPECT_OK(bn_t::from_hex("-1A", result)); + EXPECT_EQ(result, -26); + EXPECT_OK(bn_t::from_hex("abc", result)); + EXPECT_EQ(result, 0xabc); +} + +TEST(BigNumber, FromHexRejectsInvalidInput) { + bn_t result; + EXPECT_ER(bn_t::from_hex("", result)); + EXPECT_ER(bn_t::from_hex("0xAB", result)); + EXPECT_ER(bn_t::from_hex("ZZZZ", result)); + EXPECT_ER(bn_t::from_hex("1234ncc", result)); + EXPECT_ER(bn_t::from_hex(nullptr, result)); +} + TEST(BigNumber, CompareMatchesOpenSSL) { auto expect_cmp = [](const bn_t& a, const bn_t& b) { int ct = bn_t::compare(a, b); @@ -194,4 +278,36 @@ TEST(BigNumber, GetBinSize) { EXPECT_EQ(a, 1); EXPECT_EQ(a.get_bin_size(), 1); } + +TEST(BigNumber, ConvertDeserializeOversizedNoThrow) { + const uint32_t value_size = bn_t::MAX_SERIALIZED_BIGNUM_BYTES + 1; + const uint32_t header = value_size << 1; // neg=0 + + int header_size = 0; + { + converter_t sizer(true); + uint32_t tmp = header; + sizer.convert_len(tmp); + header_size = sizer.get_offset(); + } + + ASSERT_GT(header_size, 0); + ASSERT_LE(value_size, static_cast(INT_MAX - header_size)); + const int value_size_int = static_cast(value_size); + + buf_t buffer(header_size + value_size_int); + { + converter_t writer(buffer.data()); + uint32_t tmp = header; + writer.convert_len(tmp); + memset(writer.current(), 0x7F, static_cast(value_size_int)); + writer.forward(value_size_int); + } + + bn_t result; + converter_t reader{mem_t(buffer)}; + EXPECT_NO_THROW(result.convert(reader)); + EXPECT_TRUE(reader.is_error()); + EXPECT_EQ(reader.get_offset(), header_size); +} } // namespace diff --git a/tests/unit/crypto/test_base_ecc.cpp b/tests/unit/crypto/test_base_ecc.cpp index 9bad17ec..c0d9bf3a 100644 --- a/tests/unit/crypto/test_base_ecc.cpp +++ b/tests/unit/crypto/test_base_ecc.cpp @@ -1,12 +1,13 @@ #include -#include -#include +#include +#include #include "utils/test_macros.h" namespace { using namespace coinbase::crypto; +using coinbase::buf_t; class ECC : public ::testing::Test { protected: @@ -21,8 +22,8 @@ class ECC : public ::testing::Test { TEST_F(ECC, secp256k1) { ecurve_t curve = curve_secp256k1; - const mod_t &q = curve.order(); - const auto &G = curve.generator(); + const mod_t& q = curve.order(); + const auto& G = curve.generator(); EXPECT_TRUE(G.is_on_curve()); ecc_point_t GG = G; @@ -76,7 +77,7 @@ TEST_F(ECC, SigningScheme2) { std::cout << "======================================== len: " << len << std::endl; for (int i = 0; i < 5; i++) { ecurve_t curve = curve_ed25519; - const mod_t &q = curve.order(); + const mod_t& q = curve.order(); ecc_prv_key_t prv_key; prv_key.generate(curve); diff --git a/tests/unit/crypto/test_base_hash.cpp b/tests/unit/crypto/test_base_hash.cpp index f4526bbf..796febbb 100644 --- a/tests/unit/crypto/test_base_hash.cpp +++ b/tests/unit/crypto/test_base_hash.cpp @@ -1,14 +1,14 @@ #include #include -#include +#include namespace { using namespace coinbase; using namespace coinbase::crypto; -TEST(BaseHash, MemTVectorEncodesBoundariesAndLength) { +TEST(BaseHash, MemVecEncodesBoundsAndLen) { const std::vector msgs_a = {mem_t("a"), mem_t("bc")}; // concat: "abc" const std::vector msgs_b = {mem_t("ab"), mem_t("c")}; // concat: "abc" const std::vector msgs_c = {mem_t("abc")}; // concat: "abc" diff --git a/tests/unit/crypto/test_base_mod.cpp b/tests/unit/crypto/test_base_mod.cpp index 4454ce56..c9a532df 100644 --- a/tests/unit/crypto/test_base_mod.cpp +++ b/tests/unit/crypto/test_base_mod.cpp @@ -1,9 +1,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include "utils/test_macros.h" @@ -52,8 +52,8 @@ TEST(Mod, Add) { EXPECT_EQ(c, 13); #ifdef _DEBUG - EXPECT_DEATH(q.add(overflow_a, b), "out of range for constant-time operations"); - EXPECT_DEATH(q.add(a, overflow_b), "out of range for constant-time operations"); + EXPECT_CB_ASSERT(q.add(overflow_a, b), "out of range for constant-time operations"); + EXPECT_CB_ASSERT(q.add(a, overflow_b), "out of range for constant-time operations"); #endif { diff --git a/tests/unit/crypto/test_base_pki.cpp b/tests/unit/crypto/test_base_pki.cpp index 38297355..05b8dad6 100644 --- a/tests/unit/crypto/test_base_pki.cpp +++ b/tests/unit/crypto/test_base_pki.cpp @@ -1,43 +1,14 @@ #include -#include -#include +#include #include "utils/test_macros.h" -extern "C" { -// Override weak symbols for FFI KEM to provide simple deterministic stubs -static int test_kem_encap(cmem_t /*ek_bytes*/, cmem_t rho, cmem_t* kem_ct_out, cmem_t* kem_ss_out) { - buf_t ss = coinbase::ffi::view(rho).take(32); - buf_t ct = ss; // trivial ct for stub - *kem_ct_out = coinbase::ffi::copy_to_cmem(ct); - *kem_ss_out = coinbase::ffi::copy_to_cmem(ss); - return 0; -} - -static int test_kem_decap(const void* /*dk_handle*/, cmem_t kem_ct, cmem_t* kem_ss_out) { - *kem_ss_out = coinbase::ffi::copy_to_cmem(coinbase::ffi::view(kem_ct)); - return 0; -} - -static int test_kem_dk_to_ek(const void* dk_handle, cmem_t* out_ek) { - if (dk_handle) { - const cmem_t* cm = static_cast(dk_handle); - *out_ek = coinbase::ffi::copy_to_cmem(coinbase::ffi::view(*cm)); - } else { - *out_ek = cmem_t{nullptr, 0}; - } - return 0; -} - -ffi_kem_encap_fn get_ffi_kem_encap_fn(void) { return test_kem_encap; } -ffi_kem_decap_fn get_ffi_kem_decap_fn(void) { return test_kem_decap; } -ffi_kem_dk_to_ek_fn get_ffi_kem_dk_to_ek_fn(void) { return test_kem_dk_to_ek; } -} - namespace { using namespace coinbase::crypto; +using coinbase::buf_t; +using coinbase::mem_t; class PKI : public ::testing::Test { protected: @@ -93,59 +64,17 @@ TEST_F(PKI, ECDH_P256_KEM_EncapDecap_HPKE) { EXPECT_EQ(ss1, ss2); } -TEST_F(PKI, HybrideRSAEncryptDecrypt) { - prv_key_t prv_key = prv_key_t::from(rsa_prv_key); - pub_key_t pub_key = pub_key_t::from(rsa_pub_key); - +TEST_F(PKI, RSA_KEM_AEAD_EncryptDecrypt) { drbg_aes_ctr_t drbg(gen_random(32)); - ciphertext_t ciphertext; - ciphertext.encrypt(pub_key, label, plaintext, &drbg); - EXPECT_EQ(ciphertext.key_type, key_type_e::RSA); - - { - buf_t decrypted; - EXPECT_OK(ciphertext.decrypt(prv_key, label, decrypted)); - EXPECT_EQ(decrypted, plaintext); - } - { - buf_t decrypted; - EXPECT_OK(ciphertext.decrypt(prv_key, label, decrypted)); - EXPECT_EQ(decrypted, plaintext); - } - - { - buf_t decrypted; - EXPECT_OK(ciphertext.decrypt(prv_key, label, decrypted)); - EXPECT_EQ(decrypted, plaintext); - } - { - buf_t decrypted; - EXPECT_OK(ciphertext.decrypt(prv_key, label, decrypted)); - EXPECT_EQ(decrypted, plaintext); - } -} - -TEST_F(PKI, POINT_CONVERSION_HYBRID) { - prv_key_t prv_key = prv_key_t::from(ecc_prv_key); - pub_key_t pub_key = pub_key_t::from(ecc_pub_key); - - drbg_aes_ctr_t drbg(gen_random(32)); - - ciphertext_t ciphertext; - ciphertext.encrypt(pub_key, label, plaintext, &drbg); - EXPECT_EQ(ciphertext.key_type, key_type_e::ECC); + rsa_pke_t::ct_t c1, c2; + EXPECT_OK(c1.encrypt(rsa_pub_key, label, plaintext, &drbg)); + EXPECT_OK(c2.encrypt(rsa_pub_key, label, plaintext, &drbg)); + EXPECT_NE(coinbase::convert(c1), coinbase::convert(c2)); - { - buf_t decrypted; - EXPECT_OK(ciphertext.decrypt(prv_key, label, decrypted)); - EXPECT_EQ(decrypted, plaintext); - } - { - buf_t decrypted; - EXPECT_OK(ciphertext.decrypt(prv_key, label, decrypted)); - EXPECT_EQ(decrypted, plaintext); - } + buf_t decrypted; + EXPECT_OK(c1.decrypt(rsa_prv_key, label, decrypted)); + EXPECT_EQ(decrypted, plaintext); } // ----------------------------------------------------------------------------- @@ -185,25 +114,4 @@ TEST(HPKE_KEM_P256, DeterministicVector) { SUCCEED(); } -TEST(FFI_KEM, EncryptDecrypt) { - ffi_kem_ek_t ek; - ek = buf_t("dummy-ek"); - - cmem_t dk_bytes{reinterpret_cast(const_cast("dummy-dk")), 8}; - ffi_kem_dk_t dk; - dk.handle = static_cast(&dk_bytes); - - buf_t label = buf_t("label"); - buf_t plaintext = buf_t("plaintext for FFI KEM"); - - drbg_aes_ctr_t drbg(gen_random(32)); - - kem_aead_ciphertext_t ct; - EXPECT_OK(ct.encrypt(ek, label, plaintext, &drbg)); - - buf_t decrypted; - EXPECT_OK(ct.decrypt(dk, label, decrypted)); - EXPECT_EQ(decrypted, plaintext); -} - } // namespace \ No newline at end of file diff --git a/tests/unit/crypto/test_base_rsa.cpp b/tests/unit/crypto/test_base_rsa.cpp index 64e379b5..99ad1858 100644 --- a/tests/unit/crypto/test_base_rsa.cpp +++ b/tests/unit/crypto/test_base_rsa.cpp @@ -1,13 +1,14 @@ #include -#include -#include -#include +#include +#include +#include #include "utils/test_macros.h" namespace { using namespace coinbase::crypto; +using coinbase::buf_t; TEST(RSA, EncryptDecrypt) { rsa_prv_key_t prv_key; diff --git a/tests/unit/crypto/test_commitment.cpp b/tests/unit/crypto/test_commitment.cpp index 48294154..0d1d3152 100644 --- a/tests/unit/crypto/test_commitment.cpp +++ b/tests/unit/crypto/test_commitment.cpp @@ -1,14 +1,15 @@ #include -#include -#include -#include +#include +#include +#include #include "utils/test_macros.h" namespace { using namespace coinbase::crypto; +using coinbase::buf_t; TEST(CryptoCommitment, AdditionalInputSid) { buf_t sid = gen_random_bitlen(SEC_P_COM); @@ -61,7 +62,7 @@ TEST(CryptoCommitment, LocalSidAndReceiverPid) { EXPECT_ER(com3.open(a)); // incorrect receiver pid } -TEST(CryptoCommitment, AdditionalInputSid_AlternativeFormat) { +TEST(CryptoCommitment, AdditionalSid_AltFormat) { buf_t sid = gen_random_bitlen(SEC_P_COM); mpc_pid_t pid = pid_from_name("test"); commitment_t com1; diff --git a/tests/unit/crypto/test_ecc.cpp b/tests/unit/crypto/test_ecc.cpp index d57c2e2a..ea0bb766 100644 --- a/tests/unit/crypto/test_ecc.cpp +++ b/tests/unit/crypto/test_ecc.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include using namespace coinbase; using namespace coinbase::crypto; diff --git a/tests/unit/crypto/test_eddsa.cpp b/tests/unit/crypto/test_eddsa.cpp index abf97fb9..502fd189 100644 --- a/tests/unit/crypto/test_eddsa.cpp +++ b/tests/unit/crypto/test_eddsa.cpp @@ -1,15 +1,15 @@ #include -#include -#include -#include +#include +#include +#include using namespace coinbase; using namespace coinbase::crypto; namespace { -TEST(CryptoEdDSA, RejectTorsionPointsAndFixInfinityEquality) { +TEST(CryptoEdDSA, RejectTorsionAndFixInfinityEq) { crypto::vartime_scope_t vartime_scope; ecurve_t curve = crypto::curve_ed25519; @@ -103,7 +103,7 @@ TEST(CryptoEdDSA, hash_to_point) { EXPECT_EQ(in_group_counter, point_counter); } -TEST(CryptoEdDSA, mul_by_order_is_infinity_for_subgroup_points) { +TEST(CryptoEdDSA, MulByOrderIsInfinityForSubgroup) { crypto::vartime_scope_t vartime_scope; ecurve_t curve = crypto::curve_ed25519; const bn_t q = curve.order().value(); diff --git a/tests/unit/crypto/test_elgamal.cpp b/tests/unit/crypto/test_elgamal.cpp index 70e380dc..04ae4a22 100644 --- a/tests/unit/crypto/test_elgamal.cpp +++ b/tests/unit/crypto/test_elgamal.cpp @@ -1,6 +1,6 @@ #include -#include +#include using namespace coinbase; using namespace coinbase::crypto; diff --git a/tests/unit/crypto/test_hkdf_rfc5869.cpp b/tests/unit/crypto/test_hkdf_rfc5869.cpp index 00640e52..a8f024a6 100644 --- a/tests/unit/crypto/test_hkdf_rfc5869.cpp +++ b/tests/unit/crypto/test_hkdf_rfc5869.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "utils/test_macros.h" diff --git a/tests/unit/crypto/test_hpke_rfc9180_json.cpp b/tests/unit/crypto/test_hpke_rfc9180_json.cpp index 4d6faa48..36aa9c72 100644 --- a/tests/unit/crypto/test_hpke_rfc9180_json.cpp +++ b/tests/unit/crypto/test_hpke_rfc9180_json.cpp @@ -2,8 +2,8 @@ #include #include -#include -#include +#include +#include #include "utils/test_macros.h" diff --git a/tests/unit/crypto/test_lagrange.cpp b/tests/unit/crypto/test_lagrange.cpp index d14aae1c..221f1cf7 100644 --- a/tests/unit/crypto/test_lagrange.cpp +++ b/tests/unit/crypto/test_lagrange.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "utils/test_macros.h" diff --git a/tests/unit/crypto/test_ro.cpp b/tests/unit/crypto/test_ro.cpp index a9c5fd1f..00dc3dd9 100644 --- a/tests/unit/crypto/test_ro.cpp +++ b/tests/unit/crypto/test_ro.cpp @@ -1,11 +1,14 @@ #include -#include +#include using namespace coinbase::crypto; namespace { +using coinbase::buf_t; +using coinbase::mem_t; + TEST(RandomOracle, EncodeAndUpdateHappyPath) { ro::hmac_state_t s1; s1.encode_and_update(0); diff --git a/tests/unit/crypto/test_secret_sharing.cpp b/tests/unit/crypto/test_secret_sharing.cpp index 5cebfae2..08825f27 100644 --- a/tests/unit/crypto/test_secret_sharing.cpp +++ b/tests/unit/crypto/test_secret_sharing.cpp @@ -1,8 +1,8 @@ #include -#include -#include -#include +#include +#include +#include #include "utils/data/ac.h" #include "utils/test_macros.h" @@ -34,8 +34,8 @@ TEST_F(SSNode, ValidateTestNodes) { TEST_F(SSNode, InvalidNode) { node_t root(node_e::AND, "root", 0); - node_t *child1 = new node_t(node_e::LEAF, "child1", 0); - node_t *child2 = new node_t(node_e::LEAF, "child2", 0); + node_t* child1 = new node_t(node_e::LEAF, "child1", 0); + node_t* child2 = new node_t(node_e::LEAF, "child2", 0); root.add_child_node(child1); root.add_child_node(child2); @@ -44,13 +44,13 @@ TEST_F(SSNode, InvalidNode) { root.name = ""; EXPECT_OK(root.validate_tree()); - node_t *child3 = new node_t(node_e::THRESHOLD, "child3", 2); + node_t* child3 = new node_t(node_e::THRESHOLD, "child3", 2); root.add_child_node(child3); EXPECT_ER(root.validate_tree()); // threshold node with no child - node_t *child31 = new node_t(node_e::LEAF, "child31", 0); + node_t* child31 = new node_t(node_e::LEAF, "child31", 0); child3->add_child_node(child31); EXPECT_ER(root.validate_tree()); // threshold node with not enough child - node_t *child32 = new node_t(node_e::LEAF, "child32", 0); + node_t* child32 = new node_t(node_e::LEAF, "child32", 0); child3->add_child_node(child32); EXPECT_OK(root.validate_tree()); // threshold node with not enough child @@ -58,8 +58,8 @@ TEST_F(SSNode, InvalidNode) { } TEST_F(SSNode, NodeClone) { - for (const auto &root : all_roots) { - node_t *clone = root->clone(); + for (const auto& root : all_roots) { + node_t* clone = root->clone(); EXPECT_EQ(clone->children.size(), root->children.size()); delete clone; } @@ -79,7 +79,7 @@ class SecretSharing : public coinbase::testutils::TestAC { n = 5; } - bool correctly_reconstructable(const ac_t &ac_ref, const ac_shares_t &shares, const ss::node_t *root) { + bool correctly_reconstructable(const ac_t& ac_ref, const ac_shares_t& shares, const ss::node_t* root) { bn_t reconstructed_x; if (ac_ref.enough_for_quorum(shares)) { @@ -96,7 +96,7 @@ TEST_F(SecretSharing, ListLeaves) { ac_t ac(test_root); auto leaves = ac.list_leaf_names(); EXPECT_EQ(leaves.size(), 24); - for (const auto &leaf : leaves) { + for (const auto& leaf : leaves) { EXPECT_TRUE(test_root->find(leaf)); } std::set leaves_set(leaves.begin(), leaves.end()); @@ -147,7 +147,7 @@ TEST_F(SecretSharing, ShareAnd) { EXPECT_EQ(shares.size(), n); bn_t sum = 0; - for (const auto &share : shares) { + for (const auto& share : shares) { MODULO(q) sum += share; } EXPECT_EQ(sum, x); @@ -246,18 +246,71 @@ TEST_F(SecretSharing, ACEnoughQuorumAndReconstruct) { shares = ac.share(q, x, nullptr); ac_shares_t minimal_shares; - for (const auto &name : valid_quorum) { + for (const auto& name : valid_quorum) { minimal_shares[name] = shares[name]; } EXPECT_TRUE(ac.enough_for_quorum(minimal_shares)); EXPECT_TRUE(correctly_reconstructable(ac, minimal_shares, test_root)); ac_shares_t malicious_shares; - for (const auto &name : valid_quorum) { + for (const auto& name : valid_quorum) { malicious_shares[name] = bn_t::rand(q); } EXPECT_TRUE(ac.enough_for_quorum(malicious_shares)); EXPECT_FALSE(correctly_reconstructable(ac, malicious_shares, test_root)); } +TEST_F(SecretSharing, ACReconstructExponentEd25519) { + vartime_scope_t vartime_scope; + ecurve_t curve = curve_ed25519; + const mod_t q = curve.order(); + const bn_t x = bn_t::rand(q); + + ac_t ac(test_root); + ac.curve = curve; + + const ac_shares_t shares = ac.share(q, x, nullptr); + ac_pub_shares_t pub_shares; + for (const auto& [name, si] : shares) { + pub_shares[name] = si * curve.generator(); + } + + ecc_point_t P; + EXPECT_OK(ac.reconstruct_exponent(pub_shares, P)); + EXPECT_EQ(P, x * curve.generator()); +} + +TEST_F(SecretSharing, ReconstructExpRejectsNonSubgroup) { + vartime_scope_t vartime_scope; + ecurve_t curve = curve_ed25519; + const mod_t q = curve.order(); + const bn_t x = bn_t::rand(q); + + ac_t ac(test_root); + ac.curve = curve; + + const ac_shares_t shares = ac.share(q, x, nullptr); + ac_pub_shares_t pub_shares; + for (const auto& [name, si] : shares) { + pub_shares[name] = si * curve.generator(); + } + + // Ed25519 order-2 torsion point (x=0, y=-1): on-curve but not in the prime-order subgroup. + uint8_t order2[32]; + order2[0] = 0xec; + for (int i = 1; i < 31; i++) order2[i] = 0xff; + order2[31] = 0x7f; + + ecc_point_t T(curve); + ASSERT_EQ(T.from_bin(curve, coinbase::mem_t(order2, 32)), SUCCESS); + ASSERT_TRUE(T.is_on_curve()); + ASSERT_FALSE(T.is_infinity()); + ASSERT_FALSE(T.is_in_subgroup()); + + pub_shares["leaf1"] = T; + + ecc_point_t P; + EXPECT_ER(ac.reconstruct_exponent(pub_shares, P)); +} + } // namespace \ No newline at end of file diff --git a/tests/unit/crypto/test_tdh2.cpp b/tests/unit/crypto/test_tdh2.cpp index af82ddc5..85d320d5 100644 --- a/tests/unit/crypto/test_tdh2.cpp +++ b/tests/unit/crypto/test_tdh2.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "utils/data/ac.h" #include "utils/data/tdh2.h" @@ -45,6 +45,7 @@ TEST_F(TDH2, AddCompleteness) { } TEST_F(TDH2, ACCompleteness) { + test_ac.curve = curve_p256; public_key_t enc_key; ss::ac_pub_shares_t pub_shares; ss::party_map_t dec_shares; @@ -74,4 +75,26 @@ TEST_F(TDH2, ACCompleteness) { EXPECT_EQ(plain, decrypted); } +TEST_F(TDH2, CiphertextRoundTripKeepsLabel) { + int n = 3; + std::vector dec_shares; + + public_key_t enc_key; + crypto::tdh2::pub_shares_t pub_shares; + testutils::generate_additive_shares(n, enc_key, pub_shares, dec_shares, curve_p256); + + const buf_t label = crypto::gen_random(10); + const buf_t wrong_label = buf_t("wrong-label"); + const buf_t plain = crypto::gen_random(32); + + const ciphertext_t ciphertext = enc_key.encrypt(plain, label); + const buf_t serialized = coinbase::convert(ciphertext); + + ciphertext_t roundtrip; + ASSERT_EQ(coinbase::convert(roundtrip, serialized), SUCCESS); + EXPECT_EQ(roundtrip.L, label); + EXPECT_OK(roundtrip.verify(enc_key, label)); + EXPECT_ER(roundtrip.verify(enc_key, wrong_label)); +} + } // namespace diff --git a/tests/unit/protocol/test_agree_random.cpp b/tests/unit/protocol/test_agree_random.cpp index c5ef596e..04cf302e 100644 --- a/tests/unit/protocol/test_agree_random.cpp +++ b/tests/unit/protocol/test_agree_random.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "utils/local_network/mpc_tester.h" #include "utils/test_macros.h" diff --git a/tests/unit/protocol/test_broadcast.cpp b/tests/unit/protocol/test_broadcast.cpp index 7973e7ef..ee2a64e9 100644 --- a/tests/unit/protocol/test_broadcast.cpp +++ b/tests/unit/protocol/test_broadcast.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "utils/local_network/mpc_tester.h" #include "utils/test_macros.h" diff --git a/tests/unit/protocol/test_ec_dkg.cpp b/tests/unit/protocol/test_ec_dkg.cpp index c3266441..ce021601 100644 --- a/tests/unit/protocol/test_ec_dkg.cpp +++ b/tests/unit/protocol/test_ec_dkg.cpp @@ -1,8 +1,8 @@ #include -#include -#include -#include +#include +#include +#include #include "utils/local_network/mpc_tester.h" #include "utils/test_macros.h" @@ -26,7 +26,7 @@ static void RunDkgAndAdditiveShareTest(crypto::ss::node_t* root_node, const std: ecurve_t curve = crypto::curve_secp256k1; const auto& G = curve.generator(); ss::ac_t ac; - ac.G = G; + ac.curve = curve; ac.root = root_node; mpc::party_set_t quorum_party_set; @@ -36,8 +36,8 @@ static void RunDkgAndAdditiveShareTest(crypto::ss::node_t* root_node, const std: buf_t sid_dkg = crypto::gen_random(16); mpc_runner_t all_parties_runner(pnames); all_parties_runner.run_mpc([&](mpc::job_mp_t& job) { - ASSERT_OK(coinbase::mpc::eckey::key_share_mp_t::threshold_dkg(job, curve, sid_dkg, ac, quorum_party_set, - keyshares[job.get_party_idx()])); + ASSERT_OK(coinbase::mpc::eckey::key_share_mp_t::dkg_ac(job, curve, sid_dkg, ac, quorum_party_set, + keyshares[job.get_party_idx()])); }); // Basic key consistency @@ -76,7 +76,7 @@ TEST(ECDKG, ReconstructPubAdditiveShares) { ecurve_t curve = crypto::curve_secp256k1; const auto& G = curve.generator(); ss::ac_t ac; - ac.G = G; + ac.curve = curve; ac.root = root_node; std::set quorum = {"p0", "p1", "p2"}; @@ -129,7 +129,7 @@ TEST(ECDKG, ReconstructPubAdditiveShares) { EXPECT_EQ(additive_share.Qis["p2"], expected_p2); } -TEST(ECDKG, ReconstructPubAdditiveShares_ORNode) { +TEST(ECDKG, ReconstructPubShares_OR) { // OR(p0, AND(p1, THRESHOLD[1](p2, p3))) with additive quorum {p1, p2} ss::node_t* root_node = new ss::node_t( ss::node_e::OR, "", 0, @@ -146,7 +146,7 @@ TEST(ECDKG, ReconstructPubAdditiveShares_ORNode) { RunDkgAndAdditiveShareTest(root_node, pnames, dkg_quorum_indices, additive_quorum); } -TEST(ECDKG, ReconstructPubAdditiveShares_Threshold2of3) { +TEST(ECDKG, ReconstructPubShares_Thres2of3) { // THRESHOLD[2](p0, p1, p2) with additive quorum {p0, p2} ss::node_t* root_node = new ss::node_t(ss::node_e::THRESHOLD, "", 2, @@ -159,7 +159,33 @@ TEST(ECDKG, ReconstructPubAdditiveShares_Threshold2of3) { RunDkgAndAdditiveShareTest(root_node, pnames, dkg_quorum_indices, additive_quorum); } -TEST(ECDKG, ReconstructPubAdditiveShares_ThresholdNofN_ANDEquivalent) { +TEST(ECDKG, ReconstructPub_2of3_AllQuorum) { + // THRESHOLD[2](p0, p1, p2) with all leaves in additive quorum (|quorum| > threshold). + ss::node_t* root_node = + new ss::node_t(ss::node_e::THRESHOLD, "", 2, + {new ss::node_t(ss::node_e::LEAF, "p0"), new ss::node_t(ss::node_e::LEAF, "p1"), + new ss::node_t(ss::node_e::LEAF, "p2")}); + + std::vector pnames = {"p0", "p1", "p2"}; + std::set dkg_quorum_indices = {0, 1, 2}; + std::set additive_quorum = {"p0", "p1", "p2"}; + RunDkgAndAdditiveShareTest(root_node, pnames, dkg_quorum_indices, additive_quorum); +} + +TEST(ECDKG, ReconstructPub_1of3_AllQuorum) { + // THRESHOLD[1](p0, p1, p2) with all leaves in additive quorum. + ss::node_t* root_node = + new ss::node_t(ss::node_e::THRESHOLD, "", 1, + {new ss::node_t(ss::node_e::LEAF, "p0"), new ss::node_t(ss::node_e::LEAF, "p1"), + new ss::node_t(ss::node_e::LEAF, "p2")}); + + std::vector pnames = {"p0", "p1", "p2"}; + std::set dkg_quorum_indices = {0, 1, 2}; + std::set additive_quorum = {"p0", "p1", "p2"}; + RunDkgAndAdditiveShareTest(root_node, pnames, dkg_quorum_indices, additive_quorum); +} + +TEST(ECDKG, ReconstructPub_NofN_AndEq) { // THRESHOLD[3](p0, p1, p2) with additive quorum {p0, p1, p2} (equivalent to AND) ss::node_t* root_node = new ss::node_t(ss::node_e::THRESHOLD, "", 3, @@ -172,7 +198,7 @@ TEST(ECDKG, ReconstructPubAdditiveShares_ThresholdNofN_ANDEquivalent) { RunDkgAndAdditiveShareTest(root_node, pnames, dkg_quorum_indices, additive_quorum); } -TEST(ECDKG, ReconstructPubAdditiveShares_Threshold3of4_LargerLeaves) { +TEST(ECDKG, ReconstructPub_3of4_LargeLeaves) { // THRESHOLD[3](p0, p1, p2, p3) with additive quorum {p0, p1, p2} ss::node_t* root_node = new ss::node_t(ss::node_e::THRESHOLD, "", 3, @@ -185,7 +211,7 @@ TEST(ECDKG, ReconstructPubAdditiveShares_Threshold3of4_LargerLeaves) { RunDkgAndAdditiveShareTest(root_node, pnames, dkg_quorum_indices, additive_quorum); } -TEST(ECDKG, ThresholdDkgRejectsInvalidAccessStructureDuplicateLeaves) { +TEST(ECDKG, ThresholdDkgRejectsDupLeaves) { dylog_disable_scope_t dylog_disable_scope; ecurve_t curve = crypto::curve_secp256k1; @@ -197,7 +223,7 @@ TEST(ECDKG, ThresholdDkgRejectsInvalidAccessStructureDuplicateLeaves) { {new node_t(node_e::LEAF, "p1"), new node_t(node_e::LEAF, "p1"), new node_t(node_e::LEAF, "p2")}); ss::ac_t ac; - ac.G = G; + ac.curve = curve; ac.root = &root; std::vector pnames = {"p0", "p1", "p2"}; @@ -208,7 +234,7 @@ TEST(ECDKG, ThresholdDkgRejectsInvalidAccessStructureDuplicateLeaves) { buf_t sid = crypto::gen_random(16); key_share_mp_t key; - error_t rv = key_share_mp_t::threshold_dkg(job, curve, sid, ac, quorum_party_set, key); + error_t rv = key_share_mp_t::dkg_ac(job, curve, sid, ac, quorum_party_set, key); EXPECT_EQ(uint32_t(rv), uint32_t(E_BADARG)); } diff --git a/tests/unit/protocol/test_ecdsa_2p.cpp b/tests/unit/protocol/test_ecdsa_2p.cpp index 5b4ab922..efc10620 100644 --- a/tests/unit/protocol/test_ecdsa_2p.cpp +++ b/tests/unit/protocol/test_ecdsa_2p.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "utils/local_network/mpc_tester.h" @@ -183,54 +183,6 @@ TEST_F(ECDSA2PC, KeygenSign) { check_key_pair(keys[0], keys[1]); } -TEST_F(ECDSA2PC, ParallelKSRS8) { - int parallel_count = 4; - std::vector> data(parallel_count); - for (int i = 0; i < parallel_count; i++) { - int len = i + 1; - data[i].resize(len); - for (int j = 0; j < len; j++) data[i][j] = coinbase::crypto::gen_random(32); - } - std::vector> keys(parallel_count, std::vector(2)); - std::vector> new_keys(parallel_count, std::vector(2)); - buf_t sid = coinbase::crypto::gen_random_bitlen(SEC_P_COM); - - mpc_runner->run_2pc_parallel(parallel_count, [&data, &keys, &new_keys, &sid](job_parallel_2p_t& job, int th_i) { - error_t rv = UNINITIALIZED_ERROR; - auto party_index = job.get_party_idx(); - ecurve_t curve = coinbase::crypto::curve_secp256k1; - - ecdsa2pc::key_t& key = keys[th_i][party_index]; - rv = ecdsa2pc::dkg(job, curve, key); - ASSERT_EQ(rv, 0); - - std::vector sig_bufs; - buf_t session_id; - rv = sign_batch(job, session_id, key, buf_t::to_mems(data[th_i]), sig_bufs); - ASSERT_EQ(rv, 0); - - ecdsa2pc::key_t& new_key = new_keys[th_i][party_index]; - rv = ecdsa2pc::refresh(job, key, new_key); - ASSERT_EQ(rv, 0); - - EXPECT_EQ(new_key.role, key.role); - EXPECT_EQ(new_key.curve, key.curve); - EXPECT_EQ(new_key.Q, key.Q); - EXPECT_NE(new_key.x_share, key.x_share); - - std::vector new_sig_bufs; - rv = sign_batch(job, session_id, key, buf_t::to_mems(data[th_i]), new_sig_bufs); - ASSERT_EQ(rv, 0); - rv = sign_with_global_abort_batch(job, session_id, key, buf_t::to_mems(data[th_i]), new_sig_bufs); - ASSERT_EQ(rv, 0); - }); - - for (int i = 0; i < parallel_count; i++) { - check_key_pair(keys[i][0], keys[i][1]); - check_key_pair(new_keys[i][0], new_keys[i][1]); - } -} - TEST_F(ECDSA2PC, Integer_Commit) { error_t rv = UNINITIALIZED_ERROR; ecurve_t curve = coinbase::crypto::curve_secp256k1; diff --git a/tests/unit/protocol/test_ecdsa_mp.cpp b/tests/unit/protocol/test_ecdsa_mp.cpp index 31779150..b2efb100 100644 --- a/tests/unit/protocol/test_ecdsa_mp.cpp +++ b/tests/unit/protocol/test_ecdsa_mp.cpp @@ -1,8 +1,8 @@ #include -#include -#include -#include +#include +#include +#include #include "utils/local_network/mpc_tester.h" #include "utils/test_macros.h" @@ -180,57 +180,6 @@ TEST_F(ECDSA4PC, KeygenSignRefreshSign) { check_keys(new_keys); } -TEST_F(ECDSA4PC, ParallelKSRS8) { - int parallel_count = 8; - std::vector data(parallel_count); - for (int i = 0; i < parallel_count; i++) { - data[i] = crypto::gen_random(32); - } - std::vector> keys(parallel_count, std::vector(4)); - std::vector> new_keys(parallel_count, std::vector(4)); - - mpc_runner->run_mpc_parallel(parallel_count, [&keys, &new_keys, &data](job_parallel_mp_t& job, int th_i) { - std::vector> ot_role_map = test_ot_role(4); - error_t rv = UNINITIALIZED_ERROR; - auto party_index = job.get_party_idx(); - ecdsampc::key_t& key = keys[th_i][party_index]; - ecurve_t curve = crypto::curve_secp256k1; - - buf_t sid; - rv = ecdsampc::dkg(job, curve, key, sid); - ASSERT_EQ(rv, 0); - - buf_t sig; - rv = sign(job, key, data[th_i], party_idx_t(0), ot_role_map, sig); - ASSERT_EQ(rv, 0); - - if (party_index == 0) { - crypto::ecc_pub_key_t ecc_verify_key(key.Q); - EXPECT_OK(ecc_verify_key.verify(data[th_i], sig)); - } - - ecdsampc::key_t& new_key = new_keys[th_i][party_index]; - rv = ecdsampc::refresh(job, sid, key, new_key); - ASSERT_EQ(rv, 0); - EXPECT_EQ(new_key.Q, key.Q); - EXPECT_NE(new_key.x_share, key.x_share); - - buf_t new_sig; - rv = sign(job, new_key, data[th_i], party_idx_t(0), ot_role_map, new_sig); - ASSERT_EQ(rv, 0); - - if (party_index == 0) { - crypto::ecc_pub_key_t ecc_verify_key(key.Q); - EXPECT_OK(ecc_verify_key.verify(data[th_i], new_sig)); - } - }); - - for (int i = 0; i < parallel_count; i++) { - check_keys(keys[i]); - check_keys(new_keys[i]); - } -} - TEST(ECDSAMPCThreshold, DKG) { int n = 5; std::vector pnames = {"party-0", "party-1", "party-2", "party-3", "party-4"}; @@ -277,14 +226,13 @@ TEST(ECDSAMPCThreshold, DKG) { new crypto::ss::node_t(crypto::ss::node_e::LEAF, pnames[4]), })}); crypto::ss::ac_t ac; - ac.G = G; + ac.curve = curve; ac.root = root_node; // DKG is an n-party protocol mpc_runner_t all_parties_runner(pnames); all_parties_runner.run_mpc([&curve, &keyshares, &quorum_party_set, &ac, &sid_dkg](mpc::job_mp_t& job) { - EXPECT_OK(eckey::key_share_mp_t::threshold_dkg(job, curve, sid_dkg, ac, quorum_party_set, - keyshares[job.get_party_idx()])); + EXPECT_OK(eckey::key_share_mp_t::dkg_ac(job, curve, sid_dkg, ac, quorum_party_set, keyshares[job.get_party_idx()])); }); for (int i = 0; i < n; i++) { @@ -313,9 +261,8 @@ TEST(ECDSAMPCThreshold, DKG) { // Refresh is an n-party protocol all_parties_runner.run_mpc([&](mpc::job_mp_t& job) { - ASSERT_OK(eckey::key_share_mp_t::threshold_refresh(job, curve, sid_refresh, ac, quorum_party_set, - keyshares[job.get_party_idx()], - new_keyshares[job.get_party_idx()])); + ASSERT_OK(eckey::key_share_mp_t::refresh_ac(job, curve, sid_refresh, ac, quorum_party_set, + keyshares[job.get_party_idx()], new_keyshares[job.get_party_idx()])); }); ASSERT_EQ(sid_refresh.size(), 16); ASSERT_NE(sid_refresh, sid_dkg); diff --git a/tests/unit/protocol/test_hdmpc_ecdsa_2p.cpp b/tests/unit/protocol/test_hdmpc_ecdsa_2p.cpp index 9f4da18a..7e067851 100644 --- a/tests/unit/protocol/test_hdmpc_ecdsa_2p.cpp +++ b/tests/unit/protocol/test_hdmpc_ecdsa_2p.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "utils/local_network/mpc_tester.h" @@ -196,69 +196,4 @@ TEST_F(HDMPC_ECDSA_2P, SignSequential) { }); } -TEST_F(HDMPC_ECDSA_2P, SignParallel) { - int DATA_COUNT = 3; - std::vector data(DATA_COUNT); - for (int i = 0; i < data.size(); i++) data[i] = coinbase::crypto::gen_random(32); - buf_t session_id = coinbase::crypto::gen_random(32); - - mpc_runner->run_2pc_parallel(1, [&data, &session_id, &DATA_COUNT](job_parallel_2p_t& job, int dummy) { - error_t rv = UNINITIALIZED_ERROR; - auto role = job.get_party(); - ecurve_t curve = coinbase::crypto::curve_secp256k1; - - key_share_ecdsa_hdmpc_2p_t key; - rv = key_share_ecdsa_hdmpc_2p_t::dkg(job, curve, key); - ASSERT_EQ(rv, 0); - - bip32_path_t hardened_path; - std::vector non_hardened_paths(DATA_COUNT); - - hardened_path.append(1); - hardened_path.append(2); - hardened_path.append(3); - - for (int i = 0; i < DATA_COUNT; i++) { - non_hardened_paths[i].append((i + 1) * 4 + 0); - non_hardened_paths[i].append((i + 1) * 4 + 1); - } - - int n_sigs = (int)non_hardened_paths.size(); - std::vector sigs(n_sigs); - std::vector derived_keys(n_sigs); - - rv = key_share_ecdsa_hdmpc_2p_t::derive_keys(job, key, hardened_path, non_hardened_paths, session_id, derived_keys); - - ASSERT_EQ(rv, 0); - - std::vector threads; - job.set_parallel_count(n_sigs); - std::mutex update_sig_mtx; - - for (int i = 0; i < n_sigs; i++) { - threads.emplace_back([i, &derived_keys, &data, &job, &sigs, &update_sig_mtx]() { - auto _derived_key = derived_keys[i]; - auto _data = data[i]; - int parallel_count = sigs.size(); - job_parallel_2p_t parallel_job = - job.get_parallel_job(parallel_count, parallel_id_t(i)); // create a new job from network - buf_t _sig; - buf_t empty_sid; // empty session id -> sign will generate sid internally - - error_t rv = ecdsa2pc::sign(parallel_job, empty_sid, _derived_key, _data, _sig); - - { - std::unique_lock lk(update_sig_mtx, std::defer_lock); - sigs[i] = _sig; - } - - ASSERT_EQ(rv, 0); - }); - } - for (auto& th : threads) th.join(); - - job.set_parallel_count(0); - }); -} - } // namespace \ No newline at end of file diff --git a/tests/unit/protocol/test_hdmpc_eddsa_2p.cpp b/tests/unit/protocol/test_hdmpc_eddsa_2p.cpp index c1eca148..34ddd7f9 100644 --- a/tests/unit/protocol/test_hdmpc_eddsa_2p.cpp +++ b/tests/unit/protocol/test_hdmpc_eddsa_2p.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "utils/local_network/mpc_tester.h" @@ -173,68 +173,4 @@ TEST_F(HDMPC_EdDSA_2P, SignSequential) { } }); } - -TEST_F(HDMPC_EdDSA_2P, SignParallel) { - int DATA_COUNT = 3; - std::vector data(DATA_COUNT); - for (int i = 0; i < data.size(); i++) data[i] = coinbase::crypto::gen_random(32); - buf_t session_id = coinbase::crypto::gen_random(32); - - mpc_runner->run_2pc_parallel(1, [&data, &session_id, &DATA_COUNT](job_parallel_2p_t& job, int dummy) { - error_t rv = UNINITIALIZED_ERROR; - auto role = job.get_party(); - ecurve_t curve = coinbase::crypto::curve_ed25519; - - key_share_eddsa_hdmpc_2p_t key; - rv = key_share_eddsa_hdmpc_2p_t::dkg(job, curve, key); - ASSERT_EQ(rv, 0); - - bip32_path_t hardened_path; - std::vector non_hardened_paths(DATA_COUNT); - - hardened_path.append(1); - hardened_path.append(2); - hardened_path.append(3); - - for (int i = 0; i < DATA_COUNT; i++) { - non_hardened_paths[i].append((i + 1) * 4 + 0); - non_hardened_paths[i].append((i + 1) * 4 + 1); - } - - int n_sigs = (int)non_hardened_paths.size(); - std::vector sigs(n_sigs); - std::vector derived_keys(n_sigs); - - rv = key_share_eddsa_hdmpc_2p_t::derive_keys(job, key, hardened_path, non_hardened_paths, session_id, derived_keys); - - ASSERT_EQ(rv, 0); - - std::vector threads; - job.set_parallel_count(n_sigs); - std::mutex update_sig_mtx; - - for (int i = 0; i < n_sigs; i++) { - threads.emplace_back([i, &derived_keys, &data, &job, &sigs, &update_sig_mtx]() { - auto _derived_key = derived_keys[i]; - auto _data = data[i]; - int parallel_count = sigs.size(); - job_parallel_2p_t parallel_job = - job.get_parallel_job(parallel_count, parallel_id_t(i)); // create a new job from network - buf_t _sig; - - error_t rv = eddsa2pc::sign(parallel_job, _derived_key, _data, _sig); - - { - std::unique_lock lk(update_sig_mtx, std::defer_lock); - sigs[i] = _sig; - } - - ASSERT_EQ(rv, 0); - }); - } - for (auto& th : threads) th.join(); - - job.set_parallel_count(0); - }); -} } // namespace \ No newline at end of file diff --git a/tests/unit/protocol/test_int_commitment.cpp b/tests/unit/protocol/test_int_commitment.cpp index 6c9d761d..8abb6812 100644 --- a/tests/unit/protocol/test_int_commitment.cpp +++ b/tests/unit/protocol/test_int_commitment.cpp @@ -1,8 +1,8 @@ #include -#include -#include -#include +#include +#include +#include #include "utils/test_macros.h" diff --git a/tests/unit/protocol/test_mpc_network.cpp b/tests/unit/protocol/test_mpc_network.cpp index fe7b0336..41da351b 100644 --- a/tests/unit/protocol/test_mpc_network.cpp +++ b/tests/unit/protocol/test_mpc_network.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "utils/local_network/mpc_tester.h" @@ -38,31 +38,6 @@ TEST_F(Network2PC, BasicMessaging) { }); } -typedef std::function lambda_2p_t; - -TEST_F(Network2PC, ParallelMessaging) { - int parallel_count = 50; - std::atomic finished(0); - std::mutex send_cond_mutex; - - mpc_runner->run_2pc_parallel(parallel_count, [&finished](job_parallel_2p_t& job, int th_i) { - error_t rv = UNINITIALIZED_ERROR; - buf_t data; - buf_t want(mem_t("test_data:" + std::to_string(th_i * 10000))); - if (job.is_p1()) data = want; - if (job.is_p2()) EXPECT_NE(data, want); - - rv = job.p1_to_p2(data); - ASSERT_EQ(rv, 0); - - EXPECT_EQ(data, want); - finished++; - }); - - // To verify that that all threads finished - EXPECT_EQ(finished, parallel_count * 2); -} - TEST_F(Network4PC, BasicBroadcast) { mpc_runner->run_mpc([](job_mp_t& job) { error_t rv = UNINITIALIZED_ERROR; @@ -79,70 +54,6 @@ TEST_F(Network4PC, BasicBroadcast) { }); } -TEST_F(Network4PC, ParallelBroadcasting) { - int parallel_count = 3; - std::atomic finished(0); - - mpc_runner->run_mpc_parallel(parallel_count, [&finished](job_mp_t& job, int th_i) { - error_t rv = UNINITIALIZED_ERROR; - auto party_index = job.get_party_idx(); - auto data = - job.uniform_msg(buf_t("test_data:" + std::to_string(party_index) + "-thread" + std::to_string(th_i))); - rv = job.plain_broadcast(data); - ASSERT_EQ(rv, 0); - - for (int j = 0; j < 4; j++) { - EXPECT_EQ(data.received(j), buf_t("test_data:" + std::to_string(j) + "-thread" + std::to_string(th_i))); - } - finished++; - }); - - EXPECT_EQ(finished, parallel_count * 4); -} - -class Network2PC_ParallelReceiveError : public Network2PC, public ::testing::WithParamInterface {}; - -TEST_P(Network2PC_ParallelReceiveError, DoesNotDeadlock) { - int parallel_count = 8; - const int abort_th = GetParam(); - std::atomic finished(0); - - auto* runner = mpc_runner.get(); - mpc_runner->run_2pc_parallel(parallel_count, [&, runner, abort_th](job_parallel_2p_t& job, int th_i) { - if (job.is_p2() && th_i == abort_th) { - runner->abort_connection(); - } - buf_t data("x"); - job.p1_to_p2(data); - finished++; - }); - - EXPECT_EQ(finished, parallel_count * 2); -} -INSTANTIATE_TEST_SUITE_P(, Network2PC_ParallelReceiveError, ::testing::Values(0, 1)); - -class Network4PC_ParallelReceiveAllError : public Network4PC, public ::testing::WithParamInterface {}; - -TEST_P(Network4PC_ParallelReceiveAllError, DoesNotDeadlock) { - int parallel_count = 8; - const int abort_th = GetParam(); - std::atomic finished(0); - - auto* runner = mpc_runner.get(); - mpc_runner->run_mpc_parallel(parallel_count, [&, runner, abort_th](job_mp_t& job, int th_i) { - if (job.get_party_idx() == 0 && th_i == abort_th) { - runner->abort_connection(); - } - - auto data = job.uniform_msg(buf_t("x")); - job.plain_broadcast(data); - finished++; - }); - - EXPECT_EQ(finished, parallel_count * 4); -} -INSTANTIATE_TEST_SUITE_P(, Network4PC_ParallelReceiveAllError, ::testing::Values(0, 1)); - TEST_F(Network4PC, MessageWrapperCopySafety) { mpc_runner->run_mpc([](job_mp_t& job) { // nonuniform_msg_t copy then use-after-source-destruction should be safe @@ -225,65 +136,6 @@ TEST_P(NetworkMPC, PairwiseAndBroadcast) { EXPECT_EQ(data.msg, buf_t("test_data:" + std::to_string(party_index))); }); } - -TEST_P(NetworkMPC, ParallelBroadcasting) { - int n_parties = GetParam(); - int parallel_count = 16; - - auto mpc_runner = std::make_unique(n_parties); - std::atomic finished(0); - - mpc_runner->run_mpc_parallel(parallel_count, [&finished, &n_parties](job_mp_t& job, int th_i) { - error_t rv = UNINITIALIZED_ERROR; - auto party_index = job.get_party_idx(); - auto data = - job.uniform_msg(buf_t("test_data:" + std::to_string(party_index) + "-thread" + std::to_string(th_i))); - rv = job.plain_broadcast(data); - ASSERT_EQ(rv, 0); - - for (int j = 0; j < n_parties; j++) { - EXPECT_EQ(data.received(j), buf_t("test_data:" + std::to_string(j) + "-thread" + std::to_string(th_i))); - } - for (int i = 0; i < 10; i++) { - auto data2 = - job.uniform_msg(buf_t("test_data:" + std::to_string(party_index) + "-thread" + std::to_string(th_i))); - rv = job.plain_broadcast(data2); - ASSERT_EQ(rv, 0); - } - finished++; - }); - - // To verify that that all threads finished - EXPECT_EQ(finished, parallel_count * n_parties); -} INSTANTIATE_TEST_SUITE_P(, NetworkMPC, testing::Values(2, 4, 5, 10, 32, 64)); -TEST_F(Network2PC, SequentialThenParallel) { - int PARALLEL_COUNT = 3; - std::vector data(PARALLEL_COUNT); - for (int i = 0; i < data.size(); i++) data[i] = crypto::gen_random_bitlen(128); - - mpc_runner->run_2pc_parallel(1, [&data, PARALLEL_COUNT](job_parallel_2p_t& job, int dummy) { - error_t rv = UNINITIALIZED_ERROR; - auto role = job.get_party(); - - rv = job.p1_to_p2(data[0]); - - std::vector threads; - job.set_parallel_count(PARALLEL_COUNT); - - for (int i = 0; i < PARALLEL_COUNT; i++) { - threads.emplace_back([&data, &job, PARALLEL_COUNT, i]() { - job_parallel_2p_t parallel_job = job.get_parallel_job(PARALLEL_COUNT, parallel_id_t(i)); - - error_t rv = parallel_job.p1_to_p2(data[i]); - ASSERT_EQ(rv, 0); - }); - } - for (auto& th : threads) th.join(); - - job.set_parallel_count(0); - }); -} - } // namespace \ No newline at end of file diff --git a/tests/unit/protocol/test_ot.cpp b/tests/unit/protocol/test_ot.cpp index 1c851b01..4e0142a4 100644 --- a/tests/unit/protocol/test_ot.cpp +++ b/tests/unit/protocol/test_ot.cpp @@ -1,8 +1,8 @@ #include -#include -#include -#include +#include +#include +#include #include "utils/test_macros.h" diff --git a/tests/unit/protocol/test_parallel_transport_oob.cpp b/tests/unit/protocol/test_parallel_transport_oob.cpp deleted file mode 100644 index 5da8e9f0..00000000 --- a/tests/unit/protocol/test_parallel_transport_oob.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include -#include - -#include -#include -#include - -namespace { - -using namespace coinbase; -using namespace coinbase::mpc; - -class fixed_buf_transport_t final : public data_transport_interface_t { - public: - explicit fixed_buf_transport_t(buf_t malicious) : malicious_buf_(std::move(malicious)) {} - - error_t send(party_idx_t /*receiver*/, mem_t /*msg*/) override { return SUCCESS; } - - error_t receive(party_idx_t /*sender*/, buf_t& msg) override { - msg = malicious_buf_; - return SUCCESS; - } - - error_t receive_all(const std::vector& senders, std::vector& message) override { - message.assign(senders.size(), malicious_buf_); - return SUCCESS; - } - - private: - buf_t malicious_buf_; -}; - -struct scoped_log_sink_t { - scoped_log_sink_t() : prev_(coinbase::out_log_fun) { coinbase::out_log_fun = &scoped_log_sink_t::discard; } - ~scoped_log_sink_t() { coinbase::out_log_fun = prev_; } - scoped_log_sink_t(const scoped_log_sink_t&) = delete; - scoped_log_sink_t& operator=(const scoped_log_sink_t&) = delete; - - private: - static void discard(int /*mode*/, const char* /*str*/) {} - coinbase::out_log_str_f prev_; -}; - -TEST(ParallelDataTransportOOB, MaliciousVectorLenZeroReceive) { - scoped_log_sink_t logs; - - // A single byte `0x00` decodes to vector length = 0 (via convert_len). - buf_t malicious(1); - malicious[0] = 0x00; - - auto transport = std::make_shared(malicious); - parallel_data_transport_t network(transport, /*_parallel_count=*/2); - - error_t rv0 = UNINITIALIZED_ERROR; - error_t rv1 = UNINITIALIZED_ERROR; - buf_t out0, out1; - - std::thread t0([&] { rv0 = network.receive(/*sender=*/0, /*parallel_id=*/0, out0); }); - std::thread t1([&] { rv1 = network.receive(/*sender=*/0, /*parallel_id=*/1, out1); }); - t0.join(); - t1.join(); - - EXPECT_EQ(rv0, E_FORMAT); - EXPECT_EQ(rv1, E_FORMAT); - EXPECT_TRUE(out0.empty()); - EXPECT_TRUE(out1.empty()); -} - -TEST(ParallelDataTransportOOB, MaliciousVectorLenZeroReceiveAll) { - scoped_log_sink_t logs; - - buf_t malicious(1); - malicious[0] = 0x00; - - auto transport = std::make_shared(malicious); - parallel_data_transport_t network(transport, /*_parallel_count=*/2); - - const std::vector senders = {0, 1, 2}; - - error_t rv0 = UNINITIALIZED_ERROR; - error_t rv1 = UNINITIALIZED_ERROR; - std::vector outs0(senders.size()); - std::vector outs1(senders.size()); - - std::thread t0([&] { rv0 = network.receive_all(senders, /*parallel_id=*/0, outs0); }); - std::thread t1([&] { rv1 = network.receive_all(senders, /*parallel_id=*/1, outs1); }); - t0.join(); - t1.join(); - - EXPECT_EQ(rv0, E_FORMAT); - EXPECT_EQ(rv1, E_FORMAT); - for (const auto& m : outs0) EXPECT_TRUE(m.empty()); - for (const auto& m : outs1) EXPECT_TRUE(m.empty()); -} - -} // namespace diff --git a/tests/unit/protocol/test_pve.cpp b/tests/unit/protocol/test_pve.cpp index d21b3ef4..15e1f5e2 100644 --- a/tests/unit/protocol/test_pve.cpp +++ b/tests/unit/protocol/test_pve.cpp @@ -1,14 +1,14 @@ #include -#include -#include -#include -#include -#include +#include +#include +#include +#include #include "utils/test_macros.h" using namespace coinbase; +using namespace coinbase::crypto; using namespace coinbase::mpc; namespace { @@ -19,176 +19,123 @@ struct toy_kem_policy_t { struct ek_t {}; struct dk_t {}; - static error_t encapsulate(const ek_t &, buf_t &kem_ct, buf_t &kem_ss, crypto::drbg_aes_ctr_t *drbg) { + static error_t encapsulate(const ek_t&, buf_t& kem_ct, buf_t& kem_ss, crypto::drbg_aes_ctr_t* drbg) { kem_ss = drbg ? drbg->gen(32) : crypto::gen_random(32); kem_ct = kem_ss; // trivial, not secure, only for test return SUCCESS; } - static error_t decapsulate(const dk_t &, mem_t kem_ct, buf_t &kem_ss) { + static error_t decapsulate(const dk_t&, mem_t kem_ct, buf_t& kem_ss) { kem_ss = buf_t(kem_ct); return SUCCESS; } }; -class PVE : public testing::Test { - protected: - void SetUp() override { - // Generate RSA keys - rsa_prv_key1.generate(2048); - rsa_prv_key2.generate(2048); - - // Generate ECC key - ecc_prv_key.generate(crypto::curve_p256); - - // Unified valid pairs - valid_unified = { - {crypto::pub_key_t::from(rsa_prv_key1.pub()), crypto::prv_key_t::from(rsa_prv_key1)}, - {crypto::pub_key_t::from(rsa_prv_key2.pub()), crypto::prv_key_t::from(rsa_prv_key2)}, - {crypto::pub_key_t::from(ecc_prv_key.pub()), crypto::prv_key_t::from(ecc_prv_key)}, - }; - - // Unified invalid pairs (mismatched) - invalid_unified = { - {crypto::pub_key_t::from(rsa_prv_key1.pub()), crypto::prv_key_t::from(rsa_prv_key2)}, - {crypto::pub_key_t::from(rsa_prv_key2.pub()), crypto::prv_key_t::from(rsa_prv_key1)}, - {crypto::pub_key_t::from(rsa_prv_key1.pub()), crypto::prv_key_t::from(ecc_prv_key)}, - {crypto::pub_key_t::from(rsa_prv_key2.pub()), crypto::prv_key_t::from(ecc_prv_key)}, - {crypto::pub_key_t::from(ecc_prv_key.pub()), crypto::prv_key_t::from(rsa_prv_key1)}, - {crypto::pub_key_t::from(ecc_prv_key.pub()), crypto::prv_key_t::from(rsa_prv_key2)}, - }; - } - +TEST(PVE, RSA_Completeness) { const ecurve_t curve = crypto::curve_p256; - const mod_t &q = curve.order(); - const crypto::ecc_generator_point_t &G = curve.generator(); - - // Keys - crypto::rsa_prv_key_t rsa_prv_key1, rsa_prv_key2; - crypto::ecc_prv_key_t ecc_prv_key; - - // Unified pairs - std::vector> valid_unified; - std::vector> invalid_unified; -}; - -// Define alias for fixture used by batch tests -typedef PVE PVEBatch; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); -TEST_F(PVE, DefaultUnified_Completeness) { - for (const auto &kp : valid_unified) { - const auto &pub_key = kp.first; - const auto &prv_key = kp.second; - - ec_pve_t pve; // defaults to base_pke_unified - bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; - - pve.encrypt(&pub_key, "test-label", curve, x); - EXPECT_OK(pve.verify(&pub_key, X, "test-label")); + rsa_prv_key_t rsa_sk; + rsa_sk.generate(2048); + rsa_pub_key_t rsa_pk(rsa_sk.pub()); - bn_t decrypted_x; - EXPECT_OK(pve.decrypt(&prv_key, &pub_key, "test-label", curve, decrypted_x)); - EXPECT_EQ(x, decrypted_x); - } -} + ec_pve_t pve; + bn_t x = bn_t::rand(q); + ecc_point_t X = x * G; -TEST_F(PVE, DefaultUnified_VerifyWithWrongLabel) { - for (const auto &kp : valid_unified) { - const auto &pub_key = kp.first; - ec_pve_t pve; - bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; + EXPECT_OK(pve.encrypt(pve_base_pke_rsa(), pve_keyref(rsa_pk), "test-label", curve, x)); + EXPECT_OK(pve.verify(pve_base_pke_rsa(), pve_keyref(rsa_pk), X, "test-label")); - pve.encrypt(&pub_key, "test-label", curve, x); - dylog_disable_scope_t no_log_err; - EXPECT_ER(pve.verify(&pub_key, X, "wrong-label")); - } + bn_t decrypted_x; + EXPECT_OK(pve.decrypt(pve_base_pke_rsa(), pve_keyref(rsa_sk), pve_keyref(rsa_pk), "test-label", curve, decrypted_x)); + EXPECT_EQ(x, decrypted_x); } -TEST_F(PVE, DefaultUnified_VerifyWithWrongQ) { - for (const auto &kp : valid_unified) { - const auto &pub_key = kp.first; - ec_pve_t pve; - bn_t x = bn_t::rand(q); - - pve.encrypt(&pub_key, "test-label", curve, x); - dylog_disable_scope_t no_log_err; - EXPECT_ER(pve.verify(&pub_key, bn_t::rand(q) * G, "test-label")); - } -} +TEST(PVE, ECIES_Completeness) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); -TEST_F(PVE, DefaultUnified_DecryptWithWrongLabel) { - for (const auto &kp : valid_unified) { - const auto &pub_key = kp.first; - const auto &prv_key = kp.second; + ecc_prv_key_t ecc_sk; + ecc_sk.generate(crypto::curve_p256); + ecc_pub_key_t ecc_pk(ecc_sk.pub()); - ec_pve_t pve; - bn_t x = bn_t::rand(q); + ec_pve_t pve; + bn_t x = bn_t::rand(q); + ecc_point_t X = x * G; - pve.encrypt(&pub_key, "test-label", curve, x); + EXPECT_OK(pve.encrypt(pve_base_pke_ecies(), pve_keyref(ecc_pk), "test-label", curve, x)); + EXPECT_OK(pve.verify(pve_base_pke_ecies(), pve_keyref(ecc_pk), X, "test-label")); - bn_t decrypted_x; - dylog_disable_scope_t no_log_err; - EXPECT_ER(pve.decrypt(&prv_key, &pub_key, "wrong-label", curve, decrypted_x)); - EXPECT_NE(x, decrypted_x); - } + bn_t decrypted_x; + EXPECT_OK( + pve.decrypt(pve_base_pke_ecies(), pve_keyref(ecc_sk), pve_keyref(ecc_pk), "test-label", curve, decrypted_x)); + EXPECT_EQ(x, decrypted_x); } -TEST_F(PVE, DefaultUnified_DecryptWithWrongKey) { - for (const auto &kp : invalid_unified) { - const auto &pub_key = kp.first; - const auto &prv_key = kp.second; +TEST(PVE, ECIES_VerifyWithWrongLabel) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); - ec_pve_t pve; - bn_t x = bn_t::rand(q); + ecc_prv_key_t ecc_sk; + ecc_sk.generate(crypto::curve_p256); + ecc_pub_key_t ecc_pk(ecc_sk.pub()); - pve.encrypt(&pub_key, "test-label", curve, x); + ec_pve_t pve; + bn_t x = bn_t::rand(q); + ecc_point_t X = x * G; - bn_t decrypted_x; - dylog_disable_scope_t no_log_err; - EXPECT_ER(pve.decrypt(&prv_key, &pub_key, "test-label", curve, decrypted_x)); - EXPECT_NE(x, decrypted_x); - } + EXPECT_OK(pve.encrypt(pve_base_pke_ecies(), pve_keyref(ecc_pk), "test-label", curve, x)); + dylog_disable_scope_t no_log_err; + EXPECT_ER(pve.verify(pve_base_pke_ecies(), pve_keyref(ecc_pk), X, "wrong-label")); } -TEST_F(PVE, RSA_Completeness) { - crypto::rsa_prv_key_t rsa_sk; - rsa_sk.generate(2048); - crypto::rsa_pub_key_t rsa_pk(rsa_sk.pub()); +TEST(PVE, ECIES_VerifyWithWrongQ) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); - ec_pve_t pve(pve_base_pke_rsa()); - bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; + ecc_prv_key_t ecc_sk; + ecc_sk.generate(crypto::curve_p256); + ecc_pub_key_t ecc_pk(ecc_sk.pub()); - pve.encrypt(&rsa_pk, "test-label", curve, x); - EXPECT_OK(pve.verify(&rsa_pk, X, "test-label")); + ec_pve_t pve; + bn_t x = bn_t::rand(q); - bn_t decrypted_x; - EXPECT_OK(pve.decrypt(&rsa_sk, &rsa_pk, "test-label", curve, decrypted_x)); - EXPECT_EQ(x, decrypted_x); + EXPECT_OK(pve.encrypt(pve_base_pke_ecies(), pve_keyref(ecc_pk), "test-label", curve, x)); + dylog_disable_scope_t no_log_err; + EXPECT_ER(pve.verify(pve_base_pke_ecies(), pve_keyref(ecc_pk), bn_t::rand(q) * G, "test-label")); } -TEST_F(PVE, ECIES_Completeness) { - crypto::ecc_prv_key_t ecc_sk; +TEST(PVE, ECIES_DecryptWithWrongLabel) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + + ecc_prv_key_t ecc_sk; ecc_sk.generate(crypto::curve_p256); - crypto::ecc_pub_key_t ecc_pk(ecc_sk.pub()); + ecc_pub_key_t ecc_pk(ecc_sk.pub()); - ec_pve_t pve(pve_base_pke_ecies()); + ec_pve_t pve; bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; - pve.encrypt(&ecc_pk, "test-label", curve, x); - EXPECT_OK(pve.verify(&ecc_pk, X, "test-label")); + EXPECT_OK(pve.encrypt(pve_base_pke_ecies(), pve_keyref(ecc_pk), "test-label", curve, x)); bn_t decrypted_x; - EXPECT_OK(pve.decrypt(&ecc_sk, &ecc_pk, "test-label", curve, decrypted_x)); - EXPECT_EQ(x, decrypted_x); + dylog_disable_scope_t no_log_err; + EXPECT_ER( + pve.decrypt(pve_base_pke_ecies(), pve_keyref(ecc_sk), pve_keyref(ecc_pk), "wrong-label", curve, decrypted_x)); + EXPECT_NE(x, decrypted_x); } -TEST_F(PVE, CustomKEM_Completeness) { - const mpc::pve_base_pke_i &base_pke = mpc::kem_pve_base_pke(); - mpc::ec_pve_t pve(base_pke); +TEST(PVE, CustomKEM_Completeness) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); + + const mpc::pve_base_pke_i& base_pke = mpc::kem_pve_base_pke(); + mpc::ec_pve_t pve; toy_kem_policy_t::ek_t ek; toy_kem_policy_t::dk_t dk; @@ -196,83 +143,109 @@ TEST_F(PVE, CustomKEM_Completeness) { bn_t x = bn_t::rand(q); ecc_point_t X = x * G; - pve.encrypt(&ek, "test-label", curve, x); - EXPECT_OK(pve.verify(&ek, X, "test-label")); + EXPECT_OK(pve.encrypt(base_pke, pve_keyref(ek), "test-label", curve, x)); + EXPECT_OK(pve.verify(base_pke, pve_keyref(ek), X, "test-label")); bn_t decrypted_x; - EXPECT_OK(pve.decrypt(&dk, &ek, "test-label", curve, decrypted_x)); + EXPECT_OK(pve.decrypt(base_pke, pve_keyref(dk), pve_keyref(ek), "test-label", curve, decrypted_x)); EXPECT_EQ(x, decrypted_x); } -typedef PVE PVEBatch; - -TEST_F(PVEBatch, Completeness) { - int n = 20; - for (const auto &[pub_key, prv_key] : valid_unified) { - pve_batch_t pve_batch(n); - std::vector xs(n); - std::vector Xs(n); - for (int i = 0; i < n; i++) { - xs[i] = (i > n / 2) ? bn_t(i) : bn_t::rand(q); - Xs[i] = xs[i] * G; - } - - pve_batch.encrypt(&pub_key, "test-label", curve, xs); - EXPECT_OK(pve_batch.verify(&pub_key, Xs, "test-label")); - - std::vector decrypted_xs; - EXPECT_OK(pve_batch.decrypt(&prv_key, &pub_key, "test-label", curve, decrypted_xs)); - EXPECT_EQ(xs, decrypted_xs); +TEST(PVEBatch, Completeness_ECIES) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); + + const int n = 20; + ecc_prv_key_t ecc_sk; + ecc_sk.generate(crypto::curve_p256); + ecc_pub_key_t ecc_pk(ecc_sk.pub()); + + pve_batch_t pve_batch(n); + std::vector xs(n); + std::vector Xs(n); + for (int i = 0; i < n; i++) { + xs[i] = (i > n / 2) ? bn_t(i) : bn_t::rand(q); + Xs[i] = xs[i] * G; } + + EXPECT_OK(pve_batch.encrypt(pve_base_pke_ecies(), pve_keyref(ecc_pk), "test-label", curve, xs)); + EXPECT_OK(pve_batch.verify(pve_base_pke_ecies(), pve_keyref(ecc_pk), Xs, "test-label")); + + std::vector decrypted_xs; + EXPECT_OK(pve_batch.decrypt(pve_base_pke_ecies(), pve_keyref(ecc_sk), pve_keyref(ecc_pk), "test-label", curve, + decrypted_xs)); + EXPECT_EQ(xs, decrypted_xs); } -TEST_F(PVEBatch, RejectsHugeBatchCount) { +TEST(PVEBatch, RejectsHugeBatchCount) { dylog_disable_scope_t no_log_err; EXPECT_CB_ASSERT(pve_batch_t(200000000), "batch_count"); } -TEST_F(PVEBatch, VerifyWithWrongLabel) { - for (const auto &[pub_key, prv_key] : valid_unified) { - pve_batch_t pve_batch(1); - bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; +TEST(PVEBatch, VerifyWithWrongLabel_ECIES) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); - pve_batch.encrypt(&pub_key, "test-label", curve, {x}); - dylog_disable_scope_t no_log_err; - EXPECT_ER(pve_batch.verify(&pub_key, {X}, "wrong-label")); - } + ecc_prv_key_t ecc_sk; + ecc_sk.generate(crypto::curve_p256); + ecc_pub_key_t ecc_pk(ecc_sk.pub()); + + pve_batch_t pve_batch(1); + bn_t x = bn_t::rand(q); + ecc_point_t X = x * G; + + EXPECT_OK(pve_batch.encrypt(pve_base_pke_ecies(), pve_keyref(ecc_pk), "test-label", curve, {x})); + dylog_disable_scope_t no_log_err; + EXPECT_ER(pve_batch.verify(pve_base_pke_ecies(), pve_keyref(ecc_pk), {X}, "wrong-label")); } -TEST_F(PVEBatch, VerifyWithWrongPublicKey) { - for (const auto &[pub_key, prv_key] : valid_unified) { - pve_batch_t pve_batch(1); - bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; +TEST(PVEBatch, VerifyWithWrongPublicKey_ECIES) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); - pve_batch.encrypt(&pub_key, "test-label", curve, {x}); - dylog_disable_scope_t no_log_err; - EXPECT_ER(pve_batch.verify(&pub_key, {bn_t::rand(q) * G}, "test-label")); - } + ecc_prv_key_t ecc_sk; + ecc_sk.generate(crypto::curve_p256); + ecc_pub_key_t ecc_pk(ecc_sk.pub()); + + pve_batch_t pve_batch(1); + bn_t x = bn_t::rand(q); + + EXPECT_OK(pve_batch.encrypt(pve_base_pke_ecies(), pve_keyref(ecc_pk), "test-label", curve, {x})); + dylog_disable_scope_t no_log_err; + EXPECT_ER(pve_batch.verify(pve_base_pke_ecies(), pve_keyref(ecc_pk), {bn_t::rand(q) * G}, "test-label")); } -TEST_F(PVEBatch, DecryptWithWrongLabel) { - for (const auto &[pub_key, prv_key] : valid_unified) { - pve_batch_t pve_batch(1); - std::vector xs = {bn_t::rand(q)}; +TEST(PVEBatch, DecryptWithWrongLabel_ECIES) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); - pve_batch.encrypt(&pub_key, "test-label", curve, xs); + ecc_prv_key_t ecc_sk; + ecc_sk.generate(crypto::curve_p256); + ecc_pub_key_t ecc_pk(ecc_sk.pub()); - std::vector decrypted_xs; - dylog_disable_scope_t no_log_err; - EXPECT_ER(pve_batch.decrypt(&prv_key, &pub_key, "wrong-label", curve, decrypted_xs)); - EXPECT_NE(xs, decrypted_xs); - } + pve_batch_t pve_batch(1); + std::vector xs = {bn_t::rand(q)}; + + EXPECT_OK(pve_batch.encrypt(pve_base_pke_ecies(), pve_keyref(ecc_pk), "test-label", curve, xs)); + + std::vector decrypted_xs; + dylog_disable_scope_t no_log_err; + EXPECT_ER(pve_batch.decrypt(pve_base_pke_ecies(), pve_keyref(ecc_sk), pve_keyref(ecc_pk), "wrong-label", curve, + decrypted_xs)); + EXPECT_NE(xs, decrypted_xs); } -TEST_F(PVEBatch, CustomKEM_Batch_Completeness) { - const mpc::pve_base_pke_i &base_pke = mpc::kem_pve_base_pke(); - int n = 8; - mpc::ec_pve_batch_t pve_batch(n, base_pke); +TEST(PVEBatch, CustomKEM_Batch_Completeness) { + const ecurve_t curve = crypto::curve_p256; + const mod_t& q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); + + const mpc::pve_base_pke_i& base_pke = mpc::kem_pve_base_pke(); + const int n = 8; + mpc::ec_pve_batch_t pve_batch(n); toy_kem_policy_t::ek_t ek; toy_kem_policy_t::dk_t dk; @@ -284,11 +257,11 @@ TEST_F(PVEBatch, CustomKEM_Batch_Completeness) { Xs[i] = xs[i] * G; } - pve_batch.encrypt(&ek, "test-label", curve, xs); - EXPECT_OK(pve_batch.verify(&ek, Xs, "test-label")); + EXPECT_OK(pve_batch.encrypt(base_pke, pve_keyref(ek), "test-label", curve, xs)); + EXPECT_OK(pve_batch.verify(base_pke, pve_keyref(ek), Xs, "test-label")); std::vector decrypted_xs; - EXPECT_OK(pve_batch.decrypt(&dk, &ek, "test-label", curve, decrypted_xs)); + EXPECT_OK(pve_batch.decrypt(base_pke, pve_keyref(dk), pve_keyref(ek), "test-label", curve, decrypted_xs)); EXPECT_EQ(xs, decrypted_xs); } diff --git a/tests/unit/protocol/test_pve_ac.cpp b/tests/unit/protocol/test_pve_ac.cpp index a3ce48d8..7e1c11ed 100644 --- a/tests/unit/protocol/test_pve_ac.cpp +++ b/tests/unit/protocol/test_pve_ac.cpp @@ -1,9 +1,9 @@ #include -#include -#include -#include -#include +#include +#include +#include +#include #include "utils/data/ac.h" #include "utils/test_macros.h" @@ -27,12 +27,6 @@ class PVEAC : public testutils::TestAC { ecurve_t curve; mod_t q; ecc_generator_point_t G; - crypto::prv_key_t get_prv_key(int participant_index) const { - if (participant_index & 1) - return crypto::prv_key_t::from(get_ecc_prv_key(participant_index)); - else - return crypto::prv_key_t::from(get_rsa_prv_key(participant_index)); - } crypto::ecc_prv_key_t get_ecc_prv_key(int participant_index) const { crypto::ecc_prv_key_t prv_key_ecc; @@ -47,19 +41,19 @@ class PVEAC : public testutils::TestAC { } }; -TEST_F(PVEAC, PKI) { +TEST_F(PVEAC, ECC) { error_t rv = UNINITIALIZED_ERROR; ss::ac_t ac(test_root); auto leaves = ac.list_leaf_names(); - std::map pub_keys_val; - std::map prv_keys_val; + std::map pub_keys_val; + std::map prv_keys_val; ec_pve_ac_t::pks_t pub_keys; ec_pve_ac_t::sks_t prv_keys; int participant_index = 0; for (auto path : leaves) { - auto prv_key = get_prv_key(participant_index); + auto prv_key = get_ecc_prv_key(participant_index); if (!ac.enough_for_quorum(pub_keys_val)) { prv_keys_val[path] = prv_key; } @@ -67,8 +61,8 @@ TEST_F(PVEAC, PKI) { participant_index++; } - for (auto &kv : pub_keys_val) pub_keys[kv.first] = &kv.second; - for (auto &kv : prv_keys_val) prv_keys[kv.first] = &kv.second; + for (auto& kv : pub_keys_val) pub_keys[kv.first] = pve_keyref(kv.second); + for (auto& kv : prv_keys_val) prv_keys[kv.first] = pve_keyref(kv.second); const int n = 20; ec_pve_ac_t pve; @@ -80,37 +74,39 @@ TEST_F(PVEAC, PKI) { } std::string label = "test-label"; - pve.encrypt(ac, pub_keys, label, curve, xs); - rv = pve.verify(ac, pub_keys, Xs, label); + rv = pve.encrypt(pve_base_pke_ecies(), ac, pub_keys, label, curve, xs); + ASSERT_EQ(rv, 0); + rv = pve.verify(pve_base_pke_ecies(), ac, pub_keys, Xs, label); EXPECT_EQ(rv, 0); int row_index = 0; crypto::ss::party_map_t shares; - for (auto &[path, prv_key] : prv_keys) { + for (auto& [path, prv_key] : prv_keys) { bn_t share; - rv = pve.party_decrypt_row(ac, row_index, path, prv_key, label, share); + rv = pve.party_decrypt_row(pve_base_pke_ecies(), ac, row_index, path, prv_key, label, share); ASSERT_EQ(rv, 0); shares[path] = share; } std::vector decrypted_xs; - rv = pve.aggregate_to_restore_row(ac, row_index, label, shares, decrypted_xs, /*skip_verify=*/true); + rv = pve.aggregate_to_restore_row(pve_base_pke_ecies(), ac, row_index, label, shares, decrypted_xs, + /*skip_verify=*/true); ASSERT_EQ(rv, 0); EXPECT_TRUE(xs == decrypted_xs); } -TEST_F(PVEAC, ECC) { +TEST_F(PVEAC, RSA) { error_t rv = UNINITIALIZED_ERROR; ss::ac_t ac(test_root); auto leaves = ac.list_leaf_names(); - std::map pub_keys_val; - std::map prv_keys_val; + std::map pub_keys_val; + std::map prv_keys_val; ec_pve_ac_t::pks_t pub_keys; ec_pve_ac_t::sks_t prv_keys; int participant_index = 0; for (auto path : leaves) { - auto prv_key = get_ecc_prv_key(participant_index); + auto prv_key = get_rsa_prv_key(participant_index); if (!ac.enough_for_quorum(pub_keys_val)) { prv_keys_val[path] = prv_key; } @@ -118,11 +114,11 @@ TEST_F(PVEAC, ECC) { participant_index++; } - for (auto &kv : pub_keys_val) pub_keys[kv.first] = &kv.second; - for (auto &kv : prv_keys_val) prv_keys[kv.first] = &kv.second; + for (auto& kv : pub_keys_val) pub_keys[kv.first] = pve_keyref(kv.second); + for (auto& kv : prv_keys_val) prv_keys[kv.first] = pve_keyref(kv.second); const int n = 20; - ec_pve_ac_t pve(pve_base_pke_ecies()); + ec_pve_ac_t pve; std::vector xs(n); std::vector Xs(n); for (int i = 0; i < n; i++) { @@ -131,37 +127,39 @@ TEST_F(PVEAC, ECC) { } std::string label = "test-label"; - pve.encrypt(ac, pub_keys, label, curve, xs); - rv = pve.verify(ac, pub_keys, Xs, label); + rv = pve.encrypt(pve_base_pke_rsa(), ac, pub_keys, label, curve, xs); + ASSERT_EQ(rv, 0); + rv = pve.verify(pve_base_pke_rsa(), ac, pub_keys, Xs, label); EXPECT_EQ(rv, 0); int row_index = 0; crypto::ss::party_map_t shares; - for (auto &[path, prv_key] : prv_keys) { + for (auto& [path, prv_key] : prv_keys) { bn_t share; - rv = pve.party_decrypt_row(ac, row_index, path, prv_key, label, share); + rv = pve.party_decrypt_row(pve_base_pke_rsa(), ac, row_index, path, prv_key, label, share); ASSERT_EQ(rv, 0); shares[path] = share; } std::vector decrypted_xs; - rv = pve.aggregate_to_restore_row(ac, row_index, label, shares, decrypted_xs, /*skip_verify=*/true); + rv = pve.aggregate_to_restore_row(pve_base_pke_rsa(), ac, row_index, label, shares, decrypted_xs, + /*skip_verify=*/true); ASSERT_EQ(rv, 0); EXPECT_TRUE(xs == decrypted_xs); } -TEST_F(PVEAC, RSA) { +TEST_F(PVEAC, AggFail_VerifyOn_MissingLeafPub) { error_t rv = UNINITIALIZED_ERROR; ss::ac_t ac(test_root); auto leaves = ac.list_leaf_names(); - std::map pub_keys_val; - std::map prv_keys_val; + std::map pub_keys_val; + std::map prv_keys_val; ec_pve_ac_t::pks_t pub_keys; ec_pve_ac_t::sks_t prv_keys; int participant_index = 0; - for (auto path : leaves) { - auto prv_key = get_rsa_prv_key(participant_index); + for (const auto& path : leaves) { + auto prv_key = get_ecc_prv_key(participant_index); if (!ac.enough_for_quorum(pub_keys_val)) { prv_keys_val[path] = prv_key; } @@ -169,35 +167,29 @@ TEST_F(PVEAC, RSA) { participant_index++; } - for (auto &kv : pub_keys_val) pub_keys[kv.first] = &kv.second; - for (auto &kv : prv_keys_val) prv_keys[kv.first] = &kv.second; + for (const auto& kv : pub_keys_val) pub_keys[kv.first] = pve_keyref(kv.second); + for (const auto& kv : prv_keys_val) prv_keys[kv.first] = pve_keyref(kv.second); - const int n = 20; - ec_pve_ac_t pve(pve_base_pke_rsa()); + const int n = 8; + ec_pve_ac_t pve; std::vector xs(n); - std::vector Xs(n); - for (int i = 0; i < n; i++) { - xs[i] = bn_t::rand(q); - Xs[i] = xs[i] * G; - } + for (int i = 0; i < n; i++) xs[i] = bn_t::rand(q); std::string label = "test-label"; - pve.encrypt(ac, pub_keys, label, curve, xs); - rv = pve.verify(ac, pub_keys, Xs, label); - EXPECT_EQ(rv, 0); + ASSERT_OK(pve.encrypt(pve_base_pke_ecies(), ac, pub_keys, label, curve, xs)); int row_index = 0; crypto::ss::party_map_t shares; - for (auto &[path, prv_key] : prv_keys) { + for (auto& [path, prv_key] : prv_keys) { bn_t share; - rv = pve.party_decrypt_row(ac, row_index, path, prv_key, label, share); - ASSERT_EQ(rv, 0); + ASSERT_OK(pve.party_decrypt_row(pve_base_pke_ecies(), ac, row_index, path, prv_key, label, share)); shares[path] = share; } + std::vector decrypted_xs; - rv = pve.aggregate_to_restore_row(ac, row_index, label, shares, decrypted_xs, /*skip_verify=*/true); - ASSERT_EQ(rv, 0); - EXPECT_TRUE(xs == decrypted_xs); + EXPECT_ER(pve.aggregate_to_restore_row(pve_base_pke_ecies(), ac, row_index, label, shares, decrypted_xs, + /*skip_verify=*/false, + /*all_ac_pks=*/ec_pve_ac_t::pks_t{})); } } // namespace diff --git a/tests/unit/protocol/test_schnorr_2p.cpp b/tests/unit/protocol/test_schnorr_2p.cpp index a5bf2a16..52c40292 100644 --- a/tests/unit/protocol/test_schnorr_2p.cpp +++ b/tests/unit/protocol/test_schnorr_2p.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "utils/local_network/mpc_tester.h" @@ -105,51 +105,4 @@ TEST_F(BIP340_2PC, KeygenSignRefreshSign) { check_key_pair(new_keys[0], new_keys[1]); } -TEST_F(EdDSA2PC, ParallelKSRS8) { - int parallel_count = 8; - std::vector> data(30); - std::vector> data_buf(30); - for (int i = 0; i < parallel_count; i++) { - int len = i + 1; - data[i].resize(len); - data_buf[i].resize(len); - for (int j = 0; j < len; j++) data[i][j] = data_buf[i][j] = crypto::gen_random(32); - } - std::vector> keys(parallel_count, std::vector(2)); - std::vector> new_keys(parallel_count, std::vector(2)); - - mpc_runner->run_2pc_parallel(parallel_count, [&data, &keys, &new_keys](job_parallel_2p_t& job, int th_i) { - error_t rv = UNINITIALIZED_ERROR; - auto party_index = job.get_party_idx(); - ecurve_t curve = crypto::curve_ed25519; - - eddsa2pc::key_t& key = keys[th_i][party_index]; - buf_t sid; - rv = eckey::key_share_2p_t::dkg(job, curve, key, sid); - ASSERT_EQ(rv, 0); - - std::vector sig_bufs; - rv = eddsa2pc::sign_batch(job, key, data[th_i], sig_bufs); - ASSERT_EQ(rv, 0); - - eddsa2pc::key_t& new_key = new_keys[th_i][party_index]; - rv = eckey::key_share_2p_t::refresh(job, key, new_key); - ASSERT_EQ(rv, 0); - - EXPECT_EQ(new_key.role, key.role); - EXPECT_EQ(new_key.curve, key.curve); - EXPECT_EQ(new_key.Q, key.Q); - EXPECT_NE(new_key.x_share, key.x_share); - - std::vector new_sig_bufs; - rv = eddsa2pc::sign_batch(job, new_key, data[th_i], new_sig_bufs); - ASSERT_EQ(rv, 0); - }); - - for (int i = 0; i < parallel_count; i++) { - check_key_pair(keys[i][0], keys[i][1]); - check_key_pair(new_keys[i][0], new_keys[i][1]); - } -} - } // namespace \ No newline at end of file diff --git a/tests/unit/protocol/test_schnorr_mp.cpp b/tests/unit/protocol/test_schnorr_mp.cpp index 5f527387..2c4631a3 100644 --- a/tests/unit/protocol/test_schnorr_mp.cpp +++ b/tests/unit/protocol/test_schnorr_mp.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "utils/local_network/mpc_tester.h" @@ -110,46 +110,4 @@ TEST_F(BIP340_4PC, KeygenSignRefreshSign) { check_keys(new_keys); } -TEST_F(EdDSA_4PC, ParallelKSRS8) { - int parallel_count = 8; - std::vector> data(parallel_count); - for (int i = 0; i < parallel_count; i++) { - int len = i + 1; - data[i].resize(len); - for (int j = 0; j < len; j++) data[i][j] = crypto::gen_random(32); - } - std::vector> keys(parallel_count, std::vector(4)); - std::vector> new_keys(parallel_count, std::vector(4)); - - mpc_runner->run_mpc_parallel(parallel_count, [&keys, &new_keys, &data](job_parallel_mp_t& job, int th_i) { - error_t rv = UNINITIALIZED_ERROR; - auto party_index = job.get_party_idx(); - eddsampc::key_t& key = keys[th_i][party_index]; - ecurve_t curve = crypto::curve_ed25519; - - buf_t sid; - rv = eckey::key_share_mp_t::dkg(job, curve, key, sid); - ASSERT_EQ(rv, 0); - - std::vector sig_buf; - rv = eddsampc::sign_batch(job, key, buf_t::to_mems(data[th_i]), party_idx_t(0), sig_buf); - ASSERT_EQ(rv, 0); - - eddsampc::key_t& new_key = new_keys[th_i][party_index]; - rv = eckey::key_share_mp_t::refresh(job, sid, key, new_key); - ASSERT_EQ(rv, 0); - EXPECT_EQ(new_key.Q, key.Q); - EXPECT_NE(new_key.x_share, key.x_share); - - std::vector new_sig_buf; - rv = eddsampc::sign_batch(job, new_key, buf_t::to_mems(data[th_i]), party_idx_t(0), new_sig_buf); - ASSERT_EQ(rv, 0); - }); - - for (int i = 0; i < parallel_count; i++) { - check_keys(keys[i]); - check_keys(new_keys[i]); - } -} - } // namespace diff --git a/tests/unit/protocol/test_util.cpp b/tests/unit/protocol/test_util.cpp index 7c61c393..f3c4d866 100644 --- a/tests/unit/protocol/test_util.cpp +++ b/tests/unit/protocol/test_util.cpp @@ -1,6 +1,6 @@ #include -#include "cbmpc/protocol/util.h" +#include using namespace coinbase::crypto; diff --git a/tests/unit/zk/test_zk.cpp b/tests/unit/zk/test_zk.cpp index 79f5e040..bd385cba 100644 --- a/tests/unit/zk/test_zk.cpp +++ b/tests/unit/zk/test_zk.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "utils/data/zk_completeness.h" #include "utils/test_macros.h" diff --git a/tests/utils/crypto/nizk.h b/tests/utils/crypto/nizk.h index 0da0bae2..990c4ccc 100644 --- a/tests/utils/crypto/nizk.h +++ b/tests/utils/crypto/nizk.h @@ -1,5 +1,5 @@ #pragma once -#include +#include struct test_nizk_t { uint64_t aux = 0; diff --git a/tests/utils/data/ac.h b/tests/utils/data/ac.h index 8c6d0230..62d954f8 100644 --- a/tests/utils/data/ac.h +++ b/tests/utils/data/ac.h @@ -2,7 +2,7 @@ #include -#include +#include #include "test_node.h" diff --git a/tests/utils/data/mpc_data_generator.h b/tests/utils/data/mpc_data_generator.h index 6fe926c2..3a4011fb 100644 --- a/tests/utils/data/mpc_data_generator.h +++ b/tests/utils/data/mpc_data_generator.h @@ -1,6 +1,6 @@ #pragma once -#include -#include +#include +#include #include "data/data_generator.h" diff --git a/tests/utils/data/sampler/base.h b/tests/utils/data/sampler/base.h index 70715c83..494c8c90 100644 --- a/tests/utils/data/sampler/base.h +++ b/tests/utils/data/sampler/base.h @@ -3,9 +3,9 @@ #include #include -#include -#include -#include +#include +#include +#include namespace coinbase::test { template diff --git a/tests/utils/data/sampler/bn.h b/tests/utils/data/sampler/bn.h index 4bb78063..6e32ac1a 100644 --- a/tests/utils/data/sampler/bn.h +++ b/tests/utils/data/sampler/bn.h @@ -2,10 +2,10 @@ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include "base.h" diff --git a/tests/utils/data/sampler/buf.cpp b/tests/utils/data/sampler/buf.cpp index f68af446..7667bf52 100644 --- a/tests/utils/data/sampler/buf.cpp +++ b/tests/utils/data/sampler/buf.cpp @@ -2,8 +2,9 @@ using namespace coinbase::test; -buf_t buf_sampler_t::sample(const buf_distribution_t& dist, const std::vector& dist_dependencies) { - buf_t a; +coinbase::buf_t buf_sampler_t::sample(const buf_distribution_t& dist, + const std::vector& dist_dependencies) { + coinbase::buf_t a; switch (dist) { case buf_distribution_t::RANDOM_32BYTES_0: { a = crypto::gen_random(32); @@ -14,7 +15,7 @@ buf_t buf_sampler_t::sample(const buf_distribution_t& dist, const std::vector(dist_dependencies[0]); + a = std::get(dist_dependencies[0]); break; default: break; @@ -22,11 +23,11 @@ buf_t buf_sampler_t::sample(const buf_distribution_t& dist, const std::vector& filter_dependencies) { switch (filter) { case buf_filter_t::NOT_SAME_AS_1: - return a != std::get(filter_dependencies[0]); + return a != std::get(filter_dependencies[0]); break; default: break; diff --git a/tests/utils/data/sampler/buf.h b/tests/utils/data/sampler/buf.h index bd2f1158..c4a4941a 100644 --- a/tests/utils/data/sampler/buf.h +++ b/tests/utils/data/sampler/buf.h @@ -2,8 +2,8 @@ #include #include -#include -#include +#include +#include #include "base.h" diff --git a/tests/utils/data/sampler/ecp.h b/tests/utils/data/sampler/ecp.h index 9af1ea17..85dfa9c6 100644 --- a/tests/utils/data/sampler/ecp.h +++ b/tests/utils/data/sampler/ecp.h @@ -2,9 +2,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include "base.h" diff --git a/tests/utils/data/sampler/elgamal.h b/tests/utils/data/sampler/elgamal.h index ee6113b1..252f1e49 100644 --- a/tests/utils/data/sampler/elgamal.h +++ b/tests/utils/data/sampler/elgamal.h @@ -2,8 +2,8 @@ #include #include -#include -#include +#include +#include #include "base.h" diff --git a/tests/utils/data/sampler/paillier.h b/tests/utils/data/sampler/paillier.h index d44ad228..468d479c 100644 --- a/tests/utils/data/sampler/paillier.h +++ b/tests/utils/data/sampler/paillier.h @@ -2,8 +2,8 @@ #include #include -#include -#include +#include +#include #include "base.h" diff --git a/tests/utils/data/tdh2.h b/tests/utils/data/tdh2.h index 7d7209ff..d149b2c2 100644 --- a/tests/utils/data/tdh2.h +++ b/tests/utils/data/tdh2.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include using namespace coinbase::crypto; @@ -20,13 +20,13 @@ void generate_additive_shares(int n, tdh2::public_key_t& enc_key, tdh2::pub_shar for (int i = 0; i < n; i++) { pub_shares[i] = prv_shares[i] * G; } - enc_key.Q = x * G; - enc_key.Gamma = ro::hash_curve(mem_t("TDH2-Gamma"), enc_key.Q).curve(curve); + const buf_t sid = gen_random(32); + enc_key = tdh2::public_key_t(x * G, sid); dec_shares.resize(n); for (int i = 0; i < n; i++) { dec_shares[i].x = prv_shares[i]; - dec_shares[i].pid = i + 1; + dec_shares[i].rid = i + 1; dec_shares[i].pub_key = enc_key; } } @@ -37,16 +37,17 @@ void generate_ac_shares(const ss::ac_t& ac, tdh2::public_key_t& enc_key, ss::ac_ const mod_t& q = curve.order(); bn_t x = curve.get_random_value(); - enc_key.Q = x * G; - enc_key.Gamma = ro::hash_curve(mem_t("TDH2-Gamma"), enc_key.Q).curve(curve); + const buf_t sid = gen_random(32); + enc_key = tdh2::public_key_t(x * G, sid); ss::ac_shares_t prv_shares = ac.share(q, x); pub_shares.clear(); dec_shares.clear(); + int rid = 1; for (const auto& [name, xi] : prv_shares) { pub_shares[name] = xi * G; dec_shares[name].x = xi; - dec_shares[name].pid = ss::node_t::pid_from_path(name); + dec_shares[name].rid = rid++; dec_shares[name].pub_key = enc_key; } } diff --git a/tests/utils/data/test_data_factory.h b/tests/utils/data/test_data_factory.h index 7c189a38..f153ccef 100644 --- a/tests/utils/data/test_data_factory.h +++ b/tests/utils/data/test_data_factory.h @@ -4,12 +4,12 @@ #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include "sampler/bn.h" #include "sampler/buf.h" diff --git a/tests/utils/data/test_node.h b/tests/utils/data/test_node.h index b0f9e7a1..b728ec2e 100644 --- a/tests/utils/data/test_node.h +++ b/tests/utils/data/test_node.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace coinbase::testutils { diff --git a/tests/utils/data/zk_completeness.h b/tests/utils/data/zk_completeness.h index d4a58ca5..0a45b085 100644 --- a/tests/utils/data/zk_completeness.h +++ b/tests/utils/data/zk_completeness.h @@ -1,11 +1,11 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include namespace coinbase::test::data { diff --git a/tests/utils/data/zk_data_generator.h b/tests/utils/data/zk_data_generator.h index 1d229430..e82f0bbd 100644 --- a/tests/utils/data/zk_data_generator.h +++ b/tests/utils/data/zk_data_generator.h @@ -1,6 +1,6 @@ #pragma once -#include -#include +#include +#include #include "data/data_generator.h" diff --git a/tests/utils/local_network/channel.h b/tests/utils/local_network/channel.h index b42897f4..bae63b82 100644 --- a/tests/utils/local_network/channel.h +++ b/tests/utils/local_network/channel.h @@ -1,5 +1,5 @@ #pragma once -#include +#include namespace coinbase::testutils { diff --git a/tests/utils/local_network/mpc_runner.cpp b/tests/utils/local_network/mpc_runner.cpp index 1b9dcb66..5f48ab8d 100644 --- a/tests/utils/local_network/mpc_runner.cpp +++ b/tests/utils/local_network/mpc_runner.cpp @@ -1,6 +1,6 @@ #include "mpc_runner.h" -#include +#include using namespace coinbase::mpc; @@ -127,47 +127,6 @@ void mpc_runner_t::run_mpc(lambda_mp_t f) { run_mpc_role([&](party_idx_t party_index) { f(*job_mps[party_index]); }); } -void mpc_runner_t::run_2pc_parallel_helper(std::shared_ptr network, party_t role, int th_i, - lambda_2p_parallel_t f) { - parallel_id_t parallel_id = th_i; - job_parallel_2p_t job(role, test_pnames[0], test_pnames[1], network, parallel_id); - f(job, th_i); -} - -void mpc_runner_t::run_2pc_parallel(int n_threads, lambda_2p_parallel_t f) { - run_mpc_role([&](party_idx_t party_index) { - std::shared_ptr network = - std::make_shared(get_data_transport_ptr(party_index), n_threads); - - std::vector threads; - for (int th_i = 0; th_i < n_threads; th_i++) { - threads.emplace_back(run_2pc_parallel_helper, network, party_t(party_index), th_i, f); - } - for (auto& th : threads) th.join(); - }); -} - -void mpc_runner_t::run_mpc_parallel_helper(int n, std::shared_ptr network, - party_idx_t party_index, int th_i, lambda_mp_parallel_t f) { - parallel_id_t parallel_id = th_i; - std::vector pnames(test_pnames.begin(), test_pnames.begin() + n); - job_parallel_mp_t job(party_index, pnames, network, parallel_id); - f(job, th_i); -} - -void mpc_runner_t::run_mpc_parallel(int n_threads, lambda_mp_parallel_t f) { - run_mpc_role([&](party_idx_t party_index) { - std::shared_ptr network = - std::make_shared(get_data_transport_ptr(party_index), n_threads); - - std::vector threads; - for (int th_i = 0; th_i < n_threads; th_i++) { - threads.emplace_back(run_mpc_parallel_helper, n, network, party_index, th_i, f); - } - for (auto& th : threads) th.join(); - }); -} - std::shared_ptr mpc_runner_t::get_data_transport_ptr(party_idx_t role) { return data_transports[role]; } diff --git a/tests/utils/local_network/mpc_runner.h b/tests/utils/local_network/mpc_runner.h index 6a37d6ff..db9aa24d 100644 --- a/tests/utils/local_network/mpc_runner.h +++ b/tests/utils/local_network/mpc_runner.h @@ -10,10 +10,8 @@ class partner_t; typedef std::function lambda_role_t; typedef std::function lambda_2p_t; typedef std::function lambda_mp_t; -typedef std::function lambda_2p_parallel_t; -typedef std::function lambda_mp_parallel_t; -class local_data_transport_t : public mpc::data_transport_interface_t { +class local_data_transport_t : public coinbase::api::data_transport_i { public: local_data_transport_t(const std::shared_ptr& nc_ptr) : net_context_ptr(nc_ptr) {} error_t send(const mpc::party_idx_t receiver, mem_t msg) override; @@ -42,8 +40,6 @@ class mpc_runner_t { void run_2pc(lambda_2p_t f); void run_mpc(lambda_mp_t f); - void run_2pc_parallel(int n_threads, lambda_2p_parallel_t f); - void run_mpc_parallel(int n_threads, lambda_mp_parallel_t f); // In-class declaration (no initializer): static const std::vector test_pnames; @@ -65,10 +61,6 @@ class mpc_runner_t { void set_new_network_mp(); void run_mpc_role(lambda_role_t f); - static void run_2pc_parallel_helper(std::shared_ptr network, mpc::party_t role, - int th_i, lambda_2p_parallel_t f); - static void run_mpc_parallel_helper(int n, std::shared_ptr network, - mpc::party_idx_t party_index, int th_i, lambda_mp_parallel_t f); }; // namespace coinbase::testutils } // namespace coinbase::testutils \ No newline at end of file diff --git a/tests/utils/local_network/network_context.h b/tests/utils/local_network/network_context.h index e3931727..3ca64b65 100644 --- a/tests/utils/local_network/network_context.h +++ b/tests/utils/local_network/network_context.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include "channel.h" diff --git a/tools/benchmark/CMakeLists.txt b/tools/benchmark/CMakeLists.txt index 1c721b9d..f01ddd37 100644 --- a/tools/benchmark/CMakeLists.txt +++ b/tools/benchmark/CMakeLists.txt @@ -21,7 +21,23 @@ set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable benchmark tests" FORCE) add_subdirectory(${REPO_ROOT}/vendors/google-benchmark ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/.benchmark) -set(CBMPC_SOURCE_DIR /usr/local/opt/cbmpc/) +if(NOT DEFINED CBMPC_SOURCE_DIR) + if(DEFINED ENV{CBMPC_PREFIX}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX}") + elseif(DEFINED ENV{CBMPC_PREFIX_FULL}) + set(CBMPC_SOURCE_DIR "$ENV{CBMPC_PREFIX_FULL}") + else() + # Default to the repo-local full install prefix (no sudo). + set(CBMPC_SOURCE_DIR "${REPO_ROOT}/build/install/full") + endif() +endif() + +set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib") +if(EXISTS "${CBMPC_SOURCE_DIR}/lib/Release/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Release") +elseif(EXISTS "${CBMPC_SOURCE_DIR}/lib/Debug/libcbmpc.a") + set(CBMPC_LIB_DIR "${CBMPC_SOURCE_DIR}/lib/Debug") +endif() add_executable( cbmpc_benchmark @@ -45,14 +61,26 @@ add_executable( # bm_test.cpp ) -target_include_directories(cbmpc_benchmark PRIVATE ${REPO_ROOT}/tests/utils) +add_subdirectory(${REPO_ROOT}/tests/utils ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/.cbmpc_test_util) +target_include_directories(cbmpc_test_util PRIVATE ${CBMPC_SOURCE_DIR}/include) +if(EXISTS "${CBMPC_SOURCE_DIR}/include-internal") + target_include_directories(cbmpc_test_util PRIVATE ${CBMPC_SOURCE_DIR}/include-internal) +endif() +# `cbmpc_test_util` (local_network runner) historically used the parallel transport +# demo helpers. Keep the implementation linked here so clean Linux builds don't +# depend on core library internals. +target_sources(cbmpc_test_util PRIVATE ${REPO_ROOT}/demo-cpp/common/mpc_job_session.cpp) + target_include_directories(cbmpc_benchmark PRIVATE ${CBMPC_SOURCE_DIR}/include) +if(EXISTS "${CBMPC_SOURCE_DIR}/include-internal") + target_include_directories(cbmpc_benchmark PRIVATE ${CBMPC_SOURCE_DIR}/include-internal) +endif() +target_include_directories(cbmpc_benchmark PRIVATE ${REPO_ROOT}/demo-cpp/common) -target_link_directories(cbmpc_benchmark PUBLIC ${CBMPC_SOURCE_DIR}/lib) +target_link_directories(cbmpc_benchmark PUBLIC ${CBMPC_LIB_DIR}) link_openssl(cbmpc_benchmark) target_link_libraries(cbmpc_benchmark PUBLIC benchmark::benchmark PRIVATE cbmpc) -target_link_libraries(cbmpc_benchmark - PRIVATE ${REPO_ROOT}/lib/Release/libcbmpc_test_util.a) +target_link_libraries(cbmpc_benchmark PRIVATE cbmpc_test_util) if(IS_LINUX) link_openssl(cbmpc_benchmark) diff --git a/tools/benchmark/benchmark.makefile b/tools/benchmark/benchmark.makefile index 98cdff08..ff0629f6 100644 --- a/tools/benchmark/benchmark.makefile +++ b/tools/benchmark/benchmark.makefile @@ -1,12 +1,12 @@ home:=$(shell pwd) -exe:="build/cbmpc_benchmark" +exe:=build/$(BUILD_TYPE)/cbmpc_benchmark .PHONY: benchmark-build benchmark-build: ${RUN_CMD} 'cd $(home)/tools/benchmark && \ - cmake -Bbuild -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON -DCMAKE_BUILD_TYPE=Release && \ - cmake --build build/ -- -j$(NCORES)' + cmake -Bbuild/$(BUILD_TYPE) -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) -DCBMPC_SOURCE_DIR="$(CBMPC_PREFIX_FULL)" && \ + cmake --build build/$(BUILD_TYPE)/ -- -j$(CMAKE_NCORES)' .PHONY: benchmark-run # e.g. make benchmark-run unit=us filter=ZK benchmark-run: diff --git a/tools/benchmark/bm_agree_random.cpp b/tools/benchmark/bm_agree_random.cpp index 3c1dfa95..8b227ef9 100644 --- a/tools/benchmark/bm_agree_random.cpp +++ b/tools/benchmark/bm_agree_random.cpp @@ -2,8 +2,8 @@ #include -#include -#include +#include +#include #include "mpc_util.h" diff --git a/tools/benchmark/bm_commitment.cpp b/tools/benchmark/bm_commitment.cpp index 6141fe45..0259b5c9 100644 --- a/tools/benchmark/bm_commitment.cpp +++ b/tools/benchmark/bm_commitment.cpp @@ -1,8 +1,8 @@ #include -#include -#include -#include +#include +#include +#include #include "util.h" diff --git a/tools/benchmark/bm_core_bn.cpp b/tools/benchmark/bm_core_bn.cpp index 62f9d1a4..315d6f58 100644 --- a/tools/benchmark/bm_core_bn.cpp +++ b/tools/benchmark/bm_core_bn.cpp @@ -1,7 +1,7 @@ #include -#include +#include #define bit_len_lb 1 << 8 #define bit_len_ub 1 << 12 diff --git a/tools/benchmark/bm_drbg.cpp b/tools/benchmark/bm_drbg.cpp index a3208635..a70d7bc3 100644 --- a/tools/benchmark/bm_drbg.cpp +++ b/tools/benchmark/bm_drbg.cpp @@ -1,10 +1,11 @@ #include -#include -#include +#include +#include #include "util.h" +using namespace coinbase; using namespace coinbase::crypto; static void DRBG_String(benchmark::State& state) diff --git a/tools/benchmark/bm_ecdsa.cpp b/tools/benchmark/bm_ecdsa.cpp index f89bbb9c..59a1620c 100644 --- a/tools/benchmark/bm_ecdsa.cpp +++ b/tools/benchmark/bm_ecdsa.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "local_network/mpc_runner.h" #include "mpc_util.h" diff --git a/tools/benchmark/bm_eddsa.cpp b/tools/benchmark/bm_eddsa.cpp index c6409e91..6d5e34ae 100644 --- a/tools/benchmark/bm_eddsa.cpp +++ b/tools/benchmark/bm_eddsa.cpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include "mpc_util.h" diff --git a/tools/benchmark/bm_elgamal.cpp b/tools/benchmark/bm_elgamal.cpp index 68616384..298c19a5 100644 --- a/tools/benchmark/bm_elgamal.cpp +++ b/tools/benchmark/bm_elgamal.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "util.h" diff --git a/tools/benchmark/bm_elliptic_curve.cpp b/tools/benchmark/bm_elliptic_curve.cpp index 89905fb2..4a7c8642 100644 --- a/tools/benchmark/bm_elliptic_curve.cpp +++ b/tools/benchmark/bm_elliptic_curve.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include using namespace coinbase::crypto; diff --git a/tools/benchmark/bm_hash.cpp b/tools/benchmark/bm_hash.cpp index fde88838..990254e8 100644 --- a/tools/benchmark/bm_hash.cpp +++ b/tools/benchmark/bm_hash.cpp @@ -1,8 +1,9 @@ #include -#include -#include +#include +#include +using namespace coinbase; using namespace coinbase::crypto; static void BM_SHA256(benchmark::State& state) { diff --git a/tools/benchmark/bm_ot.cpp b/tools/benchmark/bm_ot.cpp index 2ecc6824..975dd6e0 100644 --- a/tools/benchmark/bm_ot.cpp +++ b/tools/benchmark/bm_ot.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #define base_ot_m_lb 1 << 6 #define base_ot_m_ub 1 << 11 diff --git a/tools/benchmark/bm_paillier.cpp b/tools/benchmark/bm_paillier.cpp index 9d50eaa1..f6087759 100644 --- a/tools/benchmark/bm_paillier.cpp +++ b/tools/benchmark/bm_paillier.cpp @@ -1,6 +1,6 @@ #include -#include +#include static void BM_Paillier_Gen(benchmark::State& state) { diff --git a/tools/benchmark/bm_pve.cpp b/tools/benchmark/bm_pve.cpp index 6f74225e..7c2bfd8e 100644 --- a/tools/benchmark/bm_pve.cpp +++ b/tools/benchmark/bm_pve.cpp @@ -1,9 +1,9 @@ #include -#include -#include -#include -#include +#include +#include +#include +#include #include "data/test_node.h" #include "util.h" @@ -11,77 +11,117 @@ using namespace coinbase; using namespace coinbase::mpc; +namespace { +constexpr int RSA_KEY_BITS = 2048; +} // namespace + static void BM_PVE_Encrypt(benchmark::State& state) { ec_pve_t pve; - crypto::pub_key_t pub_key; + const auto& curve = crypto::curve_p256; + const mod_t q = curve.order(); + bn_t x = bn_t::rand(q); + + const pve_base_pke_i* base_pke = nullptr; + pve_keyref_t ek; + + crypto::rsa_prv_key_t rsa_prv_key; + crypto::rsa_pub_key_t rsa_pub_key; + crypto::ecc_prv_key_t ecc_prv_key; + crypto::ecc_pub_key_t ecc_pub_key; + if (state.range(0) == 0) { - crypto::rsa_prv_key_t rsa_prv_key; - rsa_prv_key.generate(2048); - pub_key = crypto::pub_key_t::from(rsa_prv_key.pub()); + rsa_prv_key.generate(RSA_KEY_BITS); + rsa_pub_key = rsa_prv_key.pub(); + base_pke = &pve_base_pke_rsa(); + ek = pve_keyref(rsa_pub_key); } else { - crypto::ecc_prv_key_t ecc_prv_key; - ecc_prv_key.generate(crypto::curve_p256); - pub_key = crypto::pub_key_t::from(ecc_prv_key.pub()); + ecc_prv_key.generate(curve); + ecc_pub_key = ecc_prv_key.pub(); + base_pke = &pve_base_pke_ecies(); + ek = pve_keyref(ecc_pub_key); } - const mod_t q = crypto::curve_p256.order(); - const crypto::ecc_generator_point_t& G = crypto::curve_p256.generator(); - bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; for (auto _ : state) { - pve.encrypt(&pub_key, "test-label", crypto::curve_p256, x); + auto rv = pve.encrypt(*base_pke, ek, "test-label", curve, x); + benchmark::DoNotOptimize(rv); } } static void BM_PVE_Verify(benchmark::State& state) { ec_pve_t pve; - crypto::pub_key_t pub_key; + const auto& curve = crypto::curve_p256; + const mod_t q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); + bn_t x = bn_t::rand(q); + ecc_point_t X = x * G; + + const pve_base_pke_i* base_pke = nullptr; + pve_keyref_t ek; + + crypto::rsa_prv_key_t rsa_prv_key; + crypto::rsa_pub_key_t rsa_pub_key; + crypto::ecc_prv_key_t ecc_prv_key; + crypto::ecc_pub_key_t ecc_pub_key; + if (state.range(0) == 0) { - crypto::rsa_prv_key_t rsa_prv_key; - rsa_prv_key.generate(2048); - pub_key = crypto::pub_key_t::from(rsa_prv_key.pub()); + rsa_prv_key.generate(RSA_KEY_BITS); + rsa_pub_key = rsa_prv_key.pub(); + base_pke = &pve_base_pke_rsa(); + ek = pve_keyref(rsa_pub_key); } else { - crypto::ecc_prv_key_t ecc_prv_key; - ecc_prv_key.generate(crypto::curve_p256); - pub_key = crypto::pub_key_t::from(ecc_prv_key.pub()); + ecc_prv_key.generate(curve); + ecc_pub_key = ecc_prv_key.pub(); + base_pke = &pve_base_pke_ecies(); + ek = pve_keyref(ecc_pub_key); } - const mod_t q = crypto::curve_p256.order(); - const crypto::ecc_generator_point_t& G = crypto::curve_p256.generator(); - bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; - pve.encrypt(&pub_key, "test-label", crypto::curve_p256, x); + + pve.encrypt(*base_pke, ek, "test-label", curve, x); for (auto _ : state) { - pve.verify(&pub_key, X, "test-label"); + auto rv = pve.verify(*base_pke, ek, X, "test-label"); + benchmark::DoNotOptimize(rv); } } static void BM_PVE_Decrypt(benchmark::State& state) { ec_pve_t pve; - crypto::pub_key_t pub_key; - crypto::prv_key_t prv_key; + const auto& curve = crypto::curve_p256; + const mod_t q = curve.order(); + bn_t x = bn_t::rand(q); + + const pve_base_pke_i* base_pke = nullptr; + pve_keyref_t ek; + pve_keyref_t dk; + + crypto::rsa_prv_key_t rsa_prv_key; + crypto::rsa_pub_key_t rsa_pub_key; + crypto::ecc_prv_key_t ecc_prv_key; + crypto::ecc_pub_key_t ecc_pub_key; + if (state.range(0) == 0) { - crypto::rsa_prv_key_t rsa_prv_key; - rsa_prv_key.generate(2048); - pub_key = crypto::pub_key_t::from(rsa_prv_key.pub()); - prv_key = crypto::prv_key_t::from(rsa_prv_key); + rsa_prv_key.generate(RSA_KEY_BITS); + rsa_pub_key = rsa_prv_key.pub(); + base_pke = &pve_base_pke_rsa(); + ek = pve_keyref(rsa_pub_key); + dk = pve_keyref(rsa_prv_key); } else { - crypto::ecc_prv_key_t ecc_prv_key; - ecc_prv_key.generate(crypto::curve_p256); - pub_key = crypto::pub_key_t::from(ecc_prv_key.pub()); - prv_key = crypto::prv_key_t::from(ecc_prv_key); + ecc_prv_key.generate(curve); + ecc_pub_key = ecc_prv_key.pub(); + base_pke = &pve_base_pke_ecies(); + ek = pve_keyref(ecc_pub_key); + dk = pve_keyref(ecc_prv_key); } - const mod_t q = crypto::curve_p256.order(); - const crypto::ecc_generator_point_t& G = crypto::curve_p256.generator(); - bn_t x = bn_t::rand(q); - ecc_point_t X = x * G; - pve.encrypt(&pub_key, "test-label", crypto::curve_p256, x); + + pve.encrypt(*base_pke, ek, "test-label", curve, x); for (auto _ : state) { - pve.decrypt(&prv_key, &pub_key, "test-label", crypto::curve_p256, x); + bn_t decrypted; + auto rv = pve.decrypt(*base_pke, dk, ek, "test-label", curve, decrypted); + benchmark::DoNotOptimize(rv); + benchmark::DoNotOptimize(decrypted); } } @@ -89,29 +129,36 @@ static void BM_PVE_Batch_Encrypt(benchmark::State& state) { int n = state.range(1); ec_pve_batch_t pve(n); - crypto::pub_key_t pub_key; - if (state.range(0) == 0) { - crypto::rsa_prv_key_t rsa_prv_key; - rsa_prv_key.generate(2048); - pub_key = crypto::pub_key_t::from(rsa_prv_key.pub()); - } else { - crypto::ecc_prv_key_t ecc_prv_key; - ecc_prv_key.generate(crypto::curve_p256); - pub_key = crypto::pub_key_t::from(ecc_prv_key.pub()); - } - const mod_t q = crypto::curve_p256.order(); - const crypto::ecc_generator_point_t& G = crypto::curve_p256.generator(); + const auto& curve = crypto::curve_p256; + const mod_t q = curve.order(); std::vector x(n); for (int i = 0; i < n; i++) { x[i] = bn_t::rand(q); } - std::vector X(n); - for (int i = 0; i < n; i++) { - X[i] = x[i] * G; + + const pve_base_pke_i* base_pke = nullptr; + pve_keyref_t ek; + + crypto::rsa_prv_key_t rsa_prv_key; + crypto::rsa_pub_key_t rsa_pub_key; + crypto::ecc_prv_key_t ecc_prv_key; + crypto::ecc_pub_key_t ecc_pub_key; + + if (state.range(0) == 0) { + rsa_prv_key.generate(RSA_KEY_BITS); + rsa_pub_key = rsa_prv_key.pub(); + base_pke = &pve_base_pke_rsa(); + ek = pve_keyref(rsa_pub_key); + } else { + ecc_prv_key.generate(curve); + ecc_pub_key = ecc_prv_key.pub(); + base_pke = &pve_base_pke_ecies(); + ek = pve_keyref(ecc_pub_key); } for (auto _ : state) { - pve.encrypt(&pub_key, "test-label", crypto::curve_p256, x); + auto rv = pve.encrypt(*base_pke, ek, "test-label", curve, x); + benchmark::DoNotOptimize(rv); } } @@ -119,18 +166,9 @@ static void BM_PVE_Batch_Verify(benchmark::State& state) { int n = state.range(1); ec_pve_batch_t pve(n); - crypto::pub_key_t pub_key; - if (state.range(0) == 0) { - crypto::rsa_prv_key_t rsa_prv_key; - rsa_prv_key.generate(2048); - pub_key = crypto::pub_key_t::from(rsa_prv_key.pub()); - } else { - crypto::ecc_prv_key_t ecc_prv_key; - ecc_prv_key.generate(crypto::curve_p256); - pub_key = crypto::pub_key_t::from(ecc_prv_key.pub()); - } - const mod_t q = crypto::curve_p256.order(); - const crypto::ecc_generator_point_t& G = crypto::curve_p256.generator(); + const auto& curve = crypto::curve_p256; + const mod_t q = curve.order(); + const crypto::ecc_generator_point_t& G = curve.generator(); std::vector x(n); for (int i = 0; i < n; i++) { x[i] = bn_t::rand(q); @@ -139,9 +177,31 @@ static void BM_PVE_Batch_Verify(benchmark::State& state) { for (int i = 0; i < n; i++) { X[i] = x[i] * G; } - pve.encrypt(&pub_key, "test-label", crypto::curve_p256, x); + + const pve_base_pke_i* base_pke = nullptr; + pve_keyref_t ek; + + crypto::rsa_prv_key_t rsa_prv_key; + crypto::rsa_pub_key_t rsa_pub_key; + crypto::ecc_prv_key_t ecc_prv_key; + crypto::ecc_pub_key_t ecc_pub_key; + + if (state.range(0) == 0) { + rsa_prv_key.generate(RSA_KEY_BITS); + rsa_pub_key = rsa_prv_key.pub(); + base_pke = &pve_base_pke_rsa(); + ek = pve_keyref(rsa_pub_key); + } else { + ecc_prv_key.generate(curve); + ecc_pub_key = ecc_prv_key.pub(); + base_pke = &pve_base_pke_ecies(); + ek = pve_keyref(ecc_pub_key); + } + + pve.encrypt(*base_pke, ek, "test-label", curve, x); for (auto _ : state) { - pve.verify(&pub_key, X, "test-label"); + auto rv = pve.verify(*base_pke, ek, X, "test-label"); + benchmark::DoNotOptimize(rv); } } @@ -149,47 +209,44 @@ static void BM_PVE_Batch_Decrypt(benchmark::State& state) { int n = state.range(1); ec_pve_batch_t pve(n); - crypto::pub_key_t pub_key; - crypto::prv_key_t prv_key; - if (state.range(0) == 0) { - crypto::rsa_prv_key_t rsa_prv_key; - rsa_prv_key.generate(2048); - pub_key = crypto::pub_key_t::from(rsa_prv_key.pub()); - prv_key = crypto::prv_key_t::from(rsa_prv_key); - } else { - crypto::ecc_prv_key_t ecc_prv_key; - ecc_prv_key.generate(crypto::curve_p256); - pub_key = crypto::pub_key_t::from(ecc_prv_key.pub()); - prv_key = crypto::prv_key_t::from(ecc_prv_key); - } - const mod_t q = crypto::curve_p256.order(); - const crypto::ecc_generator_point_t& G = crypto::curve_p256.generator(); + const auto& curve = crypto::curve_p256; + const mod_t q = curve.order(); std::vector x(n); for (int i = 0; i < n; i++) { x[i] = bn_t::rand(q); } - pve.encrypt(&pub_key, "test-label", crypto::curve_p256, x); - for (auto _ : state) { - pve.decrypt(&prv_key, &pub_key, "test-label", crypto::curve_p256, x); + const pve_base_pke_i* base_pke = nullptr; + pve_keyref_t ek; + pve_keyref_t dk; + + crypto::rsa_prv_key_t rsa_prv_key; + crypto::rsa_pub_key_t rsa_pub_key; + crypto::ecc_prv_key_t ecc_prv_key; + crypto::ecc_pub_key_t ecc_pub_key; + + if (state.range(0) == 0) { + rsa_prv_key.generate(RSA_KEY_BITS); + rsa_pub_key = rsa_prv_key.pub(); + base_pke = &pve_base_pke_rsa(); + ek = pve_keyref(rsa_pub_key); + dk = pve_keyref(rsa_prv_key); + } else { + ecc_prv_key.generate(curve); + ecc_pub_key = ecc_prv_key.pub(); + base_pke = &pve_base_pke_ecies(); + ek = pve_keyref(ecc_pub_key); + dk = pve_keyref(ecc_prv_key); } -} -crypto::ecc_prv_key_t get_ecc_prv_key(int participant_index) { - crypto::ecc_prv_key_t prv_key_ecc; - prv_key_ecc.generate(crypto::curve_p256); - return prv_key_ecc; -} -crypto::rsa_prv_key_t get_rsa_prv_key(int participant_index) { - crypto::rsa_prv_key_t prv_key_rsa; - prv_key_rsa.generate(2048); - return prv_key_rsa; -} -crypto::prv_key_t get_prv_key(int participant_index) { - if (participant_index & 1) - return crypto::prv_key_t::from(get_ecc_prv_key(participant_index)); - else - return crypto::prv_key_t::from(get_rsa_prv_key(participant_index)); + pve.encrypt(*base_pke, ek, "test-label", curve, x); + + for (auto _ : state) { + std::vector decrypted; + auto rv = pve.decrypt(*base_pke, dk, ek, "test-label", curve, decrypted); + benchmark::DoNotOptimize(rv); + benchmark::DoNotOptimize(decrypted); + } } class PVEACFixture : public benchmark::Fixture { @@ -199,13 +256,13 @@ class PVEACFixture : public benchmark::Fixture { crypto::ecc_generator_point_t G; crypto::ss::ac_t ac; - std::map pub_keys; - std::map prv_keys; + std::map pub_keys; + std::map prv_keys; mpc::ec_pve_ac_t::pks_t pub_key_ptrs; mpc::ec_pve_ac_t::sks_t prv_key_ptrs; std::vector xs; std::vector Xs; - std::string label = "test-label"; + buf_t label = buf_t("test-label"); ec_pve_ac_t pve; @@ -216,21 +273,21 @@ class PVEACFixture : public benchmark::Fixture { ac = crypto::ss::ac_t(testutils::getTestRoot()); auto leaves = ac.list_leaf_names(); - int participant_index = 0; for (auto path : leaves) { - auto prv_key = get_prv_key(participant_index); + crypto::ecc_prv_key_t prv_key; + prv_key.generate(curve); + crypto::ecc_pub_key_t pub_key = prv_key.pub(); if (!ac.enough_for_quorum(pub_keys)) { prv_keys[path] = prv_key; } - pub_keys[path] = prv_key.pub(); - participant_index++; + pub_keys[path] = std::move(pub_key); } // Build pointer maps expected by ec_pve_ac_t pub_key_ptrs.clear(); prv_key_ptrs.clear(); - for (auto& kv : pub_keys) pub_key_ptrs[kv.first] = &kv.second; - for (auto& kv : prv_keys) prv_key_ptrs[kv.first] = &kv.second; + for (auto& kv : pub_keys) pub_key_ptrs[kv.first] = pve_keyref(kv.second); + for (auto& kv : prv_keys) prv_key_ptrs[kv.first] = pve_keyref(kv.second); int n = 20; xs.resize(n); @@ -250,29 +307,35 @@ BENCHMARK(BM_PVE_Batch_Verify)->Name("PVE/vencrypt-batch/Verify")->ArgsProduct({ BENCHMARK(BM_PVE_Batch_Decrypt)->Name("PVE/vencrypt-batch/Decrypt")->ArgsProduct({{0, 1}, {4, 16}}); BENCHMARK_DEFINE_F(PVEACFixture, BM_AC_Encrypt)(benchmark::State& state) { + const auto& base_pke = pve_base_pke_ecies(); for (auto _ : state) { - pve.encrypt(ac, pub_key_ptrs, label, curve, xs); + auto rv = pve.encrypt(base_pke, ac, pub_key_ptrs, label, curve, xs); + benchmark::DoNotOptimize(rv); } } BENCHMARK_DEFINE_F(PVEACFixture, BM_AC_Verify)(benchmark::State& state) { - pve.encrypt(ac, pub_key_ptrs, label, curve, xs); + const auto& base_pke = pve_base_pke_ecies(); + pve.encrypt(base_pke, ac, pub_key_ptrs, label, curve, xs); for (auto _ : state) { - pve.verify(ac, pub_key_ptrs, Xs, label); + auto rv = pve.verify(base_pke, ac, pub_key_ptrs, Xs, label); + benchmark::DoNotOptimize(rv); } } BENCHMARK_DEFINE_F(PVEACFixture, BM_AC_Decrypt)(benchmark::State& state) { - pve.encrypt(ac, pub_key_ptrs, label, curve, xs); + const auto& base_pke = pve_base_pke_ecies(); + pve.encrypt(base_pke, ac, pub_key_ptrs, label, curve, xs); std::vector decrypted_xs; for (auto _ : state) { int row_index = 0; std::map shares; for (auto &kv : prv_key_ptrs) { bn_t share; - auto rv = pve.party_decrypt_row(ac, row_index, kv.first, kv.second, label, share); + auto rv = pve.party_decrypt_row(base_pke, ac, row_index, kv.first, kv.second, label, share); if (rv) benchmark::DoNotOptimize(rv); shares[kv.first] = share; } - auto rv = pve.aggregate_to_restore_row(ac, row_index, label, shares, decrypted_xs, /*skip_verify=*/true); + auto rv = + pve.aggregate_to_restore_row(base_pke, ac, row_index, label, shares, decrypted_xs, /*skip_verify=*/true); if (rv) benchmark::DoNotOptimize(rv); } } diff --git a/tools/benchmark/bm_share.cpp b/tools/benchmark/bm_share.cpp index 0be9e91e..d724e872 100644 --- a/tools/benchmark/bm_share.cpp +++ b/tools/benchmark/bm_share.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "util.h" diff --git a/tools/benchmark/bm_sid.cpp b/tools/benchmark/bm_sid.cpp index 7dce7946..b5077e01 100644 --- a/tools/benchmark/bm_sid.cpp +++ b/tools/benchmark/bm_sid.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "mpc_util.h" diff --git a/tools/benchmark/bm_tdh2.cpp b/tools/benchmark/bm_tdh2.cpp index 69a8ea78..08dc11dc 100644 --- a/tools/benchmark/bm_tdh2.cpp +++ b/tools/benchmark/bm_tdh2.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include "data/tdh2.h" diff --git a/tools/benchmark/bm_test.cpp b/tools/benchmark/bm_test.cpp index 5bdab908..ed7f1a18 100644 --- a/tools/benchmark/bm_test.cpp +++ b/tools/benchmark/bm_test.cpp @@ -2,8 +2,8 @@ #include -#include -#include +#include +#include #include "mpc_util.h" diff --git a/tools/benchmark/bm_zk.cpp b/tools/benchmark/bm_zk.cpp index 63f1901a..759a3cdb 100644 --- a/tools/benchmark/bm_zk.cpp +++ b/tools/benchmark/bm_zk.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include "data/zk_data_generator.h" #include "util.h" diff --git a/tools/benchmark/mpc_util.h b/tools/benchmark/mpc_util.h index a0309ad3..cbddd581 100644 --- a/tools/benchmark/mpc_util.h +++ b/tools/benchmark/mpc_util.h @@ -1,9 +1,9 @@ #pragma once #include -#include -#include -#include +#include +#include +#include "mpc_job_session.h" // diff --git a/tools/benchmark/util.h b/tools/benchmark/util.h index b26cd45f..69239f0f 100644 --- a/tools/benchmark/util.h +++ b/tools/benchmark/util.h @@ -1,6 +1,6 @@ #pragma once -#include +#include inline coinbase::crypto::ecurve_t get_curve(int index) { switch (index) { From f7208e0654aedc6bf54fc5f0aa48ab19b2848267 Mon Sep 17 00:00:00 2001 From: Arash Afshar Date: Thu, 12 Mar 2026 09:46:00 -0600 Subject: [PATCH 3/3] more fixes --- .../cbmpc/internal/core/convert.h | 10 ++- include-internal/cbmpc/internal/core/utils.h | 5 +- .../cbmpc/internal/crypto/base_ec_core.h | 16 ----- include-internal/cbmpc/internal/crypto/tdh2.h | 6 +- .../cbmpc/internal/protocol/util.h | 3 + include-internal/cbmpc/internal/zk/fischlin.h | 12 ++-- include/cbmpc/core/macros.h | 4 -- src/cbmpc/api/access_structure_util.h | 2 + src/cbmpc/core/buf.cpp | 1 + src/cbmpc/crypto/base.cpp | 5 +- src/cbmpc/crypto/base_ecc.cpp | 6 +- src/cbmpc/crypto/base_hash.cpp | 2 + src/cbmpc/crypto/secret_sharing.cpp | 22 ++++--- tests/unit/api/test_ecdsa2pc.cpp | 19 ++++++ tests/unit/api/test_eddsa2pc.cpp | 66 +++++++++++++++++++ tests/unit/api/test_pve_ac.cpp | 9 +++ tests/unit/core/test_buf.cpp | 8 +++ tests/unit/core/test_convert.cpp | 19 ++++++ tests/unit/core/test_util.cpp | 16 +++++ tests/unit/crypto/test_base_hash.cpp | 18 +++++ tests/unit/crypto/test_ecc.cpp | 2 + tests/unit/crypto/test_eddsa.cpp | 38 +++++++++++ tests/unit/crypto/test_secret_sharing.cpp | 41 ++++++++++++ tests/unit/crypto/test_tdh2.cpp | 33 ++++++++++ tests/unit/zk/test_zk.cpp | 36 ++++++++++ 25 files changed, 360 insertions(+), 39 deletions(-) mode change 100755 => 100644 include/cbmpc/core/macros.h diff --git a/include-internal/cbmpc/internal/core/convert.h b/include-internal/cbmpc/internal/core/convert.h index 5fdfb0ad..f20d3811 100644 --- a/include-internal/cbmpc/internal/core/convert.h +++ b/include-internal/cbmpc/internal/core/convert.h @@ -254,7 +254,15 @@ template error_t deser(mem_t bin, ARGS&... args) { converter_t converter(bin); converter.convert(args...); - return converter.get_rv(); + error_t rv = converter.get_rv(); + if (rv != SUCCESS) return rv; + + // Strict deserialization: reject trailing bytes + if (converter.get_offset() != converter.get_size()) { + return coinbase::error(E_BADARG); + } + + return SUCCESS; } template diff --git a/include-internal/cbmpc/internal/core/utils.h b/include-internal/cbmpc/internal/core/utils.h index 154f3542..015ca964 100644 --- a/include-internal/cbmpc/internal/core/utils.h +++ b/include-internal/cbmpc/internal/core/utils.h @@ -20,7 +20,10 @@ static std::mutex coutMutex; namespace coinbase { inline int bits_to_bytes_floor(int bits) { return bits >> 3; } -inline int bits_to_bytes(int bits) { return (bits + 7) >> 3; } +inline int bits_to_bytes(int bits) { + cb_assert(bits >= 0); + return (bits + 7) >> 3; +} inline int bytes_to_bits(int bytes) { return bytes << 3; } namespace detail { diff --git a/include-internal/cbmpc/internal/crypto/base_ec_core.h b/include-internal/cbmpc/internal/crypto/base_ec_core.h index f32d88eb..63d59f2d 100644 --- a/include-internal/cbmpc/internal/crypto/base_ec_core.h +++ b/include-internal/cbmpc/internal/crypto/base_ec_core.h @@ -157,22 +157,6 @@ struct edwards_projective_t { affine_y = y * zi; } - static bool is_on_curve(const fe_t& x, const fe_t& y) { - fe_t xx = x * x; - fe_t yy = y * y; - - fe_t t = yy; - if constexpr (a_coeff == -1) - t -= yy; - else if constexpr (a_coeff == 1) - t += yy; - else - return false; - - fe_t d = get_d(); - return t == fe_t::one() + d * xx * yy; - } - static bool is_on_curve(const fe_t& x, const fe_t& y, const fe_t& z) { fe_t xx = x * x; fe_t yy = y * y; diff --git a/include-internal/cbmpc/internal/crypto/tdh2.h b/include-internal/cbmpc/internal/crypto/tdh2.h index ce2a7e5e..845789c1 100644 --- a/include-internal/cbmpc/internal/crypto/tdh2.h +++ b/include-internal/cbmpc/internal/crypto/tdh2.h @@ -59,7 +59,11 @@ struct public_key_t { */ ciphertext_t encrypt(mem_t plain, mem_t label, const bn_t& r, const bn_t& s, mem_t iv) const; - bool valid() const { return Q.valid(); } + bool valid() const { + if (!Q.valid()) return false; + ecc_point_t expected_Gamma = ro::hash_curve(mem_t("TDH2-Gamma"), Q, sid).curve(Q.get_curve()); + return Gamma == expected_Gamma; + } void convert(coinbase::converter_t& converter) { converter.convert(Q, Gamma, sid); } buf_t to_bin() const { return coinbase::convert(*this); } error_t from_bin(mem_t bin) { return coinbase::convert(*this, bin); } diff --git a/include-internal/cbmpc/internal/protocol/util.h b/include-internal/cbmpc/internal/protocol/util.h index 5fa2364e..fc4f7913 100644 --- a/include-internal/cbmpc/internal/protocol/util.h +++ b/include-internal/cbmpc/internal/protocol/util.h @@ -18,6 +18,7 @@ static T SUM(int n, LAMBDA lambda) { template static T SUM(const std::vector& v) { + if (v.empty()) return T{}; T s = v[0]; for (int i = 1; i < int(v.size()); i++) s += v[i]; return s; @@ -25,6 +26,7 @@ static T SUM(const std::vector& v) { template static T SUM(const std::vector>& v) { + if (v.empty()) return T{}; T s = v[0].get(); for (int i = 1; i < int(v.size()); i++) s += v[i].get(); return s; @@ -32,6 +34,7 @@ static T SUM(const std::vector>& v) { template static T SUM(const std::map& m) { + if (m.empty()) return T{}; T s = m.begin()->second; for (auto it = std::next(m.begin()); it != m.end(); ++it) s += it->second; return s; diff --git a/include-internal/cbmpc/internal/zk/fischlin.h b/include-internal/cbmpc/internal/zk/fischlin.h index 68300da2..72c9cbb4 100644 --- a/include-internal/cbmpc/internal/zk/fischlin.h +++ b/include-internal/cbmpc/internal/zk/fischlin.h @@ -48,17 +48,17 @@ struct fischlin_params_t { int rho, b, t; int e_max() const { - cb_assert(t < 32); - return 1 << t; + cb_assert(t < 31); + return 1U << t; } uint32_t b_mask() const { - cb_assert(b < 32); - return (1 << b) - 1; + cb_assert(b < 31); + return (1U << b) - 1; } error_t check() const { if (rho <= 0) return coinbase::error(E_CRYPTO, "rho <= 0"); if (b <= 0) return coinbase::error(E_CRYPTO, "b <= 0"); - if (b >= 32) return coinbase::error(E_CRYPTO, "b >= 32"); + if (b >= 31) return coinbase::error(E_CRYPTO, "b >= 31"); if (int64_t(b) * int64_t(rho) < SEC_P_COM) return coinbase::error(E_CRYPTO, "b * rho < SEC_P_COM"); return SUCCESS; } @@ -66,7 +66,7 @@ struct fischlin_params_t { error_t check_with_effective_b(int effective_b) const { if (rho <= 0) return coinbase::error(E_CRYPTO, "rho <= 0"); if (b <= 0) return coinbase::error(E_CRYPTO, "b <= 0"); - if (b >= 32) return coinbase::error(E_CRYPTO, "b >= 32"); + if (b >= 31) return coinbase::error(E_CRYPTO, "b >= 31"); if (effective_b <= 0) return coinbase::error(E_CRYPTO, "effective_b <= 0"); if (int64_t(rho) * int64_t(effective_b) < SEC_P_COM) diff --git a/include/cbmpc/core/macros.h b/include/cbmpc/core/macros.h old mode 100755 new mode 100644 index 6338ca1e..7e798c57 --- a/include/cbmpc/core/macros.h +++ b/include/cbmpc/core/macros.h @@ -19,10 +19,6 @@ #define DLLEXPORT __attribute__((visibility("default"))) #define DLLEXPORT_DEF DLLEXPORT -#ifndef NULL -#define NULL ((void*)0) -#endif - #define FOR_EACH(i, c) for (auto i = (c).begin(); i != (c).end(); ++i) typedef void* void_ptr; diff --git a/src/cbmpc/api/access_structure_util.h b/src/cbmpc/api/access_structure_util.h index da3b3488..9682ac1b 100644 --- a/src/cbmpc/api/access_structure_util.h +++ b/src/cbmpc/api/access_structure_util.h @@ -47,6 +47,8 @@ inline error_t validate_access_structure_node_impl(const access_structure_t& n, if (static_cast(n.threshold_k) > n.children.size()) return coinbase::error(E_BADARG, "access_structure: threshold_k > children.size()"); break; + default: + return coinbase::error(E_BADARG, "invalid node type"); } for (const auto& ch : n.children) { diff --git a/src/cbmpc/core/buf.cpp b/src/cbmpc/core/buf.cpp index 84e87788..b31d32e8 100644 --- a/src/cbmpc/core/buf.cpp +++ b/src/cbmpc/core/buf.cpp @@ -11,6 +11,7 @@ buf_t::buf_t() noexcept(true) : s(0) { static_assert(sizeof(buf_t) == 40, "Inval buf_t::buf_t(int new_size) : s(new_size) { // NOLINT(*init*) // NOTE: `buf_t(int)` intentionally leaves the buffer contents uninitialized. // Callers must fully overwrite `size()` bytes before reading from `data()`. + cb_assert(new_size >= 0); if (new_size > short_size) set_long_ptr(new byte_t[new_size]); } diff --git a/src/cbmpc/crypto/base.cpp b/src/cbmpc/crypto/base.cpp index b576a414..49ae95f9 100644 --- a/src/cbmpc/crypto/base.cpp +++ b/src/cbmpc/crypto/base.cpp @@ -119,7 +119,10 @@ buf_t gen_random(int size) { return output; } -buf_t gen_random_bitlen(int bitlen) { return gen_random(coinbase::bits_to_bytes(bitlen)); } +buf_t gen_random_bitlen(int bitlen) { + cb_assert(bitlen >= 0); + return gen_random(coinbase::bits_to_bytes(bitlen)); +} coinbase::bits_t gen_random_bits(int count) { coinbase::bits_t out(count); diff --git a/src/cbmpc/crypto/base_ecc.cpp b/src/cbmpc/crypto/base_ecc.cpp index 8519c08e..373913ee 100644 --- a/src/cbmpc/crypto/base_ecc.cpp +++ b/src/cbmpc/crypto/base_ecc.cpp @@ -410,6 +410,7 @@ void ecc_prv_key_t::set(ecurve_t curve, const bn_t& val) { } void ecc_prv_key_t::set_ed_bin(mem_t ed_bin) { + cb_assert(ed_bin.size == ed25519::prv_bin_size()); this->curve = curve_ed25519; this->ed_bin = ed_bin; } @@ -446,7 +447,10 @@ error_t sig_with_pub_key_t::verify_all(const ecc_point_t& Q, mem_t hash, const std::vector& sigs) // static { error_t rv = UNINITIALIZED_ERROR; - ecc_point_t QSum = crypto::curve_p256.infinity(); + if (sigs.empty()) return coinbase::error(E_BADARG, "sig_with_pub_key_t::verify_all: no signatures provided"); + + ecc_point_t QSum = sigs[0].Q.get_curve().infinity(); + for (const auto& s : sigs) { if (rv = s.verify(hash)) return rv; QSum += s.Q; diff --git a/src/cbmpc/crypto/base_hash.cpp b/src/cbmpc/crypto/base_hash.cpp index f55a8292..864d8371 100644 --- a/src/cbmpc/crypto/base_hash.cpp +++ b/src/cbmpc/crypto/base_hash.cpp @@ -128,6 +128,7 @@ hash_t& hash_t::init() { } hash_t& hash_t::update(const_byte_ptr ptr, int size) { + cb_assert(size >= 0); ::EVP_DigestUpdate(ctx_ptr, ptr, size); return *this; } @@ -170,6 +171,7 @@ hmac_t& hmac_t::init(mem_t key) { } hmac_t& hmac_t::update(const_byte_ptr ptr, int size) { + cb_assert(size >= 0); EVP_MAC_update(ctx_ptr, ptr, size); return *this; } diff --git a/src/cbmpc/crypto/secret_sharing.cpp b/src/cbmpc/crypto/secret_sharing.cpp index 278135cb..9446d5d6 100644 --- a/src/cbmpc/crypto/secret_sharing.cpp +++ b/src/cbmpc/crypto/secret_sharing.cpp @@ -24,9 +24,10 @@ std::vector share_and(const mod_t& q, const bn_t& x, const int n, crypto:: std::pair, std::vector> share_threshold(const mod_t& q, const bn_t& a, const int threshold, const int n, const std::vector& pids, crypto::drbg_aes_ctr_t* drbg) { + cb_assert(threshold > 0); + cb_assert(n > 0); std::vector shares(n); std::vector b(threshold); - cb_assert(threshold > 0); shares.resize(n); b.resize(threshold); b[0] = a; @@ -354,8 +355,9 @@ error_t ac_t::verify_share_against_ancestors_pub_data(const ecc_point_t& Q, cons while (node != nullptr) { auto sorted_children = node->get_sorted_children(); - auto pub_shares = pub_data.at(node->name); - ecc_point_t my_pub_share = pub_shares; + auto it = pub_data.find(node->name); + if (it == pub_data.end()) return coinbase::error(E_BADARG, "missing pub_data key"); + ecc_point_t my_pub_share = it->second; if (node->type == node_e::LEAF || node->type == node_e::OR) { if (my_pub_share != expected_pub_share) { @@ -364,15 +366,18 @@ error_t ac_t::verify_share_against_ancestors_pub_data(const ecc_point_t& Q, cons } else if (node->type == node_e::AND) { ecc_point_t expected_sum = curve.infinity(); for (size_t i = 0; i < sorted_children.size(); i++) { - auto child_pub_shares = pub_data.at(sorted_children[i]->name); - expected_sum += child_pub_shares; + auto child_it = pub_data.find(sorted_children[i]->name); + if (child_it == pub_data.end()) return coinbase::error(E_BADARG, "missing pub_data key"); + expected_sum += child_it->second; } if (expected_sum != my_pub_share) return coinbase::error(E_CRYPTO); } else if (node->type == node_e::THRESHOLD) { std::vector quorum(node->threshold); std::vector quorum_pids(node->threshold); for (int i = 0; i < node->threshold; i++) { - quorum[i] = pub_data.at(sorted_children[i]->name); + auto child_it = pub_data.find(sorted_children[i]->name); + if (child_it == pub_data.end()) return coinbase::error(E_BADARG, "missing pub_data key"); + quorum[i] = child_it->second; quorum_pids[i] = sorted_children[i]->get_pid(); } @@ -382,8 +387,9 @@ error_t ac_t::verify_share_against_ancestors_pub_data(const ecc_point_t& Q, cons if (my_pub_share != lagrange_interpolate_exponent(0, quorum, quorum_pids)) return coinbase::error(E_CRYPTO); for (size_t i = node->threshold; i < sorted_children.size(); i++) { - if (pub_data.at(sorted_children[i]->name) != - lagrange_interpolate_exponent(sorted_children[i]->get_pid(), quorum, quorum_pids)) + auto child_it = pub_data.find(sorted_children[i]->name); + if (child_it == pub_data.end()) return coinbase::error(E_BADARG, "missing pub_data key"); + if (child_it->second != lagrange_interpolate_exponent(sorted_children[i]->get_pid(), quorum, quorum_pids)) return coinbase::error(E_CRYPTO); } } else { diff --git a/tests/unit/api/test_ecdsa2pc.cpp b/tests/unit/api/test_ecdsa2pc.cpp index d50a6355..06590044 100644 --- a/tests/unit/api/test_ecdsa2pc.cpp +++ b/tests/unit/api/test_ecdsa2pc.cpp @@ -825,3 +825,22 @@ TEST_F(ApiEcdsa2pcNegWithBlobs, NegRefreshRoleMismatch) { buf_t new_blob; EXPECT_NE(coinbase::api::ecdsa_2p::refresh(job, blob1_, new_blob), SUCCESS); } + +// ========================================================================== +// Negative: sign with null or zero-length message +// ========================================================================== + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegSignNullMessage) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, blob1_, mem_t(nullptr, 0), sid, sig), SUCCESS); +} + +TEST_F(ApiEcdsa2pcNegWithBlobs, NegSignZeroLengthMessage) { + noop_transport_t t; + auto job = make_noop_job(t); + buf_t empty_msg; + buf_t sid, sig; + EXPECT_NE(coinbase::api::ecdsa_2p::sign(job, blob1_, empty_msg, sid, sig), SUCCESS); +} diff --git a/tests/unit/api/test_eddsa2pc.cpp b/tests/unit/api/test_eddsa2pc.cpp index 20b09aac..7ec570d3 100644 --- a/tests/unit/api/test_eddsa2pc.cpp +++ b/tests/unit/api/test_eddsa2pc.cpp @@ -12,6 +12,7 @@ namespace { using coinbase::buf_t; using coinbase::error_t; +using coinbase::mem_t; using coinbase::api::curve_id; using coinbase::api::eddsa_2p::party_t; @@ -205,3 +206,68 @@ TEST(ApiEdDSA2pc, KeyBlobPrivScalar_NoPubSign) { buf_t bad_merged; EXPECT_NE(coinbase::api::eddsa_2p::attach_private_scalar(public_1, bad_x, Qi_full_1, bad_merged), SUCCESS); } + +TEST(ApiEdDSA2pc, NegSignNullMessage) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::dkg(job1, curve_id::ed25519, key_blob_1); }, + [&] { return coinbase::api::eddsa_2p::dkg(job2, curve_id::ed25519, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t sig1; + buf_t sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::sign(job1, key_blob_1, mem_t(nullptr, 0), sig1); }, + [&] { return coinbase::api::eddsa_2p::sign(job2, key_blob_2, mem_t(nullptr, 0), sig2); }, rv1, rv2); + EXPECT_NE(rv1, SUCCESS); +} + +TEST(ApiEdDSA2pc, NegSignZeroLengthMessage) { + auto c1 = std::make_shared(0); + auto c2 = std::make_shared(1); + std::vector> peers = {c1, c2}; + c1->init_with_peers(peers); + c2->init_with_peers(peers); + + local_api_transport_t t1(c1); + local_api_transport_t t2(c2); + + buf_t key_blob_1; + buf_t key_blob_2; + error_t rv1 = UNINITIALIZED_ERROR; + error_t rv2 = UNINITIALIZED_ERROR; + + const coinbase::api::job_2p_t job1{party_t::p1, "p1", "p2", t1}; + const coinbase::api::job_2p_t job2{party_t::p2, "p1", "p2", t2}; + + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::dkg(job1, curve_id::ed25519, key_blob_1); }, + [&] { return coinbase::api::eddsa_2p::dkg(job2, curve_id::ed25519, key_blob_2); }, rv1, rv2); + ASSERT_EQ(rv1, SUCCESS); + ASSERT_EQ(rv2, SUCCESS); + + buf_t empty_msg; + buf_t sig1; + buf_t sig2; + run_2pc( + c1, c2, [&] { return coinbase::api::eddsa_2p::sign(job1, key_blob_1, empty_msg, sig1); }, + [&] { return coinbase::api::eddsa_2p::sign(job2, key_blob_2, empty_msg, sig2); }, rv1, rv2); + EXPECT_NE(rv1, SUCCESS); +} diff --git a/tests/unit/api/test_pve_ac.cpp b/tests/unit/api/test_pve_ac.cpp index 0622041d..2429b145 100644 --- a/tests/unit/api/test_pve_ac.cpp +++ b/tests/unit/api/test_pve_ac.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -895,3 +896,11 @@ TEST(ApiPveAcNeg, GetPublicKeysCompressedAc_GarbageCiphertext) { std::vector Qs; EXPECT_NE(coinbase::api::pve::get_public_keys_compressed_ac(mem_t(garbage.data(), 4), Qs), SUCCESS); } + +TEST(ApiPveAcNeg, ValidateAccessStructureNodeAcceptsInvalidNodeType) { + coinbase::api::access_structure_t invalid_node; + // Cast an out-of-range value to the enum type. + invalid_node.type = static_cast(99); + error_t rv = coinbase::api::detail::validate_access_structure_node(invalid_node); + EXPECT_NE(rv, SUCCESS); +} \ No newline at end of file diff --git a/tests/unit/core/test_buf.cpp b/tests/unit/core/test_buf.cpp index 363ac1fc..61755a55 100644 --- a/tests/unit/core/test_buf.cpp +++ b/tests/unit/core/test_buf.cpp @@ -196,4 +196,12 @@ TEST(Buf, BzeroAndSecureBzero) { } } +TEST(Buf, ConstructorRejectsNegativeSize) { + // buf_t constructor should validate that size >= 0 + // Negative sizes would bypass guards like `if (buf.size() > 0)` + // and could cause downstream memory corruption + EXPECT_THROW({ coinbase::buf_t buf(-1); }, coinbase::assertion_failed_t); + EXPECT_THROW({ coinbase::buf_t buf(-100); }, coinbase::assertion_failed_t); +} + } // namespace \ No newline at end of file diff --git a/tests/unit/core/test_convert.cpp b/tests/unit/core/test_convert.cpp index 792f72a7..f872c814 100644 --- a/tests/unit/core/test_convert.cpp +++ b/tests/unit/core/test_convert.cpp @@ -155,4 +155,23 @@ TEST(CoreConvert, ConvertLastRejectsNegRemainingSize) { EXPECT_EQ(out.size(), 0); } +TEST(CoreConvert, RejectsTrailingBytes) { + // Strict deserialization should fail when there are unconsumed trailing bytes. + // This prevents message malleability where two different byte sequences + // could deserialize to the same value. + int original_value = 42; + buf_t serialized = coinbase::ser(original_value); + + byte_t garbage[4] = {0xDE, 0xAD, 0xBE, 0xEF}; + buf_t with_trailing = serialized + mem_t(garbage, 4); + + int result = 0; + error_t rv = deser(with_trailing, result); + + EXPECT_NE(rv, SUCCESS); + + EXPECT_OK(deser(serialized, result)); + EXPECT_EQ(result, original_value); +} + } // namespace diff --git a/tests/unit/core/test_util.cpp b/tests/unit/core/test_util.cpp index 04d0cff6..48aac1ca 100644 --- a/tests/unit/core/test_util.cpp +++ b/tests/unit/core/test_util.cpp @@ -3,8 +3,13 @@ #include #include +#include +#include + +#include "utils/test_macros.h" using namespace coinbase; +using namespace coinbase::crypto; // Test bits_to_bytes and bytes_to_bits TEST(CoreUtils, BitAndByteConversions) { @@ -123,4 +128,15 @@ TEST(CoreUtils, ConstantTimeSelectU64) { EXPECT_EQ(result1, val1); EXPECT_EQ(result2, val2); +} + +TEST(ProtocolUtil, SUMCrashesOnEmpty) { + std::vector empty_vec; + EXPECT_NO_FATAL_FAILURE({ (void)SUM(empty_vec); }); + + std::vector> empty_refs; + EXPECT_NO_FATAL_FAILURE({ (void)SUM(empty_refs); }); + + std::map empty_map; + EXPECT_NO_FATAL_FAILURE({ (void)SUM(empty_map); }); } \ No newline at end of file diff --git a/tests/unit/crypto/test_base_hash.cpp b/tests/unit/crypto/test_base_hash.cpp index 796febbb..1a9b32dc 100644 --- a/tests/unit/crypto/test_base_hash.cpp +++ b/tests/unit/crypto/test_base_hash.cpp @@ -22,4 +22,22 @@ TEST(BaseHash, MemVecEncodesBoundsAndLen) { EXPECT_NE(hb, hc); } +TEST(BaseHash, UpdateRejectsNegativeSize) { + // hash_t::update() should validate that size >= 0 + // Negative sizes would be implicitly converted to huge positive size_t values + // when passed to EVP_DigestUpdate(), causing out-of-bounds reads + + hash_t hash(hash_e::sha256); + hash.init(); + byte_t data[4] = {0x01, 0x02, 0x03, 0x04}; + + // Should throw assertion_failed_t for negative size + EXPECT_THROW({ hash.update(data, -1); }, coinbase::assertion_failed_t); + EXPECT_THROW({ hash.update(data, -100); }, coinbase::assertion_failed_t); + + // Valid size should work + EXPECT_NO_THROW({ hash.update(data, 4); }); + EXPECT_NO_THROW({ hash.update(data, 0); }); +} + } // namespace diff --git a/tests/unit/crypto/test_ecc.cpp b/tests/unit/crypto/test_ecc.cpp index ea0bb766..8177e90c 100644 --- a/tests/unit/crypto/test_ecc.cpp +++ b/tests/unit/crypto/test_ecc.cpp @@ -3,6 +3,8 @@ #include #include +#include "test_macros.h" + using namespace coinbase; using namespace coinbase::crypto; diff --git a/tests/unit/crypto/test_eddsa.cpp b/tests/unit/crypto/test_eddsa.cpp index 502fd189..eec19a0e 100644 --- a/tests/unit/crypto/test_eddsa.cpp +++ b/tests/unit/crypto/test_eddsa.cpp @@ -230,4 +230,42 @@ TEST(CryptoEdDSA, subgroup_check) { EXPECT_EQ(got, want); } +TEST(CryptoEdDSA, SetEdBinValidatesKeyLength) { + // Ed25519 private keys must be exactly 32 bytes + // set_ed_bin() should validate this to prevent out-of-bounds reads in sign() + + // Test with too short key (16 bytes) + { + buf_t short_key(16); + short_key.bzero(); + + ecc_prv_key_t key; + // Should throw assertion_failed_t for wrong-length key + EXPECT_THROW({ key.set_ed_bin(short_key); }, coinbase::assertion_failed_t); + } + + // Test with too long key (48 bytes) + { + buf_t long_key(48); + long_key.bzero(); + + ecc_prv_key_t key; + // Should throw assertion_failed_t for wrong-length key + EXPECT_THROW({ key.set_ed_bin(long_key); }, coinbase::assertion_failed_t); + } + + // Test with valid 32-byte key - should succeed + { + buf_t valid_key(32); + valid_key.bzero(); + + ecc_prv_key_t key; + EXPECT_NO_THROW({ key.set_ed_bin(valid_key); }); + + // Verify the key was set correctly + ecc_point_t pub = key.pub(); + EXPECT_TRUE(pub.is_on_curve()); + } +} + } // namespace diff --git a/tests/unit/crypto/test_secret_sharing.cpp b/tests/unit/crypto/test_secret_sharing.cpp index 08825f27..22cf64db 100644 --- a/tests/unit/crypto/test_secret_sharing.cpp +++ b/tests/unit/crypto/test_secret_sharing.cpp @@ -313,4 +313,45 @@ TEST_F(SecretSharing, ReconstructExpRejectsNonSubgroup) { EXPECT_ER(ac.reconstruct_exponent(pub_shares, P)); } +TEST_F(SecretSharing, VerifyShareMissingPubDataKeyCrashes) { + ecurve_t curve = curve_secp256k1; + ac_t ac(simple_and_node, curve); + + ac_shares_t shares; + ac_internal_shares_t internal_shares; + ac_internal_pub_shares_t pub_data; + ASSERT_OK(ac.share_with_internals(q, x, shares, internal_shares, pub_data)); + + pub_data.erase("leaf2"); + + vartime_scope_t vartime_scope; + ecc_point_t Q = x * curve.generator(); + + EXPECT_NO_THROW({ + auto rv = ac.verify_share_against_ancestors_pub_data(Q, shares.at("leaf1"), pub_data, "leaf1"); + EXPECT_ER(rv); + }); +} + +TEST_F(SecretSharing, ShareThresholdRejectsNegativeN) { + // share_threshold() should validate that n > 0 + // Negative n values would convert to SIZE_MAX when constructing std::vector(n) + // causing std::bad_alloc or std::length_error + + int threshold = 3; + std::vector pids = {1, 3, 8}; // Dummy PIDs + + // Should throw assertion_failed_t for negative n + EXPECT_THROW({ share_threshold(q, x, threshold, -1, pids, nullptr); }, coinbase::assertion_failed_t); + + EXPECT_THROW({ share_threshold(q, x, threshold, -100, pids, nullptr); }, coinbase::assertion_failed_t); + + // Valid n should work + { + auto result = share_threshold(q, x, threshold, 3, pids, nullptr); + EXPECT_EQ(result.first.size(), 3); + EXPECT_EQ(result.second.size(), threshold); + } +} + } // namespace \ No newline at end of file diff --git a/tests/unit/crypto/test_tdh2.cpp b/tests/unit/crypto/test_tdh2.cpp index 85d320d5..1531aeb2 100644 --- a/tests/unit/crypto/test_tdh2.cpp +++ b/tests/unit/crypto/test_tdh2.cpp @@ -97,4 +97,37 @@ TEST_F(TDH2, CiphertextRoundTripKeepsLabel) { EXPECT_ER(roundtrip.verify(enc_key, wrong_label)); } +TEST_F(TDH2, PublicKeyValidChecksGammaConsistency) { + // public_key_t::valid() should verify that Gamma is correctly derived from Q and sid + // This prevents attackers from using rogue Gamma values in encryption/decryption + + vartime_scope_t vartime_scope; + ecurve_t curve = curve_secp256k1; + + // Build a legitimate public key + bn_t sk = bn_t::rand(curve.order()); + ecc_point_t Q = sk * curve.generator(); + buf_t sid = crypto::gen_random(32); + + // Constructor derives Gamma from Q and sid + public_key_t pk(Q, mem_t(sid)); + EXPECT_TRUE(pk.valid()); // Should be valid + + // Tamper with Gamma: replace it with a random unrelated point + bn_t rogue_scalar = bn_t::rand(curve.order()); + ecc_point_t rogue_Gamma = rogue_scalar * curve.generator(); + pk.Gamma = rogue_Gamma; + + // valid() should detect the Gamma inconsistency and return false + EXPECT_FALSE(pk.valid()); + + // Test with another invalid Gamma + pk.Gamma = curve.infinity(); + EXPECT_FALSE(pk.valid()); + + // Restore correct Gamma - should be valid again + pk.Gamma = ro::hash_curve(mem_t("TDH2-Gamma"), Q, sid).curve(curve); + EXPECT_TRUE(pk.valid()); +} + } // namespace diff --git a/tests/unit/zk/test_zk.cpp b/tests/unit/zk/test_zk.cpp index bd385cba..523d15a6 100644 --- a/tests/unit/zk/test_zk.cpp +++ b/tests/unit/zk/test_zk.cpp @@ -1,5 +1,6 @@ #include +#include #include #include "utils/data/zk_completeness.h" @@ -96,4 +97,39 @@ TEST_NIZK_COMPLETENESS(ZK_PaillierRangeExpSlack, new test_nizk_paillier_range_ex TEST_NIZK_COMPLETENESS_CURVES(ZK_PDL, test_nizk_pdl); TEST_NIZK_COMPLETENESS(ZK_UnknownOrderDL, new test_unknown_order_dl()); +TEST(FischlinParams, RejectsB31ToPreventUB) { + coinbase::zk::fischlin_params_t p{128, 31, 4}; + EXPECT_NE(p.check(), SUCCESS); +} + +TEST(FischlinParams, RejectsB32) { + coinbase::zk::fischlin_params_t p{128, 32, 4}; + EXPECT_NE(p.check(), SUCCESS); +} + +TEST(FischlinParams, AcceptsValidParams) { + coinbase::zk::fischlin_params_t p{128, 16, 4}; + EXPECT_EQ(p.check(), SUCCESS); +} + +TEST(FischlinParams, AcceptsB30) { + coinbase::zk::fischlin_params_t p{128, 30, 4}; + EXPECT_EQ(p.check(), SUCCESS); +} + +TEST(FischlinParams, AcceptsT30) { + coinbase::zk::fischlin_params_t p{128, 16, 30}; + EXPECT_EQ(p.check(), SUCCESS); +} + +TEST(FischlinParams, BMaskWorksCorrectly) { + coinbase::zk::fischlin_params_t p{128, 8, 4}; + EXPECT_EQ(p.b_mask(), 0xFFu); +} + +TEST(FischlinParams, EMaxWorksCorrectly) { + coinbase::zk::fischlin_params_t p{128, 16, 4}; + EXPECT_EQ(p.e_max(), 16); +} + } // namespace \ No newline at end of file