diff --git a/.travis.yml b/.travis.yml index d3ccde3c2..b33582446 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,6 +18,7 @@ env: - TEST_SUITE=scripts/travis_test_cli.sh - TEST_SUITE=scripts/travis_test_printers.sh - TEST_SUITE=scripts/travis_test_slither_config.sh + - TEST_SUITE=scripts/travis_test_simil.sh branches: only: - master diff --git a/scripts/travis_test_simil.sh b/scripts/travis_test_simil.sh new file mode 100755 index 000000000..ccf332800 --- /dev/null +++ b/scripts/travis_test_simil.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +### Install requisites + +pip3.6 install pybind11 +pip3.6 install https://github.com/facebookresearch/fastText/archive/0.2.0.zip + +### Test slither-simil + +DIR_TESTS="tests/simil" +slither-simil info "" --filename $DIR_TESTS/../complex_func.sol --fname Complex.complexExternalWrites --solc solc-0.4.25 > test_1.txt 2>&1 +DIFF=$(diff test_1.txt "$DIR_TESTS/test_1.txt") +if [ "$DIFF" != "" ] +then + echo "slither-simil failed" + cat test_1.txt + cat "$DIR_TESTS/test_1.txt" + exit -1 +fi + +rm test_1.txt diff --git a/setup.py b/setup.py index deb1447c9..4721aea29 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,8 @@ setup( 'console_scripts': [ 'slither = slither.__main__:main', 'slither-check-upgradeability = utils.upgradeability.__main__:main', - 'slither-find-paths = utils.possible_paths.__main__:main' + 'slither-find-paths = utils.possible_paths.__main__:main', + 'slither-simil = utils.similarity.__main__:main' ] } ) diff --git a/tests/simil/test_1.txt b/tests/simil/test_1.txt new file mode 100644 index 000000000..1c6a7bb7e --- /dev/null +++ b/tests/simil/test_1.txt @@ -0,0 +1,2 @@ +INFO:Slither-simil:Function complexExternalWrites in contract Complex is encoded as: +INFO:Slither-simil:new_contract (local_solc_variable(default)):=(temporary_variable) high_level_call high_level_call high_level_call high_level_call high_level_call new_contract (local_solc_variable(default)):=(temporary_variable) high_level_call new_contract (local_solc_variable(default)):=(temporary_variable) solidity_call(keccak256()) type_conversion(bytes4) low_level_call new_contract (local_solc_variable(default)):=(temporary_variable) solidity_call(keccak256()) type_conversion(bytes4) low_level_call diff --git a/utils/similarity/__init__.py b/utils/similarity/__init__.py new file mode 100644 index 000000000..b31b92c60 --- /dev/null +++ b/utils/similarity/__init__.py @@ -0,0 +1 @@ +from .model import load_model diff --git a/utils/similarity/__main__.py b/utils/similarity/__main__.py new file mode 100755 index 000000000..239b68b62 --- /dev/null +++ b/utils/similarity/__main__.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +import argparse +import logging +import sys +import traceback +import operator + +from crytic_compile import cryticparser + +from .info import info +from .test import test +from .train import train +from .plot import plot + +logging.basicConfig() +logger = logging.getLogger("Slither-simil") + +modes = ["info", "test", "train", "plot"] + +def parse_args(): + parser = argparse.ArgumentParser(description='Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector') + + parser.add_argument('mode', + help="|".join(modes)) + + parser.add_argument('model', + help='model.bin') + + parser.add_argument('--filename', + action='store', + dest='filename', + help='contract.sol') + + parser.add_argument('--fname', + action='store', + dest='fname', + help='Target function') + + parser.add_argument('--ext', + action='store', + dest='ext', + help='Extension to filter contracts') + + parser.add_argument('--nsamples', + action='store', + type=int, + dest='nsamples', + help='Number of contract samples used for training') + + parser.add_argument('--ntop', + action='store', + type=int, + dest='ntop', + default=10, + help='Number of more similar contracts to show for testing') + + parser.add_argument('--input', + action='store', + dest='input', + help='File or directory used as input') + + parser.add_argument('--version', + help='displays the current version', + version="0.0", + action='version') + + cryticparser.init(parser) + + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + + args = parser.parse_args() + return args + +# endregion +################################################################################### +################################################################################### +# region Main +################################################################################### +################################################################################### + +def main(): + args = parse_args() + + default_log = logging.INFO + logger.setLevel(default_log) + + mode = args.mode + + if mode == "info": + info(args) + elif mode == "train": + train(args) + elif mode == "test": + test(args) + elif mode == "plot": + plot(args) + else: + logger.error('Invalid mode!. It should be one of these: %s' % ", ".join(modes)) + sys.exit(-1) + +if __name__ == '__main__': + main() + +# endregion diff --git a/utils/similarity/cache.py b/utils/similarity/cache.py new file mode 100644 index 000000000..efb748c99 --- /dev/null +++ b/utils/similarity/cache.py @@ -0,0 +1,22 @@ +import sys + +try: + import numpy as np +except ImportError: + print("ERROR: in order to use slither-simil, you need to install numpy") + print("$ pip3 install numpy --user\n") + sys.exit(-1) + +def load_cache(infile, nsamples=None): + cache = dict() + with np.load(infile) as data: + array = data['arr_0'][0] + for i,(x,y) in enumerate(array): + cache[x] = y + if i == nsamples: + break + + return cache + +def save_cache(cache, outfile): + np.savez(outfile,[np.array(cache)]) diff --git a/utils/similarity/encode.py b/utils/similarity/encode.py new file mode 100644 index 000000000..3ea47ca7c --- /dev/null +++ b/utils/similarity/encode.py @@ -0,0 +1,214 @@ +import logging +import os + +from slither import Slither +from slither.core.declarations import Structure, Enum, SolidityVariableComposed, SolidityVariable, Function +from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType +from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.core.variables.state_variable import StateVariable +from slither.slithir.operations import Assignment, Index, Member, Length, Balance, Binary, \ + Unary, Condition, NewArray, NewStructure, NewContract, NewElementaryType, \ + SolidityCall, Push, Delete, EventCall, LibraryCall, InternalDynamicCall, \ + HighLevelCall, LowLevelCall, TypeConversion, Return, Transfer, Send, Unpack, InitArray, InternalCall +from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, ReferenceVariable +from .cache import load_cache + +simil_logger = logging.getLogger("Slither-simil") +compiler_logger = logging.getLogger("CryticCompile") +compiler_logger.setLevel(logging.CRITICAL) +slither_logger = logging.getLogger("Slither") +slither_logger.setLevel(logging.CRITICAL) + +def parse_target(target): + if target is None: + return None, None + + parts = target.split('.') + if len(parts) == 1: + return None, parts[0] + elif len(parts) == 2: + return parts + else: + simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") + +def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): + r = dict() + if infile.endswith(".npz"): + r = load_cache(infile, nsamples=nsamples) + else: + contracts = load_contracts(infile, ext=ext, nsamples=nsamples) + for contract in contracts: + for x,ir in encode_contract(contract, **kwargs).items(): + if ir != []: + y = " ".join(ir) + r[x] = vmodel.get_sentence_vector(y) + + return r + +def load_contracts(dirname, ext=None, nsamples=None, **kwargs): + r = [] + walk = list(os.walk(dirname)) + for x, y, files in walk: + for f in files: + if ext is None or f.endswith(ext): + r.append(x + "/".join(y) + "/" + f) + + if nsamples is None: + return r + else: + # TODO: shuffle + return r[:nsamples] + +def ntype(_type): + if isinstance(_type, ElementaryType): + _type = str(_type) + elif isinstance(_type, ArrayType): + if isinstance(_type.type, ElementaryType): + _type = str(_type) + else: + _type = "user_defined_array" + elif isinstance(_type, Structure): + _type = str(_type) + elif isinstance(_type, Enum): + _type = str(_type) + elif isinstance(_type, MappingType): + _type = str(_type) + elif isinstance(_type, UserDefinedType): + _type = "user_defined_type" # TODO: this could be Contract, Enum or Struct + else: + _type = str(_type) + + _type = _type.replace(" memory","") + _type = _type.replace(" storage ref","") + + if "struct" in _type: + return "struct" + elif "enum" in _type: + return "enum" + elif "tuple" in _type: + return "tuple" + elif "contract" in _type: + return "contract" + elif "mapping" in _type: + return "mapping" + else: + return _type.replace(" ","_") + +def encode_ir(ir): + # operations + if isinstance(ir, Assignment): + return '({}):=({})'.format(encode_ir(ir.lvalue), encode_ir(ir.rvalue)) + if isinstance(ir, Index): + return 'index({})'.format(ntype(ir._type)) + if isinstance(ir, Member): + return 'member' #.format(ntype(ir._type)) + if isinstance(ir, Length): + return 'length' + if isinstance(ir, Balance): + return 'balance' + if isinstance(ir, Binary): + return 'binary({})'.format(ir.type_str) + if isinstance(ir, Unary): + return 'unary({})'.format(ir.type_str) + if isinstance(ir, Condition): + return 'condition({})'.format(encode_ir(ir.value)) + if isinstance(ir, NewStructure): + return 'new_structure' + if isinstance(ir, NewContract): + return 'new_contract' + if isinstance(ir, NewArray): + return 'new_array({})'.format(ntype(ir._array_type)) + if isinstance(ir, NewElementaryType): + return 'new_elementary({})'.format(ntype(ir._type)) + if isinstance(ir, Push): + return 'push({},{})'.format(encode_ir(ir.value), encode_ir(ir.lvalue)) + if isinstance(ir, Delete): + return 'delete({},{})'.format(encode_ir(ir.lvalue), encode_ir(ir.variable)) + if isinstance(ir, SolidityCall): + return 'solidity_call({})'.format(ir.function.full_name) + if isinstance(ir, InternalCall): + return 'internal_call({})'.format(ntype(ir._type_call)) + if isinstance(ir, EventCall): # is this useful? + return 'event' + if isinstance(ir, LibraryCall): + return 'library_call' + if isinstance(ir, InternalDynamicCall): + return 'internal_dynamic_call' + if isinstance(ir, HighLevelCall): # TODO: improve + return 'high_level_call' + if isinstance(ir, LowLevelCall): # TODO: improve + return 'low_level_call' + if isinstance(ir, TypeConversion): + return 'type_conversion({})'.format(ntype(ir.type)) + if isinstance(ir, Return): # this can be improved using values + return 'return' #.format(ntype(ir.type)) + if isinstance(ir, Transfer): + return 'transfer({})'.format(encode_ir(ir.call_value)) + if isinstance(ir, Send): + return 'send({})'.format(encode_ir(ir.call_value)) + if isinstance(ir, Unpack): # TODO: improve + return 'unpack' + if isinstance(ir, InitArray): # TODO: improve + return 'init_array' + if isinstance(ir, Function): # TODO: investigate this + return 'function_solc' + + # variables + if isinstance(ir, Constant): + return 'constant({})'.format(ntype(ir._type)) + if isinstance(ir, SolidityVariableComposed): + return 'solidity_variable_composed({})'.format(ir.name) + if isinstance(ir, SolidityVariable): + return 'solidity_variable{}'.format(ir.name) + if isinstance(ir, TemporaryVariable): + return 'temporary_variable' + if isinstance(ir, ReferenceVariable): + return 'reference({})'.format(ntype(ir._type)) + if isinstance(ir, LocalVariable): + return 'local_solc_variable({})'.format(ir._location) + if isinstance(ir, StateVariable): + return 'state_solc_variable({})'.format(ntype(ir._type)) + if isinstance(ir, LocalVariableInitFromTuple): + return 'local_variable_init_tuple' + if isinstance(ir, TupleVariable): + return 'tuple_variable' + + # default + else: + simil_logger.error(type(ir),"is missing encoding!") + return '' + +def encode_contract(cfilename, **kwargs): + r = dict() + + # Init slither + try: + slither = Slither(cfilename, **kwargs) + except: + simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs['solc']) + return r + + # Iterate over all the contracts + for contract in slither.contracts: + + # Iterate over all the functions + for function in contract.functions_not_inherited: + + if function.nodes == []: + continue + + x = (cfilename,contract.name,function.name) + + r[x] = [] + + # Iterate over the nodes of the function + for node in function.nodes: + # Print the Solidity expression of the nodes + # And the SlithIR operations + if node.expression: + for ir in node.irs: + r[x].append(encode_ir(ir)) + return r + + diff --git a/utils/similarity/info.py b/utils/similarity/info.py new file mode 100644 index 000000000..e250aa991 --- /dev/null +++ b/utils/similarity/info.py @@ -0,0 +1,54 @@ +import logging +import sys +import os.path +import traceback + +from .model import load_model +from .encode import parse_target, encode_contract + +logging.basicConfig() +logger = logging.getLogger("Slither-simil") + +def info(args): + + try: + + model = args.model + if os.path.isfile(model): + model = load_model(model) + else: + model = None + + filename = args.filename + contract, fname = parse_target(args.fname) + solc = args.solc + + if filename is None and contract is None and fname is None: + logger.info("%s uses the following words:",args.model) + for word in model.get_words(): + logger.info(word) + sys.exit(0) + + if filename is None or contract is None or fname is None: + logger.error('The encode mode requires filename, contract and fname parameters.') + sys.exit(-1) + + irs = encode_contract(filename, **vars(args)) + if len(irs) == 0: + sys.exit(-1) + + x = (filename,contract,fname) + y = " ".join(irs[x]) + + logger.info("Function {} in contract {} is encoded as:".format(fname, contract)) + logger.info(y) + if model is not None: + fvector = model.get_sentence_vector(y) + logger.info(fvector) + + except Exception: + logger.error('Error in %s' % args.filename) + logger.error(traceback.format_exc()) + sys.exit(-1) + + diff --git a/utils/similarity/model.py b/utils/similarity/model.py new file mode 100644 index 000000000..4f3412113 --- /dev/null +++ b/utils/similarity/model.py @@ -0,0 +1,9 @@ +import sys + +try: + from fastText import load_model + from fastText import train_unsupervised +except ImportError: + print("ERROR: in order to use slither-simil, you need to install fastText 0.2.0:") + print("$ pip3 install https://github.com/facebookresearch/fastText/archive/0.2.0.zip --user\n") + sys.exit(-1) diff --git a/utils/similarity/plot.py b/utils/similarity/plot.py new file mode 100644 index 000000000..05d8bf921 --- /dev/null +++ b/utils/similarity/plot.py @@ -0,0 +1,78 @@ +import logging +import sys +import traceback +import operator +import numpy as np +import random + +from .model import load_model +from .encode import load_and_encode, parse_target + +try: + from sklearn import decomposition + import matplotlib.pyplot as plt +except ImportError: + decomposition = None + plt = None + +logger = logging.getLogger("Slither-simil") + +def plot(args): + + if decomposition is None or plt is None: + logger.error("ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:") + logger.error("$ pip3 install sklearn matplotlib --user") + sys.exit(-1) + + try: + + model = args.model + model = load_model(model) + filename = args.filename + #contract = args.contract + contract, fname = parse_target(args.fname) + #solc = args.solc + infile = args.input + #ext = args.filter + #nsamples = args.nsamples + + if fname is None or infile is None: + logger.error('The plot mode requieres fname and input parameters.') + sys.exit(-1) + + logger.info('Loading data..') + cache = load_and_encode(infile, **vars(args)) + + data = list() + fs = list() + + logger.info('Procesing data..') + for (f,c,n),y in cache.items(): + if (c == contract or contract is None) and n == fname: + fs.append(f) + data.append(y) + + if len(data) == 0: + logger.error('No contract was found with function %s', fname) + sys.exit(-1) + + data = np.array(data) + pca = decomposition.PCA(n_components=2) + tdata = pca.fit_transform(data) + + logger.info('Plotting data..') + plt.figure(figsize=(20,10)) + assert(len(tdata) == len(fs)) + for ([x,y],l) in zip(tdata, fs): + x = random.gauss(0, 0.01) + x + y = random.gauss(0, 0.01) + y + plt.scatter(x, y, c='blue') + plt.text(x-0.001,y+0.001, l) + + logger.info('Saving figure to plot.png..') + plt.savefig('plot.png', bbox_inches='tight') + + except Exception: + logger.error('Error in %s' % args.filename) + logger.error(traceback.format_exc()) + sys.exit(-1) diff --git a/utils/similarity/similarity.py b/utils/similarity/similarity.py new file mode 100644 index 000000000..4cc3f2b35 --- /dev/null +++ b/utils/similarity/similarity.py @@ -0,0 +1,6 @@ +import numpy as np + +def similarity(v1, v2): + n1 = np.linalg.norm(v1) + n2 = np.linalg.norm(v2) + return np.dot(v1, v2) / n1 / n2 diff --git a/utils/similarity/test.py b/utils/similarity/test.py new file mode 100755 index 000000000..15a39cc13 --- /dev/null +++ b/utils/similarity/test.py @@ -0,0 +1,54 @@ +import argparse +import logging +import sys +import traceback +import operator +import numpy as np + +from .model import load_model +from .encode import encode_contract, load_and_encode, parse_target +from .cache import save_cache +from .similarity import similarity + +logger = logging.getLogger("Slither-simil") + +def test(args): + + try: + model = args.model + model = load_model(model) + filename = args.filename + contract, fname = parse_target(args.fname) + infile = args.input + ntop = args.ntop + + if filename is None or contract is None or fname is None or infile is None: + logger.error('The test mode requires filename, contract, fname and input parameters.') + sys.exit(-1) + + irs = encode_contract(filename, **vars(args)) + if len(irs) == 0: + sys.exit(-1) + + y = " ".join(irs[(filename,contract,fname)]) + + fvector = model.get_sentence_vector(y) + cache = load_and_encode(infile, model, **vars(args)) + #save_cache("cache.npz", cache) + + r = dict() + for x,y in cache.items(): + r[x] = similarity(fvector, y) + + r = sorted(r.items(), key=operator.itemgetter(1), reverse=True) + logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop) + format_table = "{: <65} {: <20} {: <20} {: <10}" + logger.info(format_table.format(*["filename", "contract", "function", "score"])) + for x,score in r[:ntop]: + score = str(round(score, 3)) + logger.info(format_table.format(*(list(x)+[score]))) + + except Exception: + logger.error('Error in %s' % args.filename) + logger.error(traceback.format_exc()) + sys.exit(-1) diff --git a/utils/similarity/train.py b/utils/similarity/train.py new file mode 100755 index 000000000..e810450a6 --- /dev/null +++ b/utils/similarity/train.py @@ -0,0 +1,54 @@ +import argparse +import logging +import sys +import traceback +import operator +import os + +from .model import train_unsupervised +from .encode import encode_contract, load_contracts +from .cache import save_cache + +logger = logging.getLogger("Slither-simil") + +def train(args): + + try: + last_data_train_filename = "last_data_train.txt" + model_filename = args.model + dirname = args.input + nsamples = args.nsamples + + if dirname is None: + logger.error('The train mode requires the input parameter.') + sys.exit(-1) + + contracts = load_contracts(dirname, **vars(args)) + logger.info('Saving extracted data into %s', last_data_train_filename) + cache = [] + with open(last_data_train_filename, 'w') as f: + for filename in contracts: + #cache[filename] = dict() + for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items(): + if ir != []: + x = " ".join(ir) + f.write(x+"\n") + cache.append((os.path.split(filename)[-1], contract, function, x)) + + logger.info('Starting training') + model = train_unsupervised(input=last_data_train_filename, model='skipgram') + logger.info('Training complete') + logger.info('Saving model') + model.save_model(model_filename) + + for i,(filename, contract, function, irs) in enumerate(cache): + cache[i] = ((filename, contract, function), model.get_sentence_vector(irs)) + + logger.info('Saving cache in cache.npz') + save_cache(cache, "cache.npz") + logger.info('Done!') + + except Exception: + logger.error('Error in %s' % args.filename) + logger.error(traceback.format_exc()) + sys.exit(-1)