Static Analyzer for Solidity
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
slither/utils/similarity/train.py

54 lines
1.8 KiB

import argparse
import logging
import sys
import traceback
import operator
import os
from .model import train_unsupervised
from .encode import encode_contract, load_contracts
from .cache import save_cache
logger = logging.getLogger("Slither-simil")
def train(args):
try:
last_data_train_filename = "last_data_train.txt"
model_filename = args.model
dirname = args.input
nsamples = args.nsamples
if dirname is None:
logger.error('The train mode requires the input parameter.')
sys.exit(-1)
contracts = load_contracts(dirname, **vars(args))
logger.info('Saving extracted data into %s', last_data_train_filename)
cache = []
with open(last_data_train_filename, 'w') as f:
for filename in contracts:
#cache[filename] = dict()
for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items():
if ir != []:
x = " ".join(ir)
f.write(x+"\n")
cache.append((os.path.split(filename)[-1], contract, function, x))
logger.info('Starting training')
model = train_unsupervised(input=last_data_train_filename, model='skipgram')
logger.info('Training complete')
logger.info('Saving model')
model.save_model(model_filename)
for i,(filename, contract, function, irs) in enumerate(cache):
cache[i] = ((filename, contract, function), model.get_sentence_vector(irs))
logger.info('Saving cache in cache.npz')
save_cache(cache, "cache.npz")
logger.info('Done!')
except Exception:
logger.error('Error in %s' % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)