Skip to content

Commit 441dc14

Browse files
Amxxernestognw
andauthored
Add Bytes32x2Set (OpenZeppelin#5442)
Co-authored-by: Ernesto García <[email protected]>
1 parent 2141d3f commit 441dc14

File tree

7 files changed

+261
-12
lines changed

7 files changed

+261
-12
lines changed

.changeset/lucky-teachers-sip.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`EnumerableSet`: Add `Bytes32x2Set` that handles (ordered) pairs of bytes32.

.changeset/ten-peas-mix.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`Hashes`: Expose `efficientKeccak256` for hashing non-commutative pairs of bytes32 without allocating extra memory.

contracts/utils/cryptography/Hashes.sol

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ library Hashes {
1515
* NOTE: Equivalent to the `standardNodeHash` in our https://github.com/OpenZeppelin/merkle-tree[JavaScript library].
1616
*/
1717
function commutativeKeccak256(bytes32 a, bytes32 b) internal pure returns (bytes32) {
18-
return a < b ? _efficientKeccak256(a, b) : _efficientKeccak256(b, a);
18+
return a < b ? efficientKeccak256(a, b) : efficientKeccak256(b, a);
1919
}
2020

2121
/**
2222
* @dev Implementation of keccak256(abi.encode(a, b)) that doesn't allocate or expand memory.
2323
*/
24-
function _efficientKeccak256(bytes32 a, bytes32 b) private pure returns (bytes32 value) {
24+
function efficientKeccak256(bytes32 a, bytes32 b) internal pure returns (bytes32 value) {
2525
assembly ("memory-safe") {
2626
mstore(0x00, a)
2727
mstore(0x20, b)

contracts/utils/structs/EnumerableSet.sol

+113
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
pragma solidity ^0.8.20;
66

7+
import {Hashes} from "../cryptography/Hashes.sol";
8+
79
/**
810
* @dev Library for managing
911
* https://en.wikipedia.org/wiki/Set_(abstract_data_type)[sets] of primitive
@@ -372,4 +374,115 @@ library EnumerableSet {
372374

373375
return result;
374376
}
377+
378+
struct Bytes32x2Set {
379+
// Storage of set values
380+
bytes32[2][] _values;
381+
// Position is the index of the value in the `values` array plus 1.
382+
// Position 0 is used to mean a value is not in the self.
383+
mapping(bytes32 valueHash => uint256) _positions;
384+
}
385+
386+
/**
387+
* @dev Add a value to a self. O(1).
388+
*
389+
* Returns true if the value was added to the set, that is if it was not
390+
* already present.
391+
*/
392+
function add(Bytes32x2Set storage self, bytes32[2] memory value) internal returns (bool) {
393+
if (!contains(self, value)) {
394+
self._values.push(value);
395+
// The value is stored at length-1, but we add 1 to all indexes
396+
// and use 0 as a sentinel value
397+
self._positions[_hash(value)] = self._values.length;
398+
return true;
399+
} else {
400+
return false;
401+
}
402+
}
403+
404+
/**
405+
* @dev Removes a value from a self. O(1).
406+
*
407+
* Returns true if the value was removed from the set, that is if it was
408+
* present.
409+
*/
410+
function remove(Bytes32x2Set storage self, bytes32[2] memory value) internal returns (bool) {
411+
// We cache the value's position to prevent multiple reads from the same storage slot
412+
bytes32 valueHash = _hash(value);
413+
uint256 position = self._positions[valueHash];
414+
415+
if (position != 0) {
416+
// Equivalent to contains(self, value)
417+
// To delete an element from the _values array in O(1), we swap the element to delete with the last one in
418+
// the array, and then remove the last element (sometimes called as 'swap and pop').
419+
// This modifies the order of the array, as noted in {at}.
420+
421+
uint256 valueIndex = position - 1;
422+
uint256 lastIndex = self._values.length - 1;
423+
424+
if (valueIndex != lastIndex) {
425+
bytes32[2] memory lastValue = self._values[lastIndex];
426+
427+
// Move the lastValue to the index where the value to delete is
428+
self._values[valueIndex] = lastValue;
429+
// Update the tracked position of the lastValue (that was just moved)
430+
self._positions[_hash(lastValue)] = position;
431+
}
432+
433+
// Delete the slot where the moved value was stored
434+
self._values.pop();
435+
436+
// Delete the tracked position for the deleted slot
437+
delete self._positions[valueHash];
438+
439+
return true;
440+
} else {
441+
return false;
442+
}
443+
}
444+
445+
/**
446+
* @dev Returns true if the value is in the self. O(1).
447+
*/
448+
function contains(Bytes32x2Set storage self, bytes32[2] memory value) internal view returns (bool) {
449+
return self._positions[_hash(value)] != 0;
450+
}
451+
452+
/**
453+
* @dev Returns the number of values on the self. O(1).
454+
*/
455+
function length(Bytes32x2Set storage self) internal view returns (uint256) {
456+
return self._values.length;
457+
}
458+
459+
/**
460+
* @dev Returns the value stored at position `index` in the self. O(1).
461+
*
462+
* Note that there are no guarantees on the ordering of values inside the
463+
* array, and it may change when more values are added or removed.
464+
*
465+
* Requirements:
466+
*
467+
* - `index` must be strictly less than {length}.
468+
*/
469+
function at(Bytes32x2Set storage self, uint256 index) internal view returns (bytes32[2] memory) {
470+
return self._values[index];
471+
}
472+
473+
/**
474+
* @dev Return the entire set in an array
475+
*
476+
* WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
477+
* to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
478+
* this function has an unbounded cost, and using it as part of a state-changing function may render the function
479+
* uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
480+
*/
481+
function values(Bytes32x2Set storage self) internal view returns (bytes32[2][] memory) {
482+
return self._values;
483+
}
484+
485+
function _hash(bytes32[2] memory value) private pure returns (bytes32) {
486+
return Hashes.efficientKeccak256(value[0], value[1]);
487+
}
375488
}

scripts/generate/templates/EnumerableSet.js

+120-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ const { TYPES } = require('./EnumerableSet.opts');
55
const header = `\
66
pragma solidity ^0.8.20;
77
8+
import {Hashes} from "../cryptography/Hashes.sol";
9+
810
/**
911
* @dev Library for managing
1012
* https://en.wikipedia.org/wiki/Set_(abstract_data_type)[sets] of primitive
@@ -233,14 +235,131 @@ function values(${name} storage set) internal view returns (${type}[] memory) {
233235
}
234236
`;
235237

238+
const memorySet = ({ name, type }) => `\
239+
struct ${name} {
240+
// Storage of set values
241+
${type}[] _values;
242+
// Position is the index of the value in the \`values\` array plus 1.
243+
// Position 0 is used to mean a value is not in the self.
244+
mapping(bytes32 valueHash => uint256) _positions;
245+
}
246+
247+
/**
248+
* @dev Add a value to a self. O(1).
249+
*
250+
* Returns true if the value was added to the set, that is if it was not
251+
* already present.
252+
*/
253+
function add(${name} storage self, ${type} memory value) internal returns (bool) {
254+
if (!contains(self, value)) {
255+
self._values.push(value);
256+
// The value is stored at length-1, but we add 1 to all indexes
257+
// and use 0 as a sentinel value
258+
self._positions[_hash(value)] = self._values.length;
259+
return true;
260+
} else {
261+
return false;
262+
}
263+
}
264+
265+
/**
266+
* @dev Removes a value from a self. O(1).
267+
*
268+
* Returns true if the value was removed from the set, that is if it was
269+
* present.
270+
*/
271+
function remove(${name} storage self, ${type} memory value) internal returns (bool) {
272+
// We cache the value's position to prevent multiple reads from the same storage slot
273+
bytes32 valueHash = _hash(value);
274+
uint256 position = self._positions[valueHash];
275+
276+
if (position != 0) {
277+
// Equivalent to contains(self, value)
278+
// To delete an element from the _values array in O(1), we swap the element to delete with the last one in
279+
// the array, and then remove the last element (sometimes called as 'swap and pop').
280+
// This modifies the order of the array, as noted in {at}.
281+
282+
uint256 valueIndex = position - 1;
283+
uint256 lastIndex = self._values.length - 1;
284+
285+
if (valueIndex != lastIndex) {
286+
${type} memory lastValue = self._values[lastIndex];
287+
288+
// Move the lastValue to the index where the value to delete is
289+
self._values[valueIndex] = lastValue;
290+
// Update the tracked position of the lastValue (that was just moved)
291+
self._positions[_hash(lastValue)] = position;
292+
}
293+
294+
// Delete the slot where the moved value was stored
295+
self._values.pop();
296+
297+
// Delete the tracked position for the deleted slot
298+
delete self._positions[valueHash];
299+
300+
return true;
301+
} else {
302+
return false;
303+
}
304+
}
305+
306+
/**
307+
* @dev Returns true if the value is in the self. O(1).
308+
*/
309+
function contains(${name} storage self, ${type} memory value) internal view returns (bool) {
310+
return self._positions[_hash(value)] != 0;
311+
}
312+
313+
/**
314+
* @dev Returns the number of values on the self. O(1).
315+
*/
316+
function length(${name} storage self) internal view returns (uint256) {
317+
return self._values.length;
318+
}
319+
320+
/**
321+
* @dev Returns the value stored at position \`index\` in the self. O(1).
322+
*
323+
* Note that there are no guarantees on the ordering of values inside the
324+
* array, and it may change when more values are added or removed.
325+
*
326+
* Requirements:
327+
*
328+
* - \`index\` must be strictly less than {length}.
329+
*/
330+
function at(${name} storage self, uint256 index) internal view returns (${type} memory) {
331+
return self._values[index];
332+
}
333+
334+
/**
335+
* @dev Return the entire set in an array
336+
*
337+
* WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
338+
* to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
339+
* this function has an unbounded cost, and using it as part of a state-changing function may render the function
340+
* uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
341+
*/
342+
function values(${name} storage self) internal view returns (${type}[] memory) {
343+
return self._values;
344+
}
345+
`;
346+
347+
const hashes = `\
348+
function _hash(bytes32[2] memory value) private pure returns (bytes32) {
349+
return Hashes.efficientKeccak256(value[0], value[1]);
350+
}
351+
`;
352+
236353
// GENERATE
237354
module.exports = format(
238355
header.trimEnd(),
239356
'library EnumerableSet {',
240357
format(
241358
[].concat(
242359
defaultSet,
243-
TYPES.map(details => customSet(details)),
360+
TYPES.filter(({ size }) => size == undefined).map(details => customSet(details)),
361+
TYPES.filter(({ size }) => size != undefined).map(details => memorySet(details)),
362+
hashes,
244363
),
245364
).trimEnd(),
246365
'}',
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
const { capitalize } = require('../../helpers');
22

3-
const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str));
3+
const mapType = ({ type, size }) => [type == 'uint256' ? 'Uint' : capitalize(type), size].filter(Boolean).join('x');
44

5-
const formatType = type => ({
6-
name: `${mapType(type)}Set`,
7-
type,
5+
const formatType = ({ type, size = undefined }) => ({
6+
name: `${mapType({ type, size })}Set`,
7+
type: size != undefined ? `${type}[${size}]` : type,
8+
base: size != undefined ? type : undefined,
9+
size,
810
});
911

10-
const TYPES = ['bytes32', 'address', 'uint256'].map(formatType);
12+
const TYPES = [{ type: 'bytes32' }, { type: 'bytes32', size: 2 }, { type: 'address' }, { type: 'uint256' }].map(
13+
formatType,
14+
);
1115

1216
module.exports = { TYPES, formatType };

test/utils/structs/EnumerableSet.test.js

+7-4
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ async function fixture() {
2020
const mock = await ethers.deployContract('$EnumerableSet');
2121

2222
const env = Object.fromEntries(
23-
TYPES.map(({ name, type }) => [
23+
TYPES.map(({ name, type, base, size }) => [
2424
type,
2525
{
26-
values: Array.from({ length: 3 }, generators[type]),
26+
values: Array.from(
27+
{ length: 3 },
28+
size ? () => Array.from({ length: size }, generators[base]) : generators[type],
29+
),
2730
methods: getMethods(mock, {
2831
add: `$add(uint256,${type})`,
2932
remove: `$remove(uint256,${type})`,
@@ -33,8 +36,8 @@ async function fixture() {
3336
values: `$values_EnumerableSet_${name}(uint256)`,
3437
}),
3538
events: {
36-
addReturn: `return$add_EnumerableSet_${name}_${type}`,
37-
removeReturn: `return$remove_EnumerableSet_${name}_${type}`,
39+
addReturn: `return$add_EnumerableSet_${name}_${type.replace(/[[\]]/g, '_')}`,
40+
removeReturn: `return$remove_EnumerableSet_${name}_${type.replace(/[[\]]/g, '_')}`,
3841
},
3942
},
4043
]),

0 commit comments

Comments
 (0)