Refactor parsing

pull/1378/head
Simone 2 years ago
parent 6571ada9a2
commit 3278417af0
  1. 81
      slither/solc_parsing/declarations/contract.py
  2. 81
      slither/solc_parsing/declarations/using_for_top_level.py
  3. 13
      slither/solc_parsing/slither_compilation_unit_solc.py

@ -5,7 +5,7 @@ from slither.core.declarations import Modifier, Event, EnumContract, StructureCo
from slither.core.declarations.contract import Contract
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type
from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
@ -574,7 +574,7 @@ class ContractSolc(CallerContextExpression):
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing state variable {e}")
def analyze_using_for(self):
def analyze_using_for(self): # pylint: disable=too-many-branches
try:
for father in self._contract.inheritance:
self._contract.using_for.update(father.using_for)
@ -593,18 +593,7 @@ class ContractSolc(CallerContextExpression):
)
else:
# We have a list of functions. A function can be topLevel or a library function
# at this point library function are yet to be parsed so we add the function name
# and add the real function later
for f in using_for["functionList"]:
function_name = f["function"]["name"]
if function_name.find(".") != -1:
# Library function
self._contract.using_for[type_name].append(function_name)
else:
# Top level function
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._contract.using_for[type_name].append(tl_function)
self._analyze_function_list(using_for["functionList"], type_name)
else:
for using_for in self._usingForNotParsed:
children = using_for[self.get_children()]
@ -622,34 +611,42 @@ class ContractSolc(CallerContextExpression):
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing using for {e}")
def analyze_library_function_using_for(self):
for type_name, full_names in self._contract.using_for.items():
# If it's a string is a library function e.g. L.a
# We add the actual function and remove the string
for full_name in full_names:
if isinstance(full_name, str):
full_name_split = full_name.split(".")
# TODO this doesn't handle the case if there is an import with an alias
# e.g. MyImport.MyLib.a
if len(full_name_split) == 2:
library_name = full_name_split[0]
function_name = full_name_split[1]
# Get the library function
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for f in c.functions:
if f.name == function_name:
self._contract.using_for[type_name].append(f)
found = True
break
self._contract.using_for[type_name].remove(full_name)
else:
self.log_incorrect_parsing(
f"Expected library function instead received {full_name}"
)
def _analyze_function_list(self, function_list: List, type_name: Type):
for f in function_list:
function_name = f["function"]["name"]
if function_name.find(".") != -1:
# Library function
self._analyze_library_function(function_name, type_name)
else:
# Top level function
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._contract.using_for[type_name].append(tl_function)
def _analyze_library_function(self, function_name: str, type_name: Type) -> None:
function_name_split = function_name.split(".")
# TODO this doesn't handle the case if there is an import with an alias
# e.g. MyImport.MyLib.a
if len(function_name_split) == 2:
library_name = function_name_split[0]
function_name = function_name_split[1]
# Get the library function
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for f in c.functions:
if f.name == function_name:
self._contract.using_for[type_name].append(f)
found = True
break
if not found:
self.log_incorrect_parsing(f"Library function not found {function_name}")
else:
self.log_incorrect_parsing(
f"Expected library function instead received {function_name}"
)
def analyze_enums(self):
try:

@ -2,10 +2,11 @@
Using For Top Level module
"""
import logging
from typing import TYPE_CHECKING, Dict, Union
from typing import TYPE_CHECKING, Dict, Union, Any
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.scope.scope import FileScope
from slither.core.solidity_types import TypeAliasTopLevel
from slither.core.declarations import (
StructureTopLevel,
@ -57,26 +58,12 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
if len(full_name_split) == 1:
# Top level function
function_name = full_name_split[0]
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._using_for.using_for[type_name].append(tl_function)
self._propagate_global(type_name)
break
self._analyze_top_level_function(function_name, type_name)
elif len(full_name_split) == 2:
# Library function
library_name = full_name_split[0]
function_name = full_name_split[1]
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for cf in c.functions:
if cf.name == function_name:
self._using_for.using_for[type_name].append(cf)
self._propagate_global(type_name)
found = True
break
self._analyze_library_function(function_name, library_name, type_name)
else:
# probably case if there is an import with an alias we don't handle it for now
# e.g. MyImport.MyLib.a
@ -85,6 +72,35 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
)
continue
def _analyze_top_level_function(
self, function_name: str, type_name: Union[TypeAliasTopLevel, UserDefinedType]
) -> None:
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._using_for.using_for[type_name].append(tl_function)
self._propagate_global(type_name)
break
def _analyze_library_function(
self,
function_name: str,
library_name: str,
type_name: Union[TypeAliasTopLevel, UserDefinedType],
) -> None:
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for cf in c.functions:
if cf.name == function_name:
self._using_for.using_for[type_name].append(cf)
self._propagate_global(type_name)
found = True
break
if not found:
LOGGER.warning(f"Library {library_name} - function {function_name} not found")
def _propagate_global(self, type_name: Union[TypeAliasTopLevel, UserDefinedType]) -> None:
if self._global:
for scope in self.compilation_unit.scopes.values():
@ -93,24 +109,29 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
if alias == type_name:
scope.usingFor.add(self._using_for)
elif isinstance(type_name, UserDefinedType):
underlying = type_name.type
if isinstance(underlying, StructureTopLevel):
for struct in scope.structures.values():
if struct == underlying:
scope.usingFor.add(self._using_for)
elif isinstance(underlying, EnumTopLevel):
for enum in scope.enums.values():
if enum == underlying:
scope.usingFor.add(self._using_for)
else:
LOGGER.error(
f"Error when propagating global {underlying} {type(underlying)} not a StructTopLevel or EnumTopLevel"
)
self._propagate_global_UserDefinedType(scope, type_name)
else:
LOGGER.error(
f"Error when propagating global using for {type_name} {type(type_name)}"
)
def _propagate_global_UserDefinedType(
self, scope: Dict[Any, FileScope], type_name: UserDefinedType
):
underlying = type_name.type
if isinstance(underlying, StructureTopLevel):
for struct in scope.structures.values():
if struct == underlying:
scope.usingFor.add(self._using_for)
elif isinstance(underlying, EnumTopLevel):
for enum in scope.enums.values():
if enum == underlying:
scope.usingFor.add(self._using_for)
else:
LOGGER.error(
f"Error when propagating global {underlying} {type(underlying)} not a StructTopLevel or EnumTopLevel"
)
@property
def is_compact_ast(self) -> bool:
return self._slither_parser.is_compact_ast

@ -509,11 +509,7 @@ Please rename it, this name is reserved for Slither's internals"""
# Then we analyse state variables, functions and modifiers
self._analyze_third_part(contracts_to_be_analyzed, libraries)
self._analyze_top_level_using_for()
# Convert library function (at the moment are string) in using for that specifies list of functions
# to actual function
self._analyze_library_function_using_for(contracts_to_be_analyzed)
self._analyze_using_for(contracts_to_be_analyzed)
self._parsed = True
@ -625,9 +621,11 @@ Please rename it, this name is reserved for Slither's internals"""
else:
contracts_to_be_analyzed += [contract]
def _analyze_library_function_using_for(self, contracts_to_be_analyzed: List[ContractSolc]):
def _analyze_using_for(self, contracts_to_be_analyzed: List[ContractSolc]):
self._analyze_top_level_using_for()
for c in contracts_to_be_analyzed:
c.analyze_library_function_using_for()
c.analyze_using_for()
def _analyze_enums(self, contract: ContractSolc):
# Enum must be analyzed first
@ -651,7 +649,6 @@ Please rename it, this name is reserved for Slither's internals"""
# Event can refer to struct
contract.analyze_events()
contract.analyze_using_for()
contract.analyze_custom_errors()
contract.set_is_analyzed(True)

Loading…
Cancel
Save