Merge pull request #1636 from crytic/simplify-enums

use enum strings instead of impl. __str__
pull/1691/head
Feist Josselin 2 years ago committed by GitHub
commit 8548bed8e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 81
      slither/core/cfg/node.py
  2. 2
      slither/core/declarations/function.py
  3. 83
      slither/slithir/operations/binary.py
  4. 14
      slither/slithir/operations/unary.py

@ -66,80 +66,41 @@ if TYPE_CHECKING:
class NodeType(Enum): class NodeType(Enum):
ENTRYPOINT = 0x0 # no expression ENTRYPOINT = "ENTRY_POINT" # no expression
# Node with expression # Nodes that may have an expression
EXPRESSION = 0x10 # normal case EXPRESSION = "EXPRESSION" # normal case
RETURN = 0x11 # RETURN may contain an expression RETURN = "RETURN" # RETURN may contain an expression
IF = 0x12 IF = "IF"
VARIABLE = 0x13 # Declaration of variable VARIABLE = "NEW VARIABLE" # Variable declaration
ASSEMBLY = 0x14 ASSEMBLY = "INLINE ASM"
IFLOOP = 0x15 IFLOOP = "IF_LOOP"
# Merging nodes # Nodes where control flow merges
# Can have phi IR operation # Can have phi IR operation
ENDIF = 0x50 # ENDIF node source mapping points to the if/else body ENDIF = "END_IF" # ENDIF node source mapping points to the if/else "body"
STARTLOOP = 0x51 # STARTLOOP node source mapping points to the entire loop body STARTLOOP = "BEGIN_LOOP" # STARTLOOP node source mapping points to the entire loop "body"
ENDLOOP = 0x52 # ENDLOOP node source mapping points to the entire loop body ENDLOOP = "END_LOOP" # ENDLOOP node source mapping points to the entire loop "body"
# Below the nodes have no expression # Below the nodes do not have an expression but are used to expression CFG structure.
# But are used to expression CFG structure
# Absorbing node # Absorbing node
THROW = 0x20 THROW = "THROW"
# Loop related nodes # Loop related nodes
BREAK = 0x31 BREAK = "BREAK"
CONTINUE = 0x32 CONTINUE = "CONTINUE"
# Only modifier node # Only modifier node
PLACEHOLDER = 0x40 PLACEHOLDER = "_"
TRY = 0x41 TRY = "TRY"
CATCH = 0x42 CATCH = "CATCH"
# Node not related to the CFG # Node not related to the CFG
# Use for state variable declaration # Use for state variable declaration
OTHER_ENTRYPOINT = 0x60 OTHER_ENTRYPOINT = "OTHER_ENTRYPOINT"
# @staticmethod
def __str__(self):
if self == NodeType.ENTRYPOINT:
return "ENTRY_POINT"
if self == NodeType.EXPRESSION:
return "EXPRESSION"
if self == NodeType.RETURN:
return "RETURN"
if self == NodeType.IF:
return "IF"
if self == NodeType.VARIABLE:
return "NEW VARIABLE"
if self == NodeType.ASSEMBLY:
return "INLINE ASM"
if self == NodeType.IFLOOP:
return "IF_LOOP"
if self == NodeType.THROW:
return "THROW"
if self == NodeType.BREAK:
return "BREAK"
if self == NodeType.CONTINUE:
return "CONTINUE"
if self == NodeType.PLACEHOLDER:
return "_"
if self == NodeType.TRY:
return "TRY"
if self == NodeType.CATCH:
return "CATCH"
if self == NodeType.ENDIF:
return "END_IF"
if self == NodeType.STARTLOOP:
return "BEGIN_LOOP"
if self == NodeType.ENDLOOP:
return "END_LOOP"
if self == NodeType.OTHER_ENTRYPOINT:
return "OTHER_ENTRYPOINT"
return f"Unknown type {hex(self.value)}"
# endregion # endregion
@ -1014,7 +975,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
additional_info += " " + str(self.expression) additional_info += " " + str(self.expression)
elif self.variable_declaration: elif self.variable_declaration:
additional_info += " " + str(self.variable_declaration) additional_info += " " + str(self.variable_declaration)
txt = str(self._node_type) + additional_info txt = self._node_type.value + additional_info
return txt return txt

@ -1378,7 +1378,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
content = "" content = ""
content += "digraph{\n" content += "digraph{\n"
for node in self.nodes: for node in self.nodes:
label = f"Node Type: {str(node.type)} {node.node_id}\n" label = f"Node Type: {node.type.value} {node.node_id}\n"
if node.expression and not skip_expressions: if node.expression and not skip_expressions:
label += f"\nEXPRESSION:\n{node.expression}\n" label += f"\nEXPRESSION:\n{node.expression}\n"
if node.irs and not skip_expressions: if node.irs and not skip_expressions:

@ -12,25 +12,25 @@ logger = logging.getLogger("BinaryOperationIR")
class BinaryType(Enum): class BinaryType(Enum):
POWER = 0 # ** POWER = "**"
MULTIPLICATION = 1 # * MULTIPLICATION = "*"
DIVISION = 2 # / DIVISION = "/"
MODULO = 3 # % MODULO = "%"
ADDITION = 4 # + ADDITION = "+"
SUBTRACTION = 5 # - SUBTRACTION = "-"
LEFT_SHIFT = 6 # << LEFT_SHIFT = "<<"
RIGHT_SHIFT = 7 # >> RIGHT_SHIFT = ">>"
AND = 8 # & AND = "&"
CARET = 9 # ^ CARET = "^"
OR = 10 # | OR = "|"
LESS = 11 # < LESS = "<"
GREATER = 12 # > GREATER = ">"
LESS_EQUAL = 13 # <= LESS_EQUAL = "<="
GREATER_EQUAL = 14 # >= GREATER_EQUAL = ">="
EQUAL = 15 # == EQUAL = "=="
NOT_EQUAL = 16 # != NOT_EQUAL = "!="
ANDAND = 17 # && ANDAND = "&&"
OROR = 18 # || OROR = "||"
@staticmethod @staticmethod
def return_bool(operation_type): def return_bool(operation_type):
@ -98,47 +98,6 @@ class BinaryType(Enum):
BinaryType.DIVISION, BinaryType.DIVISION,
] ]
def __str__(self): # pylint: disable=too-many-branches
if self == BinaryType.POWER:
return "**"
if self == BinaryType.MULTIPLICATION:
return "*"
if self == BinaryType.DIVISION:
return "/"
if self == BinaryType.MODULO:
return "%"
if self == BinaryType.ADDITION:
return "+"
if self == BinaryType.SUBTRACTION:
return "-"
if self == BinaryType.LEFT_SHIFT:
return "<<"
if self == BinaryType.RIGHT_SHIFT:
return ">>"
if self == BinaryType.AND:
return "&"
if self == BinaryType.CARET:
return "^"
if self == BinaryType.OR:
return "|"
if self == BinaryType.LESS:
return "<"
if self == BinaryType.GREATER:
return ">"
if self == BinaryType.LESS_EQUAL:
return "<="
if self == BinaryType.GREATER_EQUAL:
return ">="
if self == BinaryType.EQUAL:
return "=="
if self == BinaryType.NOT_EQUAL:
return "!="
if self == BinaryType.ANDAND:
return "&&"
if self == BinaryType.OROR:
return "||"
raise SlithIRError(f"str: Unknown operation type {self} {type(self)})")
class Binary(OperationWithLValue): class Binary(OperationWithLValue):
def __init__(self, result, left_variable, right_variable, operation_type: BinaryType): def __init__(self, result, left_variable, right_variable, operation_type: BinaryType):
@ -178,8 +137,8 @@ class Binary(OperationWithLValue):
@property @property
def type_str(self): def type_str(self):
if self.node.scope.is_checked and self._type.can_be_checked_for_overflow(): if self.node.scope.is_checked and self._type.can_be_checked_for_overflow():
return "(c)" + str(self._type) return "(c)" + self._type.value
return str(self._type) return self._type.value
def __str__(self): def __str__(self):
if isinstance(self.lvalue, ReferenceVariable): if isinstance(self.lvalue, ReferenceVariable):

@ -9,8 +9,8 @@ logger = logging.getLogger("BinaryOperationIR")
class UnaryType(Enum): class UnaryType(Enum):
BANG = 0 # ! BANG = "!"
TILD = 1 # ~ TILD = "~"
@staticmethod @staticmethod
def get_type(operation_type, isprefix): def get_type(operation_type, isprefix):
@ -21,14 +21,6 @@ class UnaryType(Enum):
return UnaryType.TILD return UnaryType.TILD
raise SlithIRError(f"get_type: Unknown operation type {operation_type}") raise SlithIRError(f"get_type: Unknown operation type {operation_type}")
def __str__(self):
if self == UnaryType.BANG:
return "!"
if self == UnaryType.TILD:
return "~"
raise SlithIRError(f"str: Unknown operation type {self}")
class Unary(OperationWithLValue): class Unary(OperationWithLValue):
def __init__(self, result, variable, operation_type): def __init__(self, result, variable, operation_type):
@ -53,7 +45,7 @@ class Unary(OperationWithLValue):
@property @property
def type_str(self): def type_str(self):
return str(self._type) return self._type.value
def __str__(self): def __str__(self):
return f"{self.lvalue} = {self.type_str} {self.rvalue} " return f"{self.lvalue} = {self.type_str} {self.rvalue} "

Loading…
Cancel
Save