cut down on branches and nested blocks
 by adding an `is_function_modified` helper function
pull/1699/head
webthethird 2 years ago
parent f08d2afb3e
commit ebd2201bdd
  1. 69
      slither/utils/upgradeability.py

@ -1,7 +1,28 @@
from slither.core.declarations.contract import Contract
from slither.core.declarations.function import Function
def compare(v1: Contract, v2: Contract):
# pylint: disable=too-many-locals
def compare(v1: Contract, v2: Contract) -> dict:
"""
Compares two versions of a contract. Most useful for upgradeable (logic) contracts,
but does not require that Contract.is_upgradeable returns true for either contract.
Args:
v1: Original version of (upgradeable) contract
v2: Updated version of (upgradeable) contract
Returns: dict {
"missing-vars-in-v2": list[Variable],
"new-variables": list[Variable],
"tainted-variables": list[Variable],
"new-functions": list[Function],
"modified-functions": list[Function],
"tainted-functions": list[Function]
}
"""
order_vars1 = [v for v in v1.state_variables if not v.is_constant and not v.is_immutable]
order_vars2 = [v for v in v2.state_variables if not v.is_constant and not v.is_immutable]
func_sigs1 = [function.solidity_signature for function in v1.functions]
@ -13,7 +34,7 @@ def compare(v1: Contract, v2: Contract):
"tainted-variables": [],
"new-functions": [],
"modified-functions": [],
"tainted-functions": []
"tainted-functions": [],
}
# Since this is not a detector, include any missing variables in the v2 contract
@ -26,20 +47,11 @@ def compare(v1: Contract, v2: Contract):
new_modified_functions = []
for sig in func_sigs2:
function = v2.get_function_from_signature(sig)
orig_function = v1.get_function_from_signature(sig)
if sig not in func_sigs1:
new_modified_functions.append(function)
results["new-functions"].append(function)
else:
orig_function = v1.get_function_from_signature(sig)
# If the function content hashes are the same, no need to investigate the function further
if function.source_mapping.content_hash != orig_function.source_mapping.content_hash:
# If the hashes differ, it is possible a change in a name or in a comment could be the only difference
# So we need to resort to walking through the CFG and comparing the IR operations
for i, node in enumerate(function.nodes):
if function in new_modified_functions:
break
for j, ir in enumerate(node.irs):
if ir != orig_function.nodes[i].irs[j]:
elif is_function_modified(orig_function, function):
new_modified_functions.append(function)
results["modified-functions"].append(function)
@ -47,7 +59,9 @@ def compare(v1: Contract, v2: Contract):
for function in v2.functions:
if function in new_modified_functions:
continue
modified_calls = [funct for func in new_modified_functions if func in function.internal_calls]
modified_calls = [
func for func in new_modified_functions if func in function.internal_calls
]
if len(modified_calls) > 0:
results["tainted-functions"].append(function)
@ -57,5 +71,30 @@ def compare(v1: Contract, v2: Contract):
written_by = v2.get_functions_writing_to_variable(var)
if len(order_vars1) <= idx:
results["new-variables"].append(var)
elif any([func in read_by or func in written_by for func in new_modified_functions]):
elif any(func in read_by or func in written_by for func in new_modified_functions):
results["tainted-variables"].append(var)
def is_function_modified(f1: Function, f2: Function) -> bool:
"""
Compares two versions of a function, and returns True if the function has been modified.
First checks whether the functions' content hashes are equal to quickly rule out identical functions.
Walks the CFGs and compares IR operations if hashes differ to rule out false positives, i.e., from changed comments.
Args:
f1: Original version of the function
f2: New version of the function
Returns: True if the functions differ, otherwise False
"""
# If the function content hashes are the same, no need to investigate the function further
if f1.source_mapping.content_hash == f2.source_mapping.content_hash:
return False
# If the hashes differ, it is possible a change in a name or in a comment could be the only difference
# So we need to resort to walking through the CFG and comparing the IR operations
for i, node in enumerate(f2.nodes):
for j, ir in enumerate(node.irs):
if ir != f1.nodes[i].irs[j]:
return True
return False

Loading…
Cancel
Save