Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 155 additions & 71 deletions test/acvp_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#!/usr/bin/env python3
# Copyright (c) The mlkem-native project authors
# Copyright (c) The mldsa-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

# ACVP client for ML-DSA
#
# Processes 'internalProjection.json' files from
# https://github.com/usnistgov/ACVP-Server/blob/master/gen-val/json-files
#
# See https://pages.nist.gov/ACVP/draft-celi-acvp-ml-dsa.html and
# https://github.com/usnistgov/ACVP-Server/tree/master/gen-val/json-files
# Invokes `acvp_mldsa{lvl}` under the hood.

import argparse
Expand All @@ -20,7 +19,7 @@

# Check if we need to use a wrapper for execution (e.g. QEMU)
exec_prefix = os.environ.get("EXEC_WRAPPER", "")
exec_prefix = [exec_prefix] if exec_prefix != "" else []
exec_prefix = exec_prefix.split(" ") if exec_prefix != "" else []


def download_acvp_files(version="v1.1.0.40"):
Expand All @@ -29,9 +28,12 @@ def download_acvp_files(version="v1.1.0.40"):

# Files we need to download for ML-KEM
files_to_download = [
"ML-DSA-keyGen-FIPS204/internalProjection.json",
"ML-DSA-sigGen-FIPS204/internalProjection.json",
"ML-DSA-sigVer-FIPS204/internalProjection.json",
"ML-DSA-keyGen-FIPS204/prompt.json",
"ML-DSA-keyGen-FIPS204/expectedResults.json",
"ML-DSA-sigGen-FIPS204/prompt.json",
"ML-DSA-sigGen-FIPS204/expectedResults.json",
"ML-DSA-sigVer-FIPS204/prompt.json",
"ML-DSA-sigVer-FIPS204/expectedResults.json",
]

# Create directory structure
Expand Down Expand Up @@ -65,22 +67,36 @@ def download_acvp_files(version="v1.1.0.40"):
return True


def loadAcvpData(internalProjection):
with open(internalProjection, "r") as f:
internalProjectionData = json.load(f)
return (internalProjection, internalProjectionData)
def loadAcvpData(prompt, expectedResults):
with open(prompt, "r") as f:
promptData = json.load(f)
expectedResultsData = None
if expectedResults is not None:
with open(expectedResults, "r") as f:
expectedResultsData = json.load(f)

return (prompt, promptData, expectedResults, expectedResultsData)


def loadDefaultAcvpData(version="v1.1.0.40"):
data_dir = f"test/.acvp-data/{version}/files"
acvp_jsons_for_version = [
f"{data_dir}/ML-DSA-keyGen-FIPS204/internalProjection.json",
f"{data_dir}/ML-DSA-sigGen-FIPS204/internalProjection.json",
f"{data_dir}/ML-DSA-sigVer-FIPS204/internalProjection.json",
(
f"{data_dir}/ML-DSA-keyGen-FIPS204/prompt.json",
f"{data_dir}/ML-DSA-keyGen-FIPS204/expectedResults.json",
),
(
f"{data_dir}/ML-DSA-sigGen-FIPS204/prompt.json",
f"{data_dir}/ML-DSA-sigGen-FIPS204/expectedResults.json",
),
(
f"{data_dir}/ML-DSA-sigVer-FIPS204/prompt.json",
f"{data_dir}/ML-DSA-sigVer-FIPS204/expectedResults.json",
),
]
acvp_data = []
for internalProjection in acvp_jsons_for_version:
acvp_data.append(loadAcvpData(internalProjection))
for prompt, expectedResults in acvp_jsons_for_version:
acvp_data.append(loadAcvpData(prompt, expectedResults))
return acvp_data


Expand All @@ -107,6 +123,8 @@ def get_acvp_binary(tg):

def run_keyGen_test(tg, tc):
info(f"Running keyGen test case {tc['tcId']} ... ", end="")

results = {"tcId": tc["tcId"]}
acvp_bin = get_acvp_binary(tg)
assert tg["testType"] == "AFT"
acvp_call = exec_prefix + [
Expand All @@ -120,14 +138,12 @@ def run_keyGen_test(tg, tc):
err(f"{acvp_call} failed with error code {result.returncode}")
err(result.stderr)
exit(1)
# Extract results and compare to expected data
# Extract results
for l in result.stdout.splitlines():
(k, v) = l.split("=")
if v != tc[k]:
err("FAIL!")
err(f"Mismatching result for {k}: expected {tc[k]}, got {v}")
exit(1)
info("OK")
results[k] = v
info("done")
return results


def compute_hash(msg, alg):
Expand Down Expand Up @@ -163,13 +179,13 @@ def compute_hash(msg, alg):

def run_sigGen_test(tg, tc):
info(f"Running sigGen test case {tc['tcId']} ... ", end="")
results = {"tcId": tc["tcId"]}
acvp_bin = get_acvp_binary(tg)

assert tg["testType"] == "AFT"

is_deterministic = tg["deterministic"] is True

if tg["preHash"] == "preHash":
if "preHash" in tg and tg["preHash"] == "preHash":
assert len(tc["context"]) <= 2 * 255

# Use specialized SHAKE256 function that computes hash internally
Expand Down Expand Up @@ -200,7 +216,7 @@ def run_sigGen_test(tg, tc):
f"hashAlg={tc['hashAlg']}",
]
elif tg["signatureInterface"] == "external":
assert tc["hashAlg"] == "none"
assert "hashAlg" not in tc or tc["hashAlg"] == "none"
assert len(tc["context"]) <= 2 * 255
assert len(tc["message"]) <= 2 * 65536

Expand All @@ -213,7 +229,7 @@ def run_sigGen_test(tg, tc):
f"context={tc['context']}",
]
else: # signatureInterface=internal
assert tc["hashAlg"] == "none"
assert "hashAlg" not in tc or tc["hashAlg"] == "none"
externalMu = 0
if tg["externalMu"] is True:
externalMu = 1
Expand Down Expand Up @@ -242,21 +258,20 @@ def run_sigGen_test(tg, tc):
err(f"{acvp_call} failed with error code {result.returncode}")
err(result.stderr)
exit(1)
# Extract results and compare to expected data
# Extract results
for l in result.stdout.splitlines():
(k, v) = l.split("=")
if v != tc[k]:
err("FAIL!")
err(f"Mismatching result for {k}: expected {tc[k]}, got {v}")
exit(1)
info("OK")
results[k] = v
info("done")
return results


def run_sigVer_test(tg, tc):
info(f"Running sigVer test case {tc['tcId']} ... ", end="")
results = {"tcId": tc["tcId"]}
acvp_bin = get_acvp_binary(tg)

if tg["preHash"] == "preHash":
if "preHash" in tg and tg["preHash"] == "preHash":
assert len(tc["context"]) <= 2 * 255

# Use specialized SHAKE256 function that computes hash internally
Expand All @@ -281,7 +296,7 @@ def run_sigVer_test(tg, tc):
f"hashAlg={tc['hashAlg']}",
]
elif tg["signatureInterface"] == "external":
assert tc["hashAlg"] == "none"
assert "hashAlg" not in tc or tc["hashAlg"] == "none"
assert len(tc["context"]) <= 2 * 255
assert len(tc["message"]) <= 2 * 65536

Expand All @@ -294,7 +309,7 @@ def run_sigVer_test(tg, tc):
f"pk={tc['pk']}",
]
else: # signatureInterface=internal
assert tc["hashAlg"] == "none"
assert "hashAlg" not in tc or tc["hashAlg"] == "none"
externalMu = 0
if tg["externalMu"] is True:
externalMu = 1
Expand All @@ -314,61 +329,127 @@ def run_sigVer_test(tg, tc):
]

result = subprocess.run(acvp_call, encoding="utf-8", capture_output=True)

if (result.returncode == 0) != tc["testPassed"]:
err("FAIL!")
err(
f"Mismatching verification result: expected {tc['testPassed']}, got {result.returncode == 0}"
)
exit(1)
info("OK")


def runTestSingle(internalProjectionName, internalProjection):
info(f"Running ACVP tests for {internalProjectionName}")

assert internalProjection["algorithm"] == "ML-DSA"
# Extract results
results["testPassed"] = result.returncode == 0
info("done")
return results


def runTestSingle(promptName, prompt, expectedResultName, expectedResult, output):
info(f"Running ACVP tests for {promptName}")

assert expectedResult is not None or output is not None

# The ACVTS data structure is very slightly different from the sample files
# in the usnistgov/ACVP-Server Github repository:
# The prompt consists of a 2-element list, where the first element is
# solely consisting of {"acvVersion": "1.0"} and the second element is
# the usual prompt containing the test values.
# See https://pages.nist.gov/ACVP/draft-celi-acvp-ml-dsa.txt for details.
# We automatically detect that case here and extract the second element
isAcvts = False
if type(prompt) is list:
isAcvts = True
assert len(prompt) == 2
acvVersion = prompt[0]
assert len(acvVersion) == 1
prompt = prompt[1]

assert prompt["algorithm"] == "ML-DSA"
assert (
internalProjection["mode"] == "keyGen"
or internalProjection["mode"] == "sigGen"
or internalProjection["mode"] == "sigVer"
prompt["mode"] == "keyGen"
or prompt["mode"] == "sigGen"
or prompt["mode"] == "sigVer"
)

# copy top level fields into the results
results = internalProjection.copy()
results = prompt.copy()

results["testGroups"] = []
for tg in internalProjection["testGroups"]:
for tg in prompt["testGroups"]:
tgResult = {
"tgId": tg["tgId"],
"tests": [],
}
results["testGroups"].append(tgResult)
for tc in tg["tests"]:
if internalProjection["mode"] == "keyGen":
if prompt["mode"] == "keyGen":
result = run_keyGen_test(tg, tc)
elif internalProjection["mode"] == "sigGen":
elif prompt["mode"] == "sigGen":
result = run_sigGen_test(tg, tc)
elif internalProjection["mode"] == "sigVer":
elif prompt["mode"] == "sigVer":
result = run_sigVer_test(tg, tc)
tgResult["tests"].append(result)

# In case the testvectors are from the ACVTS server, it is expected
# that the acvVersion is included in the output results.
# See note on ACVTS data structure above.
if isAcvts is True:
results = [acvVersion, results]

# Compare to expected results
if expectedResult is not None:
info(f"Comparing results with {expectedResultName}")
# json.dumps() is guaranteed to preserve insertion order (since Python 3.7)
# Enforce strictly the same order as in the expected Result
if json.dumps(results) != json.dumps(expectedResult):
err("FAIL!")
err(f"Mismatching result for {promptName}")
exit(1)
info("OK")
else:
info(
"Results could not be validated as no expected resulted were provided to --expected"
)

# Write results to file
if output is not None:
info(f"Writing results to {output}")
with open(output, "w") as f:
json.dump(results, f)


def runTest(data, output):
# if output is defined we expect only one input
assert output is None or len(data) == 1

def runTest(data):
for internalProjectionName, internalProjection in data:
runTestSingle(internalProjectionName, internalProjection)
for promptName, prompt, expectedResultName, expectedResult in data:
runTestSingle(promptName, prompt, expectedResultName, expectedResult, output)
info("ALL GOOD!")


def test(version="v1.1.0.40"):
# load data from downloaded files
data = loadDefaultAcvpData(version)
def test(prompt, expected, output, version="v1.1.0.40"):
assert (
prompt is not None or output is None
), "cannot produce output if there is no input"

assert prompt is None or (
output is not None or expected is not None
), "if there is a prompt, either output or expectedResult required"

runTest(data)
# if prompt is passed, use it
if prompt is not None:
data = [loadAcvpData(prompt, expected)]
else:
# load data from downloaded files
data = loadDefaultAcvpData(version)

runTest(data, output)

parser = argparse.ArgumentParser()

parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--prompt", help="Path to prompt file in json format", required=False
)
parser.add_argument(
"-e",
"--expected",
help="Path to expectedResults file in json format",
required=False,
)
parser.add_argument(
"-o", "--output", help="Path to output file in json format", required=False
)
parser.add_argument(
"--version",
"-v",
Expand All @@ -377,9 +458,12 @@ def test(version="v1.1.0.40"):
)
args = parser.parse_args()

# Download files if needed
if not download_acvp_files(args.version):
print("Failed to download ACVP test files", file=sys.stderr)
sys.exit(1)
if args.prompt is None:
print(f"Using ACVP test vectors version {args.version}", file=sys.stderr)

# Download files if needed
if not download_acvp_files(args.version):
print("Failed to download ACVP test files", file=sys.stderr)
sys.exit(1)

test(args.version)
test(args.prompt, args.expected, args.output, args.version)