support default args in calls

pull/2099/head
alpharush 1 year ago
parent e287b2f905
commit 81cb124e2d
  1. 2
      slither/core/declarations/function.py
  2. 5
      slither/vyper_parsing/declarations/function.py
  3. 11
      slither/vyper_parsing/expressions/expression_parsing.py
  4. 35
      tests/unit/slithir/vyper/test_ir_generation.py

@ -138,6 +138,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
self._parameters: List["LocalVariable"] = []
self._parameters_ssa: List["LocalIRVariable"] = []
self._parameters_src: SourceMapping = SourceMapping()
# This is used for vyper calls with default arguments
self._default_args_as_expressions: List["Expression"] = []
self._returns: List["LocalVariable"] = []
self._returns_ssa: List["LocalIRVariable"] = []
self._returns_src: SourceMapping = SourceMapping()

@ -31,7 +31,7 @@ class FunctionVyper:
def __init__(
self,
function: Function,
function_data: Dict,
function_data: FunctionDef,
contract_parser: "ContractVyper",
) -> None:
@ -503,7 +503,8 @@ class FunctionVyper:
print(params)
self._function.parameters_src().set_offset(params.src, self._function.compilation_unit)
if params.defaults:
self._function._default_args_as_expressions = params.defaults
for param in params.args:
local_var = self._add_param(param)
self._function.add_parameters(local_var.underlying_variable)

@ -198,6 +198,17 @@ def parse_expression(expression: Dict, caller_context) -> "Expression":
# Since the AST lacks the type of the return values, we recover it.
if isinstance(called.value, Function):
rets = called.value.returns
# Default arguments are not represented in the AST, so we recover them as well.
if called.value._default_args_as_expressions and len(arguments) < len(
called.value.parameters
):
arguments.extend(
[
parse_expression(x, caller_context)
for x in called.value._default_args_as_expressions
]
)
elif isinstance(called.value, SolidityFunction):
rets = called.value.return_type
elif isinstance(called.value, Contract):

@ -10,7 +10,7 @@ import pytest
from slither import Slither
from slither.core.cfg.node import Node, NodeType
from slither.core.declarations import Function, Contract
from slither.core.solidity_types import ArrayType
from slither.core.solidity_types import ArrayType, ElementaryType
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import (
@ -71,7 +71,7 @@ def test_phi_entry_point_internal_call(slither_from_vyper_source):
counter: uint256
@internal
def b(y: uint256):
self.counter = y # tainted by x, 1
self.counter = y
@external
def a(x: uint256):
@ -91,3 +91,34 @@ def a(x: uint256):
)
== 1
)
def test_call_with_default_args(slither_from_vyper_source):
with slither_from_vyper_source(
"""
counter: uint256
@internal
def c(y: uint256, config: bool = True):
self.counter = y
@external
def a(x: uint256):
self.c(x)
self.c(1)
@external
def b(x: uint256):
self.c(x, False)
self.c(1, False)
"""
) as sl:
a = sl.contracts[0].get_function_from_signature("a(uint256)")
for node in a.nodes:
for op in node.irs_ssa:
if isinstance(op, InternalCall) and op.function.name == "c":
assert len(op.arguments) == 2
assert op.arguments[1] == Constant("True", ElementaryType("bool"))
b = sl.contracts[0].get_function_from_signature("b(uint256)")
for node in b.nodes:
for op in node.irs_ssa:
if isinstance(op, InternalCall) and op.function.name == "c":
assert len(op.arguments) == 2
assert op.arguments[1] == Constant("False", ElementaryType("bool"))

Loading…
Cancel
Save