Improve slither-check-upgradability to support Truffle directory

pull/184/head
Josselin 6 years ago
parent 6031907607
commit a269413022
  1. 22
      utils/upgradability/__main__.py
  2. 5
      utils/upgradability/compare_function_ids.py

@ -1,3 +1,4 @@
import os
import logging
import argparse
import sys
@ -8,7 +9,9 @@ from .compare_variables_order import compare_variables_order_implementation, com
from .compare_function_ids import compare_function_ids
logging.basicConfig()
logger = logging.getLogger("Slither-check-upgradability")
logging.getLogger("Slither-check-upgradability").setLevel(logging.INFO)
logging.getLogger("Slither").setLevel(logging.INFO)
def parse_args():
@ -35,23 +38,26 @@ def parse_args():
def main():
args = parse_args()
proxy = Slither(vars(args)['proxy.sol'], solc=args.solc)
proxy_filename = vars(args)['proxy.sol']
proxy = Slither(proxy_filename, is_truffle=os.path.isdir(proxy_filename), solc=args.solc, disable_solc_warnings=True)
proxy_name = args.ProxyName
v1 = Slither(vars(args)['implem.sol'], solc=args.solc)
v1_filename = vars(args)['implem.sol']
v1 = Slither(v1_filename, is_truffle=os.path.isdir(v1_filename), solc=args.solc, disable_solc_warnings=True)
v1_name = args.ContractName
last_contract = v1
last_version = v1
last_name = v1_name
if args.new_version:
v2 = Slither(args.new_version, solc=args.solc)
last_contract = v2
v2 = Slither(args.new_version, is_truffle=os.path.isdir(args.new_version), solc=args.solc, disable_solc_warnings=True)
last_version = v2
if args.new_contract_name:
last_name = args.new_contract_name
compare_function_ids(last_contract, proxy)
compare_variables_order_proxy(last_contract, last_name, proxy, proxy_name)
compare_function_ids(last_version, proxy)
compare_variables_order_proxy(last_version, last_name, proxy, proxy_name)
if args.new_version:
compare_variables_order_implementation(v1, v1_name, v2, last_name)

@ -12,11 +12,10 @@ logger = logging.getLogger("CompareFunctions")
logger.setLevel(logging.INFO)
def get_signatures(s):
functions = [contract.functions for contract in s.contracts_derived]
functions = [item for sublist in functions for item in sublist]
functions = s.functions
functions = [f.full_name for f in functions if f.visibility in ['public', 'external']]
variables = [contract.state_variables for contract in s.contracts_derived]
variables = [contract.state_variables for contract in s.contracts]
variables = [item for sublist in variables for item in sublist]
variables = [variable.name+ '()' for variable in variables if variable.visibility in ['public']]
return list(set(functions+variables))

Loading…
Cancel
Save