Add option to skip unrolling user-defined-types

pull/1802/head
webthethird 2 years ago
parent 3c55228707
commit 6a3ae82b56
  1. 48
      slither/utils/code_generation.py

@ -2,12 +2,14 @@
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from slither.utils.type import convert_type_for_solidity_signature_to_string from slither.utils.type import convert_type_for_solidity_signature_to_string
from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.core.declarations import Structure, Enum
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import FunctionContract, Structure, Contract from slither.core.declarations import FunctionContract, Contract, CustomErrorContract
def generate_interface(contract: "Contract") -> str: def generate_interface(contract: "Contract", unroll_structs: bool = True) -> str:
""" """
Generates code for a Solidity interface to the contract. Generates code for a Solidity interface to the contract.
Args: Args:
@ -22,13 +24,7 @@ def generate_interface(contract: "Contract") -> str:
name, args = event.signature name, args = event.signature
interface += f" event {name}({', '.join(args)});\n" interface += f" event {name}({', '.join(args)});\n"
for error in contract.custom_errors: for error in contract.custom_errors:
args = [ interface += generate_custom_error_interface(error, unroll_structs)
convert_type_for_solidity_signature_to_string(arg.type)
.replace("(", "")
.replace(")", "")
for arg in error.parameters
]
interface += f" error {error.name}({', '.join(args)});\n"
for enum in contract.enums: for enum in contract.enums:
interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n"
for struct in contract.structures: for struct in contract.structures:
@ -38,12 +34,16 @@ def generate_interface(contract: "Contract") -> str:
for func in contract.functions_entry_points: for func in contract.functions_entry_points:
if func.is_constructor or func.is_fallback or func.is_receive: if func.is_constructor or func.is_fallback or func.is_receive:
continue continue
interface += f" function {generate_interface_function_signature(func)};\n" interface += (
f" function {generate_interface_function_signature(func, unroll_structs)};\n"
)
interface += "}\n\n" interface += "}\n\n"
return interface return interface
def generate_interface_function_signature(func: "FunctionContract") -> Optional[str]: def generate_interface_function_signature(
func: "FunctionContract", unroll_structs: bool = True
) -> Optional[str]:
""" """
Generates a string of the form: Generates a string of the form:
func_name(type1,type2) external {payable/view/pure} returns (type3) func_name(type1,type2) external {payable/view/pure} returns (type3)
@ -56,7 +56,7 @@ def generate_interface_function_signature(func: "FunctionContract") -> Optional[
Returns None if the function is private or internal, or is a constructor/fallback/receive. Returns None if the function is private or internal, or is a constructor/fallback/receive.
""" """
name, parameters, return_vars = func.signature name, _, _ = func.signature
if ( if (
func not in func.contract.functions_entry_points func not in func.contract.functions_entry_points
or func.is_constructor or func.is_constructor
@ -69,16 +69,24 @@ def generate_interface_function_signature(func: "FunctionContract") -> Optional[
payable = " payable" if func.payable else "" payable = " payable" if func.payable else ""
returns = [ returns = [
convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "") convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "")
if unroll_structs
else f"{str(ret.type.type)} memory"
if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, (Structure, Enum))
else str(ret.type)
for ret in func.returns for ret in func.returns
] ]
parameters = [ parameters = [
convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "") convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "")
if unroll_structs
else f"{str(param.type.type)} memory"
if isinstance(param.type, UserDefinedType) and isinstance(param.type.type, (Structure, Enum))
else str(param.type)
for param in func.parameters for param in func.parameters
] ]
_interface_signature_str = ( _interface_signature_str = (
name + "(" + ",".join(parameters) + ") external" + payable + pure + view name + "(" + ",".join(parameters) + ") external" + payable + pure + view
) )
if len(return_vars) > 0: if len(returns) > 0:
_interface_signature_str += " returns (" + ",".join(returns) + ")" _interface_signature_str += " returns (" + ",".join(returns) + ")"
return _interface_signature_str return _interface_signature_str
@ -102,3 +110,17 @@ def generate_struct_interface_str(struct: "Structure") -> str:
definition += f" {elem.type} {elem.name};\n" definition += f" {elem.type} {elem.name};\n"
definition += " }\n" definition += " }\n"
return definition return definition
def generate_custom_error_interface(error: "CustomErrorContract", unroll_structs: bool = True) -> str:
args = [
convert_type_for_solidity_signature_to_string(arg.type)
.replace("(", "")
.replace(")", "")
if unroll_structs
else str(arg.type.type)
if isinstance(arg.type, UserDefinedType) and isinstance(arg.type.type, (Structure, Enum))
else str(arg.type)
for arg in error.parameters
]
return f" error {error.name}({', '.join(args)});\n"

Loading…
Cancel
Save