diff --git a/mythril/analysis/modules/state_change_external_calls.py b/mythril/analysis/modules/state_change_external_calls.py index 456b2af2..01ae8354 100644 --- a/mythril/analysis/modules/state_change_external_calls.py +++ b/mythril/analysis/modules/state_change_external_calls.py @@ -1,7 +1,7 @@ from mythril.analysis.swc_data import REENTRANCY from mythril.analysis.modules.base import DetectionModule from mythril.analysis.report import Issue -from mythril.laser.smt import symbol_factory, UGT, BitVec +from mythril.laser.smt import symbol_factory, UGT, BitVec, Or from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.analysis import solver @@ -32,12 +32,12 @@ class StateChangeCallsAnnotation(StateAnnotation): new_annotation.state_change_states = self.state_change_states[:] return new_annotation - def get_issue(self) -> Optional[Issue]: + def get_issue(self, global_state: GlobalState) -> Optional[Issue]: if not self.state_change_states: return None severity = "Medium" if self.user_defined_address else "Low" - address = self.call_state.get_current_instruction()["address"] + address = global_state.get_current_instruction()["address"] logging.debug( "[EXTERNAL_CALLS] Detected state changes at addresses: {}".format(address) ) @@ -50,15 +50,15 @@ class StateChangeCallsAnnotation(StateAnnotation): ) return Issue( - contract=self.call_state.environment.active_account.contract_name, - function_name=self.call_state.environment.active_function_name, + contract=global_state.environment.active_account.contract_name, + function_name=global_state.environment.active_function_name, address=address, title="State change after external call", severity=severity, description_head=description_head, description_tail=description_tail, swc_id=REENTRANCY, - bytecode=self.call_state.environment.code.bytecode, + bytecode=global_state.environment.code.bytecode, ) @@ -97,7 +97,10 @@ class StateChange(DetectionModule): constraints + [ UGT(gas, symbol_factory.BitVecVal(2300, 256)), - to > symbol_factory.BitVecVal(16, 256), + Or( + to > symbol_factory.BitVecVal(16, 256), + to == symbol_factory.BitVecVal(0, 256), + ), ] ) @@ -144,7 +147,7 @@ class StateChange(DetectionModule): for annotation in annotations: if not annotation.state_change_states: continue - vulnerabilities.append(annotation.get_issue()) + vulnerabilities.append(annotation.get_issue(global_state)) global_state.annotations.remove(annotation) return vulnerabilities