diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index b1fa570c6..c8cfeb716 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -13,7 +13,9 @@ from slither.core.expressions import ( TupleExpression, TypeConversion, CallExpression, + MemberAccess, ) +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.variables import Variable from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor @@ -27,7 +29,13 @@ class NotConstant(Exception): KEY = "ConstantFolding" CONSTANT_TYPES_OPERATIONS = Union[ - Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, ] @@ -69,6 +77,7 @@ class ConstantFolding(ExpressionVisitor): # pylint: disable=import-outside-toplevel def _post_identifier(self, expression: Identifier) -> None: from slither.core.declarations.solidity_variables import SolidityFunction + from slither.core.declarations.enum import Enum if isinstance(expression.value, Variable): if expression.value.is_constant: @@ -77,7 +86,14 @@ class ConstantFolding(ExpressionVisitor): # Everything outside of literal if isinstance( expr, - (BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): cf = ConstantFolding(expr, self._type) expr = cf.result() @@ -88,7 +104,10 @@ class ConstantFolding(ExpressionVisitor): elif isinstance(expression.value, SolidityFunction): set_val(expression, expression.value) else: - raise NotConstant + # We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value + # We can't handle it here because we don't have the field accessed so we do it in _post_member_access + if not isinstance(expression.value, Enum): + raise NotConstant # pylint: disable=too-many-branches,too-many-statements def _post_binary_operation(self, expression: BinaryOperation) -> None: @@ -96,12 +115,28 @@ class ConstantFolding(ExpressionVisitor): expression_right = expression.expression_right if not isinstance( expression_left, - (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): raise NotConstant if not isinstance( expression_right, - (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + MemberAccess, + ), ): raise NotConstant left = get_val(expression_left) @@ -205,6 +240,22 @@ class ConstantFolding(ExpressionVisitor): raise NotConstant def _post_call_expression(self, expression: expressions.CallExpression) -> None: + from slither.core.declarations.solidity_variables import SolidityFunction + from slither.core.declarations.enum import Enum + + # pylint: disable=too-many-boolean-expressions + if ( + isinstance(expression.called, Identifier) + and expression.called.value == SolidityFunction("type()") + and len(expression.arguments) == 1 + and ( + isinstance(expression.arguments[0], ElementaryTypeNameExpression) + or isinstance(expression.arguments[0], Identifier) + and isinstance(expression.arguments[0].value, Enum) + ) + ): + # Returning early to support type(ElemType).max/min or type(MyEnum).max/min + return called = get_val(expression.called) args = [get_val(arg) for arg in expression.arguments] if called.name == "keccak256(bytes)": @@ -220,12 +271,70 @@ class ConstantFolding(ExpressionVisitor): def _post_elementary_type_name_expression( self, expression: expressions.ElementaryTypeNameExpression ) -> None: - raise NotConstant + # We don't have to raise an exception to support type(uint112).max or similar + pass def _post_index_access(self, expression: expressions.IndexAccess) -> None: raise NotConstant def _post_member_access(self, expression: expressions.MemberAccess) -> None: + from slither.core.declarations import ( + SolidityFunction, + Contract, + EnumContract, + EnumTopLevel, + Enum, + ) + from slither.core.solidity_types import UserDefinedType + + # pylint: disable=too-many-nested-blocks + if isinstance(expression.expression, CallExpression) and expression.member_name in [ + "min", + "max", + ]: + if isinstance(expression.expression.called, Identifier): + if expression.expression.called.value == SolidityFunction("type()"): + assert len(expression.expression.arguments) == 1 + type_expression_found = expression.expression.arguments[0] + type_found: Union[ElementaryType, UserDefinedType] + if isinstance(type_expression_found, ElementaryTypeNameExpression): + type_expression_found_type = type_expression_found.type + assert isinstance(type_expression_found_type, ElementaryType) + type_found = type_expression_found_type + value = ( + type_found.max if expression.member_name == "max" else type_found.min + ) + set_val(expression, value) + return + # type(enum).max/min + # Case when enum is in another contract e.g. type(C.E).max + if isinstance(type_expression_found, MemberAccess): + contract = type_expression_found.expression.value + assert isinstance(contract, Contract) + for enum in contract.enums: + if enum.name == type_expression_found.member_name: + type_found_in_expression = enum + type_found = UserDefinedType(enum) + break + else: + assert isinstance(type_expression_found, Identifier) + type_found_in_expression = type_expression_found.value + assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel)) + type_found = UserDefinedType(type_found_in_expression) + value = ( + type_found_in_expression.max + if expression.member_name == "max" + else type_found_in_expression.min + ) + set_val(expression, value) + return + elif isinstance(expression.expression, Identifier) and isinstance( + expression.expression.value, Enum + ): + # Handle direct access to enum field + set_val(expression, expression.expression.value.values.index(expression.member_name)) + return + raise NotConstant def _post_new_array(self, expression: expressions.NewArray) -> None: @@ -272,6 +381,7 @@ class ConstantFolding(ExpressionVisitor): TupleExpression, TypeConversion, CallExpression, + MemberAccess, ), ): raise NotConstant