pull/202/head
ggrieco-tob 6 years ago
parent 03b318b55e
commit 8f2bad8905
  1. 2
      utils/similarity/__main__.py
  2. 2
      utils/similarity/cache.py
  3. 2
      utils/similarity/encode.py
  4. 4
      utils/similarity/test.py
  5. 25
      utils/similarity/train.py

@ -11,7 +11,7 @@ from .test import test
from .train import train
logging.basicConfig()
logger = logging.getLogger("Slither")
logger = logging.getLogger("Slither-simil")
slither_simil_usage = "USAGE" # TODO
modes = ["info", "test", "train"]

@ -19,4 +19,4 @@ def load_cache(infile, model, ext=None, solc='solc'):
return cache
def save_cache(cache, outfile):
np.savez(outfile,[np.array(list(cache.items()))])
np.savez(outfile,[np.array(cache)])

@ -172,7 +172,7 @@ def encode_contract(filename, solc):
if function.nodes == []:
continue
x = "-".join([filename,contract.name,function.name])
x = (filename,contract.name,function.name)
r[x] = []

@ -23,6 +23,7 @@ def test(args):
solc = args.solc
infile = args.input
ext = args.filter
if filename is None or contract is None or fname is None or infile is None:
logger.error('The test mode requires filename, contract, fname and input parameters.')
sys.exit(-1)
@ -31,8 +32,7 @@ def test(args):
if len(irs) == 0:
sys.exit(-1)
x = "-".join([filename,contract,fname])
y = " ".join(irs[x])
y = " ".join(irs[(filename,contract,fname)])
fvector = model.get_sentence_vector(y)
cache = load_cache(infile, model, ext=ext, solc=solc)

@ -3,9 +3,11 @@ import logging
import sys
import traceback
import operator
import os
from fastText import train_unsupervised
from .encode import encode_contract, load_contracts
from .encode import encode_contract, load_contracts
from .cache import save_cache
logger = logging.getLogger("Slither-simil")
@ -25,15 +27,26 @@ def train(args):
contracts = load_contracts(dirname, ext=ext, nsamples=nsamples)
logger.info('Saving extracted data into', last_data_train_filename)
cache = []
with open(last_data_train_filename, 'w') as f:
for contract in contracts:
for function,ir in encode_contract(contract,solc).items():
for filename in contracts:
#cache[filename] = dict()
for (filename, contract, function), ir in encode_contract(filename,solc).items():
if ir != []:
f.write(" ".join(ir)+"\n")
x = " ".join(ir)
f.write(x+"\n")
cache.append((os.path.split(filename)[-1], contract, function, x))
logger.info('Starting training')
model = train_unsupervised(input=last_data_train_filename, model='skipgram')
logger.info('Training complete')
model.save_model(model_filename)
print(model.get_words())
for i,(filename, contract, function, irs) in enumerate(cache):
cache[i] = ((filename, contract, function), model.get_sentence_vector(irs))
logger.info('Saved cache in cache.npz')
save_cache(cache, "cache.npz")
except Exception:
logger.error('Error in %s' % args.filename)

Loading…
Cancel
Save