From 8efff87e810c07e119a59176008981221953d864 Mon Sep 17 00:00:00 2001 From: Anton Bukov Date: Sun, 25 Dec 2022 12:14:45 +0400 Subject: [PATCH] Refactor to reduce code size and avoid memory structs --- contracts/ERC1155Pods.sol | 24 ++++----- contracts/ERC20Pods.sol | 24 ++++----- contracts/ERC721Pods.sol | 24 ++++----- contracts/Pod.sol | 3 +- contracts/TokenPodsLib.sol | 104 +++++++++++++++---------------------- package.json | 2 +- 6 files changed, 74 insertions(+), 107 deletions(-) diff --git a/contracts/ERC1155Pods.sol b/contracts/ERC1155Pods.sol index cf103b5..e031f19 100644 --- a/contracts/ERC1155Pods.sol +++ b/contracts/ERC1155Pods.sol @@ -9,7 +9,7 @@ import "./TokenPodsLib.sol"; import "./libs/ReentrancyGuard.sol"; abstract contract ERC1155Pods is ERC1155, IERC1155Pods, ReentrancyGuardExt { - using TokenPodsLib for TokenPodsLib.Info; + using TokenPodsLib for TokenPodsLib.Data; using ReentrancyGuardLib for ReentrancyGuardLib.Data; error ZeroPodsLimit(); @@ -29,19 +29,19 @@ abstract contract ERC1155Pods is ERC1155, IERC1155Pods, ReentrancyGuardExt { } function hasPod(address account, address pod, uint256 id) public view virtual returns(bool) { - return _info(id).hasPod(account, pod); + return _pods[id].hasPod(account, pod); } function podsCount(address account, uint256 id) public view virtual returns(uint256) { - return _info(id).podsCount(account); + return _pods[id].podsCount(account); } function podAt(address account, uint256 index, uint256 id) public view virtual returns(address) { - return _info(id).podAt(account, index); + return _pods[id].podAt(account, index); } function pods(address account, uint256 id) public view virtual returns(address[] memory) { - return _info(id).pods(account); + return _pods[id].pods(account); } function balanceOf(address account, uint256 id) public nonReentrantView(_guard) view override(IERC1155, ERC1155) virtual returns(uint256) { @@ -49,23 +49,19 @@ abstract contract ERC1155Pods is ERC1155, IERC1155Pods, ReentrancyGuardExt { } function podBalanceOf(address pod, address account, uint256 id) public nonReentrantView(_guard) view returns(uint256) { - return _info(id).podBalanceOf(account, pod, super.balanceOf(msg.sender, id)); + return _pods[id].podBalanceOf(account, pod, super.balanceOf(msg.sender, id)); } function addPod(address pod, uint256 id) public virtual { - if (_info(id).addPod(msg.sender, pod, balanceOf(msg.sender, id)) > podsLimit) revert PodsLimitReachedForAccount(); + if (_pods[id].addPod(msg.sender, pod, balanceOf(msg.sender, id), podCallGasLimit) > podsLimit) revert PodsLimitReachedForAccount(); } function removePod(address pod, uint256 id) public virtual { - _info(id).removePod(msg.sender, pod, balanceOf(msg.sender, id)); + _pods[id].removePod(msg.sender, pod, balanceOf(msg.sender, id), podCallGasLimit); } function removeAllPods(uint256 id) public virtual { - _info(id).removeAllPods(msg.sender, balanceOf(msg.sender, id)); - } - - function _info(uint256 id) private view returns(TokenPodsLib.Info memory) { - return TokenPodsLib.makeInfo(_pods[id], podCallGasLimit); + _pods[id].removeAllPods(msg.sender, balanceOf(msg.sender, id), podCallGasLimit); } // ERC1155 Overrides @@ -82,7 +78,7 @@ abstract contract ERC1155Pods is ERC1155, IERC1155Pods, ReentrancyGuardExt { unchecked { for (uint256 i = 0; i < ids.length; i++) { - _info(ids[i]).updateBalancesWithTokenId(from, to, amounts[i], ids[i]); + _pods[i].updateBalancesWithTokenId(from, to, amounts[i], ids[i], podCallGasLimit); } } } diff --git a/contracts/ERC20Pods.sol b/contracts/ERC20Pods.sol index 680b3ec..e6d61fb 100644 --- a/contracts/ERC20Pods.sol +++ b/contracts/ERC20Pods.sol @@ -10,7 +10,7 @@ import "./TokenPodsLib.sol"; import "./libs/ReentrancyGuard.sol"; abstract contract ERC20Pods is ERC20, IERC20Pods, ReentrancyGuardExt { - using TokenPodsLib for TokenPodsLib.Info; + using TokenPodsLib for TokenPodsLib.Data; using ReentrancyGuardLib for ReentrancyGuardLib.Data; error ZeroPodsLimit(); @@ -30,19 +30,19 @@ abstract contract ERC20Pods is ERC20, IERC20Pods, ReentrancyGuardExt { } function hasPod(address account, address pod) public view virtual returns(bool) { - return _info().hasPod(account, pod); + return _pods.hasPod(account, pod); } function podsCount(address account) public view virtual returns(uint256) { - return _info().podsCount(account); + return _pods.podsCount(account); } function podAt(address account, uint256 index) public view virtual returns(address) { - return _info().podAt(account, index); + return _pods.podAt(account, index); } function pods(address account) public view virtual returns(address[] memory) { - return _info().pods(account); + return _pods.pods(account); } function balanceOf(address account) public nonReentrantView(_guard) view override(IERC20, ERC20) virtual returns(uint256) { @@ -50,29 +50,25 @@ abstract contract ERC20Pods is ERC20, IERC20Pods, ReentrancyGuardExt { } function podBalanceOf(address pod, address account) public nonReentrantView(_guard) view virtual returns(uint256) { - return _info().podBalanceOf(account, pod, super.balanceOf(account)); + return _pods.podBalanceOf(account, pod, super.balanceOf(account)); } function addPod(address pod) public virtual { - if (_info().addPod(msg.sender, pod, balanceOf(msg.sender)) > podsLimit) revert PodsLimitReachedForAccount(); + if (_pods.addPod(msg.sender, pod, balanceOf(msg.sender), podCallGasLimit) > podsLimit) revert PodsLimitReachedForAccount(); } function removePod(address pod) public virtual { - _info().removePod(msg.sender, pod, balanceOf(msg.sender)); + _pods.removePod(msg.sender, pod, balanceOf(msg.sender), podCallGasLimit); } function removeAllPods() public virtual { - _info().removeAllPods(msg.sender, balanceOf(msg.sender)); - } - - function _info() private view returns(TokenPodsLib.Info memory) { - return TokenPodsLib.makeInfo(_pods, podCallGasLimit); + _pods.removeAllPods(msg.sender, balanceOf(msg.sender), podCallGasLimit); } // ERC20 Overrides function _afterTokenTransfer(address from, address to, uint256 amount) internal nonReentrant(_guard) override virtual { super._afterTokenTransfer(from, to, amount); - _info().updateBalances(from, to, amount); + _pods.updateBalances(from, to, amount, podCallGasLimit); } } diff --git a/contracts/ERC721Pods.sol b/contracts/ERC721Pods.sol index c4129b3..1ee301f 100644 --- a/contracts/ERC721Pods.sol +++ b/contracts/ERC721Pods.sol @@ -11,7 +11,7 @@ import "./TokenPodsLib.sol"; import "./libs/ReentrancyGuard.sol"; abstract contract ERC721Pods is ERC721, IERC721Pods, ReentrancyGuardExt { - using TokenPodsLib for TokenPodsLib.Info; + using TokenPodsLib for TokenPodsLib.Data; using ReentrancyGuardLib for ReentrancyGuardLib.Data; error ZeroPodsLimit(); @@ -31,19 +31,19 @@ abstract contract ERC721Pods is ERC721, IERC721Pods, ReentrancyGuardExt { } function hasPod(address account, address pod) public view virtual returns(bool) { - return _info().hasPod(account, pod); + return _pods.hasPod(account, pod); } function podsCount(address account) public view virtual returns(uint256) { - return _info().podsCount(account); + return _pods.podsCount(account); } function podAt(address account, uint256 index) public view virtual returns(address) { - return _info().podAt(account, index); + return _pods.podAt(account, index); } function pods(address account) public view virtual returns(address[] memory) { - return _info().pods(account); + return _pods.pods(account); } function balanceOf(address account) public nonReentrantView(_guard) view override(IERC721, ERC721) virtual returns(uint256) { @@ -51,29 +51,25 @@ abstract contract ERC721Pods is ERC721, IERC721Pods, ReentrancyGuardExt { } function podBalanceOf(address pod, address account) public nonReentrantView(_guard) view virtual returns(uint256) { - return _info().podBalanceOf(account, pod, super.balanceOf(account)); + return _pods.podBalanceOf(account, pod, super.balanceOf(account)); } function addPod(address pod) public virtual { - if (_info().addPod(msg.sender, pod, balanceOf(msg.sender)) > podsLimit) revert PodsLimitReachedForAccount(); + if (_pods.addPod(msg.sender, pod, balanceOf(msg.sender), podCallGasLimit) > podsLimit) revert PodsLimitReachedForAccount(); } function removePod(address pod) public virtual { - _info().removePod(msg.sender, pod, balanceOf(msg.sender)); + _pods.removePod(msg.sender, pod, balanceOf(msg.sender), podCallGasLimit); } function removeAllPods() public virtual { - _info().removeAllPods(msg.sender, balanceOf(msg.sender)); - } - - function _info() private view returns(TokenPodsLib.Info memory) { - return TokenPodsLib.makeInfo(_pods, podCallGasLimit); + _pods.removeAllPods(msg.sender, balanceOf(msg.sender), podCallGasLimit); } // ERC721 Overrides function _afterTokenTransfer(address from, address to, uint256 firstTokenId, uint256 batchSize) internal nonReentrant(_guard) override virtual { super._afterTokenTransfer(from, to, firstTokenId, batchSize); - _info().updateBalances(from, to, batchSize); + _pods.updateBalances(from, to, batchSize, podCallGasLimit); } } diff --git a/contracts/Pod.sol b/contracts/Pod.sol index 11f28c3..06adc62 100644 --- a/contracts/Pod.sol +++ b/contracts/Pod.sol @@ -3,9 +3,10 @@ pragma solidity ^0.8.0; import "./interfaces/IPod.sol"; +import "./interfaces/IPodWithId.sol"; import "./interfaces/IERC20Pods.sol"; -abstract contract Pod is IPod { +abstract contract Pod is IPod, IPodWithId { error AccessDenied(); IERC20Pods public immutable token; diff --git a/contracts/TokenPodsLib.sol b/contracts/TokenPodsLib.sol index 8e44bea..e5883fc 100644 --- a/contracts/TokenPodsLib.sol +++ b/contracts/TokenPodsLib.sol @@ -22,103 +22,88 @@ library TokenPodsLib { type DataPtr is uint256; struct Data { - mapping(address => AddressSet.Data) pods; + mapping(address => AddressSet.Data) _pods; } - struct Info { - DataPtr data; - uint256 podCallGasLimit; + function hasPod(Data storage self, address account, address pod) internal view returns(bool) { + return self._pods[account].contains(pod); } - function makeInfo(Data storage data, uint256 podCallGasLimit_) internal pure returns(Info memory info) { - DataPtr ptr; - assembly { // solhint-disable-line no-inline-assembly - ptr := data.slot - } - info.data = ptr; - info.podCallGasLimit = podCallGasLimit_; - } - - function hasPod(Info memory self, address account, address pod) internal view returns(bool) { - return _getData(self).pods[account].contains(pod); - } - - function podsCount(Info memory self, address account) internal view returns(uint256) { - return _getData(self).pods[account].length(); + function podsCount(Data storage self, address account) internal view returns(uint256) { + return self._pods[account].length(); } - function podAt(Info memory self, address account, uint256 index) internal view returns(address) { - return _getData(self).pods[account].at(index); + function podAt(Data storage self, address account, uint256 index) internal view returns(address) { + return self._pods[account].at(index); } - function pods(Info memory self, address account) internal view returns(address[] memory) { - return _getData(self).pods[account].items.get(); + function pods(Data storage self, address account) internal view returns(address[] memory) { + return self._pods[account].items.get(); } - function podBalanceOf(Info memory self, address account, address pod, uint256 balance) internal view returns(uint256) { - if (_getData(self).pods[account].contains(pod)) { + function podBalanceOf(Data storage self, address account, address pod, uint256 balance) internal view returns(uint256) { + if (self._pods[account].contains(pod)) { return balance; } return 0; } - function addPod(Info memory self, address account, address pod, uint256 balance) internal returns(uint256) { - return _addPod(self, account, pod, balance); + function addPod(Data storage self, address account, address pod, uint256 balance, uint256 podCallGasLimit) internal returns(uint256) { + return _addPod(self, account, pod, balance, podCallGasLimit); } - function removePod(Info memory self, address account, address pod, uint256 balance) internal { - _removePod(self, account, pod, balance); + function removePod(Data storage self, address account, address pod, uint256 balance, uint256 podCallGasLimit) internal { + _removePod(self, account, pod, balance, podCallGasLimit); } - function removeAllPods(Info memory self, address account, uint256 balance) internal { - _removeAllPods(self, account, balance); + function removeAllPods(Data storage self, address account, uint256 balance, uint256 podCallGasLimit) internal { + _removeAllPods(self, account, balance, podCallGasLimit); } - function _addPod(Info memory self, address account, address pod, uint256 balance) private returns(uint256) { + function _addPod(Data storage self, address account, address pod, uint256 balance, uint256 podCallGasLimit) private returns(uint256) { if (pod == address(0)) revert InvalidPodAddress(); - if (!_getData(self).pods[account].add(pod)) revert PodAlreadyAdded(); + if (!self._pods[account].add(pod)) revert PodAlreadyAdded(); emit PodAdded(account, pod); if (balance > 0) { - _notifyPod(pod, address(0), account, balance, 0, false, self.podCallGasLimit); + _notifyPod(pod, address(0), account, balance, 0, false, podCallGasLimit); } - return _getData(self).pods[account].length(); + return self._pods[account].length(); } - function _removePod(Info memory self, address account, address pod, uint256 balance) private { - if (!_getData(self).pods[account].remove(pod)) revert PodNotFound(); + function _removePod(Data storage self, address account, address pod, uint256 balance, uint256 podCallGasLimit) private { + if (!self._pods[account].remove(pod)) revert PodNotFound(); if (balance > 0) { - _notifyPod(pod, account, address(0), balance, 0, false, self.podCallGasLimit); + _notifyPod(pod, account, address(0), balance, 0, false, podCallGasLimit); } } - function _removeAllPods(Info memory self, address account, uint256 balance) private { - address[] memory items = _getData(self).pods[account].items.get(); + function _removeAllPods(Data storage self, address account, uint256 balance, uint256 podCallGasLimit) private { + address[] memory items = self._pods[account].items.get(); unchecked { for (uint256 i = items.length; i > 0; i--) { - _getData(self).pods[account].remove(items[i - 1]); + self._pods[account].remove(items[i - 1]); emit PodRemoved(account, items[i - 1]); if (balance > 0) { - _notifyPod(items[i - 1], account, address(0), balance, 0, false, self.podCallGasLimit); + _notifyPod(items[i - 1], account, address(0), balance, 0, false, podCallGasLimit); } } } } - function updateBalances(Info memory self, address from, address to, uint256 amount) internal { - _updateBalances(self, from, to, amount, 0, false); + function updateBalances(Data storage self, address from, address to, uint256 amount, uint256 podCallGasLimit) internal { + _updateBalances(self, from, to, amount, 0, false, podCallGasLimit); } - function updateBalancesWithTokenId(Info memory self, address from, address to, uint256 amount, uint256 id) internal { - _updateBalances(self, from, to, amount, id, true); + function updateBalancesWithTokenId(Data storage self, address from, address to, uint256 amount, uint256 id, uint256 podCallGasLimit) internal { + _updateBalances(self, from, to, amount, id, true, podCallGasLimit); } - function _updateBalances(Info memory self, address from, address to, uint256 amount, uint256 id, bool hasId) private { + function _updateBalances(Data storage self, address from, address to, uint256 amount, uint256 id, bool hasId, uint256 podCallGasLimit) private { unchecked { if (amount > 0 && from != to) { - uint256 gasLimit = self.podCallGasLimit; - address[] memory a = _getData(self).pods[from].items.get(); - address[] memory b = _getData(self).pods[to].items.get(); + address[] memory a = self._pods[from].items.get(); + address[] memory b = self._pods[to].items.get(); uint256 aLength = a.length; uint256 bLength = b.length; @@ -129,7 +114,7 @@ library TokenPodsLib { for (j = 0; j < bLength; j++) { if (pod == b[j]) { // Both parties are participating of the same Pod - _notifyPod(pod, from, to, amount, id, hasId, gasLimit); + _notifyPod(pod, from, to, amount, id, hasId, podCallGasLimit); b[j] = address(0); break; } @@ -137,7 +122,7 @@ library TokenPodsLib { if (j == bLength) { // Sender is participating in a Pod, but receiver is not - _notifyPod(pod, from, address(0), amount, id, hasId, gasLimit); + _notifyPod(pod, from, address(0), amount, id, hasId, podCallGasLimit); } } @@ -145,7 +130,7 @@ library TokenPodsLib { address pod = b[j]; if (pod != address(0)) { // Receiver is participating in a Pod, but sender is not - _notifyPod(pod, address(0), to, amount, id, hasId, gasLimit); + _notifyPod(pod, address(0), to, amount, id, hasId, podCallGasLimit); } } } @@ -155,7 +140,7 @@ library TokenPodsLib { /// @notice Assembly implementation of the gas limited call to avoid return gas bomb, // moreover call to a destructed pod would also revert even inside try-catch block in Solidity 0.8.17 /// @dev try IPod(pod).updateBalances{gas: _POD_CALL_GAS_LIMIT}(from, to, amount) {} catch {} - function _notifyPod(address pod, address from, address to, uint256 amount, uint256 id, bool hasId, uint256 gasLimit) private { + function _notifyPod(address pod, address from, address to, uint256 amount, uint256 id, bool hasId, uint256 podCallGasLimit) private { bytes4 selector = hasId ? IPodWithId.updateBalancesWithTokenId.selector : IPod.updateBalances.selector; bytes4 exception = InsufficientGas.selector; assembly { // solhint-disable-line no-inline-assembly @@ -168,18 +153,11 @@ library TokenPodsLib { mstore(add(ptr, 0x64), id) } - if lt(div(mul(gas(), 63), 64), gasLimit) { + if lt(div(mul(gas(), 63), 64), podCallGasLimit) { mstore(0, exception) revert(0, 4) } - pop(call(gasLimit, pod, 0, ptr, add(0x64, mul(hasId, 0x20)), 0, 0)) - } - } - - function _getData(Info memory info) private pure returns(Data storage data) { - DataPtr ptr = info.data; - assembly { // solhint-disable-line no-inline-assembly - data.slot := ptr + pop(call(podCallGasLimit, pod, 0, ptr, add(0x64, mul(hasId, 0x20)), 0, 0)) } } } diff --git a/package.json b/package.json index 2423673..9f96d7c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@1inch/token-pods", - "version": "0.0.17", + "version": "0.1.0", "description": "ERC20 extension enabling external smart contract based Pods to track balances of those users who opted-in to these Pods", "repository": { "type": "git",