Skip to content

Commit af527f9

Browse files
authored
Improve zkey loading (#27)
Co-authored-by: nixw <>
1 parent 38c832a commit af527f9

12 files changed

+630
-323
lines changed

src/binfile_utils.cpp

Lines changed: 52 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -12,92 +12,49 @@
1212

1313
namespace BinFileUtils {
1414

15-
BinFile::BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion) {
15+
BinFile::BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion)
16+
: fileLoader(fileName)
17+
{
18+
addr = fileLoader.dataBuffer();
19+
size = fileLoader.dataSize();
1620

17-
is_fd = true;
18-
struct stat sb;
19-
20-
fd = open(fileName.c_str(), O_RDONLY);
21-
if (fd == -1)
22-
throw std::system_error(errno, std::generic_category(), "open");
23-
24-
25-
if (fstat(fd, &sb) == -1) { /* To obtain file size */
26-
close(fd);
27-
throw std::system_error(errno, std::generic_category(), "fstat");
28-
}
29-
30-
size = sb.st_size;
31-
32-
addr = mmap(nullptr, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
33-
if (addr == MAP_FAILED) {
34-
close(fd);
35-
throw std::system_error(errno, std::generic_category(), "mmap failed");
36-
}
37-
madvise(addr, size, MADV_SEQUENTIAL);
38-
39-
type.assign((const char *)addr, 4);
40-
pos = 4;
41-
42-
if (type != _type) {
43-
munmap(addr, size);
44-
close(fd);
45-
throw std::invalid_argument("Invalid file type. It should be " + _type + " and it is " + type + " filename: " + fileName);
46-
}
47-
48-
version = readU32LE();
49-
if (version > maxVersion) {
50-
munmap(addr, size);
51-
close(fd);
52-
throw std::invalid_argument("Invalid version. It should be <=" + std::to_string(maxVersion) + " and it is " + std::to_string(version));
53-
}
54-
55-
u_int32_t nSections = readU32LE();
56-
57-
58-
for (u_int32_t i=0; i<nSections; i++) {
59-
u_int32_t sType=readU32LE();
60-
u_int64_t sSize=readU64LE();
61-
62-
if (sections.find(sType) == sections.end()) {
63-
sections.insert(std::make_pair(sType, std::vector<Section>()));
64-
}
21+
readFileData(_type, maxVersion);
22+
}
6523

66-
sections[sType].push_back(Section( (void *)((u_int64_t)addr + pos), sSize));
24+
BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion) {
6725

68-
pos += sSize;
69-
}
26+
addr = fileData;
27+
size = fileSize;
7028

71-
pos = 0;
72-
readingSection = nullptr;
29+
readFileData(_type, maxVersion);
7330
}
7431

32+
void BinFile::readFileData(std::string _type, uint32_t maxVersion) {
7533

76-
BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion) {
34+
const u_int64_t headerSize = 12;
35+
const u_int64_t minSectionSize = 12;
7736

78-
is_fd = false;
79-
fd = -1;
80-
81-
size = fileSize;
82-
addr = malloc(size);
83-
memcpy(addr, fileData, size);
37+
if (size < headerSize) {
38+
throw std::range_error("File is too short.");
39+
}
8440

8541
type.assign((const char *)addr, 4);
8642
pos = 4;
8743

8844
if (type != _type) {
89-
free(addr);
9045
throw std::invalid_argument("Invalid file type. It should be " + _type + " and it is " + type);
9146
}
9247

9348
version = readU32LE();
9449
if (version > maxVersion) {
95-
free(addr);
9650
throw std::invalid_argument("Invalid version. It should be <=" + std::to_string(maxVersion) + " and it is " + std::to_string(version));
9751
}
9852

9953
u_int32_t nSections = readU32LE();
10054

55+
if (size < headerSize + nSections * minSectionSize) {
56+
throw std::range_error("File is too short to contain " + std::to_string(nSections) + " sections.");
57+
}
10158

10259
for (u_int32_t i=0; i<nSections; i++) {
10360
u_int32_t sType=readU32LE();
@@ -110,20 +67,18 @@ BinFile::BinFile(const void *fileData, size_t fileSize, std::string _type, uint3
11067
sections[sType].push_back(Section( (void *)((u_int64_t)addr + pos), sSize));
11168

11269
pos += sSize;
70+
71+
if (pos > size) {
72+
throw std::range_error("Section #" + std::to_string(i) + " is invalid."
73+
". It ends at pos " + std::to_string(pos) +
74+
" but should end before " + std::to_string(size) + ".");
75+
}
11376
}
11477

11578
pos = 0;
116-
readingSection = NULL;
79+
readingSection = nullptr;
11780
}
11881

119-
BinFile::~BinFile() {
120-
if (is_fd) {
121-
munmap(addr, size);
122-
close(fd);
123-
} else {
124-
free(addr);
125-
}
126-
}
12782

12883
void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos) {
12984

@@ -135,7 +90,7 @@ void BinFile::startReadSection(u_int32_t sectionId, u_int32_t sectionPos) {
13590
throw std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
13691
}
13792

138-
if (readingSection != NULL) {
93+
if (readingSection != nullptr) {
13994
throw std::range_error("Already reading a section");
14095
}
14196

@@ -150,7 +105,7 @@ void BinFile::endReadSection(bool check) {
150105
throw std::range_error("Invalid section size");
151106
}
152107
}
153-
readingSection = NULL;
108+
readingSection = nullptr;
154109
}
155110

156111
void *BinFile::getSectionData(u_int32_t sectionId, u_int32_t sectionPos) {
@@ -169,31 +124,49 @@ void *BinFile::getSectionData(u_int32_t sectionId, u_int32_t sectionPos) {
169124
u_int64_t BinFile::getSectionSize(u_int32_t sectionId, u_int32_t sectionPos) {
170125

171126
if (sections.find(sectionId) == sections.end()) {
172-
throw new std::range_error("Section does not exist: " + std::to_string(sectionId));
127+
throw std::range_error("Section does not exist: " + std::to_string(sectionId));
173128
}
174129

175130
if (sectionPos >= sections[sectionId].size()) {
176-
throw new std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
131+
throw std::range_error("Section pos too big. There are " + std::to_string(sections[sectionId].size()) + " and it's trying to access section: " + std::to_string(sectionPos));
177132
}
178133

179134
return sections[sectionId][sectionPos].size;
180135
}
181136

182137
u_int32_t BinFile::readU32LE() {
138+
const u_int64_t new_pos = pos + 4;
139+
140+
if (new_pos > size) {
141+
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
142+
}
143+
183144
u_int32_t res = *((u_int32_t *)((u_int64_t)addr + pos));
184-
pos += 4;
145+
pos = new_pos;
185146
return res;
186147
}
187148

188149
u_int64_t BinFile::readU64LE() {
150+
const u_int64_t new_pos = pos + 8;
151+
152+
if (new_pos > size) {
153+
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
154+
}
155+
189156
u_int64_t res = *((u_int64_t *)((u_int64_t)addr + pos));
190-
pos += 8;
157+
pos = new_pos;
191158
return res;
192159
}
193160

194161
void *BinFile::read(u_int64_t len) {
162+
const u_int64_t new_pos = pos + len;
163+
164+
if (new_pos > size) {
165+
throw std::range_error("File pos is too big. There are " + std::to_string(size) + " bytes and it's trying to access byte " + std::to_string(new_pos));
166+
}
167+
195168
void *res = (void *)((u_int64_t)addr + pos);
196-
pos += len;
169+
pos = new_pos;
197170
return res;
198171
}
199172

src/binfile_utils.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
#include <map>
55
#include <vector>
66
#include <memory>
7+
#include "fileloader.hpp"
78

89
namespace BinFileUtils {
910

1011
class BinFile {
1112

12-
bool is_fd;
13-
int fd;
13+
FileLoader fileLoader;
1414

15-
void *addr;
15+
const void *addr;
1616
u_int64_t size;
1717
u_int64_t pos;
1818

@@ -32,14 +32,14 @@ namespace BinFileUtils {
3232

3333
Section *readingSection;
3434

35+
void readFileData(std::string _type, uint32_t maxVersion);
3536

3637
public:
3738

3839
BinFile(const void *fileData, size_t fileSize, std::string _type, uint32_t maxVersion);
3940
BinFile(const std::string& fileName, const std::string& _type, uint32_t maxVersion);
4041
BinFile(const BinFile&) = delete;
4142
BinFile& operator=(const BinFile&) = delete;
42-
~BinFile();
4343

4444
void startReadSection(u_int32_t sectionId, u_int32_t setionPos = 0);
4545
void endReadSection(bool check = true);

src/fileloader.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,23 @@
99

1010
namespace BinFileUtils {
1111

12+
FileLoader::FileLoader()
13+
: fd(-1)
14+
{
15+
}
16+
1217
FileLoader::FileLoader(const std::string& fileName)
18+
: fd(-1)
1319
{
20+
load(fileName);
21+
}
22+
23+
void FileLoader::load(const std::string& fileName)
24+
{
25+
if (fd != -1) {
26+
throw std::invalid_argument("file already loaded");
27+
}
28+
1429
struct stat sb;
1530

1631
fd = open(fileName.c_str(), O_RDONLY);
@@ -26,12 +41,21 @@ FileLoader::FileLoader(const std::string& fileName)
2641
size = sb.st_size;
2742

2843
addr = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd, 0);
44+
45+
if (addr == MAP_FAILED) {
46+
close(fd);
47+
throw std::system_error(errno, std::generic_category(), "mmap failed");
48+
}
49+
50+
madvise(addr, size, MADV_SEQUENTIAL);
2951
}
3052

3153
FileLoader::~FileLoader()
3254
{
33-
munmap(addr, size);
34-
close(fd);
55+
if (fd != -1) {
56+
munmap(addr, size);
57+
close(fd);
58+
}
3559
}
3660

3761
} // Namespace

src/fileloader.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ namespace BinFileUtils {
99
class FileLoader
1010
{
1111
public:
12+
FileLoader();
1213
FileLoader(const std::string& fileName);
1314
~FileLoader();
1415

16+
void load(const std::string& fileName);
17+
1518
void* dataBuffer() { return addr; }
1619
size_t dataSize() const { return size; }
1720

src/groth16.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "random_generator.hpp"
22
#include "logging.hpp"
33
#include "misc.hpp"
4+
#include <sstream>
45
#include <vector>
56
#include <mutex>
67

@@ -84,7 +85,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
8485
auto b = new typename Engine::FrElement[domainSize];
8586
auto c = new typename Engine::FrElement[domainSize];
8687

87-
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
88+
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
8889
for (u_int32_t i=begin; i<end; i++) {
8990
E.fr.copy(a[i], E.fr.zero());
9091
E.fr.copy(b[i], E.fr.zero());
@@ -96,7 +97,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
9697
#define NLOCKS 1024
9798
std::vector<std::mutex> locks(NLOCKS);
9899

99-
threadPool.parallelFor(0, nCoefs, [&] (int begin, int end, int numThread) {
100+
threadPool.parallelFor(0, nCoefs, [&] (int64_t begin, int64_t end, uint64_t idThread) {
100101
for (u_int64_t i=begin; i<end; i++) {
101102
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
102103
typename Engine::FrElement aux;
@@ -117,7 +118,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
117118
}
118119
});
119120
LOG_TRACE("Calculating c");
120-
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
121+
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
121122
for (u_int64_t i=begin; i<end; i++) {
122123
E.fr.mul(
123124
c[i],
@@ -137,7 +138,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
137138
LOG_DEBUG(E.fr.toString(a[1]).c_str());
138139
LOG_TRACE("Start Shift A");
139140

140-
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
141+
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
141142
for (u_int64_t i=begin; i<end; i++) {
142143
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
143144
}
@@ -157,7 +158,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
157158
LOG_DEBUG(E.fr.toString(b[0]).c_str());
158159
LOG_DEBUG(E.fr.toString(b[1]).c_str());
159160
LOG_TRACE("Start Shift B");
160-
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
161+
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
161162
for (u_int64_t i=begin; i<end; i++) {
162163
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
163164
}
@@ -177,7 +178,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
177178
LOG_DEBUG(E.fr.toString(c[0]).c_str());
178179
LOG_DEBUG(E.fr.toString(c[1]).c_str());
179180
LOG_TRACE("Start Shift C");
180-
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
181+
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
181182
for (u_int64_t i=begin; i<end; i++) {
182183
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
183184
}
@@ -192,7 +193,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
192193
LOG_DEBUG(E.fr.toString(c[1]).c_str());
193194

194195
LOG_TRACE("Start ABC");
195-
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
196+
threadPool.parallelFor(0, domainSize, [&] (int64_t begin, int64_t end, uint64_t idThread) {
196197
for (u_int64_t i=begin; i<end; i++) {
197198
E.fr.mul(a[i], a[i], b[i]);
198199
E.fr.sub(a[i], a[i], c[i]);

0 commit comments

Comments
 (0)