The home for Hyperlane core contracts, sdk packages, and other infrastructure
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
hyperlane-monorepo/solidity/test/hooks/ProtocolFee.t.sol

182 lines
5.5 KiB

// SPDX-License-Identifier: Apache-2.0
pragma solidity ^0.8.13;
import {Test} from "forge-std/Test.sol";
import {TypeCasts} from "../../contracts/libs/TypeCasts.sol";
import {MessageUtils} from "../isms/IsmTestUtils.sol";
import {StandardHookMetadata} from "../../contracts/hooks/libs/StandardHookMetadata.sol";
import {IPostDispatchHook} from "../../contracts/interfaces/hooks/IPostDispatchHook.sol";
import {ProtocolFee} from "../../contracts/hooks/ProtocolFee.sol";
contract ProtocolFeeTest is Test {
using TypeCasts for address;
ProtocolFee internal fees;
address internal alice = address(0x1); // alice the user
address internal bob = address(0x2); // bob the beneficiary
address internal charlie = address(0x3); // charlie the crock
uint32 internal constant TEST_ORIGIN_DOMAIN = 1;
uint32 internal constant TEST_DESTINATION_DOMAIN = 2;
uint256 internal constant MAX_FEE = 1e16;
uint256 internal constant FEE = 1e16;
bytes internal testMessage;
function setUp() public {
fees = new ProtocolFee(MAX_FEE, FEE, bob, address(this));
testMessage = _encodeTestMessage();
}
function testConstructor() public {
assertEq(fees.protocolFee(), FEE);
}
function testHookType() public {
assertEq(fees.hookType(), uint8(IPostDispatchHook.Types.PROTOCOL_FEE));
}
function testSetProtocolFee(uint256 fee) public {
fee = bound(fee, 0, fees.MAX_PROTOCOL_FEE());
fees.setProtocolFee(fee);
assertEq(fees.protocolFee(), fee);
}
function testSetProtocolFee_revertsWhen_notOwner() public {
uint256 fee = 1e17;
vm.prank(charlie);
vm.expectRevert("Ownable: caller is not the owner");
fees.setProtocolFee(fee);
assertEq(fees.protocolFee(), FEE);
}
function testSetProtocolFee_revertWhen_exceedsMax(uint256 fee) public {
fee = bound(
fee,
fees.MAX_PROTOCOL_FEE() + 1,
10 * fees.MAX_PROTOCOL_FEE()
);
vm.expectRevert("ProtocolFee: exceeds max protocol fee");
fees.setProtocolFee(fee);
assertEq(fees.protocolFee(), FEE);
}
function testSetBeneficiary_revertWhen_notOwner() public {
vm.prank(charlie);
vm.expectRevert("Ownable: caller is not the owner");
fees.setBeneficiary(charlie);
assertEq(fees.beneficiary(), bob);
}
function testQuoteDispatch() public {
assertEq(fees.quoteDispatch("", testMessage), FEE);
}
function testFuzz_postDispatch_inusfficientFees(
uint256 feeRequired,
uint256 feeSent
) public {
feeRequired = bound(feeRequired, 1, fees.MAX_PROTOCOL_FEE());
// bound feeSent to be less than feeRequired
feeSent = bound(feeSent, 0, feeRequired - 1);
vm.deal(alice, feeSent);
fees.setProtocolFee(feeRequired);
uint256 balanceBefore = alice.balance;
vm.prank(alice);
vm.expectRevert("ProtocolFee: insufficient protocol fee");
fees.postDispatch{value: feeSent}("", "");
assertEq(alice.balance, balanceBefore);
}
function testFuzz_postDispatch_sufficientFees(
uint256 feeRequired,
uint256 feeSent
) public {
feeRequired = bound(feeRequired, 1, fees.MAX_PROTOCOL_FEE());
feeSent = bound(feeSent, feeRequired, 10 * feeRequired);
vm.deal(alice, feeSent);
fees.setProtocolFee(feeRequired);
uint256 aliceBalanceBefore = alice.balance;
vm.prank(alice);
fees.postDispatch{value: feeSent}("", testMessage);
assertEq(alice.balance, aliceBalanceBefore - feeRequired);
}
function test_postDispatch_specifyRefundAddress(
uint256 feeRequired,
uint256 feeSent
) public {
bytes memory metadata = StandardHookMetadata.overrideRefundAddress(bob);
feeRequired = bound(feeRequired, 1, fees.MAX_PROTOCOL_FEE());
feeSent = bound(feeSent, feeRequired, 10 * feeRequired);
vm.deal(alice, feeSent);
fees.setProtocolFee(feeRequired);
uint256 aliceBalanceBefore = alice.balance;
uint256 bobBalanceBefore = bob.balance;
vm.prank(alice);
fees.postDispatch{value: feeSent}(metadata, testMessage);
assertEq(alice.balance, aliceBalanceBefore - feeSent);
assertEq(bob.balance, bobBalanceBefore + feeSent - feeRequired);
}
function testFuzz_collectProtocolFee(
uint256 feeRequired,
uint256 dispatchCalls
) public {
feeRequired = bound(feeRequired, 1, fees.MAX_PROTOCOL_FEE());
// no of postDispatch calls to be made
dispatchCalls = bound(dispatchCalls, 1, 1000);
vm.deal(alice, feeRequired * dispatchCalls);
fees.setProtocolFee(feeRequired);
uint256 balanceBefore = bob.balance;
for (uint256 i = 0; i < dispatchCalls; i++) {
vm.prank(alice);
fees.postDispatch{value: feeRequired}("", "");
}
fees.collectProtocolFees();
assertEq(bob.balance, balanceBefore + feeRequired * dispatchCalls);
}
// ============ Helper Functions ============
function _encodeTestMessage() internal view returns (bytes memory) {
return
MessageUtils.formatMessage(
uint8(0),
uint32(1),
TEST_ORIGIN_DOMAIN,
alice.addressToBytes32(),
TEST_DESTINATION_DOMAIN,
alice.addressToBytes32(),
abi.encodePacked("Hello World")
);
}
receive() external payable {}
}