|
| 1 | +// SPDX-License-Identifier: GPL-3.0-or-later |
| 2 | +pragma solidity ^0.8; |
| 3 | + |
| 4 | +import {ApproveAndBridge, IERC20} from "./mixin/ApproveAndBridge.sol"; |
| 5 | +import {Math} from "./vendored/Math.sol"; |
| 6 | + |
| 7 | +/// ! @dev UNAUDITED UNTESTED Do not use in production |
| 8 | +/// @dev Performs two steps before bridging via SocketGateway: |
| 9 | +/// 1. Modify input amount in calldata |
| 10 | +/// 2. Modify output amount in calldata |
| 11 | +/// 3. Call SocketGateway.fallback() with the modified calldata |
| 12 | +contract BungeeApproveAndBridge is ApproveAndBridge { |
| 13 | + error InvalidInput(); |
| 14 | + error PositionOutOfBounds(); |
| 15 | + error BridgeFailed(); |
| 16 | + |
| 17 | + /// @dev ModifyCalldataParams is a struct that contains information required to modify SocketGateway calldata |
| 18 | + /// @dev the input amount index, modify output flag, and output amount index |
| 19 | + struct ModifyCalldataParams { |
| 20 | + uint256 inputAmountIdx; |
| 21 | + bool modifyOutput; |
| 22 | + uint256 outputAmountIdx; |
| 23 | + } |
| 24 | + |
| 25 | + /// @dev routeIds on SocketGateway are 4 bytes |
| 26 | + uint8 private constant ROUTE_ID_BYTES_LENGTH = 4; |
| 27 | + /// @dev there are 3 params in ModifyCalldataParams |
| 28 | + uint8 private constant MODIFY_CALLDATA_PARAMS_COUNT = 3; |
| 29 | + /// @dev each ModifyCalldataParams is 32 bytes |
| 30 | + uint8 private constant MODIFY_CALLDATA_LENGTH_BYTES = 32; |
| 31 | + /// @dev total length of the modify calldata bytes |
| 32 | + uint8 private constant MODIFY_CALLDATA_LENGTH = MODIFY_CALLDATA_PARAMS_COUNT * MODIFY_CALLDATA_LENGTH_BYTES; |
| 33 | + /// @dev minimum length of the data payload |
| 34 | + /// @dev should atleast include the routeId and the ModifyCalldataParams |
| 35 | + uint8 private constant MIN_DATA_LENGTH = ROUTE_ID_BYTES_LENGTH + MODIFY_CALLDATA_LENGTH; |
| 36 | + |
| 37 | + /// @dev SocketGateway address |
| 38 | + address public immutable SOCKET_GATEWAY; |
| 39 | + |
| 40 | + constructor(address socketGateway_) { |
| 41 | + require(socketGateway_.code.length > 0, "Socket gateway contract not deployed"); |
| 42 | + |
| 43 | + SOCKET_GATEWAY = socketGateway_; |
| 44 | + } |
| 45 | + |
| 46 | + /** |
| 47 | + * @notice Approval should be given to the SocketGateway address |
| 48 | + * @dev Returns the SocketGateway address |
| 49 | + */ |
| 50 | + function bridgeApprovalTarget() public view override returns (address) { |
| 51 | + return address(SOCKET_GATEWAY); |
| 52 | + } |
| 53 | + |
| 54 | + /** |
| 55 | + * @notice Bridge the token via SocketGateway |
| 56 | + * @dev Modifies SocketGateway calldata to modify the input and output amounts before bridging |
| 57 | + * @param token The token to bridge |
| 58 | + * @param amount The amount of token to bridge |
| 59 | + * @param nativeTokenExtraFee extra fee in native token, if any |
| 60 | + * @param data encoded bytes including SocketGateway calldata and ModifyCalldataParams |
| 61 | + */ |
| 62 | + function bridge(IERC20 token, uint256 amount, uint256 nativeTokenExtraFee, bytes calldata data) internal override { |
| 63 | + // decode & parse data to find positions in calldata to modify |
| 64 | + bytes memory modifiedCalldata = _parseAndModifyCalldata(amount, data); |
| 65 | + |
| 66 | + // execute using the modified calldata via SocketGateway.fallback() |
| 67 | + (bool success,) = address(token) == NATIVE_TOKEN_ADDRESS |
| 68 | + ? address(SOCKET_GATEWAY).call{value: amount + nativeTokenExtraFee}(modifiedCalldata) |
| 69 | + : address(SOCKET_GATEWAY).call{value: nativeTokenExtraFee}(modifiedCalldata); |
| 70 | + if (!success) revert BridgeFailed(); |
| 71 | + } |
| 72 | + |
| 73 | + /** |
| 74 | + * @dev Parses and modifies the calldata to modify the input and output amounts before bridging |
| 75 | + * @param amount Updated input amount to use to modify the calldata |
| 76 | + * @param data encoded bytes including SocketGateway calldata and ModifyCalldataParams |
| 77 | + * @return modifiedCalldata The modified calldata |
| 78 | + */ |
| 79 | + function _parseAndModifyCalldata(uint256 amount, bytes calldata data) internal pure returns (bytes memory) { |
| 80 | + // Parse the data into route calldata and ModifyCalldataParams |
| 81 | + (bytes memory routeCalldata, ModifyCalldataParams memory modifyCalldataParams) = _parseCalldata(data); |
| 82 | + |
| 83 | + // Read the original input amount from the calldata |
| 84 | + // before modifying input amount |
| 85 | + uint256 originalInput = _readUint256({_data: routeCalldata, _index: modifyCalldataParams.inputAmountIdx}); |
| 86 | + |
| 87 | + // Replace the input amount in the calldata |
| 88 | + bytes memory modifiedCalldata = |
| 89 | + _replaceUint256({_original: routeCalldata, _start: modifyCalldataParams.inputAmountIdx, _amount: amount}); |
| 90 | + |
| 91 | + // Optionally replace the output amount if required |
| 92 | + // in case of bridges like Across, need to modify both input and output amounts |
| 93 | + // - decode current input and output amounts from calldata |
| 94 | + // - calculate and apply the percentage diff bw new and old input amount on the old output amount |
| 95 | + // - replace the output amount at the index with the new amount |
| 96 | + // - assumes output amount is always uint256 in SocketGateway impls |
| 97 | + if (modifyCalldataParams.modifyOutput) { |
| 98 | + uint256 originalOutput = _readUint256({_data: routeCalldata, _index: modifyCalldataParams.outputAmountIdx}); |
| 99 | + uint256 newOutput = _applyPctDiff({_base: originalInput, _compare: amount, _target: originalOutput}); |
| 100 | + modifiedCalldata = _replaceUint256({ |
| 101 | + _original: modifiedCalldata, |
| 102 | + _start: modifyCalldataParams.outputAmountIdx, |
| 103 | + _amount: newOutput |
| 104 | + }); |
| 105 | + } |
| 106 | + |
| 107 | + return modifiedCalldata; |
| 108 | + } |
| 109 | + |
| 110 | + /** |
| 111 | + * @dev Parses the calldata to extract the route calldata and ModifyCalldataParams |
| 112 | + * @param _data The calldata to parse |
| 113 | + * @return routeCalldata The SocketGateway route calldata |
| 114 | + * @return modifyCalldataParams The ModifyCalldataParams |
| 115 | + */ |
| 116 | + function _parseCalldata(bytes calldata _data) internal pure returns (bytes memory, ModifyCalldataParams memory) { |
| 117 | + // calldata should have minimum of routeId and ModifyCalldataParams |
| 118 | + if (_data.length < MIN_DATA_LENGTH) revert InvalidInput(); |
| 119 | + uint256 routeCalldataLength = _data.length - MODIFY_CALLDATA_LENGTH; |
| 120 | + |
| 121 | + // Extract the route execution calldata |
| 122 | + bytes memory routeCalldata = _data[:routeCalldataLength]; |
| 123 | + |
| 124 | + // Extract the ModifyCalldataParams |
| 125 | + ModifyCalldataParams memory modifyCalldataParams; |
| 126 | + (modifyCalldataParams.inputAmountIdx, modifyCalldataParams.modifyOutput, modifyCalldataParams.outputAmountIdx) = |
| 127 | + abi.decode(_data[routeCalldataLength:], (uint256, bool, uint256)); |
| 128 | + |
| 129 | + return (routeCalldata, modifyCalldataParams); |
| 130 | + } |
| 131 | + |
| 132 | + /** |
| 133 | + * @dev Replaces a uint256 at a given position in a bytes data with a new uint256 |
| 134 | + * @dev Directly modifies the original bytes data in-place without creating a new copy |
| 135 | + */ |
| 136 | + function _replaceUint256(bytes memory _original, uint256 _start, uint256 _amount) |
| 137 | + internal |
| 138 | + pure |
| 139 | + returns (bytes memory) |
| 140 | + { |
| 141 | + // check if the _start is out of bounds |
| 142 | + if (_start + 32 > _original.length) revert PositionOutOfBounds(); |
| 143 | + |
| 144 | + // Directly modify externalData in-place without creating a new copy |
| 145 | + assembly { |
| 146 | + // Calculate position in memory where we need to write the new amount |
| 147 | + // Write the amount at that position |
| 148 | + mstore(add(add(_original, 32), _start), _amount) |
| 149 | + } |
| 150 | + |
| 151 | + return _original; |
| 152 | + } |
| 153 | + |
| 154 | + /** |
| 155 | + * @dev Reads a uint256 at a given byte index in a bytes array |
| 156 | + */ |
| 157 | + function _readUint256(bytes memory _data, uint256 _index) internal pure returns (uint256 value) { |
| 158 | + if (_data.length < _index + 32) revert PositionOutOfBounds(); |
| 159 | + assembly { |
| 160 | + value := mload(add(add(_data, 0x20), _index)) |
| 161 | + } |
| 162 | + } |
| 163 | + |
| 164 | + /** |
| 165 | + * @dev Applies a percentage difference to a target number |
| 166 | + */ |
| 167 | + function _applyPctDiff(uint256 _base, uint256 _compare, uint256 _target) internal pure returns (uint256) { |
| 168 | + if (_base == 0) revert InvalidInput(); |
| 169 | + return Math.mulDiv({x: _target, y: _compare, denominator: _base}); |
| 170 | + } |
| 171 | +} |
0 commit comments