Fix loop bounds (#1336)

* new algo for creating loop counter keys

* Good old black

* Revert to debug log level

* Fix type annotation

* Minor modification of checksum generation

* Greater trace length and some black

* Count continous loops

Co-authored-by: Bernhard Mueller <b-mueller@users.noreply.github.com>
pull/1338/head
Nikhil Parasaram 5 years ago committed by GitHub
parent 0589cdbc76
commit fa3207c5b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 73
      mythril/laser/ethereum/strategy/extensions/bounded_loops.py
  2. 2
      tests/instructions/create2_test.py

@ -14,11 +14,13 @@ class JumpdestCountAnnotation(StateAnnotation):
"""State annotation that counts the number of jumps per destination.""" """State annotation that counts the number of jumps per destination."""
def __init__(self) -> None: def __init__(self) -> None:
self._reached_count = {} # type: Dict[str, int] self._reached_count = {} # type: Dict[int, int]
self.trace = [] # type: List[int]
def __copy__(self): def __copy__(self):
result = JumpdestCountAnnotation() result = JumpdestCountAnnotation()
result._reached_count = copy(self._reached_count) result._reached_count = copy(self._reached_count)
result.trace = copy(self.trace)
return result return result
@ -43,6 +45,39 @@ class BoundedLoopsStrategy(BasicSearchStrategy):
self, super_strategy.work_list, super_strategy.max_depth self, super_strategy.work_list, super_strategy.max_depth
) )
def calculate_hash(self, i, j, trace):
"""
calculate hash(trace[i: j])
:param i:
:param j:
:param trace:
:return: hash(trace[i: j])
"""
key = 0
size = 0
for itr in range(i, j):
key |= trace[itr] << ((itr - i) * 8)
size += 1
return key
def count_key(self, trace, key, start, size):
"""
Count continuous loops in the trace.
:param trace:
:param key:
:param size:
:return:
"""
count = 0
i = start
while i >= 0:
if self.calculate_hash(i, i + size, trace) != key:
break
count += 1
i -= size
return count
def get_strategic_global_state(self) -> GlobalState: def get_strategic_global_state(self) -> GlobalState:
""" Returns the next state """ Returns the next state
@ -66,34 +101,40 @@ class BoundedLoopsStrategy(BasicSearchStrategy):
cur_instr = state.get_current_instruction() cur_instr = state.get_current_instruction()
if ( annotation.trace.append(cur_instr["address"])
cur_instr["opcode"].upper() != "JUMPDEST"
or state.environment.code.instruction_list[state.mstate.prev_pc][ if cur_instr["opcode"].upper() != "JUMPDEST":
"opcode"
]
!= "JUMPI"
):
return state return state
# create unique instruction identifier # create unique instruction identifier
key = "{};{};{}".format(
cur_instr["opcode"], cur_instr["address"], state.mstate.prev_pc
)
if key in annotation._reached_count: found = False
annotation._reached_count[key] += 1 for i in range(len(annotation.trace) - 3, 0, -1):
if (
annotation.trace[i] == annotation.trace[-2]
and annotation.trace[i + 1] == annotation.trace[-1]
):
found = True
break
if found:
key = self.calculate_hash(
i, len(annotation.trace) - 1, annotation.trace
)
size = len(annotation.trace) - i - 1
count = self.count_key(annotation.trace, key, i, size)
else: else:
annotation._reached_count[key] = 1 count = 0
# The creation transaction gets a higher loop bound to give it a better chance of success. # The creation transaction gets a higher loop bound to give it a better chance of success.
# TODO: There's probably a nicer way to do this # TODO: There's probably a nicer way to do this
if isinstance( if isinstance(
state.current_transaction, ContractCreationTransaction state.current_transaction, ContractCreationTransaction
) and annotation._reached_count[key] < max(8, self.bound): ) and count < max(8, self.bound):
return state return state
elif annotation._reached_count[key] > self.bound: elif count > self.bound:
log.debug("Loop bound reached, skipping state") log.debug("Loop bound reached, skipping state")
continue continue

@ -24,7 +24,7 @@ def generate_salted_address(code_str, salt, caller):
salt = "0" * (64 - len(salt)) + salt salt = "0" * (64 - len(salt)) + salt
contract_address = int( contract_address = int(
get_code_hash("0xff" + addr + salt + get_code_hash(code_str)[2:])[26:], 16, get_code_hash("0xff" + addr + salt + get_code_hash(code_str)[2:])[26:], 16
) )
return contract_address return contract_address

Loading…
Cancel
Save