Improve support for type .max/.min and Enums

pull/2574/head
Simone 3 weeks ago
parent afac8c4bf0
commit 48d7d9beef
  1. 122
      slither/visitors/expression/constants_folding.py

@ -13,7 +13,9 @@ from slither.core.expressions import (
TupleExpression,
TypeConversion,
CallExpression,
MemberAccess,
)
from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression
from slither.core.variables import Variable
from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int
from slither.visitors.expression.expression import ExpressionVisitor
@ -27,7 +29,13 @@ class NotConstant(Exception):
KEY = "ConstantFolding"
CONSTANT_TYPES_OPERATIONS = Union[
Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
]
@ -69,6 +77,7 @@ class ConstantFolding(ExpressionVisitor):
# pylint: disable=import-outside-toplevel
def _post_identifier(self, expression: Identifier) -> None:
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.enum import Enum
if isinstance(expression.value, Variable):
if expression.value.is_constant:
@ -77,7 +86,14 @@ class ConstantFolding(ExpressionVisitor):
# Everything outside of literal
if isinstance(
expr,
(BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
cf = ConstantFolding(expr, self._type)
expr = cf.result()
@ -88,7 +104,10 @@ class ConstantFolding(ExpressionVisitor):
elif isinstance(expression.value, SolidityFunction):
set_val(expression, expression.value)
else:
raise NotConstant
# We don't want to raise an error for a direct access to an Enum as they can be converted to a constant value
# We can't handle it here because we don't have the field accessed so we do it in _post_member_access
if not isinstance(expression.value, Enum):
raise NotConstant
# pylint: disable=too-many-branches,too-many-statements
def _post_binary_operation(self, expression: BinaryOperation) -> None:
@ -96,12 +115,28 @@ class ConstantFolding(ExpressionVisitor):
expression_right = expression.expression_right
if not isinstance(
expression_left,
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
raise NotConstant
if not isinstance(
expression_right,
(Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion),
(
Literal,
BinaryOperation,
UnaryOperation,
Identifier,
TupleExpression,
TypeConversion,
MemberAccess,
),
):
raise NotConstant
left = get_val(expression_left)
@ -205,6 +240,22 @@ class ConstantFolding(ExpressionVisitor):
raise NotConstant
def _post_call_expression(self, expression: expressions.CallExpression) -> None:
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.enum import Enum
# pylint: disable=too-many-boolean-expressions
if (
isinstance(expression.called, Identifier)
and expression.called.value == SolidityFunction("type()")
and len(expression.arguments) == 1
and (
isinstance(expression.arguments[0], ElementaryTypeNameExpression)
or isinstance(expression.arguments[0], Identifier)
and isinstance(expression.arguments[0].value, Enum)
)
):
# Returning early to support type(ElemType).max/min or type(MyEnum).max/min
return
called = get_val(expression.called)
args = [get_val(arg) for arg in expression.arguments]
if called.name == "keccak256(bytes)":
@ -220,12 +271,70 @@ class ConstantFolding(ExpressionVisitor):
def _post_elementary_type_name_expression(
self, expression: expressions.ElementaryTypeNameExpression
) -> None:
raise NotConstant
# We don't have to raise an exception to support type(uint112).max or similar
pass
def _post_index_access(self, expression: expressions.IndexAccess) -> None:
raise NotConstant
def _post_member_access(self, expression: expressions.MemberAccess) -> None:
from slither.core.declarations import (
SolidityFunction,
Contract,
EnumContract,
EnumTopLevel,
Enum,
)
from slither.core.solidity_types import UserDefinedType
# pylint: disable=too-many-nested-blocks
if isinstance(expression.expression, CallExpression) and expression.member_name in [
"min",
"max",
]:
if isinstance(expression.expression.called, Identifier):
if expression.expression.called.value == SolidityFunction("type()"):
assert len(expression.expression.arguments) == 1
type_expression_found = expression.expression.arguments[0]
type_found: Union[ElementaryType, UserDefinedType]
if isinstance(type_expression_found, ElementaryTypeNameExpression):
type_expression_found_type = type_expression_found.type
assert isinstance(type_expression_found_type, ElementaryType)
type_found = type_expression_found_type
value = (
type_found.max if expression.member_name == "max" else type_found.min
)
set_val(expression, value)
return
# type(enum).max/min
# Case when enum is in another contract e.g. type(C.E).max
if isinstance(type_expression_found, MemberAccess):
contract = type_expression_found.expression.value
assert isinstance(contract, Contract)
for enum in contract.enums:
if enum.name == type_expression_found.member_name:
type_found_in_expression = enum
type_found = UserDefinedType(enum)
break
else:
assert isinstance(type_expression_found, Identifier)
type_found_in_expression = type_expression_found.value
assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel))
type_found = UserDefinedType(type_found_in_expression)
value = (
type_found_in_expression.max
if expression.member_name == "max"
else type_found_in_expression.min
)
set_val(expression, value)
return
elif isinstance(expression.expression, Identifier) and isinstance(
expression.expression.value, Enum
):
# Handle direct access to enum field
set_val(expression, expression.expression.value.values.index(expression.member_name))
return
raise NotConstant
def _post_new_array(self, expression: expressions.NewArray) -> None:
@ -272,6 +381,7 @@ class ConstantFolding(ExpressionVisitor):
TupleExpression,
TypeConversion,
CallExpression,
MemberAccess,
),
):
raise NotConstant

Loading…
Cancel
Save