mirror of https://github.com/crytic/slither
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.8 KiB
54 lines
1.8 KiB
import argparse
|
|
import logging
|
|
import sys
|
|
import traceback
|
|
import operator
|
|
import os
|
|
|
|
from .model import train_unsupervised
|
|
from .encode import encode_contract, load_contracts
|
|
from .cache import save_cache
|
|
|
|
logger = logging.getLogger("Slither-simil")
|
|
|
|
def train(args):
|
|
|
|
try:
|
|
last_data_train_filename = "last_data_train.txt"
|
|
model_filename = args.model
|
|
dirname = args.input
|
|
nsamples = args.nsamples
|
|
|
|
if dirname is None:
|
|
logger.error('The train mode requires the input parameter.')
|
|
sys.exit(-1)
|
|
|
|
contracts = load_contracts(dirname, **vars(args))
|
|
logger.info('Saving extracted data into %s', last_data_train_filename)
|
|
cache = []
|
|
with open(last_data_train_filename, 'w') as f:
|
|
for filename in contracts:
|
|
#cache[filename] = dict()
|
|
for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items():
|
|
if ir != []:
|
|
x = " ".join(ir)
|
|
f.write(x+"\n")
|
|
cache.append((os.path.split(filename)[-1], contract, function, x))
|
|
|
|
logger.info('Starting training')
|
|
model = train_unsupervised(input=last_data_train_filename, model='skipgram')
|
|
logger.info('Training complete')
|
|
logger.info('Saving model')
|
|
model.save_model(model_filename)
|
|
|
|
for i,(filename, contract, function, irs) in enumerate(cache):
|
|
cache[i] = ((filename, contract, function), model.get_sentence_vector(irs))
|
|
|
|
logger.info('Saving cache in cache.npz')
|
|
save_cache(cache, "cache.npz")
|
|
logger.info('Done!')
|
|
|
|
except Exception:
|
|
logger.error('Error in %s' % args.filename)
|
|
logger.error(traceback.format_exc())
|
|
sys.exit(-1)
|
|
|