Merge pull request #1501 from crytic/call-value-ternary

Call value ternary
pull/1554/head
Feist Josselin 2 years ago committed by GitHub
commit 81c2a46d5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      slither/solc_parsing/declarations/function.py
  2. 2
      slither/solc_parsing/declarations/modifier.py
  3. 194
      slither/utils/expression_manipulations.py
  4. 18
      tests/slithir/ternary_expressions.sol
  5. 6
      tests/slithir/test_ternary_expressions.py

@ -308,7 +308,7 @@ class FunctionSolc(CallerContextExpression):
for node_parser in self._node_to_yulobject.values():
node_parser.analyze_expressions()
self._filter_ternary()
self._rewrite_ternary_as_if_else()
self._remove_alone_endif()
@ -1336,7 +1336,7 @@ class FunctionSolc(CallerContextExpression):
###################################################################################
###################################################################################
def _filter_ternary(self) -> bool:
def _rewrite_ternary_as_if_else(self) -> bool:
ternary_found = True
updated = False
while ternary_found:

@ -87,7 +87,7 @@ class ModifierSolc(FunctionSolc):
for node in self._node_to_nodesolc.values():
node.analyze_expressions(self)
self._filter_ternary()
self._rewrite_ternary_as_if_else()
self._remove_alone_endif()
# self._analyze_read_write()

@ -23,20 +23,29 @@ from slither.all_exceptions import SlitherException
# pylint: disable=protected-access
def f_expressions(
e: AssignmentOperation, x: Union[Identifier, Literal, MemberAccess, IndexAccess]
e: Union[AssignmentOperation, BinaryOperation, TupleExpression],
x: Union[Identifier, Literal, MemberAccess, IndexAccess],
) -> None:
e._expressions.append(x)
def f_call(e, x):
def f_call(e: CallExpression, x):
e._arguments.append(x)
def f_expression(e, x):
def f_call_value(e: CallExpression, x):
e._value = x
def f_call_gas(e: CallExpression, x):
e._gas = x
def f_expression(e: Union[TypeConversion, UnaryOperation, MemberAccess], x):
e._expression = x
def f_called(e, x):
def f_called(e: CallExpression, x):
e._called = x
@ -53,13 +62,20 @@ class SplitTernaryExpression:
self.condition = None
self.copy_expression(expression, self.true_expression, self.false_expression)
def apply_copy(
def conditional_not_ahead(
self,
next_expr: Expression,
true_expression: Union[AssignmentOperation, MemberAccess],
false_expression: Union[AssignmentOperation, MemberAccess],
f: Callable,
) -> bool:
# look ahead for parenthetical expression (.. ? .. : ..)
if (
isinstance(next_expr, TupleExpression)
and len(next_expr.expressions) == 1
and isinstance(next_expr.expressions[0], ConditionalExpression)
):
next_expr = next_expr.expressions[0]
if isinstance(next_expr, ConditionalExpression):
f(true_expression, copy.copy(next_expr.then_expression))
@ -71,7 +87,6 @@ class SplitTernaryExpression:
f(false_expression, copy.copy(next_expr))
return True
# pylint: disable=too-many-branches
def copy_expression(
self, expression: Expression, true_expression: Expression, false_expression: Expression
) -> None:
@ -87,57 +102,20 @@ class SplitTernaryExpression:
):
return
# case of lib
# (.. ? .. : ..).add
if isinstance(expression, MemberAccess):
next_expr = expression.expression
if self.apply_copy(next_expr, true_expression, false_expression, f_expression):
self.copy_expression(
next_expr, true_expression.expression, false_expression.expression
)
elif isinstance(expression, (AssignmentOperation, BinaryOperation, TupleExpression)):
if isinstance(expression, (AssignmentOperation, BinaryOperation, TupleExpression)):
true_expression._expressions = []
false_expression._expressions = []
for next_expr in expression.expressions:
if isinstance(next_expr, IndexAccess):
# create an index access for each branch
if isinstance(next_expr.expression_right, ConditionalExpression):
next_expr = _handle_ternary_access(
next_expr, true_expression, false_expression
)
if self.apply_copy(next_expr, true_expression, false_expression, f_expressions):
# always on last arguments added
self.copy_expression(
next_expr,
true_expression.expressions[-1],
false_expression.expressions[-1],
)
self.convert_expressions(expression, true_expression, false_expression)
elif isinstance(expression, CallExpression):
next_expr = expression.called
self.convert_call_expression(expression, next_expr, true_expression, false_expression)
# case of lib
# (.. ? .. : ..).add
if self.apply_copy(next_expr, true_expression, false_expression, f_called):
self.copy_expression(next_expr, true_expression.called, false_expression.called)
true_expression._arguments = []
false_expression._arguments = []
for next_expr in expression.arguments:
if self.apply_copy(next_expr, true_expression, false_expression, f_call):
# always on last arguments added
self.copy_expression(
next_expr,
true_expression.arguments[-1],
false_expression.arguments[-1],
)
elif isinstance(expression, (TypeConversion, UnaryOperation)):
elif isinstance(expression, (TypeConversion, UnaryOperation, MemberAccess)):
next_expr = expression.expression
if self.apply_copy(next_expr, true_expression, false_expression, f_expression):
if self.conditional_not_ahead(
next_expr, true_expression, false_expression, f_expression
):
self.copy_expression(
expression.expression,
true_expression.expression,
@ -149,34 +127,90 @@ class SplitTernaryExpression:
f"Ternary operation not handled {expression}({type(expression)})"
)
def convert_expressions(
self,
expression: Union[AssignmentOperation, BinaryOperation, TupleExpression],
true_expression: Expression,
false_expression: Expression,
) -> None:
for next_expr in expression.expressions:
# TODO: can we get rid of `NoneType` expressions in `TupleExpression`?
# montyly: this might happen with unnamed tuple (ex: (,,,) = f()), but it needs to be checked
if next_expr:
if isinstance(next_expr, IndexAccess):
self.convert_index_access(next_expr, true_expression, false_expression)
if self.conditional_not_ahead(
next_expr, true_expression, false_expression, f_expressions
):
# always on last arguments added
self.copy_expression(
next_expr,
true_expression.expressions[-1],
false_expression.expressions[-1],
)
def convert_index_access(
self, next_expr: IndexAccess, true_expression: Expression, false_expression: Expression
) -> None:
# create an index access for each branch
# x[if cond ? 1 : 2] -> if cond { x[1] } else { x[2] }
for expr in next_expr.expressions:
if self.conditional_not_ahead(expr, true_expression, false_expression, f_expressions):
self.copy_expression(
expr,
true_expression.expressions[-1],
false_expression.expressions[-1],
)
def _handle_ternary_access(
next_expr: IndexAccess,
true_expression: AssignmentOperation,
false_expression: AssignmentOperation,
):
"""
Conditional ternary accesses are split into two accesses, one true and one false
E.g. x[if cond ? 1 : 2] -> if cond { x[1] } else { x[2] }
"""
true_index_access = IndexAccess(
next_expr.expression_left,
next_expr.expression_right.then_expression,
next_expr.type,
)
false_index_access = IndexAccess(
next_expr.expression_left,
next_expr.expression_right.else_expression,
next_expr.type,
)
f_expressions(
true_expression,
true_index_access,
)
f_expressions(
false_expression,
false_index_access,
)
return next_expr.expression_right
def convert_call_expression(
self,
expression: CallExpression,
next_expr: Expression,
true_expression: Expression,
false_expression: Expression,
) -> None:
# case of lib
# (.. ? .. : ..).add
if self.conditional_not_ahead(next_expr, true_expression, false_expression, f_called):
self.copy_expression(next_expr, true_expression.called, false_expression.called)
# In order to handle ternaries in both call options, gas and value, we return early if the
# conditional is not ahead to rewrite both ternaries (see `_rewrite_ternary_as_if_else`).
if expression.call_gas:
# case of (..).func{gas: .. ? .. : ..}()
next_expr = expression.call_gas
if self.conditional_not_ahead(next_expr, true_expression, false_expression, f_call_gas):
self.copy_expression(
next_expr,
true_expression.call_gas,
false_expression.call_gas,
)
else:
return
if expression.call_value:
# case of (..).func{value: .. ? .. : ..}()
next_expr = expression.call_value
if self.conditional_not_ahead(
next_expr, true_expression, false_expression, f_call_value
):
self.copy_expression(
next_expr,
true_expression.call_value,
false_expression.call_value,
)
else:
return
true_expression._arguments = []
false_expression._arguments = []
for expr in expression.arguments:
if self.conditional_not_ahead(expr, true_expression, false_expression, f_call):
# always on last arguments added
self.copy_expression(
expr,
true_expression.arguments[-1],
false_expression.arguments[-1],
)

@ -1,3 +1,7 @@
interface Test {
function test() external payable returns (uint);
function testTuple() external payable returns (uint, uint);
}
contract C {
// TODO
// 1) support variable declarations
@ -21,4 +25,18 @@ contract C {
function d(bool cond, bytes calldata x) external {
bytes1 a = x[cond ? 1 : 2];
}
function e(address one, address two) public {
uint x = Test(one).test{value: msg.sender == two ? 1 : 2, gas: true ? 2 : gasleft()}();
}
// Parenthetical expression
function f(address one, address two) public {
uint x = Test(one).test{value: msg.sender == two ? 1 : 2, gas: true ? (1 == 1 ? 1 : 2) : gasleft()}();
}
// Unused tuple variable
function g(address one) public {
(, uint x) = Test(one).testTuple();
}
}

@ -9,10 +9,10 @@ def test_ternary_conversions() -> None:
slither = Slither("./tests/slithir/ternary_expressions.sol")
for contract in slither.contracts:
for function in contract.functions:
vars_declared = 0
vars_assigned = 0
for node in function.nodes:
if node.type in [NodeType.IF, NodeType.IFLOOP]:
vars_declared = 0
vars_assigned = 0
# Iterate over true and false son
for inner_node in node.sons:
@ -31,7 +31,7 @@ def test_ternary_conversions() -> None:
if isinstance(ir, Assignment):
vars_assigned += 1
assert vars_declared == vars_assigned
assert vars_declared == vars_assigned
if __name__ == "__main__":

Loading…
Cancel
Save