diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index 951bf4702..ae03e2090 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -2,12 +2,14 @@ from typing import TYPE_CHECKING, Optional from slither.utils.type import convert_type_for_solidity_signature_to_string +from slither.core.solidity_types.user_defined_type import UserDefinedType +from slither.core.declarations import Structure, Enum if TYPE_CHECKING: - from slither.core.declarations import FunctionContract, Structure, Contract + from slither.core.declarations import FunctionContract, Contract, CustomErrorContract -def generate_interface(contract: "Contract") -> str: +def generate_interface(contract: "Contract", unroll_structs: bool = True) -> str: """ Generates code for a Solidity interface to the contract. Args: @@ -22,13 +24,7 @@ def generate_interface(contract: "Contract") -> str: name, args = event.signature interface += f" event {name}({', '.join(args)});\n" for error in contract.custom_errors: - args = [ - convert_type_for_solidity_signature_to_string(arg.type) - .replace("(", "") - .replace(")", "") - for arg in error.parameters - ] - interface += f" error {error.name}({', '.join(args)});\n" + interface += generate_custom_error_interface(error, unroll_structs) for enum in contract.enums: interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" for struct in contract.structures: @@ -38,12 +34,16 @@ def generate_interface(contract: "Contract") -> str: for func in contract.functions_entry_points: if func.is_constructor or func.is_fallback or func.is_receive: continue - interface += f" function {generate_interface_function_signature(func)};\n" + interface += ( + f" function {generate_interface_function_signature(func, unroll_structs)};\n" + ) interface += "}\n\n" return interface -def generate_interface_function_signature(func: "FunctionContract") -> Optional[str]: +def generate_interface_function_signature( + func: "FunctionContract", unroll_structs: bool = True +) -> Optional[str]: """ Generates a string of the form: func_name(type1,type2) external {payable/view/pure} returns (type3) @@ -56,7 +56,7 @@ def generate_interface_function_signature(func: "FunctionContract") -> Optional[ Returns None if the function is private or internal, or is a constructor/fallback/receive. """ - name, parameters, return_vars = func.signature + name, _, _ = func.signature if ( func not in func.contract.functions_entry_points or func.is_constructor @@ -69,16 +69,24 @@ def generate_interface_function_signature(func: "FunctionContract") -> Optional[ payable = " payable" if func.payable else "" returns = [ convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "") + if unroll_structs + else f"{str(ret.type.type)} memory" + if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, (Structure, Enum)) + else str(ret.type) for ret in func.returns ] parameters = [ convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "") + if unroll_structs + else f"{str(param.type.type)} memory" + if isinstance(param.type, UserDefinedType) and isinstance(param.type.type, (Structure, Enum)) + else str(param.type) for param in func.parameters ] _interface_signature_str = ( name + "(" + ",".join(parameters) + ") external" + payable + pure + view ) - if len(return_vars) > 0: + if len(returns) > 0: _interface_signature_str += " returns (" + ",".join(returns) + ")" return _interface_signature_str @@ -102,3 +110,17 @@ def generate_struct_interface_str(struct: "Structure") -> str: definition += f" {elem.type} {elem.name};\n" definition += " }\n" return definition + + +def generate_custom_error_interface(error: "CustomErrorContract", unroll_structs: bool = True) -> str: + args = [ + convert_type_for_solidity_signature_to_string(arg.type) + .replace("(", "") + .replace(")", "") + if unroll_structs + else str(arg.type.type) + if isinstance(arg.type, UserDefinedType) and isinstance(arg.type.type, (Structure, Enum)) + else str(arg.type) + for arg in error.parameters + ] + return f" error {error.name}({', '.join(args)});\n"