diff --git a/mythril/laser/ethereum/state/world_state.py b/mythril/laser/ethereum/state/world_state.py index fe3ef791..57d0879d 100644 --- a/mythril/laser/ethereum/state/world_state.py +++ b/mythril/laser/ethereum/state/world_state.py @@ -1,10 +1,10 @@ """This module contains a representation of the EVM's world state.""" from copy import copy from random import randint -from typing import Dict, List, Iterator, Optional, TYPE_CHECKING, cast +from typing import Dict, List, Iterator, Optional, Tuple, TYPE_CHECKING, cast from mythril.support.loader import DynLoader -from mythril.laser.smt import symbol_factory, Array, BitVec, If, Or +from mythril.laser.smt import symbol_factory, Array, BitVec, If, Or, And, Not, Bool from ethereum.utils import mk_contract_address from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.annotation import StateAnnotation @@ -13,6 +13,8 @@ from mythril.laser.ethereum.state.constraints import Constraints if TYPE_CHECKING: from mythril.laser.ethereum.cfg import Node +CONSTRAINT_DIFFERENCE_LIMIT = 15 + class WorldState: """The WorldState class represents the world state as described in the @@ -46,14 +48,12 @@ class WorldState: """ # Merge constraints - c1 = self.constraints.compress() - c2 = state.constraints.compress() - self.constraints = Constraints([Or(c1, c2)]) + self.constraints, condition1, _ = self._merge_constraints(state.constraints) # Merge balances - self.balances = cast(Array, If(c1, self.balances, state.balances)) + self.balances = cast(Array, If(condition1, self.balances, state.balances)) self.starting_balances = cast( - Array, If(c1, self.starting_balances, state.starting_balances) + Array, If(condition1, self.starting_balances, state.starting_balances) ) # Merge accounts @@ -61,7 +61,9 @@ class WorldState: if address not in self._accounts: self.put_account(account) else: - self._accounts[address].merge_accounts(account, c1, self.balances) + self._accounts[address].merge_accounts( + account, condition1, self.balances + ) # Merge annotations self._merge_annotations(state) @@ -69,6 +71,51 @@ class WorldState: # Merge Node self.node.merge_nodes(state.node, self.constraints) + def _merge_constraints( + self, constraints: Constraints + ) -> Tuple[Constraints, Bool, Bool]: + dict1, dict2 = {}, {} + for constraint in self.constraints: + dict1[constraint] = True + for constraint in constraints: + dict2[constraint] = True + c1, c2 = symbol_factory.Bool(True), symbol_factory.Bool(True) + new_constraint1, new_constraint2 = symbol_factory.Bool(True), symbol_factory.Bool(True) + same_constraints = Constraints() + for key in dict1: + if key not in dict2: + c1 = And(c1, key) + if Not(key) not in dict2: + new_constraint1 = And(new_constraint1, key) + else: + same_constraints.append(key) + for key in dict2: + if key not in dict1: + c2 = And(c2, key) + if Not(key) not in dict1: + new_constraint2 = And(new_constraint2, key) + else: + same_constraints.append(key) + merge_constraints = same_constraints + [Or(new_constraint1, new_constraint2)] + return merge_constraints, c1, c2 + + def _check_constraint_merge(self, constraints: Constraints) -> bool: + dict1, dict2 = {}, {} + for constraint in self.constraints: + dict1[constraint] = True + for constraint in constraints: + dict2[constraint] = True + c1, c2 = 0, 0 + for key in dict1: + if key not in dict2 and Not(key) not in dict2: + c1 += 1 + for key in dict2: + if key not in dict1 and Not(key) not in dict1: + c2 += 1 + if c1 + c2 > CONSTRAINT_DIFFERENCE_LIMIT: + return False + return True + def _check_merge_annotations(self, state: "WorldState"): """ @@ -77,6 +124,8 @@ class WorldState: """ if len(state.annotations) != len(self._annotations): return False + if self._check_constraint_merge(state.constraints) is False: + return False for v1, v2 in zip(state.annotations, self._annotations): if v1.check_merge_annotation(v2) is False: # type: ignore return False