diff --git a/utils/slither_format/formatters/naming_convention.py b/utils/slither_format/formatters/naming_convention.py index 307613b78..14fcccb4f 100644 --- a/utils/slither_format/formatters/naming_convention.py +++ b/utils/slither_format/formatters/naming_convention.py @@ -1,11 +1,11 @@ import re import logging from slither.core.expressions.identifier import Identifier -from slither.core.declarations import Structure -from slither.core.solidity_types import UserDefinedType from slither.slithir.operations import NewContract from slither.slithir.operations import Member from slither.visitors.expression.read_var_syntactic import ReadVarSyntactic +from slither.core.solidity_types import UserDefinedType, MappingType +from slither.core.declarations import Enum, Contract, Structure from ..exceptions import FormatError from ..utils.patches import create_patch @@ -354,7 +354,152 @@ def _create_patch_parameter_declaration(slither, result, element): ################################################################################### ################################################################################### -def _create_patch_contract_uses(slither, result, element): +# group 1: beginning of the from type +# group 2: beginning of the to type +# nested mapping are within the group 1 +RE_MAPPING = '[ ]*mapping[ ]*\([ ]*([\=\>\(\) a-zA-Z0-9\._\[\]]*)[ ]*=>[ ]*([a-zA-Z0-9\._\[\]]*)\)' + +def _explore_types(slither, result, target, convert, type, filename_source_code, start, end): + if isinstance(type, UserDefinedType): + # Patch type based on contract/enum + if isinstance(type.type, (Enum, Contract)): + if type.type == target: + old_str = type.type.name + new_str = convert(old_str) + + loc_start = start + loc_end = loc_start + len(old_str) + + create_patch(result, + filename_source_code, + loc_start, + loc_end, + old_str, + new_str) + + + else: + # Patch type based on structure + assert isinstance(type.type, Structure) + if type.type == target: + old_str = type.type.name + new_str = convert(old_str) + + loc_start = start + loc_end = loc_start + len(old_str) + + create_patch(result, + filename_source_code, + loc_start, + loc_end, + old_str, + new_str) + + # Structure contain a list of elements, that might need patching + # .elems return a list of VariableStructure + _explore_variables_declaration(slither, + type.type.elems.values(), + result, + target, + convert) + + if isinstance(type, MappingType): + # Mapping has three steps: + # Convertir the "from" type + # Convertir the "to" type + # Convertir nested type in the "from" + # Ex: mapping (mapping (badName => uint) => uint) + + # Do the comparison twice, so we can factor together the re matching + if isinstance(type.type_from, UserDefinedType) or target in [type.type_from, type.type_to]: + + old_str = type.type.name + new_str = convert(old_str) + + full_txt_start = start + full_txt_end = end + full_txt = slither.source_code[filename_source_code][full_txt_start:full_txt_end] + re_match = re.match(RE_MAPPING, full_txt) + + if type.type_from == target: + + loc_start = start + re_match.start(1) + loc_end = loc_start + len(old_str) + + create_patch(result, + filename_source_code, + loc_start, + loc_end, + old_str, + new_str) + + if type.type_to == target: + + loc_start = start + re_match.start(2) + loc_end = loc_start + len(old_str) + + create_patch(result, + filename_source_code, + loc_start, + loc_end, + old_str, + new_str) + + if isinstance(type.type_from, UserDefinedType): + loc_start = start + re_match.start(1) + loc_end = start + re_match.start(2) + _explore_types(slither, result, target, convert, type, filename_source_code, loc_start, loc_end) + + + +def _explore_variables_declaration(slither, variables, result, target, convert): + for variable in variables: + filename_source_code = variable.source_code['filename_absolute'] + full_txt_start = variable.source_code['start'] + full_txt_end = full_txt_start + variable.source_code['length'] + full_txt = slither.source_code[filename_source_code][full_txt_start:full_txt_end] + + _explore_types(slither, + result, + target, + convert, + variable.type, + filename_source_code, + full_txt_start, + variable.source_code['start'] + variable.source_code['length']) + + + if variable == target: + old_str = variable.name + new_str = convert(old_str) + + # The name is after the space + matches = re.finditer('[ ]*', full_txt) + # Look for the end offset of the largest list of ' ' + loc_start = max(matches, key=lambda x:len(x.group())).end() - 1 + loc_end = loc_start + len(old_str) + + create_patch(result, + filename_source_code, + loc_start, + loc_end, + old_str, + new_str) + + +def _explore_contract_declaration(slither, result, target, convert): + for contract in slither.derived_contracts: + _explore_variables_declaration(slither, contract.state_variable, result, target, convert) + for st in contract.structures: + _explore_variables_declaration(slither, st.elem.values(), result, target, convert) + + +def _convert_capitalize_contract(name): + return name.capitalize() + + + +def _create_patch_contract_uses_old(slither, result, element): in_file, in_file_str, old_str_of_interest, loc_start, loc_end = _unpack_info(slither, element) name = get_name(element) @@ -367,6 +512,7 @@ def _create_patch_contract_uses(slither, result, element): # To-do: Deep-check aggregate types (struct and mapping) svs = target_contract.variables for sv in svs: + print(sv.type) if (str(sv.type) == name): old_str_of_interest = in_file_str[sv.source_mapping['start']:(sv.source_mapping['start'] + sv.source_mapping['length'])]