From ea8d232f09552ba4f1837d1f34a1d08d74bc81e4 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Fri, 13 Jan 2023 15:17:04 +0100 Subject: [PATCH] Minor API improvements --- slither/core/source_mapping/source_mapping.py | 40 +++++++++++--- slither/detectors/functions/codex.py | 6 +- slither/printers/summary/constructor_calls.py | 18 +++--- slither/tools/flattening/flattening.py | 6 +- slither/utils/arithmetic.py | 55 +++++++++++++++++++ slither/utils/standard_libraries.py | 7 +-- tests/arithmetic_usage/test.sol | 29 ++++++++++ tests/test_features.py | 10 ++++ 8 files changed, 137 insertions(+), 34 deletions(-) create mode 100644 tests/arithmetic_usage/test.sol diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index ee5211c7c..a0fcf354a 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -2,8 +2,8 @@ import re from abc import ABCMeta from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional +from Crypto.Hash import SHA1 from crytic_compile.utils.naming import Filename - from slither.core.context.context import Context if TYPE_CHECKING: @@ -56,7 +56,7 @@ class Source: filename_short: str = self.filename.short if self.filename.short else "" return f"{filename_short}{lines} ({self.starting_column} - {self.ending_column})" - def _get_lines_str(self, line_descr=""): + def _get_lines_str(self, line_descr: str = "") -> str: # If the compilation unit was not initialized, it means that the set_offset was never called # on the corresponding object, which should not happen @@ -66,12 +66,36 @@ class Source: lines = self.lines if not lines: - lines = "" - elif len(lines) == 1: - lines = f"{line_prefix}{line_descr}{lines[0]}" - else: - lines = f"{line_prefix}{line_descr}{lines[0]}-{line_descr}{lines[-1]}" - return lines + return "" + if len(lines) == 1: + return f"{line_prefix}{line_descr}{lines[0]}" + + return f"{line_prefix}{line_descr}{lines[0]}-{line_descr}{lines[-1]}" + + @property + def content(self) -> str: + """ + Return the txt content of the Source + + Returns: + + """ + # If the compilation unit was not initialized, it means that the set_offset was never called + # on the corresponding object, which should not happen + assert self.compilation_unit + return self.compilation_unit.core.source_code[self.filename.absolute][self.start : self.end] + + @property + def content_hash(self) -> str: + """ + Return sha1(self.content) + + Returns: + + """ + h = SHA1.new() + h.update(self.content.encode("utf8")) + return h.hexdigest() def __str__(self) -> str: lines = self._get_lines_str() diff --git a/slither/detectors/functions/codex.py b/slither/detectors/functions/codex.py index fb00f64c0..48e1ffaa5 100644 --- a/slither/detectors/functions/codex.py +++ b/slither/detectors/functions/codex.py @@ -114,11 +114,7 @@ class Codex(AbstractDetector): ): continue prompt = f"Analyze this Solidity contract and find the vulnerabilities. If you find any vulnerabilities, begin the response with {VULN_FOUND}\n" - src_mapping = contract.source_mapping - content = contract.compilation_unit.core.source_code[src_mapping.filename.absolute] - start = src_mapping.start - end = src_mapping.start + src_mapping.length - prompt += content[start:end] + prompt += contract.source_mapping.content answer = self._run_codex(logging_file, prompt) diff --git a/slither/printers/summary/constructor_calls.py b/slither/printers/summary/constructor_calls.py index 4593800b9..665c76546 100644 --- a/slither/printers/summary/constructor_calls.py +++ b/slither/printers/summary/constructor_calls.py @@ -1,24 +1,22 @@ """ Module printing summary of the contract """ +from slither.core.declarations import Function from slither.core.source_mapping.source_mapping import Source from slither.printers.abstract_printer import AbstractPrinter from slither.utils import output +def _get_source_code(cst: Function) -> str: + src_mapping: Source = cst.source_mapping + return " " * src_mapping.starting_column + src_mapping.content + + class ConstructorPrinter(AbstractPrinter): WIKI = "https://github.com/crytic/slither/wiki/Printer-documentation#constructor-calls" ARGUMENT = "constructor-calls" HELP = "Print the constructors executed" - def _get_soruce_code(self, cst): - src_mapping: Source = cst.source_mapping - content = self.slither.source_code[src_mapping.filename.absolute] - start = src_mapping.start - end = src_mapping.start + src_mapping.length - initial_space = src_mapping.starting_column - return " " * initial_space + content[start:end] - def output(self, _filename): info = "" for contract in self.slither.contracts_derived: @@ -27,12 +25,12 @@ class ConstructorPrinter(AbstractPrinter): cst = contract.constructors_declared if cst: stack_name.append(contract.name) - stack_definition.append(self._get_soruce_code(cst)) + stack_definition.append(_get_source_code(cst)) for inherited_contract in contract.inheritance: cst = inherited_contract.constructors_declared if cst: stack_name.append(inherited_contract.name) - stack_definition.append(self._get_soruce_code(cst)) + stack_definition.append(_get_source_code(cst)) if len(stack_name) > 0: diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index 74a93ba2d..67b3c00a3 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -80,11 +80,7 @@ class Flattening: def _get_source_code_top_level(self, elems: Sequence[TopLevel]) -> None: for elem in elems: - src_mapping = elem.source_mapping - content = self._compilation_unit.core.source_code[src_mapping.filename.absolute] - start = src_mapping.start - end = src_mapping.start + src_mapping.length - self._source_codes_top_level[elem] = content[start:end] + self._source_codes_top_level[elem] = elem.source_mapping.content def _check_abi_encoder_v2(self): """ diff --git a/slither/utils/arithmetic.py b/slither/utils/arithmetic.py index 18ffaec51..0296231af 100644 --- a/slither/utils/arithmetic.py +++ b/slither/utils/arithmetic.py @@ -1,6 +1,12 @@ +from typing import List, TYPE_CHECKING + from slither.exceptions import SlitherException from slither.utils.integer_conversion import convert_string_to_fraction + +if TYPE_CHECKING: + from slither.core.declarations import Contract, Function + # pylint: disable=too-many-branches def convert_subdenomination( value: str, sub: str @@ -31,3 +37,52 @@ def convert_subdenomination( return int(decimal_value * 60 * 60 * 24 * 7 * 365) raise SlitherException(f"Subdemonination conversion impossible {decimal_value} {sub}") + + +# Number of unchecked arithmetic operation needed to be interesting +THRESHOLD_ARITHMETIC_USAGE = 3 + + +def _unchecked_arithemtic_usage(function: "Function") -> bool: + """ + Check if the function has more than THRESHOLD_ARITHMETIC_USAGE unchecked arithmetic operation + + Args: + function: + + Returns: + + """ + + # pylint: disable=import-outside-toplevel + from slither.slithir.operations import Binary + + score = 0 + for node in function.nodes: + if not node.scope.is_checked: + for ir in node.irs: + if isinstance(ir, Binary): + score += 1 + if score >= THRESHOLD_ARITHMETIC_USAGE: + return True + return False + + +def unchecked_arithemtic_usage(contract: "Contract") -> List["Function"]: + """ + Return the list of function with some unchecked arithmetics + + Args: + contract: + + Returns: + + """ + # pylint: disable=import-outside-toplevel + from slither.core.declarations import Function + + ret: List[Function] = [] + for function in contract.all_functions_called: + if isinstance(function, Function) and _unchecked_arithemtic_usage(function): + ret.append(function) + return ret diff --git a/slither/utils/standard_libraries.py b/slither/utils/standard_libraries.py index 897954b95..525771187 100644 --- a/slither/utils/standard_libraries.py +++ b/slither/utils/standard_libraries.py @@ -68,12 +68,7 @@ def is_openzeppelin(contract: "Contract") -> bool: def is_openzeppelin_strict(contract: "Contract") -> bool: - start = contract.source_mapping.start - end = start + contract.source_mapping.length - source_code = contract.compilation_unit.core.source_code[ - contract.source_mapping.filename.absolute - ][start:end] - source_hash = sha1(source_code.encode("utf-8")).hexdigest() + source_hash = sha1(contract.source_mapping.content.encode("utf-8")).hexdigest() return source_hash in oz_hashes diff --git a/tests/arithmetic_usage/test.sol b/tests/arithmetic_usage/test.sol new file mode 100644 index 000000000..7bda7e504 --- /dev/null +++ b/tests/arithmetic_usage/test.sol @@ -0,0 +1,29 @@ +function protected(uint a, uint b) returns(uint){ + return (a + b) * (a + b); +} + +function not_protected_asm(uint a, uint b) returns(uint){ + uint c; + assembly{ + c := mul(add(a,b), add(a,b)) + } + return c; +} + +function not_protected_unchecked(uint a, uint b) returns(uint){ + uint c; + unchecked{ + return (a + b) * (a + b); + } + +} + +contract A{ + + function f(uint a, uint b) public{ + protected(a,b); + not_protected_asm(a, b); + not_protected_unchecked(a, b); + } + +} \ No newline at end of file diff --git a/tests/test_features.py b/tests/test_features.py index 0303a5760..d29a5eb6a 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -9,6 +9,7 @@ from slither.core.variables.state_variable import StateVariable from slither.detectors import all_detectors from slither.detectors.abstract_detector import AbstractDetector from slither.slithir.operations import LibraryCall, InternalCall +from slither.utils.arithmetic import unchecked_arithemtic_usage def _run_all_detectors(slither: Slither) -> None: @@ -150,3 +151,12 @@ def test_private_variable() -> None: var_read = f.variables_read[0] assert isinstance(var_read, StateVariable) assert str(var_read.contract) == "B" + + +def test_arithmetic_usage() -> None: + solc_select.switch_global_version("0.8.15", always_install=True) + slither = Slither("./tests/arithmetic_usage/test.sol") + + assert { + f.source_mapping.content_hash for f in unchecked_arithemtic_usage(slither.contracts[0]) + } == {"2b4bc73cf59d486dd9043e840b5028b679354dd9", "e4ecd4d0fda7e762d29aceb8425f2c5d4d0bf962"}