From 36e1ac1e011aab8fa22b47e6de5ff4d728f647a6 Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 27 Mar 2023 10:37:36 -0500 Subject: [PATCH] Better handling of state variable signatures especially contract-type variables and functions that return them --- slither/utils/code_generation.py | 51 ++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index f0c433ac3..fc7578040 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -1,19 +1,25 @@ # Functions for generating Solidity code from typing import TYPE_CHECKING, Optional -from slither.utils.type import convert_type_for_solidity_signature_to_string -from slither.core.solidity_types.user_defined_type import UserDefinedType -from slither.core.declarations import Structure, Enum +from slither.utils.type import ( + convert_type_for_solidity_signature_to_string, + export_nested_types_from_variable, + export_return_type_from_variable, +) +from slither.core.solidity_types import UserDefinedType, MappingType, ArrayType +from slither.core.declarations import Structure, Enum, Contract if TYPE_CHECKING: - from slither.core.declarations import FunctionContract, Contract, CustomErrorContract + from slither.core.declarations import FunctionContract, CustomErrorContract + from slither.core.variables import StateVariable def generate_interface(contract: "Contract", unroll_structs: bool = True) -> str: """ Generates code for a Solidity interface to the contract. Args: - contract: A Contract object + contract: A Contract object. + unroll_structs: Specifies whether to use structures' underlying types instead of the user-defined type. Returns: A string with the code for an interface, with function stubs for all public or external functions and @@ -30,7 +36,7 @@ def generate_interface(contract: "Contract", unroll_structs: bool = True) -> str for struct in contract.structures: interface += generate_struct_interface_str(struct) for var in contract.state_variables_entry_points: - interface += f" function {var.signature_str.replace('returns', 'external returns ')};\n" + interface += generate_interface_variable_signature(var, unroll_structs) for func in contract.functions_entry_points: if func.is_constructor or func.is_fallback or func.is_receive: continue @@ -41,6 +47,35 @@ def generate_interface(contract: "Contract", unroll_structs: bool = True) -> str return interface +def generate_interface_variable_signature( + var: "StateVariable", unroll_structs: bool = True +) -> Optional[str]: + if unroll_structs: + params = [ + convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "") + for x in export_nested_types_from_variable(var) + ] + returns = [ + convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "") + for x in export_return_type_from_variable(var) + ] + else: + _, params, _ = var.signature + returns = [] + _type = var.type + while isinstance(_type, MappingType): + _type = _type.type_to + while isinstance(_type, (ArrayType, UserDefinedType)): + _type = _type.type + ret = str(_type) + if isinstance(_type, Structure): + ret += " memory" + elif isinstance(_type, Contract): + ret = "address" + returns.append(ret) + return f" function {var.name}({','.join(params)}) external returns ({', '.join(returns)});\n" + + def generate_interface_function_signature( func: "FunctionContract", unroll_structs: bool = True ) -> Optional[str]: @@ -72,6 +107,8 @@ def generate_interface_function_signature( if unroll_structs else f"{str(ret.type.type)} memory" if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, (Structure, Enum)) + else "address" + if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Contract) else str(ret.type) for ret in func.returns ] @@ -81,6 +118,8 @@ def generate_interface_function_signature( else f"{str(param.type.type)} memory" if isinstance(param.type, UserDefinedType) and isinstance(param.type.type, (Structure, Enum)) + else "address" + if isinstance(param.type, UserDefinedType) and isinstance(param.type.type, Contract) else str(param.type) for param in func.parameters ]