Skip to content

Commit

Permalink
Improve zkey loading (#27)
Browse files Browse the repository at this point in the history
Co-authored-by: nixw <>
olomix authored Dec 10, 2024
1 parent 38c832a commit af527f9
Showing 12 changed files with 630 additions and 323 deletions.
2 changes: 1 addition & 1 deletion depends/ffiasm
Submodule ffiasm updated 5 files
+5 −4 c/fft.cpp
+2 −0 c/fft.hpp
+31 −24 c/misc.hpp
+37 −36 c/msm.cpp
+3 −1 c/msm.hpp
131 changes: 52 additions & 79 deletions src/binfile_utils.cpp
Original file line number Diff line number Diff line change
@@ -12,92 +12,49 @@

namespace BinFileUtils {

BinFile::BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion) {
BinFile::BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion)
: fileLoader(fileName)
{
addr = fileLoader.dataBuffer();
size = fileLoader.dataSize();

is_fd = true;
struct stat sb;

fd = open(fileName.c_str(), O_RDONLY);
if (fd == -1)
throw std::system_error(errno, std::generic_category(), "open");


if (fstat(fd, &sb) == -1) { /* To obtain file size */
close(fd);
throw std::system_error(errno, std::generic_category(), "fstat");
}

size = sb.st_size;

addr = mmap(nullptr, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (addr == MAP_FAILED) {
close(fd);
throw std::system_error(errno, std::generic_category(), "mmap failed");
}
madvise(addr, size, MADV_SEQUENTIAL);

type.assign((const char *)addr, 4);
pos = 4;

if (type != _type) {
munmap(addr, size);
close(fd);
throw std::invalid_argument("Invalid file type. It should be " + _type + " and it is " + type + " filename: " + fileName);
}

version = readU32LE();
if (version > maxVersion) {
munmap(addr, size);
close(fd);
throw std::invalid_argument("Invalid version. It should be <=" + std::to_string(maxVersion) + " and it is " + std::to_string(version));
}

u_int32_t nSections = readU32LE();


for (u_int32_t i=0; i<nSections; i++) {
u_int32_t sType=readU32LE();
u_int64_t sSize=readU64LE();

if (sections.find(sType) == sections.end()) {
sections.insert(std::make_pair(sType, std::vector<Section>()));
}
readFileData(_type, maxVersion);
}

sections[sType].push_back(Section( (void *)((u_int64_t)addr + pos), sSize));
BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion) {

pos += sSize;
}
addr = fileData;
size = fileSize;

pos = 0;
readingSection = nullptr;
readFileData(_type, maxVersion);
}

void BinFile::readFileData(std::string _type, uint32_t maxVersion) {

BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion) {
const u_int64_t headerSize = 12;
const u_int64_t minSectionSize = 12;

is_fd = false;
fd = -1;

size = fileSize;
addr = malloc(size);
memcpy(addr, fileData, size);
if (size < headerSize) {
throw std::range_error("File is too short.");
}

type.assign((const char *)addr, 4);
pos = 4;

if (type != _type) {
free(addr);
throw std::invalid_argument("Invalid file type. It should be " + _type + " and it is " + type);
}

version = readU32LE();
if (version > maxVersion) {
free(addr);
throw std::invalid_argument("Invalid version. It should be <=" + std::to_string(maxVersion) + " and it is " + std::to_string(version));
}

u_int32_t nSections = readU32LE();

if (size < headerSize + nSections * minSectionSize) {
throw std::range_error("File is too short to contain " + std::to_string(nSections) + " sections.");
}

for (u_int32_t i=0; i<nSections; i++) {
u_int32_t sType=readU32LE();
@@ -110,20 +67,18 @@ BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint3
sections[sType].push_back(Section( (void *)((u_int64_t)addr + pos), sSize));

pos += sSize;

if (pos > size) {
throw std::range_error("Section #" + std::to_string(i) + " is invalid."
". It ends at pos " + std::to_string(pos) +
" but should end before " + std::to_string(size) + ".");
}
}

pos = 0;
readingSection = NULL;
readingSection = nullptr;
}

BinFile::~BinFile() {
if (is_fd) {
munmap(addr, size);
close(fd);
} else {
free(addr);
}
}

void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos) {

@@ -135,7 +90,7 @@ void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos) {
throw std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
}

if (readingSection != NULL) {
if (readingSection != nullptr) {
throw std::range_error("Already reading a section");
}

@@ -150,7 +105,7 @@ void BinFile::endReadSection(bool check) {
throw std::range_error("Invalid section size");
}
}
readingSection = NULL;
readingSection = nullptr;
}

void *BinFile::getSectionData(u_int32_t sectionId, u_int32_t sectionPos) {
@@ -169,31 +124,49 @@ void *BinFile::getSectionData(u_int32_t sectionId, u_int32_t sectionPos) {
u_int64_t BinFile::getSectionSize(u_int32_t sectionId, u_int32_t sectionPos) {

if (sections.find(sectionId) == sections.end()) {
throw new std::range_error("Section does not exist: " + std::to_string(sectionId));
throw std::range_error("Section does not exist: " + std::to_string(sectionId));
}

if (sectionPos >= sections[sectionId].size()) {
throw new std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
throw std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
}

return sections[sectionId][sectionPos].size;
}

u_int32_t BinFile::readU32LE() {
const u_int64_t new_pos = pos + 4;

if (new_pos > size) {
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
}

u_int32_t res = *((u_int32_t *)((u_int64_t)addr + pos));
pos += 4;
pos = new_pos;
return res;
}

u_int64_t BinFile::readU64LE() {
const u_int64_t new_pos = pos + 8;

if (new_pos > size) {
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
}

u_int64_t res = *((u_int64_t *)((u_int64_t)addr + pos));
pos += 8;
pos = new_pos;
return res;
}

void *BinFile::read(u_int64_t len) {
const u_int64_t new_pos = pos + len;

if (new_pos > size) {
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
}

void *res = (void *)((u_int64_t)addr + pos);
pos += len;
pos = new_pos;
return res;
}

8 changes: 4 additions & 4 deletions src/binfile_utils.hpp
Original file line number Diff line number Diff line change
@@ -4,15 +4,15 @@
#include <map>
#include <vector>
#include <memory>
#include "fileloader.hpp"

namespace BinFileUtils {

class BinFile {

bool is_fd;
int fd;
FileLoader fileLoader;

void *addr;
const void *addr;
u_int64_t size;
u_int64_t pos;

@@ -32,14 +32,14 @@ namespace BinFileUtils {

Section *readingSection;

void readFileData(std::string _type, uint32_t maxVersion);

public:

BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion);
BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion);
BinFile(const BinFile&) = delete;
BinFile& operator=(const BinFile&) = delete;
~BinFile();

void startReadSection(u_int32_t sectionId, u_int32_t setionPos = 0);
void endReadSection(bool check = true);
28 changes: 26 additions & 2 deletions src/fileloader.cpp
Original file line number Diff line number Diff line change
@@ -9,8 +9,23 @@

namespace BinFileUtils {

FileLoader::FileLoader()
: fd(-1)
{
}

FileLoader::FileLoader(const std::string& fileName)
: fd(-1)
{
load(fileName);
}

void FileLoader::load(const std::string& fileName)
{
if (fd != -1) {
throw std::invalid_argument("file already loaded");
}

struct stat sb;

fd = open(fileName.c_str(), O_RDONLY);
@@ -26,12 +41,21 @@ FileLoader::FileLoader(const std::string& fileName)
size = sb.st_size;

addr = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd, 0);

if (addr == MAP_FAILED) {
close(fd);
throw std::system_error(errno, std::generic_category(), "mmap failed");
}

madvise(addr, size, MADV_SEQUENTIAL);
}

FileLoader::~FileLoader()
{
munmap(addr, size);
close(fd);
if (fd != -1) {
munmap(addr, size);
close(fd);
}
}

} // Namespace
3 changes: 3 additions & 0 deletions src/fileloader.hpp
Original file line number Diff line number Diff line change
@@ -9,9 +9,12 @@ namespace BinFileUtils {
class FileLoader
{
public:
FileLoader();
FileLoader(const std::string& fileName);
~FileLoader();

void load(const std::string& fileName);

void* dataBuffer() { return addr; }
size_t dataSize() const { return size; }

15 changes: 8 additions & 7 deletions src/groth16.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "random_generator.hpp"
#include "logging.hpp"
#include "misc.hpp"
#include <sstream>
#include <vector>
#include <mutex>

@@ -84,7 +85,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
auto b = new typename Engine::FrElement[domainSize];
auto c = new typename Engine::FrElement[domainSize];

threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int32_t i=begin; i<end; i++) {
E.fr.copy(a[i], E.fr.zero());
E.fr.copy(b[i], E.fr.zero());
@@ -96,7 +97,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
#define NLOCKS 1024
std::vector<std::mutex> locks(NLOCKS);

threadPool.parallelFor(0, nCoefs, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, nCoefs, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
typename Engine::FrElement aux;
@@ -117,7 +118,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
}
});
LOG_TRACE("Calculating c");
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(
c[i],
@@ -137,7 +138,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(a[1]).c_str());
LOG_TRACE("Start Shift A");

threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
}
@@ -157,7 +158,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(b[0]).c_str());
LOG_DEBUG(E.fr.toString(b[1]).c_str());
LOG_TRACE("Start Shift B");
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
}
@@ -177,7 +178,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(c[0]).c_str());
LOG_DEBUG(E.fr.toString(c[1]).c_str());
LOG_TRACE("Start Shift C");
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
}
@@ -192,7 +193,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(c[1]).c_str());

LOG_TRACE("Start ABC");
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(a[i], a[i], b[i]);
E.fr.sub(a[i], a[i], c[i]);
133 changes: 53 additions & 80 deletions src/main_prover.cpp
Original file line number Diff line number Diff line change
@@ -1,103 +1,76 @@
#include <iostream>
#include <fstream>
#include <gmp.h>
#include <memory>
#include <string>
#include <vector>
#include <stdexcept>
#include <nlohmann/json.hpp>

#include <alt_bn128.hpp>
#include "binfile_utils.hpp"
#include "zkey_utils.hpp"
#include "wtns_utils.hpp"
#include "groth16.hpp"

using json = nlohmann::json;

#define handle_error(msg) \
do { perror(msg); exit(EXIT_FAILURE); } while (0)
#include "prover.h"
#include "fileloader.hpp"

int main(int argc, char **argv)
{
if (argc != 5) {
std::cerr << "Invalid number of parameters:\n";
std::cerr << "Usage: prover <circuit.zkey> <witness.wtns> <proof.json> <public.json>\n";
std::cerr << "Invalid number of parameters" << std::endl;
std::cerr << "Usage: prover <circuit.zkey> <witness.wtns> <proof.json> <public.json>" << std::endl;
return EXIT_FAILURE;
}

mpz_t altBbn128r;

mpz_init(altBbn128r);
mpz_set_str(altBbn128r, "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10);

try {
std::string zkeyFilename = argv[1];
std::string wtnsFilename = argv[2];
std::string proofFilename = argv[3];
std::string publicFilename = argv[4];

auto zkey = BinFileUtils::openExisting(zkeyFilename, "zkey", 1);
auto zkeyHeader = ZKeyUtils::loadHeader(zkey.get());

std::string proofStr;
if (mpz_cmp(zkeyHeader->rPrime, altBbn128r) != 0) {
throw std::invalid_argument( "zkey curve not supported" );
const std::string zkeyFilename = argv[1];
const std::string wtnsFilename = argv[2];
const std::string proofFilename = argv[3];
const std::string publicFilename = argv[4];

BinFileUtils::FileLoader zkeyFile(zkeyFilename);
BinFileUtils::FileLoader wtnsFile(wtnsFilename);
std::vector<char> publicBuffer;
std::vector<char> proofBuffer;
unsigned long long publicSize = 0;
unsigned long long proofSize = 0;
char errorMsg[1024];

int error = groth16_public_size_for_zkey_buf(
zkeyFile.dataBuffer(),
zkeyFile.dataSize(),
&publicSize,
errorMsg,
sizeof(errorMsg));

if (error != PROVER_OK) {
throw std::runtime_error(errorMsg);
}

auto wtns = BinFileUtils::openExisting(wtnsFilename, "wtns", 2);
auto wtnsHeader = WtnsUtils::loadHeader(wtns.get());

if (mpz_cmp(wtnsHeader->prime, altBbn128r) != 0) {
throw std::invalid_argument( "different wtns curve" );
groth16_proof_size(&proofSize);

publicBuffer.resize(publicSize);
proofBuffer.resize(proofSize);

error = groth16_prover(
zkeyFile.dataBuffer(),
zkeyFile.dataSize(),
wtnsFile.dataBuffer(),
wtnsFile.dataSize(),
proofBuffer.data(),
&proofSize,
publicBuffer.data(),
&publicSize,
errorMsg,
sizeof(errorMsg));

if (error != PROVER_OK) {
throw std::runtime_error(errorMsg);
}

auto prover = Groth16::makeProver<AltBn128::Engine>(
zkeyHeader->nVars,
zkeyHeader->nPublic,
zkeyHeader->domainSize,
zkeyHeader->nCoefs,
zkeyHeader->vk_alpha1,
zkeyHeader->vk_beta1,
zkeyHeader->vk_beta2,
zkeyHeader->vk_delta1,
zkeyHeader->vk_delta2,
zkey->getSectionData(4), // Coefs
zkey->getSectionData(5), // pointsA
zkey->getSectionData(6), // pointsB1
zkey->getSectionData(7), // pointsB2
zkey->getSectionData(8), // pointsC
zkey->getSectionData(9) // pointsH1
);
AltBn128::FrElement *wtnsData = (AltBn128::FrElement *)wtns->getSectionData(2);
auto proof = prover->prove(wtnsData);

std::ofstream proofFile;
proofFile.open (proofFilename);
proofFile << proof->toJson();
proofFile.close();
std::ofstream proofFile(proofFilename);
proofFile.write(proofBuffer.data(), proofSize);

std::ofstream publicFile;
publicFile.open (publicFilename);

json jsonPublic;
AltBn128::FrElement aux;
for (int i=1; i<=zkeyHeader->nPublic; i++) {
AltBn128::Fr.toMontgomery(aux, wtnsData[i]);
jsonPublic.push_back(AltBn128::Fr.toString(aux));
}
std::ofstream publicFile(publicFilename);
publicFile.write(publicBuffer.data(), publicSize);

publicFile << jsonPublic;
publicFile.close();

} catch (std::exception* e) {
mpz_clear(altBbn128r);
std::cerr << e->what() << '\n';
return EXIT_FAILURE;
} catch (std::exception& e) {
mpz_clear(altBbn128r);
std::cerr << e.what() << '\n';
std::cerr << "Error: " << e.what() << std::endl;
return EXIT_FAILURE;

}

mpz_clear(altBbn128r);
exit(EXIT_SUCCESS);
}
488 changes: 372 additions & 116 deletions src/prover.cpp

Large diffs are not rendered by default.

113 changes: 94 additions & 19 deletions src/prover.h
Original file line number Diff line number Diff line change
@@ -16,9 +16,12 @@ extern "C" {
* @returns PROVER_OK in case of success, and the size of public buffer is written to public_size
*/
int
groth16_public_size_for_zkey_buf(const void *zkey_buffer, unsigned long zkey_size,
size_t *public_size,
char *error_msg, unsigned long error_msg_maxsize);
groth16_public_size_for_zkey_buf(
const void *zkey_buffer,
unsigned long long zkey_size,
unsigned long long *public_size,
char *error_msg,
unsigned long long error_msg_maxsize);

/**
* groth16_public_size_for_zkey_file calculates minimum buffer size for
@@ -30,37 +33,109 @@ groth16_public_size_for_zkey_buf(const void *zkey_buffer, unsigned long zkey_siz
* PROVER_ERROR - in case of an error, error_msg contains the error message
*/
int
groth16_public_size_for_zkey_file(const char *zkey_fname,
unsigned long *public_size,
char *error_msg, unsigned long error_msg_maxsize);
groth16_public_size_for_zkey_file(
const char *zkey_fname,
unsigned long long *public_size,
char *error_msg,
unsigned long long error_msg_maxsize);

/**
* groth16_prover
* Returns buffer size to output proof as json string
*/
void
groth16_proof_size(
unsigned long long *proof_size);

/**
* Initializes 'prover_object' with a pointer to a new prover object.
* @return error code:
* PROVER_OK - in case of success
* PPOVER_ERROR - in case of an error
*/
int
groth16_prover_create(
void **prover_object,
const void *zkey_buffer,
unsigned long long zkey_size,
char *error_msg,
unsigned long long error_msg_maxsize);

/**
* Initializes 'prover_object' with a pointer to a new prover object.
* @return error code:
* PROVER_OK - in case of success
* PPOVER_ERROR - in case of an error
*/
int
groth16_prover_create_zkey_file(
void **prover_object,
const char *zkey_file_path,
char *error_msg,
unsigned long long error_msg_maxsize);

/**
* Proves 'wtns_buffer' and saves results to 'proof_buffer' and 'public_buffer'.
* @return error code:
* PROVER_OK - in case of success
* PPOVER_ERROR - in case of an error
* PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizess
* PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizes
*/
int
groth16_prover(const void *zkey_buffer, unsigned long zkey_size,
const void *wtns_buffer, unsigned long wtns_size,
char *proof_buffer, unsigned long *proof_size,
char *public_buffer, unsigned long *public_size,
char *error_msg, unsigned long error_msg_maxsize);
groth16_prover_prove(
void *prover_object,
const void *wtns_buffer,
unsigned long long wtns_size,
char *proof_buffer,
unsigned long long *proof_size,
char *public_buffer,
unsigned long long *public_size,
char *error_msg,
unsigned long long error_msg_maxsize);

/**
* Destroys 'prover_object'.
*/
void
groth16_prover_destroy(void *prover_object);

/**
* groth16_prover
* @return error code:
* PROVER_OK - in case of success
* PPOVER_ERROR - in case of an error
* PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizess
* PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizes
*/
int
groth16_prover(
const void *zkey_buffer,
unsigned long long zkey_size,
const void *wtns_buffer,
unsigned long long wtns_size,
char *proof_buffer,
unsigned long long *proof_size,
char *public_buffer,
unsigned long long *public_size,
char *error_msg,
unsigned long long error_msg_maxsize);

/**
* groth16_prover_zkey_file
* @return error code:
* PROVER_OK - in case of success
* PPOVER_ERROR - in case of an error
* PROVER_ERROR_SHORT_BUFFER - in case of a short buffer error, also updates proof_size and public_size with actual proof and public sizes
*/
int
groth16_prover_zkey_file(const char *zkey_file_path,
const void *wtns_buffer, unsigned long wtns_size,
char *proof_buffer, unsigned long *proof_size,
char *public_buffer, unsigned long *public_size,
char *error_msg, unsigned long error_msg_maxsize);
groth16_prover_zkey_file(
const char *zkey_file_path,
const void *wtns_buffer,
unsigned long long wtns_size,
char *proof_buffer,
unsigned long long *proof_size,
char *public_buffer,
unsigned long long *public_size,
char *error_msg,
unsigned long long error_msg_maxsize);

#ifdef __cplusplus
}
14 changes: 7 additions & 7 deletions src/test_public_size.c
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
#include "prover.h"

int
test_groth16_public_size(const char *zkey_fname, size_t *public_size) {
test_groth16_public_size(const char *zkey_fname, unsigned long long *public_size) {
int ret_val = 0;
const int error_sz = 256;
char error_msg[error_sz];
@@ -54,7 +54,7 @@ test_groth16_public_size(const char *zkey_fname, size_t *public_size) {

int ok = groth16_public_size_for_zkey_buf(buf, sb.st_size, public_size, error_msg, error_sz);
if (ok == 0) {
printf("Public size: %lu\n", *public_size);
printf("Public size: %llu\n", *public_size);
} else {
printf("Error: %s\n", error_msg);
ret_val = 1;
@@ -72,13 +72,13 @@ test_groth16_public_size(const char *zkey_fname, size_t *public_size) {

int
test_groth16_public_size_for_zkey_file(const char *zkey_fname,
size_t *public_size) {
unsigned long long *public_size) {
const int err_ln = 256;
char error_msg[err_ln];
int ret = groth16_public_size_for_zkey_file(zkey_fname, public_size, error_msg, err_ln);

if (ret == 0) {
printf("Public size: %lu\n", *public_size);
printf("Public size: %llu\n", *public_size);
} else {
printf("Error: %s\n", error_msg);
}
@@ -98,7 +98,7 @@ main(int argc, char *argv[]) {
int ret_val = 0;
clock_t start = clock();

size_t public_size = 0;
unsigned long long public_size = 0;

int test_groth16_public_size_ok =
test_groth16_public_size(argv[1], &public_size);
@@ -114,7 +114,7 @@ main(int argc, char *argv[]) {
if (public_size != want_pub_size) {
printf("test_groth16_public_size expected public signals buf size: %ld\n",
want_pub_size);
printf("test_groth16_public_size actual public signals buf size: %lu\n",
printf("test_groth16_public_size actual public signals buf size: %llu\n",
public_size);
ret_val = 1;
}
@@ -135,7 +135,7 @@ main(int argc, char *argv[]) {
if (public_size != want_pub_size) {
printf("test_groth16_public_size_for_zkey_file expected public signals buf size: %ld\n",
want_pub_size);
printf("test_groth16_public_size_for_zkey_file actual public signals buf size: %lu\n",
printf("test_groth16_public_size_for_zkey_file actual public signals buf size: %llu\n",
public_size);
ret_val = 1;
}
9 changes: 5 additions & 4 deletions src/wtns_utils.cpp
Original file line number Diff line number Diff line change
@@ -3,25 +3,26 @@
namespace WtnsUtils {

Header::Header() {
mpz_init(prime);
}

Header::~Header() {
mpz_clear(prime);
}

std::unique_ptr<Header> loadHeader(BinFileUtils::BinFile *f) {
Header *h = new Header();
std::unique_ptr<Header> h(new Header());

f->startReadSection(1);

h->n8 = f->readU32LE();
mpz_init(h->prime);
mpz_import(h->prime, h->n8, -1, 1, -1, 0, f->read(h->n8));

h->nVars = f->readU32LE();

f->endReadSection();

return std::unique_ptr<Header>(h);
return h;
}

} // NAMESPACE
} // NAMESPACE
9 changes: 5 additions & 4 deletions src/zkey_utils.cpp
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@ namespace ZKeyUtils {


Header::Header() {
mpz_init(qPrime);
mpz_init(rPrime);
}

Header::~Header() {
@@ -15,7 +17,8 @@ Header::~Header() {


std::unique_ptr<Header> loadHeader(BinFileUtils::BinFile *f) {
auto h = new Header();

std::unique_ptr<Header> h(new Header());

f->startReadSection(1);
uint32_t protocol = f->readU32LE();
@@ -27,11 +30,9 @@ std::unique_ptr<Header> loadHeader(BinFileUtils::BinFile *f) {
f->startReadSection(2);

h->n8q = f->readU32LE();
mpz_init(h->qPrime);
mpz_import(h->qPrime, h->n8q, -1, 1, -1, 0, f->read(h->n8q));

h->n8r = f->readU32LE();
mpz_init(h->rPrime);
mpz_import(h->rPrime, h->n8r , -1, 1, -1, 0, f->read(h->n8r));

h->nVars = f->readU32LE();
@@ -48,7 +49,7 @@ std::unique_ptr<Header> loadHeader(BinFileUtils::BinFile *f) {

h->nCoefs = f->getSectionSize(4) / (12 + h->n8r);

return std::unique_ptr<Header>(h);
return h;
}

} // namespace

0 comments on commit af527f9

Please sign in to comment.