mirror of https://github.com/crytic/slither
commit
c017f7aed2
@ -0,0 +1,21 @@ |
|||||||
|
#!/usr/bin/env bash |
||||||
|
|
||||||
|
### Install requisites |
||||||
|
|
||||||
|
pip3.6 install pybind11 |
||||||
|
pip3.6 install https://github.com/facebookresearch/fastText/archive/0.2.0.zip |
||||||
|
|
||||||
|
### Test slither-simil |
||||||
|
|
||||||
|
DIR_TESTS="tests/simil" |
||||||
|
slither-simil info "" --filename $DIR_TESTS/../complex_func.sol --fname Complex.complexExternalWrites --solc solc-0.4.25 > test_1.txt 2>&1 |
||||||
|
DIFF=$(diff test_1.txt "$DIR_TESTS/test_1.txt") |
||||||
|
if [ "$DIFF" != "" ] |
||||||
|
then |
||||||
|
echo "slither-simil failed" |
||||||
|
cat test_1.txt |
||||||
|
cat "$DIR_TESTS/test_1.txt" |
||||||
|
exit -1 |
||||||
|
fi |
||||||
|
|
||||||
|
rm test_1.txt |
@ -0,0 +1,2 @@ |
|||||||
|
INFO:Slither-simil:Function complexExternalWrites in contract Complex is encoded as: |
||||||
|
INFO:Slither-simil:new_contract (local_solc_variable(default)):=(temporary_variable) high_level_call high_level_call high_level_call high_level_call high_level_call new_contract (local_solc_variable(default)):=(temporary_variable) high_level_call new_contract (local_solc_variable(default)):=(temporary_variable) solidity_call(keccak256()) type_conversion(bytes4) low_level_call new_contract (local_solc_variable(default)):=(temporary_variable) solidity_call(keccak256()) type_conversion(bytes4) low_level_call |
@ -0,0 +1 @@ |
|||||||
|
from .model import load_model |
@ -0,0 +1,107 @@ |
|||||||
|
#!/usr/bin/env python3 |
||||||
|
|
||||||
|
import argparse |
||||||
|
import logging |
||||||
|
import sys |
||||||
|
import traceback |
||||||
|
import operator |
||||||
|
|
||||||
|
from crytic_compile import cryticparser |
||||||
|
|
||||||
|
from .info import info |
||||||
|
from .test import test |
||||||
|
from .train import train |
||||||
|
from .plot import plot |
||||||
|
|
||||||
|
logging.basicConfig() |
||||||
|
logger = logging.getLogger("Slither-simil") |
||||||
|
|
||||||
|
modes = ["info", "test", "train", "plot"] |
||||||
|
|
||||||
|
def parse_args(): |
||||||
|
parser = argparse.ArgumentParser(description='Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector') |
||||||
|
|
||||||
|
parser.add_argument('mode', |
||||||
|
help="|".join(modes)) |
||||||
|
|
||||||
|
parser.add_argument('model', |
||||||
|
help='model.bin') |
||||||
|
|
||||||
|
parser.add_argument('--filename', |
||||||
|
action='store', |
||||||
|
dest='filename', |
||||||
|
help='contract.sol') |
||||||
|
|
||||||
|
parser.add_argument('--fname', |
||||||
|
action='store', |
||||||
|
dest='fname', |
||||||
|
help='Target function') |
||||||
|
|
||||||
|
parser.add_argument('--ext', |
||||||
|
action='store', |
||||||
|
dest='ext', |
||||||
|
help='Extension to filter contracts') |
||||||
|
|
||||||
|
parser.add_argument('--nsamples', |
||||||
|
action='store', |
||||||
|
type=int, |
||||||
|
dest='nsamples', |
||||||
|
help='Number of contract samples used for training') |
||||||
|
|
||||||
|
parser.add_argument('--ntop', |
||||||
|
action='store', |
||||||
|
type=int, |
||||||
|
dest='ntop', |
||||||
|
default=10, |
||||||
|
help='Number of more similar contracts to show for testing') |
||||||
|
|
||||||
|
parser.add_argument('--input', |
||||||
|
action='store', |
||||||
|
dest='input', |
||||||
|
help='File or directory used as input') |
||||||
|
|
||||||
|
parser.add_argument('--version', |
||||||
|
help='displays the current version', |
||||||
|
version="0.0", |
||||||
|
action='version') |
||||||
|
|
||||||
|
cryticparser.init(parser) |
||||||
|
|
||||||
|
if len(sys.argv) == 1: |
||||||
|
parser.print_help(sys.stderr) |
||||||
|
sys.exit(1) |
||||||
|
|
||||||
|
args = parser.parse_args() |
||||||
|
return args |
||||||
|
|
||||||
|
# endregion |
||||||
|
################################################################################### |
||||||
|
################################################################################### |
||||||
|
# region Main |
||||||
|
################################################################################### |
||||||
|
################################################################################### |
||||||
|
|
||||||
|
def main(): |
||||||
|
args = parse_args() |
||||||
|
|
||||||
|
default_log = logging.INFO |
||||||
|
logger.setLevel(default_log) |
||||||
|
|
||||||
|
mode = args.mode |
||||||
|
|
||||||
|
if mode == "info": |
||||||
|
info(args) |
||||||
|
elif mode == "train": |
||||||
|
train(args) |
||||||
|
elif mode == "test": |
||||||
|
test(args) |
||||||
|
elif mode == "plot": |
||||||
|
plot(args) |
||||||
|
else: |
||||||
|
logger.error('Invalid mode!. It should be one of these: %s' % ", ".join(modes)) |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
main() |
||||||
|
|
||||||
|
# endregion |
@ -0,0 +1,22 @@ |
|||||||
|
import sys |
||||||
|
|
||||||
|
try: |
||||||
|
import numpy as np |
||||||
|
except ImportError: |
||||||
|
print("ERROR: in order to use slither-simil, you need to install numpy") |
||||||
|
print("$ pip3 install numpy --user\n") |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
def load_cache(infile, nsamples=None): |
||||||
|
cache = dict() |
||||||
|
with np.load(infile) as data: |
||||||
|
array = data['arr_0'][0] |
||||||
|
for i,(x,y) in enumerate(array): |
||||||
|
cache[x] = y |
||||||
|
if i == nsamples: |
||||||
|
break |
||||||
|
|
||||||
|
return cache |
||||||
|
|
||||||
|
def save_cache(cache, outfile): |
||||||
|
np.savez(outfile,[np.array(cache)]) |
@ -0,0 +1,214 @@ |
|||||||
|
import logging |
||||||
|
import os |
||||||
|
|
||||||
|
from slither import Slither |
||||||
|
from slither.core.declarations import Structure, Enum, SolidityVariableComposed, SolidityVariable, Function |
||||||
|
from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType |
||||||
|
from slither.core.variables.local_variable import LocalVariable |
||||||
|
from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple |
||||||
|
from slither.core.variables.state_variable import StateVariable |
||||||
|
from slither.slithir.operations import Assignment, Index, Member, Length, Balance, Binary, \ |
||||||
|
Unary, Condition, NewArray, NewStructure, NewContract, NewElementaryType, \ |
||||||
|
SolidityCall, Push, Delete, EventCall, LibraryCall, InternalDynamicCall, \ |
||||||
|
HighLevelCall, LowLevelCall, TypeConversion, Return, Transfer, Send, Unpack, InitArray, InternalCall |
||||||
|
from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, ReferenceVariable |
||||||
|
from .cache import load_cache |
||||||
|
|
||||||
|
simil_logger = logging.getLogger("Slither-simil") |
||||||
|
compiler_logger = logging.getLogger("CryticCompile") |
||||||
|
compiler_logger.setLevel(logging.CRITICAL) |
||||||
|
slither_logger = logging.getLogger("Slither") |
||||||
|
slither_logger.setLevel(logging.CRITICAL) |
||||||
|
|
||||||
|
def parse_target(target): |
||||||
|
if target is None: |
||||||
|
return None, None |
||||||
|
|
||||||
|
parts = target.split('.') |
||||||
|
if len(parts) == 1: |
||||||
|
return None, parts[0] |
||||||
|
elif len(parts) == 2: |
||||||
|
return parts |
||||||
|
else: |
||||||
|
simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") |
||||||
|
|
||||||
|
def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): |
||||||
|
r = dict() |
||||||
|
if infile.endswith(".npz"): |
||||||
|
r = load_cache(infile, nsamples=nsamples) |
||||||
|
else: |
||||||
|
contracts = load_contracts(infile, ext=ext, nsamples=nsamples) |
||||||
|
for contract in contracts: |
||||||
|
for x,ir in encode_contract(contract, **kwargs).items(): |
||||||
|
if ir != []: |
||||||
|
y = " ".join(ir) |
||||||
|
r[x] = vmodel.get_sentence_vector(y) |
||||||
|
|
||||||
|
return r |
||||||
|
|
||||||
|
def load_contracts(dirname, ext=None, nsamples=None, **kwargs): |
||||||
|
r = [] |
||||||
|
walk = list(os.walk(dirname)) |
||||||
|
for x, y, files in walk: |
||||||
|
for f in files: |
||||||
|
if ext is None or f.endswith(ext): |
||||||
|
r.append(x + "/".join(y) + "/" + f) |
||||||
|
|
||||||
|
if nsamples is None: |
||||||
|
return r |
||||||
|
else: |
||||||
|
# TODO: shuffle |
||||||
|
return r[:nsamples] |
||||||
|
|
||||||
|
def ntype(_type): |
||||||
|
if isinstance(_type, ElementaryType): |
||||||
|
_type = str(_type) |
||||||
|
elif isinstance(_type, ArrayType): |
||||||
|
if isinstance(_type.type, ElementaryType): |
||||||
|
_type = str(_type) |
||||||
|
else: |
||||||
|
_type = "user_defined_array" |
||||||
|
elif isinstance(_type, Structure): |
||||||
|
_type = str(_type) |
||||||
|
elif isinstance(_type, Enum): |
||||||
|
_type = str(_type) |
||||||
|
elif isinstance(_type, MappingType): |
||||||
|
_type = str(_type) |
||||||
|
elif isinstance(_type, UserDefinedType): |
||||||
|
_type = "user_defined_type" # TODO: this could be Contract, Enum or Struct |
||||||
|
else: |
||||||
|
_type = str(_type) |
||||||
|
|
||||||
|
_type = _type.replace(" memory","") |
||||||
|
_type = _type.replace(" storage ref","") |
||||||
|
|
||||||
|
if "struct" in _type: |
||||||
|
return "struct" |
||||||
|
elif "enum" in _type: |
||||||
|
return "enum" |
||||||
|
elif "tuple" in _type: |
||||||
|
return "tuple" |
||||||
|
elif "contract" in _type: |
||||||
|
return "contract" |
||||||
|
elif "mapping" in _type: |
||||||
|
return "mapping" |
||||||
|
else: |
||||||
|
return _type.replace(" ","_") |
||||||
|
|
||||||
|
def encode_ir(ir): |
||||||
|
# operations |
||||||
|
if isinstance(ir, Assignment): |
||||||
|
return '({}):=({})'.format(encode_ir(ir.lvalue), encode_ir(ir.rvalue)) |
||||||
|
if isinstance(ir, Index): |
||||||
|
return 'index({})'.format(ntype(ir._type)) |
||||||
|
if isinstance(ir, Member): |
||||||
|
return 'member' #.format(ntype(ir._type)) |
||||||
|
if isinstance(ir, Length): |
||||||
|
return 'length' |
||||||
|
if isinstance(ir, Balance): |
||||||
|
return 'balance' |
||||||
|
if isinstance(ir, Binary): |
||||||
|
return 'binary({})'.format(ir.type_str) |
||||||
|
if isinstance(ir, Unary): |
||||||
|
return 'unary({})'.format(ir.type_str) |
||||||
|
if isinstance(ir, Condition): |
||||||
|
return 'condition({})'.format(encode_ir(ir.value)) |
||||||
|
if isinstance(ir, NewStructure): |
||||||
|
return 'new_structure' |
||||||
|
if isinstance(ir, NewContract): |
||||||
|
return 'new_contract' |
||||||
|
if isinstance(ir, NewArray): |
||||||
|
return 'new_array({})'.format(ntype(ir._array_type)) |
||||||
|
if isinstance(ir, NewElementaryType): |
||||||
|
return 'new_elementary({})'.format(ntype(ir._type)) |
||||||
|
if isinstance(ir, Push): |
||||||
|
return 'push({},{})'.format(encode_ir(ir.value), encode_ir(ir.lvalue)) |
||||||
|
if isinstance(ir, Delete): |
||||||
|
return 'delete({},{})'.format(encode_ir(ir.lvalue), encode_ir(ir.variable)) |
||||||
|
if isinstance(ir, SolidityCall): |
||||||
|
return 'solidity_call({})'.format(ir.function.full_name) |
||||||
|
if isinstance(ir, InternalCall): |
||||||
|
return 'internal_call({})'.format(ntype(ir._type_call)) |
||||||
|
if isinstance(ir, EventCall): # is this useful? |
||||||
|
return 'event' |
||||||
|
if isinstance(ir, LibraryCall): |
||||||
|
return 'library_call' |
||||||
|
if isinstance(ir, InternalDynamicCall): |
||||||
|
return 'internal_dynamic_call' |
||||||
|
if isinstance(ir, HighLevelCall): # TODO: improve |
||||||
|
return 'high_level_call' |
||||||
|
if isinstance(ir, LowLevelCall): # TODO: improve |
||||||
|
return 'low_level_call' |
||||||
|
if isinstance(ir, TypeConversion): |
||||||
|
return 'type_conversion({})'.format(ntype(ir.type)) |
||||||
|
if isinstance(ir, Return): # this can be improved using values |
||||||
|
return 'return' #.format(ntype(ir.type)) |
||||||
|
if isinstance(ir, Transfer): |
||||||
|
return 'transfer({})'.format(encode_ir(ir.call_value)) |
||||||
|
if isinstance(ir, Send): |
||||||
|
return 'send({})'.format(encode_ir(ir.call_value)) |
||||||
|
if isinstance(ir, Unpack): # TODO: improve |
||||||
|
return 'unpack' |
||||||
|
if isinstance(ir, InitArray): # TODO: improve |
||||||
|
return 'init_array' |
||||||
|
if isinstance(ir, Function): # TODO: investigate this |
||||||
|
return 'function_solc' |
||||||
|
|
||||||
|
# variables |
||||||
|
if isinstance(ir, Constant): |
||||||
|
return 'constant({})'.format(ntype(ir._type)) |
||||||
|
if isinstance(ir, SolidityVariableComposed): |
||||||
|
return 'solidity_variable_composed({})'.format(ir.name) |
||||||
|
if isinstance(ir, SolidityVariable): |
||||||
|
return 'solidity_variable{}'.format(ir.name) |
||||||
|
if isinstance(ir, TemporaryVariable): |
||||||
|
return 'temporary_variable' |
||||||
|
if isinstance(ir, ReferenceVariable): |
||||||
|
return 'reference({})'.format(ntype(ir._type)) |
||||||
|
if isinstance(ir, LocalVariable): |
||||||
|
return 'local_solc_variable({})'.format(ir._location) |
||||||
|
if isinstance(ir, StateVariable): |
||||||
|
return 'state_solc_variable({})'.format(ntype(ir._type)) |
||||||
|
if isinstance(ir, LocalVariableInitFromTuple): |
||||||
|
return 'local_variable_init_tuple' |
||||||
|
if isinstance(ir, TupleVariable): |
||||||
|
return 'tuple_variable' |
||||||
|
|
||||||
|
# default |
||||||
|
else: |
||||||
|
simil_logger.error(type(ir),"is missing encoding!") |
||||||
|
return '' |
||||||
|
|
||||||
|
def encode_contract(cfilename, **kwargs): |
||||||
|
r = dict() |
||||||
|
|
||||||
|
# Init slither |
||||||
|
try: |
||||||
|
slither = Slither(cfilename, **kwargs) |
||||||
|
except: |
||||||
|
simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs['solc']) |
||||||
|
return r |
||||||
|
|
||||||
|
# Iterate over all the contracts |
||||||
|
for contract in slither.contracts: |
||||||
|
|
||||||
|
# Iterate over all the functions |
||||||
|
for function in contract.functions_not_inherited: |
||||||
|
|
||||||
|
if function.nodes == []: |
||||||
|
continue |
||||||
|
|
||||||
|
x = (cfilename,contract.name,function.name) |
||||||
|
|
||||||
|
r[x] = [] |
||||||
|
|
||||||
|
# Iterate over the nodes of the function |
||||||
|
for node in function.nodes: |
||||||
|
# Print the Solidity expression of the nodes |
||||||
|
# And the SlithIR operations |
||||||
|
if node.expression: |
||||||
|
for ir in node.irs: |
||||||
|
r[x].append(encode_ir(ir)) |
||||||
|
return r |
||||||
|
|
||||||
|
|
@ -0,0 +1,54 @@ |
|||||||
|
import logging |
||||||
|
import sys |
||||||
|
import os.path |
||||||
|
import traceback |
||||||
|
|
||||||
|
from .model import load_model |
||||||
|
from .encode import parse_target, encode_contract |
||||||
|
|
||||||
|
logging.basicConfig() |
||||||
|
logger = logging.getLogger("Slither-simil") |
||||||
|
|
||||||
|
def info(args): |
||||||
|
|
||||||
|
try: |
||||||
|
|
||||||
|
model = args.model |
||||||
|
if os.path.isfile(model): |
||||||
|
model = load_model(model) |
||||||
|
else: |
||||||
|
model = None |
||||||
|
|
||||||
|
filename = args.filename |
||||||
|
contract, fname = parse_target(args.fname) |
||||||
|
solc = args.solc |
||||||
|
|
||||||
|
if filename is None and contract is None and fname is None: |
||||||
|
logger.info("%s uses the following words:",args.model) |
||||||
|
for word in model.get_words(): |
||||||
|
logger.info(word) |
||||||
|
sys.exit(0) |
||||||
|
|
||||||
|
if filename is None or contract is None or fname is None: |
||||||
|
logger.error('The encode mode requires filename, contract and fname parameters.') |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
irs = encode_contract(filename, **vars(args)) |
||||||
|
if len(irs) == 0: |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
x = (filename,contract,fname) |
||||||
|
y = " ".join(irs[x]) |
||||||
|
|
||||||
|
logger.info("Function {} in contract {} is encoded as:".format(fname, contract)) |
||||||
|
logger.info(y) |
||||||
|
if model is not None: |
||||||
|
fvector = model.get_sentence_vector(y) |
||||||
|
logger.info(fvector) |
||||||
|
|
||||||
|
except Exception: |
||||||
|
logger.error('Error in %s' % args.filename) |
||||||
|
logger.error(traceback.format_exc()) |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
|
@ -0,0 +1,9 @@ |
|||||||
|
import sys |
||||||
|
|
||||||
|
try: |
||||||
|
from fastText import load_model |
||||||
|
from fastText import train_unsupervised |
||||||
|
except ImportError: |
||||||
|
print("ERROR: in order to use slither-simil, you need to install fastText 0.2.0:") |
||||||
|
print("$ pip3 install https://github.com/facebookresearch/fastText/archive/0.2.0.zip --user\n") |
||||||
|
sys.exit(-1) |
@ -0,0 +1,78 @@ |
|||||||
|
import logging |
||||||
|
import sys |
||||||
|
import traceback |
||||||
|
import operator |
||||||
|
import numpy as np |
||||||
|
import random |
||||||
|
|
||||||
|
from .model import load_model |
||||||
|
from .encode import load_and_encode, parse_target |
||||||
|
|
||||||
|
try: |
||||||
|
from sklearn import decomposition |
||||||
|
import matplotlib.pyplot as plt |
||||||
|
except ImportError: |
||||||
|
decomposition = None |
||||||
|
plt = None |
||||||
|
|
||||||
|
logger = logging.getLogger("Slither-simil") |
||||||
|
|
||||||
|
def plot(args): |
||||||
|
|
||||||
|
if decomposition is None or plt is None: |
||||||
|
logger.error("ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:") |
||||||
|
logger.error("$ pip3 install sklearn matplotlib --user") |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
try: |
||||||
|
|
||||||
|
model = args.model |
||||||
|
model = load_model(model) |
||||||
|
filename = args.filename |
||||||
|
#contract = args.contract |
||||||
|
contract, fname = parse_target(args.fname) |
||||||
|
#solc = args.solc |
||||||
|
infile = args.input |
||||||
|
#ext = args.filter |
||||||
|
#nsamples = args.nsamples |
||||||
|
|
||||||
|
if fname is None or infile is None: |
||||||
|
logger.error('The plot mode requieres fname and input parameters.') |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
logger.info('Loading data..') |
||||||
|
cache = load_and_encode(infile, **vars(args)) |
||||||
|
|
||||||
|
data = list() |
||||||
|
fs = list() |
||||||
|
|
||||||
|
logger.info('Procesing data..') |
||||||
|
for (f,c,n),y in cache.items(): |
||||||
|
if (c == contract or contract is None) and n == fname: |
||||||
|
fs.append(f) |
||||||
|
data.append(y) |
||||||
|
|
||||||
|
if len(data) == 0: |
||||||
|
logger.error('No contract was found with function %s', fname) |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
data = np.array(data) |
||||||
|
pca = decomposition.PCA(n_components=2) |
||||||
|
tdata = pca.fit_transform(data) |
||||||
|
|
||||||
|
logger.info('Plotting data..') |
||||||
|
plt.figure(figsize=(20,10)) |
||||||
|
assert(len(tdata) == len(fs)) |
||||||
|
for ([x,y],l) in zip(tdata, fs): |
||||||
|
x = random.gauss(0, 0.01) + x |
||||||
|
y = random.gauss(0, 0.01) + y |
||||||
|
plt.scatter(x, y, c='blue') |
||||||
|
plt.text(x-0.001,y+0.001, l) |
||||||
|
|
||||||
|
logger.info('Saving figure to plot.png..') |
||||||
|
plt.savefig('plot.png', bbox_inches='tight') |
||||||
|
|
||||||
|
except Exception: |
||||||
|
logger.error('Error in %s' % args.filename) |
||||||
|
logger.error(traceback.format_exc()) |
||||||
|
sys.exit(-1) |
@ -0,0 +1,6 @@ |
|||||||
|
import numpy as np |
||||||
|
|
||||||
|
def similarity(v1, v2): |
||||||
|
n1 = np.linalg.norm(v1) |
||||||
|
n2 = np.linalg.norm(v2) |
||||||
|
return np.dot(v1, v2) / n1 / n2 |
@ -0,0 +1,54 @@ |
|||||||
|
import argparse |
||||||
|
import logging |
||||||
|
import sys |
||||||
|
import traceback |
||||||
|
import operator |
||||||
|
import numpy as np |
||||||
|
|
||||||
|
from .model import load_model |
||||||
|
from .encode import encode_contract, load_and_encode, parse_target |
||||||
|
from .cache import save_cache |
||||||
|
from .similarity import similarity |
||||||
|
|
||||||
|
logger = logging.getLogger("Slither-simil") |
||||||
|
|
||||||
|
def test(args): |
||||||
|
|
||||||
|
try: |
||||||
|
model = args.model |
||||||
|
model = load_model(model) |
||||||
|
filename = args.filename |
||||||
|
contract, fname = parse_target(args.fname) |
||||||
|
infile = args.input |
||||||
|
ntop = args.ntop |
||||||
|
|
||||||
|
if filename is None or contract is None or fname is None or infile is None: |
||||||
|
logger.error('The test mode requires filename, contract, fname and input parameters.') |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
irs = encode_contract(filename, **vars(args)) |
||||||
|
if len(irs) == 0: |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
y = " ".join(irs[(filename,contract,fname)]) |
||||||
|
|
||||||
|
fvector = model.get_sentence_vector(y) |
||||||
|
cache = load_and_encode(infile, model, **vars(args)) |
||||||
|
#save_cache("cache.npz", cache) |
||||||
|
|
||||||
|
r = dict() |
||||||
|
for x,y in cache.items(): |
||||||
|
r[x] = similarity(fvector, y) |
||||||
|
|
||||||
|
r = sorted(r.items(), key=operator.itemgetter(1), reverse=True) |
||||||
|
logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop) |
||||||
|
format_table = "{: <65} {: <20} {: <20} {: <10}" |
||||||
|
logger.info(format_table.format(*["filename", "contract", "function", "score"])) |
||||||
|
for x,score in r[:ntop]: |
||||||
|
score = str(round(score, 3)) |
||||||
|
logger.info(format_table.format(*(list(x)+[score]))) |
||||||
|
|
||||||
|
except Exception: |
||||||
|
logger.error('Error in %s' % args.filename) |
||||||
|
logger.error(traceback.format_exc()) |
||||||
|
sys.exit(-1) |
@ -0,0 +1,54 @@ |
|||||||
|
import argparse |
||||||
|
import logging |
||||||
|
import sys |
||||||
|
import traceback |
||||||
|
import operator |
||||||
|
import os |
||||||
|
|
||||||
|
from .model import train_unsupervised |
||||||
|
from .encode import encode_contract, load_contracts |
||||||
|
from .cache import save_cache |
||||||
|
|
||||||
|
logger = logging.getLogger("Slither-simil") |
||||||
|
|
||||||
|
def train(args): |
||||||
|
|
||||||
|
try: |
||||||
|
last_data_train_filename = "last_data_train.txt" |
||||||
|
model_filename = args.model |
||||||
|
dirname = args.input |
||||||
|
nsamples = args.nsamples |
||||||
|
|
||||||
|
if dirname is None: |
||||||
|
logger.error('The train mode requires the input parameter.') |
||||||
|
sys.exit(-1) |
||||||
|
|
||||||
|
contracts = load_contracts(dirname, **vars(args)) |
||||||
|
logger.info('Saving extracted data into %s', last_data_train_filename) |
||||||
|
cache = [] |
||||||
|
with open(last_data_train_filename, 'w') as f: |
||||||
|
for filename in contracts: |
||||||
|
#cache[filename] = dict() |
||||||
|
for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items(): |
||||||
|
if ir != []: |
||||||
|
x = " ".join(ir) |
||||||
|
f.write(x+"\n") |
||||||
|
cache.append((os.path.split(filename)[-1], contract, function, x)) |
||||||
|
|
||||||
|
logger.info('Starting training') |
||||||
|
model = train_unsupervised(input=last_data_train_filename, model='skipgram') |
||||||
|
logger.info('Training complete') |
||||||
|
logger.info('Saving model') |
||||||
|
model.save_model(model_filename) |
||||||
|
|
||||||
|
for i,(filename, contract, function, irs) in enumerate(cache): |
||||||
|
cache[i] = ((filename, contract, function), model.get_sentence_vector(irs)) |
||||||
|
|
||||||
|
logger.info('Saving cache in cache.npz') |
||||||
|
save_cache(cache, "cache.npz") |
||||||
|
logger.info('Done!') |
||||||
|
|
||||||
|
except Exception: |
||||||
|
logger.error('Error in %s' % args.filename) |
||||||
|
logger.error(traceback.format_exc()) |
||||||
|
sys.exit(-1) |
Loading…
Reference in new issue