diff --git a/scripts/travis_test.sh b/scripts/travis_test.sh index 5e70083b1..e60fa9e71 100755 --- a/scripts/travis_test.sh +++ b/scripts/travis_test.sh @@ -56,6 +56,7 @@ test_slither(){ fi } + test_slither tests/uninitialized.sol "uninitialized-state" test_slither tests/backdoor.sol "backdoor" test_slither tests/backdoor.sol "suicidal" @@ -73,6 +74,7 @@ test_slither tests/low_level_calls.sol "low-level-calls" test_slither tests/const_state_variables.sol "constable-states" test_slither tests/external_function.sol "external-function" test_slither tests/naming_convention.sol "naming-convention" +#test_slither tests/complex_func.sol "complex-function" ### Test scripts diff --git a/slither/__main__.py b/slither/__main__.py index 567811444..017bebaa6 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -101,6 +101,7 @@ def get_detectors_and_printers(): from slither.detectors.attributes.locked_ether import LockedEther from slither.detectors.functions.arbitrary_send import ArbitrarySend from slither.detectors.functions.suicidal import Suicidal + from slither.detectors.functions.complex_function import ComplexFunction from slither.detectors.reentrancy.reentrancy import Reentrancy from slither.detectors.variables.uninitialized_storage_variables import UninitializedStorageVars from slither.detectors.variables.unused_state_variables import UnusedStateVars @@ -126,6 +127,7 @@ def get_detectors_and_printers(): LowLevelCalls, NamingConvention, ConstCandidateStateVars, + #ComplexFunction, ExternalFunction] from slither.printers.summary.function import FunctionSummary diff --git a/slither/analyses/taint/specific_variable.py b/slither/analyses/taint/specific_variable.py index 304b89b4d..fc508c963 100644 --- a/slither/analyses/taint/specific_variable.py +++ b/slither/analyses/taint/specific_variable.py @@ -15,9 +15,7 @@ from .common import iterate_over_irs def make_key(variable): if isinstance(variable, Variable): - key = 'TAINT_{}{}{}'.format(variable.contract.name, - variable.name, - str(type(variable))) + key = 'TAINT_{}'.format(id(variable)) else: assert isinstance(variable, SolidityVariable) key = 'TAINT_{}{}'.format(variable.name, @@ -60,8 +58,6 @@ def _visit_node(node, visited, key): key) taints = iterate_over_irs(node.irs, _transfer_func_, taints) - taints = [v for v in taints if not isinstance(v, (TemporaryVariable, ReferenceVariable))] - node.function.slither.context[key] = list(set(taints)) for son in node.sons: @@ -101,7 +97,7 @@ def is_tainted(variable, taint): if not isinstance(variable, (Variable, SolidityVariable)): return False key = make_key(taint) - return key in variable.context and variable.context[key] + return (key in variable.context and variable.context[key]) or variable == taint def is_tainted_from_key(variable, key): """ diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py index 747b7d285..8c16e3106 100644 --- a/slither/core/children/child_node.py +++ b/slither/core/children/child_node.py @@ -10,3 +10,11 @@ class ChildNode(object): @property def node(self): return self._node + + @property + def function(self): + return self.node.function + + @property + def contract(self): + return self.node.function.contract diff --git a/slither/detectors/functions/complex_function.py b/slither/detectors/functions/complex_function.py new file mode 100644 index 000000000..355f1865b --- /dev/null +++ b/slither/detectors/functions/complex_function.py @@ -0,0 +1,114 @@ +from slither.core.declarations.solidity_variables import (SolidityFunction, + SolidityVariableComposed) +from slither.detectors.abstract_detector import (AbstractDetector, + DetectorClassification) +from slither.slithir.operations import (HighLevelCall, + LowLevelCall, + LibraryCall) +from slither.utils.code_complexity import compute_cyclomatic_complexity + + +class ComplexFunction(AbstractDetector): + """ + Module detecting complex functions + A complex function is defined by: + - high cyclomatic complexity + - numerous writes to state variables + - numerous external calls + """ + + + ARGUMENT = 'complex-function' + HELP = 'Complex functions' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.MEDIUM + + MAX_STATE_VARIABLES = 10 + MAX_EXTERNAL_CALLS = 5 + MAX_CYCLOMATIC_COMPLEXITY = 7 + + CAUSE_CYCLOMATIC = "cyclomatic" + CAUSE_EXTERNAL_CALL = "external_calls" + CAUSE_STATE_VARS = "state_vars" + + + @staticmethod + def detect_complex_func(func): + """Detect the cyclomatic complexity of the contract functions + shouldn't be greater than 7 + """ + result = [] + code_complexity = compute_cyclomatic_complexity(func) + + if code_complexity > ComplexFunction.MAX_CYCLOMATIC_COMPLEXITY: + result.append({ + "func": func, + "cause": ComplexFunction.CAUSE_CYCLOMATIC + }) + + """Detect the number of external calls in the func + shouldn't be greater than 5 + """ + count = 0 + for node in func.nodes: + for ir in node.irs: + if isinstance(ir, (HighLevelCall, LowLevelCall, LibraryCall)): + count += 1 + + if count > ComplexFunction.MAX_EXTERNAL_CALLS: + result.append({ + "func": func, + "cause": ComplexFunction.CAUSE_EXTERNAL_CALL + }) + + """Checks the number of the state variables written + shouldn't be greater than 10 + """ + if len(func.state_variables_written) > ComplexFunction.MAX_STATE_VARIABLES: + result.append({ + "func": func, + "cause": ComplexFunction.CAUSE_STATE_VARS + }) + + return result + + def detect_complex(self, contract): + ret = [] + + for func in contract.all_functions_called: + result = self.detect_complex_func(func) + ret.extend(result) + + return ret + + def detect(self): + results = [] + + for contract in self.contracts: + issues = self.detect_complex(contract) + + for issue in issues: + func, cause = issue.values() + func_name = func.name + + txt = "Complex function in {} Contract: {}, Function: {}" + + if cause == self.CAUSE_EXTERNAL_CALL: + txt += ", Reason: High number of external calls" + if cause == self.CAUSE_CYCLOMATIC: + txt += ", Reason: High number of branches" + if cause == self.CAUSE_STATE_VARS: + txt += ", Reason: High number of modified state variables" + + info = txt.format(self.filename, + contract.name, + func_name) + self.log(info) + + results.append({'vuln': 'ComplexFunc', + 'sourceMapping': func.source_mapping, + 'filename': self.filename, + 'contract': contract.name, + 'func': func_name}) + return results + diff --git a/slither/printers/functions/authorization.py b/slither/printers/functions/authorization.py index ff1666331..b2e2a1086 100644 --- a/slither/printers/functions/authorization.py +++ b/slither/printers/functions/authorization.py @@ -39,4 +39,4 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): state_variables_written = [v.name for v in function.all_state_variables_written()] msg_sender_condition = self.get_msg_sender_checks(function) table.add_row([function.name, str(state_variables_written), str(msg_sender_condition)]) - self.info(txt + str(table)) + self.info(txt + str(table)) diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 3748f4403..9b4a98e9c 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -208,7 +208,7 @@ def convert_to_low_level(ir): logger.error('Incorrect conversion to low level {}'.format(ir)) exit(-1) -def convert_to_push(ir): +def convert_to_push(ir, node): """ Convert a call to a PUSH operaiton @@ -221,7 +221,7 @@ def convert_to_push(ir): if isinstance(ir.arguments[0], list): ret = [] - val = TemporaryVariable() + val = TemporaryVariable(node) operation = InitArray(ir.arguments[0], val) ret.append(operation) @@ -419,7 +419,7 @@ def propagate_types(ir, node): # Which leads to return a list of operation if isinstance(t, ArrayType): if ir.function_name == 'push' and len(ir.arguments) == 1: - return convert_to_push(ir) + return convert_to_push(ir, node) elif isinstance(ir, Index): if isinstance(ir.variable_left.type, MappingType): @@ -458,7 +458,9 @@ def propagate_types(ir, node): elif isinstance(ir, Member): # TODO we should convert the reference to a temporary if the member is a length or a balance if ir.variable_right == 'length' and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)): - return Length(ir.variable_left, ir.lvalue) + length = Length(ir.variable_left, ir.lvalue) + ir.lvalue.points_to = ir.variable_left + return ir if ir.variable_right == 'balance' and isinstance(ir.variable_left.type, ElementaryType): return Balance(ir.variable_left, ir.lvalue) left = ir.variable_left @@ -657,7 +659,7 @@ def convert_expression(expression, node): if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]: result = [Condition(expression.value)] return result - visitor = ExpressionToSlithIR(expression) + visitor = ExpressionToSlithIR(expression, node) result = visitor.result() result = apply_ir_heuristics(result, node) diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py index 9763202ef..defccaaea 100644 --- a/slither/slithir/variables/reference.py +++ b/slither/slithir/variables/reference.py @@ -8,11 +8,12 @@ class ReferenceVariable(ChildNode, Variable): COUNTER = 0 - def __init__(self): + def __init__(self, node): super(ReferenceVariable, self).__init__() self._index = ReferenceVariable.COUNTER ReferenceVariable.COUNTER += 1 self._points_to = None + self._node = node @property def index(self): diff --git a/slither/slithir/variables/temporary.py b/slither/slithir/variables/temporary.py index 910ff51b1..a736a3dd1 100644 --- a/slither/slithir/variables/temporary.py +++ b/slither/slithir/variables/temporary.py @@ -6,10 +6,11 @@ class TemporaryVariable(ChildNode, Variable): COUNTER = 0 - def __init__(self): + def __init__(self, node): super(TemporaryVariable, self).__init__() self._index = TemporaryVariable.COUNTER TemporaryVariable.COUNTER += 1 + self._node = node @property def index(self): diff --git a/slither/utils/code_complexity.py b/slither/utils/code_complexity.py new file mode 100644 index 000000000..efe94648b --- /dev/null +++ b/slither/utils/code_complexity.py @@ -0,0 +1,75 @@ +# Funciton computing the code complexity + +def compute_number_edges(function): + """ + Compute the number of edges of the CFG + Args: + function (core.declarations.function.Function) + Returns: + int + """ + n = 0 + for node in function.nodes: + n += len(node.sons) + return n + + +def compute_strongly_connected_components(function): + """ + Compute strongly connected components + Based on Kosaraju algo + Implem follows wikipedia algo: https://en.wikipedia.org/wiki/Kosaraju%27s_algorithm#The_algorithm + Args: + function (core.declarations.function.Function) + Returns: + list(list(nodes)) + """ + visited = {n:False for n in function.nodes} + assigned = {n:False for n in function.nodes} + components = [] + l = [] + + def visit(node): + if not visited[node]: + visited[node] = True + for son in node.sons: + visit(son) + l.append(node) + + for n in function.nodes: + visit(n) + + def assign(node, root): + if not assigned[node]: + assigned[node] = True + root.append(node) + for father in node.fathers: + assign(father, root) + + for n in l: + component = [] + assign(n, component) + if component: + components.append(component) + + return components + +def compute_cyclomatic_complexity(function): + """ + Compute the cyclomatic complexity of a function + Args: + function (core.declarations.function.Function) + Returns: + int + """ + # from https://en.wikipedia.org/wiki/Cyclomatic_complexity + # M = E - N + 2P + # where M is the complexity + # E number of edges + # N number of nodes + # P number of connected components + + E = compute_number_edges(function) + N = len(function.nodes) + P = len(compute_strongly_connected_components(function)) + return E - N + 2 * P \ No newline at end of file diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index 3dfe07ce9..e1c90c595 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -61,8 +61,9 @@ def convert_assignment(left, right, t, return_type): class ExpressionToSlithIR(ExpressionVisitor): - def __init__(self, expression): + def __init__(self, expression, node): self._expression = expression + self._node = node self._result = [] self._visit_expression(self.expression) @@ -104,7 +105,7 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_binary_operation(self, expression): left = get(expression.expression_left) right = get(expression.expression_right) - val = TemporaryVariable() + val = TemporaryVariable(self._node) operation = Binary(val, left, right, expression.type) self._result.append(operation) @@ -123,18 +124,18 @@ class ExpressionToSlithIR(ExpressionVisitor): if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()': val = TupleVariable() else: - val = TemporaryVariable() + val = TemporaryVariable(self._node) internal_call = InternalCall(called, len(args), val, expression.type_call) self._result.append(internal_call) set_val(expression, val) else: - val = TemporaryVariable() + val = TemporaryVariable(self._node) # If tuple if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()': val = TupleVariable() else: - val = TemporaryVariable() + val = TemporaryVariable(self._node) message_call = TmpCall(called, len(args), val, expression.type_call) self._result.append(message_call) @@ -152,7 +153,7 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_index_access(self, expression): left = get(expression.expression_left) right = get(expression.expression_right) - val = ReferenceVariable() + val = ReferenceVariable(self._node) operation = Index(val, left, right, expression.type) self._result.append(operation) set_val(expression, val) @@ -162,26 +163,26 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_member_access(self, expression): expr = get(expression.expression) - val = ReferenceVariable() + val = ReferenceVariable(self._node) member = Member(expr, Constant(expression.member_name), val) self._result.append(member) set_val(expression, val) def _post_new_array(self, expression): - val = TemporaryVariable() + val = TemporaryVariable(self._node) operation = TmpNewArray(expression.depth, expression.array_type, val) self._result.append(operation) set_val(expression, val) def _post_new_contract(self, expression): - val = TemporaryVariable() + val = TemporaryVariable(self._node) operation = TmpNewContract(expression.contract_name, val) self._result.append(operation) set_val(expression, val) def _post_new_elementary_type(self, expression): # TODO unclear if this is ever used? - val = TemporaryVariable() + val = TemporaryVariable(self._node) operation = TmpNewElementaryType(expression.type, val) self._result.append(operation) set_val(expression, val) @@ -196,7 +197,7 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_type_conversion(self, expression): expr = get(expression.expression) - val = TemporaryVariable() + val = TemporaryVariable(self._node) operation = TypeConversion(val, expr, expression.type) self._result.append(operation) set_val(expression, val) @@ -204,7 +205,7 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_unary_operation(self, expression): value = get(expression.expression) if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: - lvalue = TemporaryVariable() + lvalue = TemporaryVariable(self._node) operation = Unary(lvalue, value, expression.type) self._result.append(operation) set_val(expression, lvalue) @@ -221,14 +222,14 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, value) elif expression.type in [UnaryOperationType.PLUSPLUS_POST]: - lvalue = TemporaryVariable() + lvalue = TemporaryVariable(self._node) operation = Assignment(lvalue, value, value.type) self._result.append(operation) operation = Binary(value, value, Constant("1"), BinaryType.ADDITION) self._result.append(operation) set_val(expression, lvalue) elif expression.type in [UnaryOperationType.MINUSMINUS_POST]: - lvalue = TemporaryVariable() + lvalue = TemporaryVariable(self._node) operation = Assignment(lvalue, value, value.type) self._result.append(operation) operation = Binary(value, value, Constant("1"), BinaryType.SUBTRACTION) @@ -237,7 +238,7 @@ class ExpressionToSlithIR(ExpressionVisitor): elif expression.type in [UnaryOperationType.PLUS_PRE]: set_val(expression, value) elif expression.type in [UnaryOperationType.MINUS_PRE]: - lvalue = TemporaryVariable() + lvalue = TemporaryVariable(self._node) operation = Binary(lvalue, Constant("0"), value, BinaryType.SUBTRACTION) self._result.append(operation) set_val(expression, lvalue) diff --git a/tests/complex_func.sol b/tests/complex_func.sol new file mode 100644 index 000000000..cdb716efd --- /dev/null +++ b/tests/complex_func.sol @@ -0,0 +1,88 @@ +pragma solidity ^0.4.24; + +contract Complex { + int numberOfSides = 7; + string shape; + uint i0 = 0; + uint i1 = 0; + uint i2 = 0; + uint i3 = 0; + uint i4 = 0; + uint i5 = 0; + uint i6 = 0; + uint i7 = 0; + uint i8 = 0; + uint i9 = 0; + uint i10 = 0; + + + function computeShape() external { + if (numberOfSides <= 2) { + shape = "Cant be a shape!"; + } else if (numberOfSides == 3) { + shape = "Triangle"; + } else if (numberOfSides == 4) { + shape = "Square"; + } else if (numberOfSides == 5) { + shape = "Pentagon"; + } else if (numberOfSides == 6) { + shape = "Hexagon"; + } else if (numberOfSides == 7) { + shape = "Heptagon"; + } else if (numberOfSides == 8) { + shape = "Octagon"; + } else if (numberOfSides == 9) { + shape = "Nonagon"; + } else if (numberOfSides == 10) { + shape = "Decagon"; + } else if (numberOfSides == 11) { + shape = "Hendecagon"; + } else { + shape = "Your shape is more than 11 sides."; + } + } + + function complexExternalWrites() external { + Increment test1 = new Increment(); + test1.increaseBy1(); + test1.increaseBy1(); + test1.increaseBy1(); + test1.increaseBy1(); + test1.increaseBy1(); + + Increment test2 = new Increment(); + test2.increaseBy1(); + + address test3 = new Increment(); + test3.call(bytes4(keccak256("increaseBy2()"))); + + address test4 = new Increment(); + test4.call(bytes4(keccak256("increaseBy2()"))); + } + + function complexStateVars() external { + i0 = 1; + i1 = 1; + i2 = 1; + i3 = 1; + i4 = 1; + i5 = 1; + i6 = 1; + i7 = 1; + i8 = 1; + i9 = 1; + i10 = 1; + } +} + +contract Increment { + uint i = 0; + + function increaseBy1() public { + i += 1; + } + + function increaseBy2() public { + i += 2; + } +} \ No newline at end of file