first version of slither-simil

pull/202/head
ggrieco-tob 6 years ago
parent e2e6873dab
commit 43fa9c73df
  1. 3
      setup.py
  2. 0
      utils/similarity/__init__.py
  3. 107
      utils/similarity/__main__.py
  4. 22
      utils/similarity/cache.py
  5. 168
      utils/similarity/encode.py
  6. 47
      utils/similarity/info.py
  7. 6
      utils/similarity/similarity.py
  8. 49
      utils/similarity/test.py
  9. 37
      utils/similarity/train.py

@ -15,7 +15,8 @@ setup(
'console_scripts': [
'slither = slither.__main__:main',
'slither-check-upgradeability = utils.upgradeability.__main__:main',
'slither-find-paths = utils.possible_paths.__main__:main'
'slither-find-paths = utils.possible_paths.__main__:main',
'slither-simil = utils.similarity.__main__:main'
]
}
)

@ -0,0 +1,107 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
import traceback
import operator
import numpy as np
from .info import info
from .test import test
from .train import train
logging.basicConfig()
logger = logging.getLogger("Slither")
slither_simil_usage = "USAGE" # TODO
modes = ["info", "test", "train"]
def parse_args():
parser = argparse.ArgumentParser(description='',
usage=slither_simil_usage)
parser.add_argument('mode',
help="|".join(modes))
parser.add_argument('model',
help='model.bin')
parser.add_argument('--solc',
help='solc path',
action='store',
default='solc')
parser.add_argument('--filename',
action='store',
dest='filename',
help='contract.sol')
parser.add_argument('--contract',
action='store',
dest='contract',
help='Contract')
parser.add_argument('--filter',
action='store',
dest='filter',
help='Extension to filter contracts')
parser.add_argument('--fname',
action='store',
dest='fname',
help='Function name')
parser.add_argument('--input',
action='store',
dest='input',
help='File or directory used as input')
parser.add_argument('--version',
help='displays the current version',
version="0.0",
action='version')
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1)
args = parser.parse_args()
return args
# endregion
###################################################################################
###################################################################################
# region Main
###################################################################################
###################################################################################
def main():
args = parse_args()
default_log = logging.INFO
logger.setLevel(default_log)
try:
mode = args.mode
if mode == "info":
info(args)
elif mode == "train":
train(args)
elif mode == "test":
test(args)
else:
logger.error('Invalid mode!. It should be one of these: %s' % ", ".join(modes))
sys.exit(-1)
except Exception:
logger.error('Error in %s' % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)
if __name__ == '__main__':
main()
# endregion

@ -0,0 +1,22 @@
import numpy as np
from .encode import encode_contract, load_contracts
def load_cache(infile, model, ext=None, solc='solc'):
cache = dict()
if infile.endswith(".npz"):
with np.load(infile) as data:
array = data['arr_0'][0]
for x,y in array:
cache[x] = y
else:
contracts = load_contracts(infile, ext=ext)
for contract in contracts:
for x,ir in encode_contract(contract, solc=solc).items():
if ir != []:
y = " ".join(ir)
cache[x] = model.get_sentence_vector(y)
return cache
def save_cache(cache, outfile):
np.savez(outfile,[np.array(list(cache.items()))])

@ -0,0 +1,168 @@
import os
import sys
from slither import Slither
from slither.slithir.operations import *
from slither.slithir.variables import *
from slither.core.declarations import *
from slither.solc_parsing.declarations.function import *
from slither.solc_parsing.variables.state_variable import *
from slither.solc_parsing.variables.local_variable import *
from slither.solc_parsing.variables.local_variable_init_from_tuple import *
def load_contracts(dirname, ext=None):
r = []
walk = list(os.walk(dirname))
for x, y, files in walk:
for f in files:
if ext is None or f.endswith(ext):
r.append(x + "/".join(y) + "/" + f)
return r
def ntype(_type):
if type(_type) is not str:
_type = str(_type)
if "struct" in _type:
return "struct"
elif "enum" in _type:
return "enum"
elif "tuple" in _type:
return "tuple"
elif "contract" in _type:
return "contract"
elif "mapping" in _type:
return "mapping"
elif "." in _type or _type[0].isupper():
return "<name>"
else:
return _type.replace(" ","_")
def encode_ir(ir):
# operations
if isinstance(ir, Assignment):
return '({}):=({})'.format(encode_ir(ir.lvalue), encode_ir(ir.rvalue))
if isinstance(ir, Index):
return 'index({})'.format(ntype(ir._type))
if isinstance(ir, Member):
return 'member' #.format(ntype(ir._type))
if isinstance(ir, Length):
return 'length'
if isinstance(ir, Balance):
return 'balance'
if isinstance(ir, Binary):
return 'binary({})'.format(ir.type_str)
if isinstance(ir, Unary):
return 'unary({})'.format(ir.type_str)
if isinstance(ir, Condition):
return 'condition({})'.format(encode_ir(ir.value))
if isinstance(ir, NewStructure):
return 'new_structure'
if isinstance(ir, NewContract):
return 'new_contract'
if isinstance(ir, NewArray):
return 'new_array({})'.format(ntype(ir._array_type))
if isinstance(ir, NewElementaryType):
return 'new_elementary({})'.format(ntype(ir._type))
if isinstance(ir, Push):
return 'push({},{})'.format(encode_ir(ir.value), encode_ir(ir.lvalue))
if isinstance(ir, Delete):
return 'delete({},{})'.format(encode_ir(ir.lvalue), encode_ir(ir.variable))
if isinstance(ir, SolidityCall):
return 'solidity_call({})'.format(ir.function.full_name)
if isinstance(ir, InternalCall):
return 'internal_call({})'.format(ntype(ir._type_call))
if isinstance(ir, EventCall): # is this useful?
return 'event'
if isinstance(ir, LibraryCall):
return 'library_call'
if isinstance(ir, InternalDynamicCall):
return 'internal_dynamic_call'
if isinstance(ir, HighLevelCall): # TODO: improve
return 'high_level_call'
if isinstance(ir, LowLevelCall): # TODO: improve
return 'low_level_call'
if isinstance(ir, TypeConversion):
return 'type_conversion({})'.format(ntype(ir.type))
if isinstance(ir, Return): # this can be improved using values
return 'return' #.format(ntype(ir.type))
if isinstance(ir, Transfer):
return 'transfer({})'.format(encode_ir(ir.call_value))
if isinstance(ir, Send):
return 'send({})'.format(encode_ir(ir.call_value))
if isinstance(ir, Unpack): # TODO: improve
return 'unpack'
if isinstance(ir, InitArray): # TODO: improve
return 'init_array'
if isinstance(ir, FunctionSolc): # TODO: investigate this
return 'function_solc'
# variables
if isinstance(ir, Constant):
return 'constant({})'.format(ntype(ir._type))
if isinstance(ir, SolidityVariableComposed):
return 'solidity_variable_composed({})'.format(ir.name)
if isinstance(ir, SolidityVariable):
return 'solidity_variable{}'.format(ir.name)
if isinstance(ir, TemporaryVariable):
return 'temporary_variable'
if isinstance(ir, ReferenceVariable):
return 'reference({})'.format(ntype(ir._type))
if isinstance(ir, LocalVariableSolc):
return 'local_solc_variable({})'.format(ir._location)
if isinstance(ir, StateVariableSolc):
return 'state_solc_variable({})'.format(ntype(ir._type))
if isinstance(ir, LocalVariableInitFromTupleSolc):
return 'local_variable_init_tuple'
if isinstance(ir, TupleVariable):
return 'tuple_variable'
# default
else:
print(type(ir),"is missing encoding!")
#sys.exit(1)
return ''
def encode_contract(filename, solc):
r = dict()
# Init slither
try:
slither = Slither(filename, solc=solc)
except:
print("Compilation failed")
return r
# Iterate over all the contracts
for contract in slither.contracts:
# Iterate over all the functions
for function in contract.functions:
# Dont explore inherited functions
if function.contract == contract:
if function.nodes == []:
continue
x = "-".join([filename,contract.name,function.name])
r[x] = []
# Iterate over the nodes of the function
for node in function.nodes:
# Print the Solidity expression of the nodes
# And the SlithIR operations
if node.expression:
#print('\tSolidity expression: {}'.format(node.expression))
#print('\tSlithIR:')
for ir in node.irs:
#print(ir)
r[x].append(encode_ir(ir))
#print('\t\t\t{}'.format(ir))
return r

@ -0,0 +1,47 @@
import logging
import sys
import traceback
from fastText import load_model
from .encode import encode_contract
logging.basicConfig()
logger = logging.getLogger("Slither")
def info(args):
try:
model = args.model
model = load_model(model)
filename = args.filename
contract = args.contract
solc = args.solc
fname = args.fname
if filename is None and contract is None and fname is None:
print(args.model,"uses the following words:")
for word in model.get_words():
print(word)
sys.exit(0)
if filename is None or contract is None or fname is None:
logger.error('The encode mode requires filename, contract and fname parameters.')
sys.exit(-1)
irs = encode_contract(filename, solc=solc)
if len(irs) == 0:
sys.exit(-1)
x = "-".join([filename,contract,fname])
y = " ".join(irs[x])
fvector = model.get_sentence_vector(y)
print("Function {} in contract {} is encoded as:".format(fname, contract))
print(y)
print(fvector)
except Exception:
logger.error('Error in %s' % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -0,0 +1,6 @@
import numpy as np
def similarity(v1, v2):
n1 = np.linalg.norm(v1)
n2 = np.linalg.norm(v2)
return np.dot(v1, v2) / n1 / n2

@ -0,0 +1,49 @@
import argparse
import logging
import sys
import traceback
import operator
import numpy as np
from fastText import load_model
from .encode import encode_contract, load_contracts
from .cache import load_cache, save_cache
from .similarity import similarity
logger = logging.getLogger("crytic-pred")
def test(args):
try:
model = args.model
model = load_model(model)
filename = args.filename
contract = args.contract
fname = args.fname
solc = args.solc
infile = args.input
ext = args.filter
if filename is None or contract is None or fname is None or infile is None:
logger.error('The test mode requires filename, contract, fname and input parameters.')
sys.exit(-1)
irs = encode_contract(filename,solc=solc)
x = "-".join([filename,contract,fname])
y = " ".join(irs[x])
fvector = model.get_sentence_vector(y)
cache = load_cache(infile, model, ext=ext, solc=solc)
#save_cache("cache.npz", cache)
r = dict()
for x,y in cache.items():
r[x] = similarity(fvector, y)
r = sorted(r.items(), key=operator.itemgetter(1), reverse=True)
for x,score in r[:10]:
print(x,score)
except Exception:
logger.error('Error in %s' % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)

@ -0,0 +1,37 @@
import argparse
import logging
import sys
import traceback
import operator
from fastText import train_unsupervised
from .encode import encode_contract, load_contracts
logger = logging.getLogger("crytic-pred")
def train(args):
try:
model_filename = args.model
solc = args.solc
dirname = args.input
if dirname is None:
logger.error('The train mode requires the directory parameter.')
sys.exit(-1)
contracts = load_contracts(dirname)
with open("data.txt", 'w') as f:
for contract in contracts:
for function,ir in encode_contract(contract,solc).items():
if ir != []:
f.write(" ".join(ir)+"\n")
model = train_unsupervised(input='data.txt', model='skipgram')
model.save_model(model_filename)
print(model.get_words())
except Exception:
logger.error('Error in %s' % args.filename)
logger.error(traceback.format_exc())
sys.exit(-1)
Loading…
Cancel
Save