diff --git a/mythril/laser/ethereum/plugins/implementations/plugin_annotations.py b/mythril/laser/ethereum/plugins/implementations/plugin_annotations.py index 209e562e..b739debc 100644 --- a/mythril/laser/ethereum/plugins/implementations/plugin_annotations.py +++ b/mythril/laser/ethereum/plugins/implementations/plugin_annotations.py @@ -1,4 +1,5 @@ from mythril.laser.ethereum.state.annotation import StateAnnotation +from mythril.laser.smt import If from copy import copy from typing import Dict, List, Set @@ -48,6 +49,20 @@ class DependencyAnnotation(StateAnnotation): elif value not in self.storage_written[iteration]: self.storage_written[iteration].append(value) + def check_merge_annotation(self, other): + return self.has_call == other.has_call and self.path == other.path + + def merge_annotation(self, other: "DependencyAnnotation"): + self.blocks_seen = self.blocks_seen.union(other.blocks_seen) + for v in other.storage_loaded: + if v not in self.storage_loaded: + self.storage_loaded.append(v) + for key, val in other.storage_written.items(): + if key not in self.storage_written: + self.storage_written[key] = val + elif self.storage_written[key] != val: + self.storage_written[key] += val + class WSDependencyAnnotation(StateAnnotation): """Dependency Annotation for World state @@ -63,3 +78,15 @@ class WSDependencyAnnotation(StateAnnotation): result = WSDependencyAnnotation() result.annotations_stack = copy(self.annotations_stack) return result + + def check_merge_annotation(self, other): + if len(self.annotations_stack) != len(other.annotations_stack): + return False + for a1, a2 in zip(self.annotations_stack, other.annotations_stack): + if a1.check_merge_annotation(a2) is False: + return False + return True + + def merge_annotations(self, other): + for a1, a2 in zip(self.annotations_stack, other.annotations_stack): + a1.merge_annotation(a2) diff --git a/mythril/laser/ethereum/plugins/implementations/state_merge.py b/mythril/laser/ethereum/plugins/implementations/state_merge.py index 1b6d0212..ee7d153e 100644 --- a/mythril/laser/ethereum/plugins/implementations/state_merge.py +++ b/mythril/laser/ethereum/plugins/implementations/state_merge.py @@ -1,5 +1,5 @@ from copy import copy -from typing import List +from typing import Dict, List from mythril.laser.ethereum.svm import LaserEVM from mythril.laser.ethereum.plugins.plugin import LaserPlugin from mythril.laser.smt import symbol_factory, simplify, Or @@ -33,22 +33,29 @@ class StateMerge(LaserPlugin): old_states = copy(open_states) while old_size != len(new_states): old_size = len(new_states) - i = 0 new_states = [] - while i < len(old_states) - 1: - if self.check_merge_condition(old_states[i], old_states[i + 1]): - new_states.append( - self.merge_states(old_states[i], old_states[i + 1]) - ) - i += 2 + merged_dict = {} # type: Dict[int, bool] + for i in range(len(old_states)): + if merged_dict.get(i, False): continue - else: + i_is_merged = False + for j in range(i + 1, len(old_states)): + if merged_dict.get(j, False) or not self.check_merge_condition( + old_states[i], old_states[j] + ): + j += 1 + continue + state = old_states[i] + self.merge_states(state, old_states[j]) + merged_dict[j] = True + new_states.append(state) + i_is_merged = True + break + + if i_is_merged is False: new_states.append(old_states[i]) - i += 1 - if i == len(old_states) - 1: - new_states.append(old_states[i]) - old_states = copy(new_states) + old_states = copy(new_states) logging.info( "States reduced from {} to {}".format(num_old_states, len(new_states)) ) @@ -67,5 +74,11 @@ class StateMerge(LaserPlugin): @staticmethod def merge_states(state1: WorldState, state2: WorldState) -> WorldState: + """ + Merge state2 into state1 + :param state1: The state to be merged into + :param state2: The state which is merged into state1 + :return: + """ state1.merge_states(state2) return state1 diff --git a/mythril/laser/ethereum/state/world_state.py b/mythril/laser/ethereum/state/world_state.py index e0179ffc..01471227 100644 --- a/mythril/laser/ethereum/state/world_state.py +++ b/mythril/laser/ethereum/state/world_state.py @@ -44,8 +44,6 @@ class WorldState: :param state: The state to be merged with :return: """ - # combine annotations - self._annotations += state._annotations # Merge constraints c1 = self.constraints.compress() @@ -65,10 +63,35 @@ class WorldState: else: self._accounts[address].merge_accounts(account, c1, self.balances) + # Merge annotations + self._merge_annotations(state) + # Merge Node self.node.merge_nodes(state.node, self.constraints) - def check_merge_condition(self, state): + def _check_merge_annotations(self, state: "WorldState"): + """ + + :param state: + :return: + """ + if len(state.annotations) != len(self._annotations): + return False + for v1, v2 in zip(state.annotations, self._annotations): + if v1.check_merge_annotation(v2) is False: # type: ignore + return False + return True + + def _merge_annotations(self, state: "WorldState"): + """ + + :param state: + :return: + """ + for v1, v2 in zip(state.annotations, self._annotations): + v1.merge_annotations(v2) # type: ignore + + def check_merge_condition(self, state: "WorldState"): """ Checks whether we can merge this state with "state" or not :param state: The state to check the merge-ability with @@ -76,12 +99,16 @@ class WorldState: """ if self.node and not self.node.check_merge_condition(state.node): return False + for address, account in state.accounts.items(): if ( address in self._accounts and self._accounts[address].check_merge_condition(account) is False ): return False + if not self._check_merge_annotations(state): + return False + return True @property