Merge pull request #1784 from Troublor/array-type

Make type information of NewArray more precise
pull/1936/head
Feist Josselin 1 year ago committed by GitHub
commit bb76a5ba69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 29
      slither/core/expressions/new_array.py
  2. 2
      slither/slithir/convert.py
  3. 2
      slither/slithir/operations/init_array.py
  4. 22
      slither/slithir/operations/new_array.py
  5. 8
      slither/slithir/tmp_operations/tmp_new_array.py
  6. 3
      slither/slithir/utils/ssa.py
  7. 34
      slither/solc_parsing/expressions/expression_parsing.py
  8. 3
      slither/visitors/expression/expression_printer.py
  9. 2
      slither/visitors/slithir/expression_to_slithir.py
  10. 32
      tests/unit/slithir/test_ssa_generation.py

@ -1,32 +1,23 @@
from typing import Union, TYPE_CHECKING
from typing import TYPE_CHECKING
from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type
if TYPE_CHECKING:
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.solidity_types.type_alias import TypeAliasTopLevel
from slither.core.solidity_types.array_type import ArrayType
class NewArray(Expression):
# note: dont conserve the size of the array if provided
def __init__(
self, depth: int, array_type: Union["TypeAliasTopLevel", "ElementaryType"]
) -> None:
def __init__(self, array_type: "ArrayType") -> None:
super().__init__()
assert isinstance(array_type, Type)
self._depth: int = depth
self._array_type: Type = array_type
# pylint: disable=import-outside-toplevel
from slither.core.solidity_types.array_type import ArrayType
@property
def array_type(self) -> Type:
return self._array_type
assert isinstance(array_type, ArrayType)
self._array_type = array_type
@property
def depth(self) -> int:
return self._depth
def array_type(self) -> "ArrayType":
return self._array_type
def __str__(self):
return "new " + str(self._array_type) + "[]" * self._depth
return "new " + str(self._array_type)

@ -1077,7 +1077,7 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]) -> Union[Call,
return op
if isinstance(ins.ori, TmpNewArray):
n = NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue)
n = NewArray(ins.ori.array_type, ins.lvalue)
n.set_expression(ins.expression)
return n

@ -41,7 +41,7 @@ class InitArray(OperationWithLValue):
def convert(elem):
if isinstance(elem, (list,)):
return str([convert(x) for x in elem])
return str(elem)
return f"{elem}({elem.type})"
init_values = convert(self.init_values)
return f"{self.lvalue}({self.lvalue.type}) = {init_values}"

@ -1,10 +1,10 @@
from typing import List, Union, TYPE_CHECKING
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.solidity_types.array_type import ArrayType
from slither.slithir.operations.call import Call
from slither.core.solidity_types.type import Type
from slither.slithir.operations.lvalue import OperationWithLValue
if TYPE_CHECKING:
from slither.core.solidity_types.type_alias import TypeAliasTopLevel
from slither.slithir.variables.constant import Constant
from slither.slithir.variables.temporary import TemporaryVariable
from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA
@ -13,32 +13,24 @@ if TYPE_CHECKING:
class NewArray(Call, OperationWithLValue):
def __init__(
self,
depth: int,
array_type: "TypeAliasTopLevel",
array_type: "ArrayType",
lvalue: Union["TemporaryVariableSSA", "TemporaryVariable"],
) -> None:
super().__init__()
assert isinstance(array_type, Type)
self._depth = depth
assert isinstance(array_type, ArrayType)
self._array_type = array_type
self._lvalue = lvalue
@property
def array_type(self) -> "TypeAliasTopLevel":
def array_type(self) -> "ArrayType":
return self._array_type
@property
def read(self) -> List["Constant"]:
return self._unroll(self.arguments)
@property
def depth(self) -> int:
return self._depth
def __str__(self):
args = [str(a) for a in self.arguments]
lvalue = self.lvalue
return (
f"{lvalue}({lvalue.type}) = new {self.array_type}{'[]' * self.depth}({','.join(args)})"
)
return f"{lvalue}{lvalue.type}) = new {self.array_type}({','.join(args)})"

@ -6,13 +6,11 @@ from slither.slithir.variables.temporary import TemporaryVariable
class TmpNewArray(OperationWithLValue):
def __init__(
self,
depth: int,
array_type: Type,
lvalue: TemporaryVariable,
) -> None:
super().__init__()
assert isinstance(array_type, Type)
self._depth = depth
self._array_type = array_type
self._lvalue = lvalue
@ -24,9 +22,5 @@ class TmpNewArray(OperationWithLValue):
def read(self):
return []
@property
def depth(self) -> int:
return self._depth
def __str__(self):
return f"{self.lvalue} = new {self.array_type}{'[]' * self._depth}"
return f"{self.lvalue} = new {self.array_type}"

@ -789,10 +789,9 @@ def copy_ir(ir: Operation, *instances) -> Operation:
variable_right = get_variable(ir, lambda x: x.variable_right, *instances)
return Member(variable_left, variable_right, lvalue)
if isinstance(ir, NewArray):
depth = ir.depth
array_type = ir.array_type
lvalue = get_variable(ir, lambda x: x.lvalue, *instances)
new_ir = NewArray(depth, array_type, lvalue)
new_ir = NewArray(array_type, lvalue)
new_ir.arguments = get_rec_values(ir, lambda x: x.arguments, *instances)
return new_ir
if isinstance(ir, NewElementaryType):

@ -559,37 +559,9 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression)
type_name = children[0]
if type_name[caller_context.get_key()] == "ArrayTypeName":
depth = 0
while type_name[caller_context.get_key()] == "ArrayTypeName":
# Note: dont conserve the size of the array if provided
# We compute it directly
if is_compact_ast:
type_name = type_name["baseType"]
else:
type_name = type_name["children"][0]
depth += 1
if type_name[caller_context.get_key()] == "ElementaryTypeName":
if is_compact_ast:
array_type = ElementaryType(type_name["name"])
else:
array_type = ElementaryType(type_name["attributes"]["name"])
elif type_name[caller_context.get_key()] == "UserDefinedTypeName":
if is_compact_ast:
if "name" not in type_name:
name_type = type_name["pathNode"]["name"]
else:
name_type = type_name["name"]
array_type = parse_type(UnknownType(name_type), caller_context)
else:
array_type = parse_type(
UnknownType(type_name["attributes"]["name"]), caller_context
)
elif type_name[caller_context.get_key()] == "FunctionTypeName":
array_type = parse_type(type_name, caller_context)
else:
raise ParsingError(f"Incorrect type array {type_name}")
array = NewArray(depth, array_type)
array_type = parse_type(type_name, caller_context)
assert isinstance(array_type, ArrayType)
array = NewArray(array_type)
array.set_offset(src, caller_context.compilation_unit)
return array

@ -76,8 +76,7 @@ class ExpressionPrinter(ExpressionVisitor):
def _post_new_array(self, expression: expressions.NewArray) -> None:
array = str(expression.array_type)
depth = expression.depth
val = f"new {array}{'[]' * depth}"
val = f"new {array}"
set_val(expression, val)
def _post_new_contract(self, expression: expressions.NewContract) -> None:

@ -532,7 +532,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
def _post_new_array(self, expression: NewArray) -> None:
val = TemporaryVariable(self._node)
operation = TmpNewArray(expression.depth, expression.array_type, val)
operation = TmpNewArray(expression.array_type, val)
operation.set_expression(expression)
self._result.append(operation)
set_val(expression, val)

@ -1,15 +1,17 @@
# # pylint: disable=too-many-lines
import pathlib
from collections import defaultdict
from argparse import ArgumentTypeError
from collections import defaultdict
from inspect import getsourcefile
from typing import Union, List, Dict, Callable
import pytest
from solc_select.solc_select import valid_version as solc_valid_version
from slither import Slither
from slither.core.cfg.node import Node, NodeType
from slither.core.declarations import Function, Contract
from slither.core.solidity_types import ArrayType
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import (
@ -1050,6 +1052,34 @@ def test_issue_1748(slither_from_source):
assert isinstance(assign_op, InitArray)
def test_issue_1776(slither_from_source):
source = """
contract Contract {
function foo() public returns (uint) {
uint[5][10][] memory arr = new uint[5][10][](2);
return 0;
}
}
"""
with slither_from_source(source) as slither:
c = slither.get_contract_from_name("Contract")[0]
f = c.functions[0]
operations = f.slithir_operations
new_op = operations[0]
lvalue = new_op.lvalue
lvalue_type = lvalue.type
assert isinstance(lvalue_type, ArrayType)
assert lvalue_type.is_dynamic
lvalue_type1 = lvalue_type.type
assert isinstance(lvalue_type1, ArrayType)
assert not lvalue_type1.is_dynamic
assert lvalue_type1.length_value.value == "10"
lvalue_type2 = lvalue_type1.type
assert isinstance(lvalue_type2, ArrayType)
assert not lvalue_type2.is_dynamic
assert lvalue_type2.length_value.value == "5"
def test_issue_1846_ternary_in_if(slither_from_source):
source = """
contract Contract {

Loading…
Cancel
Save