diff --git a/utils/similarity/__main__.py b/utils/similarity/__main__.py index 2e2abd26c..9e9860fd8 100755 --- a/utils/similarity/__main__.py +++ b/utils/similarity/__main__.py @@ -51,6 +51,12 @@ def parse_args(): dest='fname', help='Function name') + parser.add_argument('--nsamples', + action='store', + type=int, + dest='nsamples', + help='Number of contract samples used for training') + parser.add_argument('--input', action='store', dest='input', diff --git a/utils/similarity/encode.py b/utils/similarity/encode.py index 6c6c48906..6630b1bfc 100644 --- a/utils/similarity/encode.py +++ b/utils/similarity/encode.py @@ -49,8 +49,8 @@ def ntype(_type): else: _type = str(_type) - _type = _type.replace("_memory","") - _type = _type.replace("_storage_ref","") + _type = _type.replace(" memory","") + _type = _type.replace(" storage ref","") if "struct" in _type: return "struct" diff --git a/utils/similarity/train.py b/utils/similarity/train.py index 6892c8c80..8f8e7a888 100755 --- a/utils/similarity/train.py +++ b/utils/similarity/train.py @@ -12,23 +12,26 @@ logger = logging.getLogger("Slither-simil") def train(args): try: + last_data_train_filename = "last_data_train.txt" model_filename = args.model solc = args.solc dirname = args.input ext = args.filter + nsamples = args.nsamples if dirname is None: logger.error('The train mode requires the input parameter.') sys.exit(-1) - contracts = load_contracts(dirname, ext=ext, nsamples=None) - with open("data.txt", 'w') as f: + contracts = load_contracts(dirname, ext=ext, nsamples=nsamples) + logger.info('Saving extracted data into', last_data_train_filename) + with open(last_data_train_filename, 'w') as f: for contract in contracts: for function,ir in encode_contract(contract,solc).items(): if ir != []: f.write(" ".join(ir)+"\n") - model = train_unsupervised(input='data.txt', model='skipgram') + model = train_unsupervised(input=last_data_train_filename, model='skipgram') model.save_model(model_filename) print(model.get_words())