Add types to function signatures

pull/1757/head
webthethird 2 years ago
parent a292d8de79
commit 78e2ea37da
  1. 50
      slither/utils/upgradeability.py

@ -1,4 +1,4 @@
from typing import Optional, Tuple, List from typing import Optional, Tuple, List, Union
from slither.core.declarations import ( from slither.core.declarations import (
Contract, Contract,
Structure, Structure,
@ -8,6 +8,7 @@ from slither.core.declarations import (
Function, Function,
) )
from slither.core.solidity_types import ( from slither.core.solidity_types import (
Type,
ElementaryType, ElementaryType,
ArrayType, ArrayType,
MappingType, MappingType,
@ -24,6 +25,7 @@ from slither.core.expressions.call_expression import CallExpression
from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.assignment_operation import AssignmentOperation
from slither.core.cfg.node import Node, NodeType from slither.core.cfg.node import Node, NodeType
from slither.slithir.operations import ( from slither.slithir.operations import (
Operation,
Assignment, Assignment,
Index, Index,
Member, Member,
@ -91,18 +93,16 @@ def compare(
func_sigs1 = [function.solidity_signature for function in v1.functions] func_sigs1 = [function.solidity_signature for function in v1.functions]
func_sigs2 = [function.solidity_signature for function in v2.functions] func_sigs2 = [function.solidity_signature for function in v2.functions]
results = { missing_vars_in_v2 = []
"missing-vars-in-v2": [], new_variables = []
"new-variables": [], tainted_variables = []
"tainted-variables": [], new_functions = []
"new-functions": [], modified_functions = []
"modified-functions": [], tainted_functions = []
"tainted-functions": [],
}
# Since this is not a detector, include any missing variables in the v2 contract # Since this is not a detector, include any missing variables in the v2 contract
if len(order_vars2) < len(order_vars1): if len(order_vars2) < len(order_vars1):
results["missing-vars-in-v2"].extend(get_missing_vars(v1, v2)) missing_vars_in_v2.extend(get_missing_vars(v1, v2))
# Find all new and modified functions in the v2 contract # Find all new and modified functions in the v2 contract
new_modified_functions = [] new_modified_functions = []
@ -112,7 +112,7 @@ def compare(
orig_function = v1.get_function_from_signature(sig) orig_function = v1.get_function_from_signature(sig)
if sig not in func_sigs1: if sig not in func_sigs1:
new_modified_functions.append(function) new_modified_functions.append(function)
results["new-functions"].append(function) new_functions.append(function)
new_modified_function_vars += ( new_modified_function_vars += (
function.state_variables_read + function.state_variables_written function.state_variables_read + function.state_variables_written
) )
@ -120,7 +120,7 @@ def compare(
orig_function, function orig_function, function
): ):
new_modified_functions.append(function) new_modified_functions.append(function)
results["modified-functions"].append(function) modified_functions.append(function)
new_modified_function_vars += ( new_modified_function_vars += (
function.state_variables_read + function.state_variables_written function.state_variables_read + function.state_variables_written
) )
@ -145,27 +145,27 @@ def compare(
and not var.is_immutable and not var.is_immutable
] ]
if len(modified_calls) > 0 or len(tainted_vars) > 0: if len(modified_calls) > 0 or len(tainted_vars) > 0:
results["tainted-functions"].append(function) tainted_functions.append(function)
# Find all new or tainted variables, i.e., variables that are read or written by a new/modified/tainted function # Find all new or tainted variables, i.e., variables that are read or written by a new/modified/tainted function
for var in order_vars2: for var in order_vars2:
read_by = v2.get_functions_reading_from_variable(var) read_by = v2.get_functions_reading_from_variable(var)
written_by = v2.get_functions_writing_to_variable(var) written_by = v2.get_functions_writing_to_variable(var)
if v1.get_state_variable_from_name(var.name) is None: if v1.get_state_variable_from_name(var.name) is None:
results["new-variables"].append(var) new_variables.append(var)
elif any( elif any(
func in read_by or func in written_by func in read_by or func in written_by
for func in new_modified_functions + results["tainted-functions"] for func in new_modified_functions + tainted_functions
): ):
results["tainted-variables"].append(var) tainted_variables.append(var)
return ( return (
results["missing-vars-in-v2"], missing_vars_in_v2,
results["new-variables"], new_variables,
results["tainted-variables"], tainted_variables,
results["new-functions"], new_functions,
results["modified-functions"], modified_functions,
results["tainted-functions"], tainted_functions,
) )
@ -226,7 +226,7 @@ def is_function_modified(f1: Function, f2: Function) -> bool:
return False return False
def ntype(_type): # pylint: disable=too-many-branches def ntype(_type: Union[Type, str]) -> str: # pylint: disable=too-many-branches
if isinstance(_type, ElementaryType): if isinstance(_type, ElementaryType):
_type = str(_type) _type = str(_type)
elif isinstance(_type, ArrayType): elif isinstance(_type, ArrayType):
@ -261,7 +261,9 @@ def ntype(_type): # pylint: disable=too-many-branches
return _type.replace(" ", "_") return _type.replace(" ", "_")
def encode_ir_for_compare(ir) -> str: # pylint: disable=too-many-branches def encode_ir_for_compare(
ir: Union[Operation, Variable]
) -> str: # pylint: disable=too-many-branches
# operations # operations
if isinstance(ir, Assignment): if isinstance(ir, Assignment):
return f"({encode_ir_for_compare(ir.lvalue)}):=({encode_ir_for_compare(ir.rvalue)})" return f"({encode_ir_for_compare(ir.lvalue)}):=({encode_ir_for_compare(ir.rvalue)})"

Loading…
Cancel
Save