Merge pull request #1601 from crytic/dev-api-improvements

Minor API improvements
pull/1691/head
Feist Josselin 2 years ago committed by GitHub
commit 540985d4bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 40
      slither/core/source_mapping/source_mapping.py
  2. 6
      slither/detectors/functions/codex.py
  3. 18
      slither/printers/summary/constructor_calls.py
  4. 6
      slither/tools/flattening/flattening.py
  5. 55
      slither/utils/arithmetic.py
  6. 7
      slither/utils/standard_libraries.py
  7. 29
      tests/arithmetic_usage/test.sol
  8. 10
      tests/test_features.py

@ -2,8 +2,8 @@ import re
from abc import ABCMeta
from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional
from Crypto.Hash import SHA1
from crytic_compile.utils.naming import Filename
from slither.core.context.context import Context
if TYPE_CHECKING:
@ -56,7 +56,7 @@ class Source:
filename_short: str = self.filename.short if self.filename.short else ""
return f"{filename_short}{lines} ({self.starting_column} - {self.ending_column})"
def _get_lines_str(self, line_descr=""):
def _get_lines_str(self, line_descr: str = "") -> str:
# If the compilation unit was not initialized, it means that the set_offset was never called
# on the corresponding object, which should not happen
@ -66,12 +66,36 @@ class Source:
lines = self.lines
if not lines:
lines = ""
elif len(lines) == 1:
lines = f"{line_prefix}{line_descr}{lines[0]}"
else:
lines = f"{line_prefix}{line_descr}{lines[0]}-{line_descr}{lines[-1]}"
return lines
return ""
if len(lines) == 1:
return f"{line_prefix}{line_descr}{lines[0]}"
return f"{line_prefix}{line_descr}{lines[0]}-{line_descr}{lines[-1]}"
@property
def content(self) -> str:
"""
Return the txt content of the Source
Returns:
"""
# If the compilation unit was not initialized, it means that the set_offset was never called
# on the corresponding object, which should not happen
assert self.compilation_unit
return self.compilation_unit.core.source_code[self.filename.absolute][self.start : self.end]
@property
def content_hash(self) -> str:
"""
Return sha1(self.content)
Returns:
"""
h = SHA1.new()
h.update(self.content.encode("utf8"))
return h.hexdigest()
def __str__(self) -> str:
lines = self._get_lines_str()

@ -118,11 +118,7 @@ class Codex(AbstractDetector):
):
continue
prompt = f"Analyze this Solidity contract and find the vulnerabilities. If you find any vulnerabilities, begin the response with {VULN_FOUND}\n"
src_mapping = contract.source_mapping
content = contract.compilation_unit.core.source_code[src_mapping.filename.absolute]
start = src_mapping.start
end = src_mapping.start + src_mapping.length
prompt += content[start:end]
prompt += contract.source_mapping.content
answer = self._run_codex(logging_file, prompt)

@ -1,24 +1,22 @@
"""
Module printing summary of the contract
"""
from slither.core.declarations import Function
from slither.core.source_mapping.source_mapping import Source
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils import output
def _get_source_code(cst: Function) -> str:
src_mapping: Source = cst.source_mapping
return " " * src_mapping.starting_column + src_mapping.content
class ConstructorPrinter(AbstractPrinter):
WIKI = "https://github.com/crytic/slither/wiki/Printer-documentation#constructor-calls"
ARGUMENT = "constructor-calls"
HELP = "Print the constructors executed"
def _get_soruce_code(self, cst):
src_mapping: Source = cst.source_mapping
content = self.slither.source_code[src_mapping.filename.absolute]
start = src_mapping.start
end = src_mapping.start + src_mapping.length
initial_space = src_mapping.starting_column
return " " * initial_space + content[start:end]
def output(self, _filename):
info = ""
for contract in self.slither.contracts_derived:
@ -27,12 +25,12 @@ class ConstructorPrinter(AbstractPrinter):
cst = contract.constructors_declared
if cst:
stack_name.append(contract.name)
stack_definition.append(self._get_soruce_code(cst))
stack_definition.append(_get_source_code(cst))
for inherited_contract in contract.inheritance:
cst = inherited_contract.constructors_declared
if cst:
stack_name.append(inherited_contract.name)
stack_definition.append(self._get_soruce_code(cst))
stack_definition.append(_get_source_code(cst))
if len(stack_name) > 0:

@ -80,11 +80,7 @@ class Flattening:
def _get_source_code_top_level(self, elems: Sequence[TopLevel]) -> None:
for elem in elems:
src_mapping = elem.source_mapping
content = self._compilation_unit.core.source_code[src_mapping.filename.absolute]
start = src_mapping.start
end = src_mapping.start + src_mapping.length
self._source_codes_top_level[elem] = content[start:end]
self._source_codes_top_level[elem] = elem.source_mapping.content
def _check_abi_encoder_v2(self):
"""

@ -1,6 +1,12 @@
from typing import List, TYPE_CHECKING
from slither.exceptions import SlitherException
from slither.utils.integer_conversion import convert_string_to_fraction
if TYPE_CHECKING:
from slither.core.declarations import Contract, Function
# pylint: disable=too-many-branches
def convert_subdenomination(
value: str, sub: str
@ -31,3 +37,52 @@ def convert_subdenomination(
return int(decimal_value * 60 * 60 * 24 * 7 * 365)
raise SlitherException(f"Subdemonination conversion impossible {decimal_value} {sub}")
# Number of unchecked arithmetic operation needed to be interesting
THRESHOLD_ARITHMETIC_USAGE = 3
def _unchecked_arithemtic_usage(function: "Function") -> bool:
"""
Check if the function has more than THRESHOLD_ARITHMETIC_USAGE unchecked arithmetic operation
Args:
function:
Returns:
"""
# pylint: disable=import-outside-toplevel
from slither.slithir.operations import Binary
score = 0
for node in function.nodes:
if not node.scope.is_checked:
for ir in node.irs:
if isinstance(ir, Binary):
score += 1
if score >= THRESHOLD_ARITHMETIC_USAGE:
return True
return False
def unchecked_arithemtic_usage(contract: "Contract") -> List["Function"]:
"""
Return the list of function with some unchecked arithmetics
Args:
contract:
Returns:
"""
# pylint: disable=import-outside-toplevel
from slither.core.declarations import Function
ret: List[Function] = []
for function in contract.all_functions_called:
if isinstance(function, Function) and _unchecked_arithemtic_usage(function):
ret.append(function)
return ret

@ -68,12 +68,7 @@ def is_openzeppelin(contract: "Contract") -> bool:
def is_openzeppelin_strict(contract: "Contract") -> bool:
start = contract.source_mapping.start
end = start + contract.source_mapping.length
source_code = contract.compilation_unit.core.source_code[
contract.source_mapping.filename.absolute
][start:end]
source_hash = sha1(source_code.encode("utf-8")).hexdigest()
source_hash = sha1(contract.source_mapping.content.encode("utf-8")).hexdigest()
return source_hash in oz_hashes

@ -0,0 +1,29 @@
function protected(uint a, uint b) returns(uint){
return (a + b) * (a + b);
}
function not_protected_asm(uint a, uint b) returns(uint){
uint c;
assembly{
c := mul(add(a,b), add(a,b))
}
return c;
}
function not_protected_unchecked(uint a, uint b) returns(uint){
uint c;
unchecked{
return (a + b) * (a + b);
}
}
contract A{
function f(uint a, uint b) public{
protected(a,b);
not_protected_asm(a, b);
not_protected_unchecked(a, b);
}
}

@ -9,6 +9,7 @@ from slither.core.variables.state_variable import StateVariable
from slither.detectors import all_detectors
from slither.detectors.abstract_detector import AbstractDetector
from slither.slithir.operations import LibraryCall, InternalCall
from slither.utils.arithmetic import unchecked_arithemtic_usage
def _run_all_detectors(slither: Slither) -> None:
@ -150,3 +151,12 @@ def test_private_variable() -> None:
var_read = f.variables_read[0]
assert isinstance(var_read, StateVariable)
assert str(var_read.contract) == "B"
def test_arithmetic_usage() -> None:
solc_select.switch_global_version("0.8.15", always_install=True)
slither = Slither("./tests/arithmetic_usage/test.sol")
assert {
f.source_mapping.content_hash for f in unchecked_arithemtic_usage(slither.contracts[0])
} == {"2b4bc73cf59d486dd9043e840b5028b679354dd9", "e4ecd4d0fda7e762d29aceb8425f2c5d4d0bf962"}

Loading…
Cancel
Save