Skip to content

Commit 0a757ec

Browse files
AmxxRenanSouza2ernestognw
authored
Add sort in memory to Arrays library (#4846)
Co-authored-by: RenanSouza2 <[email protected]> Co-authored-by: Ernesto García <[email protected]>
1 parent 036c3cb commit 0a757ec

File tree

7 files changed

+184
-28
lines changed

7 files changed

+184
-28
lines changed

.changeset/dirty-cobras-smile.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`Arrays`: add a `sort` function.

contracts/utils/Arrays.sol

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,69 @@ import {Math} from "./math/Math.sol";
1212
library Arrays {
1313
using StorageSlot for bytes32;
1414

15+
/**
16+
* @dev Sort an array (in memory) in increasing order.
17+
*
18+
* This function does the sorting "in place", meaning that it overrides the input. The object is returned for
19+
* convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
20+
*
21+
* NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the
22+
* array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful
23+
* when executing this as part of a transaction. If the array being sorted is too large, the sort operation may
24+
* consume more gas than is available in a block, leading to potential DoS.
25+
*/
26+
function sort(uint256[] memory array) internal pure returns (uint256[] memory) {
27+
_quickSort(array, 0, array.length);
28+
return array;
29+
}
30+
31+
/**
32+
* @dev Performs a quick sort on an array in memory. The array is sorted in increasing order.
33+
*
34+
* Invariant: `i <= j <= array.length`. This is the case when initially called by {sort} and is preserved in
35+
* subcalls.
36+
*/
37+
function _quickSort(uint256[] memory array, uint256 i, uint256 j) private pure {
38+
unchecked {
39+
// Can't overflow given `i <= j`
40+
if (j - i < 2) return;
41+
42+
// Use first element as pivot
43+
uint256 pivot = unsafeMemoryAccess(array, i);
44+
// Position where the pivot should be at the end of the loop
45+
uint256 index = i;
46+
47+
for (uint256 k = i + 1; k < j; ++k) {
48+
// Unsafe access is safe given `k < j <= array.length`.
49+
if (unsafeMemoryAccess(array, k) < pivot) {
50+
// If array[k] is smaller than the pivot, we increment the index and move array[k] there.
51+
_swap(array, ++index, k);
52+
}
53+
}
54+
55+
// Swap pivot into place
56+
_swap(array, i, index);
57+
58+
_quickSort(array, i, index); // Sort the left side of the pivot
59+
_quickSort(array, index + 1, j); // Sort the right side of the pivot
60+
}
61+
}
62+
63+
/**
64+
* @dev Swaps the elements at positions `i` and `j` in the `arr` array.
65+
*/
66+
function _swap(uint256[] memory arr, uint256 i, uint256 j) private pure {
67+
assembly {
68+
let start := add(arr, 0x20) // Pointer to the first element of the array
69+
let pos_i := add(start, mul(i, 0x20))
70+
let pos_j := add(start, mul(j, 0x20))
71+
let val_i := mload(pos_i)
72+
let val_j := mload(pos_j)
73+
mstore(pos_i, val_j)
74+
mstore(pos_j, val_i)
75+
}
76+
}
77+
1578
/**
1679
* @dev Searches a sorted `array` and returns the first index that contains
1780
* a value greater or equal to `element`. If no such index exists (i.e. all
@@ -238,7 +301,7 @@ library Arrays {
238301
*
239302
* WARNING: Only use if you are certain `pos` is lower than the array length.
240303
*/
241-
function unsafeMemoryAccess(uint256[] memory arr, uint256 pos) internal pure returns (uint256 res) {
304+
function unsafeMemoryAccess(address[] memory arr, uint256 pos) internal pure returns (address res) {
242305
assembly {
243306
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
244307
}
@@ -249,7 +312,18 @@ library Arrays {
249312
*
250313
* WARNING: Only use if you are certain `pos` is lower than the array length.
251314
*/
252-
function unsafeMemoryAccess(address[] memory arr, uint256 pos) internal pure returns (address res) {
315+
function unsafeMemoryAccess(bytes32[] memory arr, uint256 pos) internal pure returns (bytes32 res) {
316+
assembly {
317+
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
318+
}
319+
}
320+
321+
/**
322+
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
323+
*
324+
* WARNING: Only use if you are certain `pos` is lower than the array length.
325+
*/
326+
function unsafeMemoryAccess(uint256[] memory arr, uint256 pos) internal pure returns (uint256 res) {
253327
assembly {
254328
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
255329
}

scripts/generate/templates/Checkpoints.t.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ const header = `\
77
pragma solidity ^0.8.20;
88
99
import {Test} from "forge-std/Test.sol";
10-
import {SafeCast} from "../../../contracts/utils/math/SafeCast.sol";
11-
import {Checkpoints} from "../../../contracts/utils/structs/Checkpoints.sol";
10+
import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
11+
import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol";
1212
`;
1313

1414
/* eslint-disable max-len */

test/utils/Arrays.t.sol

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// SPDX-License-Identifier: MIT
2+
3+
pragma solidity ^0.8.20;
4+
5+
import {Test} from "forge-std/Test.sol";
6+
import {Arrays} from "@openzeppelin/contracts/utils/Arrays.sol";
7+
8+
contract ArraysTest is Test {
9+
function testSort(uint256[] memory values) public {
10+
Arrays.sort(values);
11+
for (uint256 i = 1; i < values.length; ++i) {
12+
assertLe(values[i - 1], values[i]);
13+
}
14+
}
15+
}

test/utils/Arrays.test.js

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,56 @@ const upperBound = (array, value) => {
1616
return i == -1 ? array.length : i;
1717
};
1818

19+
// By default, js "sort" cast to string and then sort in alphabetical order. Use this to sort numbers.
20+
const compareNumbers = (a, b) => (a > b ? 1 : a < b ? -1 : 0);
21+
1922
const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i);
2023

2124
describe('Arrays', function () {
25+
const fixture = async () => {
26+
return { mock: await ethers.deployContract('$Arrays') };
27+
};
28+
29+
beforeEach(async function () {
30+
Object.assign(this, await loadFixture(fixture));
31+
});
32+
33+
describe('sort', function () {
34+
for (const length of [0, 1, 2, 8, 32, 128]) {
35+
it(`sort array of length ${length}`, async function () {
36+
this.elements = randomArray(generators.uint256, length);
37+
this.expected = Array.from(this.elements).sort(compareNumbers);
38+
});
39+
40+
if (length > 1) {
41+
it(`sort array of length ${length} (identical elements)`, async function () {
42+
this.elements = Array(length).fill(generators.uint256());
43+
this.expected = this.elements;
44+
});
45+
46+
it(`sort array of length ${length} (already sorted)`, async function () {
47+
this.elements = randomArray(generators.uint256, length).sort(compareNumbers);
48+
this.expected = this.elements;
49+
});
50+
51+
it(`sort array of length ${length} (sorted in reverse order)`, async function () {
52+
this.elements = randomArray(generators.uint256, length).sort(compareNumbers).reverse();
53+
this.expected = Array.from(this.elements).reverse();
54+
});
55+
56+
it(`sort array of length ${length} (almost sorted)`, async function () {
57+
this.elements = randomArray(generators.uint256, length).sort(compareNumbers);
58+
this.expected = Array.from(this.elements);
59+
// rotate (move the last element to the front) for an almost sorted effect
60+
this.elements.unshift(this.elements.pop());
61+
});
62+
}
63+
}
64+
afterEach(async function () {
65+
expect(await this.mock.$sort(this.elements)).to.deep.equal(this.expected);
66+
});
67+
});
68+
2269
describe('search', function () {
2370
for (const [title, { array, tests }] of Object.entries({
2471
'Even number of elements': {
@@ -74,7 +121,7 @@ describe('Arrays', function () {
74121
})) {
75122
describe(title, function () {
76123
const fixture = async () => {
77-
return { mock: await ethers.deployContract('Uint256ArraysMock', [array]) };
124+
return { instance: await ethers.deployContract('Uint256ArraysMock', [array]) };
78125
};
79126

80127
beforeEach(async function () {
@@ -86,20 +133,20 @@ describe('Arrays', function () {
86133
it('[deprecated] findUpperBound', async function () {
87134
// findUpperBound does not support duplicated
88135
if (hasDuplicates(array)) {
89-
expect(await this.mock.findUpperBound(input)).to.be.equal(upperBound(array, input) - 1);
136+
expect(await this.instance.findUpperBound(input)).to.equal(upperBound(array, input) - 1);
90137
} else {
91-
expect(await this.mock.findUpperBound(input)).to.be.equal(lowerBound(array, input));
138+
expect(await this.instance.findUpperBound(input)).to.equal(lowerBound(array, input));
92139
}
93140
});
94141

95142
it('lowerBound', async function () {
96-
expect(await this.mock.lowerBound(input)).to.be.equal(lowerBound(array, input));
97-
expect(await this.mock.lowerBoundMemory(array, input)).to.be.equal(lowerBound(array, input));
143+
expect(await this.instance.lowerBound(input)).to.equal(lowerBound(array, input));
144+
expect(await this.instance.lowerBoundMemory(array, input)).to.equal(lowerBound(array, input));
98145
});
99146

100147
it('upperBound', async function () {
101-
expect(await this.mock.upperBound(input)).to.be.equal(upperBound(array, input));
102-
expect(await this.mock.upperBoundMemory(array, input)).to.be.equal(upperBound(array, input));
148+
expect(await this.instance.upperBound(input)).to.equal(upperBound(array, input));
149+
expect(await this.instance.upperBoundMemory(array, input)).to.equal(upperBound(array, input));
103150
});
104151
});
105152
}
@@ -108,28 +155,44 @@ describe('Arrays', function () {
108155
});
109156

110157
describe('unsafeAccess', function () {
111-
for (const [title, { artifact, elements }] of Object.entries({
158+
for (const [type, { artifact, elements }] of Object.entries({
112159
address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) },
113160
bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) },
114161
uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) },
115162
})) {
116-
describe(title, function () {
117-
const fixture = async () => {
118-
return { mock: await ethers.deployContract(artifact, [elements]) };
119-
};
163+
describe(type, function () {
164+
describe('storage', function () {
165+
const fixture = async () => {
166+
return { instance: await ethers.deployContract(artifact, [elements]) };
167+
};
120168

121-
beforeEach(async function () {
122-
Object.assign(this, await loadFixture(fixture));
123-
});
169+
beforeEach(async function () {
170+
Object.assign(this, await loadFixture(fixture));
171+
});
124172

125-
for (const i in elements) {
126-
it(`unsafeAccess within bounds #${i}`, async function () {
127-
expect(await this.mock.unsafeAccess(i)).to.equal(elements[i]);
173+
for (const i in elements) {
174+
it(`unsafeAccess within bounds #${i}`, async function () {
175+
expect(await this.instance.unsafeAccess(i)).to.equal(elements[i]);
176+
});
177+
}
178+
179+
it('unsafeAccess outside bounds', async function () {
180+
await expect(this.instance.unsafeAccess(elements.length)).to.not.be.rejected;
128181
});
129-
}
182+
});
183+
184+
describe('memory', function () {
185+
const fragment = `$unsafeMemoryAccess(${type}[] arr, uint256 pos)`;
130186

131-
it('unsafeAccess outside bounds', async function () {
132-
await expect(this.mock.unsafeAccess(elements.length)).to.not.be.rejected;
187+
for (const i in elements) {
188+
it(`unsafeMemoryAccess within bounds #${i}`, async function () {
189+
expect(await this.mock[fragment](elements, i)).to.equal(elements[i]);
190+
});
191+
}
192+
193+
it('unsafeMemoryAccess outside bounds', async function () {
194+
await expect(this.mock[fragment](elements, elements.length)).to.not.be.rejected;
195+
});
133196
});
134197
});
135198
}

test/utils/Base64.t.sol

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
pragma solidity ^0.8.20;
44

55
import {Test} from "forge-std/Test.sol";
6-
76
import {Base64} from "@openzeppelin/contracts/utils/Base64.sol";
87

98
contract Base64Test is Test {

test/utils/structs/Checkpoints.t.sol

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
pragma solidity ^0.8.20;
55

66
import {Test} from "forge-std/Test.sol";
7-
import {SafeCast} from "../../../contracts/utils/math/SafeCast.sol";
8-
import {Checkpoints} from "../../../contracts/utils/structs/Checkpoints.sol";
7+
import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
8+
import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol";
99

1010
contract CheckpointsTrace224Test is Test {
1111
using Checkpoints for Checkpoints.Trace224;

0 commit comments

Comments
 (0)