import sys from slither.core.declarations.solidity_variables import \ SolidityVariableComposed from slither.core.variables.state_variable import StateVariable from slither.slither import Slither from slither.slithir.operations.high_level_call import HighLevelCall from slither.slithir.operations.index import Index from slither.slithir.variables.reference import ReferenceVariable from slither.slithir.variables.temporary import TemporaryVariable def visit_node(node, visited): if node in visited: return visited += [node] taints = node.function.slither.context[KEY] refs = {} for ir in node.irs: if isinstance(ir, Index): refs[ir.lvalue] = ir.variable_left if isinstance(ir, Index): read = [ir.variable_left] else: read = ir.read print(ir) print('Refs {}'.format(refs)) print('Read {}'.format([str(x) for x in ir.read])) print('Before {}'.format([str(x) for x in taints])) if any(var_read in taints for var_read in read): taints += [ir.lvalue] lvalue = ir.lvalue while isinstance(lvalue, ReferenceVariable): taints += [refs[lvalue]] lvalue = refs[lvalue] print('After {}'.format([str(x) for x in taints])) print() taints = [v for v in taints if not isinstance(v, (TemporaryVariable, ReferenceVariable))] node.function.slither.context[KEY] = list(set(taints)) for son in node.sons: visit_node(son, visited) def check_call(func, taints): for node in func.nodes: for ir in node.irs: if isinstance(ir, HighLevelCall): if ir.destination in taints: print('Call to tainted address found in {}'.format(function.name)) if __name__ == "__main__": if len(sys.argv) != 2: print('python.py taint.py taint.sol') exit(-1) # Init slither slither = Slither(sys.argv[1]) initial_taint = [SolidityVariableComposed('msg.sender')] initial_taint += [SolidityVariableComposed('msg.value')] KEY = 'TAINT' prev_taints = [] slither.context[KEY] = initial_taint while(set(prev_taints) != set(slither.context[KEY])): prev_taints = slither.context[KEY] for contract in slither.contracts: for function in contract.functions: print('Function {}'.format(function.name)) slither.context[KEY] = list(set(slither.context[KEY] + function.parameters)) visit_node(function.entry_point, []) print('All variables tainted : {}'.format([str(v) for v in slither.context[KEY]])) print('All state variables tainted : {}'.format([str(v) for v in prev_taints if isinstance(v, StateVariable)])) for function in contract.functions: check_call(function, slither.context[KEY])