Merge pull request #1501 from crytic/call-value-ternary

Call value ternary
pull/1554/head
Feist Josselin 2 years ago committed by GitHub
commit 81c2a46d5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      slither/solc_parsing/declarations/function.py
  2. 2
      slither/solc_parsing/declarations/modifier.py
  3. 194
      slither/utils/expression_manipulations.py
  4. 18
      tests/slithir/ternary_expressions.sol
  5. 6
      tests/slithir/test_ternary_expressions.py

@ -308,7 +308,7 @@ class FunctionSolc(CallerContextExpression):
for node_parser in self._node_to_yulobject.values(): for node_parser in self._node_to_yulobject.values():
node_parser.analyze_expressions() node_parser.analyze_expressions()
self._filter_ternary() self._rewrite_ternary_as_if_else()
self._remove_alone_endif() self._remove_alone_endif()
@ -1336,7 +1336,7 @@ class FunctionSolc(CallerContextExpression):
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def _filter_ternary(self) -> bool: def _rewrite_ternary_as_if_else(self) -> bool:
ternary_found = True ternary_found = True
updated = False updated = False
while ternary_found: while ternary_found:

@ -87,7 +87,7 @@ class ModifierSolc(FunctionSolc):
for node in self._node_to_nodesolc.values(): for node in self._node_to_nodesolc.values():
node.analyze_expressions(self) node.analyze_expressions(self)
self._filter_ternary() self._rewrite_ternary_as_if_else()
self._remove_alone_endif() self._remove_alone_endif()
# self._analyze_read_write() # self._analyze_read_write()

@ -23,20 +23,29 @@ from slither.all_exceptions import SlitherException
# pylint: disable=protected-access # pylint: disable=protected-access
def f_expressions( def f_expressions(
e: AssignmentOperation, x: Union[Identifier, Literal, MemberAccess, IndexAccess] e: Union[AssignmentOperation, BinaryOperation, TupleExpression],
x: Union[Identifier, Literal, MemberAccess, IndexAccess],
) -> None: ) -> None:
e._expressions.append(x) e._expressions.append(x)
def f_call(e, x): def f_call(e: CallExpression, x):
e._arguments.append(x) e._arguments.append(x)
def f_expression(e, x): def f_call_value(e: CallExpression, x):
e._value = x
def f_call_gas(e: CallExpression, x):
e._gas = x
def f_expression(e: Union[TypeConversion, UnaryOperation, MemberAccess], x):
e._expression = x e._expression = x
def f_called(e, x): def f_called(e: CallExpression, x):
e._called = x e._called = x
@ -53,13 +62,20 @@ class SplitTernaryExpression:
self.condition = None self.condition = None
self.copy_expression(expression, self.true_expression, self.false_expression) self.copy_expression(expression, self.true_expression, self.false_expression)
def apply_copy( def conditional_not_ahead(
self, self,
next_expr: Expression, next_expr: Expression,
true_expression: Union[AssignmentOperation, MemberAccess], true_expression: Union[AssignmentOperation, MemberAccess],
false_expression: Union[AssignmentOperation, MemberAccess], false_expression: Union[AssignmentOperation, MemberAccess],
f: Callable, f: Callable,
) -> bool: ) -> bool:
# look ahead for parenthetical expression (.. ? .. : ..)
if (
isinstance(next_expr, TupleExpression)
and len(next_expr.expressions) == 1
and isinstance(next_expr.expressions[0], ConditionalExpression)
):
next_expr = next_expr.expressions[0]
if isinstance(next_expr, ConditionalExpression): if isinstance(next_expr, ConditionalExpression):
f(true_expression, copy.copy(next_expr.then_expression)) f(true_expression, copy.copy(next_expr.then_expression))
@ -71,7 +87,6 @@ class SplitTernaryExpression:
f(false_expression, copy.copy(next_expr)) f(false_expression, copy.copy(next_expr))
return True return True
# pylint: disable=too-many-branches
def copy_expression( def copy_expression(
self, expression: Expression, true_expression: Expression, false_expression: Expression self, expression: Expression, true_expression: Expression, false_expression: Expression
) -> None: ) -> None:
@ -87,57 +102,20 @@ class SplitTernaryExpression:
): ):
return return
# case of lib if isinstance(expression, (AssignmentOperation, BinaryOperation, TupleExpression)):
# (.. ? .. : ..).add
if isinstance(expression, MemberAccess):
next_expr = expression.expression
if self.apply_copy(next_expr, true_expression, false_expression, f_expression):
self.copy_expression(
next_expr, true_expression.expression, false_expression.expression
)
elif isinstance(expression, (AssignmentOperation, BinaryOperation, TupleExpression)):
true_expression._expressions = [] true_expression._expressions = []
false_expression._expressions = [] false_expression._expressions = []
self.convert_expressions(expression, true_expression, false_expression)
for next_expr in expression.expressions:
if isinstance(next_expr, IndexAccess):
# create an index access for each branch
if isinstance(next_expr.expression_right, ConditionalExpression):
next_expr = _handle_ternary_access(
next_expr, true_expression, false_expression
)
if self.apply_copy(next_expr, true_expression, false_expression, f_expressions):
# always on last arguments added
self.copy_expression(
next_expr,
true_expression.expressions[-1],
false_expression.expressions[-1],
)
elif isinstance(expression, CallExpression): elif isinstance(expression, CallExpression):
next_expr = expression.called next_expr = expression.called
self.convert_call_expression(expression, next_expr, true_expression, false_expression)
# case of lib elif isinstance(expression, (TypeConversion, UnaryOperation, MemberAccess)):
# (.. ? .. : ..).add
if self.apply_copy(next_expr, true_expression, false_expression, f_called):
self.copy_expression(next_expr, true_expression.called, false_expression.called)
true_expression._arguments = []
false_expression._arguments = []
for next_expr in expression.arguments:
if self.apply_copy(next_expr, true_expression, false_expression, f_call):
# always on last arguments added
self.copy_expression(
next_expr,
true_expression.arguments[-1],
false_expression.arguments[-1],
)
elif isinstance(expression, (TypeConversion, UnaryOperation)):
next_expr = expression.expression next_expr = expression.expression
if self.apply_copy(next_expr, true_expression, false_expression, f_expression): if self.conditional_not_ahead(
next_expr, true_expression, false_expression, f_expression
):
self.copy_expression( self.copy_expression(
expression.expression, expression.expression,
true_expression.expression, true_expression.expression,
@ -149,34 +127,90 @@ class SplitTernaryExpression:
f"Ternary operation not handled {expression}({type(expression)})" f"Ternary operation not handled {expression}({type(expression)})"
) )
def convert_expressions(
self,
expression: Union[AssignmentOperation, BinaryOperation, TupleExpression],
true_expression: Expression,
false_expression: Expression,
) -> None:
for next_expr in expression.expressions:
# TODO: can we get rid of `NoneType` expressions in `TupleExpression`?
# montyly: this might happen with unnamed tuple (ex: (,,,) = f()), but it needs to be checked
if next_expr:
if isinstance(next_expr, IndexAccess):
self.convert_index_access(next_expr, true_expression, false_expression)
if self.conditional_not_ahead(
next_expr, true_expression, false_expression, f_expressions
):
# always on last arguments added
self.copy_expression(
next_expr,
true_expression.expressions[-1],
false_expression.expressions[-1],
)
def convert_index_access(
self, next_expr: IndexAccess, true_expression: Expression, false_expression: Expression
) -> None:
# create an index access for each branch
# x[if cond ? 1 : 2] -> if cond { x[1] } else { x[2] }
for expr in next_expr.expressions:
if self.conditional_not_ahead(expr, true_expression, false_expression, f_expressions):
self.copy_expression(
expr,
true_expression.expressions[-1],
false_expression.expressions[-1],
)
def _handle_ternary_access( def convert_call_expression(
next_expr: IndexAccess, self,
true_expression: AssignmentOperation, expression: CallExpression,
false_expression: AssignmentOperation, next_expr: Expression,
): true_expression: Expression,
""" false_expression: Expression,
Conditional ternary accesses are split into two accesses, one true and one false ) -> None:
E.g. x[if cond ? 1 : 2] -> if cond { x[1] } else { x[2] } # case of lib
""" # (.. ? .. : ..).add
true_index_access = IndexAccess( if self.conditional_not_ahead(next_expr, true_expression, false_expression, f_called):
next_expr.expression_left, self.copy_expression(next_expr, true_expression.called, false_expression.called)
next_expr.expression_right.then_expression,
next_expr.type, # In order to handle ternaries in both call options, gas and value, we return early if the
) # conditional is not ahead to rewrite both ternaries (see `_rewrite_ternary_as_if_else`).
false_index_access = IndexAccess( if expression.call_gas:
next_expr.expression_left, # case of (..).func{gas: .. ? .. : ..}()
next_expr.expression_right.else_expression, next_expr = expression.call_gas
next_expr.type, if self.conditional_not_ahead(next_expr, true_expression, false_expression, f_call_gas):
) self.copy_expression(
next_expr,
f_expressions( true_expression.call_gas,
true_expression, false_expression.call_gas,
true_index_access, )
) else:
f_expressions( return
false_expression,
false_index_access, if expression.call_value:
) # case of (..).func{value: .. ? .. : ..}()
next_expr = expression.call_value
return next_expr.expression_right if self.conditional_not_ahead(
next_expr, true_expression, false_expression, f_call_value
):
self.copy_expression(
next_expr,
true_expression.call_value,
false_expression.call_value,
)
else:
return
true_expression._arguments = []
false_expression._arguments = []
for expr in expression.arguments:
if self.conditional_not_ahead(expr, true_expression, false_expression, f_call):
# always on last arguments added
self.copy_expression(
expr,
true_expression.arguments[-1],
false_expression.arguments[-1],
)

@ -1,3 +1,7 @@
interface Test {
function test() external payable returns (uint);
function testTuple() external payable returns (uint, uint);
}
contract C { contract C {
// TODO // TODO
// 1) support variable declarations // 1) support variable declarations
@ -21,4 +25,18 @@ contract C {
function d(bool cond, bytes calldata x) external { function d(bool cond, bytes calldata x) external {
bytes1 a = x[cond ? 1 : 2]; bytes1 a = x[cond ? 1 : 2];
} }
function e(address one, address two) public {
uint x = Test(one).test{value: msg.sender == two ? 1 : 2, gas: true ? 2 : gasleft()}();
}
// Parenthetical expression
function f(address one, address two) public {
uint x = Test(one).test{value: msg.sender == two ? 1 : 2, gas: true ? (1 == 1 ? 1 : 2) : gasleft()}();
}
// Unused tuple variable
function g(address one) public {
(, uint x) = Test(one).testTuple();
}
} }

@ -9,10 +9,10 @@ def test_ternary_conversions() -> None:
slither = Slither("./tests/slithir/ternary_expressions.sol") slither = Slither("./tests/slithir/ternary_expressions.sol")
for contract in slither.contracts: for contract in slither.contracts:
for function in contract.functions: for function in contract.functions:
vars_declared = 0
vars_assigned = 0
for node in function.nodes: for node in function.nodes:
if node.type in [NodeType.IF, NodeType.IFLOOP]: if node.type in [NodeType.IF, NodeType.IFLOOP]:
vars_declared = 0
vars_assigned = 0
# Iterate over true and false son # Iterate over true and false son
for inner_node in node.sons: for inner_node in node.sons:
@ -31,7 +31,7 @@ def test_ternary_conversions() -> None:
if isinstance(ir, Assignment): if isinstance(ir, Assignment):
vars_assigned += 1 vars_assigned += 1
assert vars_declared == vars_assigned assert vars_declared == vars_assigned
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save