diff --git a/.gitignore b/.gitignore index 51873c6f2..2f5cbe034 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,7 @@ nosetests.xml coverage.xml *.cover .hypothesis/ - +.vscode/ # Translations *.mo *.pot diff --git a/.travis.yml b/.travis.yml index 665ef18f1..5f5c642cf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,11 +3,12 @@ os: - linux language: python python: - - 2.7.12 + - 3.6 branches: only: - master + - dev install: - scripts/travis_install.sh diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..88d24606b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,21 @@ +FROM alpine:3.6 + +LABEL name slither +LABEL src "https://github.com/trailofbits/slither" +LABEL creator trailofbits +LABEL dockerfile_maintenance trailofbits +LABEL desc "Static Analyzer for Solidity" + +# Mostly stolen from ethereum/solc. +RUN apk add --no-cache git python3 build-base cmake boost-dev \ +&& sed -i -E -e 's/include /include /' /usr/include/boost/asio/detail/socket_types.hpp \ +&& git clone --depth 1 --recursive -b release https://github.com/ethereum/solidity \ +&& cd /solidity && cmake -DCMAKE_BUILD_TYPE=Release -DTESTS=0 -DSTATIC_LINKING=1 \ +&& cd /solidity && make solc && install -s solc/solc /usr/bin \ +&& cd / && rm -rf solidity \ +&& rm -rf /var/cache/apk/* \ +&& git clone https://github.com/trailofbits/slither.git +WORKDIR slither +RUN python3 setup.py install +ENTRYPOINT ["slither"] +CMD ["tests/uninitialized.sol"] diff --git a/README.md b/README.md index 24979729c..74baa949e 100644 --- a/README.md +++ b/README.md @@ -1,84 +1,115 @@ # Slither, the Solidity source analyzer [![Build Status](https://travis-ci.com/trailofbits/slither.svg?token=JEF97dFy1QsDCfQ2Wusd&branch=master)](https://travis-ci.com/trailofbits/slither) +[![Slack Status](https://empireslacking.herokuapp.com/badge.svg)](https://empireslacking.herokuapp.com) +[![PyPI version](https://badge.fury.io/py/slither-analyzer.svg)](https://badge.fury.io/py/slither-analyzer) -Slither is a Solidity static analysis framework written in Python 3. It provides an API to easily manipulate Solidity code. In addition to exposing a Solidity contracts AST, Slither provides many APIs to quickly check local and state variable usage. +Slither is a Solidity static analysis framework written in Python 3. It runs a suite of vulnerability detectors, prints visual information about contract details, and provides an API to easily write custom analyses. Slither enables developers to find vulnerabilities, enhance their code comphrehension, and quickly prototype custom analyses. -With Slither you can: -- Detect vulnerabilities -- Speed up your understanding of code -- Build custom analyses to answer specific questions -- Quickly prototype a new static analysis techniques +## Features -## How to install - -Slither uses Python 3.6. +* Detects vulnerable Solidity code with low false positives +* Identifies where the error condition occurs in the source code +* Easy integration into continuous integration and Truffle builds +* Built-in 'printers' quickly report crucial contract information +* Detector API to write custom analyses in Python +* Ability to analyze contracts written with Solidity >= 0.4 +* Intermediate representation ([SlithIR](https://github.com/trailofbits/slither/wiki/SlithIR)) enables simple, high-precision analyses +## Usage -```bash -$ python setup.py install +Run Slither on a Truffle application: ``` - -You may also want solc, the Solidity compiler, which can be installed using homebrew: - -```bash -$ brew update -$ brew upgrade -$ brew tap ethereum/ethereum -$ brew install solidity -$ brew linkapps solidity +truffle compile +slither . ``` -or with aptitude: - -```bash -$ sudo add-apt-repository ppa:ethereum/ethereum -$ sudo apt-get update -$ sudo apt-get install solc -``` - -## How to use - -``` -$ slither file.sol +Run Slither on a single file: ``` - -``` -$ slither examples/uninitialized.sol +$ slither tests/uninitialized.sol # argument can be file, folder or glob, be sure to quote the argument when using a glob [..] -INFO:Detectors:Uninitialized state variables in examples/uninitialized.sol, Contract: Uninitialized, Vars: destination, Used in ['transfer'] +INFO:Detectors:Uninitialized state variables in tests/uninitialized.sol, Contract: Uninitialized, Vars: destination, Used in ['transfer'] [..] ``` -If Slither is applied on a directory, it will run on every `.sol` file of the directory. +If Slither is run on a directory, it will run on every `.sol` file in the directory. -## Options +### Configuration -### Configuration * `--solc SOLC`: Path to `solc` (default 'solc') +* `--solc-args SOLC_ARGS`: Add custom solc arguments. `SOLC_ARGS` can contain multiple arguments * `--disable-solc-warnings`: Do not print solc warnings * `--solc-ast`: Use the solc AST file as input (`solc file.sol --ast-json > file.ast.json`) * `--json FILE`: Export results as JSON -* `--solc-args SOLC_ARGS`: Add custom solc arguments. `SOLC_ARGS` can contain multiple arguments. -### Analyses -* `--high`: Run only medium/high severity checks with high confidence -* `--medium`: Run only medium/high severity checks with medium confidence -* `--low`: Run only low severity checks +## Detectors + +By default, all the detectors are run. + +Num | Detector | What it Detects | Impact | Confidence +--- | --- | --- | --- | --- +1 | `suicidal` | [Functions allowing anyone to destruct the contract](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#suicidal) | High | High +2 | `uninitialized-local` | [Uninitialized local variables](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#uninitialized-local-variables) | High | High +3 | `uninitialized-state` | [Uninitialized state variables](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#uninitialized-state-variables) | High | High +4 | `uninitialized-storage` | [Uninitialized storage variables](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#uninitialized-storage-variables) | High | High +5 | `arbitrary-send` | [Functions that send ether to arbitrary destinations](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#functions-that-send-ether-to-arbitrary-destinations) | High | Medium +6 | `reentrancy` | [Reentrancy vulnerabilities](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#reentrancy-vulnerabilities) | High | Medium +7 | `locked-ether` | [Contracts that lock ether](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#contracts-that-lock-ether) | Medium | High +8 | `tx-origin` | [Dangerous usage of `tx.origin`](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#dangerous-usage-of-txorigin) | Medium | Medium +9 | `assembly` | [Assembly usage](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#assembly-usage) | Informational | High +10 | `constable-states` | [State variables that could be declared constant](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#state-variables-that-could-be-declared-constant) | Informational | High +11 | `external-function` | [Public function that could be declared as external](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#public-function-that-could-be-declared-as-external) | Informational | High +12 | `low-level-calls` | [Low level calls](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#low-level-calls) | Informational | High +13 | `naming-convention` | [Conformance to Solidity naming conventions](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#conformance-to-solidity-naming-conventions) | Informational | High +14 | `pragma` | [If different pragma directives are used](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#state-variables-that-could-be-declared-constant) | Informational | High +15 | `solc-version` | [Old versions of Solidity (< 0.4.23)](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#old-versions-of-solidity) | Informational | High +16 | `unused-state` | [Unused state variables](https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#unused-state-variables) | Informational | High + +[Contact us](https://www.trailofbits.com/contact/) to get access to additional detectors. ### Printers -* `--print-summary`: Print a summary of the contracts -* `--print-quick-summary`: Print a quick summary of the contracts -* `--print-inheritance`: Print the inheritance graph -For more information about printers, see the [Printers documentation](docs/PRINTERS.md) +To run a printer, use `--print` and a comma-separated list of printers. + +Num | Printer | Description +--- | --- | --- +1 | `call-graph` | Export the call-graph of the contracts to a dot file +2 | `contract-summary` | Print a summary of the contracts +3 | `function-summary` | Print a summary of the functions +4 | `inheritance` | Print the inheritance relations between contracts +5 | `inheritance-graph` | Export the inheritance graph of each contract to a dot file +6 | `slithir` | Print the slithIR representation of the functions +7 | `vars-and-auth` | Print the state variables written and the authorization of the functions + + +## How to install + +Slither requires Python 3.6+ and [solc](https://github.com/ethereum/solidity/), the Solidity compiler. + +### Using Pip + +``` +$ pip install slither-analyzer +``` + +### Using Git + +```bash +$ git clone https://github.com/trailofbits/slither.git && cd slither +$ python setup.py install +``` + +## Getting Help + +Feel free to stop by our [Slack channel](https://empireslacking.herokuapp.com) (#ethereum) for help using or extending Slither. + +* The [Printer documentation](https://github.com/trailofbits/slither/wiki/Printer-documentation) describes the information Slither is capable of visualizing for each contract. -## Checks available +* The [Detector documentation](https://github.com/trailofbits/slither/wiki/Adding-a-new-detector) describes how to write a new vulnerability analyses. -Check | Purpose | Severity | Confidence ---- | --- | --- | --- -`--uninitialized`| Detect uninitialized variables | High | High +* The [API documentation](https://github.com/trailofbits/slither/wiki/API-examples) describes the methods and objects available for custom analyses. +* The [SlithIR documentation](https://github.com/trailofbits/slither/wiki/SlithIR) describes the SlithIR intermediate representation. ## License -Slither is licensed and distributed under AGPLv3. [Contact us](mailto:opensource@trailofbits.com) if you're looking for an exception to the terms. +Slither is licensed and distributed under the AGPLv3 license. [Contact us](mailto:opensource@trailofbits.com) if you're looking for an exception to the terms. diff --git a/docs/PRINTERS.md b/docs/PRINTERS.md deleted file mode 100644 index 30340089c..000000000 --- a/docs/PRINTERS.md +++ /dev/null @@ -1,96 +0,0 @@ -# Slither Printers - -Slither allows printing contracts information through its printers. - -## Quick Summary -`slither.py file.sol --print-quick-summary` - -Output a quick summary of the contract. -Example: -``` -$ slither.py vulns/0x01293cd77f68341635814c35299ed30ae212789e.sol --print-quick-summary -``` - - -## Summary -`slither.py file.sol --print-summary` - -Output a summary of the contract showing for each function: -- What are the visibility and the modifiers -- What are the state variables read or written -- What are the calls - -Example: -``` -$ slither.py vulns/0x01293cd77f68341635814c35299ed30ae212789e.sol --print-summary -``` -``` -[...] - -INFO:Slither:Contract NBACrypto -Contract vars: [u'ceoAddress', u'cfoAddress', u'teams', u'players', u'teamsAreInitiated', u'playersAreInitiated', u'isPaused'] -+--------------------+------------+--------------+------------------------+---------------+----------------------------------------------+ -| Function | Visibility | Modifiers | Read | Write | Calls | -+--------------------+------------+--------------+------------------------+---------------+----------------------------------------------+ -| pauseGame | public | [u'onlyCeo'] | [] | [u'isPaused'] | [] | -| unPauseGame | public | [u'onlyCeo'] | [] | [u'isPaused'] | [] | -| GetIsPauded | public | [] | [u'isPaused'] | [] | [] | -| purchaseCountry | public | [] | [u'isPaused'] | [u'teams'] | [u'cfoAddress.transfer', u'mul'] | -| | | | | | [u'require', u'teams.ownerAddress.transfer'] | -| purchasePlayer | public | [] | [u'isPaused'] | [u'players'] | [u'cfoAddress.transfer', u'mul'] | -| | | | | | [u'require', u'teams.ownerAddress.transfer'] | -| | | | | | [u'players.ownerAddress.transfer'] | -| modifyPriceCountry | public | [] | [] | [u'teams'] | [u'require'] | -| getTeam | public | [] | [u'teams'] | [] | [] | -| getPlayer | public | [] | [u'players'] | [] | [] | -| getTeamPrice | public | [] | [] | [] | [] | -| getPlayerPrice | public | [] | [] | [] | [] | -| getTeamOwner | public | [] | [] | [] | [] | -| getPlayerOwner | public | [] | [] | [] | [] | -| mul | internal | [] | [] | [] | [u'assert'] | -| div | internal | [] | [] | [] | [] | -| InitiateTeams | public | [u'onlyCeo'] | [u'teamsAreInitiated'] | [] | [u'require', u'teams.push'] | -| addPlayer | public | [u'onlyCeo'] | [] | [] | [u'players.push'] | -+--------------------+------------+--------------+------------------------+---------------+----------------------------------------------+ -``` - -## Inheritance Graph -`slither.py file.sol --print-inheritance` - -Output a graph showing the inheritance interaction between the contracts. -Example: -``` -$ slither examples/DAO.sol --print-inheritance -[...] -INFO:PrinterInheritance:Inheritance Graph: examples/DAO.sol.dot -``` - -The output format is [dot](https://www.graphviz.org/) and can be converted to svg using: -``` -dot examples/DAO.sol.dot -Tsvg -o examples/DAO.svg -``` - -Functions in orange override a parent's functions. If a variable points to another contract, the contract type is written in blue. - - - - -## Variables written and authorization -`slither.py file.sol --print-variables-written-and-authorization` - -Print the variables written and the check on `msg.sender` of each function. -``` -... -INFO:Printers: -Contract MyNewBank -+----------+------------------------+-------------------------+ -| Function | State variable written | Condition on msg.sender | -+----------+------------------------+-------------------------+ -| kill | [] | ['msg.sender != owner'] | -| withdraw | [] | ['msg.sender != owner'] | -| init | [u'owner'] | [] | -| owned | [u'owner'] | [] | -| fallback | [u'deposits'] | [] | -+----------+------------------------+-------------------------+ -``` - diff --git a/docs/imgs/DAO.svg b/docs/imgs/DAO.svg deleted file mode 100644 index 47ea043bb..000000000 --- a/docs/imgs/DAO.svg +++ /dev/null @@ -1,221 +0,0 @@ - - - - - - -%3 - - -TokenInterface - -TokenInterface -Functions: -    balanceOf -    transfer -    transferFrom -    approve -    allowance -Public Variables: -    totalSupply -Private Variables: -    balances -    allowed - - -Token - -Token -Functions: -    balanceOf -    transfer -    transferFrom -    approve -    allowance -Modifiers: -    noEther - - -Token->TokenInterface - - - - -ManagedAccountInterface - -ManagedAccountInterface -Functions: -    payOut -Public Variables: -    owner -    payOwnerOnly -    accumulatedInput - - -ManagedAccount - -ManagedAccount -Functions: -    fallback -    payOut - - -ManagedAccount->ManagedAccountInterface - - - - -TokenCreationInterface - -TokenCreationInterface -Functions: -    createTokenProxy -    refund -    divisor -Public Variables: -    closingTime -    minTokensToCreate -    isFueled -    privateCreation -    extraBalance - (ManagedAccount) -Private Variables: -    weiGiven - - -TokenCreation - -TokenCreation -Functions: -    createTokenProxy -    refund -    divisor - - -TokenCreation->Token - - - - -TokenCreation->TokenCreationInterface - - - - -DAOInterface - -DAOInterface -Functions: -    fallback -    receiveEther -    newProposal -    checkProposalCode -    vote -    executeProposal -    splitDAO -    newContract -    changeAllowedRecipients -    changeProposalDeposit -    retrieveDAOReward -    getMyReward -    withdrawRewardFor -    transferWithoutReward -    transferFromWithoutReward -    halveMinQuorum -    numberOfProposals -    getNewDAOAddress -    isBlocked -    unblockMe -Modifiers: -    onlyTokenholders -Public Variables: -    proposals -    minQuorumDivisor -    lastTimeMinQuorumMet -    curator -    allowedRecipients -    rewardToken -    totalRewardToken -    rewardAccount - (ManagedAccount) -    DAOrewardAccount - (ManagedAccount) -    DAOpaidOut -    paidOut -    blocked -    proposalDeposit -    daoCreator - (DAO_Creator) -Private Variables: -    creationGracePeriod -    minProposalDebatePeriod -    minSplitDebatePeriod -    splitExecutionPeriod -    quorumHalvingPeriod -    executeProposalPeriod -    maxDepositDivisor -    sumOfProposalDeposits - - -DAO - -DAO -Functions: -    fallback -    receiveEther -    newProposal -    checkProposalCode -    vote -    executeProposal -    closeProposal -    splitDAO -    newContract -    retrieveDAOReward -    getMyReward -    withdrawRewardFor -    transfer -    transferWithoutReward -    transferFrom -    transferFromWithoutReward -    transferPaidOut -    changeProposalDeposit -    changeAllowedRecipients -    isRecipientAllowed -    actualBalance -    minQuorum -    halveMinQuorum -    createNewDAO -    numberOfProposals -    getNewDAOAddress -    isBlocked -    unblockMe -Modifiers: -    onlyTokenholders - - -DAO->Token - - - - -DAO->TokenCreation - - - - -DAO->DAOInterface - - - - -DAO_Creator - -DAO_Creator -Functions: -    createDAO - - - diff --git a/docs/imgs/quick-summary.png b/docs/imgs/quick-summary.png deleted file mode 100644 index ed978e192..000000000 Binary files a/docs/imgs/quick-summary.png and /dev/null differ diff --git a/examples/bugs/uninitialized.sol b/examples/bugs/uninitialized.sol deleted file mode 100644 index 89126a4f3..000000000 --- a/examples/bugs/uninitialized.sol +++ /dev/null @@ -1,11 +0,0 @@ -contract Uninitialized{ - - - address destination; - - function transfer() payable{ - - destination.transfer(msg.value); - } - -} diff --git a/examples/printers/authorization.sol b/examples/printers/authorization.sol new file mode 100644 index 000000000..a4b361754 --- /dev/null +++ b/examples/printers/authorization.sol @@ -0,0 +1,25 @@ +pragma solidity ^0.4.24; +contract Owner{ + + address owner; + + modifier onlyOwner(){ + require(msg.sender == owner); + _; + } + +} + +contract MyContract is Owner{ + + mapping(address => uint) balances; + + constructor() public{ + owner = msg.sender; + } + + function mint(uint value) onlyOwner public{ + balances[msg.sender] += value; + } + +} diff --git a/examples/printers/call_graph.sol b/examples/printers/call_graph.sol new file mode 100644 index 000000000..182ccbf52 --- /dev/null +++ b/examples/printers/call_graph.sol @@ -0,0 +1,34 @@ +library Library { + function library_func() { + } +} + +contract ContractA { + uint256 public val = 0; + + function my_func_a() { + keccak256(0); + Library.library_func(); + } +} + +contract ContractB { + ContractA a; + + constructor() { + a = new ContractA(); + } + + function my_func_b() { + a.my_func_a(); + my_second_func_b(); + } + + function my_func_a() { + my_second_func_b(); + } + + function my_second_func_b(){ + a.val(); + } +} \ No newline at end of file diff --git a/examples/printers/call_graph.sol.dot b/examples/printers/call_graph.sol.dot new file mode 100644 index 000000000..5677b7a40 --- /dev/null +++ b/examples/printers/call_graph.sol.dot @@ -0,0 +1,28 @@ +strict digraph { +subgraph cluster_5_Library { +label = "Library" +"5_library_func" [label="library_func"] +} +subgraph cluster_22_ContractA { +label = "ContractA" +"22_my_func_a" [label="my_func_a"] +"22_val" [label="val"] +} +subgraph cluster_63_ContractB { +label = "ContractB" +"63_my_second_func_b" [label="my_second_func_b"] +"63_my_func_a" [label="my_func_a"] +"63_constructor" [label="constructor"] +"63_my_func_b" [label="my_func_b"] +"63_my_func_b" -> "63_my_second_func_b" +"63_my_func_a" -> "63_my_second_func_b" +} +subgraph cluster_solidity { +label = "[Solidity]" +"keccak256()" +"22_my_func_a" -> "keccak256()" +} +"22_my_func_a" -> "5_library_func" +"63_my_func_b" -> "22_my_func_a" +"63_my_second_func_b" -> "22_val" +} \ No newline at end of file diff --git a/examples/printers/call_graph.sol.dot.png b/examples/printers/call_graph.sol.dot.png new file mode 100644 index 000000000..a679acbc4 Binary files /dev/null and b/examples/printers/call_graph.sol.dot.png differ diff --git a/examples/printers/inheritances.sol b/examples/printers/inheritances.sol index a8b383c76..adda74ef7 100644 --- a/examples/printers/inheritances.sol +++ b/examples/printers/inheritances.sol @@ -1,21 +1,16 @@ -contract Contract1{ +pragma solidity ^0.4.24; - uint myvar; - - function myfunc() public{} +contract BaseContract1{ } -contract Contract2{ - - uint public myvar2; - - function myfunc2() public{} - - function privatefunc() private{} +contract BaseContract2{ } -contract Contract3 is Contract1, Contract2{ - - function myfunc() public{} // override myfunc +contract ChildContract1 is BaseContract1{ +} +contract ChildContract2 is BaseContract1, BaseContract2{ } + +contract GrandchildContract1 is ChildContract1{ +} \ No newline at end of file diff --git a/examples/printers/inheritances_graph.sol b/examples/printers/inheritances_graph.sol new file mode 100644 index 000000000..a8b383c76 --- /dev/null +++ b/examples/printers/inheritances_graph.sol @@ -0,0 +1,21 @@ +contract Contract1{ + + uint myvar; + + function myfunc() public{} +} + +contract Contract2{ + + uint public myvar2; + + function myfunc2() public{} + + function privatefunc() private{} +} + +contract Contract3 is Contract1, Contract2{ + + function myfunc() public{} // override myfunc + +} diff --git a/examples/printers/inheritances.sol.dot b/examples/printers/inheritances_graph.sol.dot similarity index 99% rename from examples/printers/inheritances.sol.dot rename to examples/printers/inheritances_graph.sol.dot index 8a22220e0..8ed2c3345 100644 --- a/examples/printers/inheritances.sol.dot +++ b/examples/printers/inheritances_graph.sol.dot @@ -1,7 +1,7 @@ digraph{ +Contract1[shape="box"label=<
Contract1
Public Functions:
myfunc()
Private Variables:
myvar
>]; +Contract2[shape="box"label=<
Contract2
Public Functions:
myfunc2()
Private Functions:
privatefunc()
Public Variables:
myvar2
>]; Contract3 -> Contract2; Contract3 -> Contract1; Contract3[shape="box"label=<
Contract3
Public Functions:
myfunc()
Public Variables:
myvar2
Private Variables:
myvar
>]; -Contract2[shape="box"label=<
Contract2
Public Functions:
myfunc2()
Private Functions:
privatefunc()
Public Variables:
myvar2
>]; -Contract1[shape="box"label=<
Contract1
Public Functions:
myfunc()
Private Variables:
myvar
>]; } \ No newline at end of file diff --git a/examples/printers/inheritances_graph.sol.png b/examples/printers/inheritances_graph.sol.png new file mode 100644 index 000000000..bed1f2a80 Binary files /dev/null and b/examples/printers/inheritances_graph.sol.png differ diff --git a/examples/printers/quick_summary.sol b/examples/printers/quick_summary.sol new file mode 100644 index 000000000..44c690a04 --- /dev/null +++ b/examples/printers/quick_summary.sol @@ -0,0 +1,13 @@ +pragma solidity 0.4.24; + +contract MyContract{ + + function myfunc() public{ + + } + + function myPrivateFunc() private{ + + } + +} diff --git a/examples/printers/quick_summary.sol.png b/examples/printers/quick_summary.sol.png new file mode 100644 index 000000000..45d2143a7 Binary files /dev/null and b/examples/printers/quick_summary.sol.png differ diff --git a/examples/printers/slihtir.sol b/examples/printers/slihtir.sol new file mode 100644 index 000000000..4f898d817 --- /dev/null +++ b/examples/printers/slihtir.sol @@ -0,0 +1,27 @@ +pragma solidity ^0.4.24; + +library UnsafeMath{ + + function add(uint a, uint b) public pure returns(uint){ + return a + b; + } + + function min(uint a, uint b) public pure returns(uint){ + return a - b; + } +} + +contract MyContract{ + using UnsafeMath for uint; + + mapping(address => uint) balances; + + function transfer(address to, uint val) public{ + + balances[msg.sender] = balances[msg.sender].min(val); + balances[to] = balances[to].add(val); + + } + + +} diff --git a/examples/scripts/convert_to_ir.py b/examples/scripts/convert_to_ir.py new file mode 100644 index 000000000..72f79ba64 --- /dev/null +++ b/examples/scripts/convert_to_ir.py @@ -0,0 +1,30 @@ +import sys +from slither.slither import Slither +from slither.slithir.convert import convert_expression + + +if len(sys.argv) != 2: + print('python function_called.py functions_called.sol') + exit(-1) + +# Init slither +slither = Slither(sys.argv[1]) + +# Get the contract +contract = slither.get_contract_from_name('Test') + +# Get the variable +test = contract.get_function_from_signature('one()') + +nodes = test.nodes + +for node in nodes: + if node.expression: + print('Expression:\n\t{}'.format(node.expression)) + irs = convert_expression(node.expression) + print('IR expressions:') + for ir in irs: + print('\t{}'.format(ir)) + print() + + diff --git a/examples/scripts/export_to_dot.py b/examples/scripts/export_to_dot.py new file mode 100644 index 000000000..628747515 --- /dev/null +++ b/examples/scripts/export_to_dot.py @@ -0,0 +1,18 @@ +import sys +from slither.slither import Slither + + +if len(sys.argv) != 2: + print('python function_called.py contract.sol') + exit(-1) + +# Init slither +slither = Slither(sys.argv[1]) + +for contract in slither.contracts: + for function in contract.functions + contract.modifiers: + filename = "{}-{}-{}.dot".format(sys.argv[1], contract.name, function.full_name) + print('Export {}'.format(filename)) + function.slithir_cfg_to_dot(filename) + + diff --git a/examples/scripts/functions_called.py b/examples/scripts/functions_called.py index afa48b870..5f25477d0 100644 --- a/examples/scripts/functions_called.py +++ b/examples/scripts/functions_called.py @@ -1,7 +1,12 @@ +import sys from slither.slither import Slither +if len(sys.argv) != 2: + print('python functions_called.py functions_called.sol') + exit(-1) + # Init slither -slither = Slither('functions_called.sol') +slither = Slither(sys.argv[1]) # Get the contract contract = slither.get_contract_from_name('Contract') @@ -9,7 +14,7 @@ contract = slither.get_contract_from_name('Contract') # Get the variable entry_point = contract.get_function_from_signature('entry_point()') -all_calls = entry_point.all_calls() +all_calls = entry_point.all_internal_calls() all_calls_formated = [f.contract.name + '.' + f.name for f in all_calls] diff --git a/examples/scripts/functions_writing.py b/examples/scripts/functions_writing.py index 04775e0b7..4609e9f6c 100644 --- a/examples/scripts/functions_writing.py +++ b/examples/scripts/functions_writing.py @@ -1,7 +1,12 @@ +import sys from slither.slither import Slither +if len(sys.argv) != 2: + print('python function_writing.py functions_writing.sol') + exit(-1) + # Init slither -slither = Slither('functions_writing.sol') +slither = Slither(sys.argv[1]) # Get the contract contract = slither.get_contract_from_name('Contract') @@ -10,7 +15,7 @@ contract = slither.get_contract_from_name('Contract') var_a = contract.get_state_variable_from_name('a') # Get the functions writing the variable -functions_writing_a = contract.get_functions_writing_variable(var_a) +functions_writing_a = contract.get_functions_writing_to_variable(var_a) # Print the result print('The function writing "a" are {}'.format([f.name for f in functions_writing_a])) diff --git a/examples/scripts/slithIR.py b/examples/scripts/slithIR.py new file mode 100644 index 000000000..04fe255c8 --- /dev/null +++ b/examples/scripts/slithIR.py @@ -0,0 +1,32 @@ +import sys +from slither import Slither + +if len(sys.argv) != 2: + print('python slithIR.py contract.sol') + exit(-1) + +# Init slither +slither = Slither(sys.argv[1]) + +# Iterate over all the contracts +for contract in slither.contracts: + + # Iterate over all the functions + for function in contract.functions: + + # Dont explore inherited functions + if function.contract == contract: + + print('Function: {}'.format(function.name)) + + # Iterate over the nodes of the function + for node in function.nodes: + + # Print the Solidity expression of the nodes + # And the SlithIR operations + if node.expression: + + print('\tSolidity expression: {}'.format(node.expression)) + print('\tSlithIR:') + for ir in node.irs: + print('\t\t\t{}'.format(ir)) diff --git a/examples/scripts/taint_mapping.py b/examples/scripts/taint_mapping.py new file mode 100644 index 000000000..db88b0d91 --- /dev/null +++ b/examples/scripts/taint_mapping.py @@ -0,0 +1,84 @@ +import sys + +from slither.core.declarations.solidity_variables import \ + SolidityVariableComposed +from slither.core.variables.state_variable import StateVariable +from slither.slither import Slither +from slither.slithir.operations.high_level_call import HighLevelCall +from slither.slithir.operations.index import Index +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.temporary import TemporaryVariable + + +def visit_node(node, visited): + if node in visited: + return + + visited += [node] + taints = node.function.slither.context[KEY] + + refs = {} + for ir in node.irs: + if isinstance(ir, Index): + refs[ir.lvalue] = ir.variable_left + + if isinstance(ir, Index): + read = [ir.variable_left] + else: + read = ir.read + print(ir) + print('Refs {}'.format(refs)) + print('Read {}'.format([str(x) for x in ir.read])) + print('Before {}'.format([str(x) for x in taints])) + if any(var_read in taints for var_read in read): + taints += [ir.lvalue] + lvalue = ir.lvalue + while isinstance(lvalue, ReferenceVariable): + taints += [refs[lvalue]] + lvalue = refs[lvalue] + + print('After {}'.format([str(x) for x in taints])) + print() + + taints = [v for v in taints if not isinstance(v, (TemporaryVariable, ReferenceVariable))] + + node.function.slither.context[KEY] = list(set(taints)) + + for son in node.sons: + visit_node(son, visited) + +def check_call(func, taints): + for node in func.nodes: + for ir in node.irs: + if isinstance(ir, HighLevelCall): + if ir.destination in taints: + print('Call to tainted address found in {}'.format(function.name)) + +if __name__ == "__main__": + if len(sys.argv) != 2: + print('python taint_mapping.py taint.sol') + exit(-1) + + # Init slither + slither = Slither(sys.argv[1]) + + initial_taint = [SolidityVariableComposed('msg.sender')] + initial_taint += [SolidityVariableComposed('msg.value')] + + KEY = 'TAINT' + + prev_taints = [] + slither.context[KEY] = initial_taint + while(set(prev_taints) != set(slither.context[KEY])): + prev_taints = slither.context[KEY] + for contract in slither.contracts: + for function in contract.functions: + print('Function {}'.format(function.name)) + slither.context[KEY] = list(set(slither.context[KEY] + function.parameters)) + visit_node(function.entry_point, []) + print('All variables tainted : {}'.format([str(v) for v in slither.context[KEY]])) + + print('All state variables tainted : {}'.format([str(v) for v in prev_taints if isinstance(v, StateVariable)])) + + for function in contract.functions: + check_call(function, slither.context[KEY]) diff --git a/examples/scripts/variable_in_condition.py b/examples/scripts/variable_in_condition.py index 58e3f16a0..d931a744b 100644 --- a/examples/scripts/variable_in_condition.py +++ b/examples/scripts/variable_in_condition.py @@ -1,7 +1,12 @@ +import sys from slither.slither import Slither +if len(sys.argv) != 2: + print('python variable_in_condition.py variable_in_condition.sol') + exit(-1) + # Init slither -slither = Slither('variable_in_condition.sol') +slither = Slither(sys.argv[1]) # Get the contract contract = slither.get_contract_from_name('Contract') @@ -10,7 +15,7 @@ contract = slither.get_contract_from_name('Contract') var_a = contract.get_state_variable_from_name('a') # Get the functions reading the variable -functions_reading_a = contract.get_functions_reading_variable(var_a) +functions_reading_a = contract.get_functions_reading_from_variable(var_a) function_using_a_as_condition = [f for f in functions_reading_a if\ f.is_reading_in_conditional_node(var_a) or\ diff --git a/plugin_example/README.md b/plugin_example/README.md new file mode 100644 index 000000000..fc711f5d1 --- /dev/null +++ b/plugin_example/README.md @@ -0,0 +1,19 @@ +# Slither, Plugin Example + +This repo contains an example of plugin for Slither. + +See the [detector documentation](https://github.com/trailofbits/slither/wiki/Adding-a-new-detector). + +## Architecture + +- `setup.py`: Contain the plugin information +- `slither_my_plugin/__init__.py`: Contain `make_plugin()`. The function must return the list of new detectors and printers +- `slither_my_plugin/detectors/example.py`: Detector plugin skeleton. + +Once these files are updated with your plugin, you can install it: +``` +python setup.py develop +``` + +We recommend to use a Python virtual environment (for example: [virtualenvwrapper](https://virtualenvwrapper.readthedocs.io/en/latest/)). + diff --git a/plugin_example/setup.py b/plugin_example/setup.py new file mode 100644 index 000000000..2890e224f --- /dev/null +++ b/plugin_example/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + +setup( + name='slither-my-plugins', + description='This is an example of detectors and printers to Slither.', + url='https://github.com/trailofbits/slither-plugins', + author='Trail of Bits', + version='0.0', + packages=find_packages(), + python_requires='>=3.6', + install_requires=[ + 'slither-analyzer==0.1' + ], + entry_points={ + 'slither_analyzer.plugin': 'slither my-plugin=slither_my_plugin:make_plugin', + } +) diff --git a/plugin_example/slither_my_plugin/__init__.py b/plugin_example/slither_my_plugin/__init__.py new file mode 100644 index 000000000..eabdb147e --- /dev/null +++ b/plugin_example/slither_my_plugin/__init__.py @@ -0,0 +1,8 @@ +from slither_my_plugin.detectors.example import Example + + +def make_plugin(): + plugin_detectors = [Example] + plugin_printers = [] + + return plugin_detectors, plugin_printers diff --git a/plugin_example/slither_my_plugin/detectors/example.py b/plugin_example/slither_my_plugin/detectors/example.py new file mode 100644 index 000000000..be800abe8 --- /dev/null +++ b/plugin_example/slither_my_plugin/detectors/example.py @@ -0,0 +1,19 @@ + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + + +class Example(AbstractDetector): + """ + Documentation + """ + + ARGUMENT = 'mydetector' # slither will launch the detector with slither.py --mydetector + HELP = 'Help printed by slither' + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.HIGH + + def detect(self): + + self.logger('Nothing to detect!') + + return [] diff --git a/scripts/pretty_print_and_sort_json.py b/scripts/pretty_print_and_sort_json.py new file mode 100644 index 000000000..08c0002b2 --- /dev/null +++ b/scripts/pretty_print_and_sort_json.py @@ -0,0 +1,101 @@ +#!/usr/bin/python3 + +''' +the purpose of this file is to sort the json output from the detectors such that +the order is deterministic + +- the keys of a json object are sorted +- json objects in a list will be sorted based on the values of their keys +- lists of strings/numbers are sorted + +''' + +import sys +import json + +raw_json_file = sys.argv[1] +pretty_json_file = sys.argv[2] + +from collections import OrderedDict + +def create_property_val_tuple(d, props_info): + p_names = props_info[0] + p_types = props_info[1] + result = [] + for p in p_names: + if not p in d: # not all objects have the same keys + if p_types[p] is 'number': + result.append(0) # to make sorting work + if p_types[p] is 'string': + result.append("") # to make sorting work + else: + result.append(d[p]) + return tuple(result) + +def get_props_info(list_of_dicts): + found_props = set() + prop_types = dict() + + # gather all prop names + for d in list_of_dicts: + for p in d: + found_props.add(p) + + # create a copy, since we are gonna possibly remove props + props_whose_value_we_can_sort_on = set(found_props) + + # for each object, loop through list of all found property names, + # if the object contains that property, check that it's of type string or number + # if it is, save it's type (used later on for sorting with objects that don't have that property) + # if it's not of type string/number remove it from list of properties to check + # since we cannot sort on non-string/number values + for p in list(found_props): + if p in props_whose_value_we_can_sort_on: # short circuit + for d in list_of_dicts: + if p in props_whose_value_we_can_sort_on: # less shorter short circuit + if p in d: + # we ae only gonna sort key values if they are of type string or number + if not isinstance(d[p], str) and not isinstance(d[p], int): + props_whose_value_we_can_sort_on.remove(p) + # we need to store the type of the value because not each object + # in a list of output objects for 1 detector will have the same + # keys, so if we want to sort based on the values then if a certain object + # does not have a key which another object does have we are gonna + # put in 0 for number and "" for string for that key such that sorting on values + # still works + elif isinstance(d[p], str): + prop_types[p] = 'string' + elif isinstance(d[p], int): + prop_types[p] = 'number' + return (sorted(list(props_whose_value_we_can_sort_on)), prop_types) + +def order_by_prop_value(list_of_dicts): + props_info = get_props_info(list_of_dicts) + return sorted(list_of_dicts, key=lambda d: create_property_val_tuple(d, props_info)) + +def order_dict(d): + result = OrderedDict() # such that we keep the order + for k, v in sorted(d.items()): + if isinstance(v, dict): + result[k] = order_dict(v) + elif type(v) is list: + result[k] = order_list(v) + else: # string/number + result[k] = v + return result + +def order_list(l): + if not l: + return [] + if isinstance(l[0], str): # it's a list of string + return sorted(l) + elif isinstance(l[0], int): # it's a list of numbers + return sorted(l) + elif isinstance(l[0], dict): # it's a list of objects + ordered_by_key = [order_dict(v) for v in l] + ordered_by_val = order_by_prop_value(ordered_by_key) + return ordered_by_val + +with open(raw_json_file, 'r') as json_data: + with open(pretty_json_file, 'w') as out_file: + out_file.write(json.dumps(order_list(json.load(json_data)), sort_keys=False, indent=4, separators=(',',': '))) diff --git a/scripts/tests_generate_expected_json.sh b/scripts/tests_generate_expected_json.sh new file mode 100755 index 000000000..492eb5ba4 --- /dev/null +++ b/scripts/tests_generate_expected_json.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +DIR="$(cd "$(dirname "$0")" && pwd)" + +# generate_expected_json file.sol detectors +generate_expected_json(){ + # generate output filename + # e.g. file: uninitialized.sol detector: uninitialized-state + # ---> uninitialized.uninitialized-state.json + output_filename="$(basename $1 .sol).$2.json" + + # run slither detector on input file and save output as json + slither "$1" --disable-solc-warnings --detect "$2" --json "$DIR/tmp-gen.json" + + # convert json file to pretty print and write to destination folder + python "$DIR/pretty_print_and_sort_json.py" "$DIR/tmp-gen.json" "$DIR/../tests/expected_json/$output_filename" + + # remove the raw un-prettified json file + rm "$DIR/tmp-gen.json" +} + +generate_expected_json tests/uninitialized.sol "uninitialized-state" +generate_expected_json tests/backdoor.sol "backdoor" +generate_expected_json tests/backdoor.sol "suicidal" +generate_expected_json tests/pragma.0.4.24.sol "pragma" +generate_expected_json tests/old_solc.sol.json "solc-version" +generate_expected_json tests/reentrancy.sol "reentrancy" +generate_expected_json tests/uninitialized_storage_pointer.sol "uninitialized-storage" +generate_expected_json tests/tx_origin.sol "tx-origin" +generate_expected_json tests/unused_state.sol "unused-state" +generate_expected_json tests/locked_ether.sol "locked-ether" +generate_expected_json tests/arbitrary_send.sol "arbitrary-send" +generate_expected_json tests/inline_assembly_contract.sol "assembly" +generate_expected_json tests/inline_assembly_library.sol "assembly" +generate_expected_json tests/low_level_calls.sol "low-level-calls" +generate_expected_json tests/const_state_variables.sol "constable-states" +generate_expected_json tests/external_function.sol "external-function" +generate_expected_json tests/naming_convention.sol "naming-convention" +generate_expected_json tests/uninitialized_local_variable.sol "uninitialized-local" diff --git a/scripts/travis_install.sh b/scripts/travis_install.sh index 93a6ca45d..21387729f 100755 --- a/scripts/travis_install.sh +++ b/scripts/travis_install.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -pip install -r requirements.txt +python setup.py install function install_solc { sudo wget -O /usr/bin/solc https://github.com/ethereum/solidity/releases/download/v0.4.24/solc-static-linux diff --git a/scripts/travis_test.sh b/scripts/travis_test.sh index e09e17dbb..1b927db32 100755 --- a/scripts/travis_test.sh +++ b/scripts/travis_test.sh @@ -1,13 +1,96 @@ #!/usr/bin/env bash -./slither.py examples/bugs/uninitialized.sol --disable-solc-warnings -if [ $? -ne 1 ]; then +### Test Detectors + +DIR="$(cd "$(dirname "$0")" && pwd)" + +# test_slither file.sol detectors +test_slither(){ + + expected="$DIR/../tests/expected_json/$(basename $1 .sol).$2.json" + actual="$DIR/$(basename $1 .sol).$2.json" + + # run slither detector on input file and save output as json + slither "$1" --disable-solc-warnings --detect "$2" --json "$DIR/tmp-test.json" + + # convert json file to pretty print and write to destination folder + python "$DIR/pretty_print_and_sort_json.py" "$DIR/tmp-test.json" "$actual" + + # remove the raw un-prettified json file + rm "$DIR/tmp-test.json" + + result=$(diff "$expected" "$actual") + + if [ "$result" != "" ]; then + rm "$actual" + echo "" + echo "failed test of file: $1, detector: $2" + echo "" + echo "$result" + echo "" + exit 1 + else + rm "$actual" + fi + + # run slither detector on input file and save output as json + slither "$1" --disable-solc-warnings --detect "$2" --compact-ast --json "$DIR/tmp-test.json" + + # convert json file to pretty print and write to destination folder + python "$DIR/pretty_print_and_sort_json.py" "$DIR/tmp-test.json" "$actual" + + # remove the raw un-prettified json file + rm "$DIR/tmp-test.json" + + result=$(diff "$expected" "$actual") + + if [ "$result" != "" ]; then + rm "$actual" + echo "" + echo "failed test of file: $1, detector: $2" + echo "" + echo "$result" + echo "" + exit 1 + else + rm "$actual" + fi +} + + +test_slither tests/uninitialized.sol "uninitialized-state" +test_slither tests/backdoor.sol "backdoor" +test_slither tests/backdoor.sol "suicidal" +test_slither tests/pragma.0.4.24.sol "pragma" +test_slither tests/old_solc.sol.json "solc-version" +test_slither tests/reentrancy.sol "reentrancy" +test_slither tests/uninitialized_storage_pointer.sol "uninitialized-storage" +test_slither tests/tx_origin.sol "tx-origin" +test_slither tests/unused_state.sol "unused-state" +test_slither tests/locked_ether.sol "locked-ether" +test_slither tests/arbitrary_send.sol "arbitrary-send" +test_slither tests/inline_assembly_contract.sol "assembly" +test_slither tests/inline_assembly_library.sol "assembly" +test_slither tests/low_level_calls.sol "low-level-calls" +test_slither tests/const_state_variables.sol "constable-states" +test_slither tests/external_function.sol "external-function" +test_slither tests/naming_convention.sol "naming-convention" +#test_slither tests/complex_func.sol "complex-function" + +### Test scripts + +python examples/scripts/functions_called.py examples/scripts/functions_called.sol +if [ $? -ne 0 ]; then exit 1 fi -./slither.py examples/bugs/backdoor.sol --disable-solc-warnings -if [ $? -ne 1 ]; then +python examples/scripts/functions_writing.py examples/scripts/functions_writing.sol +if [ $? -ne 0 ]; then exit 1 fi +python examples/scripts/variable_in_condition.py examples/scripts/variable_in_condition.sol +if [ $? -ne 0 ]; then + exit 1 +fi exit 0 diff --git a/setup.py b/setup.py index b4425acf8..99368e649 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ from setuptools import setup, find_packages setup( - name='Slither', + name='slither-analyzer', description='Slither is a Solidity static analysis framework written in Python 3.', url='https://github.com/trailofbits/slither', author='Trail of Bits', - version='0.1', + version='0.2.0', packages=find_packages(), python_requires='>=3.6', install_requires=['prettytable>=0.7.2'], diff --git a/slither/__main__.py b/slither/__main__.py index 1e9dfdf23..f39376f6b 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -1,45 +1,81 @@ #!/usr/bin/env python3 -import sys import argparse -import logging -import traceback -import os import glob import json +import logging +import os +import sys +import traceback + +from pkg_resources import iter_entry_points, require +from slither.detectors.abstract_detector import (AbstractDetector, + DetectorClassification) +from slither.printers.abstract_printer import AbstractPrinter from slither.slither import Slither -from slither.detectors.detectors import Detectors -from slither.printers.printers import Printers +from slither.utils.colors import red +from slither.utils.command_line import output_to_markdown, output_detectors, output_printers logging.basicConfig() logger = logging.getLogger("Slither") -def determineChecks(detectors, args): - if args.low: - return detectors.low - elif args.medium: - return detectors.medium + detectors.high - elif args.high: - return detectors.high - elif args.detectors_to_run: - return args.detectors_to_run - else: - return detectors.high + detectors.medium + detectors.low +def process(filename, args, detector_classes, printer_classes): + """ + The core high-level code for running Slither static analysis. + Returns: + list(result), int: Result list and number of contracts analyzed + """ + ast = '--ast-json' + if args.compact_ast: + ast = '--ast-compact-json' + slither = Slither(filename, args.solc, args.disable_solc_warnings, args.solc_args, ast) -def process(filename, args, detectors, printers): - slither = Slither(filename, args.solc, args.disable_solc_warnings, args.solc_args) - if args.printers_to_run: - [printers.run_printer(slither, p) for p in args.printers_to_run] - return [] - else: - checks = determineChecks(detectors, args) - results = [detectors.run_detector(slither, c) for c in checks] - results = [x for x in results if x] # remove empty results - results = [item for sublist in results for item in sublist] #flatten - return results + return _process(slither, detector_classes, printer_classes) + +def _process(slither, detector_classes, printer_classes): + for detector_cls in detector_classes: + slither.register_detector(detector_cls) + + for printer_cls in printer_classes: + slither.register_printer(printer_cls) + + analyzed_contracts_count = len(slither.contracts) + + results = [] + + if not printer_classes: + detector_results = slither.run_detectors() + detector_results = [x for x in detector_results if x] # remove empty results + detector_results = [item for sublist in detector_results for item in sublist] # flatten + + results.extend(detector_results) + + slither.run_printers() # Currently printers does not return results + + return results, analyzed_contracts_count + +def process_truffle(dirname, args, detector_classes, printer_classes): + if not os.path.isdir(os.path.join(dirname, 'build'))\ + or not os.path.isdir(os.path.join(dirname, 'build', 'contracts')): + logger.info(red('No truffle build directory found, did you run `truffle compile`?')) + return (0,0) + + filenames = glob.glob(os.path.join(dirname,'build','contracts', '*.json')) + + all_contracts = [] + all_filenames = [] + + for filename in filenames: + with open(filename) as f: + contract_loaded = json.load(f) + all_contracts.append(contract_loaded['ast']) + all_filenames.append(contract_loaded['sourcePath']) + + slither = Slither(all_contracts, args.solc, args.disable_solc_warnings, args.solc_args) + return _process(slither, detector_classes, printer_classes) def output_json(results, filename): @@ -53,84 +89,101 @@ def exit(results): sys.exit(len(results)) -def main(): - detectors = Detectors() - printers = Printers() - - parser = argparse.ArgumentParser(description='Slither', - usage="slither.py contract.sol [flag]") - - parser.add_argument('filename', - help='contract.sol file') - - parser.add_argument('--solc', - help='solc path', - action='store', - default='solc') +def get_detectors_and_printers(): + """ + NOTE: This contains just a few detectors and printers that we made public. + """ + from slither.detectors.examples.backdoor import Backdoor + from slither.detectors.variables.uninitialized_state_variables import UninitializedStateVarsDetection + from slither.detectors.variables.uninitialized_storage_variables import UninitializedStorageVars + from slither.detectors.variables.uninitialized_local_variables import UninitializedLocalVars + from slither.detectors.attributes.constant_pragma import ConstantPragma + from slither.detectors.attributes.old_solc import OldSolc + from slither.detectors.attributes.locked_ether import LockedEther + from slither.detectors.functions.arbitrary_send import ArbitrarySend + from slither.detectors.functions.suicidal import Suicidal + from slither.detectors.functions.complex_function import ComplexFunction + from slither.detectors.reentrancy.reentrancy import Reentrancy + from slither.detectors.variables.unused_state_variables import UnusedStateVars + from slither.detectors.variables.possible_const_state_variables import ConstCandidateStateVars + from slither.detectors.statements.tx_origin import TxOrigin + from slither.detectors.statements.assembly import Assembly + from slither.detectors.operations.low_level_calls import LowLevelCalls + from slither.detectors.naming_convention.naming_convention import NamingConvention + from slither.detectors.functions.external_function import ExternalFunction + + detectors = [Backdoor, + UninitializedStateVarsDetection, + UninitializedStorageVars, + UninitializedLocalVars, + ConstantPragma, + OldSolc, + Reentrancy, + LockedEther, + ArbitrarySend, + Suicidal, + UnusedStateVars, + TxOrigin, + Assembly, + LowLevelCalls, + NamingConvention, + ConstCandidateStateVars, + #ComplexFunction, + ExternalFunction] + + from slither.printers.summary.function import FunctionSummary + from slither.printers.summary.contract import ContractSummary + from slither.printers.inheritance.inheritance import PrinterInheritance + from slither.printers.inheritance.inheritance_graph import PrinterInheritanceGraph + from slither.printers.call.call_graph import PrinterCallGraph + from slither.printers.functions.authorization import PrinterWrittenVariablesAndAuthorization + from slither.printers.summary.slithir import PrinterSlithIR + from slither.printers.summary.human_summary import PrinterHumanSummary + + printers = [FunctionSummary, + ContractSummary, + PrinterInheritance, + PrinterInheritanceGraph, + PrinterCallGraph, + PrinterWrittenVariablesAndAuthorization, + PrinterSlithIR, + PrinterHumanSummary] + + # Handle plugins! + for entry_point in iter_entry_points(group='slither_analyzer.plugin', name=None): + make_plugin = entry_point.load() + + plugin_detectors, plugin_printers = make_plugin() + + if not all(issubclass(d, AbstractDetector) for d in plugin_detectors): + raise Exception('Error when loading plugin %s, %r is not a detector' % (entry_point, d)) + + if not all(issubclass(p, AbstractPrinter) for p in plugin_printers): + raise Exception('Error when loading plugin %s, %r is not a printer' % (entry_point, p)) + + # We convert those to lists in case someone returns a tuple + detectors += list(plugin_detectors) + printers += list(plugin_printers) + + return detectors, printers - parser.add_argument('--solc-args', - help='Add custom solc arguments. Example: --solc-args "--allow-path /tmp --evm-version byzantium".', - action='store', - default=None) - - parser.add_argument('--disable-solc-warnings', - help='Disable solc warnings', - action='store_true', - default=False) - - parser.add_argument('--solc-ast', - help='Provide the ast solc file', - action='store_true', - default=False) +def main(): + detectors, printers = get_detectors_and_printers() - parser.add_argument('--low', - help='Only low analyses', - action='store_true', - default=False) + main_impl(all_detector_classes=detectors, all_printer_classes=printers) - parser.add_argument('--medium', - help='Only medium and high analyses', - action='store_true', - default=False) - parser.add_argument('--high', - help='Only high analyses', - action='store_true', - default=False) +def main_impl(all_detector_classes, all_printer_classes): + """ + :param all_detector_classes: A list of all detectors that can be included/excluded. + :param all_printer_classes: A list of all printers that can be included. + """ + args = parse_args(all_detector_classes, all_printer_classes) - parser.add_argument('--json', - help='Export results as JSON', - action='store', - default=None) - - for detector_name, Detector in detectors.detectors.items(): - detector_arg = '--{}'.format(Detector.ARGUMENT) - detector_help = 'Detection of ' + Detector.HELP - parser.add_argument(detector_arg, - help=detector_help, - action="append_const", - dest="detectors_to_run", - const=detector_name) - - for printer_name, Printer in printers.printers.items(): - printer_arg = '--{}'.format(Printer.ARGUMENT) - printer_help = Printer.HELP - parser.add_argument(printer_arg, - help=printer_help, - action="append_const", - dest="printers_to_run", - const=printer_name) - - # Debug - parser.add_argument('--debug', - help='Debug mode', - action="store_true", - default=False) + printer_classes = choose_printers(args, all_printer_classes) + detector_classes = choose_detectors(args, all_detector_classes) - args = parser.parse_args() - default_log = logging.INFO - if args.debug: - default_log = logging.DEBUG + default_log = logging.INFO if not args.debug else logging.DEBUG for (l_name, l_level) in [('Slither', default_log), ('Contract', default_log), @@ -146,31 +199,244 @@ def main(): l.setLevel(l_level) try: - filename = sys.argv[1] + filename = args.filename + + globbed_filenames = glob.glob(filename, recursive=True) if os.path.isfile(filename): - results = process(filename, args, detectors, printers) - elif os.path.isdir(filename): + (results, number_contracts) = process(filename, args, detector_classes, printer_classes) + + elif os.path.isfile(os.path.join(filename, 'truffle.js')): + (results, number_contracts) = process_truffle(filename, args, detector_classes, printer_classes) + + elif os.path.isdir(filename) or len(globbed_filenames) > 0: extension = "*.sol" if not args.solc_ast else "*.json" filenames = glob.glob(os.path.join(filename, extension)) - results = [process(filename, args, detectors, printers) for filename in filenames] - results = [item for sublist in results for item in sublist] #flatten - #if args.json: - # output_json(results, args.json) - #exit(results) + if len(filenames) == 0: + filenames = globbed_filenames + number_contracts = 0 + results = [] + for filename in filenames: + (results_tmp, number_contracts_tmp) = process(filename, args, detector_classes, printer_classes) + number_contracts += number_contracts_tmp + results += results_tmp + + else: raise Exception("Unrecognised file/dir path: '#{filename}'".format(filename=filename)) if args.json: output_json(results, args.json) - logger.info('%s analyzed, %d result(s) found', filename, len(results)) + # Dont print the number of result for printers + if printer_classes: + logger.info('%s analyzed (%d contracts)', filename, number_contracts) + else: + logger.info('%s analyzed (%d contracts), %d result(s) found', filename, number_contracts, len(results)) exit(results) - except Exception as e: - logging.error('Error in %s'%sys.argv[1]) + except Exception: + logging.error('Error in %s' % args.filename) logging.error(traceback.format_exc()) sys.exit(-1) +def parse_args(detector_classes, printer_classes): + parser = argparse.ArgumentParser(description='Slither', + usage="slither.py contract.sol [flag]") + + parser.add_argument('filename', + help='contract.sol') + + parser.add_argument('--version', + help='displays the current version', + version=require('slither-analyzer')[0].version, + action='version') + + group_detector = parser.add_argument_group('Detectors') + group_printer = parser.add_argument_group('Printers') + group_solc = parser.add_argument_group('Solc options') + group_misc = parser.add_argument_group('Additional option') + + group_detector.add_argument('--detect', + help='Comma-separated list of detectors, defaults to all, ' + 'available detectors: {}'.format( + ', '.join(d.ARGUMENT for d in detector_classes)), + action='store', + dest='detectors_to_run', + default='all') + + group_printer.add_argument('--print', + help='Comma-separated list fo contract information printers, ' + 'available printers: {}'.format( + ', '.join(d.ARGUMENT for d in printer_classes)), + action='store', + dest='printers_to_run', + default='') + + group_detector.add_argument('--list-detectors', + help='List available detectors', + action=ListDetectors, + nargs=0, + default=False) + + group_printer.add_argument('--list-printers', + help='List available printers', + action=ListPrinters, + nargs=0, + default=False) + + + group_detector.add_argument('--exclude-detectors', + help='Comma-separated list of detectors that should be excluded', + action='store', + dest='detectors_to_exclude', + default='') + + group_detector.add_argument('--exclude-informational', + help='Exclude informational impact analyses', + action='store_true', + default=False) + + group_detector.add_argument('--exclude-low', + help='Exclude low impact analyses', + action='store_true', + default=False) + + group_detector.add_argument('--exclude-medium', + help='Exclude medium impact analyses', + action='store_true', + default=False) + + group_detector.add_argument('--exclude-high', + help='Exclude high impact analyses', + action='store_true', + default=False) + + + group_solc.add_argument('--solc', + help='solc path', + action='store', + default='solc') + + group_solc.add_argument('--solc-args', + help='Add custom solc arguments. Example: --solc-args "--allow-path /tmp --evm-version byzantium".', + action='store', + default=None) + + group_solc.add_argument('--disable-solc-warnings', + help='Disable solc warnings', + action='store_true', + default=False) + + group_solc.add_argument('--solc-ast', + help='Provide the ast solc file', + action='store_true', + default=False) + + group_misc.add_argument('--json', + help='Export results as JSON', + action='store', + default=None) + + + + # debugger command + parser.add_argument('--debug', + help=argparse.SUPPRESS, + action="store_true", + default=False) + + parser.add_argument('--markdown', + help=argparse.SUPPRESS, + action=OutputMarkdown, + nargs=0, + default=False) + + parser.add_argument('--compact-ast', + help=argparse.SUPPRESS, + action='store_true', + default=False) + + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + + args = parser.parse_args() + + return args + +class ListDetectors(argparse.Action): + def __call__(self, parser, *args, **kwargs): + detectors, _ = get_detectors_and_printers() + output_detectors(detectors) + parser.exit() + +class ListPrinters(argparse.Action): + def __call__(self, parser, *args, **kwargs): + _, printers = get_detectors_and_printers() + output_printers(printers) + parser.exit() + +class OutputMarkdown(argparse.Action): + def __call__(self, parser, *args, **kwargs): + detectors, printers = get_detectors_and_printers() + output_to_markdown(detectors, printers) + parser.exit() + + +def choose_detectors(args, all_detector_classes): + # If detectors are specified, run only these ones + + detectors_to_run = [] + detectors = {d.ARGUMENT: d for d in all_detector_classes} + + if args.detectors_to_run == 'all': + detectors_to_run = all_detector_classes + detectors_excluded = args.detectors_to_exclude.split(',') + for d in detectors: + if d in detectors_excluded: + detectors_to_run.remove(detectors[d]) + else: + for d in args.detectors_to_run.split(','): + if d in detectors: + detectors_to_run.append(detectors[d]) + else: + raise Exception('Error: {} is not a detector'.format(d)) + return detectors_to_run + + if args.exclude_informational: + detectors_to_run = [d for d in detectors_to_run if + d.IMPACT != DetectorClassification.INFORMATIONAL] + if args.exclude_low: + detectors_to_run = [d for d in detectors_to_run if + d.IMPACT != DetectorClassification.LOW] + if args.exclude_medium: + detectors_to_run = [d for d in detectors_to_run if + d.IMPACT != DetectorClassification.MEDIUM] + if args.exclude_high: + detectors_to_run = [d for d in detectors_to_run if + d.IMPACT != DetectorClassification.HIGH] + if args.detectors_to_exclude: + detectors_to_run = [d for d in detectors_to_run if + d.ARGUMENT not in args.detectors_to_exclude] + return detectors_to_run + + +def choose_printers(args, all_printer_classes): + printers_to_run = [] + + # disable default printer + if args.printers_to_run == '': + return [] + + printers = {p.ARGUMENT: p for p in all_printer_classes} + for p in args.printers_to_run.split(','): + if p in printers: + printers_to_run.append(printers[p]) + else: + raise Exception('Error: {} is not a printer'.format(p)) + return printers_to_run + + if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/slither/core/solidityTypes/__init__.py b/slither/analyses/__init__.py similarity index 100% rename from slither/core/solidityTypes/__init__.py rename to slither/analyses/__init__.py diff --git a/slither/core/sourceMapping/__init__.py b/slither/analyses/taint/__init__.py similarity index 100% rename from slither/core/sourceMapping/__init__.py rename to slither/analyses/taint/__init__.py diff --git a/slither/analyses/taint/calls.py b/slither/analyses/taint/calls.py new file mode 100644 index 000000000..ac3e08260 --- /dev/null +++ b/slither/analyses/taint/calls.py @@ -0,0 +1,58 @@ +""" + Compute taint on call + + use taint from state_variable + + call from slithIR with a taint set to yes means its destination is tainted +""" +from slither.analyses.taint.state_variables import get_taint as get_taint_state +from slither.core.declarations import SolidityVariableComposed +from slither.slithir.operations import (HighLevelCall, Index, LowLevelCall, + Member, OperationWithLValue, Send, + Transfer) +from slither.slithir.variables import ReferenceVariable + +from .common import iterate_over_irs + +KEY = 'TAINT_CALL_DESTINATION' + +def _transfer_func(ir, read, refs, taints): + if isinstance(ir, OperationWithLValue) and any(var_read in taints for var_read in read): + taints += [ir.lvalue] + lvalue = ir.lvalue + while isinstance(lvalue, ReferenceVariable): + taints += [refs[lvalue]] + lvalue = refs[lvalue] + if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): + if ir.destination in taints: + ir.context[KEY] = True + + return taints + +def _visit_node(node, visited, taints): + if node in visited: + return + + visited += [node] + + taints = iterate_over_irs(node.irs, _transfer_func, taints) + + for son in node.sons: + _visit_node(son, visited, taints) + +def _run_taint(slither, initial_taint): + if KEY in slither.context: + return + for contract in slither.contracts: + for function in contract.functions: + if not function.is_implemented: + continue + _visit_node(function.entry_point, [], initial_taint + function.parameters) + +def run_taint(slither): + initial_taint = get_taint_state(slither) + initial_taint += [SolidityVariableComposed('msg.sender')] + + if KEY not in slither.context: + _run_taint(slither, initial_taint) + diff --git a/slither/analyses/taint/common.py b/slither/analyses/taint/common.py new file mode 100644 index 000000000..944f65bab --- /dev/null +++ b/slither/analyses/taint/common.py @@ -0,0 +1,18 @@ +from slither.slithir.operations import (Index, Member, Length, Balance) + +def iterate_over_irs(irs, transfer_func, taints): + refs = {} + for ir in irs: + if isinstance(ir, (Index, Member)): + refs[ir.lvalue] = ir.variable_left + + if isinstance(ir, (Length, Balance)): + refs[ir.lvalue] = ir.value + + if isinstance(ir, Index): + read = [ir.variable_left] + else: + read = ir.read + taints = transfer_func(ir, read, refs, taints) + return taints + diff --git a/slither/analyses/taint/specific_variable.py b/slither/analyses/taint/specific_variable.py new file mode 100644 index 000000000..fc508c963 --- /dev/null +++ b/slither/analyses/taint/specific_variable.py @@ -0,0 +1,115 @@ +""" + Compute taint from a specific variable + + Do not propagate taint on protected function or constructor + Propage to state variables + Iterate until it finding a fixpoint +""" +from slither.core.declarations.solidity_variables import SolidityVariable +from slither.core.variables.state_variable import StateVariable +from slither.core.variables.variable import Variable +from slither.slithir.operations import Index, Member, OperationWithLValue +from slither.slithir.variables import ReferenceVariable, TemporaryVariable + +from .common import iterate_over_irs + +def make_key(variable): + if isinstance(variable, Variable): + key = 'TAINT_{}'.format(id(variable)) + else: + assert isinstance(variable, SolidityVariable) + key = 'TAINT_{}{}'.format(variable.name, + str(type(variable))) + return key + +def _transfer_func_with_key(ir, read, refs, taints, key): + if isinstance(ir, OperationWithLValue) and ir.lvalue: + if any(is_tainted_from_key(var_read, key) or var_read in taints for var_read in read): + taints += [ir.lvalue] + ir.lvalue.context[key] = True + lvalue = ir.lvalue + while isinstance(lvalue, ReferenceVariable): + taints += [refs[lvalue]] + lvalue = refs[lvalue] + lvalue.context[key] = True + return taints + +def _visit_node(node, visited, key): + if node in visited: + return + + visited = visited + [node] + taints = node.function.slither.context[key] + + # taints only increase + # if we already see this node with the last taint set + # we dont need to explore itœ + if node in node.slither.context['visited_all_paths']: + if node.slither.context['visited_all_paths'][node] == taints: + return + + node.slither.context['visited_all_paths'][node] = taints + + # use of lambda function, as the key is required for this transfer_func + _transfer_func_ = lambda _ir, _read, _refs, _taints: _transfer_func_with_key(_ir, + _read, + _refs, + _taints, + key) + taints = iterate_over_irs(node.irs, _transfer_func_, taints) + + node.function.slither.context[key] = list(set(taints)) + + for son in node.sons: + _visit_node(son, visited, key) + +def run_taint(slither, taint): + + key = make_key(taint) + + # if a node was already visited by another path + # we will only explore it if the traversal brings + # new variables written + # This speedup the exploration through a light fixpoint + # Its particular useful on 'complex' functions with several loops and conditions + slither.context['visited_all_paths'] = {} + + prev_taints = [] + slither.context[key] = [taint] + # Loop until reaching a fixpoint + while(set(prev_taints) != set(slither.context[key])): + prev_taints = slither.context[key] + for contract in slither.contracts: + for function in contract.functions: + # Dont propagated taint on protected functions + if function.is_implemented and not function.is_protected(): + slither.context[key] = list(set(slither.context[key])) + _visit_node(function.entry_point, [], key) + + slither.context[key] = [v for v in prev_taints if isinstance(v, (StateVariable, SolidityVariable))] + +def is_tainted(variable, taint): + """ + Args: + variable (Variable) + taint (Variable): Root of the taint + """ + if not isinstance(variable, (Variable, SolidityVariable)): + return False + key = make_key(taint) + return (key in variable.context and variable.context[key]) or variable == taint + +def is_tainted_from_key(variable, key): + """ + Args: + variable (Variable) + key (str): key + """ + if not isinstance(variable, (Variable, SolidityVariable)): + return False + return key in variable.context and variable.context[key] + + +def get_state_variable_tainted(slither, taint): + key = make_key(taint) + return slither.context[key] diff --git a/slither/analyses/taint/state_variables.py b/slither/analyses/taint/state_variables.py new file mode 100644 index 000000000..f96853d7d --- /dev/null +++ b/slither/analyses/taint/state_variables.py @@ -0,0 +1,82 @@ +""" + Compute taint on state variables + + Do not propagate taint on protected function + Compute taint from function parameters, msg.sender and msg.value + Iterate until it finding a fixpoint + +""" +from slither.core.declarations.solidity_variables import \ + SolidityVariableComposed +from slither.core.variables.state_variable import StateVariable +from slither.slithir.operations import Index, Member, OperationWithLValue +from slither.slithir.variables import ReferenceVariable, TemporaryVariable + +from .common import iterate_over_irs +KEY = 'TAINT_STATE_VARIABLES' + +def _transfer_func(ir, read, refs, taints): + if isinstance(ir, OperationWithLValue) and any(var_read in taints for var_read in read): + taints += [ir.lvalue] + lvalue = ir.lvalue + while isinstance(lvalue, ReferenceVariable): + taints += [refs[lvalue]] + lvalue = refs[lvalue] + return taints + +def _visit_node(node, visited): + if node in visited: + return + + visited += [node] + taints = node.function.slither.context[KEY] + + taints = iterate_over_irs(node.irs, _transfer_func, taints) + + taints = [v for v in taints if not isinstance(v, (TemporaryVariable, ReferenceVariable))] + + node.function.slither.context[KEY] = list(set(taints)) + + for son in node.sons: + _visit_node(son, visited) + + +def _run_taint(slither, initial_taint): + if KEY in slither.context: + return + + prev_taints = [] + slither.context[KEY] = initial_taint + # Loop until reaching a fixpoint + while(set(prev_taints) != set(slither.context[KEY])): + prev_taints = slither.context[KEY] + for contract in slither.contracts: + for function in contract.functions: + if not function.is_implemented: + continue + # Dont propagated taint on protected functions + if not function.is_protected(): + slither.context[KEY] = list(set(slither.context[KEY] + function.parameters)) + _visit_node(function.entry_point, []) + + slither.context[KEY] = [v for v in prev_taints if isinstance(v, StateVariable)] + +def run_taint(slither, initial_taint=None): + if initial_taint is None: + initial_taint = [SolidityVariableComposed('msg.sender')] + initial_taint += [SolidityVariableComposed('msg.value')] + + if KEY not in slither.context: + _run_taint(slither, initial_taint) + +def get_taint(slither, initial_taint=None): + """ + Return the state variables tainted + Args: + slither: + initial_taint (List Variable) + Returns: + List(StateVariable) + """ + run_taint(slither, initial_taint) + return slither.context[KEY] diff --git a/slither/solcParsing/__init__.py b/slither/analyses/write/__init__.py similarity index 100% rename from slither/solcParsing/__init__.py rename to slither/analyses/write/__init__.py diff --git a/slither/analyses/write/are_variables_written.py b/slither/analyses/write/are_variables_written.py new file mode 100644 index 000000000..d644e3ac2 --- /dev/null +++ b/slither/analyses/write/are_variables_written.py @@ -0,0 +1,57 @@ +""" + Detect if all the given variables are written in all the paths of the function +""" +from slither.core.cfg.node import NodeType +from slither.core.declarations import SolidityFunction +from slither.slithir.operations import (Index, Member, OperationWithLValue, + SolidityCall, Length, Balance) +from slither.slithir.variables import ReferenceVariable + + +def _visit(node, visited, variables_written, variables_to_write): + + if node in visited: + return [] + + visited = visited + [node] + + refs = {} + for ir in node.irs: + if isinstance(ir, SolidityCall): + # TODO convert the revert to a THROW node + if ir.function in [SolidityFunction('revert(string)'), + SolidityFunction('revert()')]: + return [] + + if not isinstance(ir, OperationWithLValue): + continue + if isinstance(ir, (Index, Member)): + refs[ir.lvalue] = ir.variable_left + if isinstance(ir, (Length, Balance)): + refs[ir.lvalue] = ir.value + + variables_written = variables_written + [ir.lvalue] + lvalue = ir.lvalue + while isinstance(lvalue, ReferenceVariable): + variables_written = variables_written + [refs[lvalue]] + lvalue = refs[lvalue] + + ret = [] + if not node.sons and not node.type in [NodeType.THROW, NodeType.RETURN]: + ret += [v for v in variables_to_write if not v in variables_written] + + for son in node.sons: + ret += _visit(son, visited, variables_written, variables_to_write) + return ret + +def are_variables_written(function, variables_to_write): + """ + Return the list of variable that are not written at the end of the function + + Args: + function (Function) + variables_to_write (list Variable): variable that must be written + Returns: + list(Variable): List of variable that are not written (sublist of variables_to_write) + """ + return list(set(_visit(function.entry_point, [], [], variables_to_write))) diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index b94c1025a..dbbbb522d 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -3,19 +3,90 @@ """ import logging -from slither.core.sourceMapping.sourceMapping import SourceMapping -from slither.core.cfg.nodeType import NodeType +from slither.core.children.child_function import ChildFunction +from slither.core.declarations import Contract +from slither.core.declarations.solidity_variables import (SolidityFunction, + SolidityVariable) +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable +from slither.slithir.convert import convert_expression +from slither.slithir.operations import (Balance, HighLevelCall, Index, + InternalCall, Length, LibraryCall, + LowLevelCall, Member, + OperationWithLValue, SolidityCall) +from slither.slithir.variables import (Constant, ReferenceVariable, + TemporaryVariable, TupleVariable) +from slither.visitors.expression.expression_printer import ExpressionPrinter +from slither.visitors.expression.read_var import ReadVar +from slither.visitors.expression.write_var import WriteVar -from slither.visitors.expression.expressionPrinter import ExpressionPrinter -from slither.visitors.expression.readVar import ReadVar -from slither.visitors.expression.writeVar import WriteVar - -from slither.core.children.childFunction import ChildFunction - -from slither.core.declarations.solidityVariables import SolidityFunction logger = logging.getLogger("Node") +class NodeType: + + ENTRYPOINT = 0x0 # no expression + + # Node with expression + + EXPRESSION = 0x10 # normal case + RETURN = 0x11 # RETURN may contain an expression + IF = 0x12 + VARIABLE = 0x13 # Declaration of variable + ASSEMBLY = 0x14 + IFLOOP = 0x15 + + # Below the nodes have no expression + # But are used to expression CFG structure + + # Absorbing node + THROW = 0x20 + + # Loop related nodes + BREAK = 0x31 + CONTINUE = 0x32 + + # Only modifier node + PLACEHOLDER = 0x40 + + # Merging nodes + # Unclear if they will be necessary + ENDIF = 0x50 + STARTLOOP = 0x51 + ENDLOOP = 0x52 + +# @staticmethod + def str(t): + if t == 0x0: + return 'ENTRY_POINT' + if t == 0x10: + return 'EXPRESSION' + if t == 0x11: + return 'RETURN' + if t == 0x12: + return 'IF' + if t == 0x13: + return 'NEW VARIABLE' + if t == 0x14: + return 'INLINE ASM' + if t == 0x15: + return 'IF_LOOP' + if t == 0x20: + return 'THROW' + if t == 0x31: + return 'BREAK' + if t == 0x32: + return 'CONTINUE' + if t == 0x40: + return '_' + if t == 0x50: + return 'END_IF' + if t == 0x51: + return 'BEGIN_LOOP' + if t == 0x52: + return 'END_LOOP' + return 'Unknown type {}'.format(hex(t)) + def link_nodes(n1, n2): n1.add_son(n2) n2.add_father(n1) @@ -36,7 +107,12 @@ class Node(SourceMapping, ChildFunction): self._node_id = node_id self._vars_written = [] self._vars_read = [] - self._calls = [] + self._internal_calls = [] + self._solidity_calls = [] + self._high_level_calls = [] + self._low_level_calls = [] + self._external_calls_as_expressions = [] + self._irs = [] self._state_vars_written = [] self._state_vars_read = [] @@ -46,6 +122,10 @@ class Node(SourceMapping, ChildFunction): self._expression_vars_read = [] self._expression_calls = [] + @property + def slither(self): + return self.function.slither + @property def node_id(self): """Unique node id.""" @@ -63,21 +143,21 @@ class Node(SourceMapping, ChildFunction): """ list(Variable): Variables read (local/state/solidity) """ - return self._vars_read + return list(self._vars_read) @property def state_variables_read(self): """ list(StateVariable): State variables read """ - return self._state_vars_read + return list(self._state_vars_read) @property def solidity_variables_read(self): """ list(SolidityVariable): State variables read """ - return self._solidity_vars_read + return list(self._solidity_vars_read) @property def variables_read_as_expression(self): @@ -88,29 +168,63 @@ class Node(SourceMapping, ChildFunction): """ list(Variable): Variables written (local/state/solidity) """ - return self._vars_written + return list(self._vars_written) @property def state_variables_written(self): """ list(StateVariable): State variables written """ - return self._state_vars_written + return list(self._state_vars_written) @property def variables_written_as_expression(self): return self._expression_vars_written @property - def calls(self): + def internal_calls(self): + """ + list(Function or SolidityFunction): List of internal/soldiity function calls + """ + return list(self._internal_calls) + + @property + def solidity_calls(self): + """ + list(SolidityFunction): List of Soldity calls + """ + return list(self._internal_calls) + + @property + def high_level_calls(self): + """ + list((Contract, Function|Variable)): + List of high level calls (external calls). + A variable is called in case of call to a public state variable + Include library calls + """ + return list(self._high_level_calls) + + @property + def low_level_calls(self): + """ + list((Variable|SolidityVariable, str)): List of low_level call + A low level call is defined by + - the variable called + - the name of the function (call/delegatecall/codecall) + """ + return list(self._low_level_calls) + + @property + def external_calls_as_expressions(self): """ - list(Function or SolidityFunction): List of calls + list(CallExpression): List of message calls (that creates a transaction) """ - return self._calls + return self._external_calls_as_expressions @property def calls_as_expression(self): - return self._expression_calls + return list(self._expression_calls) @property def expression(self): @@ -126,7 +240,16 @@ class Node(SourceMapping, ChildFunction): def add_variable_declaration(self, var): assert self._variable_declaration is None self._variable_declaration = var - self._expression = var.expression + if var.expression: + self._vars_written += [var] + + @property + def variable_declaration(self): + """ + Returns: + LocalVariable + """ + return self._variable_declaration def __str__(self): txt = NodeType.str(self._node_type) + ' '+ str(self.expression) @@ -138,19 +261,25 @@ class Node(SourceMapping, ChildFunction): Returns: bool: True if the node has a require or assert call """ - return self.calls and\ - any(isinstance(c, SolidityFunction) and\ - (c.name in ['require(bool)', 'require(bool,string)', 'assert(bool)'])\ - for c in self.calls) + return any(c.name in ['require(bool)', 'require(bool,string)', 'assert(bool)'] for c in self.internal_calls) def contains_if(self): """ - Check if the node is a conditional node + Check if the node is a IF node Returns: bool: True if the node is a conditional node (IF or IFLOOP) """ return self.type in [NodeType.IF, NodeType.IFLOOP] + def is_conditional(self): + """ + Check if the node is a conditional node + A conditional node is either a IF or a require/assert + Returns: + bool: True if the node is a conditional node + """ + return self.contains_if() or self.contains_require_or_assert() + def add_father(self, father): """ Add a father node @@ -172,9 +301,10 @@ class Node(SourceMapping, ChildFunction): """ Returns the father nodes Returns: - fathers: list of fathers + list(Node): list of fathers """ - return self._fathers + return list(self._fathers) + def remove_father(self, father): """ Remove the father node. Do nothing if the node is not a father @@ -184,6 +314,13 @@ class Node(SourceMapping, ChildFunction): """ self._fathers = [x for x in self._fathers if x.node_id != father.node_id] + def remove_son(self, son): + """ Remove the son node. Do nothing if the node is not a son + + Args: + fathers: list of fathers to add + """ + self._sons = [x for x in self._sons if x.node_id != son.node_id] def add_son(self, son): """ Add a son node @@ -206,7 +343,69 @@ class Node(SourceMapping, ChildFunction): """ Returns the son nodes Returns: - sons: list of sons + list(Node): list of sons + """ + return list(self._sons) + + @property + def irs(self): + """ Returns the slithIR representation + + return + list(slithIR.Operation) """ - return self._sons + return self._irs + + def slithir_generation(self): + if self.expression: + expression = self.expression + self._irs = convert_expression(expression, self) + + self._find_read_write_call() + + def _find_read_write_call(self): + + def is_slithir_var(var): + return isinstance(var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)) + for ir in self.irs: + self._vars_read += [v for v in ir.read if not is_slithir_var(v)] + if isinstance(ir, OperationWithLValue): + if isinstance(ir, (Index, Member, Length, Balance)): + continue # Don't consider Member and Index operations -> ReferenceVariable + var = ir.lvalue + # If its a reference, we loop until finding the origin + if isinstance(var, (ReferenceVariable)): + while isinstance(var, ReferenceVariable): + var = var.points_to + # Only store non-slithIR variables + if not is_slithir_var(var): + self._vars_written.append(var) + + if isinstance(ir, InternalCall): + self._internal_calls.append(ir.function) + if isinstance(ir, SolidityCall): + # TODO: consider removing dependancy of solidity_call to internal_call + self._solidity_calls.append(ir.function) + self._internal_calls.append(ir.function) + if isinstance(ir, LowLevelCall): + assert isinstance(ir.destination, (Variable, SolidityVariable)) + self._low_level_calls.append((ir.destination, ir.function_name.value)) + elif isinstance(ir, (HighLevelCall)) and not isinstance(ir, LibraryCall): + if isinstance(ir.destination.type, Contract): + self._high_level_calls.append((ir.destination.type, ir.function)) + else: + self._high_level_calls.append((ir.destination.type.type, ir.function)) + elif isinstance(ir, LibraryCall): + assert isinstance(ir.destination, Contract) + self._high_level_calls.append((ir.destination, ir.function)) + + self._vars_read = list(set(self._vars_read)) + self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] + self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)] + self._vars_written = list(set(self._vars_written)) + self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] + self._internal_calls = list(set(self._internal_calls)) + self._solidity_calls = list(set(self._solidity_calls)) + self._high_level_calls = list(set(self._high_level_calls)) + self._low_level_calls = list(set(self._low_level_calls)) diff --git a/slither/core/cfg/nodeType.py b/slither/core/cfg/nodeType.py deleted file mode 100644 index 0ba09db31..000000000 --- a/slither/core/cfg/nodeType.py +++ /dev/null @@ -1,63 +0,0 @@ -class NodeType: - - ENTRYPOINT = 0x0 # no expression - - # Node with expression - - EXPRESSION = 0x10 # normal case - RETURN = 0x11 # RETURN may contain an expression - IF = 0x12 - VARIABLE = 0x13 # Declaration of variable - ASSEMBLY = 0x14 - IFLOOP = 0x15 - - # Below the nodes have no expression - # But are used to expression CFG structure - - # Absorbing node - THROW = 0x20 - - # Loop related nodes - BREAK = 0x31 - CONTINUE = 0x32 - - # Only modifier node - PLACEHOLDER = 0x40 - - # Merging nodes - # Unclear if they will be necessary - ENDIF = 0x50 - STARTLOOP = 0x51 - ENDLOOP = 0x52 - - @staticmethod - def str(t): - if t == 0x0: - return 'EntryPoint' - if t == 0x10: - return 'Expressions' - if t == 0x11: - return 'Return' - if t == 0x12: - return 'If' - if t == 0x13: - return 'New variable' - if t == 0x14: - return 'Inline Assembly' - if t == 0x15: - return 'IfLoop' - if t == 0x20: - return 'Throw' - if t == 0x31: - return 'Break' - if t == 0x32: - return 'Continue' - if t == 0x40: - return '_' - if t == 0x50: - return 'EndIf' - if t == 0x51: - return 'BeginLoop' - if t == 0x52: - return 'EndLoop' - return 'Unknown type {}'.format(hex(t)) diff --git a/slither/core/children/childContract.py b/slither/core/children/child_contract.py similarity index 100% rename from slither/core/children/childContract.py rename to slither/core/children/child_contract.py diff --git a/slither/core/children/childEvent.py b/slither/core/children/child_event.py similarity index 100% rename from slither/core/children/childEvent.py rename to slither/core/children/child_event.py diff --git a/slither/core/children/childFunction.py b/slither/core/children/child_function.py similarity index 100% rename from slither/core/children/childFunction.py rename to slither/core/children/child_function.py diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py new file mode 100644 index 000000000..8c16e3106 --- /dev/null +++ b/slither/core/children/child_node.py @@ -0,0 +1,20 @@ + +class ChildNode(object): + def __init__(self): + super(ChildNode, self).__init__() + self._node = None + + def set_node(self, node): + self._node = node + + @property + def node(self): + return self._node + + @property + def function(self): + return self.node.function + + @property + def contract(self): + return self.node.function.contract diff --git a/slither/core/children/childSlither.py b/slither/core/children/child_slither.py similarity index 100% rename from slither/core/children/childSlither.py rename to slither/core/children/child_slither.py diff --git a/slither/core/children/childStructure.py b/slither/core/children/child_structure.py similarity index 100% rename from slither/core/children/childStructure.py rename to slither/core/children/child_structure.py diff --git a/slither/core/declarations/__init__.py b/slither/core/declarations/__init__.py index e69de29bb..b6cddc787 100644 --- a/slither/core/declarations/__init__.py +++ b/slither/core/declarations/__init__.py @@ -0,0 +1,9 @@ +from .contract import Contract +from .enum import Enum +from .event import Event +from .function import Function +from .import_directive import Import +from .modifier import Modifier +from .pragma_directive import Pragma +from .solidity_variables import SolidityVariable, SolidityVariableComposed, SolidityFunction +from .structure import Structure diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index fe270d47a..f1e27d16f 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -2,8 +2,8 @@ Contract module """ import logging -from slither.core.children.childSlither import ChildSlither -from slither.core.sourceMapping.sourceMapping import SourceMapping +from slither.core.children.child_slither import ChildSlither +from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.declarations.function import Function logger = logging.getLogger("Contract") @@ -19,7 +19,7 @@ class Contract(ChildSlither, SourceMapping): self._name = None self._id = None - self._inheritances = [] + self._inheritance = [] self._enums = {} self._structures = {} @@ -52,21 +52,21 @@ class Contract(ChildSlither, SourceMapping): return self._id @property - def inheritances(self): + def inheritance(self): ''' - list(Contract): Inheritances list. Order: the first elem is the first father to be executed + list(Contract): Inheritance list. Order: the first elem is the first father to be executed ''' - return self._inheritances + return list(self._inheritance) @property - def inheritances_reverse(self): + def inheritance_reverse(self): ''' - list(Contract): Inheritances list. Order: the last elem is the first father to be executed + list(Contract): Inheritance list. Order: the last elem is the first father to be executed ''' - return reversed(self._inheritances) + return reversed(self._inheritance) - def setInheritances(self, inheritances): - self._inheritances = inheritances + def setInheritance(self, inheritance): + self._inheritance = inheritance @property def structures(self): @@ -95,6 +95,10 @@ class Contract(ChildSlither, SourceMapping): def modifiers_as_dict(self): return self._modifiers + @property + def constructor(self): + return next((func for func in self.functions if func.is_constructor), None) + @property def functions(self): ''' @@ -110,13 +114,19 @@ class Contract(ChildSlither, SourceMapping): return [f for f in self.functions if f.contract != self] @property - def functions_all_called(self): + def all_functions_called(self): ''' list(Function): List of functions reachable from the contract (include super) ''' - all_calls = (f.all_calls() for f in self.functions) + all_calls = (f.all_internal_calls() for f in self.functions) all_calls = [item for sublist in all_calls for item in sublist] + self.functions - all_calls = set(all_calls) + all_calls = list(set(all_calls)) + + all_constructors = [c.constructor for c in self.inheritance] + all_constructors = list(set([c for c in all_constructors if c])) + + all_calls = set(all_calls+all_constructors) + return [c for c in all_calls if isinstance(c, Function)] def functions_as_dict(self): @@ -135,7 +145,7 @@ class Contract(ChildSlither, SourceMapping): @property def state_variables(self): ''' - list(StateVariable): List of the state variables. + list(StateVariable): List of the state variables. ''' return list(self._variables.values()) @@ -144,7 +154,7 @@ class Contract(ChildSlither, SourceMapping): ''' list(StateVariable): List of the state variables. Alias to self.state_variables ''' - return self.state_variables + return list(self.state_variables) def variables_as_dict(self): return self._variables @@ -154,6 +164,10 @@ class Contract(ChildSlither, SourceMapping): return self._using_for def reverse_using_for(self, name): + ''' + Returns: + (list) + ''' return self._using_for[name] @property @@ -163,13 +177,13 @@ class Contract(ChildSlither, SourceMapping): def __str__(self): return self.name - def get_functions_reading_variable(self, variable): + def get_functions_reading_from_variable(self, variable): ''' Return the functions reading the variable ''' return [f for f in self.functions if f.is_reading(variable)] - def get_functions_writing_variable(self, variable): + def get_functions_writing_to_variable(self, variable): ''' Return the functions writting the variable ''' @@ -289,6 +303,7 @@ class Contract(ChildSlither, SourceMapping): def is_erc20(self): """ Check if the contract is an erc20 token + Note: it does not check for correct return values Returns: bool @@ -302,8 +317,8 @@ class Contract(ChildSlither, SourceMapping): """ Return the function summary Returns: - (str, list, list, list): (name, variables, fuction summaries, modifier summaries) + (str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries) """ func_summaries = [f.get_summary() for f in self.functions] modif_summaries = [f.get_summary() for f in self.modifiers] - return (self.name, [str(x) for x in self.inheritances], [str(x) for x in self.variables], func_summaries, modif_summaries) + return (self.name, [str(x) for x in self.inheritance], [str(x) for x in self.variables], func_summaries, modif_summaries) diff --git a/slither/core/declarations/enum.py b/slither/core/declarations/enum.py index e112eef2e..c81910c2c 100644 --- a/slither/core/declarations/enum.py +++ b/slither/core/declarations/enum.py @@ -1,5 +1,5 @@ -from slither.core.sourceMapping.sourceMapping import SourceMapping -from slither.core.children.childContract import ChildContract +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.children.child_contract import ChildContract class Enum(ChildContract, SourceMapping): def __init__(self, name, canonical_name, values): diff --git a/slither/core/declarations/event.py b/slither/core/declarations/event.py index d008c5b59..0974c4773 100644 --- a/slither/core/declarations/event.py +++ b/slither/core/declarations/event.py @@ -1,5 +1,5 @@ -from slither.core.children.childContract import ChildContract -from slither.core.sourceMapping.sourceMapping import SourceMapping +from slither.core.children.child_contract import ChildContract +from slither.core.source_mapping.source_mapping import SourceMapping class Event(ChildContract, SourceMapping): diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index cf4de158a..a16438f48 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -3,16 +3,17 @@ """ import logging from itertools import groupby -from slither.core.sourceMapping.sourceMapping import SourceMapping -from slither.core.children.childContract import ChildContract -from slither.core.variables.stateVariable import StateVariable +from slither.core.children.child_contract import ChildContract +from slither.core.declarations.solidity_variables import (SolidityFunction, + SolidityVariable, + SolidityVariableComposed) from slither.core.expressions.identifier import Identifier -from slither.core.expressions.unaryOperation import UnaryOperation -from slither.core.expressions.memberAccess import MemberAccess -from slither.core.expressions.indexAccess import IndexAccess - -from slither.core.declarations.solidityVariables import SolidityVariable, SolidityFunction +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.unary_operation import UnaryOperation +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.state_variable import StateVariable logger = logging.getLogger("Function") @@ -42,7 +43,11 @@ class Function(ChildContract, SourceMapping): self._vars_read_or_written = [] self._solidity_vars_read = [] self._state_vars_written = [] - self._calls = [] + self._internal_calls = [] + self._solidity_calls = [] + self._low_level_calls = [] + self._high_level_calls = [] + self._external_calls_as_expressions = [] self._expression_vars_read = [] self._expression_vars_written = [] self._expression_calls = [] @@ -50,6 +55,18 @@ class Function(ChildContract, SourceMapping): self._modifiers = [] self._payable = False + + @property + def return_type(self): + """ + Return the list of return type + If no return, return None + """ + returns = self.returns + if returns: + return [r.type for r in returns] + return None + @property def name(self): """ @@ -67,7 +84,7 @@ class Function(ChildContract, SourceMapping): """ list(Node): List of the nodes """ - return self._nodes + return list(self._nodes) @property def entry_point(self): @@ -130,21 +147,21 @@ class Function(ChildContract, SourceMapping): """ list(LocalVariable): List of the parameters """ - return self._parameters + return list(self._parameters) @property def returns(self): """ list(LocalVariable): List of the return variables """ - return self._returns + return list(self._returns) @property def modifiers(self): """ list(Modifier): List of the modifiers """ - return self._modifiers + return list(self._modifiers) def __str__(self): return self._name @@ -157,6 +174,13 @@ class Function(ChildContract, SourceMapping): """ return list(self._variables.values()) + @property + def local_variables(self): + """ + Return all local variables (dont include paramters and return values) + """ + return list(set(self.variables) - set(self.returns) - set(self.parameters)) + def variables_as_dict(self): return self._variables @@ -165,42 +189,42 @@ class Function(ChildContract, SourceMapping): """ list(Variable): Variables read (local/state/solidity) """ - return self._vars_read + return list(self._vars_read) @property def variables_written(self): """ list(Variable): Variables written (local/state/solidity) """ - return self._vars_written + return list(self._vars_written) @property def state_variables_read(self): """ list(StateVariable): State variables read """ - return self._state_vars_read + return list(self._state_vars_read) @property def solidity_variables_read(self): """ list(SolidityVariable): Solidity variables read """ - return self._solidity_vars_read + return list(self._solidity_vars_read) @property def state_variables_written(self): """ list(StateVariable): State variables written """ - return self._state_vars_written + return list(self._state_vars_written) @property def variables_read_or_written(self): """ list(Variable): Variables read or written (local/state/solidity) """ - return self._vars_read_or_written + return list(self._vars_read_or_written) @property def variables_read_as_expression(self): @@ -211,14 +235,49 @@ class Function(ChildContract, SourceMapping): return self._expression_vars_written @property - def calls(self): + def internal_calls(self): + """ + list(Function or SolidityFunction): List of function calls (that does not create a transaction) + """ + return list(self._internal_calls) + + @property + def solidity_calls(self): + """ + list(SolidityFunction): List of Soldity calls + """ + return list(self._internal_calls) + + @property + def high_level_calls(self): + """ + list((Contract, Function|Variable)): + List of high level calls (external calls). + A variable is called in case of call to a public state variable + Include library calls + """ + return list(self._high_level_calls) + + @property + def low_level_calls(self): """ - list(Function or SolidityFunction): List of calls + list((Variable|SolidityVariable, str)): List of low_level call + A low level call is defined by + - the variable called + - the name of the function (call/delegatecall/codecall) """ - return self._calls + return list(self._low_level_calls) + @property - def calls_as_expression(self): + def external_calls_as_expressions(self): + """ + list(ExpressionCall): List of message calls (that creates a transaction) + """ + return list(self._external_calls_as_expressions) + + @property + def calls_as_expressions(self): return self._expression_calls @property @@ -233,7 +292,8 @@ class Function(ChildContract, SourceMapping): @property def signature(self): """ - (str, list(str), list(str)): Function signature as (name, list parameters type, list return values type) + (str, list(str), list(str)): Function signature as + (name, list parameters type, list return values type) """ return self.name, [str(x.type) for x in self.parameters], [str(x.type) for x in self.returns] @@ -256,8 +316,12 @@ class Function(ChildContract, SourceMapping): return name+'('+','.join(parameters)+')' + @property + def slither(self): + return self.contract.slither + def _filter_state_variables_written(self, expressions): - ret =[] + ret = [] for expression in expressions: if isinstance(expression, Identifier): ret.append(expression) @@ -320,44 +384,49 @@ class Function(ChildContract, SourceMapping): calls = [x for x in calls if x] calls = [item for sublist in calls for item in sublist] # Remove dupplicate if they share the same string representation + # TODO: check if groupby is still necessary here calls = [next(obj) for i, obj in\ groupby(sorted(calls, key=lambda x: str(x)), lambda x: str(x))] self._expression_calls = calls - calls = [x.calls for x in self.nodes] - calls = [x for x in calls if x] - calls = [item for sublist in calls for item in sublist] - calls = [next(obj) for i, obj in\ - groupby(sorted(calls, key=lambda x: str(x)), lambda x: str(x))] - self._calls = [c for c in calls if isinstance(c, (Function, SolidityFunction))] + internal_calls = [x.internal_calls for x in self.nodes] + internal_calls = [x for x in internal_calls if x] + internal_calls = [item for sublist in internal_calls for item in sublist] + internal_calls = [next(obj) for i, obj in + groupby(sorted(internal_calls, key=lambda x: str(x)), lambda x: str(x))] + self._internal_calls = internal_calls - def all_state_variables_read(self): - """ recursive version of variables_read - """ - variables = self.state_variables_read - explored = [self] - to_explore = [c for c in self.calls if isinstance(c, Function) and c not in explored] - to_explore += [m for m in self.modifiers if m not in explored] + self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)] - while to_explore: - f = to_explore[0] - to_explore = to_explore[1:] - if f in explored: - continue - explored.append(f) - variables += f.state_variables_read - to_explore += [c for c in f.calls if\ - isinstance(c, Function) and c not in explored and c not in to_explore] - to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] + low_level_calls = [x.low_level_calls for x in self.nodes] + low_level_calls = [x for x in low_level_calls if x] + low_level_calls = [item for sublist in low_level_calls for item in sublist] + low_level_calls = [next(obj) for i, obj in + groupby(sorted(low_level_calls, key=lambda x: str(x)), lambda x: str(x))] - return list(set(variables)) + self._low_level_calls = low_level_calls - def all_solidity_variables_read(self): - """ recursive version of solidity_read - """ - variables = self.solidity_variables_read + high_level_calls = [x.high_level_calls for x in self.nodes] + high_level_calls = [x for x in high_level_calls if x] + high_level_calls = [item for sublist in high_level_calls for item in sublist] + high_level_calls = [next(obj) for i, obj in + groupby(sorted(high_level_calls, key=lambda x: str(x)), lambda x: str(x))] + + self._high_level_calls = high_level_calls + + external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes] + external_calls_as_expressions = [x for x in external_calls_as_expressions if x] + external_calls_as_expressions = [item for sublist in external_calls_as_expressions for item in sublist] + external_calls_as_expressions = [next(obj) for i, obj in + groupby(sorted(external_calls_as_expressions, key=lambda x: str(x)), lambda x: str(x))] + self._external_calls_as_expressions = external_calls_as_expressions + + + def _explore_functions(self, f_new_values): + values = f_new_values(self) explored = [self] - to_explore = [c for c in self.calls if isinstance(c, Function) and c not in explored] + to_explore = [c for c in self.internal_calls if + isinstance(c, Function) and c not in explored] to_explore += [m for m in self.modifiers if m not in explored] while to_explore: @@ -366,75 +435,90 @@ class Function(ChildContract, SourceMapping): if f in explored: continue explored.append(f) - variables += f.solidity_variables_read - to_explore += [c for c in f.calls if\ + + values += f_new_values(f) + + to_explore += [c for c in f.internal_calls if\ isinstance(c, Function) and c not in explored and c not in to_explore] to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] - return list(set(variables)) + return list(set(values)) - def all_expressions(self): + def all_state_variables_read(self): """ recursive version of variables_read """ - variables = self.expressions - explored = [self] - to_explore = [c for c in self.calls if isinstance(c, Function) and c not in explored] - to_explore += [m for m in self.modifiers if m not in explored] + return self._explore_functions(lambda x: x.state_variables_read) - while to_explore: - f = to_explore[0] - to_explore = to_explore[1:] - if f in explored: - continue - explored.append(f) - variables += f.expressions - to_explore += [c for c in f.calls if\ - isinstance(c, Function) and c not in explored and c not in to_explore] - to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] + def all_solidity_variables_read(self): + """ recursive version of solidity_read + """ + return self._explore_functions(lambda x: x.solidity_variables_read) - return list(set(variables)) + def all_expressions(self): + """ recursive version of variables_read + """ + return self._explore_functions(lambda x: x.expressions) def all_state_variables_written(self): """ recursive version of variables_written """ - variables = self.state_variables_written - explored = [self] - to_explore = [c for c in self.calls if isinstance(c, Function) and c not in explored] - to_explore += [m for m in self.modifiers if m not in explored] + return self._explore_functions(lambda x: x.state_variables_written) - while to_explore: - f = to_explore[0] - to_explore = to_explore[1:] - if f in explored: - continue - explored.append(f) - variables += f.state_variables_written - to_explore += [c for c in f.calls if\ - isinstance(c, Function) and c not in explored and c not in to_explore] - to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] + def all_internal_calls(self): + """ recursive version of internal_calls + """ + return self._explore_functions(lambda x: x.internal_calls) - return list(set(variables)) + def all_conditional_state_variables_read(self): + """ + Return the state variable used in a condition - def all_calls(self): - """ recursive version of calls + Over approximate and also return index access + It won't work if the variable is assigned to a temp variable """ - calls = self.calls - explored = [self] - to_explore = [c for c in self.calls if isinstance(c, Function) and c not in explored] - to_explore += [m for m in self.modifiers if m not in explored] + def _explore_func(func): + ret = [n.state_variables_read for n in func.nodes if n.is_conditional()] + return [item for sublist in ret for item in sublist] + return self._explore_functions(lambda x: _explore_func(x)) - while to_explore: - f = to_explore[0] - to_explore = to_explore[1:] - if f in explored: - continue - explored.append(f) - calls += f.calls - to_explore += [c for c in f.calls if\ - isinstance(c, Function) and c not in explored and c not in to_explore] - to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] + def all_conditional_solidity_variables_read(self): + """ + Return the Soldiity variables directly used in a condtion - return list(set(calls)) + Use of the IR to filter index access + Assumption: the solidity vars are used directly in the conditional node + It won't work if the variable is assigned to a temp variable + """ + from slither.slithir.operations.binary import Binary + def _solidity_variable_in_node(node): + ret = [] + for ir in node.irs: + if isinstance(ir, Binary): + ret += ir.read + return [var for var in ret if isinstance(var, SolidityVariable)] + def _explore_func(func, f): + ret = [f(n) for n in func.nodes if n.is_conditional()] + return [item for sublist in ret for item in sublist] + return self._explore_functions(lambda x: _explore_func(x, _solidity_variable_in_node)) + + def all_solidity_variables_used_as_args(self): + """ + Return the Soldiity variables directly used in a call + + Use of the IR to filter index access + Used to catch check(msg.sender) + """ + from slither.slithir.operations.internal_call import InternalCall + def _solidity_variable_in_node(node): + ret = [] + for ir in node.irs: + if isinstance(ir, InternalCall): + ret += ir.read + return [var for var in ret if isinstance(var, SolidityVariable)] + def _explore_func(func, f): + ret = [f(n) for n in func.nodes] + return [item for sublist in ret for item in sublist] + return self._explore_functions(lambda x: _explore_func(x, _solidity_variable_in_node)) def is_reading(self, variable): """ @@ -508,15 +592,54 @@ class Function(ChildContract, SourceMapping): f.write("}\n") + def slithir_cfg_to_dot(self, filename): + """ + Export the function to a dot file + Args: + filename (str) + """ + from slither.core.cfg.node import NodeType + with open(filename, 'w') as f: + f.write('digraph{\n') + for node in self.nodes: + label = 'Node Type: {}\n'.format(NodeType.str(node.type)) + if node.expression: + label += '\nEXPRESSION:\n{}\n'.format(node.expression) + label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs]) + f.write('{}[label="{}"];\n'.format(node.node_id, label)) + for son in node.sons: + f.write('{}->{};\n'.format(node.node_id, son.node_id)) + + f.write("}\n") + def get_summary(self): """ Return the function summary Returns: - (str, str, list(str), list(str), listr(str), list(str); - name, visibility, modifiers, variables read, variables written, calls + (str, str, list(str), list(str), listr(str), list(str), list(str); + name, visibility, modifiers, vars read, vars written, internal_calls, external_calls_as_expressions """ return (self.name, self.visibility, [str(x) for x in self.modifiers], [str(x) for x in self.state_variables_read + self.solidity_variables_read], [str(x) for x in self.state_variables_written], - [str(x) for x in self.calls]) + [str(x) for x in self.internal_calls], + [str(x) for x in self.external_calls_as_expressions]) + + def is_protected(self): + """ + Determine if the function is protected using a check on msg.sender + + Only detects if msg.sender is directly used in a condition + For example, it wont work for: + address a = msg.sender + require(a == owner) + Returns + (bool) + """ + + if self.is_constructor: + return True + conditional_vars = self.all_conditional_solidity_variables_read() + args_vars = self.all_solidity_variables_used_as_args() + return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars diff --git a/slither/core/declarations/import_directive.py b/slither/core/declarations/import_directive.py new file mode 100644 index 000000000..8e22eb8c9 --- /dev/null +++ b/slither/core/declarations/import_directive.py @@ -0,0 +1,14 @@ +from slither.core.source_mapping.source_mapping import SourceMapping + +class Import(SourceMapping): + + def __init__(self, filename): + super(Import, self).__init__() + self._filename = filename + + @property + def filename(self): + return self._filename + + def __str__(self): + return self.filename diff --git a/slither/core/declarations/pragma_directive.py b/slither/core/declarations/pragma_directive.py new file mode 100644 index 000000000..1747b9229 --- /dev/null +++ b/slither/core/declarations/pragma_directive.py @@ -0,0 +1,21 @@ +from slither.core.source_mapping.source_mapping import SourceMapping + +class Pragma(SourceMapping): + + def __init__(self, directive): + super(Pragma, self).__init__() + self._directive = directive + + @property + def directive(self): + ''' + list(str) + ''' + return self._directive + + @property + def version(self): + return ''.join(self.directive[1:]) + + def __str__(self): + return 'pragma '+''.join(self.directive) diff --git a/slither/core/declarations/solidityVariables.py b/slither/core/declarations/solidityVariables.py deleted file mode 100644 index 7d537f127..000000000 --- a/slither/core/declarations/solidityVariables.py +++ /dev/null @@ -1,48 +0,0 @@ -# https://solidity.readthedocs.io/en/v0.4.24/units-and-global-variables.html - - -SOLIDITY_VARIABLES = ["block", "msg", "now", "tx", "this", "super", 'abi'] - -SOLIDITY_VARIABLES_COMPOSED = ["block.coinbase", "block.difficulty", "block.gaslimit", "block.number", "block.timestamp", "msg.data", "msg.gas", "msg.sender", "msg.sig", "msg.value", "tx.gasprice", "tx.origin"] - -SOLIDITY_FUNCTIONS = ["gasleft()", "assert(bool)", "require(bool)", "require(bool,string)", "revert()", "revert(string)", "addmod(uint256,uint256,uint256)", "mulmod(uint256,uint256,uint256)", "keccak256()", "sha256()", "sha3()", "ripemd160()", "ecrecover(bytes32,uint8,bytes32,bytes32)", "selfdestruct(address)", "suicide(address)", "log0(bytes32)", "log1(bytes32,bytes32)", "log2(bytes32,bytes32,bytes32)", "log3(bytes32,bytes32,bytes32,bytes32)", "blockhash(uint256)"] - -class SolidityVariable: - - def __init__(self, name): - assert name in SOLIDITY_VARIABLES - self._name = name - - @property - def name(self): - return self._name - - def __str__(self): - return self._name - -class SolidityVariableComposed(SolidityVariable): - def __init__(self, name): - assert name in SOLIDITY_VARIABLES_COMPOSED - self._name = name - - @property - def name(self): - return self._name - - def __str__(self): - return self._name - - - -class SolidityFunction: - - def __init__(self, name): - assert name in SOLIDITY_FUNCTIONS - self._name = name - - @property - def name(self): - return self._name - - def __str__(self): - return self._name diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py new file mode 100644 index 000000000..36f0795f1 --- /dev/null +++ b/slither/core/declarations/solidity_variables.py @@ -0,0 +1,147 @@ +# https://solidity.readthedocs.io/en/v0.4.24/units-and-global-variables.html +from slither.core.context.context import Context +from slither.core.solidity_types import ElementaryType + +SOLIDITY_VARIABLES = {"now":'uint256', + "this":'address', + 'abi':'address', # to simplify the conversion, assume that abi return an address + 'msg':'', + 'tx':'', + 'block':'', + 'super':''} + +SOLIDITY_VARIABLES_COMPOSED = {"block.coinbase":"address", + "block.difficulty":"uint256", + "block.gaslimit":"uint256", + "block.number":"uint256", + "block.timestamp":"uint256", + "block.blockhash":"uint256", # alias for blockhash. It's a call + "msg.data":"bytes", + "msg.gas":"uint256", + "msg.sender":"address", + "msg.sig":"bytes4", + "msg.value":"uint256", + "tx.gasprice":"uint256", + "tx.origin":"address"} + + +SOLIDITY_FUNCTIONS = {"gasleft()":['uint256'], + "assert(bool)":[], + "require(bool)":[], + "require(bool,string)":[], + "revert()":[], + "revert(string)":[], + "addmod(uint256,uint256,uint256)":['uint256'], + "mulmod(uint256,uint256,uint256)":['uint256'], + "keccak256()":['bytes32'], + "sha256()":['bytes32'], + "sha3()":['bytes32'], + "ripemd160()":['bytes32'], + "ecrecover(bytes32,uint8,bytes32,bytes32)":['address'], + "selfdestruct(address)":[], + "suicide(address)":[], + "log0(bytes32)":[], + "log1(bytes32,bytes32)":[], + "log2(bytes32,bytes32,bytes32)":[], + "log3(bytes32,bytes32,bytes32,bytes32)":[], + "blockhash(uint256)":['bytes32'], + # the following need a special handling + # as they are recognized as a SolidityVariableComposed + # and converted to a SolidityFunction by SlithIR + "this.balance()":['uint256'], + "abi.encode()":['bytes'], + "abi.encodePacked()":['bytes'], + "abi.encodeWithSelector()":["bytes"], + "abi.encodeWithSignature()":["bytes"]} + +def solidity_function_signature(name): + """ + Return the function signature (containing the return value) + It is useful if a solidity function is used as a pointer + (see exoressionParsing.find_variable documentation) + Args: + name(str): + Returns: + str + """ + return name+' returns({})'.format(','.join(SOLIDITY_FUNCTIONS[name])) + +class SolidityVariable(Context): + + def __init__(self, name): + super(SolidityVariable, self).__init__() + self._check_name(name) + self._name = name + + # dev function, will be removed once the code is stable + def _check_name(self, name): + assert name in SOLIDITY_VARIABLES + + @property + def name(self): + return self._name + + @property + def type(self): + return ElementaryType(SOLIDITY_VARIABLES[self.name]) + + def __str__(self): + return self._name + + def __eq__(self, other): + return self.__class__ == other.__class__ and self.name == other.name + + def __hash__(self): + return hash(self.name) + +class SolidityVariableComposed(SolidityVariable): + def __init__(self, name): + super(SolidityVariableComposed, self).__init__(name) + + def _check_name(self, name): + assert name in SOLIDITY_VARIABLES_COMPOSED + + @property + def name(self): + return self._name + + @property + def type(self): + return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name]) + + def __str__(self): + return self._name + + def __eq__(self, other): + return self.__class__ == other.__class__ and self.name == other.name + + def __hash__(self): + return hash(self.name) + + +class SolidityFunction: + + def __init__(self, name): + assert name in SOLIDITY_FUNCTIONS + self._name = name + + @property + def name(self): + return self._name + + @property + def full_name(self): + return self.name + + @property + def return_type(self): + return [ElementaryType(x) for x in SOLIDITY_FUNCTIONS[self.name]] + + def __str__(self): + return self._name + + def __eq__(self, other): + return self.__class__ == other.__class__ and self.name == other.name + + def __hash__(self): + return hash(self.name) diff --git a/slither/core/declarations/structure.py b/slither/core/declarations/structure.py index 45721df8e..b11fb7e35 100644 --- a/slither/core/declarations/structure.py +++ b/slither/core/declarations/structure.py @@ -1,5 +1,5 @@ -from slither.core.sourceMapping.sourceMapping import SourceMapping -from slither.core.children.childContract import ChildContract +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.children.child_contract import ChildContract from slither.core.variables.variable import Variable diff --git a/slither/core/expressions/__init__.py b/slither/core/expressions/__init__.py index e69de29bb..23690aca0 100644 --- a/slither/core/expressions/__init__.py +++ b/slither/core/expressions/__init__.py @@ -0,0 +1,16 @@ +from .assignment_operation import AssignmentOperation, AssignmentOperationType +from .binary_operation import BinaryOperation, BinaryOperationType +from .call_expression import CallExpression +from .conditional_expression import ConditionalExpression +from .elementary_type_name_expression import ElementaryTypeNameExpression +from .identifier import Identifier +from .index_access import IndexAccess +from .literal import Literal +from .new_array import NewArray +from .new_contract import NewContract +from .new_elementary_type import NewElementaryType +from .super_call_expression import SuperCallExpression +from .super_identifier import SuperIdentifier +from .tuple_expression import TupleExpression +from .type_conversion import TypeConversion +from .unary_operation import UnaryOperation, UnaryOperationType diff --git a/slither/core/expressions/assignmentOperation.py b/slither/core/expressions/assignment_operation.py similarity index 94% rename from slither/core/expressions/assignmentOperation.py rename to slither/core/expressions/assignment_operation.py index eb9fbc3ec..fe7d524bf 100644 --- a/slither/core/expressions/assignmentOperation.py +++ b/slither/core/expressions/assignment_operation.py @@ -1,5 +1,5 @@ import logging -from slither.core.expressions.expressionTyped import ExpressionTyped +from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression @@ -13,7 +13,7 @@ class AssignmentOperationType: ASSIGN_LEFT_SHIFT = 4 # <<= ASSIGN_RIGHT_SHIFT = 5 # >>= ASSIGN_ADDITION = 6 # += - ASSIGN_SUBSTRACTION = 7 # -= + ASSIGN_SUBTRACTION = 7 # -= ASSIGN_MULTIPLICATION = 8 # *= ASSIGN_DIVISION = 9 # /= ASSIGN_MODULO = 10 # %= @@ -35,7 +35,7 @@ class AssignmentOperationType: if operation_type == '+=': return AssignmentOperationType.ASSIGN_ADDITION if operation_type == '-=': - return AssignmentOperationType.ASSIGN_SUBSTRACTION + return AssignmentOperationType.ASSIGN_SUBTRACTION if operation_type == '*=': return AssignmentOperationType.ASSIGN_MULTIPLICATION if operation_type == '/=': @@ -62,7 +62,7 @@ class AssignmentOperationType: return '>>=' if operation_type == AssignmentOperationType.ASSIGN_ADDITION: return '+=' - if operation_type == AssignmentOperationType.ASSIGN_SUBSTRACTION: + if operation_type == AssignmentOperationType.ASSIGN_SUBTRACTION: return '-=' if operation_type == AssignmentOperationType.ASSIGN_MULTIPLICATION: return '*=' diff --git a/slither/core/expressions/binaryOperation.py b/slither/core/expressions/binary_operation.py similarity index 95% rename from slither/core/expressions/binaryOperation.py rename to slither/core/expressions/binary_operation.py index 4a7c22aa3..b404d88cc 100644 --- a/slither/core/expressions/binaryOperation.py +++ b/slither/core/expressions/binary_operation.py @@ -1,5 +1,5 @@ import logging -from slither.core.expressions.expressionTyped import ExpressionTyped +from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression @@ -11,7 +11,7 @@ class BinaryOperationType: DIVISION = 2 # / MODULO = 3 # % ADDITION = 4 # + - SUBSTRACTION = 5 # - + SUBTRACTION = 5 # - LEFT_SHIFT = 6 # << RIGHT_SHIT = 7 # >>> AND = 8 # & @@ -39,7 +39,7 @@ class BinaryOperationType: if operation_type == '+': return BinaryOperationType.ADDITION if operation_type == '-': - return BinaryOperationType.SUBSTRACTION + return BinaryOperationType.SUBTRACTION if operation_type == '<<': return BinaryOperationType.LEFT_SHIFT if operation_type == '>>': @@ -82,7 +82,7 @@ class BinaryOperationType: return '%' if operation_type == BinaryOperationType.ADDITION: return '+' - if operation_type == BinaryOperationType.SUBSTRACTION: + if operation_type == BinaryOperationType.SUBTRACTION: return '-' if operation_type == BinaryOperationType.LEFT_SHIFT: return '<<' @@ -123,7 +123,7 @@ class BinaryOperation(ExpressionTyped): self._type = expression_type @property - def get_expression(self): + def expressions(self): return self._expressions @property diff --git a/slither/core/expressions/callExpression.py b/slither/core/expressions/call_expression.py similarity index 100% rename from slither/core/expressions/callExpression.py rename to slither/core/expressions/call_expression.py diff --git a/slither/core/expressions/conditionalExpression.py b/slither/core/expressions/conditional_expression.py similarity index 100% rename from slither/core/expressions/conditionalExpression.py rename to slither/core/expressions/conditional_expression.py diff --git a/slither/core/expressions/elementaryTypeNameExpression.py b/slither/core/expressions/elementary_type_name_expression.py similarity index 90% rename from slither/core/expressions/elementaryTypeNameExpression.py rename to slither/core/expressions/elementary_type_name_expression.py index 70f73130f..cf8c07c17 100644 --- a/slither/core/expressions/elementaryTypeNameExpression.py +++ b/slither/core/expressions/elementary_type_name_expression.py @@ -2,7 +2,7 @@ This expression does nothing, if a contract used it, its probably a bug """ from slither.core.expressions.expression import Expression -from slither.core.solidityTypes.type import Type +from slither.core.solidity_types.type import Type class ElementaryTypeNameExpression(Expression): diff --git a/slither/core/expressions/expression.py b/slither/core/expressions/expression.py index 620239403..6e9bd7aba 100644 --- a/slither/core/expressions/expression.py +++ b/slither/core/expressions/expression.py @@ -1,4 +1,4 @@ -from slither.core.sourceMapping.sourceMapping import SourceMapping +from slither.core.source_mapping.source_mapping import SourceMapping class Expression( SourceMapping): diff --git a/slither/core/expressions/expressionTyped.py b/slither/core/expressions/expression_typed.py similarity index 100% rename from slither/core/expressions/expressionTyped.py rename to slither/core/expressions/expression_typed.py diff --git a/slither/core/expressions/identifier.py b/slither/core/expressions/identifier.py index 71e104ad7..cb92d8329 100644 --- a/slither/core/expressions/identifier.py +++ b/slither/core/expressions/identifier.py @@ -1,4 +1,4 @@ -from slither.core.expressions.expressionTyped import ExpressionTyped +from slither.core.expressions.expression_typed import ExpressionTyped class Identifier(ExpressionTyped): diff --git a/slither/core/expressions/indexAccess.py b/slither/core/expressions/index_access.py similarity index 86% rename from slither/core/expressions/indexAccess.py rename to slither/core/expressions/index_access.py index 909bca393..5e02bc6fa 100644 --- a/slither/core/expressions/indexAccess.py +++ b/slither/core/expressions/index_access.py @@ -1,5 +1,5 @@ -from slither.core.expressions.expressionTyped import ExpressionTyped -from slither.core.solidityTypes.type import Type +from slither.core.expressions.expression_typed import ExpressionTyped +from slither.core.solidity_types.type import Type class IndexAccess(ExpressionTyped): diff --git a/slither/core/expressions/memberAccess.py b/slither/core/expressions/member_access.py similarity index 86% rename from slither/core/expressions/memberAccess.py rename to slither/core/expressions/member_access.py index 01b72bbee..72d1238b4 100644 --- a/slither/core/expressions/memberAccess.py +++ b/slither/core/expressions/member_access.py @@ -1,6 +1,6 @@ from slither.core.expressions.expression import Expression -from slither.core.expressions.expressionTyped import ExpressionTyped -from slither.core.solidityTypes.type import Type +from slither.core.expressions.expression_typed import ExpressionTyped +from slither.core.solidity_types.type import Type class MemberAccess(ExpressionTyped): diff --git a/slither/core/expressions/newArray.py b/slither/core/expressions/new_array.py similarity index 78% rename from slither/core/expressions/newArray.py rename to slither/core/expressions/new_array.py index fecdaeee7..9c59a75a9 100644 --- a/slither/core/expressions/newArray.py +++ b/slither/core/expressions/new_array.py @@ -1,8 +1,5 @@ -import logging -from .expression import Expression -from slither.core.solidityTypes.type import Type - -logger = logging.getLogger("NewArray") +from slither.core.expressions.expression import Expression +from slither.core.solidity_types.type import Type class NewArray(Expression): diff --git a/slither/core/expressions/newContract.py b/slither/core/expressions/new_contract.py similarity index 100% rename from slither/core/expressions/newContract.py rename to slither/core/expressions/new_contract.py diff --git a/slither/core/expressions/newElementaryType.py b/slither/core/expressions/new_elementary_type.py similarity index 84% rename from slither/core/expressions/newElementaryType.py rename to slither/core/expressions/new_elementary_type.py index de9a3a7cb..b099bed9e 100644 --- a/slither/core/expressions/newElementaryType.py +++ b/slither/core/expressions/new_elementary_type.py @@ -1,5 +1,5 @@ from slither.core.expressions.expression import Expression -from slither.core.solidityTypes.elementaryType import ElementaryType +from slither.core.solidity_types.elementary_type import ElementaryType class NewElementaryType(Expression): diff --git a/slither/core/expressions/superCallExpression.py b/slither/core/expressions/super_call_expression.py similarity index 61% rename from slither/core/expressions/superCallExpression.py rename to slither/core/expressions/super_call_expression.py index e2500b34a..420e324f9 100644 --- a/slither/core/expressions/superCallExpression.py +++ b/slither/core/expressions/super_call_expression.py @@ -1,4 +1,4 @@ from slither.core.expressions.expression import Expression -from slither.core.expressions.callExpression import CallExpression +from slither.core.expressions.call_expression import CallExpression class SuperCallExpression(CallExpression): pass diff --git a/slither/core/expressions/superIdentifier.py b/slither/core/expressions/super_identifier.py similarity index 69% rename from slither/core/expressions/superIdentifier.py rename to slither/core/expressions/super_identifier.py index 2f39f467d..33299b9a9 100644 --- a/slither/core/expressions/superIdentifier.py +++ b/slither/core/expressions/super_identifier.py @@ -1,4 +1,4 @@ -from slither.core.expressions.expressionTyped import ExpressionTyped +from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.identifier import Identifier class SuperIdentifier(Identifier): diff --git a/slither/core/expressions/tupleExpression.py b/slither/core/expressions/tuple_expression.py similarity index 77% rename from slither/core/expressions/tupleExpression.py rename to slither/core/expressions/tuple_expression.py index 29df13e0c..90c7f91c0 100644 --- a/slither/core/expressions/tupleExpression.py +++ b/slither/core/expressions/tuple_expression.py @@ -12,6 +12,6 @@ class TupleExpression(Expression): return self._expressions def __str__(self): - expressions_str = [str(e) for e in self.expressions] - return '(' + ','.join(expressions_str) + ')' + expressions_str = [str(e) for e in self.expressions] + return '(' + ','.join(expressions_str) + ')' diff --git a/slither/core/expressions/typeConversion.py b/slither/core/expressions/type_conversion.py similarity index 81% rename from slither/core/expressions/typeConversion.py rename to slither/core/expressions/type_conversion.py index 4b0e01bdf..7fe165754 100644 --- a/slither/core/expressions/typeConversion.py +++ b/slither/core/expressions/type_conversion.py @@ -1,6 +1,6 @@ -from slither.core.expressions.expressionTyped import ExpressionTyped +from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression -from slither.core.solidityTypes.type import Type +from slither.core.solidity_types.type import Type class TypeConversion(ExpressionTyped): diff --git a/slither/core/expressions/unaryOperation.py b/slither/core/expressions/unary_operation.py similarity index 95% rename from slither/core/expressions/unaryOperation.py rename to slither/core/expressions/unary_operation.py index 549bc3e24..3416cd854 100644 --- a/slither/core/expressions/unaryOperation.py +++ b/slither/core/expressions/unary_operation.py @@ -1,7 +1,7 @@ import logging -from slither.core.expressions.expressionTyped import ExpressionTyped +from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression -from slither.core.solidityTypes.type import Type +from slither.core.solidity_types.type import Type logger = logging.getLogger("UnaryOperation") @@ -101,6 +101,10 @@ class UnaryOperation(ExpressionTyped): def type_str(self): return UnaryOperationType.str(self._type) + @property + def type(self): + return self._type + @property def is_prefix(self): return UnaryOperationType.is_prefix(self._type) diff --git a/slither/core/slitherCore.py b/slither/core/slither_core.py similarity index 59% rename from slither/core/slitherCore.py rename to slither/core/slither_core.py index 15d8fe9ed..cd8f8a113 100644 --- a/slither/core/slitherCore.py +++ b/slither/core/slither_core.py @@ -2,18 +2,26 @@ Main module """ import os +from slither.core.context.context import Context -class Slither: +class Slither(Context): """ Slither static analyzer """ - name_class = 'Slither' def __init__(self): + super(Slither, self).__init__() self._contracts = {} self._filename = None self._source_units = {} self._solc_version = None # '0.3' or '0.4':! + self._pragma_directives = [] + self._import_directives = [] + self._raw_source_code = {} + + @property + def source_units(self): + return self._source_units @property def contracts(self): @@ -23,9 +31,9 @@ class Slither: @property def contracts_derived(self): """list(Contract): List of contracts that are derived and not inherited.""" - inheritances = (x.inheritances for x in self.contracts) - inheritances = (item for sublist in inheritances for item in sublist) - return [c for c in self._contracts.values() if c not in inheritances] + inheritance = (x.inheritance for x in self.contracts) + inheritance = [item for sublist in inheritance for item in sublist] + return [c for c in self._contracts.values() if c not in inheritance] def contracts_as_dict(self): """list(dict(str: Contract): List of contracts as dict: name -> Contract.""" @@ -41,6 +49,21 @@ class Slither: """str: Solidity version.""" return self._solc_version + @property + def pragma_directives(self): + """ list(list(str)): Pragma directives. Example [['solidity', '^', '0.4', '.24']]""" + return self._pragma_directives + + @property + def import_directives(self): + """ list(str): Import directives""" + return self._import_directives + + @property + def source_code(self): + """ {filename: source_code}: source code """ + return self._raw_source_code + def get_contract_from_name(self, contract_name): """ Return a contract from a name @@ -57,4 +80,4 @@ class Slither: """ for c in self.contracts: for f in c.functions: - f.cfg_to_dot(os.path.join(d,'{}.{}.dot'.format(c.name, f.name))) + f.cfg_to_dot(os.path.join(d, '{}.{}.dot'.format(c.name, f.name))) diff --git a/slither/core/solidityTypes/functionType.py b/slither/core/solidityTypes/functionType.py deleted file mode 100644 index 90b7cbf21..000000000 --- a/slither/core/solidityTypes/functionType.py +++ /dev/null @@ -1,26 +0,0 @@ -from slither.core.solidityTypes.type import Type -from slither.core.variables.functionTypeVariable import FunctionTypeVariable -from slither.core.expressions.expression import Expression - -class FunctionType(Type): - - def __init__(self, params, return_values): - assert all(isinstance(x, FunctionTypeVariable) for x in params) - assert all(isinstance(x, FunctionTypeVariable) for x in return_values) - super(FunctionType, self).__init__() - self._params = params - self._return_values = return_values - - @property - def params(self): - return self._params - - @property - def return_values(self): - return self._return_values - - def __str__(self): - params = ".".join([str(x) for x in self._params]) - return_values = ".".join([str(x) for x in self._return_values]) - return 'function({}) returns ({})'.format(params, return_values) - diff --git a/slither/core/solidityTypes/type.py b/slither/core/solidityTypes/type.py deleted file mode 100644 index a7daac91f..000000000 --- a/slither/core/solidityTypes/type.py +++ /dev/null @@ -1,3 +0,0 @@ -from slither.core.sourceMapping.sourceMapping import SourceMapping - -class Type(SourceMapping): pass diff --git a/slither/core/solidityTypes/userDefinedType.py b/slither/core/solidityTypes/userDefinedType.py deleted file mode 100644 index f0ac1d6d2..000000000 --- a/slither/core/solidityTypes/userDefinedType.py +++ /dev/null @@ -1,23 +0,0 @@ -from slither.core.solidityTypes.type import Type - -from slither.core.declarations.structure import Structure -from slither.core.declarations.enum import Enum -from slither.core.declarations.contract import Contract - - -class UserDefinedType(Type): - - def __init__(self, t): - assert isinstance(t, (Contract, Enum, Structure)) - super(UserDefinedType, self).__init__() - self._type = t - - @property - def type(self): - return self._type - - def __str__(self): - if isinstance(self.type, (Enum, Structure)): - return str(self.type.contract)+'.'+str(self.type.name) - return str(self.type.name) - diff --git a/slither/core/solidity_types/__init__.py b/slither/core/solidity_types/__init__.py new file mode 100644 index 000000000..b7ec244bf --- /dev/null +++ b/slither/core/solidity_types/__init__.py @@ -0,0 +1,5 @@ +from .array_type import ArrayType +from .elementary_type import ElementaryType +from .function_type import FunctionType +from .mapping_type import MappingType +from .user_defined_type import UserDefinedType diff --git a/slither/core/solidityTypes/arrayType.py b/slither/core/solidity_types/array_type.py similarity index 53% rename from slither/core/solidityTypes/arrayType.py rename to slither/core/solidity_types/array_type.py index 1eec105d1..49cb825f1 100644 --- a/slither/core/solidityTypes/arrayType.py +++ b/slither/core/solidity_types/array_type.py @@ -1,4 +1,5 @@ -from slither.core.solidityTypes.type import Type +from slither.core.variables.variable import Variable +from slither.core.solidity_types.type import Type from slither.core.expressions.expression import Expression class ArrayType(Type): @@ -21,6 +22,16 @@ class ArrayType(Type): def __str__(self): if self._length: + if isinstance(self._length.value, Variable) and self._length.value.is_constant: + return str(self._type)+'[{}]'.format(str(self._length.value.expression)) return str(self._type)+'[{}]'.format(str(self._length)) return str(self._type)+'[]' + + def __eq__(self, other): + if not isinstance(other, ArrayType): + return False + return self._type == other.type and self.length == other.length + + def __hash__(self): + return hash(str(self)) diff --git a/slither/core/solidityTypes/elementaryType.py b/slither/core/solidity_types/elementary_type.py similarity index 89% rename from slither/core/solidityTypes/elementaryType.py rename to slither/core/solidity_types/elementary_type.py index a2354685b..99435a6f5 100644 --- a/slither/core/solidityTypes/elementaryType.py +++ b/slither/core/solidity_types/elementary_type.py @@ -1,6 +1,6 @@ import itertools -from slither.core.solidityTypes.type import Type +from slither.core.solidity_types.type import Type # see https://solidity.readthedocs.io/en/v0.4.24/miscellaneous.html?highlight=grammar @@ -48,3 +48,11 @@ class ElementaryType(Type): def __str__(self): return self._type + def __eq__(self, other): + if not isinstance(other, ElementaryType): + return False + return self.type == other.type + + def __hash__(self): + return hash(str(self)) + diff --git a/slither/core/solidity_types/function_type.py b/slither/core/solidity_types/function_type.py new file mode 100644 index 000000000..3afb18de4 --- /dev/null +++ b/slither/core/solidity_types/function_type.py @@ -0,0 +1,66 @@ +from slither.core.solidity_types.type import Type +from slither.core.variables.function_type_variable import FunctionTypeVariable +from slither.core.expressions.expression import Expression + +class FunctionType(Type): + + def __init__(self, params, return_values): + assert all(isinstance(x, FunctionTypeVariable) for x in params) + assert all(isinstance(x, FunctionTypeVariable) for x in return_values) + super(FunctionType, self).__init__() + self._params = params + self._return_values = return_values + + @property + def params(self): + return self._params + + @property + def return_values(self): + return self._return_values + + @property + def return_type(self): + return [x.type for x in self.return_values] + + def __str__(self): + # Use x.type + # x.name may be empty + params = ",".join([str(x.type) for x in self._params]) + return_values = ",".join([str(x.type) for x in self._return_values]) + if return_values: + return 'function({}) returns({})'.format(params, return_values) + return 'function({})'.format(params) + + @property + def parameters_signature(self): + ''' + Return the parameters signature(without the return statetement) + ''' + # Use x.type + # x.name may be empty + params = ",".join([str(x.type) for x in self._params]) + return '({})'.format(params) + + @property + def signature(self): + ''' + Return the signature(with the return statetement if it exists) + ''' + # Use x.type + # x.name may be empty + params = ",".join([str(x.type) for x in self._params]) + return_values = ",".join([str(x.type) for x in self._return_values]) + if return_values: + return '({}) returns({})'.format(params, return_values) + return '({})'.format(params) + + + + def __eq__(self, other): + if not isinstance(other, FunctionType): + return False + return self.params == other.params and self.return_values == other.return_values + + def __hash__(self): + return hash(str(self)) diff --git a/slither/core/solidityTypes/mappingType.py b/slither/core/solidity_types/mapping_type.py similarity index 62% rename from slither/core/solidityTypes/mappingType.py rename to slither/core/solidity_types/mapping_type.py index c718c8f81..1c016c8c6 100644 --- a/slither/core/solidityTypes/mappingType.py +++ b/slither/core/solidity_types/mapping_type.py @@ -1,4 +1,4 @@ -from slither.core.solidityTypes.type import Type +from slither.core.solidity_types.type import Type class MappingType(Type): @@ -19,3 +19,12 @@ class MappingType(Type): def __str__(self): return 'mapping({} => {}'.format(str(self._from), str(self._to)) + + def __eq__(self, other): + if not isinstance(other, MappingType): + return False + return self.type_from == other.type_from and self.type_to == other.type_to + + def __hash__(self): + return hash(str(self)) + diff --git a/slither/core/solidity_types/type.py b/slither/core/solidity_types/type.py new file mode 100644 index 000000000..1c2794bec --- /dev/null +++ b/slither/core/solidity_types/type.py @@ -0,0 +1,3 @@ +from slither.core.source_mapping.source_mapping import SourceMapping + +class Type(SourceMapping): pass diff --git a/slither/core/solidity_types/user_defined_type.py b/slither/core/solidity_types/user_defined_type.py new file mode 100644 index 000000000..bfabb3c85 --- /dev/null +++ b/slither/core/solidity_types/user_defined_type.py @@ -0,0 +1,35 @@ +from slither.core.solidity_types.type import Type + + +class UserDefinedType(Type): + + def __init__(self, t): + from slither.core.declarations.structure import Structure + from slither.core.declarations.enum import Enum + from slither.core.declarations.contract import Contract + + assert isinstance(t, (Contract, Enum, Structure)) + super(UserDefinedType, self).__init__() + self._type = t + + @property + def type(self): + return self._type + + def __str__(self): + from slither.core.declarations.structure import Structure + from slither.core.declarations.enum import Enum + + if isinstance(self.type, (Enum, Structure)): + return str(self.type.contract)+'.'+str(self.type.name) + return str(self.type.name) + + def __eq__(self, other): + if not isinstance(other, UserDefinedType): + return False + return self.type == other.type + + + def __hash__(self): + return hash(str(self)) + diff --git a/slither/core/sourceMapping/sourceMapping.py b/slither/core/sourceMapping/sourceMapping.py deleted file mode 100644 index 14b900522..000000000 --- a/slither/core/sourceMapping/sourceMapping.py +++ /dev/null @@ -1,18 +0,0 @@ -from slither.core.context.context import Context - -class SourceMapping(Context): - - def __init__(self): - super(SourceMapping, self).__init__() - self._source_mapping = None - self._offset = None - - def set_source_mapping(self, source_mapping): - self._source_mapping = source_mapping - - @property - def source_mapping(self): - return self._source_mapping - - def set_offset(self, offset): - self._offset = offset diff --git a/slither/solcParsing/cfg/__init__.py b/slither/core/source_mapping/__init__.py similarity index 100% rename from slither/solcParsing/cfg/__init__.py rename to slither/core/source_mapping/__init__.py diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py new file mode 100644 index 000000000..88d66db5f --- /dev/null +++ b/slither/core/source_mapping/source_mapping.py @@ -0,0 +1,89 @@ +import re +from slither.core.context.context import Context + +class SourceMapping(Context): + + def __init__(self): + super(SourceMapping, self).__init__() + self._source_mapping = None + + @property + def source_mapping(self): + return self._source_mapping + + @staticmethod + def _compute_line(source_code, start, length): + """ + Compute line(s) number from a start/end offset + Not done in an efficient way + """ + total_length = len(source_code) + source_code = source_code.split('\n') + counter = 0 + i = 0 + lines = [] + while counter < total_length: + counter += len(source_code[i]) +1 + i = i+1 + if counter > start: + lines.append(i) + if counter > start+length: + break + return lines + + @staticmethod + def _convert_source_mapping(offset, slither): + ''' + Convert a text offset to a real offset + see https://solidity.readthedocs.io/en/develop/miscellaneous.html#source-mappings + Returns: + (dict): {'start':0, 'length':0, 'filename': 'file.sol'} + ''' + sourceUnits = slither.source_units + + position = re.findall('([0-9]*):([0-9]*):([-]?[0-9]*)', offset) + if len(position) != 1: + return {} + + s, l, f = position[0] + s = int(s) + l = int(l) + f = int(f) + + if f not in sourceUnits: + return {'start':s, 'length':l} + filename = sourceUnits[f] + + lines = [] + + if filename in slither.source_code: + lines = SourceMapping._compute_line(slither.source_code[filename], s, l) + + return {'start':s, 'length':l, 'filename': filename, 'lines' : lines } + + def set_offset(self, offset, slither): + if isinstance(offset, dict): + self._source_mapping = offset + else: + self._source_mapping = self._convert_source_mapping(offset, slither) + + + @property + def source_mapping_str(self): + + def relative_path(path): + # Remove absolute path for printing + # Truffle returns absolutePath + if '/contracts/' in path: + return path[path.find('/contracts/'):] + return path + + lines = self.source_mapping['lines'] + if not lines: + lines = '' + elif len(lines) == 1: + lines = '#{}'.format(lines[0]) + else: + lines = '#{}-{}'.format(lines[0], lines[-1]) + return '{}{}'.format(relative_path(self.source_mapping['filename']), lines) + diff --git a/slither/core/variables/eventVariable.py b/slither/core/variables/event_variable.py similarity index 58% rename from slither/core/variables/eventVariable.py rename to slither/core/variables/event_variable.py index 26daced04..a6ac7c0a3 100644 --- a/slither/core/variables/eventVariable.py +++ b/slither/core/variables/event_variable.py @@ -1,5 +1,5 @@ from .variable import Variable -from slither.core.children.childEvent import ChildEvent +from slither.core.children.child_event import ChildEvent class EventVariable(ChildEvent, Variable): pass diff --git a/slither/core/variables/functionTypeVariable.py b/slither/core/variables/function_type_variable.py similarity index 100% rename from slither/core/variables/functionTypeVariable.py rename to slither/core/variables/function_type_variable.py diff --git a/slither/core/variables/localVariable.py b/slither/core/variables/localVariable.py deleted file mode 100644 index 53938aaa4..000000000 --- a/slither/core/variables/localVariable.py +++ /dev/null @@ -1,5 +0,0 @@ -from .variable import Variable -from slither.core.children.childFunction import ChildFunction - -class LocalVariable(ChildFunction, Variable): pass - diff --git a/slither/core/variables/local_variable.py b/slither/core/variables/local_variable.py new file mode 100644 index 000000000..2ac3baad2 --- /dev/null +++ b/slither/core/variables/local_variable.py @@ -0,0 +1,48 @@ +from .variable import Variable +from slither.core.children.child_function import ChildFunction +from slither.core.solidity_types.user_defined_type import UserDefinedType +from slither.core.solidity_types.array_type import ArrayType + +from slither.core.declarations.structure import Structure + + +class LocalVariable(ChildFunction, Variable): + + def __init__(self): + super(LocalVariable, self).__init__() + self._location = None + + + def set_location(self, loc): + self._location = loc + + @property + def location(self): + ''' + Variable Location + Can be storage/memory or default + Returns: + (str) + ''' + return self._location + + @property + def is_storage(self): + """ + Return true if the variable is located in storage + See https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location + Returns: + (bool) + """ + if self.location == 'memory': + return False + if self.location == 'storage': + return True + + if isinstance(self.type, ArrayType): + return True + + if isinstance(self.type, UserDefinedType): + return isinstance(self.type.type, Structure) + + return False diff --git a/slither/core/variables/localVariableInitFromTuple.py b/slither/core/variables/local_variable_init_from_tuple.py similarity index 88% rename from slither/core/variables/localVariableInitFromTuple.py rename to slither/core/variables/local_variable_init_from_tuple.py index d8a314915..09ca7a361 100644 --- a/slither/core/variables/localVariableInitFromTuple.py +++ b/slither/core/variables/local_variable_init_from_tuple.py @@ -1,4 +1,4 @@ -from slither.core.variables.localVariable import LocalVariable +from slither.core.variables.local_variable import LocalVariable class LocalVariableInitFromTuple(LocalVariable): """ diff --git a/slither/core/variables/stateVariable.py b/slither/core/variables/state_variable.py similarity index 56% rename from slither/core/variables/stateVariable.py rename to slither/core/variables/state_variable.py index c5e7f711b..1f241e799 100644 --- a/slither/core/variables/stateVariable.py +++ b/slither/core/variables/state_variable.py @@ -1,4 +1,4 @@ from .variable import Variable -from slither.core.children.childContract import ChildContract +from slither.core.children.child_contract import ChildContract class StateVariable(ChildContract, Variable): pass diff --git a/slither/core/variables/structureVariable.py b/slither/core/variables/structure_variable.py similarity index 57% rename from slither/core/variables/structureVariable.py rename to slither/core/variables/structure_variable.py index 5f4920e2b..1c0c188e1 100644 --- a/slither/core/variables/structureVariable.py +++ b/slither/core/variables/structure_variable.py @@ -1,5 +1,5 @@ from .variable import Variable -from slither.core.children.childStructure import ChildStructure +from slither.core.children.child_structure import ChildStructure class StructureVariable(ChildStructure, Variable): pass diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index a2ac28954..0ab0a99ca 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -2,8 +2,9 @@ Variable module """ -from slither.core.sourceMapping.sourceMapping import SourceMapping - +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.solidity_types.type import Type +from slither.core.solidity_types.elementary_type import ElementaryType class Variable(SourceMapping): @@ -17,16 +18,24 @@ class Variable(SourceMapping): self._mappingTo = None self._initial_expression = None self._type = None - self._expression = None self._initialized = None self._visibility = None + self._is_constant = False @property def expression(self): """ Expression: Expression of the node (if initialized) + Initial expression may be different than the expression of the node + where the variable is declared, if its used ternary operator + Ex: uint a = b?1:2 + The expression associated to a is uint a = b?1:2 + But two nodes are created, + one where uint a = 1, + and one where uint a = 2 + """ - return self._expression + return self._initial_expression @property def initialized(self): @@ -53,6 +62,10 @@ class Variable(SourceMapping): def type(self): return self._type + @property + def is_constant(self): + return self._is_constant + @property def visibility(self): ''' @@ -61,6 +74,9 @@ class Variable(SourceMapping): return self._visibility def set_type(self, t): + if isinstance(t, str): + t = ElementaryType(t) + assert isinstance(t, (Type, list)) or t is None self._type = t def __str__(self): diff --git a/slither/detectors/abstractDetector.py b/slither/detectors/abstractDetector.py deleted file mode 100644 index 2d86c32a0..000000000 --- a/slither/detectors/abstractDetector.py +++ /dev/null @@ -1,48 +0,0 @@ -import abc -import re -from slither.detectors.detectorClassification import DetectorClassification -from slither.utils.colors import green, yellow, red - -class IncorrectDetectorInitialization(Exception): - pass - -class AbstractDetector(object, metaclass=abc.ABCMeta): - ARGUMENT = '' # run the detector with slither.py --ARGUMENT - HELP = '' # help information - CLASSIFICATION = None - - HIDDEN_DETECTOR = False # yes if the detector should not be showed - - def __init__(self, slither, logger): - self.slither = slither - self.contracts = slither.contracts - self.filename = slither.filename - self.logger = logger - if self.HELP == '': - raise IncorrectDetectorInitialization('HELP is not initialized') - if self.ARGUMENT == '': - raise IncorrectDetectorInitialization('ARGUMENT is not initialized') - if re.match('^[a-zA-Z0-9_-]*$', self.ARGUMENT) is None: - raise IncorrectDetectorInitialization('ARGUMENT has illegal character') - if not self.CLASSIFICATION in [DetectorClassification.LOW, - DetectorClassification.MEDIUM, - DetectorClassification.HIGH]: - raise IncorrectDetectorInitialization('CLASSIFICATION is not initialized') - - def log(self, info): - if self.logger: - self.logger.info(self.color(info)) - - @abc.abstractmethod - def detect(self): - """TODO Documentation""" - return - - @property - def color(self): - if self.CLASSIFICATION == DetectorClassification.LOW: - return green - if self.CLASSIFICATION == DetectorClassification.MEDIUM: - return yellow - if self.CLASSIFICATION == DetectorClassification.HIGH: - return red diff --git a/slither/detectors/abstract_detector.py b/slither/detectors/abstract_detector.py new file mode 100644 index 000000000..d0015d09b --- /dev/null +++ b/slither/detectors/abstract_detector.py @@ -0,0 +1,81 @@ +import abc +import re + +from slither.utils.colors import green, yellow, red + + +class IncorrectDetectorInitialization(Exception): + pass + + +class DetectorClassification: + HIGH = 0 + MEDIUM = 1 + LOW = 2 + INFORMATIONAL = 3 + + +classification_colors = { + DetectorClassification.INFORMATIONAL: green, + DetectorClassification.LOW: green, + DetectorClassification.MEDIUM: yellow, + DetectorClassification.HIGH: red, +} + +classification_txt = { + DetectorClassification.INFORMATIONAL: 'Informational', + DetectorClassification.LOW: 'Low', + DetectorClassification.MEDIUM: 'Medium', + DetectorClassification.HIGH: 'High', +} + +class AbstractDetector(metaclass=abc.ABCMeta): + ARGUMENT = '' # run the detector with slither.py --ARGUMENT + HELP = '' # help information + IMPACT = None + CONFIDENCE = None + + WIKI = '' + + def __init__(self, slither, logger): + self.slither = slither + self.contracts = slither.contracts + self.filename = slither.filename + self.logger = logger + + if not self.HELP: + raise IncorrectDetectorInitialization('HELP is not initialized {}'.format(self.__class__.__name__)) + + if not self.ARGUMENT: + raise IncorrectDetectorInitialization('ARGUMENT is not initialized {}'.format(self.__class__.__name__)) + + if re.match('^[a-zA-Z0-9_-]*$', self.ARGUMENT) is None: + raise IncorrectDetectorInitialization('ARGUMENT has illegal character {}'.format(self.__class__.__name__)) + + if self.IMPACT not in [DetectorClassification.LOW, + DetectorClassification.MEDIUM, + DetectorClassification.HIGH, + DetectorClassification.INFORMATIONAL]: + raise IncorrectDetectorInitialization('IMPACT is not initialized {}'.format(self.__class__.__name__)) + + if self.CONFIDENCE not in [DetectorClassification.LOW, + DetectorClassification.MEDIUM, + DetectorClassification.HIGH, + DetectorClassification.INFORMATIONAL]: + raise IncorrectDetectorInitialization('CONFIDENCE is not initialized {}'.format(self.__class__.__name__)) + + def log(self, info): + if self.logger: + info = "\n"+info + if self.WIKI != '': + info += 'Reference: {}'.format(self.WIKI) + self.logger.info(self.color(info)) + + @abc.abstractmethod + def detect(self): + """TODO Documentation""" + return + + @property + def color(self): + return classification_colors[self.IMPACT] diff --git a/slither/solcParsing/declarations/__init__.py b/slither/detectors/attributes/__init__.py similarity index 100% rename from slither/solcParsing/declarations/__init__.py rename to slither/detectors/attributes/__init__.py diff --git a/slither/detectors/attributes/constant_pragma.py b/slither/detectors/attributes/constant_pragma.py new file mode 100644 index 000000000..0d96e93e9 --- /dev/null +++ b/slither/detectors/attributes/constant_pragma.py @@ -0,0 +1,38 @@ +""" + Check that the same pragma is used in all the files +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + + +class ConstantPragma(AbstractDetector): + """ + Check that the same pragma is used in all the files + """ + + ARGUMENT = 'pragma' + HELP = 'If different pragma directives are used' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#state-variables-that-could-be-declared-constant' + + def detect(self): + results = [] + pragma = self.slither.pragma_directives + versions = [p.version for p in pragma] + versions = list(set(versions)) + + if len(versions) > 1: + info = "Different versions of Solidity is used in {}:\n".format(self.filename) + for p in pragma: + info += "\t- {} declares {}\n".format(p.source_mapping_str, str(p)) + self.log(info) + + source = [p.source_mapping for p in pragma] + + results.append({'vuln': 'ConstantPragma', + 'versions': versions, + 'sourceMapping': source}) + + return results diff --git a/slither/detectors/attributes/locked_ether.py b/slither/detectors/attributes/locked_ether.py new file mode 100644 index 000000000..7f3b75024 --- /dev/null +++ b/slither/detectors/attributes/locked_ether.py @@ -0,0 +1,66 @@ +""" + Check if ether are locked in the contract +""" + +from slither.detectors.abstract_detector import (AbstractDetector, + DetectorClassification) +from slither.slithir.operations import (HighLevelCall, LowLevelCall, Send, + Transfer) + + +class LockedEther(AbstractDetector): + """ + """ + + ARGUMENT = 'locked-ether' + HELP = "Contracts that lock ether" + IMPACT = DetectorClassification.MEDIUM + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#contracts-that-lock-ether' + + @staticmethod + def do_no_send_ether(contract): + functions = contract.all_functions_called + for function in functions: + calls = [c.name for c in function.internal_calls] + if 'suicide(address)' in calls or 'selfdestruct(address)' in calls: + return False + for node in function.nodes: + for ir in node.irs: + if isinstance(ir, (Send, Transfer, HighLevelCall, LowLevelCall)): + if ir.call_value and ir.call_value != 0: + return False + if isinstance(ir, (LowLevelCall)): + if ir.function_name in ['delegatecall', 'callcode']: + return False + return True + + + def detect(self): + results = [] + + for contract in self.slither.contracts_derived: + if contract.is_signature_only(): + continue + funcs_payable = [function for function in contract.functions if function.payable] + if funcs_payable: + if self.do_no_send_ether(contract): + txt = "Contract locking ether found in {}:\n".format(self.filename) + txt += "\tContract {} has payable functions:\n".format(contract.name) + for function in funcs_payable: + txt += "\t - {} ({})\n".format(function.name, function.source_mapping_str) + txt += "\tBut has not function to withdraw the ether\n" + info = txt.format(self.filename, + contract.name, + [f.name for f in funcs_payable]) + self.log(info) + + source = [f.source_mapping for f in funcs_payable] + + results.append({'vuln': 'LockedEther', + 'functions_payable' : [f.name for f in funcs_payable], + 'contract': contract.name, + 'sourceMapping': source}) + + return results diff --git a/slither/detectors/attributes/old_solc.py b/slither/detectors/attributes/old_solc.py new file mode 100644 index 000000000..4bd21c3df --- /dev/null +++ b/slither/detectors/attributes/old_solc.py @@ -0,0 +1,41 @@ +""" + Check if an old version of solc is used + Solidity >= 0.4.23 should be used +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + + +class OldSolc(AbstractDetector): + """ + Check if an old version of solc is used + """ + + ARGUMENT = 'solc-version' + HELP = 'Old versions of Solidity (< 0.4.23)' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#old-versions-of-solidity' + + @staticmethod + def _convert_pragma(version): + return version.replace('solidity', '').replace('^', '') + + def detect(self): + results = [] + pragma = self.slither.pragma_directives + old_pragma = [p for p in pragma if self._convert_pragma(p.version) not in ['0.4.23', '0.4.24']] + + if old_pragma: + info = "Old version (<0.4.23) of Solidity used in {}:\n".format(self.filename) + for p in old_pragma: + info += "\t- {} declares {}\n".format(p.source_mapping_str, str(p)) + self.log(info) + + source = [p.source_mapping for p in pragma] + results.append({'vuln': 'OldPragma', + 'pragma': [p.version for p in old_pragma], + 'sourceMapping': source}) + + return results diff --git a/slither/detectors/detectorClassification.py b/slither/detectors/detectorClassification.py deleted file mode 100644 index 52dde20a9..000000000 --- a/slither/detectors/detectorClassification.py +++ /dev/null @@ -1,4 +0,0 @@ -class DetectorClassification: - LOW = 0 - MEDIUM = 1 - HIGH = 2 diff --git a/slither/detectors/detectors.py b/slither/detectors/detectors.py deleted file mode 100644 index 7e79b8443..000000000 --- a/slither/detectors/detectors.py +++ /dev/null @@ -1,47 +0,0 @@ -import sys, inspect -import os -import logging - -from slither.detectors.abstractDetector import AbstractDetector -from slither.detectors.detectorClassification import DetectorClassification - -# Detectors must be imported here -from slither.detectors.examples.backdoor import Backdoor -from slither.detectors.variables.uninitializedStateVarsDetection import UninitializedStateVarsDetection - -### - -logger_detector = logging.getLogger("Detectors") - -class Detectors: - - def __init__(self): - self.detectors = {} - self.low = [] - self.medium = [] - self.high = [] - - self._load_detectors() - - def _load_detectors(self): - for name, obj in inspect.getmembers(sys.modules[__name__]): - if inspect.isclass(obj): - if issubclass(obj, AbstractDetector) and name != 'AbstractDetector': - if obj.HIDDEN_DETECTOR: - continue - if name in self.detectors: - raise Exception('Detector name collision: {}'.format(name)) - self.detectors[name] = obj - if obj.CLASSIFICATION == DetectorClassification.LOW: - self.low.append(name) - elif obj.CLASSIFICATION == DetectorClassification.MEDIUM: - self.medium.append(name) - elif obj.CLASSIFICATION == DetectorClassification.HIGH: - self.high.append(name) - else: - raise Exception('Unknown classification') - - def run_detector(self, slither, name): - Detector = self.detectors[name] - instance = Detector(slither, logger_detector) - return instance.detect() diff --git a/slither/detectors/examples/backdoor.py b/slither/detectors/examples/backdoor.py index e4c7d41a7..76fe19fcc 100644 --- a/slither/detectors/examples/backdoor.py +++ b/slither/detectors/examples/backdoor.py @@ -1,25 +1,30 @@ -from slither.detectors.abstractDetector import AbstractDetector -from slither.detectors.detectorClassification import DetectorClassification +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + class Backdoor(AbstractDetector): """ Detect function named backdoor """ - ARGUMENT = 'backdoor' # slither will launch the detector with slither.py --mydetector + ARGUMENT = 'backdoor' # slither will launch the detector with slither.py --mydetector HELP = 'Function named backdoor (detector example)' - CLASSIFICATION = DetectorClassification.HIGH + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.HIGH def detect(self): ret = [] for contract in self.slither.contracts_derived: # Check if a function has 'backdoor' in its name - if any('backdoor' in f.name for f in contract.functions): - # Info to be printed - info = 'Backdoor function found in {}'.format(contract.name) - # Print the info - self.log(info) - # Add the result in ret - ret.append({'vuln':'backdoor', 'contract':contract.name}) + for f in contract.functions: + if 'backdoor' in f.name: + # Info to be printed + info = 'Backdoor function found in {}.{} ({})\n' + info = info.format(contract.name, f.name, f.source_mapping_str) + # Print the info + self.log(info) + # Add the result in ret + source = f.source_mapping + ret.append({'vuln': 'backdoor', 'contract': contract.name, 'sourceMapping' : source}) + return ret diff --git a/slither/solcParsing/expressions/__init__.py b/slither/detectors/functions/__init__.py similarity index 100% rename from slither/solcParsing/expressions/__init__.py rename to slither/detectors/functions/__init__.py diff --git a/slither/detectors/functions/arbitrary_send.py b/slither/detectors/functions/arbitrary_send.py new file mode 100644 index 000000000..bb702a59b --- /dev/null +++ b/slither/detectors/functions/arbitrary_send.py @@ -0,0 +1,122 @@ +""" + Module detecting send to arbitrary address + + To avoid FP, it does not report: + - If msg.sender is used as index (withdraw situation) + - If the function is protected + - If the value sent is msg.value (repay situation) + + TODO: dont report if the value is tainted by msg.value +""" + +from slither.analyses.taint.calls import KEY +from slither.analyses.taint.calls import run_taint as run_taint_calls +from slither.analyses.taint.specific_variable import is_tainted +from slither.analyses.taint.specific_variable import \ + run_taint as run_taint_variable +from slither.core.declarations.solidity_variables import (SolidityFunction, + SolidityVariableComposed) +from slither.detectors.abstract_detector import (AbstractDetector, + DetectorClassification) +from slither.slithir.operations import (HighLevelCall, Index, LowLevelCall, + Send, SolidityCall, Transfer) + + +class ArbitrarySend(AbstractDetector): + """ + """ + + ARGUMENT = 'arbitrary-send' + HELP = 'Functions that send ether to arbitrary destinations' + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.MEDIUM + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#functions-that-send-ether-to-arbitrary-destinations' + + @staticmethod + def arbitrary_send(func): + """ + """ + if func.is_protected(): + return [] + + ret = [] + for node in func.nodes: + for ir in node.irs: + if isinstance(ir, SolidityCall): + if ir.function == SolidityFunction('ecrecover(bytes32,uint8,bytes32,bytes32)'): + return False + if isinstance(ir, Index): + if ir.variable_right == SolidityVariableComposed('msg.sender'): + return False + if is_tainted(ir.variable_right, SolidityVariableComposed('msg.sender')): + return False + if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): + if ir.call_value is None: + continue + if ir.call_value == SolidityVariableComposed('msg.value'): + continue + if is_tainted(ir.call_value, SolidityVariableComposed('msg.value')): + continue + + if KEY in ir.context: + if ir.context[KEY]: + ret.append(node) + return ret + + + def detect_arbitrary_send(self, contract): + """ + Detect arbitrary send + Args: + contract (Contract) + Returns: + list((Function), (list (Node))) + """ + ret = [] + for f in [f for f in contract.functions if f.contract == contract]: + nodes = self.arbitrary_send(f) + if nodes: + ret.append((f, nodes)) + return ret + + def detect(self): + """ + """ + results = [] + + # Look if the destination of a call is tainted + run_taint_calls(self.slither) + + # Taint msg.value + taint = SolidityVariableComposed('msg.value') + run_taint_variable(self.slither, taint) + + # Taint msg.sender + taint = SolidityVariableComposed('msg.sender') + run_taint_variable(self.slither, taint) + + for c in self.contracts: + arbitrary_send = self.detect_arbitrary_send(c) + for (func, nodes) in arbitrary_send: + calls_str = [str(node.expression) for node in nodes] + + info = "{}{} sends eth to arbirary user\n" + info = info.format(func.contract.name, + func.name) + info += '\tDangerous calls:\n' + for node in nodes: + info += '\t- {} ({})\n'.format(node.expression, node.source_mapping_str) + + self.log(info) + + source_mapping = [node.source_mapping for node in nodes] + + results.append({'vuln': 'ArbitrarySend', + 'sourceMapping': source_mapping, + 'filename': self.filename, + 'contract': func.contract.name, + 'func': func.name, + 'calls': calls_str}) + + return results diff --git a/slither/detectors/functions/complex_function.py b/slither/detectors/functions/complex_function.py new file mode 100644 index 000000000..ab1260e21 --- /dev/null +++ b/slither/detectors/functions/complex_function.py @@ -0,0 +1,116 @@ +from slither.core.declarations.solidity_variables import (SolidityFunction, + SolidityVariableComposed) +from slither.detectors.abstract_detector import (AbstractDetector, + DetectorClassification) +from slither.slithir.operations import (HighLevelCall, + LowLevelCall, + LibraryCall) +from slither.utils.code_complexity import compute_cyclomatic_complexity + + +class ComplexFunction(AbstractDetector): + """ + Module detecting complex functions + A complex function is defined by: + - high cyclomatic complexity + - numerous writes to state variables + - numerous external calls + """ + + + ARGUMENT = 'complex-function' + HELP = 'Complex functions' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.MEDIUM + + MAX_STATE_VARIABLES = 10 + MAX_EXTERNAL_CALLS = 5 + MAX_CYCLOMATIC_COMPLEXITY = 7 + + CAUSE_CYCLOMATIC = "cyclomatic" + CAUSE_EXTERNAL_CALL = "external_calls" + CAUSE_STATE_VARS = "state_vars" + + + @staticmethod + def detect_complex_func(func): + """Detect the cyclomatic complexity of the contract functions + shouldn't be greater than 7 + """ + result = [] + code_complexity = compute_cyclomatic_complexity(func) + + if code_complexity > ComplexFunction.MAX_CYCLOMATIC_COMPLEXITY: + result.append({ + "func": func, + "cause": ComplexFunction.CAUSE_CYCLOMATIC + }) + + """Detect the number of external calls in the func + shouldn't be greater than 5 + """ + count = 0 + for node in func.nodes: + for ir in node.irs: + if isinstance(ir, (HighLevelCall, LowLevelCall, LibraryCall)): + count += 1 + + if count > ComplexFunction.MAX_EXTERNAL_CALLS: + result.append({ + "func": func, + "cause": ComplexFunction.CAUSE_EXTERNAL_CALL + }) + + """Checks the number of the state variables written + shouldn't be greater than 10 + """ + if len(func.state_variables_written) > ComplexFunction.MAX_STATE_VARIABLES: + result.append({ + "func": func, + "cause": ComplexFunction.CAUSE_STATE_VARS + }) + + return result + + def detect_complex(self, contract): + ret = [] + + for func in contract.all_functions_called: + result = self.detect_complex_func(func) + ret.extend(result) + + return ret + + def detect(self): + results = [] + + for contract in self.contracts: + issues = self.detect_complex(contract) + + for issue in issues: + func, cause = issue.values() + func_name = func.name + + txt = "Complex function in {}\n\t- {}.{} ({})\n" + + if cause == self.CAUSE_EXTERNAL_CALL: + txt += "\t- Reason: High number of external calls" + if cause == self.CAUSE_CYCLOMATIC: + txt += "\t- Reason: High number of branches" + if cause == self.CAUSE_STATE_VARS: + txt += "\t- Reason: High number of modified state variables" + + info = txt.format(self.filename, + contract.name, + func_name, + func.source_mapping_str) + info = info + "\n" + self.log(info) + + results.append({'vuln': 'ComplexFunc', + 'sourceMapping': func.source_mapping, + 'filename': self.filename, + 'contract': contract.name, + 'func': func_name}) + return results + diff --git a/slither/detectors/functions/external_function.py b/slither/detectors/functions/external_function.py new file mode 100644 index 000000000..d2f1cc322 --- /dev/null +++ b/slither/detectors/functions/external_function.py @@ -0,0 +1,74 @@ +from slither.detectors.abstract_detector import (AbstractDetector, + DetectorClassification) +from slither.slithir.operations import (HighLevelCall, SolidityCall ) +from slither.slithir.operations import (InternalCall, InternalDynamicCall) + +class ExternalFunction(AbstractDetector): + """ + Detect public function that could be declared as external + + IMPROVEMENT: Add InternalDynamicCall check + https://github.com/trailofbits/slither/pull/53#issuecomment-432809950 + """ + + ARGUMENT = 'external-function' + HELP = 'Public function that could be declared as external' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#public-function-that-could-be-declared-as-external' + + @staticmethod + def detect_functions_called(contract): + """ Returns a list of InternallCall, SolidityCall + calls made in a function + + Returns: + (list): List of all InternallCall, SolidityCall + """ + result = [] + for func in contract.all_functions_called: + for node in func.nodes: + for ir in node.irs: + if isinstance(ir, (InternalCall, SolidityCall)): + result.append(ir.function) + return result + + @staticmethod + def _contains_internal_dynamic_call(contract): + for func in contract.all_functions_called: + for node in func.nodes: + for ir in node.irs: + if isinstance(ir, (InternalDynamicCall)): + return True + return False + + def detect(self): + results = [] + + public_function_calls = [] + all_info = '' + + for contract in self.slither.contracts_derived: + if self._contains_internal_dynamic_call(contract): + continue + + func_list = self.detect_functions_called(contract) + public_function_calls.extend(func_list) + + for func in [f for f in contract.functions if f.visibility == 'public' and\ + not f in public_function_calls and\ + not f.is_constructor]: + txt = "{}.{} ({}) should be declared external\n" + info = txt.format(func.contract.name, + func.name, + func.source_mapping_str) + all_info += info + results.append({'vuln': 'ExternalFunc', + 'sourceMapping': func.source_mapping, + 'filename': self.filename, + 'contract': func.contract.name, + 'func': func.name}) + if all_info != '': + self.log(all_info) + return results diff --git a/slither/detectors/functions/suicidal.py b/slither/detectors/functions/suicidal.py new file mode 100644 index 000000000..8cac68977 --- /dev/null +++ b/slither/detectors/functions/suicidal.py @@ -0,0 +1,73 @@ +""" +Module detecting suicidal contract + +A suicidal contract is an unprotected function that calls selfdestruct +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + +class Suicidal(AbstractDetector): + """ + Unprotected function detector + """ + + ARGUMENT = 'suicidal' + HELP = 'Functions allowing anyone to destruct the contract' + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#suicidal' + + @staticmethod + def detect_suicidal_func(func): + """ Detect if the function is suicidal + + Detect the public functions calling suicide/selfdestruct without protection + Returns: + (bool): True if the function is suicidal + """ + + if func.is_constructor: + return False + + if func.visibility != 'public': + return False + + calls = [c.name for c in func.internal_calls] + if not ('suicide(address)' in calls or 'selfdestruct(address)' in calls): + return False + + if func.is_protected(): + return False + + return True + + def detect_suicidal(self, contract): + ret = [] + for f in [f for f in contract.functions if f.contract == contract]: + if self.detect_suicidal_func(f): + ret.append(f) + return ret + + def detect(self): + """ Detect the suicidal functions + """ + results = [] + for c in self.contracts: + functions = self.detect_suicidal(c) + for func in functions: + + txt = "{}.{} ({}) allows anyone to destruct the contract\n" + info = txt.format(func.contract.name, + func.name, + func.source_mapping_str) + + self.log(info) + + results.append({'vuln': 'SuicidalFunc', + 'sourceMapping': func.source_mapping, + 'filename': self.filename, + 'contract': c.name, + 'func': func.name}) + + return results diff --git a/slither/solcParsing/solidityTypes/__init__.py b/slither/detectors/naming_convention/__init__.py similarity index 100% rename from slither/solcParsing/solidityTypes/__init__.py rename to slither/detectors/naming_convention/__init__.py diff --git a/slither/detectors/naming_convention/naming_convention.py b/slither/detectors/naming_convention/naming_convention.py new file mode 100644 index 000000000..a2f37ef84 --- /dev/null +++ b/slither/detectors/naming_convention/naming_convention.py @@ -0,0 +1,204 @@ +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +import re + + +class NamingConvention(AbstractDetector): + """ + Check if naming conventions are followed + https://solidity.readthedocs.io/en/v0.4.25/style-guide.html?highlight=naming_convention%20convention#naming_convention-conventions + + Exceptions: + - Allow constant variables name/symbol/decimals to be lowercase (ERC20) + - Allow '_' at the beggining of the mixed_case match for private variables and unused parameters + """ + + ARGUMENT = 'naming-convention' + HELP = 'Conformance to Solidity naming conventions' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#conformance-to-solidity-naming-conventions' + + @staticmethod + def is_cap_words(name): + return re.search('^[A-Z]([A-Za-z0-9]+)?_?$', name) is not None + + @staticmethod + def is_mixed_case(name): + return re.search('^[a-z]([A-Za-z0-9]+)?_?$', name) is not None + + @staticmethod + def is_mixed_case_with_underscore(name): + # Allow _ at the beginning to represent private variable + # or unused parameters + return re.search('^[a-z_]([A-Za-z0-9]+)?_?$', name) is not None + + @staticmethod + def is_upper_case_with_underscores(name): + return re.search('^[A-Z0-9_]+_?$', name) is not None + + @staticmethod + def should_avoid_name(name): + return re.search('^[lOI]$', name) is not None + + def detect(self): + + results = [] + all_info = '' + for contract in self.contracts: + + if not self.is_cap_words(contract.name): + info = "Contract '{}' ({}) is not in CapWords\n".format(contract.name, contract.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'sourceMapping': contract.source_mapping}) + + for struct in contract.structures: + if struct.contract != contract: + continue + + if not self.is_cap_words(struct.name): + info = "Struct '{}.{}' ({}) is not in CapWords\n" + info = info.format(struct.contract.name, struct.name, struct.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'struct': struct.name, + 'sourceMapping': struct.source_mapping}) + + for event in contract.events: + if event.contract != contract: + continue + + if not self.is_cap_words(event.name): + info = "Event '{}.{}' ({}) is not in CapWords\n" + info = info.format(event.contract.name, event.name, event.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'event': event.name, + 'sourceMapping': event.source_mapping}) + + for func in contract.functions: + if func.contract != contract: + continue + + if not self.is_mixed_case(func.name): + info = "Function '{}.{}' ({}) is not in mixedCase\n" + info = info.format(func.contract.name, func.name, func.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'function': func.name, + 'sourceMapping': func.source_mapping}) + + for argument in func.parameters: + if argument in func.variables_read_or_written: + correct_naming = self.is_mixed_case(argument.name) + else: + correct_naming = self.is_mixed_case_with_underscore(argument.name) + if not correct_naming: + info = "Parameter '{}' of {}.{} ({}) is not in mixedCase\n" + info = info.format(argument.name, + argument.function.contract.name, + argument.function, + argument.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'function': func.name, + 'argument': argument.name, + 'sourceMapping': argument.source_mapping}) + + for var in contract.state_variables: + if var.contract != contract: + continue + + if self.should_avoid_name(var.name): + if not self.is_upper_case_with_underscores(var.name): + info = "Variable '{}.{}' ({}) used l, O, I, which should not be used\n" + info = info.format(var.contract.name, var.name, var.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'constant': var.name, + 'sourceMapping': var.source_mapping}) + + if var.is_constant is True: + # For ERC20 compatibility + if var.name in ['symbol', 'name', 'decimals']: + continue + + if not self.is_upper_case_with_underscores(var.name): + info = "Constant '{}.{}' ({}) is not in UPPER_CASE_WITH_UNDERSCORES\n" + info = info.format(var.contract.name, var.name, var.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'constant': var.name, + 'sourceMapping': var.source_mapping}) + else: + if var.visibility == 'private': + correct_naming = self.is_mixed_case_with_underscore(var.name) + else: + correct_naming = self.is_mixed_case(var.name) + if not correct_naming: + info = "Variable '{}.{}' ({}) is not in mixedCase\n" + info = info.format(var.contract.name, var.name, var.source_mapping_str) + all_info += info + + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'variable': var.name, + 'sourceMapping': var.source_mapping}) + + for enum in contract.enums: + if enum.contract != contract: + continue + + if not self.is_cap_words(enum.name): + info = "Enum '{}.{}' ({}) is not in CapWords\n" + info = info.format(enum.contract.name, enum.name, enum.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'enum': enum.name, + 'sourceMapping': enum.source_mapping}) + + for modifier in contract.modifiers: + if modifier.contract != contract: + continue + + if not self.is_mixed_case(modifier.name): + info = "Modifier '{}.{}' ({}) is not in mixedCase\n" + info = info.format(modifier.contract.name, modifier.name, modifier.source_mapping_str) + all_info += info + + results.append({'vuln': 'NamingConvention', + 'filename': self.filename, + 'contract': contract.name, + 'modifier': modifier.name, + 'sourceMapping': modifier.source_mapping}) + if all_info != '': + self.log(all_info) + + return results diff --git a/slither/solcParsing/variables/__init__.py b/slither/detectors/operations/__init__.py similarity index 100% rename from slither/solcParsing/variables/__init__.py rename to slither/detectors/operations/__init__.py diff --git a/slither/detectors/operations/low_level_calls.py b/slither/detectors/operations/low_level_calls.py new file mode 100644 index 000000000..509536244 --- /dev/null +++ b/slither/detectors/operations/low_level_calls.py @@ -0,0 +1,62 @@ +""" +Module detecting usage of low level calls +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.slithir.operations import LowLevelCall + + +class LowLevelCalls(AbstractDetector): + """ + Detect usage of low level calls + """ + + ARGUMENT = 'low-level-calls' + HELP = 'Low level calls' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#low-level-calls' + + @staticmethod + def _contains_low_level_calls(node): + """ + Check if the node contains Low Level Calls + Returns: + (bool) + """ + return any(isinstance(ir, LowLevelCall) for ir in node.irs) + + def detect_low_level_calls(self, contract): + ret = [] + for f in [f for f in contract.functions if contract == f.contract]: + nodes = f.nodes + assembly_nodes = [n for n in nodes if + self._contains_low_level_calls(n)] + if assembly_nodes: + ret.append((f, assembly_nodes)) + return ret + + def detect(self): + """ Detect the functions that use low level calls + """ + results = [] + all_info = '' + for c in self.contracts: + values = self.detect_low_level_calls(c) + for func, nodes in values: + info = "Low level call in {}.{} ({})\n" + info = info.format(func.contract.name, func.name, func.source_mapping_str) + all_info += info + + sourceMapping = [n.source_mapping for n in nodes] + + results.append({'vuln': 'Low level call', + 'sourceMapping': sourceMapping, + 'filename': self.filename, + 'contract': func.contract.name, + 'function_name': func.name}) + + if all_info != '': + self.log(all_info) + return results diff --git a/slither/detectors/reentrancy/__init__.py b/slither/detectors/reentrancy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/detectors/reentrancy/reentrancy.py b/slither/detectors/reentrancy/reentrancy.py new file mode 100644 index 000000000..847cdc17c --- /dev/null +++ b/slither/detectors/reentrancy/reentrancy.py @@ -0,0 +1,214 @@ +"""" + Re-entrancy detection + + Based on heuristics, it may lead to FP and FN + Iterate over all the nodes of the graph until reaching a fixpoint +""" + +from slither.core.cfg.node import NodeType +from slither.core.declarations import Function, SolidityFunction +from slither.core.expressions import UnaryOperation, UnaryOperationType +from slither.detectors.abstract_detector import (AbstractDetector, + DetectorClassification) +from slither.visitors.expression.export_values import ExportValues +from slither.slithir.operations import (HighLevelCall, LowLevelCall, + LibraryCall, + Send, Transfer) + +class Reentrancy(AbstractDetector): + ARGUMENT = 'reentrancy' + HELP = 'Reentrancy vulnerabilities' + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.MEDIUM + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#reentrancy-vulnerabilities' + + key = 'REENTRANCY' + + @staticmethod + def _can_callback(node): + """ + Detect if the node contains a call that can + be used to re-entrance + + Consider as valid target: + - low level call + - high level call + + Do not consider Send/Transfer as there is not enough gas + """ + for ir in node.irs: + if isinstance(ir, LowLevelCall): + return True + if isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): + return True + return False + + @staticmethod + def _can_send_eth(node): + """ + Detect if the node can send eth + """ + for ir in node.irs: + if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): + if ir.call_value: + return True + return False + + def _check_on_call_returned(self, node): + """ + Check if the node is a condtional node where + there is an external call checked + Heuristic: + - The call is a IF node + - It contains a, external call + - The condition is the negation (!) + + This will work only on naive implementation + """ + return isinstance(node.expression, UnaryOperation)\ + and node.expression.type == UnaryOperationType.BANG + + def _explore(self, node, visited): + """ + Explore the CFG and look for re-entrancy + Heuristic: There is a re-entrancy if a state variable is written + after an external call + + node.context will contains the external calls executed + It contains the calls executed in father nodes + + if node.context is not empty, and variables are written, a re-entrancy is possible + """ + if node in visited: + return + + visited = visited + [node] + + # First we add the external calls executed in previous nodes + # send_eth returns the list of calls sending value + # calls returns the list of calls that can callback + # read returns the variable read + fathers_context = {'send_eth':[], 'calls':[], 'read':[]} + + for father in node.fathers: + if self.key in father.context: + fathers_context['send_eth'] += father.context[self.key]['send_eth'] + fathers_context['calls'] += father.context[self.key]['calls'] + fathers_context['read'] += father.context[self.key]['read'] + + # Exclude path that dont bring further information + if node in self.visited_all_paths: + if all(call in self.visited_all_paths[node]['calls'] for call in fathers_context['calls']): + if all(send in self.visited_all_paths[node]['send_eth'] for send in fathers_context['send_eth']): + if all(read in self.visited_all_paths[node]['read'] for read in fathers_context['read']): + return + else: + self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[]} + + self.visited_all_paths[node]['send_eth'] = list(set(self.visited_all_paths[node]['send_eth'] + fathers_context['send_eth'])) + self.visited_all_paths[node]['calls'] = list(set(self.visited_all_paths[node]['calls'] + fathers_context['calls'])) + self.visited_all_paths[node]['read'] = list(set(self.visited_all_paths[node]['read'] + fathers_context['read'])) + + node.context[self.key] = fathers_context + + contains_call = False + if self._can_callback(node): + node.context[self.key]['calls'] = list(set(node.context[self.key]['calls'] + [node])) + contains_call = True + if self._can_send_eth(node): + node.context[self.key]['send_eth'] = list(set(node.context[self.key]['send_eth'] + [node])) + + + # All the state variables written + state_vars_written = node.state_variables_written + # Add the state variables written in internal calls + for internal_call in node.internal_calls: + # Filter to Function, as internal_call can be a solidity call + if isinstance(internal_call, Function): + state_vars_written += internal_call.all_state_variables_written() + + read_then_written = [(v, node.source_mapping_str) for v in state_vars_written if v in node.context[self.key]['read']] + + node.context[self.key]['read'] = list(set(node.context[self.key]['read'] + node.state_variables_read)) + # If a state variables was read and is then written, there is a dangerous call and + # ether were sent + # We found a potential re-entrancy bug + if (read_then_written and + node.context[self.key]['calls'] and + node.context[self.key]['send_eth']): + # calls are ordered + finding_key = (node.function, + tuple(set(node.context[self.key]['calls'])), + tuple(set(node.context[self.key]['send_eth']))) + finding_vars = read_then_written + if finding_key not in self.result: + self.result[finding_key] = [] + self.result[finding_key] = list(set(self.result[finding_key] + finding_vars)) + + sons = node.sons + if contains_call and self._check_on_call_returned(node): + sons = sons[1:] + + for son in sons: + self._explore(son, visited) + + def detect_reentrancy(self, contract): + """ + """ + for function in contract.functions: + if function.is_implemented: + self._explore(function.entry_point, []) + + def detect(self): + """ + """ + self.result = {} + + # if a node was already visited by another path + # we will only explore it if the traversal brings + # new variables written + # This speedup the exploration through a light fixpoint + # Its particular useful on 'complex' functions with several loops and conditions + self.visited_all_paths = {} + + for c in self.contracts: + self.detect_reentrancy(c) + + results = [] + + for (func, calls, send_eth), varsWritten in self.result.items(): + calls = list(set(calls)) + send_eth = list(set(send_eth)) +# if calls == send_eth: +# calls_info = 'Call: {},'.format(calls_str) +# else: +# calls_info = 'Call: {}, Ether sent: {},'.format(calls_str, send_eth_str) + info = 'Reentrancy in {}.{} ({}):\n' + info = info.format(func.contract.name, func.name, func.source_mapping_str) + info += '\tExternal calls:\n' + for call_info in calls: + info += '\t- {} ({})\n'.format(call_info.expression, call_info.source_mapping_str) + if calls != send_eth: + info += '\tExternal calls sending eth:\n' + for call_info in send_eth: + info += '\t- {} ({})\n'.format(call_info.expression, call_info.source_mapping_str) + info += '\tState variables written after the call(s):\n' + for (v, mapping) in varsWritten: + info += '\t- {} ({})\n'.format(v, mapping) + self.log(info) + + source = [v.source_mapping for (v,_) in varsWritten] + source += [node.source_mapping for node in calls] + source += [node.source_mapping for node in send_eth] + + results.append({'vuln': 'Reentrancy', + 'sourceMapping': source, + 'filename': self.filename, + 'contract': func.contract.name, + 'function_name': func.name, + 'calls': [str(x.expression) for x in calls], + 'send_eth': [str(x.expression) for x in send_eth], + 'varsWritten': [str(x) for (x,_) in varsWritten]}) + + return results diff --git a/slither/detectors/shadowing/shadowingFunctionsDetection.py b/slither/detectors/shadowing/shadowing_functions.py similarity index 80% rename from slither/detectors/shadowing/shadowingFunctionsDetection.py rename to slither/detectors/shadowing/shadowing_functions.py index e41bc3ee0..60d9677df 100644 --- a/slither/detectors/shadowing/shadowingFunctionsDetection.py +++ b/slither/detectors/shadowing/shadowing_functions.py @@ -3,8 +3,8 @@ It is more useful as summary printer than as vuln detection """ -from slither.detectors.abstractDetector import AbstractDetector -from slither.detectors.detectorClassification import DetectorClassification +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + class ShadowingFunctionsDetection(AbstractDetector): """ @@ -15,14 +15,13 @@ class ShadowingFunctionsDetection(AbstractDetector): ARGUMENT = 'shadowing-function' HELP = 'Function Shadowing' - CLASSIFICATION = DetectorClassification.LOW - - HIDDEN_DETECTOR = True + IMPACT = DetectorClassification.LOW + CONFIDENCE = DetectorClassification.HIGH def detect_shadowing(self, contract): functions_declared = set([x.full_name for x in contract.functions]) ret = {} - for father in contract.inheritances: + for father in contract.inheritance: functions_declared_father = ([x.full_name for x in father.functions]) inter = functions_declared.intersection(functions_declared_father) if inter: @@ -42,7 +41,7 @@ class ShadowingFunctionsDetection(AbstractDetector): shadowing = self.detect_shadowing(c) if shadowing: for contract, funcs in shadowing.items(): - results.append({'vuln':self.vuln_name, + results.append({'vuln': self.vuln_name, 'filename': self.filename, 'contractShadower': c.name, 'contract': contract.name, diff --git a/slither/detectors/statements/__init__.py b/slither/detectors/statements/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/detectors/statements/assembly.py b/slither/detectors/statements/assembly.py new file mode 100644 index 000000000..a10697d3a --- /dev/null +++ b/slither/detectors/statements/assembly.py @@ -0,0 +1,62 @@ +""" +Module detecting usage of inline assembly +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.cfg.node import NodeType + + +class Assembly(AbstractDetector): + """ + Detect usage of inline assembly + """ + + ARGUMENT = 'assembly' + HELP = 'Assembly usage' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#assembly-usage' + + @staticmethod + def _contains_inline_assembly_use(node): + """ + Check if the node contains ASSEMBLY type + Returns: + (bool) + """ + return node.type == NodeType.ASSEMBLY + + def detect_assembly(self, contract): + ret = [] + for f in contract.functions: + nodes = f.nodes + assembly_nodes = [n for n in nodes if + self._contains_inline_assembly_use(n)] + if assembly_nodes: + ret.append((f, assembly_nodes)) + return ret + + def detect(self): + """ Detect the functions that use inline assembly + """ + results = [] + all_info = '' + for c in self.contracts: + values = self.detect_assembly(c) + for func, nodes in values: + info = "{}.{} uses assembly ({})\n" + info = info.format(func.contract.name, func.name, func.source_mapping_str) + all_info += info + + sourceMapping = [n.source_mapping for n in nodes] + + results.append({'vuln': 'Assembly', + 'sourceMapping': sourceMapping, + 'filename': self.filename, + 'contract': func.contract.name, + 'function_name': func.name}) + + if all_info != '': + self.log(all_info) + return results diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py new file mode 100644 index 000000000..fa5ba4938 --- /dev/null +++ b/slither/detectors/statements/tx_origin.py @@ -0,0 +1,69 @@ +""" +Module detecting usage of `tx.origin` in a conditional node +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + +class TxOrigin(AbstractDetector): + """ + Detect usage of tx.origin in a conditional node + """ + + ARGUMENT = 'tx-origin' + HELP = 'Dangerous usage of `tx.origin`' + IMPACT = DetectorClassification.MEDIUM + CONFIDENCE = DetectorClassification.MEDIUM + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#dangerous-usage-of-txorigin' + + @staticmethod + def _contains_incorrect_tx_origin_use(node): + """ + Check if the node read tx.origin and dont read msg.sender + Avoid the FP due to (msg.sender == tx.origin) + Returns: + (bool) + """ + solidity_var_read = node.solidity_variables_read + if solidity_var_read: + return any(v.name == 'tx.origin' for v in solidity_var_read) and\ + all(v.name != 'msg.sender' for v in solidity_var_read) + return False + + def detect_tx_origin(self, contract): + ret = [] + for f in contract.functions: + + nodes = f.nodes + condtional_nodes = [n for n in nodes if n.contains_if() or + n.contains_require_or_assert()] + bad_tx_nodes = [n for n in condtional_nodes if + self._contains_incorrect_tx_origin_use(n)] + if bad_tx_nodes: + ret.append((f, bad_tx_nodes)) + return ret + + def detect(self): + """ Detect the functions that use tx.origin in a conditional node + """ + results = [] + for c in self.contracts: + values = self.detect_tx_origin(c) + for func, nodes in values: + info = "{}.{} uses tx.origin for authorization:\n" + info = info.format(func.contract.name, func.name) + + for node in nodes: + info += "\t- {} ({})\n".format(node.expression, node.source_mapping_str) + + self.log(info) + + sourceMapping = [n.source_mapping for n in nodes] + + results.append({'vuln': 'TxOrigin', + 'sourceMapping': sourceMapping, + 'filename': self.filename, + 'contract': func.contract.name, + 'function_name': func.name}) + + return results diff --git a/slither/detectors/variables/possible_const_state_variables.py b/slither/detectors/variables/possible_const_state_variables.py new file mode 100644 index 000000000..61d34efff --- /dev/null +++ b/slither/detectors/variables/possible_const_state_variables.py @@ -0,0 +1,82 @@ +""" +Module detecting state variables that could be declared as constant +""" + +from collections import defaultdict +from slither.core.solidity_types.elementary_type import ElementaryType +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.slithir.operations import OperationWithLValue +from slither.core.variables.state_variable import StateVariable +from slither.core.expressions.literal import Literal + + +class ConstCandidateStateVars(AbstractDetector): + """ + State variables that could be declared as constant detector. + Not all types for constants are implemented in Solidity as of 0.4.25. + The only supported types are value types and strings (ElementaryType). + Reference: https://solidity.readthedocs.io/en/latest/contracts.html#constant-state-variables + """ + + ARGUMENT = 'constable-states' + HELP = 'State variables that could be declared constant' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#state-variables-that-could-be-declared-constant' + + @staticmethod + def lvalues_of_operations_with_lvalue(contract): + ret = [] + for f in contract.all_functions_called + contract.modifiers: + for n in f.nodes: + for ir in n.irs: + if isinstance(ir, OperationWithLValue) and isinstance(ir.lvalue, StateVariable): + ret.append(ir.lvalue) + return ret + + @staticmethod + def non_const_state_variables(contract): + return [variable for variable in contract.state_variables + if not variable.is_constant and type(variable.expression) == Literal] + + def detect_const_candidates(self, contract): + const_candidates = [] + non_const_state_vars = self.non_const_state_variables(contract) + lvalues_of_operations = self.lvalues_of_operations_with_lvalue(contract) + for non_const in non_const_state_vars: + if non_const not in lvalues_of_operations \ + and non_const not in const_candidates \ + and isinstance(non_const.type, ElementaryType): + const_candidates.append(non_const) + + return const_candidates + + def detect(self): + """ Detect state variables that could be const + """ + results = [] + all_info = '' + for c in self.slither.contracts_derived: + const_candidates = self.detect_const_candidates(c) + if const_candidates: + variables_by_contract = defaultdict(list) + + for state_var in const_candidates: + variables_by_contract[state_var.contract.name].append(state_var) + + for contract, variables in variables_by_contract.items(): + variable_names = [v.name for v in variables] + for v in variables: + all_info += "{}.{} should be constant ({})\n".format(contract, v.name, v.source_mapping_str) + + sourceMapping = [v.source_mapping for v in const_candidates] + + results.append({'vuln': 'ConstStateVariableCandidates', + 'sourceMapping': sourceMapping, + 'filename': self.filename, + 'contract': c.name, + 'unusedVars': variable_names}) + if all_info != '': + self.log(all_info) + return results diff --git a/slither/detectors/variables/uninitializedStateVarsDetection.py b/slither/detectors/variables/uninitializedStateVarsDetection.py deleted file mode 100644 index b47641975..000000000 --- a/slither/detectors/variables/uninitializedStateVarsDetection.py +++ /dev/null @@ -1,75 +0,0 @@ -""" - Module detecting state uninitialized variables - Recursively check the called functions - - The heuristic chekcs that: - - state variables are read or called - - the variables does not call push (avoid too many FP) - - Only analyze "leaf" contracts (contracts that are not inherited by another contract) -""" - -from slither.detectors.abstractDetector import AbstractDetector -from slither.detectors.detectorClassification import DetectorClassification - -from slither.visitors.expression.findPush import FindPush - -class UninitializedStateVarsDetection(AbstractDetector): - """ - Constant function detector - """ - - ARGUMENT = 'uninitialized' - HELP = 'Uninitialized state variables' - CLASSIFICATION = DetectorClassification.HIGH - - def detect_uninitialized(self, contract): - # get all the state variables read by all functions - var_read = [f.state_variables_read for f in contract.functions_all_called + contract.modifiers] - # flat list - var_read = [item for sublist in var_read for item in sublist] - # remove state variable that are initiliazed at contract construction - var_read = [v for v in var_read if v.uninitialized] - - # get all the state variables written by the functions - var_written = [f.state_variables_written for f in contract.functions_all_called + contract.modifiers] - # flat list - var_written = [item for sublist in var_written for item in sublist] - - all_push = [f.apply_visitor(FindPush) for f in contract.functions] - # flat list - all_push = [item for sublist in all_push for item in sublist] - - uninitialized_vars = list(set([v for v in var_read if\ - v not in var_written and\ - v not in all_push and\ - str(v.type) not in contract.using_for])) # Note: does not handle using X for * - - return [(v, contract.get_functions_reading_variable(v)) for v in uninitialized_vars] - - - def detect(self): - """ Detect uninitialized state variables - - Recursively visit the calls - Returns: - dict: [contract name] = set(state variable uninitialized) - """ - results = [] - for c in self.slither.contracts_derived: - ret = self.detect_uninitialized(c) - for variable, functions in ret: - info = "Uninitialized state variables in %s, "%self.filename +\ - "Contract: %s, Vars: %s, Used in %s"%(c.name, - str(variable), - [str(f) for f in functions]) - self.log(info) - - results.append({'vuln':'UninitializedStateVars', - 'sourceMapping': c.source_mapping, - 'filename': self.filename, - 'contract': c.name, - 'functions': [str(f) for f in functions], - 'variable': str(variable)}) - - return results diff --git a/slither/detectors/variables/uninitialized_local_variables.py b/slither/detectors/variables/uninitialized_local_variables.py new file mode 100644 index 000000000..111b0cf0e --- /dev/null +++ b/slither/detectors/variables/uninitialized_local_variables.py @@ -0,0 +1,100 @@ +""" + Module detecting state uninitialized local variables + + Recursively explore the CFG to only report uninitialized local variables that are + written before being read +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + +from slither.visitors.expression.find_push import FindPush + + +class UninitializedLocalVars(AbstractDetector): + """ + """ + + ARGUMENT = 'uninitialized-local' + HELP = 'Uninitialized local variables' + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#uninitialized-local-variables' + + key = "UNINITIALIZEDLOCAL" + + def _detect_uninitialized(self, function, node, visited): + if node in visited: + return + + visited = visited + [node] + + fathers_context = [] + + for father in node.fathers: + if self.key in father.context: + fathers_context += father.context[self.key] + + # Exclude path that dont bring further information + if node in self.visited_all_paths: + if all(f_c in self.visited_all_paths[node] for f_c in fathers_context): + return + else: + self.visited_all_paths[node] = [] + + self.visited_all_paths[node] = list(set(self.visited_all_paths[node] + fathers_context)) + + if self.key in node.context: + fathers_context += node.context[self.key] + + variables_read = node.variables_read + for uninitialized_local_variable in fathers_context: + if uninitialized_local_variable in variables_read: + self.results.append((function, uninitialized_local_variable)) + + # Only save the local variables that are not yet written + uninitialized_local_variables = list(set(fathers_context) - set(node.variables_written)) + node.context[self.key] = uninitialized_local_variables + + for son in node.sons: + self._detect_uninitialized(function, son, visited) + + + def detect(self): + """ Detect uninitialized state variables + + Recursively visit the calls + Returns: + dict: [contract name] = set(state variable uninitialized) + """ + results = [] + + self.results = [] + self.visited_all_paths = {} + + for contract in self.slither.contracts: + for function in contract.functions: + if function.is_implemented: + # dont consider storage variable, as they are detected by another detector + uninitialized_local_variables = [v for v in function.local_variables if not v.is_storage and v.uninitialized] + function.entry_point.context[self.key] = uninitialized_local_variables + self._detect_uninitialized(function, function.entry_point, []) + + for(function, uninitialized_local_variable) in self.results: + var_name = uninitialized_local_variable.name + + info = "{} in {}.{} ({}) is a local variable never initialiazed\n" + info = info.format(var_name, function.contract.name, function.name, uninitialized_local_variable.source_mapping_str) + + self.log(info) + + source = [function.source_mapping, uninitialized_local_variable.source_mapping] + + results.append({'vuln': 'UninitializedLocalVars', + 'sourceMapping': source, + 'filename': self.filename, + 'contract': function.contract.name, + 'function': function.name, + 'variable': var_name}) + + return results diff --git a/slither/detectors/variables/uninitialized_state_variables.py b/slither/detectors/variables/uninitialized_state_variables.py new file mode 100644 index 000000000..3046268a8 --- /dev/null +++ b/slither/detectors/variables/uninitialized_state_variables.py @@ -0,0 +1,93 @@ +""" + Module detecting state uninitialized variables + Recursively check the called functions + + The heuristic checks: + - state variables including mappings/refs + - LibraryCalls, InternalCalls, InternalDynamicCalls with storage variables + + Only analyze "leaf" contracts (contracts that are not inherited by another contract) +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.variables.state_variable import StateVariable +from slither.slithir.variables import ReferenceVariable +from slither.slithir.operations.assignment import Assignment + +from slither.slithir.operations import (OperationWithLValue, Index, Member, + InternalCall, InternalDynamicCall, LibraryCall) + + +class UninitializedStateVarsDetection(AbstractDetector): + """ + Constant function detector + """ + + ARGUMENT = 'uninitialized-state' + HELP = 'Uninitialized state variables' + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#uninitialized-state-variables' + + @staticmethod + def written_variables(contract): + ret = [] + for f in contract.all_functions_called + contract.modifiers: + for n in f.nodes: + ret += n.state_variables_written + for ir in n.irs: + if isinstance(ir, LibraryCall) \ + or isinstance(ir, InternalCall): + idx = 0 + if ir.function: + for param in ir.function.parameters: + if param.location == 'storage': + ret.append(ir.arguments[idx]) + idx = idx+1 + + return ret + + @staticmethod + def read_variables(contract): + ret = [] + for f in contract.all_functions_called + contract.modifiers: + ret += f.state_variables_read + return ret + + def detect_uninitialized(self, contract): + written_variables = self.written_variables(contract) + read_variables = self.read_variables(contract) + return [(variable, contract.get_functions_reading_from_variable(variable)) + for variable in contract.state_variables if variable not in written_variables and\ + not variable.expression and\ + variable in read_variables] + + def detect(self): + """ Detect uninitialized state variables + + Recursively visit the calls + Returns: + dict: [contract name] = set(state variable uninitialized) + """ + results = [] + for c in self.slither.contracts_derived: + ret = self.detect_uninitialized(c) + for variable, functions in ret: + info = "{}.{} ({}) is never initialized. It is used in:\n" + info = info.format(variable.contract.name, variable.name, variable.source_mapping_str) + for f in functions: + info += "\t- {} ({})\n".format(f.name, f.source_mapping_str) + self.log(info) + + source = [variable.source_mapping] + source += [f.source_mapping for f in functions] + + results.append({'vuln': 'UninitializedStateVars', + 'sourceMapping': source, + 'filename': self.filename, + 'contract': c.name, + 'functions': [str(f) for f in functions], + 'variable': str(variable)}) + + return results diff --git a/slither/detectors/variables/uninitialized_storage_variables.py b/slither/detectors/variables/uninitialized_storage_variables.py new file mode 100644 index 000000000..ccc3fbc77 --- /dev/null +++ b/slither/detectors/variables/uninitialized_storage_variables.py @@ -0,0 +1,100 @@ +""" + Module detecting state uninitialized storage variables + + Recursively explore the CFG to only report uninitialized storage variables that are + written before being read +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + +from slither.visitors.expression.find_push import FindPush + + +class UninitializedStorageVars(AbstractDetector): + """ + """ + + ARGUMENT = 'uninitialized-storage' + HELP = 'Uninitialized storage variables' + IMPACT = DetectorClassification.HIGH + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#uninitialized-storage-variables' + + # node.context[self.key] contains the uninitialized storage variables + key = "UNINITIALIZEDSTORAGE" + + def _detect_uninitialized(self, function, node, visited): + if node in visited: + return + + visited = visited + [node] + + fathers_context = [] + + for father in node.fathers: + if self.key in father.context: + fathers_context += father.context[self.key] + + # Exclude path that dont bring further information + if node in self.visited_all_paths: + if all(f_c in self.visited_all_paths[node] for f_c in fathers_context): + return + else: + self.visited_all_paths[node] = [] + + self.visited_all_paths[node] = list(set(self.visited_all_paths[node] + fathers_context)) + + if self.key in node.context: + fathers_context += node.context[self.key] + + variables_read = node.variables_read + for uninitialized_storage_variable in fathers_context: + if uninitialized_storage_variable in variables_read: + self.results.append((function, uninitialized_storage_variable)) + + # Only save the storage variables that are not yet written + uninitialized_storage_variables = list(set(fathers_context) - set(node.variables_written)) + node.context[self.key] = uninitialized_storage_variables + + for son in node.sons: + self._detect_uninitialized(function, son, visited) + + + def detect(self): + """ Detect uninitialized state variables + + Recursively visit the calls + Returns: + dict: [contract name] = set(state variable uninitialized) + """ + results = [] + + self.results = [] + self.visited_all_paths = {} + + for contract in self.slither.contracts: + for function in contract.functions: + if function.is_implemented: + uninitialized_storage_variables = [v for v in function.local_variables if v.is_storage and v.uninitialized] + function.entry_point.context[self.key] = uninitialized_storage_variables + self._detect_uninitialized(function, function.entry_point, []) + + for(function, uninitialized_storage_variable) in self.results: + var_name = uninitialized_storage_variable.name + + info = "{} in {}.{} ({}) is a storage variable never initialiazed\n" + info = info.format(var_name, function.contract.name, function.name, uninitialized_storage_variable.source_mapping_str) + + self.log(info) + + source = [function.source_mapping, uninitialized_storage_variable.source_mapping] + + results.append({'vuln': 'UninitializedStorageVars', + 'sourceMapping': source, + 'filename': self.filename, + 'contract': function.contract.name, + 'function': function.name, + 'variable': var_name}) + + return results diff --git a/slither/detectors/variables/unused_state_variables.py b/slither/detectors/variables/unused_state_variables.py new file mode 100644 index 000000000..36a42d267 --- /dev/null +++ b/slither/detectors/variables/unused_state_variables.py @@ -0,0 +1,55 @@ +""" +Module detecting unused state variables +""" + +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + +class UnusedStateVars(AbstractDetector): + """ + Unused state variables detector + """ + + ARGUMENT = 'unused-state' + HELP = 'Unused state variables' + IMPACT = DetectorClassification.INFORMATIONAL + CONFIDENCE = DetectorClassification.HIGH + + WIKI = 'https://github.com/trailofbits/slither/wiki/Vulnerabilities-Description#unused-state-variables' + + def detect_unused(self, contract): + if contract.is_signature_only(): + return None + # Get all the variables read in all the functions and modifiers + variables_used = [x.state_variables_read + x.state_variables_written for x in + (contract.all_functions_called + contract.modifiers)] + # Flat list + variables_used = [item for sublist in variables_used for item in sublist] + # Return the variables unused that are not public + return [x for x in contract.variables if + x not in variables_used and x.visibility != 'public'] + + def detect(self): + """ Detect unused state variables + """ + results = [] + all_info = '' + for c in self.slither.contracts_derived: + unusedVars = self.detect_unused(c) + if unusedVars: + unusedVarsName = [v.name for v in unusedVars] + info = '' + for var in unusedVars: + info += "{}.{} ({}) is never used\n".format(var.contract.name, var.name, var.source_mapping_str) + + all_info += info + + sourceMapping = [v.source_mapping for v in unusedVars] + + results.append({'vuln': 'unusedStateVars', + 'sourceMapping': sourceMapping, + 'filename': self.filename, + 'contract': c.name, + 'unusedVars': unusedVarsName}) + if all_info != '': + self.log(all_info) + return results diff --git a/slither/printers/abstractPrinter.py b/slither/printers/abstract_printer.py similarity index 74% rename from slither/printers/abstractPrinter.py rename to slither/printers/abstract_printer.py index 06ca714ee..839bb725e 100644 --- a/slither/printers/abstractPrinter.py +++ b/slither/printers/abstract_printer.py @@ -1,20 +1,24 @@ import abc + class IncorrectPrinterInitialization(Exception): pass -class AbstractPrinter(object, metaclass=abc.ABCMeta): - ARGUMENT = '' # run the printer with slither.py --ARGUMENT - HELP = '' # help information + +class AbstractPrinter(metaclass=abc.ABCMeta): + ARGUMENT = '' # run the printer with slither.py --ARGUMENT + HELP = '' # help information def __init__(self, slither, logger): self.slither = slither self.contracts = slither.contracts self.filename = slither.filename self.logger = logger - if self.HELP == '': + + if not self.HELP: raise IncorrectPrinterInitialization('HELP is not initialized') - if self.ARGUMENT == '': + + if not self.ARGUMENT: raise IncorrectPrinterInitialization('ARGUMENT is not initialized') def info(self, info): diff --git a/slither/printers/call/__init__.py b/slither/printers/call/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/printers/call/call_graph.py b/slither/printers/call/call_graph.py new file mode 100644 index 000000000..bc446180d --- /dev/null +++ b/slither/printers/call/call_graph.py @@ -0,0 +1,155 @@ +""" + Module printing the call graph + + The call graph shows for each function, + what are the contracts/functions called. + The output is a dot file named filename.dot +""" + +from slither.printers.abstract_printer import AbstractPrinter +from slither.core.declarations.solidity_variables import SolidityFunction +from slither.core.declarations.function import Function +from slither.core.declarations.contract import Contract +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.identifier import Identifier +from slither.core.variables.variable import Variable +from slither.core.solidity_types.user_defined_type import UserDefinedType + +# return unique id for contract to use as subgraph name +def _contract_subgraph(contract): + return f'cluster_{contract.id}_{contract.name}' + +# return unique id for contract function to use as node name +def _function_node(contract, function): + return f'{contract.id}_{function.name}' + +# return unique id for solidity function to use as node name +def _solidity_function_node(solidity_function): + return f'{solidity_function.name}' + +# return dot language string to add graph edge +def _edge(from_node, to_node): + return f'"{from_node}" -> "{to_node}"' + +# return dot language string to add graph node (with optional label) +def _node(node, label=None): + return ' '.join(( + f'"{node}"', + f'[label="{label}"]' if label is not None else '', + )) + +class PrinterCallGraph(AbstractPrinter): + ARGUMENT = 'call-graph' + HELP = 'Export the call-graph of the contracts to a dot file' + + def __init__(self, slither, logger): + super(PrinterCallGraph, self).__init__(slither, logger) + + self.contract_functions = {} # contract -> contract functions nodes + self.contract_calls = {} # contract -> contract calls edges + + for contract in slither.contracts: + self.contract_functions[contract] = set() + self.contract_calls[contract] = set() + + self.solidity_functions = set() # solidity function nodes + self.solidity_calls = set() # solidity calls edges + + self.external_calls = set() # external calls edges + + self._process_contracts(slither.contracts) + + def _process_contracts(self, contracts): + for contract in contracts: + for function in contract.functions: + self._process_function(contract, function) + + def _process_function(self, contract, function): + self.contract_functions[contract].add( + _node(_function_node(contract, function), function.name), + ) + + for internal_call in function.internal_calls: + self._process_internal_call(contract, function, internal_call) + for external_call in function.high_level_calls: + self._process_external_call(contract, function, external_call) + + def _process_internal_call(self, contract, function, internal_call): + if isinstance(internal_call, (Function)): + self.contract_calls[contract].add(_edge( + _function_node(contract, function), + _function_node(contract, internal_call), + )) + elif isinstance(internal_call, (SolidityFunction)): + self.solidity_functions.add( + _node(_solidity_function_node(internal_call)), + ) + self.solidity_calls.add(_edge( + _function_node(contract, function), + _solidity_function_node(internal_call), + )) + + def _process_external_call(self, contract, function, external_call): + external_contract, external_function = external_call + + # add variable as node to respective contract + if isinstance(external_function, (Variable)): + self.contract_functions[external_contract].add(_node( + _function_node(external_contract, external_function), + external_function.name + )) + + self.external_calls.add(_edge( + _function_node(contract, function), + _function_node(external_contract, external_function), + )) + + def _render_internal_calls(self): + lines = [] + + for contract in self.contract_functions: + lines.append(f'subgraph {_contract_subgraph(contract)} {{') + lines.append(f'label = "{contract.name}"') + + lines.extend(self.contract_functions[contract]) + lines.extend(self.contract_calls[contract]) + + lines.append('}') + + return '\n'.join(lines) + + def _render_solidity_calls(self): + lines = [] + + lines.append('subgraph cluster_solidity {') + lines.append('label = "[Solidity]"') + + lines.extend(self.solidity_functions) + lines.extend(self.solidity_calls) + + lines.append('}') + + return '\n'.join(lines) + + def _render_external_calls(self): + return '\n'.join(self.external_calls) + + def output(self, filename): + """ + Output the graph in filename + Args: + filename(string) + """ + if not filename.endswith('.dot'): + filename += '.dot' + + self.info(f'Call Graph: {filename}') + + with open(filename, 'w') as f: + f.write('\n'.join([ + 'strict digraph {', + self._render_internal_calls(), + self._render_solidity_calls(), + self._render_external_calls(), + '}', + ])) diff --git a/slither/printers/functions/authorization.py b/slither/printers/functions/authorization.py index b382e9498..b2e2a1086 100644 --- a/slither/printers/functions/authorization.py +++ b/slither/printers/functions/authorization.py @@ -3,17 +3,17 @@ """ from prettytable import PrettyTable -from slither.printers.abstractPrinter import AbstractPrinter +from slither.printers.abstract_printer import AbstractPrinter from slither.core.declarations.function import Function class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): - ARGUMENT = 'print-variables-written-and-authorization' - HELP = 'Print the the variables written and the authorization of the functions' + ARGUMENT = 'vars-and-auth' + HELP = 'Print the state variables written and the authorization of the functions' @staticmethod def get_msg_sender_checks(function): - all_functions = function.all_calls() + [function] + function.modifiers + all_functions = function.all_internal_calls() + [function] + function.modifiers all_nodes = [f.nodes for f in all_functions if isinstance(f, Function)] all_nodes = [item for sublist in all_nodes for item in sublist] @@ -33,10 +33,10 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): for contract in self.contracts: txt = "\nContract %s\n"%contract.name - table = PrettyTable(["Function", "State variable written", "Condition on msg.sender"]) + table = PrettyTable(["Function", "State variables written", "Conditions on msg.sender"]) for function in contract.functions: state_variables_written = [v.name for v in function.all_state_variables_written()] msg_sender_condition = self.get_msg_sender_checks(function) table.add_row([function.name, str(state_variables_written), str(msg_sender_condition)]) - self.info(txt + str(table)) + self.info(txt + str(table)) diff --git a/slither/printers/inheritance/inheritance.py b/slither/printers/inheritance/inheritance.py new file mode 100644 index 000000000..d7b4a6b11 --- /dev/null +++ b/slither/printers/inheritance/inheritance.py @@ -0,0 +1,46 @@ +""" + Module printing the inheritance relation + + The inheritance shows the relation between the contracts +""" + +from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.colors import blue, green + + +class PrinterInheritance(AbstractPrinter): + ARGUMENT = 'inheritance' + HELP = 'Print the inheritance relations between contracts' + + def _get_child_contracts(self, base): + # Generate function to get all child contracts of a base contract + for child in self.contracts: + if base in child.inheritance: + yield child + + def output(self, filename): + """ + Output the inheritance relation + + _filename is not used + Args: + _filename(string) + """ + info = 'Inheritance\n' + + if not self.contracts: + return + + info += blue('Child_Contract -> ') + green('Base_Contracts') + for child in self.contracts: + info += blue(f'\n+ {child.name}') + if child.inheritance: + info += ' -> ' + green(", ".join(map(str, child.inheritance))) + + info += green('\n\nBase_Contract -> ') + blue('Child_Contracts') + for base in self.contracts: + info += green(f'\n+ {base.name}') + children = list(self._get_child_contracts(base)) + if children: + info += ' -> ' + blue(", ".join(map(str, children))) + self.info(info) diff --git a/slither/printers/inheritance/printerInheritance.py b/slither/printers/inheritance/inheritance_graph.py similarity index 64% rename from slither/printers/inheritance/printerInheritance.py rename to slither/printers/inheritance/inheritance_graph.py index 9170be9e5..f76e9fb07 100644 --- a/slither/printers/inheritance/printerInheritance.py +++ b/slither/printers/inheritance/inheritance_graph.py @@ -6,20 +6,18 @@ The output is a dot file named filename.dot """ -from slither.printers.abstractPrinter import AbstractPrinter -from slither.detectors.shadowing.shadowingFunctionsDetection import ShadowingFunctionsDetection - from slither.core.declarations.contract import Contract +from slither.detectors.shadowing.shadowing_functions import ShadowingFunctionsDetection +from slither.printers.abstract_printer import AbstractPrinter -class PrinterInheritance(AbstractPrinter): - - ARGUMENT = 'print-inheritance' - HELP = 'Print the inheritance graph' +class PrinterInheritanceGraph(AbstractPrinter): + ARGUMENT = 'inheritance-graph' + HELP = 'Export the inheritance graph of each contract to a dot file' def __init__(self, slither, logger): - super(PrinterInheritance, self).__init__(slither, logger) + super(PrinterInheritanceGraph, self).__init__(slither, logger) - inheritance = [x.inheritances for x in slither.contracts] + inheritance = [x.inheritance for x in slither.contracts] self.inheritance = set([item for sublist in inheritance for item in sublist]) shadow = ShadowingFunctionsDetection(slither, None) @@ -38,22 +36,22 @@ class PrinterInheritance(AbstractPrinter): pattern_shadow = ' %s' if contract.name in self.functions_shadowed: if func_name in self.functions_shadowed[contract.name]: - return pattern_shadow%func_name - return pattern%func_name + return pattern_shadow % func_name + return pattern % func_name def _get_pattern_var(self, var, contract): # Html pattern, each line is a row in a table var_name = var.name pattern = ' %s' - pattern_contract = ' %s (%s)' - #pattern_arrow = ' %s' + pattern_contract = ' %s (%s)' + # pattern_arrow = ' %s' if isinstance(var.type, Contract): - return pattern_contract%(var_name, str(var.type)) - #return pattern_arrow%(self._get_port_id(var, contract), var_name) - return pattern%var_name + return pattern_contract % (var_name, str(var.type)) + # return pattern_arrow%(self._get_port_id(var, contract), var_name) + return pattern % var_name def _get_port_id(self, var, contract): - return "%s%s"%(var.name, contract.name) + return "%s%s" % (var.name, contract.name) def _summary(self, contract): """ @@ -61,44 +59,48 @@ class PrinterInheritance(AbstractPrinter): """ ret = '' # Add arrows - for i in contract.inheritances: - ret += '%s -> %s;\n'%(contract.name, i) + for i in contract.inheritance: + ret += '%s -> %s;\n' % (contract.name, i) # Functions visibilities = ['public', 'external'] - public_functions = [self._get_pattern_func(f, contract) for f in contract.functions if not f.is_constructor and f.contract == contract and f.visibility in visibilities] + public_functions = [self._get_pattern_func(f, contract) for f in contract.functions if + not f.is_constructor and f.contract == contract and f.visibility in visibilities] public_functions = ''.join(public_functions) - private_functions = [self._get_pattern_func(f, contract) for f in contract.functions if not f.is_constructor and f.contract == contract and f.visibility not in visibilities] + private_functions = [self._get_pattern_func(f, contract) for f in contract.functions if + not f.is_constructor and f.contract == contract and f.visibility not in visibilities] private_functions = ''.join(private_functions) # Modifiers modifiers = [self._get_pattern_func(m, contract) for m in contract.modifiers if m.contract == contract] modifiers = ''.join(modifiers) # Public variables - public_variables = [self._get_pattern_var(v, contract) for v in contract.variables if v.visibility in visibilities] + public_variables = [self._get_pattern_var(v, contract) for v in contract.variables if + v.visibility in visibilities] public_variables = ''.join(public_variables) - private_variables = [self._get_pattern_var(v, contract) for v in contract.variables if not v.visibility in visibilities] + private_variables = [self._get_pattern_var(v, contract) for v in contract.variables if + not v.visibility in visibilities] private_variables = ''.join(private_variables) # Build the node label - ret += '%s[shape="box"'%contract.name + ret += '%s[shape="box"' % contract.name ret += 'label=< ' - ret += ''%contract.name + ret += '' % contract.name if public_functions: ret += '' - ret += '%s'%public_functions + ret += '%s' % public_functions if private_functions: ret += '' - ret += '%s'%private_functions + ret += '%s' % private_functions if modifiers: ret += '' - ret += '%s'%modifiers + ret += '%s' % modifiers if public_variables: ret += '' - ret += '%s'%public_variables + ret += '%s' % public_variables if private_variables: ret += '' - ret += '%s'%private_variables + ret += '%s' % private_variables ret += '
%s
%s
Public Functions:
Private Functions:
Modifiers:
Public Variables:
Private Variables:
>];\n' return ret @@ -111,7 +113,7 @@ class PrinterInheritance(AbstractPrinter): """ if not filename.endswith('.dot'): filename += ".dot" - info = 'Inheritance Graph: '+filename + info = 'Inheritance Graph: ' + filename self.info(info) with open(filename, 'w') as f: f.write('digraph{\n') diff --git a/slither/printers/printers.py b/slither/printers/printers.py deleted file mode 100644 index a59cd40c8..000000000 --- a/slither/printers/printers.py +++ /dev/null @@ -1,32 +0,0 @@ -import sys, inspect -import logging - -from slither.printers.abstractPrinter import AbstractPrinter - -# Printer must be imported here -from slither.printers.summary.printerSummary import PrinterSummary -from slither.printers.summary.printerQuickSummary import PrinterQuickSummary -from slither.printers.summary.printer_human_summary import PrinterHumanSummary -from slither.printers.inheritance.printerInheritance import PrinterInheritance -from slither.printers.functions.authorization import PrinterWrittenVariablesAndAuthorization - -logger_printer = logging.getLogger("Printers") - -class Printers: - - def __init__(self): - self.printers = {} - self._load_printers() - - def _load_printers(self): - for name, obj in inspect.getmembers(sys.modules[__name__]): - if inspect.isclass(obj): - if issubclass(obj, AbstractPrinter) and name != 'AbstractPrinter': - if name in self.printers: - raise Exception('Printer name collision: {}'.format(name)) - self.printers[name] = obj - - def run_printer(self, slither, name): - Printer = self.printers[name] - instance = Printer(slither, logger_printer) - return instance.output(slither.filename) diff --git a/slither/printers/summary/printerQuickSummary.py b/slither/printers/summary/contract.py similarity index 62% rename from slither/printers/summary/printerQuickSummary.py rename to slither/printers/summary/contract.py index 4c7a2219e..51409ce91 100644 --- a/slither/printers/summary/printerQuickSummary.py +++ b/slither/printers/summary/contract.py @@ -2,13 +2,13 @@ Module printing summary of the contract """ -from slither.printers.abstractPrinter import AbstractPrinter +from slither.printers.abstract_printer import AbstractPrinter from slither.utils.colors import blue, green, magenta -class PrinterQuickSummary(AbstractPrinter): +class ContractSummary(AbstractPrinter): - ARGUMENT = 'print-quick-summary' - HELP = 'Print a quick summary of the contract' + ARGUMENT = 'contract-summary' + HELP = 'Print a summary of the contracts' def output(self, _filename): """ @@ -19,13 +19,13 @@ class PrinterQuickSummary(AbstractPrinter): txt = "" for c in self.contracts: - (name, var, func_summaries, modif_summaries) = c.get_summary() + (name, _inheritance, _var, func_summaries, _modif_summaries) = c.get_summary() txt += blue("\n+ Contract %s\n"%name) - for (f_name, visi, modifiers, read, write, calls) in func_summaries: + for (f_name, visi, _, _, _, _, _) in func_summaries: txt += " - " if visi in ['external', 'public']: txt += green("%s (%s)\n"%(f_name, visi)) - elif visi in ['internal','private']: + elif visi in ['internal', 'private']: txt += magenta("%s (%s)\n"%(f_name, visi)) else: txt += "%s (%s)\n"%(f_name, visi) diff --git a/slither/printers/summary/function.py b/slither/printers/summary/function.py new file mode 100644 index 000000000..a56552035 --- /dev/null +++ b/slither/printers/summary/function.py @@ -0,0 +1,62 @@ +""" + Module printing summary of the contract +""" + +from prettytable import PrettyTable +from slither.printers.abstract_printer import AbstractPrinter + +class FunctionSummary(AbstractPrinter): + + ARGUMENT = 'function-summary' + HELP = 'Print a summary of the functions' + + @staticmethod + def _convert(l): + if l: + n = 2 + l = [l[i:i + n] for i in range(0, len(l), n)] + l = [str(x) for x in l] + return "\n".join(l) + return str(l) + + def output(self, _filename): + """ + _filename is not used + Args: + _filename(string) + """ + + for c in self.contracts: + (name, inheritance, var, func_summaries, modif_summaries) = c.get_summary() + txt = "\nContract %s"%name + txt += '\nContract vars: '+str(var) + txt += '\nInheritance:: '+str(inheritance) + table = PrettyTable(["Function", + "Visibility", + "Modifiers", + "Read", + "Write", + "Internal Calls", + "External Calls"]) + for (f_name, visi, modifiers, read, write, internal_calls, external_calls) in func_summaries: + read = self._convert(read) + write = self._convert(write) + internal_calls = self._convert(internal_calls) + external_calls = self._convert(external_calls) + table.add_row([f_name, visi, modifiers, read, write, internal_calls, external_calls]) + txt += "\n \n"+str(table) + table = PrettyTable(["Modifiers", + "Visibility", + "Read", + "Write", + "Internal Calls", + "External Calls"]) + for (f_name, visi, _, read, write, internal_calls, external_calls) in modif_summaries: + read = self._convert(read) + write = self._convert(write) + internal_calls = self._convert(internal_calls) + external_calls = self._convert(external_calls) + table.add_row([f_name, visi, read, write, internal_calls, external_calls]) + txt += "\n\n"+str(table) + txt += "\n" + self.info(txt) diff --git a/slither/printers/summary/printer_human_summary.py b/slither/printers/summary/human_summary.py similarity index 71% rename from slither/printers/summary/printer_human_summary.py rename to slither/printers/summary/human_summary.py index 1f5e14092..cadd41948 100644 --- a/slither/printers/summary/printer_human_summary.py +++ b/slither/printers/summary/human_summary.py @@ -3,21 +3,20 @@ Module printing summary of the contract """ import logging -from slither.detectors.detectors import Detectors -from slither.printers.abstractPrinter import AbstractPrinter +from slither.printers.abstract_printer import AbstractPrinter from slither.utils.code_complexity import compute_cyclomatic_complexity from slither.utils.colors import green, red, yellow class PrinterHumanSummary(AbstractPrinter): - ARGUMENT = 'print-human-summary' + ARGUMENT = 'human-summary' HELP = 'Print a human readable summary of the contracts' @staticmethod def get_summary_erc20(contract): txt = '' - functions_name = [f.name.lower() for f in contract.functions] - state_variables = [v.name.lower() for v in contract.state_variables] + functions_name = [f.name for f in contract.functions] + state_variables = [v.name for v in contract.state_variables] if 'pause' in functions_name: txt += "\t\t Can be paused? : {}\n".format(yellow('Yes')) @@ -32,7 +31,7 @@ class PrinterHumanSummary(AbstractPrinter): else: txt += "\t\t Minting restriction? : {}\n".format(green('No Minting')) - if 'increaseApproval' in functions_name: + if 'increaseApproval' in functions_name or 'safeIncreaseAllowance' in functions_name: txt += "\t\t ERC20 race condition mitigation: {}\n".format(green('Yes')) else: txt += "\t\t ERC20 race condition mitigation: {}\n".format(red('No')) @@ -40,24 +39,27 @@ class PrinterHumanSummary(AbstractPrinter): return txt def get_detectors_result(self): - detectors = Detectors() # disable detectors logger logger = logging.getLogger('Detectors') logger.setLevel(logging.ERROR) - checks_low = detectors.low - checks_medium = detectors.medium - checks_high = detectors.high + checks_informational = self.slither.detectors_informational + checks_low = self.slither.detectors_low + checks_medium = self.slither.detectors_medium + checks_high = self.slither.detectors_high - issues_low = [detectors.run_detector(self.slither, c) for c in checks_low] + issues_informational = [c.detect() for c in checks_informational] + issues_informational = [item for sublist in issues_informational for item in sublist] + issues_low = [c.detect() for c in checks_low] issues_low = [c for c in issues_low if c] - issues_medium = (detectors.run_detector(self.slither, c) for c in checks_medium) + issues_medium = (c.detect() for c in checks_medium) issues_medium = [c for c in issues_medium if c] - issues_high = [detectors.run_detector(self.slither, c) for c in checks_high] + issues_high = [c.detect() for c in checks_high] issues_high = [c for c in issues_high if c] - txt = "Number of low issues: {}\n".format(green(len(issues_low))) + txt = "Number of informational issues: {}\n".format(green(len(issues_informational))) + txt += "Number of low issues: {}\n".format(green(len(issues_low))) txt += "Number of medium issues: {}\n".format(yellow(len(issues_medium))) txt += "Number of high issues: {}\n".format(red(len(issues_high))) @@ -95,6 +97,7 @@ class PrinterHumanSummary(AbstractPrinter): txt += "\nContract {}\n".format(contract.name) txt += self.is_complex_code(contract) is_erc20 = contract.is_erc20() + txt += '\tNumber of functions:{}'.format(len(contract.functions)) txt += "\tIs ERC20 token: {}\n".format(contract.is_erc20()) if is_erc20: txt += self.get_summary_erc20(contract) diff --git a/slither/printers/summary/printerSummary.py b/slither/printers/summary/printerSummary.py deleted file mode 100644 index d68dea9dc..000000000 --- a/slither/printers/summary/printerSummary.py +++ /dev/null @@ -1,49 +0,0 @@ -""" - Module printing summary of the contract -""" - -from prettytable import PrettyTable -from slither.printers.abstractPrinter import AbstractPrinter - -class PrinterSummary(AbstractPrinter): - - ARGUMENT = 'print-summary' - HELP = 'Print the summary of the contract' - - @staticmethod - def _convert(l): - if l: - n = 2 - l = [l[i:i + n] for i in range(0, len(l), n)] - l = [str(x) for x in l] - return "\n".join(l) - return str(l) - - def output(self, _filename): - """ - _filename is not used - Args: - _filename(string) - """ - - for c in self.contracts: - (name, var, inheritances, func_summaries, modif_summaries) = c.get_summary() - txt = "\nContract %s"%name - txt += '\nContract vars: '+str(var) - txt += '\nInheritances:: '+str(inheritances) - table = PrettyTable(["Function", "Visibility", "Modifiers", "Read", "Write", "Calls"]) - for (f_name, visi, modifiers, read, write, calls) in func_summaries: - read = self._convert(read) - write = self._convert(write) - calls = self._convert(calls) - table.add_row([f_name, visi, modifiers, read, write, calls]) - txt += "\n \n"+str(table) - table = PrettyTable(["Modifiers", "Visibility", "Read", "Write", "Calls"]) - for (f_name, visi, _, read, write, calls) in modif_summaries: - read = self._convert(read) - write = self._convert(write) - calls = self._convert(calls) - table.add_row([f_name, visi, read, write, calls]) - txt += "\n\n"+str(table) - txt += "\n" - self.info(txt) diff --git a/slither/printers/summary/slithir.py b/slither/printers/summary/slithir.py new file mode 100644 index 000000000..d48d3bf87 --- /dev/null +++ b/slither/printers/summary/slithir.py @@ -0,0 +1,42 @@ +""" + Module printing summary of the contract +""" + +from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.colors import blue, green, magenta + +class PrinterSlithIR(AbstractPrinter): + + ARGUMENT = 'slithir' + HELP = 'Print the slithIR representation of the functions' + + def output(self, _filename): + """ + _filename is not used + Args: + _filename(string) + """ + + txt = "" + for contract in self.contracts: + print('Contract {}'.format(contract.name)) + for function in contract.functions: + if function.contract == contract: + print('\tFunction {}'.format(function.full_name)) + for node in function.nodes: + if node.expression: + print('\t\tExpression: {}'.format(node.expression)) + print('\t\tIRs:') + for ir in node.irs: + print('\t\t\t{}'.format(ir)) + for modifier in contract.modifiers: + if modifier.contract == contract: + print('\tModifier {}'.format(modifier.full_name)) + for node in modifier.nodes: + print(node) + if node.expression: + print('\t\tExpression: {}'.format(node.expression)) + print('\t\tIRs:') + for ir in node.irs: + print('\t\t\t{}'.format(ir)) + self.info(txt) diff --git a/slither/slither.py b/slither/slither.py index 5918795a9..d056faa1b 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -1,19 +1,108 @@ -import os -import sys import logging +import os import subprocess +import sys -from .solcParsing.slitherSolc import SlitherSolc +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.printers.abstract_printer import AbstractPrinter +from .solc_parsing.slitherSolc import SlitherSolc from .utils.colors import red logger = logging.getLogger("Slither") logging.basicConfig() +logger_detector = logging.getLogger("Detectors") +logger_printer = logging.getLogger("Printers") + class Slither(SlitherSolc): - def __init__(self, filename, solc='solc', disable_solc_warnings=False ,solc_arguments=''): + def __init__(self, contract, solc='solc', disable_solc_warnings=False, solc_arguments='', ast_format='--ast-json'): + self._detectors = [] + self._printers = [] + + # json text provided + if isinstance(contract, list): + super(Slither, self).__init__('') + for c in contract: + self._parse_contracts_from_loaded_json(c, c['absolutePath']) + # .json or .sol provided + else: + contracts_json = self._run_solc(contract, solc, disable_solc_warnings, solc_arguments, ast_format) + super(Slither, self).__init__(contract) + + for c in contracts_json: + self._parse_contracts_from_json(c) + + self._analyze_contracts() + + @property + def detectors(self): + return self._detectors + + @property + def detectors_high(self): + return [d for d in self.detectors if d.IMPACT == DetectorClassification.HIGH] + + @property + def detectors_medium(self): + return [d for d in self.detectors if d.IMPACT == DetectorClassification.MEDIUM] + + @property + def detectors_low(self): + return [d for d in self.detectors if d.IMPACT == DetectorClassification.LOW] + + @property + def detectors_informational(self): + return [d for d in self.detectors if d.IMPACT == DetectorClassification.INFORMATIONAL] + + def register_detector(self, detector_class): + """ + :param detector_class: Class inheriting from `AbstractDetector`. + """ + self._check_common_things('detector', detector_class, AbstractDetector, self._detectors) + instance = detector_class(self, logger_detector) + self._detectors.append(instance) + + def register_printer(self, printer_class): + """ + :param printer_class: Class inheriting from `AbstractPrinter`. + """ + self._check_common_things('printer', printer_class, AbstractPrinter, self._printers) + + instance = printer_class(self, logger_printer) + self._printers.append(instance) + + def run_detectors(self): + """ + :return: List of registered detectors results. + """ + + return [d.detect() for d in self._detectors] + + def run_printers(self): + """ + :return: List of registered printers outputs. + """ + + return [p.output(self.filename) for p in self._printers] + + def _check_common_things(self, thing_name, cls, base_cls, instances_list): + + if not issubclass(cls, base_cls) or cls is base_cls: + raise Exception( + "You can't register {!r} as a {}. You need to pass a class that inherits from {}".format( + cls, thing_name, base_cls.__name__ + ) + ) + + if any(isinstance(obj, cls) for obj in instances_list): + raise Exception( + "You can't register {!r} twice.".format(cls) + ) + + def _run_solc(self, filename, solc, disable_solc_warnings, solc_arguments, ast_format): if not os.path.isfile(filename): logger.error('{} does not exist (are you in the correct directory?)'.format(filename)) exit(-1) @@ -30,7 +119,7 @@ class Slither(SlitherSolc): logger.info('Empty AST file: %s', filename) sys.exit(-1) else: - cmd = [solc, filename, '--ast-json'] + cmd = [solc, filename, ast_format] if solc_arguments: # To parse, we first split the string on each '--' solc_args = solc_arguments.split('--') @@ -38,7 +127,7 @@ class Slither(SlitherSolc): # One solc option may have multiple argument sepparated with ' ' # For example: --allow-paths /tmp . # split() removes the delimiter, so we add it again - solc_args = [('--'+x).split(' ', 1) for x in solc_args if x] + solc_args = [('--' + x).split(' ', 1) for x in solc_args if x] # Flat the list of list solc_args = [item for sublist in solc_args for item in sublist] cmd += solc_args @@ -59,11 +148,4 @@ class Slither(SlitherSolc): stdout = stdout.split('\n=') - super(Slither, self).__init__(filename) - for d in stdout: - self.parse_contracts_from_json(d) - - self.analyze_contracts() - - - + return stdout diff --git a/slither/slithir/__init__.py b/slither/slithir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py new file mode 100644 index 000000000..9b4a98e9c --- /dev/null +++ b/slither/slithir/convert.py @@ -0,0 +1,676 @@ +import logging + +from slither.core.declarations import (Contract, Enum, Event, SolidityFunction, + Structure, SolidityVariableComposed, Function, SolidityVariable) +from slither.core.expressions import Identifier, Literal +from slither.core.solidity_types import ElementaryType, UserDefinedType, MappingType, ArrayType, FunctionType +from slither.core.variables.variable import Variable +from slither.slithir.operations import (Assignment, Binary, BinaryType, Call, + Condition, Delete, EventCall, + HighLevelCall, Index, InitArray, + InternalCall, InternalDynamicCall, LibraryCall, + LowLevelCall, Member, NewArray, + NewContract, NewElementaryType, + NewStructure, OperationWithLValue, + Push, Return, Send, SolidityCall, + Transfer, TypeConversion, Unary, + Unpack, Length, Balance) +from slither.slithir.tmp_operations.argument import Argument, ArgumentType +from slither.slithir.tmp_operations.tmp_call import TmpCall +from slither.slithir.tmp_operations.tmp_new_array import TmpNewArray +from slither.slithir.tmp_operations.tmp_new_contract import TmpNewContract +from slither.slithir.tmp_operations.tmp_new_elementary_type import \ + TmpNewElementaryType +from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure +from slither.slithir.variables import (Constant, ReferenceVariable, + TemporaryVariable, TupleVariable) +from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR + +logger = logging.getLogger('ConvertToIR') + +def is_value(ins): + if isinstance(ins, TmpCall): + if isinstance(ins.ori, Member): + if ins.ori.variable_right == 'value': + return True + return False + +def is_gas(ins): + if isinstance(ins, TmpCall): + if isinstance(ins.ori, Member): + if ins.ori.variable_right == 'gas': + return True + return False + +def integrate_value_gas(result): + was_changed = True + + calls = [] + + while was_changed: + # We loop until we do not find any call to value or gas + was_changed = False + + # Find all the assignments + assigments = {} + for i in result: + if isinstance(i, OperationWithLValue): + assigments[i.lvalue.name] = i + if isinstance(i, TmpCall): + if isinstance(i.called, Variable) and i.called.name in assigments: + ins_ori = assigments[i.called.name] + i.set_ori(ins_ori) + + to_remove = [] + variable_to_replace = {} + + # Replace call to value, gas to an argument of the real call + for idx in range(len(result)): + ins = result[idx] + # value can be shadowed, so we check that the prev ins + # is an Argument + if is_value(ins) and isinstance(result[idx-1], Argument): + was_changed = True + result[idx-1].set_type(ArgumentType.VALUE) + result[idx-1].call_id = ins.ori.variable_left.name + calls.append(ins.ori.variable_left) + to_remove.append(ins) + variable_to_replace[ins.lvalue.name] = ins.ori.variable_left + elif is_gas(ins) and isinstance(result[idx-1], Argument): + was_changed = True + result[idx-1].set_type(ArgumentType.GAS) + result[idx-1].call_id = ins.ori.variable_left.name + calls.append(ins.ori.variable_left) + to_remove.append(ins) + variable_to_replace[ins.lvalue.name] = ins.ori.variable_left + + # Remove the call to value/gas instruction + result = [i for i in result if not i in to_remove] + + # update the real call + for ins in result: + if isinstance(ins, TmpCall): + # use of while if there redirections + while ins.called.name in variable_to_replace: + was_changed = True + ins.call_id = variable_to_replace[ins.called.name].name + calls.append(ins.called) + ins.called = variable_to_replace[ins.called.name] + if isinstance(ins, Argument): + while ins.call_id in variable_to_replace: + was_changed = True + ins.call_id = variable_to_replace[ins.call_id].name + + calls = list(set([str(c) for c in calls])) + idx = 0 + calls_d = {} + for call in calls: + calls_d[str(call)] = idx + idx = idx+1 + + return result + +def propage_type_and_convert_call(result, node): + calls_value = {} + calls_gas = {} + + call_data = [] + + idx = 0 + # use of while len() as result can be modified during the iteration + while idx < len(result): + ins = result[idx] + + if isinstance(ins, TmpCall): + new_ins = extract_tmp_call(ins) + if new_ins: + ins = new_ins + result[idx] = ins + + if isinstance(ins, Argument): + if ins.get_type() in [ArgumentType.GAS]: + assert not ins.call_id in calls_gas + calls_gas[ins.call_id] = ins.argument + elif ins.get_type() in [ArgumentType.VALUE]: + assert not ins.call_id in calls_value + calls_value[ins.call_id] = ins.argument + else: + assert ins.get_type() == ArgumentType.CALL + call_data.append(ins.argument) + + if isinstance(ins, (HighLevelCall, NewContract)): + if ins.call_id in calls_value: + ins.call_value = calls_value[ins.call_id] + if ins.call_id in calls_gas: + ins.call_gas = calls_gas[ins.call_id] + + if isinstance(ins, (Call, NewContract, NewStructure)): + ins.arguments = call_data + call_data = [] + + if is_temporary(ins): + del result[idx] + continue + + new_ins = propagate_types(ins, node) + if new_ins: + if isinstance(new_ins, (list,)): + assert len(new_ins) == 2 + result.insert(idx, new_ins[0]) + result.insert(idx+1, new_ins[1]) + idx = idx + 1 + else: + result[idx] = new_ins + idx = idx +1 + return result + +def convert_to_low_level(ir): + """ + Convert to a transfer/send/or low level call + The funciton assume to receive a correct IR + The checks must be done by the caller + + Additionally convert abi... to solidityfunction + """ + if ir.function_name == 'transfer': + assert len(ir.arguments) == 1 + ir = Transfer(ir.destination, ir.arguments[0]) + return ir + elif ir.function_name == 'send': + assert len(ir.arguments) == 1 + ir = Send(ir.destination, ir.arguments[0], ir.lvalue) + ir.lvalue.set_type(ElementaryType('bool')) + return ir + elif ir.destination.name == 'abi' and ir.function_name in ['encode', + 'encodePacked', + 'encodeWithSelector', + 'encodeWithSignature']: + + call = SolidityFunction('abi.{}()'.format(ir.function_name)) + new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call) + new_ir.arguments = ir.arguments + if isinstance(call.return_type, list) and len(call.return_type) == 1: + new_ir.lvalue.set_type(call.return_type[0]) + else: + new_ir.lvalue.set_type(call.return_type) + return new_ir + elif ir.function_name in ['call', 'delegatecall', 'callcode']: + new_ir = LowLevelCall(ir.destination, + ir.function_name, + ir.nbr_arguments, + ir.lvalue, + ir.type_call) + new_ir.call_gas = ir.call_gas + new_ir.call_value = ir.call_value + new_ir.arguments = ir.arguments + new_ir.lvalue.set_type(ElementaryType('bool')) + return new_ir + logger.error('Incorrect conversion to low level {}'.format(ir)) + exit(-1) + +def convert_to_push(ir, node): + """ + Convert a call to a PUSH operaiton + + The funciton assume to receive a correct IR + The checks must be done by the caller + + May necessitate to create an intermediate operation (InitArray) + As a result, the function return may return a list + """ + if isinstance(ir.arguments[0], list): + ret = [] + + val = TemporaryVariable(node) + operation = InitArray(ir.arguments[0], val) + ret.append(operation) + + ir = Push(ir.destination, val) + + length = Literal(len(operation.init_values)) + t = operation.init_values[0].type + ir.lvalue.set_type(ArrayType(t, length)) + + ret.append(ir) + return ret + + ir = Push(ir.destination, ir.arguments[0]) + return ir + +def look_for_library(contract, ir, node, using_for, t): + for destination in using_for[t]: + lib_contract = contract.slither.get_contract_from_name(str(destination)) + if lib_contract: + lib_call = LibraryCall(lib_contract, + ir.function_name, + ir.nbr_arguments, + ir.lvalue, + ir.type_call) + lib_call.call_gas = ir.call_gas + lib_call.arguments = [ir.destination] + ir.arguments + new_ir = convert_type_library_call(lib_call, lib_contract) + if new_ir: + return new_ir + return None + +def convert_to_library(ir, node, using_for): + contract = node.function.contract + t = ir.destination.type + + if t in using_for: + new_ir = look_for_library(contract, ir, node, using_for, t) + if new_ir: + return new_ir + + if '*' in using_for: + new_ir = look_for_library(contract, ir, node, using_for, '*') + if new_ir: + return new_ir + + return None + +def get_type(t): + """ + Convert a type to a str + If the instance is a Contract, return 'address' instead + """ + if isinstance(t, UserDefinedType): + if isinstance(t.type, Contract): + return 'address' + return str(t) + +def get_sig(ir): + sig = '{}({})' + name = ir.function_name + + args = [] + for arg in ir.arguments: + if isinstance(arg, (list,)): + type_arg = '{}[{}]'.format(get_type(arg[0].type), len(arg)) + elif isinstance(arg, Function): + type_arg = arg.signature_str + else: + type_arg = get_type(arg.type) + args.append(type_arg) + return sig.format(name, ','.join(args)) + +def convert_type_library_call(ir, lib_contract): + sig = get_sig(ir) + func = lib_contract.get_function_from_signature(sig) + if not func: + func = lib_contract.get_state_variable_from_name(ir.function_name) + # In case of multiple binding to the same type + if not func: + # specific lookup when the compiler does implicit conversion + # for example + # myFunc(uint) + # can be called with an uint8 + for function in lib_contract.functions: + if function.name == ir.function_name and len(function.parameters) == len(ir.arguments): + func = function + break + if not func: + return None + ir.function = func + if isinstance(func, Function): + t = func.return_type + # if its not a tuple, return a singleton + if t and len(t) == 1: + t = t[0] + else: + # otherwise its a variable (getter) + t = func.type + if t: + ir.lvalue.set_type(t) + else: + ir.lvalue = None + return ir + +def convert_type_of_high_level_call(ir, contract): + sig = get_sig(ir) + func = contract.get_function_from_signature(sig) + if not func: + func = contract.get_state_variable_from_name(ir.function_name) + if not func: + # specific lookup when the compiler does implicit conversion + # for example + # myFunc(uint) + # can be called with an uint8 + for function in contract.functions: + if function.name == ir.function_name and len(function.parameters) == len(ir.arguments): + func = function + break + # lowlelvel lookup needs to be done at last step + if not func and ir.function_name in ['call', + 'delegatecall', + 'codecall', + 'transfer', + 'send']: + return convert_to_low_level(ir) + if not func: + logger.error('Function not found {}'.format(sig)) + ir.function = func + if isinstance(func, Function): + return_type = func.return_type + # if its not a tuple; return a singleton + if return_type and len(return_type) == 1: + return_type = return_type[0] + else: + # otherwise its a variable (getter) + if isinstance(func.type, MappingType): + return_type = func.type.type_to + elif isinstance(func.type, ArrayType): + return_type = func.type.type + else: + return_type = func.type + if return_type: + ir.lvalue.set_type(return_type) + else: + ir.lvalue = None + + return None + +def propagate_types(ir, node): + # propagate the type + using_for = node.function.contract.using_for + if isinstance(ir, OperationWithLValue): + # Force assignment in case of missing previous correct type + if not ir.lvalue.type: + if isinstance(ir, Assignment): + ir.lvalue.set_type(ir.rvalue.type) + elif isinstance(ir, Binary): + if BinaryType.return_bool(ir.type): + ir.lvalue.set_type(ElementaryType('bool')) + else: + ir.lvalue.set_type(ir.variable_left.type) + elif isinstance(ir, Delete): + # nothing to propagate + pass + elif isinstance(ir, LibraryCall): + return convert_type_library_call(ir, ir.destination) + elif isinstance(ir, HighLevelCall): + t = ir.destination.type + + # Temporary operation (they are removed later) + if t is None: + return + + # convert library + if t in using_for or '*' in using_for: + new_ir = convert_to_library(ir, node, using_for) + if new_ir: + return new_ir + + if isinstance(t, UserDefinedType): + # UserdefinedType + t_type = t.type + if isinstance(t_type, Contract): + contract = node.slither.get_contract_from_name(t_type.name) + return convert_type_of_high_level_call(ir, contract) + + # Convert HighLevelCall to LowLevelCall + if isinstance(t, ElementaryType) and t.name == 'address': + if ir.destination.name == 'this': + return convert_type_of_high_level_call(ir, node.function.contract) + return convert_to_low_level(ir) + + # Convert push operations + # May need to insert a new operation + # Which leads to return a list of operation + if isinstance(t, ArrayType): + if ir.function_name == 'push' and len(ir.arguments) == 1: + return convert_to_push(ir, node) + + elif isinstance(ir, Index): + if isinstance(ir.variable_left.type, MappingType): + ir.lvalue.set_type(ir.variable_left.type.type_to) + elif isinstance(ir.variable_left.type, ArrayType): + ir.lvalue.set_type(ir.variable_left.type.type) + + elif isinstance(ir, InitArray): + length = len(ir.init_values) + t = ir.init_values[0].type + ir.lvalue.set_type(ArrayType(t, length)) + elif isinstance(ir, InternalCall): + # if its not a tuple, return a singleton + return_type = ir.function.return_type + if return_type: + if len(return_type) == 1: + ir.lvalue.set_type(return_type[0]) + elif len(return_type)>1: + ir.lvalue.set_type(return_type) + else: + ir.lvalue = None + elif isinstance(ir, InternalDynamicCall): + # if its not a tuple, return a singleton + return_type = ir.function_type.return_type + if return_type: + if len(return_type) == 1: + ir.lvalue.set_type(return_type[0]) + else: + ir.lvalue.set_type(return_type) + else: + ir.lvalue = None + elif isinstance(ir, LowLevelCall): + # Call are not yet converted + # This should not happen + assert False + elif isinstance(ir, Member): + # TODO we should convert the reference to a temporary if the member is a length or a balance + if ir.variable_right == 'length' and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)): + length = Length(ir.variable_left, ir.lvalue) + ir.lvalue.points_to = ir.variable_left + return ir + if ir.variable_right == 'balance' and isinstance(ir.variable_left.type, ElementaryType): + return Balance(ir.variable_left, ir.lvalue) + left = ir.variable_left + if isinstance(left, (Variable, SolidityVariable)): + t = ir.variable_left.type + elif isinstance(left, (Contract, Enum, Structure)): + t = UserDefinedType(left) + # can be None due to temporary operation + if t: + if isinstance(t, UserDefinedType): + # UserdefinedType + type_t = t.type + if isinstance(type_t, Enum): + ir.lvalue.set_type(t) + elif isinstance(type_t, Structure): + elems = type_t.elems + for elem in elems: + if elem == ir.variable_right: + ir.lvalue.set_type(elems[elem].type) + else: + assert isinstance(type_t, Contract) + elif isinstance(ir, NewArray): + ir.lvalue.set_type(ir.array_type) + elif isinstance(ir, NewContract): + contract = node.slither.get_contract_from_name(ir.contract_name) + ir.lvalue.set_type(UserDefinedType(contract)) + elif isinstance(ir, NewElementaryType): + ir.lvalue.set_type(ir.type) + elif isinstance(ir, NewStructure): + ir.lvalue.set_type(UserDefinedType(ir.structure)) + elif isinstance(ir, Push): + # No change required + pass + elif isinstance(ir, Send): + ir.lvalue.set_type(ElementaryType('bool')) + elif isinstance(ir, SolidityCall): + return_type = ir.function.return_type + if len(return_type) == 1: + ir.lvalue.set_type(return_type[0]) + elif len(return_type)>1: + ir.lvalue.set_type(return_type) + elif isinstance(ir, TypeConversion): + ir.lvalue.set_type(ir.type) + elif isinstance(ir, Unary): + ir.lvalue.set_type(ir.rvalue.type) + elif isinstance(ir, Unpack): + types = ir.tuple.type.type + idx = ir.index + t = types[idx] + ir.lvalue.set_type(t) + elif isinstance(ir, (Argument, TmpCall, TmpNewArray, TmpNewContract, TmpNewStructure, TmpNewElementaryType)): + # temporary operation; they will be removed + pass + else: + logger.error('Not handling {} during type propgation'.format(type(ir))) + exit(-1) + +def apply_ir_heuristics(irs, node): + """ + Apply a set of heuristic to improve slithIR + """ + + irs = integrate_value_gas(irs) + + irs = propage_type_and_convert_call(irs, node) +# irs = remove_temporary(irs) +# irs = replace_calls(irs) + irs = remove_unused(irs) + + find_references_origin(irs) + + #reset_variable_number(irs) + + return irs + +def find_references_origin(irs): + """ + Make lvalue of each Index, Member operation + points to the left variable + """ + for ir in irs: + if isinstance(ir, (Index, Member)): + ir.lvalue.points_to = ir.variable_left + +def is_temporary(ins): + return isinstance(ins, (Argument, + TmpNewElementaryType, + TmpNewContract, + TmpNewArray, + TmpNewStructure)) + + +def remove_temporary(result): + result = [ins for ins in result if not isinstance(ins, (Argument, + TmpNewElementaryType, + TmpNewContract, + TmpNewArray, + TmpNewStructure))] + + return result + +def remove_unused(result): + removed = True + + if not result: + return result + + # dont remove the last elem, as it may be used by RETURN + last_elem = result[-1] + + while removed: + removed = False + + to_keep = [] + to_remove = [] + + # keep variables that are read + # and reference that are written + for ins in result: + to_keep += [str(x) for x in ins.read] + if isinstance(ins, OperationWithLValue) and not isinstance(ins, (Index, Member)): + if isinstance(ins.lvalue, ReferenceVariable): + to_keep += [str(ins.lvalue)] + + for ins in result: + if isinstance(ins, Member): + if not ins.lvalue.name in to_keep and ins != last_elem: + to_remove.append(ins) + removed = True + + result = [i for i in result if not i in to_remove] + return result + + + +def extract_tmp_call(ins): + assert isinstance(ins, TmpCall) + if isinstance(ins.ori, Member): + if isinstance(ins.ori.variable_left, Contract): + libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call) + libcall.call_id = ins.call_id + return libcall + msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call) + msgcall.call_id = ins.call_id + return msgcall + + if isinstance(ins.ori, TmpCall): + r = extract_tmp_call(ins.ori) + return r + if isinstance(ins.called, SolidityVariableComposed): + if str(ins.called) == 'block.blockhash': + ins.called = SolidityFunction('blockhash(uint256)') + elif str(ins.called) == 'this.balance': + return SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call) + + if isinstance(ins.called, SolidityFunction): + return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call) + + if isinstance(ins.ori, TmpNewElementaryType): + return NewElementaryType(ins.ori.type, ins.lvalue) + + if isinstance(ins.ori, TmpNewContract): + op = NewContract(Constant(ins.ori.contract_name), ins.lvalue) + op.call_id = ins.call_id + return op + + if isinstance(ins.ori, TmpNewArray): + return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue) + + if isinstance(ins.called, Structure): + op = NewStructure(ins.called, ins.lvalue) + op.call_id = ins.call_id + return op + + if isinstance(ins.called, Event): + return EventCall(ins.called.name) + + if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): + return InternalDynamicCall(ins.lvalue, ins.called, ins.called.type) + + raise Exception('Not extracted {} {}'.format(type(ins.called), ins)) + +def convert_expression(expression, node): + # handle standlone expression + # such as return true; + from slither.core.cfg.node import NodeType + if isinstance(expression, Literal) and node.type == NodeType.RETURN: + result = [Return(Constant(expression.value))] + return result + if isinstance(expression, Identifier) and node.type == NodeType.RETURN: + result = [Return(expression.value)] + return result + if isinstance(expression, Literal) and node.type in [NodeType.IF, NodeType.IFLOOP]: + result = [Condition(Constant(expression.value))] + return result + if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]: + result = [Condition(expression.value)] + return result + visitor = ExpressionToSlithIR(expression, node) + result = visitor.result() + + result = apply_ir_heuristics(result, node) + + if result: + if node.type in [NodeType.IF, NodeType.IFLOOP]: + assert isinstance(result[-1], (OperationWithLValue)) + result.append(Condition(result[-1].lvalue)) + elif node.type == NodeType.RETURN: + # May return None + if isinstance(result[-1], (OperationWithLValue)): + result.append(Return(result[-1].lvalue)) + + return result diff --git a/slither/slithir/operations/__init__.py b/slither/slithir/operations/__init__.py new file mode 100644 index 000000000..ed59260e2 --- /dev/null +++ b/slither/slithir/operations/__init__.py @@ -0,0 +1,30 @@ +from .assignment import Assignment +from .binary import Binary, BinaryType +from .call import Call +from .condition import Condition +from .delete import Delete +from .event_call import EventCall +from .high_level_call import HighLevelCall +from .index import Index +from .init_array import InitArray +from .internal_call import InternalCall +from .internal_dynamic_call import InternalDynamicCall +from .library_call import LibraryCall +from .low_level_call import LowLevelCall +from .lvalue import OperationWithLValue +from .member import Member +from .new_array import NewArray +from .new_elementary_type import NewElementaryType +from .new_contract import NewContract +from .new_structure import NewStructure +from .operation import Operation +from .push import Push +from .return_operation import Return +from .send import Send +from .solidity_call import SolidityCall +from .transfer import Transfer +from .type_conversion import TypeConversion +from .unary import Unary, UnaryType +from .unpack import Unpack +from .length import Length +from .balance import Balance diff --git a/slither/slithir/operations/assignment.py b/slither/slithir/operations/assignment.py new file mode 100644 index 000000000..3ea100f74 --- /dev/null +++ b/slither/slithir/operations/assignment.py @@ -0,0 +1,39 @@ +import logging + +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable +from slither.slithir.variables import TupleVariable +from slither.core.declarations.function import Function +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + +logger = logging.getLogger("AssignmentOperationIR") + +class Assignment(OperationWithLValue): + + def __init__(self, left_variable, right_variable, variable_return_type): + assert is_valid_lvalue(left_variable) + assert is_valid_rvalue(right_variable) or isinstance(right_variable, (Function, TupleVariable)) + super(Assignment, self).__init__() + self._variables = [left_variable, right_variable] + self._lvalue = left_variable + self._rvalue = right_variable + self._variable_return_type = variable_return_type + + @property + def variables(self): + return list(self._variables) + + @property + def read(self): + return [self.rvalue] + + @property + def variable_return_type(self): + return self._variable_return_type + + @property + def rvalue(self): + return self._rvalue + + def __str__(self): + return '{}({}) := {}({})'.format(self.lvalue, self.lvalue.type, self.rvalue, self.rvalue.type) diff --git a/slither/slithir/operations/balance.py b/slither/slithir/operations/balance.py new file mode 100644 index 000000000..3dd6560fd --- /dev/null +++ b/slither/slithir/operations/balance.py @@ -0,0 +1,26 @@ +import logging +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.declarations import Function +from slither.core.variables.variable import Variable +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.core.solidity_types.elementary_type import ElementaryType + +class Balance(OperationWithLValue): + + def __init__(self, value, lvalue): + assert is_valid_rvalue(value) + assert is_valid_lvalue(lvalue) + self._value = value + self._lvalue = lvalue + lvalue.set_type(ElementaryType('uint256')) + + @property + def read(self): + return [self._value] + + @property + def value(self): + return self._value + + def __str__(self): + return "{} -> BALANCE {}".format(self.lvalue, self.value) diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py new file mode 100644 index 000000000..733b6596b --- /dev/null +++ b/slither/slithir/operations/binary.py @@ -0,0 +1,172 @@ +import logging +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.core.solidity_types import ElementaryType + +logger = logging.getLogger("BinaryOperationIR") + +class BinaryType(object): + POWER = 0 # ** + MULTIPLICATION = 1 # * + DIVISION = 2 # / + MODULO = 3 # % + ADDITION = 4 # + + SUBTRACTION = 5 # - + LEFT_SHIFT = 6 # << + RIGHT_SHIT = 7 # >> + AND = 8 # & + CARET = 9 # ^ + OR = 10 # | + LESS = 11 # < + GREATER = 12 # > + LESS_EQUAL = 13 # <= + GREATER_EQUAL = 14 # >= + EQUAL = 15 # == + NOT_EQUAL = 16 # != + ANDAND = 17 # && + OROR = 18 # || + + @staticmethod + def return_bool(operation_type): + return operation_type in [BinaryType.OROR, + BinaryType.ANDAND, + BinaryType.LESS, + BinaryType.GREATER, + BinaryType.LESS_EQUAL, + BinaryType.GREATER_EQUAL, + BinaryType.EQUAL, + BinaryType.NOT_EQUAL] + + @staticmethod + def get_type(operation_type): + if operation_type == '**': + return BinaryType.POWER + if operation_type == '*': + return BinaryType.MULTIPLICATION + if operation_type == '/': + return BinaryType.DIVISION + if operation_type == '%': + return BinaryType.MODULO + if operation_type == '+': + return BinaryType.ADDITION + if operation_type == '-': + return BinaryType.SUBTRACTION + if operation_type == '<<': + return BinaryType.LEFT_SHIFT + if operation_type == '>>': + return BinaryType.RIGHT_SHIT + if operation_type == '&': + return BinaryType.AND + if operation_type == '^': + return BinaryType.CARET + if operation_type == '|': + return BinaryType.OR + if operation_type == '<': + return BinaryType.LESS + if operation_type == '>': + return BinaryType.GREATER + if operation_type == '<=': + return BinaryType.LESS_EQUAL + if operation_type == '>=': + return BinaryType.GREATER_EQUAL + if operation_type == '==': + return BinaryType.EQUAL + if operation_type == '!=': + return BinaryType.NOT_EQUAL + if operation_type == '&&': + return BinaryType.ANDAND + if operation_type == '||': + return BinaryType.OROR + + logger.error('get_type: Unknown operation type {})'.format(operation_type)) + exit(-1) + + @staticmethod + def str(operation_type): + if operation_type == BinaryType.POWER: + return '**' + if operation_type == BinaryType.MULTIPLICATION: + return '*' + if operation_type == BinaryType.DIVISION: + return '/' + if operation_type == BinaryType.MODULO: + return '%' + if operation_type == BinaryType.ADDITION: + return '+' + if operation_type == BinaryType.SUBTRACTION: + return '-' + if operation_type == BinaryType.LEFT_SHIFT: + return '<<' + if operation_type == BinaryType.RIGHT_SHIT: + return '>>' + if operation_type == BinaryType.AND: + return '&' + if operation_type == BinaryType.CARET: + return '^' + if operation_type == BinaryType.OR: + return '|' + if operation_type == BinaryType.LESS: + return '<' + if operation_type == BinaryType.GREATER: + return '>' + if operation_type == BinaryType.LESS_EQUAL: + return '<=' + if operation_type == BinaryType.GREATER_EQUAL: + return '>=' + if operation_type == BinaryType.EQUAL: + return '==' + if operation_type == BinaryType.NOT_EQUAL: + return '!=' + if operation_type == BinaryType.ANDAND: + return '&&' + if operation_type == BinaryType.OROR: + return '||' + logger.error('str: Unknown operation type {})'.format(operation_type)) + exit(-1) + +class Binary(OperationWithLValue): + + def __init__(self, result, left_variable, right_variable, operation_type): + assert is_valid_rvalue(left_variable) + assert is_valid_rvalue(right_variable) + assert is_valid_lvalue(result) + super(Binary, self).__init__() + self._variables = [left_variable, right_variable] + self._type = operation_type + self._lvalue = result + if BinaryType.return_bool(operation_type): + result.set_type(ElementaryType('bool')) + else: + result.set_type(left_variable.type) + + @property + def read(self): + return [self.variable_left, self.variable_right] + + @property + def get_variable(self): + return self._variables + + @property + def variable_left(self): + return self._variables[0] + + @property + def variable_right(self): + return self._variables[1] + + @property + def type(self): + return self._type + + @property + def type_str(self): + return BinaryType.str(self._type) + + def __str__(self): + return '{}({}) = {} {} {}'.format(str(self.lvalue), + self.lvalue.type, + self.variable_left, + self.type_str, + self.variable_right) diff --git a/slither/slithir/operations/call.py b/slither/slithir/operations/call.py new file mode 100644 index 000000000..25d929c92 --- /dev/null +++ b/slither/slithir/operations/call.py @@ -0,0 +1,17 @@ + +from slither.slithir.operations.operation import Operation + +class Call(Operation): + + def __init__(self): + super(Call, self).__init__() + self._arguments = [] + + @property + def arguments(self): + return self._arguments + + @arguments.setter + def arguments(self, v): + self._arguments = v + diff --git a/slither/slithir/operations/condition.py b/slither/slithir/operations/condition.py new file mode 100644 index 000000000..0de1467e5 --- /dev/null +++ b/slither/slithir/operations/condition.py @@ -0,0 +1,23 @@ +from slither.slithir.operations.operation import Operation + +from slither.slithir.utils.utils import is_valid_rvalue +class Condition(Operation): + """ + Condition + Only present as last operation in conditional node + """ + def __init__(self, value): + assert is_valid_rvalue(value) + super(Condition, self).__init__() + self._value = value + + @property + def read(self): + return [self.value] + + @property + def value(self): + return self._value + + def __str__(self): + return "CONDITION {}".format(self.value) diff --git a/slither/slithir/operations/delete.py b/slither/slithir/operations/delete.py new file mode 100644 index 000000000..02f10809f --- /dev/null +++ b/slither/slithir/operations/delete.py @@ -0,0 +1,27 @@ +import logging +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable + +from slither.slithir.utils.utils import is_valid_lvalue +class Delete(OperationWithLValue): + """ + Delete has a lvalue, as it has for effect to change the value + of its operand + """ + + def __init__(self, variable): + assert is_valid_lvalue(variable) + super(Delete, self).__init__() + self._variable = variable + self._lvalue = variable + + @property + def read(self): + return [self.variable] + + @property + def variable(self): + return self._variable + + def __str__(self): + return "{} = delete {} ".format(self.lvalue, self.variable) diff --git a/slither/slithir/operations/event_call.py b/slither/slithir/operations/event_call.py new file mode 100644 index 000000000..0c7214d60 --- /dev/null +++ b/slither/slithir/operations/event_call.py @@ -0,0 +1,29 @@ + +from slither.slithir.operations.call import Call +from slither.core.variables.variable import Variable + +class EventCall(Call): + def __init__(self, name): + super(EventCall, self).__init__() + self._name = name + # todo add instance of the Event + + @property + def name(self): + return self._name + + @property + def read(self): + def unroll(l): + ret = [] + for x in l: + if not isinstance(x, list): + ret += [x] + else: + ret += unroll(x) + return ret + return unroll(self.arguments) + + def __str__(self): + args = [str(a) for a in self.arguments] + return 'Emit {}({})'.format(self.name, '.'.join(args)) diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py new file mode 100644 index 000000000..fd54a1113 --- /dev/null +++ b/slither/slithir/operations/high_level_call.py @@ -0,0 +1,122 @@ +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable +from slither.core.declarations.solidity_variables import SolidityVariable + +from slither.slithir.utils.utils import is_valid_lvalue +from slither.slithir.variables.constant import Constant + +class HighLevelCall(Call, OperationWithLValue): + """ + High level message call + """ + + def __init__(self, destination, function_name, nbr_arguments, result, type_call): + assert isinstance(function_name, Constant) + assert is_valid_lvalue(result) + self._check_destination(destination) + super(HighLevelCall, self).__init__() + self._destination = destination + self._function_name = function_name + self._nbr_arguments = nbr_arguments + self._type_call = type_call + self._lvalue = result + self._callid = None # only used if gas/value != 0 + self._function_instance = None + + self._call_value = None + self._call_gas = None + + # Development function, to be removed once the code is stable + # It is ovveride by LbraryCall + def _check_destination(self, destination): + assert isinstance(destination, (Variable, SolidityVariable)) + + @property + def call_id(self): + return self._callid + + @call_id.setter + def call_id(self, c): + self._callid = c + + @property + def call_value(self): + return self._call_value + + @call_value.setter + def call_value(self, v): + self._call_value = v + + @property + def call_gas(self): + return self._call_gas + + @call_gas.setter + def call_gas(self, v): + self._call_gas = v + + @property + def read(self): + # if array inside the parameters + def unroll(l): + ret = [] + for x in l: + if not isinstance(x, list): + ret += [x] + else: + ret += unroll(x) + return ret + all_read = [self.destination, self.call_gas, self.call_value] + unroll(self.arguments) + # remove None + return [x for x in all_read if x] + + @property + def destination(self): + return self._destination + + @property + def function_name(self): + return self._function_name + + @property + def function(self): + return self._function_instance + + @function.setter + def function(self, function): + self._function_instance = function + + @property + def nbr_arguments(self): + return self._nbr_arguments + + @property + def type_call(self): + return self._type_call + + def __str__(self): + value = '' + gas = '' + if self.call_value: + value = 'value:{}'.format(self.call_value) + if self.call_gas: + gas = 'gas:{}'.format(self.call_gas) + arguments = [] + if self.arguments: + arguments = self.arguments + + txt = '{}HIGH_LEVEL_CALL, dest:{}({}), function:{}, arguments:{} {} {}' + if not self.lvalue: + lvalue = '' + elif isinstance(self.lvalue.type, (list,)): + lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) + else: + lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type) + return txt.format(lvalue, + self.destination, + self.destination.type, + self.function_name, + [str(x) for x in arguments], + value, + gas) diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py new file mode 100644 index 000000000..ac3140f46 --- /dev/null +++ b/slither/slithir/operations/index.py @@ -0,0 +1,37 @@ +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable +from slither.core.declarations import SolidityVariableComposed +from slither.slithir.variables.reference import ReferenceVariable + +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + + +class Index(OperationWithLValue): + + def __init__(self, result, left_variable, right_variable, index_type): + super(Index, self).__init__() + assert is_valid_lvalue(left_variable) or left_variable == SolidityVariableComposed('msg.data') + assert is_valid_rvalue(right_variable) + assert isinstance(result, ReferenceVariable) + self._variables = [left_variable, right_variable] + self._type = index_type + self._lvalue = result + + @property + def read(self): + return list(self.variables) + + @property + def variables(self): + return self._variables + + @property + def variable_left(self): + return self._variables[0] + + @property + def variable_right(self): + return self._variables[1] + + def __str__(self): + return "{}({}) -> {}[{}]".format(self.lvalue, self.lvalue.type, self.variable_left, self.variable_right) diff --git a/slither/slithir/operations/init_array.py b/slither/slithir/operations/init_array.py new file mode 100644 index 000000000..e85299dba --- /dev/null +++ b/slither/slithir/operations/init_array.py @@ -0,0 +1,48 @@ +import logging +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + +class InitArray(OperationWithLValue): + + def __init__(self, init_values, lvalue): + # init_values can be an array of n dimension + # reduce was removed in py3 + def reduce(xs): + result = True + for i in xs: + result = result and i + return result + def check(elem): + if isinstance(elem, (list,)): + return reduce(elem) + return is_valid_rvalue(elem) + assert check(init_values) + self._init_values = init_values + self._lvalue = lvalue + + @property + def read(self): + # if array inside the init values + def unroll(l): + ret = [] + for x in l: + if not isinstance(x, list): + ret += [x] + else: + ret += unroll(x) + return ret + return unroll(self.init_values) + + @property + def init_values(self): + return list(self._init_values) + + def __str__(self): + + def convert(elem): + if isinstance(elem, (list,)): + return str([convert(x) for x in elem]) + return str(elem) + init_values = convert(self.init_values) + return "{}({}) = {}".format(self.lvalue, self.lvalue.type, init_values) diff --git a/slither/slithir/operations/internal_call.py b/slither/slithir/operations/internal_call.py new file mode 100644 index 000000000..606059ca6 --- /dev/null +++ b/slither/slithir/operations/internal_call.py @@ -0,0 +1,55 @@ +from slither.core.declarations.function import Function +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable + + +class InternalCall(Call, OperationWithLValue): + + def __init__(self, function, nbr_arguments, result, type_call): + assert isinstance(function, Function) + super(InternalCall, self).__init__() + self._function = function + self._nbr_arguments = nbr_arguments + self._type_call = type_call + self._lvalue = result + + @property + def read(self): + # if array inside the parameters + def unroll(l): + ret = [] + for x in l: + if not isinstance(x, list): + ret += [x] + else: + ret += unroll(x) + return ret + return list(unroll(self.arguments)) + + @property + def function(self): + return self._function + + @property + def nbr_arguments(self): + return self._nbr_arguments + + @property + def type_call(self): + return self._type_call + + def __str__(self): + args = [str(a) for a in self.arguments] + if not self.lvalue: + lvalue = '' + elif isinstance(self.lvalue.type, (list,)): + lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) + else: + lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type) + txt = '{}INTERNAL_CALL, {}.{}({})' + return txt.format(lvalue, + self.function.contract.name, + self.function.full_name, + ','.join(args)) + diff --git a/slither/slithir/operations/internal_dynamic_call.py b/slither/slithir/operations/internal_dynamic_call.py new file mode 100644 index 000000000..5e40d3e57 --- /dev/null +++ b/slither/slithir/operations/internal_dynamic_call.py @@ -0,0 +1,54 @@ +from slither.core.declarations.function import Function +from slither.core.solidity_types import FunctionType +from slither.core.variables.variable import Variable +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue + + +class InternalDynamicCall(Call, OperationWithLValue): + + def __init__(self, lvalue, function, function_type): + assert isinstance(function_type, FunctionType) + assert isinstance(function, Variable) + assert is_valid_lvalue(lvalue) + super(InternalDynamicCall, self).__init__() + self._function = function + self._function_type = function_type + self._lvalue = lvalue + + @property + def read(self): + # if array inside the parameters + def unroll(l): + ret = [] + for x in l: + if not isinstance(x, list): + ret += [x] + else: + ret += unroll(x) + return ret + + return unroll(self.arguments) + [self.function] + + @property + def function(self): + return self._function + + @property + def function_type(self): + return self._function_type + + def __str__(self): + args = [str(a) for a in self.arguments] + if not self.lvalue: + lvalue = '' + elif isinstance(self.lvalue.type, (list,)): + lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) + else: + lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type) + txt = '{}INTERNAL_DYNAMIC_CALL {}({})' + return txt.format(lvalue, + self.function.name, + ','.join(args)) + diff --git a/slither/slithir/operations/length.py b/slither/slithir/operations/length.py new file mode 100644 index 000000000..10b16ff80 --- /dev/null +++ b/slither/slithir/operations/length.py @@ -0,0 +1,26 @@ +import logging +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.declarations import Function +from slither.core.variables.variable import Variable +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.core.solidity_types.elementary_type import ElementaryType + +class Length(OperationWithLValue): + + def __init__(self, value, lvalue): + assert is_valid_rvalue(value) + assert is_valid_lvalue(lvalue) + self._value = value + self._lvalue = lvalue + lvalue.set_type(ElementaryType('uint256')) + + @property + def read(self): + return [self._value] + + @property + def value(self): + return self._value + + def __str__(self): + return "{} -> LENGTH {}".format(self.lvalue, self.value) diff --git a/slither/slithir/operations/library_call.py b/slither/slithir/operations/library_call.py new file mode 100644 index 000000000..76abfbe8a --- /dev/null +++ b/slither/slithir/operations/library_call.py @@ -0,0 +1,33 @@ +from slither.slithir.operations.high_level_call import HighLevelCall +from slither.core.declarations.contract import Contract + +class LibraryCall(HighLevelCall): + """ + High level message call + """ + # Development function, to be removed once the code is stable + def _check_destination(self, destination): + assert isinstance(destination, (Contract)) + + def __str__(self): + gas = '' + if self.call_gas: + gas = 'gas:{}'.format(self.call_gas) + arguments = [] + if self.arguments: + arguments = self.arguments + if not self.lvalue: + lvalue = '' + elif isinstance(self.lvalue.type, (list,)): + lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) + else: + lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type) + txt = '{}LIBRARY_CALL, dest:{}, function:{}, arguments:{} {}' + return txt.format(lvalue, + self.destination, + self.function_name, + [str(x) for x in arguments], + gas) + + + diff --git a/slither/slithir/operations/low_level_call.py b/slither/slithir/operations/low_level_call.py new file mode 100644 index 000000000..4f6bf6f34 --- /dev/null +++ b/slither/slithir/operations/low_level_call.py @@ -0,0 +1,91 @@ +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable +from slither.core.declarations.solidity_variables import SolidityVariable + +from slither.slithir.variables.constant import Constant + +class LowLevelCall(Call, OperationWithLValue): + """ + High level message call + """ + + def __init__(self, destination, function_name, nbr_arguments, result, type_call): + assert isinstance(destination, (Variable, SolidityVariable)) + assert isinstance(function_name, Constant) + super(LowLevelCall, self).__init__() + self._destination = destination + self._function_name = function_name + self._nbr_arguments = nbr_arguments + self._type_call = type_call + self._lvalue = result + self._callid = None # only used if gas/value != 0 + + self._call_value = None + self._call_gas = None + + @property + def call_id(self): + return self._callid + + @call_id.setter + def call_id(self, c): + self._callid = c + + @property + def call_value(self): + return self._call_value + + @call_value.setter + def call_value(self, v): + self._call_value = v + + @property + def call_gas(self): + return self._call_gas + + @call_gas.setter + def call_gas(self, v): + self._call_gas = v + + @property + def read(self): + all_read = [self.destination, self.call_gas, self.call_value] + self.arguments + # remove None + return [x for x in all_read if x] + + @property + def destination(self): + return self._destination + + @property + def function_name(self): + return self._function_name + + @property + def nbr_arguments(self): + return self._nbr_arguments + + @property + def type_call(self): + return self._type_call + + def __str__(self): + value = '' + gas = '' + if self.call_value: + value = 'value:{}'.format(self.call_value) + if self.call_gas: + gas = 'gas:{}'.format(self.call_gas) + arguments = [] + if self.arguments: + arguments = self.arguments + txt = '{}({}) = LOW_LEVEL_CALL, dest:{}, function:{}, arguments:{} {} {}' + return txt.format(self.lvalue, + self.lvalue.type, + self.destination, + self.function_name, + [str(x) for x in arguments], + value, + gas) + diff --git a/slither/slithir/operations/lvalue.py b/slither/slithir/operations/lvalue.py new file mode 100644 index 000000000..ce7976d34 --- /dev/null +++ b/slither/slithir/operations/lvalue.py @@ -0,0 +1,23 @@ +from slither.slithir.operations.operation import Operation + +class OperationWithLValue(Operation): + ''' + Operation with a lvalue + ''' + + def __init__(self): + super(OperationWithLValue, self).__init__() + + self._lvalue = None + + @property + def lvalue(self): + return self._lvalue + + @property + def used(self): + return self.read + [self.lvalue] + + @lvalue.setter + def lvalue(self, lvalue): + self._lvalue = lvalue diff --git a/slither/slithir/operations/member.py b/slither/slithir/operations/member.py new file mode 100644 index 000000000..c3d157df1 --- /dev/null +++ b/slither/slithir/operations/member.py @@ -0,0 +1,37 @@ +from slither.core.expressions.expression import Expression +from slither.core.expressions.expression_typed import ExpressionTyped +from slither.core.solidity_types.type import Type +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.constant import Constant + +from slither.core.declarations.contract import Contract +from slither.core.declarations.enum import Enum + +class Member(OperationWithLValue): + + def __init__(self, variable_left, variable_right, result): + assert is_valid_rvalue(variable_left) or isinstance(variable_left, (Contract, Enum)) + assert isinstance(variable_right, Constant) + assert isinstance(result, ReferenceVariable) + super(Member, self).__init__() + self._variable_left = variable_left + self._variable_right = variable_right + self._lvalue = result + + @property + def read(self): + return [self.variable_left, self.variable_right] + + @property + def variable_left(self): + return self._variable_left + + @property + def variable_right(self): + return self._variable_right + + def __str__(self): + return '{}({}) -> {}.{}'.format(self.lvalue, self.lvalue.type, self.variable_left, self.variable_right) + diff --git a/slither/slithir/operations/new_array.py b/slither/slithir/operations/new_array.py new file mode 100644 index 000000000..213a5c487 --- /dev/null +++ b/slither/slithir/operations/new_array.py @@ -0,0 +1,38 @@ +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.operations.call import Call +from slither.core.solidity_types.type import Type + +class NewArray(Call, OperationWithLValue): + + def __init__(self, depth, array_type, lvalue): + super(NewArray, self).__init__() + assert isinstance(array_type, Type) + self._depth = depth + self._array_type = array_type + + self._lvalue = lvalue + + @property + def array_type(self): + return self._array_type + + @property + def read(self): + # if array inside the parameters + def unroll(l): + ret = [] + for x in l: + if not isinstance(x, list): + ret += [x] + else: + ret += unroll(x) + return ret + return unroll(self.arguments) + + @property + def depth(self): + return self._depth + + def __str__(self): + args = [str(a) for a in self.arguments] + return '{} = new {}{}({})'.format(self.lvalue, self.array_type, '[]'*self.depth, ','.join(args)) diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py new file mode 100644 index 000000000..0646a70db --- /dev/null +++ b/slither/slithir/operations/new_contract.py @@ -0,0 +1,58 @@ +from slither.core.declarations.contract import Contract +from slither.slithir.operations import Call, OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.slithir.variables.constant import Constant + + +class NewContract(Call, OperationWithLValue): + + def __init__(self, contract_name, lvalue): + assert isinstance(contract_name, Constant) + assert is_valid_lvalue(lvalue) + super(NewContract, self).__init__() + self._contract_name = contract_name + # todo create analyze to add the contract instance + self._lvalue = lvalue + self._callid = None # only used if gas/value != 0 + self._call_value = None + @property + def call_value(self): + return self._call_value + + @call_value.setter + def call_value(self, v): + self._call_value = v + + @property + def call_id(self): + return self._callid + + @call_id.setter + def call_id(self, c): + self._callid = c + + + @property + def contract_name(self): + return self._contract_name + + + @property + def read(self): + # if array inside the parameters + def unroll(l): + ret = [] + for x in l: + if not isinstance(x, list): + ret += [x] + else: + ret += unroll(x) + return ret + return unroll(self.arguments) + + def __str__(self): + value = '' + if self.call_value: + value = 'value:{}'.format(self.call_value) + args = [str(a) for a in self.arguments] + return '{} = new {}({}) {}'.format(self.lvalue, self.contract_name, ','.join(args), value) diff --git a/slither/slithir/operations/new_elementary_type.py b/slither/slithir/operations/new_elementary_type.py new file mode 100644 index 000000000..90714abb6 --- /dev/null +++ b/slither/slithir/operations/new_elementary_type.py @@ -0,0 +1,27 @@ +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.operations.call import Call +from slither.core.solidity_types.elementary_type import ElementaryType + +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + +class NewElementaryType(Call, OperationWithLValue): + + def __init__(self, new_type, lvalue): + assert isinstance(new_type, ElementaryType) + assert is_valid_lvalue(lvalue) + super(NewElementaryType, self).__init__() + self._type = new_type + self._lvalue = lvalue + + @property + def type(self): + return self._type + + @property + def read(self): + return list(self.arguments) + + def __str__(self): + args = [str(a) for a in self.arguments] + + return '{} = new {}({})'.format(self.lvalue, self._type, ','.join(args)) diff --git a/slither/slithir/operations/new_structure.py b/slither/slithir/operations/new_structure.py new file mode 100644 index 000000000..302ac7ff6 --- /dev/null +++ b/slither/slithir/operations/new_structure.py @@ -0,0 +1,41 @@ +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue + +from slither.slithir.utils.utils import is_valid_lvalue + +from slither.core.declarations.structure import Structure + +class NewStructure(Call, OperationWithLValue): + + def __init__(self, structure, lvalue): + super(NewStructure, self).__init__() + assert isinstance(structure, Structure) + assert is_valid_lvalue(lvalue) + self._structure = structure + # todo create analyze to add the contract instance + self._lvalue = lvalue + + @property + def read(self): + # if array inside the parameters + def unroll(l): + ret = [] + for x in l: + if not isinstance(x, list): + ret += [x] + else: + ret += unroll(x) + return ret + return unroll(self.arguments) + + @property + def structure(self): + return self._structure + + @property + def structure_name(self): + return self.structure.name + + def __str__(self): + args = [str(a) for a in self.arguments] + return '{} = new {}({})'.format(self.lvalue, self.structure_name, ','.join(args)) diff --git a/slither/slithir/operations/operation.py b/slither/slithir/operations/operation.py new file mode 100644 index 000000000..c4c338c25 --- /dev/null +++ b/slither/slithir/operations/operation.py @@ -0,0 +1,29 @@ +import abc +from slither.core.context.context import Context + +class AbstractOperation(abc.ABC): + + @property + @abc.abstractmethod + def read(self): + """ + Return the list of variables READ + """ + pass + + @property + @abc.abstractmethod + def used(self): + """ + Return the list of variables used + """ + pass + +class Operation(Context, AbstractOperation): + + @property + def used(self): + """ + By default used is all the variables read + """ + return self.read diff --git a/slither/slithir/operations/push.py b/slither/slithir/operations/push.py new file mode 100644 index 000000000..77a97e65e --- /dev/null +++ b/slither/slithir/operations/push.py @@ -0,0 +1,28 @@ +import logging +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.declarations import Function +from slither.core.variables.variable import Variable +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + +class Push(OperationWithLValue): + + def __init__(self, array, value): + assert is_valid_rvalue(value) or isinstance(value, Function) + assert is_valid_lvalue(array) + self._value = value + self._lvalue = array + + @property + def read(self): + return [self._value] + + @property + def array(self): + return self._lvalue + + @property + def value(self): + return self._value + + def __str__(self): + return "PUSH {} in {}".format(self.value, self.lvalue) diff --git a/slither/slithir/operations/return_operation.py b/slither/slithir/operations/return_operation.py new file mode 100644 index 000000000..225e40823 --- /dev/null +++ b/slither/slithir/operations/return_operation.py @@ -0,0 +1,27 @@ +from slither.slithir.operations.operation import Operation + +from slither.slithir.variables.tuple import TupleVariable +from slither.slithir.utils.utils import is_valid_rvalue +class Return(Operation): + """ + Return + Only present as last operation in RETURN node + """ + def __init__(self, value): + # Note: Can return None + # ex: return call() + # where call() dont return + assert is_valid_rvalue(value) or isinstance(value, TupleVariable) or value == None + super(Return, self).__init__() + self._value = value + + @property + def read(self): + return [self.value] + + @property + def value(self): + return self._value + + def __str__(self): + return "RETURN {}".format(self.value) diff --git a/slither/slithir/operations/send.py b/slither/slithir/operations/send.py new file mode 100644 index 000000000..201de989a --- /dev/null +++ b/slither/slithir/operations/send.py @@ -0,0 +1,37 @@ +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable +from slither.core.declarations.solidity_variables import SolidityVariable + +from slither.slithir.utils.utils import is_valid_lvalue +from slither.slithir.variables.constant import Constant + +class Send(Call, OperationWithLValue): + + def __init__(self, destination, value, result): + assert is_valid_lvalue(result) + assert isinstance(destination, (Variable, SolidityVariable)) + super(Send, self).__init__() + self._destination = destination + self._lvalue = result + + self._call_value = value + + @property + def call_value(self): + return self._call_value + + @property + def read(self): + return [self.destination, self.call_value] + + @property + def destination(self): + return self._destination + + + def __str__(self): + value = 'value:{}'.format(self.call_value) + return str(self.lvalue) +' = SEND dest:{} {}'.format(self.destination, value) +# + diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py new file mode 100644 index 000000000..070a70b46 --- /dev/null +++ b/slither/slithir/operations/solidity_call.py @@ -0,0 +1,38 @@ +from slither.core.declarations.solidity_variables import SolidityFunction +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable + + +class SolidityCall(Call, OperationWithLValue): + + def __init__(self, function, nbr_arguments, result, type_call): + assert isinstance(function, SolidityFunction) + super(SolidityCall, self).__init__() + self._function = function + self._nbr_arguments = nbr_arguments + self._type_call = type_call + self._lvalue = result + + @property + def read(self): + return list(self.arguments) + + @property + def function(self): + return self._function + + @property + def nbr_arguments(self): + return self._nbr_arguments + + @property + def type_call(self): + return self._type_call + + def __str__(self): + args = [str(a) for a in self.arguments] + return str(self.lvalue) +' = SOLIDITY_CALL {}({})'.format(self.function.full_name, ','.join(args)) + # return str(self.lvalue) +' = INTERNALCALL {} (arg {})'.format(self.function, + # self.nbr_arguments) + diff --git a/slither/slithir/operations/transfer.py b/slither/slithir/operations/transfer.py new file mode 100644 index 000000000..b334d02ce --- /dev/null +++ b/slither/slithir/operations/transfer.py @@ -0,0 +1,31 @@ +from slither.slithir.operations.call import Call +from slither.core.variables.variable import Variable +from slither.core.declarations.solidity_variables import SolidityVariable + +class Transfer(Call): + + def __init__(self, destination, value): + assert isinstance(destination, (Variable, SolidityVariable)) + self._destination = destination + super(Transfer, self).__init__() + + self._call_value = value + + + @property + def call_value(self): + return self._call_value + + @property + def read(self): + return [self.destination, self.call_value] + + @property + def destination(self): + return self._destination + + def __str__(self): + value = 'value:{}'.format(self.call_value) + return 'Transfer dest:{} {}'.format(self.destination, value) + + diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py new file mode 100644 index 000000000..661f38e1a --- /dev/null +++ b/slither/slithir/operations/type_conversion.py @@ -0,0 +1,33 @@ +from slither.slithir.operations.lvalue import OperationWithLValue + +from slither.core.variables.variable import Variable + +from slither.core.solidity_types.type import Type +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + +class TypeConversion(OperationWithLValue): + + def __init__(self, result, variable, variable_type): + assert is_valid_rvalue(variable) + assert is_valid_lvalue(result) + assert isinstance(variable_type, Type) + + self._variable = variable + self._type = variable_type + self._lvalue = result + + + @property + def variable(self): + return self._variable + + @property + def type(self): + return self._type + + @property + def read(self): + return [self.variable] + + def __str__(self): + return str(self.lvalue) +' = CONVERT {} to {}'.format(self.variable, self.type) diff --git a/slither/slithir/operations/unary.py b/slither/slithir/operations/unary.py new file mode 100644 index 000000000..44bcdf0bc --- /dev/null +++ b/slither/slithir/operations/unary.py @@ -0,0 +1,56 @@ +import logging +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable + +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + +logger = logging.getLogger("BinaryOperationIR") + +class UnaryType: + BANG = 0 # ! + TILD = 1 # ~ + + @staticmethod + def get_type(operation_type, isprefix): + if isprefix: + if operation_type == '!': + return UnaryType.BANG + if operation_type == '~': + return UnaryType.TILD + logger.error('get_type: Unknown operation type {}'.format(operation_type)) + exit(-1) + + @staticmethod + def str(operation_type): + if operation_type == UnaryType.BANG: + return '!' + if operation_type == UnaryType.TILD: + return '~' + + logger.error('str: Unknown operation type {}'.format(operation_type)) + exit(-1) + +class Unary(OperationWithLValue): + + def __init__(self, result, variable, operation_type): + assert is_valid_rvalue(variable) + assert is_valid_lvalue(result) + super(Unary, self).__init__() + self._variable = variable + self._type = operation_type + self._lvalue = result + + @property + def read(self): + return [self._variable] + + @property + def rvalue(self): + return self._variable + + @property + def type_str(self): + return UnaryType.str(self._type) + + def __str__(self): + return "{} = {} {} ".format(self.lvalue, self.type_str, self.rvalue) diff --git a/slither/slithir/operations/unpack.py b/slither/slithir/operations/unpack.py new file mode 100644 index 000000000..5841aec59 --- /dev/null +++ b/slither/slithir/operations/unpack.py @@ -0,0 +1,34 @@ +import logging +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.variables.tuple import TupleVariable + +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue + +class Unpack(OperationWithLValue): + + def __init__(self, result, tuple_var, idx): + assert is_valid_lvalue(result) + assert isinstance(tuple_var, TupleVariable) + assert isinstance(idx, int) + super(Unpack, self).__init__() + self._tuple = tuple_var + self._idx = idx + self._lvalue = result + + @property + def read(self): + return [self.tuple] + + @property + def tuple(self): + return self._tuple + + @property + def index(self): + return self._idx + + def __str__(self): + return "{}({})= UNPACK {} index: {} ".format(self.lvalue, + self.lvalue.type, + self.tuple, + self.index) diff --git a/slither/slithir/tmp_operations/__init__.py b/slither/slithir/tmp_operations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/slithir/tmp_operations/argument.py b/slither/slithir/tmp_operations/argument.py new file mode 100644 index 000000000..4c04924c2 --- /dev/null +++ b/slither/slithir/tmp_operations/argument.py @@ -0,0 +1,47 @@ +from enum import Enum +from slither.slithir.operations.operation import Operation + +class ArgumentType(Enum): + CALL = 0 + VALUE = 1 + GAS = 2 + DATA = 3 + +class Argument(Operation): + + def __init__(self, argument): + super(Argument, self).__init__() + self._argument = argument + self._type = ArgumentType.CALL + self._callid = None + + + @property + def argument(self): + return self._argument + + @property + def call_id(self): + return self._callid + + @call_id.setter + def call_id(self, c): + self._callid = c + + @property + def read(self): + return [self.argument] + + def set_type(self, t): + assert isinstance(t, ArgumentType) + self._type = t + + def get_type(self): + return self._type + + def __str__(self): + call_id = 'none' + if self.call_id: + call_id = '(id ({}))'.format(self.call_id) + return 'ARG_{} {} {}'.format(self._type.name, str(self._argument), call_id) + diff --git a/slither/slithir/tmp_operations/tmp_call.py b/slither/slithir/tmp_operations/tmp_call.py new file mode 100644 index 000000000..0d7fda1e9 --- /dev/null +++ b/slither/slithir/tmp_operations/tmp_call.py @@ -0,0 +1,65 @@ +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.variables.variable import Variable +from slither.core.declarations.solidity_variables import SolidityVariableComposed, SolidityFunction +from slither.core.declarations.structure import Structure +from slither.core.declarations.event import Event + + +class TmpCall(OperationWithLValue): + + def __init__(self, called, nbr_arguments, result, type_call): + assert isinstance(called, (Variable, + SolidityVariableComposed, + SolidityFunction, + Structure, + Event)) + super(TmpCall, self).__init__() + self._called = called + self._nbr_arguments = nbr_arguments + self._type_call = type_call + self._lvalue = result + self._ori = None # + self._callid = None + + @property + def call_id(self): + return self._callid + + @property + def read(self): + return [self.called] + + @call_id.setter + def call_id(self, c): + self._callid = c + + @property + def called(self): + return self._called + + @property + def read(self): + return [self.called] + + @called.setter + def called(self, c): + self._called = c + + @property + def nbr_arguments(self): + return self._nbr_arguments + + @property + def type_call(self): + return self._type_call + + @property + def ori(self): + return self._ori + + def set_ori(self, ori): + self._ori = ori + + def __str__(self): + return str(self.lvalue) +' = TMPCALL{} '.format(self.nbr_arguments)+ str(self._called) + diff --git a/slither/slithir/tmp_operations/tmp_new_array.py b/slither/slithir/tmp_operations/tmp_new_array.py new file mode 100644 index 000000000..e4deb4fe7 --- /dev/null +++ b/slither/slithir/tmp_operations/tmp_new_array.py @@ -0,0 +1,27 @@ +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.solidity_types.type import Type + +class TmpNewArray(OperationWithLValue): + + def __init__(self, depth, array_type, lvalue): + super(TmpNewArray, self).__init__() + assert isinstance(array_type, Type) + self._depth = depth + self._array_type = array_type + self._lvalue = lvalue + + @property + def array_type(self): + return self._array_type + + @property + def read(self): + return [] + + @property + def depth(self): + return self._depth + + def __str__(self): + return '{} = new {}{}'.format(self.lvalue, self.array_type, '[]'*self._depth) + diff --git a/slither/slithir/tmp_operations/tmp_new_contract.py b/slither/slithir/tmp_operations/tmp_new_contract.py new file mode 100644 index 000000000..35521d068 --- /dev/null +++ b/slither/slithir/tmp_operations/tmp_new_contract.py @@ -0,0 +1,20 @@ +from slither.slithir.operations.lvalue import OperationWithLValue + +class TmpNewContract(OperationWithLValue): + + def __init__(self, contract_name, lvalue): + super(TmpNewContract, self).__init__() + self._contract_name = contract_name + self._lvalue = lvalue + + @property + def contract_name(self): + return self._contract_name + + @property + def read(self): + return [] + + def __str__(self): + return '{} = new {}'.format(self.lvalue, self.contract_name) + diff --git a/slither/slithir/tmp_operations/tmp_new_elementary_type.py b/slither/slithir/tmp_operations/tmp_new_elementary_type.py new file mode 100644 index 000000000..1fcb65334 --- /dev/null +++ b/slither/slithir/tmp_operations/tmp_new_elementary_type.py @@ -0,0 +1,21 @@ +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.solidity_types.elementary_type import ElementaryType + +class TmpNewElementaryType(OperationWithLValue): + + def __init__(self, new_type, lvalue): + assert isinstance(new_type, ElementaryType) + super(TmpNewElementaryType, self).__init__() + self._type = new_type + self._lvalue = lvalue + + @property + def read(self): + return [] + + @property + def type(self): + return self._type + + def __str__(self): + return '{} = new {}'.format(self.lvalue, self._type) diff --git a/slither/slithir/tmp_operations/tmp_new_structure.py b/slither/slithir/tmp_operations/tmp_new_structure.py new file mode 100644 index 000000000..90f11d115 --- /dev/null +++ b/slither/slithir/tmp_operations/tmp_new_structure.py @@ -0,0 +1,20 @@ +from slither.slithir.operations.lvalue import OperationWithLValue + +class TmpNewStructure(OperationWithLValue): + + def __init__(self, contract_name, lvalue): + super(TmpNewStructure, self).__init__() + self._contract_name = contract_name + self._lvalue = lvalue + + @property + def contract_name(self): + return self._contract_name + + @property + def read(self): + return [] + + def __str__(self): + return '{} = tmpnew {}'.format(self.lvalue, self.contract_name) + diff --git a/slither/slithir/utils/__init__.py b/slither/slithir/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py new file mode 100644 index 000000000..74ef1b2cf --- /dev/null +++ b/slither/slithir/utils/utils.py @@ -0,0 +1,16 @@ +from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.state_variable import StateVariable + +from slither.core.declarations.solidity_variables import SolidityVariable + +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.tuple import TupleVariable + +def is_valid_rvalue(v): + return isinstance(v, (StateVariable, LocalVariable, TemporaryVariable, Constant, SolidityVariable, ReferenceVariable)) + +def is_valid_lvalue(v): + return isinstance(v, (StateVariable, LocalVariable, TemporaryVariable, ReferenceVariable, TupleVariable)) + diff --git a/slither/slithir/utils/variable_number.py b/slither/slithir/utils/variable_number.py new file mode 100644 index 000000000..d1ebc5b65 --- /dev/null +++ b/slither/slithir/utils/variable_number.py @@ -0,0 +1,23 @@ +from slither.slithir.variables import (Constant, ReferenceVariable, + TemporaryVariable, TupleVariable) +from slither.slithir.operations import OperationWithLValue + +def transform_slithir_vars_to_ssa(function): + """ + Transform slithIR vars to SSA + """ + variables = [] + for node in function.nodes: + for ir in node.irs: + if isinstance(ir, OperationWithLValue) and not ir.lvalue in variables: + variables += [ir.lvalue] + + tmp_variables = [v for v in variables if isinstance(v, TemporaryVariable)] + for idx in range(len(tmp_variables)): + tmp_variables[idx].index = idx + ref_variables = [v for v in variables if isinstance(v, ReferenceVariable)] + for idx in range(len(ref_variables)): + ref_variables[idx].index = idx + tuple_variables = [v for v in variables if isinstance(v, TupleVariable)] + for idx in range(len(tuple_variables)): + tuple_variables[idx].index = idx diff --git a/slither/slithir/variables/__init__.py b/slither/slithir/variables/__init__.py new file mode 100644 index 000000000..ad005c9be --- /dev/null +++ b/slither/slithir/variables/__init__.py @@ -0,0 +1,4 @@ +from .constant import Constant +from .reference import ReferenceVariable +from .temporary import TemporaryVariable +from .tuple import TupleVariable diff --git a/slither/slithir/variables/constant.py b/slither/slithir/variables/constant.py new file mode 100644 index 000000000..f16d686a6 --- /dev/null +++ b/slither/slithir/variables/constant.py @@ -0,0 +1,32 @@ +from slither.core.variables.variable import Variable +from slither.core.solidity_types.elementary_type import ElementaryType + +class Constant(Variable): + + def __init__(self, val): + super(Constant, self).__init__() + assert isinstance(val, str) + if val.isdigit(): + self._type = ElementaryType('uint256') + self._val = int(val) + else: + self._type = ElementaryType('string') + self._val = val + + @property + def value(self): + ''' + Return the value. + If the expression was an hexadecimal delcared as hex'...' + return a str + Returns: + (str, int) + ''' + return self._val + + def __str__(self): + return str(self.value) + + + def __eq__(self, other): + return self.value == other diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py new file mode 100644 index 000000000..defccaaea --- /dev/null +++ b/slither/slithir/variables/reference.py @@ -0,0 +1,49 @@ + +from slither.core.children.child_node import ChildNode +from slither.core.declarations import Contract, Enum, SolidityVariable +from slither.core.variables.variable import Variable + + +class ReferenceVariable(ChildNode, Variable): + + COUNTER = 0 + + def __init__(self, node): + super(ReferenceVariable, self).__init__() + self._index = ReferenceVariable.COUNTER + ReferenceVariable.COUNTER += 1 + self._points_to = None + self._node = node + + @property + def index(self): + return self._index + + @index.setter + def index(self, idx): + self._index = idx + + @property + def points_to(self): + """ + Return the variable pointer by the reference + It is the left member of a Index or Member operator + """ + return self._points_to + + @points_to.setter + def points_to(self, points_to): + # Can only be a rvalue of + # Member or Index operator + from slither.slithir.utils.utils import is_valid_lvalue + assert is_valid_lvalue(points_to) \ + or isinstance(points_to, (SolidityVariable, Contract, Enum)) + + self._points_to = points_to + + @property + def name(self): + return 'REF_{}'.format(self.index) + + def __str__(self): + return self.name diff --git a/slither/slithir/variables/temporary.py b/slither/slithir/variables/temporary.py new file mode 100644 index 000000000..a736a3dd1 --- /dev/null +++ b/slither/slithir/variables/temporary.py @@ -0,0 +1,28 @@ + +from slither.core.variables.variable import Variable +from slither.core.children.child_node import ChildNode + +class TemporaryVariable(ChildNode, Variable): + + COUNTER = 0 + + def __init__(self, node): + super(TemporaryVariable, self).__init__() + self._index = TemporaryVariable.COUNTER + TemporaryVariable.COUNTER += 1 + self._node = node + + @property + def index(self): + return self._index + + @index.setter + def index(self, idx): + self._index = idx + + @property + def name(self): + return 'TMP_{}'.format(self.index) + + def __str__(self): + return self.name diff --git a/slither/slithir/variables/tuple.py b/slither/slithir/variables/tuple.py new file mode 100644 index 000000000..1b9e73353 --- /dev/null +++ b/slither/slithir/variables/tuple.py @@ -0,0 +1,27 @@ + +from slither.core.variables.variable import Variable + +from slither.core.solidity_types.type import Type +class TupleVariable(Variable): + + COUNTER = 0 + + def __init__(self): + super(TupleVariable, self).__init__() + self._index = TupleVariable.COUNTER + TupleVariable.COUNTER += 1 + + @property + def index(self): + return self._index + + @index.setter + def index(self, idx): + self._index = idx + + @property + def name(self): + return 'TUPLE_{}'.format(self.index) + + def __str__(self): + return self.name diff --git a/slither/solcParsing/cfg/nodeSolc.py b/slither/solcParsing/cfg/nodeSolc.py deleted file mode 100644 index 1e953085b..000000000 --- a/slither/solcParsing/cfg/nodeSolc.py +++ /dev/null @@ -1,51 +0,0 @@ -from slither.core.cfg.node import Node -from slither.solcParsing.expressions.expressionParsing import parse_expression -from slither.visitors.expression.readVar import ReadVar -from slither.visitors.expression.writeVar import WriteVar -from slither.visitors.expression.findCalls import FindCalls - -from slither.visitors.expression.exportValues import ExportValues -from slither.core.declarations.solidityVariables import SolidityVariable, SolidityFunction -from slither.core.declarations.function import Function - - -from slither.core.variables.stateVariable import StateVariable - -class NodeSolc(Node): - - def __init__(self, nodeType, nodeId): - super(NodeSolc, self).__init__(nodeType, nodeId) - self._unparsed_expression = None - - def add_unparsed_expression(self, expression): - assert self._unparsed_expression is None - self._unparsed_expression = expression - - def analyze_expressions(self, caller_context): - if self._unparsed_expression: - expression = parse_expression(self._unparsed_expression, caller_context) - - pp = ReadVar(expression) - self._expression_vars_read = pp.result() - vars_read = [ExportValues(v).result() for v in self._expression_vars_read] - self._vars_read = [item for sublist in vars_read for item in sublist] - self._state_vars_read = [x for x in self.variables_read if\ - isinstance(x, (StateVariable))] - self._solidity_vars_read = [x for x in self.variables_read if\ - isinstance(x, (SolidityVariable))] - - pp = WriteVar(expression) - self._expression_vars_written = pp.result() - vars_written = [ExportValues(v).result() for v in self._expression_vars_written] - self._vars_written = [item for sublist in vars_written for item in sublist] - self._state_vars_written = [x for x in self.variables_written if\ - isinstance(x, StateVariable)] - - pp = FindCalls(expression) - self._expression_calls = pp.result() - calls = [ExportValues(c).result() for c in self.calls_as_expression] - calls = [item for sublist in calls for item in sublist] - self._calls = [c for c in calls if isinstance(c, (Function, SolidityFunction))] - - self._unparsed_expression = None - self._expression = expression diff --git a/slither/solcParsing/declarations/eventSolc.py b/slither/solcParsing/declarations/eventSolc.py deleted file mode 100644 index 647e673d5..000000000 --- a/slither/solcParsing/declarations/eventSolc.py +++ /dev/null @@ -1,31 +0,0 @@ -""" - Event module -""" -from slither.solcParsing.variables.eventVariableSolc import EventVariableSolc -from slither.core.declarations.event import Event - -class EventSolc(Event): - """ - Event class - """ - - def __init__(self, event): - super(EventSolc, self).__init__() - self._name = event['attributes']['name'] - self._elems = [] - - elems = event['children'][0] - assert elems['name'] == 'ParameterList' - if 'children' in elems: - self._elemsNotParsed = elems['children'] - else: - self._elemsNotParsed = [] - - def analyze(self, contract): - for elem_to_parse in self._elemsNotParsed: - elem = EventVariableSolc(elem_to_parse) - elem.analyze(contract) - self._elems.append(elem) - - self._elemsNotParsed = [] - diff --git a/slither/solcParsing/declarations/functionSolc.py b/slither/solcParsing/declarations/functionSolc.py deleted file mode 100644 index 1c1a3226b..000000000 --- a/slither/solcParsing/declarations/functionSolc.py +++ /dev/null @@ -1,539 +0,0 @@ -""" - Event module -""" -import logging -from slither.core.declarations.function import Function -from slither.solcParsing.cfg.nodeSolc import NodeSolc -from slither.core.cfg.nodeType import NodeType -from slither.core.cfg.node import link_nodes - -from slither.solcParsing.variables.localVariableSolc import LocalVariableSolc -from slither.solcParsing.variables.localVariableInitFromTupleSolc import LocalVariableInitFromTupleSolc -from slither.solcParsing.variables.variableDeclarationSolc import MultipleVariablesDeclaration - -from slither.solcParsing.expressions.expressionParsing import parse_expression - -from slither.visitors.expression.exportValues import ExportValues -logger = logging.getLogger("FunctionSolc") - -class FunctionSolc(Function): - """ - Event class - """ - # elems = [(type, name)] - - def __init__(self, function): - super(FunctionSolc, self).__init__() - self._name = function['attributes']['name'] - self._functionNotParsed = function - self._params_was_analyzed = False - self._content_was_analyzed = False - - def _analyze_attributes(self): - attributes = self._functionNotParsed['attributes'] - - if 'payable' in attributes: - self._payable = attributes['payable'] - elif 'stateMutability' in attributes: - if attributes['stateMutability'] == 'payable': - self._payable = True - elif attributes['stateMutability'] == 'pure': - self._pure = True - self._view = True - - if 'constant' in attributes: - self._view = attributes['constant'] - - self._is_constructor = False - - if 'isConstructor' in attributes: - self._is_constructor = attributes['isConstructor'] - - if 'visibility' in attributes: - self._visibility = attributes['visibility'] - # old solc - elif 'public' in attributes: - if attributes['public']: - self._visibility = 'public' - else: - self._visibility = 'private' - else: - self._visibility = 'public' - - if 'payable' in attributes: - self._payable = attributes['payable'] - - def _new_node(self, expression): - node = NodeSolc(expression, len(self.nodes)) - node.set_function(self) - self._nodes.append(node) - return node - - def _parse_if(self, ifStatement, node): - # IfStatement = 'if' '(' Expression ')' Statement ( 'else' Statement )? - - children = ifStatement['children'] - condition_node = self._new_node(NodeType.IF) - #condition = parse_expression(children[0], self) - condition = children[0] - condition_node.add_unparsed_expression(condition) - - link_nodes(node, condition_node) - - trueStatement = self._parse_statement(children[1], condition_node) - - endIf_node = self._new_node(NodeType.ENDIF) - link_nodes(trueStatement, endIf_node) - - if len(children) == 3: - falseStatement = self._parse_statement(children[2], condition_node) - - link_nodes(falseStatement, endIf_node) - - else: - link_nodes(condition_node, endIf_node) - - return endIf_node - - def _parse_while(self, whileStatement, node): - # WhileStatement = 'while' '(' Expression ')' Statement - - children = whileStatement['children'] - - node_startWhile = self._new_node(NodeType.STARTLOOP) - - node_condition = self._new_node(NodeType.IFLOOP) - #expression = parse_expression(children[0], self) - expression = children[0] - node_condition.add_unparsed_expression(expression) - - statement = self._parse_statement(children[1], node_condition) - - node_endWhile = self._new_node(NodeType.ENDLOOP) - - link_nodes(node, node_startWhile) - link_nodes(node_startWhile, node_condition) - link_nodes(statement, node_condition) - link_nodes(node_condition, node_endWhile) - - return node_endWhile - - def _parse_for(self, statement, node): - # ForStatement = 'for' '(' (SimpleStatement)? ';' (Expression)? ';' (ExpressionStatement)? ')' Statement - - hasInitExession = True - hasCondition = True - hasLoopExpression = True - - # Old solc version do not prevent in the attributes - # if the loop has a init value /condition or expression - # There is no way to determine that for(a;;) and for(;a;) are different with old solc - if 'attributes' in statement: - if 'initializationExpression' in statement: - if not statement['initializationExpression']: - hasInitExession = False - if 'condition' in statement: - if not statement['condition']: - hasCondition = False - if 'loopExpression' in statement: - if not statement['loopExpression']: - hasLoopExpression = False - - - node_startLoop = self._new_node(NodeType.STARTLOOP) - node_endLoop = self._new_node(NodeType.ENDLOOP) - - children = statement['children'] - - if hasInitExession: - if len(children) >= 2: - if children[0]['name'] in ['VariableDefinitionStatement', - 'VariableDeclarationStatement', - 'ExpressionStatement']: - node_initExpression = self._parse_statement(children[0], node) - link_nodes(node_initExpression, node_startLoop) - else: - hasInitExession = False - else: - hasInitExession = False - - if not hasInitExession: - link_nodes(node, node_startLoop) - node_condition = node_startLoop - - if hasCondition: - if hasInitExession and len(children) >= 2: - candidate = children[1] - else: - candidate = children[0] - if candidate['name'] not in ['VariableDefinitionStatement', - 'VariableDeclarationStatement', - 'ExpressionStatement']: - node_condition = self._new_node(NodeType.IFLOOP) - #expression = parse_expression(candidate, self) - expression = candidate - node_condition.add_unparsed_expression(expression) - link_nodes(node_startLoop, node_condition) - link_nodes(node_condition, node_endLoop) - hasCondition = True - else: - hasCondition = False - - - node_statement = self._parse_statement(children[-1], node_condition) - - node_LoopExpression = node_statement - if hasLoopExpression: - if len(children) > 2: - if children[-2]['name'] == 'ExpressionStatement': - node_LoopExpression = self._parse_statement(children[-2], node_statement) - - link_nodes(node_LoopExpression, node_startLoop) - - return node_endLoop - - def _parse_dowhile(self, doWhilestatement, node): - children = doWhilestatement['children'] - - node_startDoWhile = self._new_node(NodeType.STARTLOOP) - - # same order in the AST as while - node_condition = self._new_node(NodeType.IFLOOP) - #expression = parse_expression(children[0], self) - expression = children[0] - node_condition.add_unparsed_expression(expression) - - statement = self._parse_statement(children[1], node_condition) - node_endDoWhile = self._new_node(NodeType.ENDLOOP) - - link_nodes(node, node_startDoWhile) - link_nodes(node_startDoWhile, node_condition) - link_nodes(statement, node_condition) - link_nodes(node_condition, node_endDoWhile) - - return node_endDoWhile - - def _parse_variable_definition(self, statement, node): - #assert len(statement['children']) == 1 - # if there is, parse default value - #assert not 'attributes' in statement - - try: - local_var = LocalVariableSolc(statement) - #local_var = LocalVariableSolc(statement['children'][0], statement['children'][1::]) - local_var.set_function(self) - local_var.set_offset(statement['src']) - - self._variables[local_var.name] = local_var - #local_var.analyze(self) - - new_node = self._new_node(NodeType.VARIABLE) - new_node.add_variable_declaration(local_var) - link_nodes(node, new_node) - return new_node - except MultipleVariablesDeclaration: - # Custom handling of var (a,b) = .. style declaration - # We split the variabledeclaration in multiple declarations - count = 0 - children = statement['children'] - child = children[0] - while child['name'] == 'VariableDeclaration': - count = count +1 - child = children[count] - - assert len(children) == (count + 1) - tuple_vars = children[count] - - - variables_declaration = children[0:count] - i = 0 - new_node = node - if tuple_vars['name'] == 'TupleExpression': - assert len(tuple_vars['children']) == count - for variable in variables_declaration: - init = tuple_vars['children'][i] - src = variable['src'] - i= i+1 - # Create a fake statement to be consistent - new_statement = {'name':'VariableDefinitionStatement', - 'src': src, - 'children':[variable, init]} - - new_node = self._parse_variable_definition(new_statement, new_node) - else: - # If we have - # var (a, b) = f() - # we can split in multiple declarations, keep the init value and use LocalVariableSolc - # We use LocalVariableInitFromTupleSolc class - assert tuple_vars['name'] in ['FunctionCall', 'Conditional'] - for variable in variables_declaration: - src = variable['src'] - i= i+1 - # Create a fake statement to be consistent - new_statement = {'name':'VariableDefinitionStatement', - 'src': src, - 'children':[variable, tuple_vars]} - - new_node = self._parse_variable_definition_init_tuple(new_statement, i, new_node) - return new_node - - def _parse_variable_definition_init_tuple(self, statement, index, node): - local_var = LocalVariableInitFromTupleSolc(statement, index) - #local_var = LocalVariableSolc(statement['children'][0], statement['children'][1::]) - local_var.set_function(self) - local_var.set_offset(statement['src']) - - self._variables[local_var.name] = local_var -# local_var.analyze(self) - - new_node = self._new_node(NodeType.VARIABLE) - new_node.add_variable_declaration(local_var) - link_nodes(node, new_node) - return new_node - - - def _parse_statement(self, statement, node): - """ - - Return: - node - """ - # Statement = IfStatement | WhileStatement | ForStatement | Block | InlineAssemblyStatement | - # ( DoWhileStatement | PlaceholderStatement | Continue | Break | Return | - # Throw | EmitStatement | SimpleStatement ) ';' - # SimpleStatement = VariableDefinition | ExpressionStatement - - name = statement['name'] - # SimpleStatement = VariableDefinition | ExpressionStatement - if name == 'IfStatement': - node = self._parse_if(statement, node) - elif name == 'WhileStatement': - node = self._parse_while(statement, node) - elif name == 'ForStatement': - node = self._parse_for(statement, node) - elif name == 'Block': - node = self._parse_block(statement, node) - elif name == 'InlineAssembly': - break_node = self._new_node(NodeType.ASSEMBLY) - link_nodes(node, break_node) - node = break_node - elif name == 'DoWhileStatement': - node = self._parse_dowhile(statement, node) - # For Continue / Break / Return / Throw - # The is fixed later - elif name == 'Continue': - continue_node = self._new_node(NodeType.CONTINUE) - link_nodes(node, continue_node) - node = continue_node - elif name == 'Break': - break_node = self._new_node(NodeType.BREAK) - link_nodes(node, break_node) - node = break_node - elif name == 'Return': - return_node = self._new_node(NodeType.RETURN) - link_nodes(node, return_node) - if 'children' in statement and statement['children']: - assert len(statement['children']) == 1 - expression = statement['children'][0] - return_node.add_unparsed_expression(expression) - node = return_node - elif name == 'Throw': - throw_node = self._new_node(NodeType.THROW) - link_nodes(node, throw_node) - node = throw_node - elif name == 'EmitStatement': - #expression = parse_expression(statement['children'][0], self) - expression = statement['children'][0] - new_node = self._new_node(NodeType.EXPRESSION) - new_node.add_unparsed_expression(expression) - link_nodes(node, new_node) - node = new_node - elif name in ['VariableDefinitionStatement', 'VariableDeclarationStatement']: - node = self._parse_variable_definition(statement, node) - elif name == 'ExpressionStatement': - assert len(statement['children']) == 1 - assert not 'attributes' in statement - #expression = parse_expression(statement['children'][0], self) - expression = statement['children'][0] - new_node = self._new_node(NodeType.EXPRESSION) - new_node.add_unparsed_expression(expression) - link_nodes(node, new_node) - node = new_node - else: - logger.error('Statement not parsed %s'%name) - exit(-1) - - return node - - def _parse_block(self, block, node): - ''' - Return: - Node - ''' - assert block['name'] == 'Block' - - for child in block['children']: - node = self._parse_statement(child, node) - return node - - def _parse_cfg(self, cfg): - - assert cfg['name'] == 'Block' - - node = self._new_node(NodeType.ENTRYPOINT) - self._entry_point = node - - if not cfg['children']: - self._is_empty = True - else: - self._is_empty = False - self._parse_block(cfg, node) - self._remove_incorrect_edges() - - def _find_end_loop(self, node, visited): - if node in visited: - return None - - if node.type == NodeType.ENDLOOP: - return node - - visited = visited + [node] - for son in node.sons: - ret = self._find_end_loop(son, visited) - if ret: - return ret - - return None - - def _find_start_loop(self, node, visited): - if node in visited: - return None - - if node.type == NodeType.STARTLOOP: - return node - - visited = visited + [node] - for father in node.fathers: - ret = self._find_start_loop(father, visited) - if ret: - return ret - - return None - - def _fix_break_node(self, node): - end_node = self._find_end_loop(node, []) - - if not end_node: - logger.error('Break in no-loop context {}'.format(node.nodeId())) - exit(-1) - - for son in node.sons: - son.remove_father(node) - node.set_sons([end_node]) - end_node.add_father(node) - - def _fix_continue_node(self, node): - start_node = self._find_start_loop(node, []) - - if not start_node: - logger.error('Continue in no-loop context {}'.format(node.nodeId())) - exit(-1) - - for son in node.sons: - son.remove_father(node) - node.set_sons([start_node]) - start_node.add_father(node) - - def _remove_incorrect_edges(self): - for node in self._nodes: - if node.type in [NodeType.RETURN, NodeType.THROW]: - for son in node.sons: - son.remove_father(node) - node.set_sons([]) - if node.type in [NodeType.BREAK]: - self._fix_break_node(node) - if node.type in [NodeType.CONTINUE]: - self._fix_continue_node(node) - - def _parse_params(self, params): - - assert params['name'] == 'ParameterList' - for param in params['children']: - assert param['name'] == 'VariableDeclaration' - - local_var = LocalVariableSolc(param) - - local_var.set_function(self) - local_var.set_offset(param['src']) - local_var.analyze(self) - - self._variables[local_var.name] = local_var - self._parameters.append(local_var) - - def _parse_returns(self, returns): - - assert returns['name'] == 'ParameterList' - for ret in returns['children']: - assert ret['name'] == 'VariableDeclaration' - - local_var = LocalVariableSolc(ret) - - local_var.set_function(self) - local_var.set_offset(ret['src']) - local_var.analyze(self) - - self._variables[local_var.name] = local_var - self._returns.append(local_var) - - - def _parse_modifier(self, modifier): - m = parse_expression(modifier, self) - self._expression_modifiers.append(m) - self._modifiers += [m for m in ExportValues(m).result() if isinstance(m, Function)] - - - def analyze_params(self): - # Can be re-analyzed due to inheritance - if self._params_was_analyzed: - return - - self._params_was_analyzed = True - - self._analyze_attributes() - - children = self._functionNotParsed['children'] - - params = children[0] - returns = children[1] - - if params: - self._parse_params(params) - if returns: - self._parse_returns(returns) - - def analyze_content(self): - if self._content_was_analyzed: - return - - self._content_was_analyzed = True - - children = self._functionNotParsed['children'] - self._is_implemented = False - for child in children[2:]: - if child['name'] == 'Block': - self._is_implemented = True - self._parse_cfg(child) - continue - - assert child['name'] == 'ModifierInvocation' - - self._parse_modifier(child) - - for local_vars in self.variables: - local_vars.analyze(self) - - for node in self.nodes: - node.analyze_expressions(self) - - self._analyze_read_write() - self._analyze_calls() diff --git a/slither/solcParsing/expressions/expressionParsing.py b/slither/solcParsing/expressions/expressionParsing.py deleted file mode 100644 index 0b24efde9..000000000 --- a/slither/solcParsing/expressions/expressionParsing.py +++ /dev/null @@ -1,359 +0,0 @@ -import logging -import re -from slither.core.expressions.unaryOperation import UnaryOperation, UnaryOperationType -from slither.core.expressions.binaryOperation import BinaryOperation, BinaryOperationType -from slither.core.expressions.literal import Literal -from slither.core.expressions.identifier import Identifier -from slither.core.expressions.superIdentifier import SuperIdentifier -from slither.core.expressions.indexAccess import IndexAccess -from slither.core.expressions.memberAccess import MemberAccess -from slither.core.expressions.tupleExpression import TupleExpression -from slither.core.expressions.conditionalExpression import ConditionalExpression -from slither.core.expressions.assignmentOperation import AssignmentOperation, AssignmentOperationType -from slither.core.expressions.typeConversion import TypeConversion -from slither.core.expressions.callExpression import CallExpression -from slither.core.expressions.superCallExpression import SuperCallExpression -from slither.core.expressions.newArray import NewArray -from slither.core.expressions.newContract import NewContract -from slither.core.expressions.newElementaryType import NewElementaryType -from slither.core.expressions.elementaryTypeNameExpression import ElementaryTypeNameExpression - -from slither.solcParsing.solidityTypes.typeParsing import parse_type, UnknownType - -from slither.core.declarations.contract import Contract -from slither.core.declarations.function import Function - -from slither.core.declarations.solidityVariables import SOLIDITY_VARIABLES, SOLIDITY_FUNCTIONS, SOLIDITY_VARIABLES_COMPOSED -from slither.core.declarations.solidityVariables import SolidityVariable, SolidityFunction, SolidityVariableComposed - -from slither.core.solidityTypes.elementaryType import ElementaryType - -logger = logging.getLogger("ExpressionParsing") - -class VariableNotFound(Exception): pass - -def find_variable(var_name, caller_context): - - if isinstance(caller_context, Contract): - function = None - contract = caller_context - elif isinstance(caller_context, Function): - function = caller_context - contract = function.contract - else: - logger.error('Incorrect caller context') - exit(-1) - - if function: - func_variables = function.variables_as_dict() - if var_name in func_variables: - return func_variables[var_name] - - contract_variables = contract.variables_as_dict() - if var_name in contract_variables: - return contract_variables[var_name] - - functions = contract.functions_as_dict() - if var_name in functions: - return functions[var_name] - - modifiers = contract.modifiers_as_dict() - if var_name in modifiers: - return modifiers[var_name] - - structures = contract.structures_as_dict() - if var_name in structures: - return structures[var_name] - - events = contract.events_as_dict() - if var_name in events: - return events[var_name] - - enums = contract.enums_as_dict() - if var_name in enums: - return enums[var_name] - - # If the enum is refered as its name rather than its canonicalName - enums = {e.name: e for e in contract.enums} - if var_name in enums: - return enums[var_name] - - # Could refer to any enum - all_enums = [c.enums_as_dict() for c in contract.slither.contracts] - all_enums = {k: v for d in all_enums for k, v in d.items()} - if var_name in all_enums: - return all_enums[var_name] - - if var_name in SOLIDITY_VARIABLES: - return SolidityVariable(var_name) - - if var_name in SOLIDITY_FUNCTIONS: - return SolidityFunction(var_name) - - contracts = contract.slither.contracts_as_dict() - if var_name in contracts: - return contracts[var_name] - - raise VariableNotFound('Variable not found: {}'.format(var_name)) - - -def parse_call(expression, caller_context): - attributes = expression['attributes'] - - type_conversion = attributes['type_conversion'] - - children = expression['children'] - if type_conversion: - assert len(children) == 2 - - type_call = parse_type(UnknownType(attributes['type']), caller_context) - type_info = children[0] - assert type_info['name'] in ['ElementaryTypenameExpression', - 'ElementaryTypeNameExpression', - 'Identifier', - 'TupleExpression', - 'IndexAccess', - 'MemberAccess'] - - expression = parse_expression(children[1], caller_context) - t = TypeConversion(expression, type_call) - return t - - assert children - - type_call = attributes['type'] - called = parse_expression(children[0], caller_context) - arguments = [parse_expression(a, caller_context) for a in children[1::]] - - if isinstance(called, SuperCallExpression): - return SuperCallExpression(called, arguments, type_call) - return CallExpression(called, arguments, type_call) - -def parse_super_name(expression): - assert expression['name'] == 'MemberAccess' - attributes = expression['attributes'] - base_name = attributes['member_name'] - - arguments = attributes['type'] - assert arguments.startswith('function ') - # remove function (...() - arguments = arguments[len('function '):] - - arguments = filter_name(arguments) - if ' ' in arguments: - arguments = arguments[:arguments.find(' ')] - - return base_name+arguments - -def filter_name(value): - value = value.replace(' memory', '') - value = value.replace(' storage', '') - value = value.replace('struct ', '') - value = value.replace('contract ', '') - value = value.replace('enum ', '') - value = value.replace(' ref', '') - value = value.replace(' pointer', '') - return value - -def parse_expression(expression, caller_context): - """ - - Returns: - str: expression - """ - # Expression - # = Expression ('++' | '--') - # | NewExpression - # | IndexAccess - # | MemberAccess - # | FunctionCall - # | '(' Expression ')' - # | ('!' | '~' | 'delete' | '++' | '--' | '+' | '-') Expression - # | Expression '**' Expression - # | Expression ('*' | '/' | '%') Expression - # | Expression ('+' | '-') Expression - # | Expression ('<<' | '>>') Expression - # | Expression '&' Expression - # | Expression '^' Expression - # | Expression '|' Expression - # | Expression ('<' | '>' | '<=' | '>=') Expression - # | Expression ('==' | '!=') Expression - # | Expression '&&' Expression - # | Expression '||' Expression - # | Expression '?' Expression ':' Expression - # | Expression ('=' | '|=' | '^=' | '&=' | '<<=' | '>>=' | '+=' | '-=' | '*=' | '/=' | '%=') Expression - # | PrimaryExpression - - # The AST naming does not follow the spec - name = expression['name'] - - if name == 'UnaryOperation': - attributes = expression['attributes'] - assert 'prefix' in attributes - operation_type = UnaryOperationType.get_type(attributes['operator'], attributes['prefix']) - - assert len(expression['children']) == 1 - expression = parse_expression(expression['children'][0], caller_context) - unary_op = UnaryOperation(expression, operation_type) - return unary_op - - elif name == 'BinaryOperation': - attributes = expression['attributes'] - operation_type = BinaryOperationType.get_type(attributes['operator']) - - assert len(expression['children']) == 2 - left_expression = parse_expression(expression['children'][0], caller_context) - right_expression = parse_expression(expression['children'][1], caller_context) - binary_op = BinaryOperation(left_expression, right_expression, operation_type) - return binary_op - - elif name == 'FunctionCall': - return parse_call(expression, caller_context) - - elif name == 'TupleExpression': - if 'children' not in expression : - attributes = expression['attributes'] - components = attributes['components'] - expressions = [parse_expression(c, caller_context) if c else None for c in components] - else: - expressions = [parse_expression(e, caller_context) for e in expression['children']] - t = TupleExpression(expressions) - return t - - elif name == 'Conditional': - children = expression['children'] - assert len(children) == 3 - if_expression = parse_expression(children[0], caller_context) - then_expression = parse_expression(children[1], caller_context) - else_expression = parse_expression(children[2], caller_context) - conditional = ConditionalExpression(if_expression, then_expression, else_expression) - return conditional - - elif name == 'Assignment': - attributes = expression['attributes'] - children = expression['children'] - assert len(expression['children']) == 2 - - left_expression = parse_expression(children[0], caller_context) - right_expression = parse_expression(children[1], caller_context) - operation_type = AssignmentOperationType.get_type(attributes['operator']) - operation_return_type = attributes['type'] - - assignement = AssignmentOperation(left_expression, right_expression, operation_type, operation_return_type) - return assignement - - elif name == 'Literal': - assert 'children' not in expression - value = expression['attributes']['value'] - literal = Literal(value) - return literal - - elif name == 'Identifier': - assert 'children' not in expression - value = expression['attributes']['value'] - if 'type' in expression['attributes']: - t = expression['attributes']['type'] - if t: - found = re.findall('[struct|enum|function|modifier] \(([\[\] a-zA-Z0-9\.,]*)\)', t) - if found: - value = value+'('+found[0]+')' - value = filter_name(value) - - var = find_variable(value, caller_context) - - identifier = Identifier(var) - return identifier - - elif name == 'IndexAccess': - index_type = expression['attributes']['type'] - children = expression['children'] - assert len(children) == 2 - left_expression = parse_expression(children[0], caller_context) - right_expression = parse_expression(children[1], caller_context) - index = IndexAccess(left_expression, right_expression, index_type) - return index - - elif name == 'MemberAccess': - member_name = expression['attributes']['member_name'] - member_type = expression['attributes']['type'] - children = expression['children'] - assert len(children) == 1 - member_expression = parse_expression(children[0], caller_context) - if str(member_expression) == 'super': - super_name = parse_super_name(expression) - inheritances = caller_context.contract.inheritances - var = None - for father in inheritances: - try: - var = find_variable(super_name, father) - break - except VariableNotFound: - continue - if var is None: - raise VariableNotFound('Variable not found: {}'.format(super_name)) - return SuperIdentifier(var) - member_access = MemberAccess(member_name, member_type, member_expression) - if str(member_access) in SOLIDITY_VARIABLES_COMPOSED: - return Identifier(SolidityVariableComposed(str(member_access))) - return member_access - - elif name == 'ElementaryTypeNameExpression': - # nop exression - # uint; - assert 'children' not in expression - value = expression['attributes']['value'] - t = parse_type(UnknownType(value), caller_context) - - return ElementaryTypeNameExpression(t) - - - # NewExpression is not a root expression, it's always the child of another expression - elif name == 'NewExpression': - new_type = expression['attributes']['type'] - - children = expression['children'] - assert len(children) == 1 - #new_expression = parse_expression(children[0]) - - child = children[0] - - if child['name'] == 'ArrayTypeName': - depth = 0 - while child['name'] == 'ArrayTypeName': - # Note: dont conserve the size of the array if provided - #assert len(child['children']) == 1 - child = child['children'][0] - depth += 1 - - if child['name'] == 'ElementaryTypeName': - array_type = ElementaryType(child['attributes']['name']) - elif child['name'] == 'UserDefinedTypeName': - array_type = parse_type(UnknownType(child['attributes']['name']), caller_context) - else: - logger.error('Incorrect type array {}'.format(child)) - exit(-1) - array = NewArray(depth, array_type) - return array - - if child['name'] == 'ElementaryTypeName': - elem_type = ElementaryType(child['attributes']['name']) - new_elem = NewElementaryType(elem_type) - return new_elem - - assert child['name'] == 'UserDefinedTypeName' - - contract_name = child['attributes']['name'] - new = NewContract(contract_name) - return new - - elif name == 'ModifierInvocation': - - children = expression['children'] - called = parse_expression(children[0], caller_context) - arguments = [parse_expression(a, caller_context) for a in children[1::]] - - call = CallExpression(called, arguments, 'Modifier') - return call - - logger.error('Expression not parsed %s'%name) - exit(-1) diff --git a/slither/solcParsing/variables/eventVariableSolc.py b/slither/solcParsing/variables/eventVariableSolc.py deleted file mode 100644 index b38f6b282..000000000 --- a/slither/solcParsing/variables/eventVariableSolc.py +++ /dev/null @@ -1,5 +0,0 @@ - -from .variableDeclarationSolc import VariableDeclarationSolc -from slither.core.variables.eventVariable import EventVariable - -class EventVariableSolc(VariableDeclarationSolc, EventVariable): pass diff --git a/slither/solcParsing/variables/functionTypeVariableSolc.py b/slither/solcParsing/variables/functionTypeVariableSolc.py deleted file mode 100644 index 9ba8f46e0..000000000 --- a/slither/solcParsing/variables/functionTypeVariableSolc.py +++ /dev/null @@ -1,5 +0,0 @@ - -from slither.solcParsing.variables.variableDeclarationSolc import VariableDeclarationSolc -from slither.core.variables.functionTypeVariable import FunctionTypeVariable - -class FunctionTypeVariableSolc(VariableDeclarationSolc, FunctionTypeVariable): pass diff --git a/slither/solcParsing/variables/localVariableSolc.py b/slither/solcParsing/variables/localVariableSolc.py deleted file mode 100644 index e983a86c9..000000000 --- a/slither/solcParsing/variables/localVariableSolc.py +++ /dev/null @@ -1,5 +0,0 @@ - -from .variableDeclarationSolc import VariableDeclarationSolc -from slither.core.variables.localVariable import LocalVariable - -class LocalVariableSolc(VariableDeclarationSolc, LocalVariable): pass diff --git a/slither/solcParsing/variables/stateVariableSolc.py b/slither/solcParsing/variables/stateVariableSolc.py deleted file mode 100644 index 64db8f619..000000000 --- a/slither/solcParsing/variables/stateVariableSolc.py +++ /dev/null @@ -1,5 +0,0 @@ - -from .variableDeclarationSolc import VariableDeclarationSolc -from slither.core.variables.stateVariable import StateVariable - -class StateVariableSolc(VariableDeclarationSolc, StateVariable): pass diff --git a/slither/solcParsing/variables/structureVariableSolc.py b/slither/solcParsing/variables/structureVariableSolc.py deleted file mode 100644 index b7b03b06c..000000000 --- a/slither/solcParsing/variables/structureVariableSolc.py +++ /dev/null @@ -1,5 +0,0 @@ - -from .variableDeclarationSolc import VariableDeclarationSolc -from slither.core.variables.structureVariable import StructureVariable - -class StructureVariableSolc(VariableDeclarationSolc, StructureVariable): pass diff --git a/slither/solcParsing/variables/variableDeclarationSolc.py b/slither/solcParsing/variables/variableDeclarationSolc.py deleted file mode 100644 index bd703cd92..000000000 --- a/slither/solcParsing/variables/variableDeclarationSolc.py +++ /dev/null @@ -1,115 +0,0 @@ -import logging -from slither.solcParsing.expressions.expressionParsing import parse_expression - -from slither.core.variables.variable import Variable - -from slither.solcParsing.solidityTypes.typeParsing import parse_type, UnknownType - -from slither.core.solidityTypes.elementaryType import ElementaryType, NonElementaryType - -logger = logging.getLogger("VariableDeclarationSolcParsing") - -class MultipleVariablesDeclaration(Exception): - ''' - This is raised on - var (a,b) = ... - It should occur only on local variable definition - ''' - pass - -class VariableDeclarationSolc(Variable): - - def __init__(self, var): - ''' - A variable can be declared through a statement, or directly. - If it is through a statement, the following children may contain - the init value. - It may be possible that the variable is declared through a statement, - but the init value is declared at the VariableDeclaration children level - ''' - - super(VariableDeclarationSolc, self).__init__() - self._was_analyzed = False - - if var['name'] in ['VariableDeclarationStatement', 'VariableDefinitionStatement']: - if len(var['children']) == 2: - init = var['children'][1] - elif len(var['children']) == 1: - init = None - elif len(var['children']) > 2: - raise MultipleVariablesDeclaration - else: - logger.error('Variable declaration without children?'+var) - exit(-1) - declaration = var['children'][0] - self._init_from_declaration(declaration, init) - elif var['name'] == 'VariableDeclaration': - self._init_from_declaration(var, None) - else: - logger.error('Incorrect variable declaration type {}'.format(var['name'])) - exit(-1) - - @property - def initialized(self): - return self._initialized - - @property - def uninitialized(self): - return not self._initialized - - - def _init_from_declaration(self, var, init): - assert len(var['children']) <= 2 - assert var['name'] == 'VariableDeclaration' - - attributes = var['attributes'] - self._name = attributes['name'] - - self._typeName = attributes['type'] - self._arrayDepth = 0 - self._isMapping = False - self._mappingFrom = None - self._mappingTo = False - self._initial_expression = None - self._type = None - - if 'visibility' in attributes: - self._visibility = attributes['visibility'] - else: - self._visibility = 'internal' - - if not var['children']: - # It happens on variable declared inside loop declaration - try: - self._type = ElementaryType(self._typeName) - self._elem_to_parse = None - except NonElementaryType: - self._elem_to_parse = UnknownType(self._typeName) - else: - self._elem_to_parse = var['children'][0] - - if init: # there are two way to init a var local in the AST - assert len(var['children']) <= 1 - self._initialized = True - self._initializedNotParsed = init - elif len(var['children']) == 1: - self._initialized = False - self._initializedNotParsed = [] - else: - assert len(var['children']) == 2 - self._initialized = True - self._initializedNotParsed = var['children'][1] - - def analyze(self, caller_context): - # Can be re-analyzed due to inheritance - if self._was_analyzed: - return - self._was_analyzed = True - - if self._elem_to_parse: - self._type = parse_type(self._elem_to_parse, caller_context) - self._elem_to_parse = None - - if self._initialized: - self._initial_expression = parse_expression(self._initializedNotParsed, caller_context) - self._initializedNotParsed = None diff --git a/slither/solc_parsing/__init__.py b/slither/solc_parsing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/solc_parsing/cfg/__init__.py b/slither/solc_parsing/cfg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/solc_parsing/cfg/node.py b/slither/solc_parsing/cfg/node.py new file mode 100644 index 000000000..56e73e32c --- /dev/null +++ b/slither/solc_parsing/cfg/node.py @@ -0,0 +1,65 @@ +from slither.core.cfg.node import Node +from slither.core.cfg.node import NodeType +from slither.solc_parsing.expressions.expression_parsing import parse_expression +from slither.visitors.expression.read_var import ReadVar +from slither.visitors.expression.write_var import WriteVar +from slither.visitors.expression.find_calls import FindCalls + +from slither.visitors.expression.export_values import ExportValues +from slither.core.declarations.solidity_variables import SolidityVariable, SolidityFunction +from slither.core.declarations.function import Function + +from slither.core.variables.state_variable import StateVariable + +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.assignment_operation import AssignmentOperation, AssignmentOperationType + +class NodeSolc(Node): + + def __init__(self, nodeType, nodeId): + super(NodeSolc, self).__init__(nodeType, nodeId) + self._unparsed_expression = None + + def add_unparsed_expression(self, expression): + assert self._unparsed_expression is None + self._unparsed_expression = expression + + def analyze_expressions(self, caller_context): + if self.type == NodeType.VARIABLE and not self._expression: + self._expression = self.variable_declaration.expression + if self._unparsed_expression: + expression = parse_expression(self._unparsed_expression, caller_context) + self._expression = expression + self._unparsed_expression = None + + if self.expression: + + if self.type == NodeType.VARIABLE: + # Update the expression to be an assignement to the variable + #print(self.variable_declaration) + self._expression = AssignmentOperation(Identifier(self.variable_declaration), + self.expression, + AssignmentOperationType.ASSIGN, + self.variable_declaration.type) + + expression = self.expression + pp = ReadVar(expression) + self._expression_vars_read = pp.result() + +# self._vars_read = [item for sublist in vars_read for item in sublist] +# self._state_vars_read = [x for x in self.variables_read if\ +# isinstance(x, (StateVariable))] +# self._solidity_vars_read = [x for x in self.variables_read if\ +# isinstance(x, (SolidityVariable))] + + pp = WriteVar(expression) + self._expression_vars_written = pp.result() + +# self._vars_written = [item for sublist in vars_written for item in sublist] +# self._state_vars_written = [x for x in self.variables_written if\ +# isinstance(x, StateVariable)] + + pp = FindCalls(expression) + self._expression_calls = pp.result() + self._external_calls_as_expressions = [c for c in self.calls_as_expression if not isinstance(c.called, Identifier)] + diff --git a/slither/solc_parsing/declarations/__init__.py b/slither/solc_parsing/declarations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/solcParsing/declarations/contractSolc04.py b/slither/solc_parsing/declarations/contract.py similarity index 50% rename from slither/solcParsing/declarations/contractSolc04.py rename to slither/solc_parsing/declarations/contract.py index 4caba8e2d..03ff2dc54 100644 --- a/slither/solcParsing/declarations/contractSolc04.py +++ b/slither/solc_parsing/declarations/contract.py @@ -3,12 +3,14 @@ import logging from slither.core.declarations.contract import Contract from slither.core.declarations.enum import Enum -from slither.solcParsing.declarations.structureSolc import StructureSolc -from slither.solcParsing.declarations.eventSolc import EventSolc -from slither.solcParsing.declarations.modifierSolc import ModifierSolc -from slither.solcParsing.declarations.functionSolc import FunctionSolc +from slither.solc_parsing.declarations.structure import StructureSolc +from slither.solc_parsing.declarations.event import EventSolc +from slither.solc_parsing.declarations.modifier import ModifierSolc +from slither.solc_parsing.declarations.function import FunctionSolc -from slither.solcParsing.variables.stateVariableSolc import StateVariableSolc +from slither.solc_parsing.variables.state_variable import StateVariableSolc + +from slither.solc_parsing.solidity_types.type_parsing import parse_type logger = logging.getLogger("ContractSolcParsing") @@ -35,9 +37,14 @@ class ContractSolc04(Contract): self._is_analyzed = False # Export info - self._name = self._data['attributes']['name'] + if self.is_compact_ast: + self._name = self._data['name'] + else: + self._name = self._data['attributes'][self.get_key()] + self._id = self._data['id'] - self._inheritances = [] + + self._inheritance = [] self._parse_contract_info() self._parse_contract_items() @@ -46,11 +53,27 @@ class ContractSolc04(Contract): def is_analyzed(self): return self._is_analyzed + def get_key(self): + return self.slither.get_key() + + def get_children(self, key='nodes'): + if self.is_compact_ast: + return key + return 'children' + + @property + def is_compact_ast(self): + return self.slither.is_compact_ast + def set_is_analyzed(self, is_analyzed): self._is_analyzed = is_analyzed def _parse_contract_info(self): - attributes = self._data['attributes'] + if self.is_compact_ast: + attributes = self._data + else: + attributes = self._data['attributes'] + self.isInterface = False if 'contractKind' in attributes: if attributes['contractKind'] == 'interface': @@ -60,51 +83,64 @@ class ContractSolc04(Contract): self.fullyImplemented = attributes['fullyImplemented'] def _parse_contract_items(self): - if not 'children' in self._data: # empty contract + if not self.get_children() in self._data: # empty contract return - for item in self._data['children']: - if item['name'] == 'FunctionDefinition': + for item in self._data[self.get_children()]: + if item[self.get_key()] == 'FunctionDefinition': self._functionsNotParsed.append(item) - elif item['name'] == 'EventDefinition': + elif item[self.get_key()] == 'EventDefinition': self._eventsNotParsed.append(item) - elif item['name'] == 'InheritanceSpecifier': + elif item[self.get_key()] == 'InheritanceSpecifier': # we dont need to parse it as it is redundant # with self.linearizedBaseContracts continue - elif item['name'] == 'VariableDeclaration': + elif item[self.get_key()] == 'VariableDeclaration': self._variablesNotParsed.append(item) - elif item['name'] == 'EnumDefinition': + elif item[self.get_key()] == 'EnumDefinition': self._enumsNotParsed.append(item) - elif item['name'] == 'ModifierDefinition': + elif item[self.get_key()] == 'ModifierDefinition': self._modifiersNotParsed.append(item) - elif item['name'] == 'StructDefinition': + elif item[self.get_key()] == 'StructDefinition': self._structuresNotParsed.append(item) - elif item['name'] == 'UsingForDirective': + elif item[self.get_key()] == 'UsingForDirective': self._usingForNotParsed.append(item) else: - logger.error('Unknown contract item: '+item['name']) + logger.error('Unknown contract item: '+item[self.get_key()]) exit(-1) return def analyze_using_for(self): - for father in self.inheritances: + for father in self.inheritance: self._using_for.update(father.using_for) - for using_for in self._usingForNotParsed: - children = using_for['children'] - assert children and len(children) <= 2 - if len(children) == 2: - new = children[0]['attributes']['name'] - old = children[1]['attributes']['name'] - else: - new = children[0]['attributes']['name'] - old = '*' - self._using_for[old] = new + if self.is_compact_ast: + for using_for in self._usingForNotParsed: + lib_name = parse_type(using_for['libraryName'], self) + if 'typeName' in using_for and using_for['typeName']: + type_name = parse_type(using_for['typeName'], self) + else: + type_name = '*' + if not type_name in self._using_for: + self.using_for[type_name] = [] + self._using_for[type_name].append(lib_name) + else: + for using_for in self._usingForNotParsed: + children = using_for[self.get_children()] + assert children and len(children) <= 2 + if len(children) == 2: + new = parse_type(children[0], self) + old = parse_type(children[1], self) + else: + new = parse_type(children[0], self) + old = '*' + if not old in self._using_for: + self.using_for[old] = [] + self._using_for[old].append(new) self._usingForNotParsed = [] def analyze_enums(self): - for father in self.inheritances: + for father in self.inheritance: self._enums.update(father.enums_as_dict()) for enum in self._enumsNotParsed: @@ -115,42 +151,54 @@ class ContractSolc04(Contract): def _analyze_enum(self, enum): # Enum can be parsed in one pass - name = enum['attributes']['name'] - if 'canonicalName' in enum['attributes']: - canonicalName = enum['attributes']['canonicalName'] + if self.is_compact_ast: + name = enum['name'] + canonicalName = enum['canonicalName'] else: - canonicalName = self.name + '.' + name + name = enum['attributes'][self.get_key()] + if 'canonicalName' in enum['attributes']: + canonicalName = enum['attributes']['canonicalName'] + else: + canonicalName = self.name + '.' + name values = [] - for child in enum['children']: - assert child['name'] == 'EnumValue' - values.append(child['attributes']['name']) + for child in enum[self.get_children('members')]: + assert child[self.get_key()] == 'EnumValue' + if self.is_compact_ast: + values.append(child['name']) + else: + values.append(child['attributes'][self.get_key()]) new_enum = Enum(name, canonicalName, values) new_enum.set_contract(self) - new_enum.set_offset(enum['src']) + new_enum.set_offset(enum['src'], self.slither) self._enums[canonicalName] = new_enum def _parse_struct(self, struct): - name = struct['attributes']['name'] - if 'canonicalName' in struct['attributes']: - canonicalName = struct['attributes']['canonicalName'] + if self.is_compact_ast: + name = struct['name'] + attributes = struct + else: + name = struct['attributes'][self.get_key()] + attributes = struct['attributes'] + if 'canonicalName' in attributes: + canonicalName = attributes['canonicalName'] else: canonicalName = self.name + '.' + name - if 'children' in struct: - children = struct['children'] + if self.get_children('members') in struct: + children = struct[self.get_children('members')] else: children = [] # empty struct st = StructureSolc(name, canonicalName, children) st.set_contract(self) - st.set_offset(struct['src']) + st.set_offset(struct['src'], self.slither) self._structures[name] = st def _analyze_struct(self, struct): struct.analyze() def parse_structs(self): - for father in self.inheritances_reverse: + for father in self.inheritance_reverse: self._structures.update(father.structures_as_dict()) for struct in self._structuresNotParsed: @@ -163,27 +211,40 @@ class ContractSolc04(Contract): def analyze_events(self): - for father in self.inheritances_reverse: + for father in self.inheritance_reverse: self._events.update(father.events_as_dict()) for event_to_parse in self._eventsNotParsed: - event = EventSolc(event_to_parse) + event = EventSolc(event_to_parse, self) event.analyze(self) + event.set_contract(self) + event.set_offset(event_to_parse['src'], self.slither) self._events[event.full_name] = event self._eventsNotParsed = None def parse_state_variables(self): - for father in self.inheritances_reverse: + for father in self.inheritance_reverse: self._variables.update(father.variables_as_dict()) for varNotParsed in self._variablesNotParsed: var = StateVariableSolc(varNotParsed) - var.set_offset(varNotParsed['src']) + var.set_offset(varNotParsed['src'], self.slither) var.set_contract(self) self._variables[var.name] = var + def analyze_constant_state_variables(self): + from slither.solc_parsing.expressions.expression_parsing import VariableNotFound + for var in self.variables: + if var.is_constant: + # cant parse constant expression based on function calls + try: + var.analyze(self) + except VariableNotFound: + pass + return + def analyze_state_variables(self): for var in self.variables: var.analyze(self) @@ -191,9 +252,9 @@ class ContractSolc04(Contract): def _parse_modifier(self, modifier): - modif = ModifierSolc(modifier) + modif = ModifierSolc(modifier, self) modif.set_contract(self) - modif.set_offset(modifier['src']) + modif.set_offset(modifier['src'], self.slither) self._modifiers_no_params.append(modif) def parse_modifiers(self): @@ -205,9 +266,8 @@ class ContractSolc04(Contract): return def _parse_function(self, function): - func = FunctionSolc(function) - func.set_contract(self) - func.set_offset(function['src']) + func = FunctionSolc(function, self) + func.set_offset(function['src'], self.slither) self._functions_no_params.append(func) def parse_functions(self): @@ -221,7 +281,7 @@ class ContractSolc04(Contract): return def analyze_params_modifiers(self): - for father in self.inheritances_reverse: + for father in self.inheritance_reverse: self._modifiers.update(father.modifiers_as_dict()) for modifier in self._modifiers_no_params: @@ -232,10 +292,15 @@ class ContractSolc04(Contract): return def analyze_params_functions(self): - for father in self.inheritances_reverse: - functions = {k:v for (k,v) in father.functions_as_dict().items()} #if not v.is_constructor} + # keep track of the contracts visited + # to prevent an ovveride due to multiple inheritance of the same contract + # A is B, C, D is C, --> the second C was already seen + contracts_visited = [] + for father in self.inheritance_reverse: + functions = {k:v for (k, v) in father.functions_as_dict().items() + if not v.contract in contracts_visited} + contracts_visited.append(father) self._functions.update(functions) - for function in self._functions_no_params: function.analyze_params() self._functions[function.full_name] = function diff --git a/slither/solc_parsing/declarations/event.py b/slither/solc_parsing/declarations/event.py new file mode 100644 index 000000000..d37a478f2 --- /dev/null +++ b/slither/solc_parsing/declarations/event.py @@ -0,0 +1,43 @@ +""" + Event module +""" +from slither.solc_parsing.variables.event_variable import EventVariableSolc +from slither.core.declarations.event import Event + +class EventSolc(Event): + """ + Event class + """ + + def __init__(self, event, contract): + super(EventSolc, self).__init__() + self._contract = contract + + self._elems = [] + if self.is_compact_ast: + self._name = event['name'] + elems = event['parameters'] + assert elems['nodeType'] == 'ParameterList' + self._elemsNotParsed = elems['parameters'] + else: + self._name = event['attributes']['name'] + elems = event['children'][0] + + assert elems['name'] == 'ParameterList' + if 'children' in elems: + self._elemsNotParsed = elems['children'] + else: + self._elemsNotParsed = [] + + @property + def is_compact_ast(self): + return self.contract.is_compact_ast + + def analyze(self, contract): + for elem_to_parse in self._elemsNotParsed: + elem = EventVariableSolc(elem_to_parse) + elem.analyze(contract) + self._elems.append(elem) + + self._elemsNotParsed = [] + diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py new file mode 100644 index 000000000..5b3952917 --- /dev/null +++ b/slither/solc_parsing/declarations/function.py @@ -0,0 +1,909 @@ +""" + Event module +""" +import logging +from slither.core.declarations.function import Function +from slither.core.cfg.node import NodeType +from slither.solc_parsing.cfg.node import NodeSolc +from slither.core.cfg.node import NodeType +from slither.core.cfg.node import link_nodes + +from slither.solc_parsing.variables.local_variable import LocalVariableSolc +from slither.solc_parsing.variables.local_variable_init_from_tuple import LocalVariableInitFromTupleSolc +from slither.solc_parsing.variables.variable_declaration import MultipleVariablesDeclaration + +from slither.solc_parsing.expressions.expression_parsing import parse_expression + +from slither.visitors.expression.export_values import ExportValues +from slither.visitors.expression.has_conditional import HasConditional + +from slither.utils.expression_manipulations import SplitTernaryExpression + +from slither.slithir.utils.variable_number import transform_slithir_vars_to_ssa + +logger = logging.getLogger("FunctionSolc") + +class FunctionSolc(Function): + """ + Event class + """ + # elems = [(type, name)] + + def __init__(self, function, contract): + super(FunctionSolc, self).__init__() + self._contract = contract + if self.is_compact_ast: + self._name = function['name'] + else: + self._name = function['attributes'][self.get_key()] + self._functionNotParsed = function + self._params_was_analyzed = False + self._content_was_analyzed = False + self._counter_nodes = 0 + + def get_key(self): + return self.slither.get_key() + + def get_children(self, key): + if self.is_compact_ast: + return key + return 'children' + + @property + def is_compact_ast(self): + return self.slither.is_compact_ast + + def _analyze_attributes(self): + if self.is_compact_ast: + attributes = self._functionNotParsed + else: + attributes = self._functionNotParsed['attributes'] + + if 'payable' in attributes: + self._payable = attributes['payable'] + if 'stateMutability' in attributes: + if attributes['stateMutability'] == 'payable': + self._payable = True + elif attributes['stateMutability'] == 'pure': + self._pure = True + self._view = True + elif attributes['stateMutability'] == 'view': + self._view = True + + if 'constant' in attributes: + self._view = attributes['constant'] + + self._is_constructor = False + + if 'isConstructor' in attributes: + self._is_constructor = attributes['isConstructor'] + + if 'visibility' in attributes: + self._visibility = attributes['visibility'] + # old solc + elif 'public' in attributes: + if attributes['public']: + self._visibility = 'public' + else: + self._visibility = 'private' + else: + self._visibility = 'public' + + if 'payable' in attributes: + self._payable = attributes['payable'] + + def _new_node(self, node_type, src): + node = NodeSolc(node_type, self._counter_nodes) + node.set_offset(src, self.slither) + self._counter_nodes += 1 + node.set_function(self) + self._nodes.append(node) + return node + + def _parse_if(self, ifStatement, node): + # IfStatement = 'if' '(' Expression ')' Statement ( 'else' Statement )? + falseStatement = None + + if self.is_compact_ast: + condition = ifStatement['condition'] + # Note: check if the expression could be directly + # parsed here + condition_node = self._new_node(NodeType.IF, ifStatement['src']) + condition_node.add_unparsed_expression(condition) + link_nodes(node, condition_node) + trueStatement = self._parse_statement(ifStatement['trueBody'], condition_node) + if ifStatement['falseBody']: + falseStatement = self._parse_statement(ifStatement['falseBody'], condition_node) + else: + children = ifStatement[self.get_children('children')] + condition = children[0] + # Note: check if the expression could be directly + # parsed here + condition_node = self._new_node(NodeType.IF, ifStatement['src']) + condition_node.add_unparsed_expression(condition) + link_nodes(node, condition_node) + trueStatement = self._parse_statement(children[1], condition_node) + if len(children) == 3: + falseStatement = self._parse_statement(children[2], condition_node) + + endIf_node = self._new_node(NodeType.ENDIF, ifStatement['src']) + link_nodes(trueStatement, endIf_node) + + if falseStatement: + link_nodes(falseStatement, endIf_node) + else: + link_nodes(condition_node, endIf_node) + return endIf_node + +# def _parse_if(self, ifStatement, node): +# # IfStatement = 'if' '(' Expression ')' Statement ( 'else' Statement )? +# +# children = ifStatement[self.get_children('children')] +# condition_node = self._new_node(NodeType.IF) +# #condition = parse_expression(children[0], self) +# condition = children[0] +# condition_node.add_unparsed_expression(condition) +# +# link_nodes(node, condition_node) +# +# trueStatement = self._parse_statement(children[1], condition_node) +# +# +# endIf_node = self._new_node(NodeType.ENDIF) +# link_nodes(trueStatement, endIf_node) +# +# if len(children) == 3: +# falseStatement = self._parse_statement(children[2], condition_node) +# +# link_nodes(falseStatement, endIf_node) +# +# else: +# link_nodes(condition_node, endIf_node) +# +# return endIf_node + + def _parse_while(self, whileStatement, node): + # WhileStatement = 'while' '(' Expression ')' Statement + + node_startWhile = self._new_node(NodeType.STARTLOOP, whileStatement['src']) + node_condition = self._new_node(NodeType.IFLOOP, whileStatement['src']) + + if self.is_compact_ast: + node_condition.add_unparsed_expression(whileStatement['condition']) + statement = self._parse_statement(whileStatement['body'], node_condition) + else: + children = whileStatement[self.get_children('children')] + expression = children[0] + node_condition.add_unparsed_expression(expression) + statement = self._parse_statement(children[1], node_condition) + + node_endWhile = self._new_node(NodeType.ENDLOOP, whileStatement['src']) + + link_nodes(node, node_startWhile) + link_nodes(node_startWhile, node_condition) + link_nodes(statement, node_condition) + link_nodes(node_condition, node_endWhile) + + return node_endWhile + + def _parse_for_compact_ast(self, statement, node): + body = statement['body'] + init_expression = statement['initializationExpression'] + condition = statement['condition'] + loop_expression = statement['loopExpression'] + + node_startLoop = self._new_node(NodeType.STARTLOOP, statement['src']) + node_endLoop = self._new_node(NodeType.ENDLOOP, statement['src']) + + if init_expression: + node_init_expression = self._parse_statement(init_expression, node) + link_nodes(node_init_expression, node_startLoop) + else: + link_nodes(node, node_startLoop) + + if condition: + node_condition = self._new_node(NodeType.IFLOOP, statement['src']) + node_condition.add_unparsed_expression(condition) + link_nodes(node_startLoop, node_condition) + link_nodes(node_condition, node_endLoop) + else: + node_condition = node_startLoop + + node_body = self._parse_statement(body, node_condition) + + if loop_expression: + node_LoopExpression = self._parse_statement(loop_expression, node_body) + link_nodes(node_LoopExpression, node_startLoop) + else: + link_nodes(node_body, node_startLoop) + + if not condition: + if not loop_expression: + # TODO: fix case where loop has no expression + link_nodes(node_startLoop, node_endLoop) + else: + link_nodes(node_LoopExpression, node_endLoop) + + return node_endLoop + + + def _parse_for(self, statement, node): + # ForStatement = 'for' '(' (SimpleStatement)? ';' (Expression)? ';' (ExpressionStatement)? ')' Statement + + # the handling of loop in the legacy ast is too complex + # to integrate the comapct ast + # its cleaner to do it separately + if self.is_compact_ast: + return self._parse_for_compact_ast(statement, node) + + hasInitExession = True + hasCondition = True + hasLoopExpression = True + + # Old solc version do not prevent in the attributes + # if the loop has a init value /condition or expression + # There is no way to determine that for(a;;) and for(;a;) are different with old solc + if 'attributes' in statement: + if 'initializationExpression' in statement: + if not statement['initializationExpression']: + hasInitExession = False + if 'condition' in statement: + if not statement['condition']: + hasCondition = False + if 'loopExpression' in statement: + if not statement['loopExpression']: + hasLoopExpression = False + + + node_startLoop = self._new_node(NodeType.STARTLOOP, statement['src']) + node_endLoop = self._new_node(NodeType.ENDLOOP, statement['src']) + + children = statement[self.get_children('children')] + + if hasInitExession: + if len(children) >= 2: + if children[0][self.get_key()] in ['VariableDefinitionStatement', + 'VariableDeclarationStatement', + 'ExpressionStatement']: + node_initExpression = self._parse_statement(children[0], node) + link_nodes(node_initExpression, node_startLoop) + else: + hasInitExession = False + else: + hasInitExession = False + + if not hasInitExession: + link_nodes(node, node_startLoop) + node_condition = node_startLoop + + if hasCondition: + if hasInitExession and len(children) >= 2: + candidate = children[1] + else: + candidate = children[0] + if candidate[self.get_key()] not in ['VariableDefinitionStatement', + 'VariableDeclarationStatement', + 'ExpressionStatement']: + node_condition = self._new_node(NodeType.IFLOOP, statement['src']) + #expression = parse_expression(candidate, self) + expression = candidate + node_condition.add_unparsed_expression(expression) + link_nodes(node_startLoop, node_condition) + link_nodes(node_condition, node_endLoop) + hasCondition = True + else: + hasCondition = False + + + node_statement = self._parse_statement(children[-1], node_condition) + + node_LoopExpression = node_statement + if hasLoopExpression: + if len(children) > 2: + if children[-2][self.get_key()] == 'ExpressionStatement': + node_LoopExpression = self._parse_statement(children[-2], node_statement) + if not hasCondition: + link_nodes(node_LoopExpression, node_endLoop) + + if not hasCondition and not hasLoopExpression: + link_nodes(node, node_endLoop) + + link_nodes(node_LoopExpression, node_startLoop) + + return node_endLoop + + def _parse_dowhile(self, doWhilestatement, node): + + node_startDoWhile = self._new_node(NodeType.STARTLOOP, doWhilestatement['src']) + node_condition = self._new_node(NodeType.IFLOOP, doWhilestatement['src']) + + if self.is_compact_ast: + node_condition.add_unparsed_expression(doWhilestatement['condition']) + statement = self._parse_statement(doWhilestatement['body'], node_condition) + else: + children = doWhilestatement[self.get_children('children')] + # same order in the AST as while + expression = children[0] + node_condition.add_unparsed_expression(expression) + statement = self._parse_statement(children[1], node_condition) + + node_endDoWhile = self._new_node(NodeType.ENDLOOP, doWhilestatement['src']) + + link_nodes(node, node_startDoWhile) + link_nodes(node_startDoWhile, statement) + link_nodes(statement, node_condition) + link_nodes(node_condition, node_endDoWhile) + + return node_endDoWhile + + def _parse_variable_definition(self, statement, node): + try: + local_var = LocalVariableSolc(statement) + local_var.set_function(self) + local_var.set_offset(statement['src'], self.contract.slither) + + self._variables[local_var.name] = local_var + #local_var.analyze(self) + + new_node = self._new_node(NodeType.VARIABLE, statement['src']) + new_node.add_variable_declaration(local_var) + link_nodes(node, new_node) + return new_node + except MultipleVariablesDeclaration: + # Custom handling of var (a,b) = .. style declaration + if self.is_compact_ast: + variables = statement['declarations'] + count = len(variables) + + if statement['initialValue']['nodeType'] == 'TupleExpression': + inits = statement['initialValue']['components'] + i = 0 + new_node = node + for variable in variables: + init = inits[i] + src = variable['src'] + i = i+1 + + new_statement = {'nodeType':'VariableDefinitionStatement', + 'src': src, + 'declarations':[variable], + 'initialValue':init} + new_node = self._parse_variable_definition(new_statement, new_node) + + else: + # If we have + # var (a, b) = f() + # we can split in multiple declarations, without init + # Then we craft one expression that does the assignment + variables = [] + i = 0 + new_node = node + for variable in statement['declarations']: + i = i+1 + if variable: + src = variable['src'] + # Create a fake statement to be consistent + new_statement = {'nodeType':'VariableDefinitionStatement', + 'src': src, + 'declarations':[variable]} + variables.append(variable) + + new_node = self._parse_variable_definition_init_tuple(new_statement, + i, + new_node) + + var_identifiers = [] + # craft of the expression doing the assignement + for v in variables: + identifier = { + 'nodeType':'Identifier', + 'src': v['src'], + 'name': v['name'], + 'typeDescriptions': { + 'typeString':v['typeDescriptions']['typeString'] + } + } + var_identifiers.append(identifier) + + tuple_expression = {'nodeType':'TupleExpression', + 'src': statement['src'], + 'components':var_identifiers} + + expression = { + 'nodeType' : 'Assignment', + 'src':statement['src'], + 'operator': '=', + 'type':'tuple()', + 'leftHandSide': tuple_expression, + 'rightHandSide': statement['initialValue'], + 'typeDescriptions': {'typeString':'tuple()'} + } + node = new_node + new_node = self._new_node(NodeType.EXPRESSION, statement['src']) + new_node.add_unparsed_expression(expression) + link_nodes(node, new_node) + + + else: + count = 0 + children = statement[self.get_children('children')] + child = children[0] + while child[self.get_key()] == 'VariableDeclaration': + count = count +1 + child = children[count] + + assert len(children) == (count + 1) + tuple_vars = children[count] + + + variables_declaration = children[0:count] + i = 0 + new_node = node + if tuple_vars[self.get_key()] == 'TupleExpression': + assert len(tuple_vars[self.get_children('children')]) == count + for variable in variables_declaration: + init = tuple_vars[self.get_children('children')][i] + src = variable['src'] + i = i+1 + # Create a fake statement to be consistent + new_statement = {self.get_key():'VariableDefinitionStatement', + 'src': src, + self.get_children('children'):[variable, init]} + + new_node = self._parse_variable_definition(new_statement, new_node) + else: + # If we have + # var (a, b) = f() + # we can split in multiple declarations, without init + # Then we craft one expression that does the assignment + assert tuple_vars[self.get_key()] in ['FunctionCall', 'Conditional'] + variables = [] + for variable in variables_declaration: + src = variable['src'] + i = i+1 + # Create a fake statement to be consistent + new_statement = {self.get_key():'VariableDefinitionStatement', + 'src': src, + self.get_children('children'):[variable]} + variables.append(variable) + + new_node = self._parse_variable_definition_init_tuple(new_statement, i, new_node) + var_identifiers = [] + # craft of the expression doing the assignement + for v in variables: + identifier = { + self.get_key() : 'Identifier', + 'src': v['src'], + 'attributes': { + 'value': v['attributes'][self.get_key()], + 'type': v['attributes']['type']} + } + var_identifiers.append(identifier) + + expression = { + self.get_key() : 'Assignment', + 'src':statement['src'], + 'attributes': {'operator': '=', + 'type':'tuple()'}, + self.get_children('children'): + [{self.get_key(): 'TupleExpression', + 'src': statement['src'], + self.get_children('children'): var_identifiers}, + tuple_vars]} + node = new_node + new_node = self._new_node(NodeType.EXPRESSION, statement['src']) + new_node.add_unparsed_expression(expression) + link_nodes(node, new_node) + + + return new_node + + def _parse_variable_definition_init_tuple(self, statement, index, node): + local_var = LocalVariableInitFromTupleSolc(statement, index) + #local_var = LocalVariableSolc(statement[self.get_children('children')][0], statement[self.get_children('children')][1::]) + local_var.set_function(self) + local_var.set_offset(statement['src'], self.contract.slither) + + self._variables[local_var.name] = local_var +# local_var.analyze(self) + + new_node = self._new_node(NodeType.VARIABLE, statement['src']) + new_node.add_variable_declaration(local_var) + link_nodes(node, new_node) + return new_node + + + def _parse_statement(self, statement, node): + """ + + Return: + node + """ + # Statement = IfStatement | WhileStatement | ForStatement | Block | InlineAssemblyStatement | + # ( DoWhileStatement | PlaceholderStatement | Continue | Break | Return | + # Throw | EmitStatement | SimpleStatement ) ';' + # SimpleStatement = VariableDefinition | ExpressionStatement + + name = statement[self.get_key()] + # SimpleStatement = VariableDefinition | ExpressionStatement + if name == 'IfStatement': + node = self._parse_if(statement, node) + elif name == 'WhileStatement': + node = self._parse_while(statement, node) + elif name == 'ForStatement': + node = self._parse_for(statement, node) + elif name == 'Block': + node = self._parse_block(statement, node) + elif name == 'InlineAssembly': + break_node = self._new_node(NodeType.ASSEMBLY, statement['src']) + link_nodes(node, break_node) + node = break_node + elif name == 'DoWhileStatement': + node = self._parse_dowhile(statement, node) + # For Continue / Break / Return / Throw + # The is fixed later + elif name == 'Continue': + continue_node = self._new_node(NodeType.CONTINUE, statement['src']) + link_nodes(node, continue_node) + node = continue_node + elif name == 'Break': + break_node = self._new_node(NodeType.BREAK, statement['src']) + link_nodes(node, break_node) + node = break_node + elif name == 'Return': + return_node = self._new_node(NodeType.RETURN, statement['src']) + link_nodes(node, return_node) + if self.is_compact_ast: + if statement['expression']: + return_node.add_unparsed_expression(statement['expression']) + else: + if self.get_children('children') in statement and statement[self.get_children('children')]: + assert len(statement[self.get_children('children')]) == 1 + expression = statement[self.get_children('children')][0] + return_node.add_unparsed_expression(expression) + node = return_node + elif name == 'Throw': + throw_node = self._new_node(NodeType.THROW, statement['src']) + link_nodes(node, throw_node) + node = throw_node + elif name == 'EmitStatement': + #expression = parse_expression(statement[self.get_children('children')][0], self) + if self.is_compact_ast: + expression = statement['eventCall'] + else: + expression = statement[self.get_children('children')][0] + new_node = self._new_node(NodeType.EXPRESSION, statement['src']) + new_node.add_unparsed_expression(expression) + link_nodes(node, new_node) + node = new_node + elif name in ['VariableDefinitionStatement', 'VariableDeclarationStatement']: + node = self._parse_variable_definition(statement, node) + elif name == 'ExpressionStatement': + #assert len(statement[self.get_children('expression')]) == 1 + #assert not 'attributes' in statement + #expression = parse_expression(statement[self.get_children('children')][0], self) + if self.is_compact_ast: + expression = statement[self.get_children('expression')] + else: + expression = statement[self.get_children('expression')][0] + new_node = self._new_node(NodeType.EXPRESSION, statement['src']) + new_node.add_unparsed_expression(expression) + link_nodes(node, new_node) + node = new_node + else: + logger.error('Statement not parsed %s'%name) + exit(-1) + + return node + + def _parse_block(self, block, node): + ''' + Return: + Node + ''' + assert block[self.get_key()] == 'Block' + + if self.is_compact_ast: + statements = block['statements'] + else: + statements = block[self.get_children('children')] + + for statement in statements: + node = self._parse_statement(statement, node) + return node + + def _parse_cfg(self, cfg): + + assert cfg[self.get_key()] == 'Block' + + node = self._new_node(NodeType.ENTRYPOINT, cfg['src']) + self._entry_point = node + + if self.is_compact_ast: + statements = cfg['statements'] + else: + statements = cfg[self.get_children('children')] + + if not statements: + self._is_empty = True + else: + self._is_empty = False + self._parse_block(cfg, node) + self._remove_incorrect_edges() + self._remove_alone_endif() + + def _find_end_loop(self, node, visited): + if node in visited: + return None + + if node.type == NodeType.ENDLOOP: + return node + + visited = visited + [node] + for son in node.sons: + ret = self._find_end_loop(son, visited) + if ret: + return ret + + return None + + def _find_start_loop(self, node, visited): + if node in visited: + return None + + if node.type == NodeType.STARTLOOP: + return node + + visited = visited + [node] + for father in node.fathers: + ret = self._find_start_loop(father, visited) + if ret: + return ret + + return None + + def _fix_break_node(self, node): + end_node = self._find_end_loop(node, []) + + if not end_node: + logger.error('Break in no-loop context {}'.format(node)) + exit(-1) + + for son in node.sons: + son.remove_father(node) + node.set_sons([end_node]) + end_node.add_father(node) + + def _fix_continue_node(self, node): + start_node = self._find_start_loop(node, []) + + if not start_node: + logger.error('Continue in no-loop context {}'.format(node.nodeId())) + exit(-1) + + for son in node.sons: + son.remove_father(node) + node.set_sons([start_node]) + start_node.add_father(node) + + def _remove_incorrect_edges(self): + for node in self._nodes: + if node.type in [NodeType.RETURN, NodeType.THROW]: + for son in node.sons: + son.remove_father(node) + node.set_sons([]) + if node.type in [NodeType.BREAK]: + self._fix_break_node(node) + if node.type in [NodeType.CONTINUE]: + self._fix_continue_node(node) + + def _remove_alone_endif(self): + """ + Can occur on: + if(..){ + return + } + else{ + return + } + + Iterate until a fix point to remove the ENDIF node + creates on the following pattern + if(){ + return + } + else if(){ + return + } + """ + prev_nodes = [] + while set(prev_nodes) != set(self.nodes): + prev_nodes = self.nodes + to_remove = [] + for node in self.nodes: + if node.type == NodeType.ENDIF and not node.fathers: + for son in node.sons: + son.remove_father(node) + node.set_sons([]) + to_remove.append(node) + self._nodes = [n for n in self.nodes if not n in to_remove] +# + def _parse_params(self, params): + assert params[self.get_key()] == 'ParameterList' + + if self.is_compact_ast: + params = params['parameters'] + else: + params = params[self.get_children('children')] + + for param in params: + assert param[self.get_key()] == 'VariableDeclaration' + + local_var = LocalVariableSolc(param) + + local_var.set_function(self) + local_var.set_offset(param['src'], self.contract.slither) + local_var.analyze(self) + + # see https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location + if local_var.location == 'default': + local_var.set_location('memory') + + self._variables[local_var.name] = local_var + self._parameters.append(local_var) + + def _parse_returns(self, returns): + + assert returns[self.get_key()] == 'ParameterList' + + if self.is_compact_ast: + returns = returns['parameters'] + else: + returns = returns[self.get_children('children')] + + for ret in returns: + assert ret[self.get_key()] == 'VariableDeclaration' + + local_var = LocalVariableSolc(ret) + + local_var.set_function(self) + local_var.set_offset(ret['src'], self.contract.slither) + local_var.analyze(self) + + # see https://solidity.readthedocs.io/en/v0.4.24/types.html?highlight=storage%20location#data-location + if local_var.location == 'default': + local_var.set_location('memory') + + self._variables[local_var.name] = local_var + self._returns.append(local_var) + + + def _parse_modifier(self, modifier): + m = parse_expression(modifier, self) + self._expression_modifiers.append(m) + self._modifiers += [m for m in ExportValues(m).result() if isinstance(m, Function)] + + + def analyze_params(self): + # Can be re-analyzed due to inheritance + if self._params_was_analyzed: + return + + self._params_was_analyzed = True + + self._analyze_attributes() + + if self.is_compact_ast: + params = self._functionNotParsed['parameters'] + returns = self._functionNotParsed['returnParameters'] + else: + children = self._functionNotParsed[self.get_children('children')] + params = children[0] + returns = children[1] + + if params: + self._parse_params(params) + if returns: + self._parse_returns(returns) + + def analyze_content(self): + if self._content_was_analyzed: + return + + self._content_was_analyzed = True + + if self.is_compact_ast: + body = self._functionNotParsed['body'] + + if body and body[self.get_key()] == 'Block': + self._is_implemented = True + self._parse_cfg(body) + + for modifier in self._functionNotParsed['modifiers']: + self._parse_modifier(modifier) + + else: + children = self._functionNotParsed[self.get_children('children')] + self._is_implemented = False + for child in children[2:]: + if child[self.get_key()] == 'Block': + self._is_implemented = True + self._parse_cfg(child) + continue + + assert child[self.get_key()] == 'ModifierInvocation' + + self._parse_modifier(child) + + for local_vars in self.variables: + local_vars.analyze(self) + + for node in self.nodes: + node.analyze_expressions(self) + + ternary_found = True + while ternary_found: + ternary_found = False + for node in self.nodes: + has_cond = HasConditional(node.expression) + if has_cond.result(): + st = SplitTernaryExpression(node.expression) + condition = st.condition + assert condition + true_expr = st.true_expression + false_expr = st.false_expression + self.split_ternary_node(node, condition, true_expr, false_expr) + ternary_found = True + break + self._remove_alone_endif() + + + def convert_expression_to_slithir(self): + for node in self.nodes: + node.slithir_generation() + transform_slithir_vars_to_ssa(self) + self._analyze_read_write() + self._analyze_calls() + + + def split_ternary_node(self, node, condition, true_expr, false_expr): + condition_node = self._new_node(NodeType.IF, node.source_mapping) + condition_node.add_expression(condition) + condition_node.analyze_expressions(self) + + true_node = self._new_node(node.type, node.source_mapping) + if node.type == NodeType.VARIABLE: + true_node.add_variable_declaration(node.variable_declaration) + true_node.add_expression(true_expr) + true_node.analyze_expressions(self) + + false_node = self._new_node(node.type, node.source_mapping) + if node.type == NodeType.VARIABLE: + false_node.add_variable_declaration(node.variable_declaration) + false_node.add_expression(false_expr) + false_node.analyze_expressions(self) + + endif_node = self._new_node(NodeType.ENDIF, node.source_mapping) + + for father in node.fathers: + father.remove_son(node) + father.add_son(condition_node) + condition_node.add_father(father) + + for son in node.sons: + son.remove_father(node) + son.add_father(endif_node) + endif_node.add_son(son) + + link_nodes(condition_node, true_node) + link_nodes(condition_node, false_node) + + + if not true_node.type in [NodeType.THROW, NodeType.RETURN]: + link_nodes(true_node, endif_node) + if not false_node.type in [NodeType.THROW, NodeType.RETURN]: + link_nodes(false_node, endif_node) + + self._nodes = [n for n in self._nodes if n.node_id != node.node_id] + + diff --git a/slither/solcParsing/declarations/modifierSolc.py b/slither/solc_parsing/declarations/modifier.py similarity index 56% rename from slither/solcParsing/declarations/modifierSolc.py rename to slither/solc_parsing/declarations/modifier.py index 09b57851a..1341779c3 100644 --- a/slither/solcParsing/declarations/modifierSolc.py +++ b/slither/solc_parsing/declarations/modifier.py @@ -2,9 +2,9 @@ Event module """ from slither.core.declarations.modifier import Modifier -from slither.solcParsing.declarations.functionSolc import FunctionSolc +from slither.solc_parsing.declarations.function import FunctionSolc -from slither.core.cfg.nodeType import NodeType +from slither.core.cfg.node import NodeType from slither.core.cfg.node import link_nodes class ModifierSolc(Modifier, FunctionSolc): @@ -19,9 +19,11 @@ class ModifierSolc(Modifier, FunctionSolc): self._analyze_attributes() - children = self._functionNotParsed['children'] - - params = children[0] + if self.is_compact_ast: + params = self._functionNotParsed['parameters'] + else: + children = self._functionNotParsed['children'] + params = children[0] if params: self._parse_params(params) @@ -32,15 +34,24 @@ class ModifierSolc(Modifier, FunctionSolc): self._content_was_analyzed = True - children = self._functionNotParsed['children'] - self._isImplemented = False - if len(children) > 1: - assert len(children) == 2 - block = children[1] - assert block['name'] == 'Block' + if self.is_compact_ast: + body = self._functionNotParsed['body'] + + if body and body[self.get_key()] == 'Block': + self._is_implemented = True + self._parse_cfg(body) + + else: + children = self._functionNotParsed['children'] + self._isImplemented = False - self._parse_cfg(block) + if len(children) > 1: + assert len(children) == 2 + block = children[1] + assert block['name'] == 'Block' + self._isImplemented = True + self._parse_cfg(block) for local_vars in self.variables: local_vars.analyze(self) @@ -52,9 +63,9 @@ class ModifierSolc(Modifier, FunctionSolc): self._analyze_calls() def _parse_statement(self, statement, node): - name = statement['name'] + name = statement[self.get_key()] if name == 'PlaceholderStatement': - placeholder_node = self._new_node(NodeType.PLACEHOLDER) + placeholder_node = self._new_node(NodeType.PLACEHOLDER, statement['src']) link_nodes(node, placeholder_node) return placeholder_node return super(ModifierSolc, self)._parse_statement(statement, node) diff --git a/slither/solcParsing/declarations/structureSolc.py b/slither/solc_parsing/declarations/structure.py similarity index 81% rename from slither/solcParsing/declarations/structureSolc.py rename to slither/solc_parsing/declarations/structure.py index 826f597a5..0026f451e 100644 --- a/slither/solcParsing/declarations/structureSolc.py +++ b/slither/solc_parsing/declarations/structure.py @@ -1,7 +1,7 @@ """ Structure module """ -from slither.solcParsing.variables.structureVariableSolc import StructureVariableSolc +from slither.solc_parsing.variables.structure_variable import StructureVariableSolc from slither.core.declarations.structure import Structure class StructureSolc(Structure): @@ -23,7 +23,7 @@ class StructureSolc(Structure): for elem_to_parse in self._elemsNotParsed: elem = StructureVariableSolc(elem_to_parse) elem.set_structure(self) - elem.set_offset(elem_to_parse['src']) + elem.set_offset(elem_to_parse['src'], self.contract.slither) elem.analyze(self.contract) diff --git a/slither/solc_parsing/expressions/__init__.py b/slither/solc_parsing/expressions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py new file mode 100644 index 000000000..02d095cc8 --- /dev/null +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -0,0 +1,545 @@ +import logging +import re +from slither.core.expressions.unary_operation import UnaryOperation, UnaryOperationType +from slither.core.expressions.binary_operation import BinaryOperation, BinaryOperationType +from slither.core.expressions.literal import Literal +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.super_identifier import SuperIdentifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.assignment_operation import AssignmentOperation, AssignmentOperationType +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.super_call_expression import SuperCallExpression +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.new_elementary_type import NewElementaryType +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression + +from slither.solc_parsing.solidity_types.type_parsing import parse_type, UnknownType + +from slither.core.declarations.contract import Contract +from slither.core.declarations.function import Function + +from slither.core.declarations.solidity_variables import SOLIDITY_VARIABLES, SOLIDITY_FUNCTIONS, SOLIDITY_VARIABLES_COMPOSED +from slither.core.declarations.solidity_variables import SolidityVariable, SolidityFunction, SolidityVariableComposed, solidity_function_signature + +from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, FunctionType + + +logger = logging.getLogger("ExpressionParsing") + +class VariableNotFound(Exception): pass + +def get_pointer_name(variable): + curr_type = variable.type + while(isinstance(curr_type, (ArrayType, MappingType))): + if isinstance(curr_type, ArrayType): + curr_type = curr_type.type + else: + assert isinstance(curr_type, MappingType) + curr_type = curr_type.type_to + + if isinstance(curr_type, (FunctionType)): + return variable.name + curr_type.parameters_signature + return None + + +def find_variable(var_name, caller_context): + + if isinstance(caller_context, Contract): + function = None + contract = caller_context + elif isinstance(caller_context, Function): + function = caller_context + contract = function.contract + else: + logger.error('Incorrect caller context') + exit(-1) + + if function: + func_variables = function.variables_as_dict() + if var_name in func_variables: + return func_variables[var_name] + # A local variable can be a pointer + # for example + # function test(function(uint) internal returns(bool) t) interna{ + # Will have a local variable t which will match the signature + # t(uint256) + func_variables_ptr = {get_pointer_name(f) : f for f in function.variables} + if var_name and var_name in func_variables_ptr: + return func_variables_ptr[var_name] + + contract_variables = contract.variables_as_dict() + if var_name in contract_variables: + return contract_variables[var_name] + + # A state variable can be a pointer + conc_variables_ptr = {get_pointer_name(f) : f for f in contract.variables} + if var_name and var_name in conc_variables_ptr: + return conc_variables_ptr[var_name] + + + functions = contract.functions_as_dict() + if var_name in functions: + return functions[var_name] + + modifiers = contract.modifiers_as_dict() + if var_name in modifiers: + return modifiers[var_name] + + structures = contract.structures_as_dict() + if var_name in structures: + return structures[var_name] + + events = contract.events_as_dict() + if var_name in events: + return events[var_name] + + enums = contract.enums_as_dict() + if var_name in enums: + return enums[var_name] + + # If the enum is refered as its name rather than its canonicalName + enums = {e.name: e for e in contract.enums} + if var_name in enums: + return enums[var_name] + + # Could refer to any enum + all_enums = [c.enums_as_dict() for c in contract.slither.contracts] + all_enums = {k: v for d in all_enums for k, v in d.items()} + if var_name in all_enums: + return all_enums[var_name] + + if var_name in SOLIDITY_VARIABLES: + return SolidityVariable(var_name) + + if var_name in SOLIDITY_FUNCTIONS: + return SolidityFunction(var_name) + + contracts = contract.slither.contracts_as_dict() + if var_name in contracts: + return contracts[var_name] + + raise VariableNotFound('Variable not found: {}'.format(var_name)) + + +def parse_call(expression, caller_context): + + if caller_context.is_compact_ast: + attributes = expression + type_conversion = expression['kind'] == 'typeConversion' + type_return = attributes['typeDescriptions']['typeString'] + + else: + attributes = expression['attributes'] + type_conversion = attributes['type_conversion'] + type_return = attributes['type'] + + if type_conversion: + type_call = parse_type(UnknownType(type_return), caller_context) + + + if caller_context.is_compact_ast: + type_info = expression['expression'] + assert len(expression['arguments']) == 1 + expression_to_parse = expression['arguments'][0] + else: + children = expression['children'] + assert len(children) == 2 + type_info = children[0] + expression_to_parse = children[1] + assert type_info['name'] in ['ElementaryTypenameExpression', + 'ElementaryTypeNameExpression', + 'Identifier', + 'TupleExpression', + 'IndexAccess', + 'MemberAccess'] + + expression = parse_expression(expression_to_parse, caller_context) + t = TypeConversion(expression, type_call) + return t + + if caller_context.is_compact_ast: + called = parse_expression(expression['expression'], caller_context) + arguments = [] + if expression['arguments']: + arguments = [parse_expression(a, caller_context) for a in expression['arguments']] + else: + children = expression['children'] + called = parse_expression(children[0], caller_context) + arguments = [parse_expression(a, caller_context) for a in children[1::]] + + if isinstance(called, SuperCallExpression): + return SuperCallExpression(called, arguments, type_return) + return CallExpression(called, arguments, type_return) + +def parse_super_name(expression, is_compact_ast): + if is_compact_ast: + assert expression['nodeType'] == 'MemberAccess' + attributes = expression + base_name = expression['memberName'] + arguments = expression['typeDescriptions']['typeString'] + else: + assert expression['name'] == 'MemberAccess' + attributes = expression['attributes'] + base_name = attributes['member_name'] + arguments = attributes['type'] + + assert arguments.startswith('function ') + # remove function (...() + arguments = arguments[len('function '):] + + arguments = filter_name(arguments) + if ' ' in arguments: + arguments = arguments[:arguments.find(' ')] + + return base_name+arguments + +def filter_name(value): + value = value.replace(' memory', '') + value = value.replace(' storage', '') + value = value.replace(' external', '') + value = value.replace(' internal', '') + value = value.replace('struct ', '') + value = value.replace('contract ', '') + value = value.replace('enum ', '') + value = value.replace(' ref', '') + value = value.replace(' pointer', '') + value = value.replace(' pure', '') + value = value.replace(' view', '') + value = value.replace(' constant', '') + value = value.replace('function (', 'function(') + value = value.replace('returns (', 'returns(') + + # remove the text remaining after functio(...) + # which should only be ..returns(...) + # nested parenthesis so we use a system of counter on parenthesis + idx = value.find('(') + if idx: + counter = 1 + max_idx = len(value) + while counter: + assert idx < max_idx + idx = idx +1 + if value[idx] == '(': + counter += 1 + elif value[idx] == ')': + counter -= 1 + value = value[:idx+1] + return value + +def parse_expression(expression, caller_context): + """ + + Returns: + str: expression + """ + # Expression + # = Expression ('++' | '--') + # | NewExpression + # | IndexAccess + # | MemberAccess + # | FunctionCall + # | '(' Expression ')' + # | ('!' | '~' | 'delete' | '++' | '--' | '+' | '-') Expression + # | Expression '**' Expression + # | Expression ('*' | '/' | '%') Expression + # | Expression ('+' | '-') Expression + # | Expression ('<<' | '>>') Expression + # | Expression '&' Expression + # | Expression '^' Expression + # | Expression '|' Expression + # | Expression ('<' | '>' | '<=' | '>=') Expression + # | Expression ('==' | '!=') Expression + # | Expression '&&' Expression + # | Expression '||' Expression + # | Expression '?' Expression ':' Expression + # | Expression ('=' | '|=' | '^=' | '&=' | '<<=' | '>>=' | '+=' | '-=' | '*=' | '/=' | '%=') Expression + # | PrimaryExpression + + # The AST naming does not follow the spec + name = expression[caller_context.get_key()] + is_compact_ast = caller_context.is_compact_ast + + if name == 'UnaryOperation': + if is_compact_ast: + attributes = expression + else: + attributes = expression['attributes'] + assert 'prefix' in attributes + operation_type = UnaryOperationType.get_type(attributes['operator'], attributes['prefix']) + + if is_compact_ast: + expression = parse_expression(expression['subExpression'], caller_context) + else: + assert len(expression['children']) == 1 + expression = parse_expression(expression['children'][0], caller_context) + unary_op = UnaryOperation(expression, operation_type) + return unary_op + + elif name == 'BinaryOperation': + if is_compact_ast: + attributes = expression + else: + attributes = expression['attributes'] + operation_type = BinaryOperationType.get_type(attributes['operator']) + + if is_compact_ast: + left_expression = parse_expression(expression['leftExpression'], caller_context) + right_expression = parse_expression(expression['rightExpression'], caller_context) + else: + assert len(expression['children']) == 2 + left_expression = parse_expression(expression['children'][0], caller_context) + right_expression = parse_expression(expression['children'][1], caller_context) + binary_op = BinaryOperation(left_expression, right_expression, operation_type) + return binary_op + + elif name == 'FunctionCall': + return parse_call(expression, caller_context) + + elif name == 'TupleExpression': + """ + For expression like + (a,,c) = (1,2,3) + the AST provides only two children in the left side + We check the type provided (tuple(uint256,,uint256)) + To determine that there is an empty variable + Otherwhise we would not be able to determine that + a = 1, c = 3, and 2 is lost + + Note: this is only possible with Solidity >= 0.4.12 + """ + if is_compact_ast: + expressions = [parse_expression(e, caller_context) if e else None for e in expression['components']] + else: + if 'children' not in expression : + attributes = expression['attributes'] + components = attributes['components'] + expressions = [parse_expression(c, caller_context) if c else None for c in components] + else: + expressions = [parse_expression(e, caller_context) for e in expression['children']] + # Add none for empty tuple items + if "attributes" in expression: + if "type" in expression['attributes']: + t = expression['attributes']['type'] + if ',,' in t or '(,' in t or ',)' in t: + t = t[len('tuple('):-1] + elems = t.split(',') + for idx in range(len(elems)): + if elems[idx] == '': + expressions.insert(idx, None) + t = TupleExpression(expressions) + return t + + elif name == 'Conditional': + if is_compact_ast: + if_expression = parse_expression(expression['condition'], caller_context) + then_expression = parse_expression(expression['trueExpression'], caller_context) + else_expression = parse_expression(expression['falseExpression'], caller_context) + else: + children = expression['children'] + assert len(children) == 3 + if_expression = parse_expression(children[0], caller_context) + then_expression = parse_expression(children[1], caller_context) + else_expression = parse_expression(children[2], caller_context) + conditional = ConditionalExpression(if_expression, then_expression, else_expression) + return conditional + + elif name == 'Assignment': + if is_compact_ast: + left_expression = parse_expression(expression['leftHandSide'], caller_context) + right_expression = parse_expression(expression['rightHandSide'], caller_context) + + operation_type = AssignmentOperationType.get_type(expression['operator']) + + operation_return_type = expression['typeDescriptions']['typeString'] + else: + attributes = expression['attributes'] + children = expression['children'] + assert len(expression['children']) == 2 + left_expression = parse_expression(children[0], caller_context) + right_expression = parse_expression(children[1], caller_context) + + operation_type = AssignmentOperationType.get_type(attributes['operator']) + operation_return_type = attributes['type'] + + assignement = AssignmentOperation(left_expression, right_expression, operation_type, operation_return_type) + return assignement + + elif name == 'Literal': + assert 'children' not in expression + + if is_compact_ast: + value = expression['value'] + if not value: + value = '0x'+expression['hexValue'] + else: + value = expression['attributes']['value'] + if value is None: + # for literal declared as hex + # see https://solidity.readthedocs.io/en/v0.4.25/types.html?highlight=hex#hexadecimal-literals + assert 'hexvalue' in expression['attributes'] + value = '0x'+expression['attributes']['hexvalue'] + literal = Literal(value) + return literal + + elif name == 'Identifier': + assert 'children' not in expression + + t = None + + if caller_context.is_compact_ast: + value = expression['name'] + t = expression['typeDescriptions']['typeString'] + else: + value = expression['attributes']['value'] + if 'type' in expression['attributes']: + t = expression['attributes']['type'] + + if t: + found = re.findall('[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)', t) + assert len(found) <= 1 + if found: + value = value+'('+found[0]+')' + value = filter_name(value) + + var = find_variable(value, caller_context) + + identifier = Identifier(var) + return identifier + + elif name == 'IndexAccess': + if is_compact_ast: + index_type = expression['typeDescriptions']['typeString'] + left = expression['baseExpression'] + right = expression['indexExpression'] + else: + index_type = expression['attributes']['type'] + children = expression['children'] + assert len(children) == 2 + left = children[0] + right = children[1] + left_expression = parse_expression(left, caller_context) + right_expression = parse_expression(right, caller_context) + index = IndexAccess(left_expression, right_expression, index_type) + return index + + elif name == 'MemberAccess': + if caller_context.is_compact_ast: + member_name = expression['memberName'] + member_type = expression['typeDescriptions']['typeString'] + member_expression = parse_expression(expression['expression'], caller_context) + else: + member_name = expression['attributes']['member_name'] + member_type = expression['attributes']['type'] + children = expression['children'] + assert len(children) == 1 + member_expression = parse_expression(children[0], caller_context) + if str(member_expression) == 'super': + super_name = parse_super_name(expression, is_compact_ast) + if isinstance(caller_context, Contract): + inheritance = caller_context.inheritance + else: + assert isinstance(caller_context, Function) + inheritance = caller_context.contract.inheritance + var = None + for father in inheritance: + try: + var = find_variable(super_name, father) + break + except VariableNotFound: + continue + if var is None: + raise VariableNotFound('Variable not found: {}'.format(super_name)) + return SuperIdentifier(var) + member_access = MemberAccess(member_name, member_type, member_expression) + if str(member_access) in SOLIDITY_VARIABLES_COMPOSED: + return Identifier(SolidityVariableComposed(str(member_access))) + return member_access + + elif name == 'ElementaryTypeNameExpression': + # nop exression + # uint; + if is_compact_ast: + value = expression['typeName'] + else: + assert 'children' not in expression + value = expression['attributes']['value'] + t = parse_type(UnknownType(value), caller_context) + + return ElementaryTypeNameExpression(t) + + + # NewExpression is not a root expression, it's always the child of another expression + elif name == 'NewExpression': + + if is_compact_ast: + type_name = expression['typeName'] + else: + children = expression['children'] + assert len(children) == 1 + type_name = children[0] + + if type_name[caller_context.get_key()] == 'ArrayTypeName': + depth = 0 + while type_name[caller_context.get_key()] == 'ArrayTypeName': + # Note: dont conserve the size of the array if provided + # We compute it directly + if is_compact_ast: + type_name = type_name['baseType'] + else: + type_name = type_name['children'][0] + depth += 1 + if type_name[caller_context.get_key()] == 'ElementaryTypeName': + if is_compact_ast: + array_type = ElementaryType(type_name['name']) + else: + array_type = ElementaryType(type_name['attributes']['name']) + elif type_name[caller_context.get_key()] == 'UserDefinedTypeName': + if is_compact_ast: + array_type = parse_type(UnknownType(type_name['name']), caller_context) + else: + array_type = parse_type(UnknownType(type_name['attributes']['name']), caller_context) + else: + logger.error('Incorrect type array {}'.format(type_name)) + exit(-1) + array = NewArray(depth, array_type) + return array + + if type_name[caller_context.get_key()] == 'ElementaryTypeName': + if is_compact_ast: + elem_type = ElementaryType(type_name['name']) + else: + elem_type = ElementaryType(type_name['attributes']['name']) + new_elem = NewElementaryType(elem_type) + return new_elem + + assert type_name[caller_context.get_key()] == 'UserDefinedTypeName' + + if is_compact_ast: + contract_name = type_name['name'] + else: + contract_name = type_name['attributes']['name'] + new = NewContract(contract_name) + return new + + elif name == 'ModifierInvocation': + + if is_compact_ast: + called = parse_expression(expression['modifierName'], caller_context) + arguments = [] + if expression['arguments']: + arguments = [parse_expression(a, caller_context) for a in expression['arguments']] + else: + children = expression['children'] + called = parse_expression(children[0], caller_context) + arguments = [parse_expression(a, caller_context) for a in children[1::]] + + call = CallExpression(called, arguments, 'Modifier') + return call + + logger.error('Expression not parsed %s'%name) + exit(-1) diff --git a/slither/solcParsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py similarity index 52% rename from slither/solcParsing/slitherSolc.py rename to slither/solc_parsing/slitherSolc.py index e56d01984..ec7383fe2 100644 --- a/slither/solcParsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -1,11 +1,14 @@ +import os import json import re import logging logger = logging.getLogger("SlitherSolcParsing") -from slither.solcParsing.declarations.contractSolc04 import ContractSolc04 -from slither.core.slitherCore import Slither +from slither.solc_parsing.declarations.contract import ContractSolc04 +from slither.core.slither_core import Slither +from slither.core.declarations.pragma_directive import Pragma +from slither.core.declarations.import_directive import Import class SlitherSolc(Slither): @@ -14,59 +17,121 @@ class SlitherSolc(Slither): self._filename = filename self._contractsNotParsed = [] self._contracts_by_id = {} + self._analyzed = False - def parse_contracts_from_json(self, json_data): - first = json_data.find('{') - if first != -1: - last = json_data.rfind('}') +1 - filename = json_data[0:first] - json_data = json_data[first:last] + self._is_compact_ast = False + def get_key(self): + if self._is_compact_ast: + return 'nodeType' + return 'name' + + def get_children(self): + if self._is_compact_ast: + return 'nodes' + return 'children' + + @property + def is_compact_ast(self): + return self._is_compact_ast + + def _parse_contracts_from_json(self, json_data): + try: data_loaded = json.loads(json_data) + self._parse_contracts_from_loaded_json(data_loaded['ast'], data_loaded['sourcePath']) + return True + except ValueError: + + first = json_data.find('{') + if first != -1: + last = json_data.rfind('}') + 1 + filename = json_data[0:first] + json_data = json_data[first:last] + + data_loaded = json.loads(json_data) + self._parse_contracts_from_loaded_json(data_loaded, filename) + return True + return False + + def _parse_contracts_from_loaded_json(self, data_loaded, filename): + if 'nodeType' in data_loaded: + self._is_compact_ast = True + + if 'sourcePaths' in data_loaded: + for sourcePath in data_loaded['sourcePaths']: + if os.path.isfile(sourcePath): + with open(sourcePath) as f: + source_code = f.read() + self.source_code[sourcePath] = source_code + + if data_loaded[self.get_key()] == 'root': + self._solc_version = '0.3' + logger.error('solc <0.4 is not supported') + return + elif data_loaded[self.get_key()] == 'SourceUnit': + self._solc_version = '0.4' + self._parse_source_unit(data_loaded, filename) + else: + logger.error('solc version is not supported') + return + + for contract_data in data_loaded[self.get_children()]: + # if self.solc_version == '0.3': + # assert contract_data[self.get_key()] == 'Contract' + # contract = ContractSolc03(self, contract_data) + if self.solc_version == '0.4': + assert contract_data[self.get_key()] in ['ContractDefinition', 'PragmaDirective', 'ImportDirective'] + if contract_data[self.get_key()] == 'ContractDefinition': + contract = ContractSolc04(self, contract_data) + if 'src' in contract_data: + contract.set_offset(contract_data['src'], self) + self._contractsNotParsed.append(contract) + elif contract_data[self.get_key()] == 'PragmaDirective': + if self._is_compact_ast: + pragma = Pragma(contract_data['literals']) + else: + pragma = Pragma(contract_data['attributes']["literals"]) + pragma.set_offset(contract_data['src'], self) + self._pragma_directives.append(pragma) + elif contract_data[self.get_key()] == 'ImportDirective': + if self.is_compact_ast: + import_directive = Import(contract_data["absolutePath"]) + else: + import_directive = Import(contract_data['attributes']["absolutePath"]) + import_directive.set_offset(contract_data['src'], self) + self._import_directives.append(import_directive) - if data_loaded['name'] == 'root': - self._solc_version = '0.3' - logger.error('solc <0.4 not supported') - exit(-1) - elif data_loaded['name'] == 'SourceUnit': - self._solc_version = '0.4' - self._parse_source_unit(data_loaded, filename) - else: - logger.error('solc version not supported') - exit(-1) - - for contract_data in data_loaded['children']: -# if self.solc_version == '0.3': -# assert contract_data['name'] == 'Contract' - # contract = ContractSolc03(self, contract_data) - if self.solc_version == '0.4': - assert contract_data['name'] in ['ContractDefinition', 'PragmaDirective', 'ImportDirective'] - if contract_data['name'] == 'ContractDefinition': - contract = ContractSolc04(self, contract_data) - self._contractsNotParsed.append(contract) - - return True - return False def _parse_source_unit(self, data, filename): - if data['name'] != 'SourceUnit': - return -1 # handle solc prior 0.3.6 + if data[self.get_key()] != 'SourceUnit': + return -1 # handle solc prior 0.3.6 # match any char for filename # filename can contain space, /, -, .. name = re.findall('=* (.+) =*', filename) - assert len(name) == 1 - name = name[0] + if name: + assert len(name) == 1 + name = name[0] + else: + name = filename - sourceUnit = -1 # handle old solc, or error + sourceUnit = -1 # handle old solc, or error if 'src' in data: sourceUnit = re.findall('[0-9]*:[0-9]*:([0-9]*)', data['src']) if len(sourceUnit) == 1: sourceUnit = int(sourceUnit[0]) self._source_units[sourceUnit] = name + if os.path.isfile(name) and not name in self.source_code: + with open(name) as f: + source_code = f.read() + self.source_code[name] = source_code + + - def analyze_contracts(self): + def _analyze_contracts(self): + if self._analyzed: + raise Exception('Contract analysis can be run only once!') # First we save all the contracts in a dict # the key is the contractid @@ -74,10 +139,10 @@ class SlitherSolc(Slither): self._contracts_by_id[contract.id] = contract self._contracts[contract.name] = contract - # Update of the inheritances + # Update of the inheritance for contract in self._contractsNotParsed: # remove the first elem in linearizedBaseContracts as it is the contract itself - contract.setInheritances([self._contracts_by_id[i] for i in contract.linearizedBaseContracts[1:]]) + contract.setInheritance([self._contracts_by_id[i] for i in contract.linearizedBaseContracts[1:]]) contracts_to_be_analyzed = self.contracts @@ -101,23 +166,29 @@ class SlitherSolc(Slither): # Then we analyse state variables, functions and modifiers self._analyze_third_part(contracts_to_be_analyzed, libraries) + self._analyzed = True + + self._convert_to_slithir() # TODO refactor the following functions, and use a lambda function + @property + def analyzed(self): + return self._analyzed + def _analyze_all_enums(self, contracts_to_be_analyzed): while contracts_to_be_analyzed: contract = contracts_to_be_analyzed[0] contracts_to_be_analyzed = contracts_to_be_analyzed[1:] - all_father_analyzed = all(father.is_analyzed for father in contract.inheritances) + all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) - if not contract.inheritances or all_father_analyzed: + if not contract.inheritance or all_father_analyzed: self._analyze_enums(contract) else: contracts_to_be_analyzed += [contract] return - def _analyze_first_part(self, contracts_to_be_analyzed, libraries): for lib in libraries: self._parse_struct_var_modifiers_functions(lib) @@ -130,16 +201,15 @@ class SlitherSolc(Slither): contract = contracts_to_be_analyzed[0] contracts_to_be_analyzed = contracts_to_be_analyzed[1:] - all_father_analyzed = all(father.is_analyzed for father in contract.inheritances) + all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) - if not contract.inheritances or all_father_analyzed: + if not contract.inheritance or all_father_analyzed: self._parse_struct_var_modifiers_functions(contract) else: contracts_to_be_analyzed += [contract] return - def _analyze_second_part(self, contracts_to_be_analyzed, libraries): for lib in libraries: self._analyze_struct_events(lib) @@ -152,9 +222,9 @@ class SlitherSolc(Slither): contract = contracts_to_be_analyzed[0] contracts_to_be_analyzed = contracts_to_be_analyzed[1:] - all_father_analyzed = all(father.is_analyzed for father in contract.inheritances) + all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) - if not contract.inheritances or all_father_analyzed: + if not contract.inheritance or all_father_analyzed: self._analyze_struct_events(contract) else: @@ -173,9 +243,9 @@ class SlitherSolc(Slither): contract = contracts_to_be_analyzed[0] contracts_to_be_analyzed = contracts_to_be_analyzed[1:] - all_father_analyzed = all(father.is_analyzed for father in contract.inheritances) + all_father_analyzed = all(father.is_analyzed for father in contract.inheritance) - if not contract.inheritances or all_father_analyzed: + if not contract.inheritance or all_father_analyzed: self._analyze_variables_modifiers_functions(contract) else: @@ -188,13 +258,16 @@ class SlitherSolc(Slither): contract.set_is_analyzed(True) def _parse_struct_var_modifiers_functions(self, contract): - contract.parse_structs() # struct can refer another struct + contract.parse_structs() # struct can refer another struct contract.parse_state_variables() contract.parse_modifiers() contract.parse_functions() contract.set_is_analyzed(True) def _analyze_struct_events(self, contract): + + contract.analyze_constant_state_variables() + # Struct can refer to enum, or state variables contract.analyze_structs() # Event can refer to struct @@ -217,3 +290,8 @@ class SlitherSolc(Slither): contract.set_is_analyzed(True) + def _convert_to_slithir(self): + for contract in self.contracts: + for func in contract.functions + contract.modifiers: + if func.contract == contract: + func.convert_expression_to_slithir() diff --git a/slither/solc_parsing/solidity_types/__init__.py b/slither/solc_parsing/solidity_types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/solcParsing/solidityTypes/typeParsing.py b/slither/solc_parsing/solidity_types/type_parsing.py similarity index 52% rename from slither/solcParsing/solidityTypes/typeParsing.py rename to slither/solc_parsing/solidity_types/type_parsing.py index 612aa9d61..76e6d3110 100644 --- a/slither/solcParsing/solidityTypes/typeParsing.py +++ b/slither/solc_parsing/solidity_types/type_parsing.py @@ -1,12 +1,12 @@ import logging -from slither.core.solidityTypes.elementaryType import ElementaryType, ElementaryTypeName -from slither.core.solidityTypes.userDefinedType import UserDefinedType -from slither.core.solidityTypes.arrayType import ArrayType -from slither.core.solidityTypes.mappingType import MappingType -from slither.core.solidityTypes.functionType import FunctionType +from slither.core.solidity_types.elementary_type import ElementaryType, ElementaryTypeName +from slither.core.solidity_types.user_defined_type import UserDefinedType +from slither.core.solidity_types.array_type import ArrayType +from slither.core.solidity_types.mapping_type import MappingType +from slither.core.solidity_types.function_type import FunctionType -from slither.core.variables.functionTypeVariable import FunctionTypeVariable +from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.declarations.contract import Contract from slither.core.declarations.function import Function @@ -35,8 +35,18 @@ def _find_from_type_name(name, contract, contracts, structures, enums): return ArrayType(ElementaryType(name_elementary), Literal(depth)) else: return ElementaryType(name_elementary) + # We first look for contract + # To avoid collision + # Ex: a structure with the name of a contract + name_contract = name + if name_contract.startswith('contract '): + name_contract = name_contract[len('contract '):] + if name_contract.startswith('library '): + name_contract = name_contract[len('library '):] + var_type = next((c for c in contracts if c.name == name_contract), None) - var_type = next((st for st in structures if st.name == name), None) + if not var_type: + var_type = next((st for st in structures if st.name == name), None) if not var_type: var_type = next((e for e in enums if e.name == name), None) if not var_type: @@ -69,32 +79,28 @@ def _find_from_type_name(name, contract, contracts, structures, enums): var_type = next((st for st in all_structures if st.contract.name+"."+st.name == name_struct), None) if var_type: return ArrayType(UserDefinedType(var_type), Literal(depth)) - if not var_type: - name_contract = name - if name_contract.startswith('contract '): - name_contract = name_contract[len('contract '):] - var_type = next((c for c in contracts if c.name == name_contract), None) + if not var_type: var_type = next((f for f in contract.functions if f.name == name), None) if not var_type: if name.startswith('function '): - found = re.findall('function \(([a-zA-Z0-9\.,]*)\) returns \(([a-zA-Z0-9\.,]*)\)', name) - assert len(found) == 1 - params = found[0][0].split(',') - return_values = found[0][1].split(',') - params = [_find_from_type_name(p, contract, contracts, structures, enums) for p in params] - return_values = [_find_from_type_name(r, contract, contracts, structures, enums) for r in return_values] - params_vars = [] - return_vars = [] - for p in params: - var = FunctionTypeVariable() - var.set_type(p) - params_vars.append(var) - for r in return_values: - var = FunctionTypeVariable() - var.set_type(r) - return_vars.append(var) - return FunctionType(params_vars, return_vars) + found = re.findall('function \(([ ()a-zA-Z0-9\.,]*)\) returns \(([a-zA-Z0-9\.,]*)\)', name) + assert len(found) == 1 + params = found[0][0].split(',') + return_values = found[0][1].split(',') + params = [_find_from_type_name(p, contract, contracts, structures, enums) for p in params] + return_values = [_find_from_type_name(r, contract, contracts, structures, enums) for r in return_values] + params_vars = [] + return_vars = [] + for p in params: + var = FunctionTypeVariable() + var.set_type(p) + params_vars.append(var) + for r in return_values: + var = FunctionTypeVariable() + var.set_type(r) + return_vars.append(var) + return FunctionType(params_vars, return_vars) if not var_type: if name.startswith('mapping('): found = re.findall('mapping\(([a-zA-Z0-9\.]*) => ([a-zA-Z0-9\.]*)\)', name) @@ -116,8 +122,8 @@ def _find_from_type_name(name, contract, contracts, structures, enums): def parse_type(t, caller_context): # local import to avoid circular dependency - from slither.solcParsing.expressions.expressionParsing import parse_expression - from slither.solcParsing.variables.functionTypeVariableSolc import FunctionTypeVariableSolc + from slither.solc_parsing.expressions.expression_parsing import parse_expression + from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc if isinstance(caller_context, Contract): contract = caller_context @@ -127,6 +133,14 @@ def parse_type(t, caller_context): logger.error('Incorrect caller context') exit(-1) + + is_compact_ast = caller_context.is_compact_ast + + if is_compact_ast: + key = 'nodeType' + else: + key = 'name' + structures = contract.structures enums = contract.enums contracts = contract.slither.contracts @@ -134,51 +148,69 @@ def parse_type(t, caller_context): if isinstance(t, UnknownType): return _find_from_type_name(t.name, contract, contracts, structures, enums) - elif t['name'] == 'ElementaryTypeName': - return ElementaryType(t['attributes']['name']) + elif t[key] == 'ElementaryTypeName': + if is_compact_ast: + return ElementaryType(t['name']) + return ElementaryType(t['attributes'][key]) - elif t['name'] == 'UserDefinedTypeName': - return _find_from_type_name(t['attributes']['name'], contract, contracts, structures, enums) + elif t[key] == 'UserDefinedTypeName': + if is_compact_ast: + return _find_from_type_name(t['typeDescriptions']['typeString'], contract, contracts, structures, enums) + return _find_from_type_name(t['attributes'][key], contract, contracts, structures, enums) - elif t['name'] == 'ArrayTypeName': + elif t[key] == 'ArrayTypeName': length = None - if len(t['children']) == 2: - length = parse_expression(t['children'][1], caller_context) + if is_compact_ast: + if t['length']: + length = parse_expression(t['length'], caller_context) + array_type = parse_type(t['baseType'], contract) else: - assert len(t['children']) == 1 - array_type = parse_type(t['children'][0], contract) + if len(t['children']) == 2: + length = parse_expression(t['children'][1], caller_context) + else: + assert len(t['children']) == 1 + array_type = parse_type(t['children'][0], contract) return ArrayType(array_type, length) - elif t['name'] == 'Mapping': + elif t[key] == 'Mapping': - assert len(t['children']) == 2 + if is_compact_ast: + mappingFrom = parse_type(t['keyType'], contract) + mappingTo = parse_type(t['valueType'], contract) + else: + assert len(t['children']) == 2 - mappingFrom = parse_type(t['children'][0], contract) - mappingTo = parse_type(t['children'][1], contract) + mappingFrom = parse_type(t['children'][0], contract) + mappingTo = parse_type(t['children'][1], contract) return MappingType(mappingFrom, mappingTo) - elif t['name'] == 'FunctionTypeName': - assert len(t['children']) == 2 + elif t[key] == 'FunctionTypeName': - params = t['children'][0] - return_values = t['children'][1] + if is_compact_ast: + params = t['parameterTypes'] + return_values = t['returnParameterTypes'] + index = 'parameters' + else: + assert len(t['children']) == 2 + params = t['children'][0] + return_values = t['children'][1] + index = 'children' - assert params['name'] == 'ParameterList' - assert return_values['name'] == 'ParameterList' + assert params[key] == 'ParameterList' + assert return_values[key] == 'ParameterList' params_vars = [] return_values_vars = [] - for p in params['children']: + for p in params[index]: var = FunctionTypeVariableSolc(p) - - var.set_offset(p['src']) + var.set_offset(p['src'], caller_context.slither) var.analyze(caller_context) params_vars.append(var) - for p in return_values['children']: + for p in return_values[index]: var = FunctionTypeVariableSolc(p) - var.set_offset(p['src']) + var.set_offset(p['src'], caller_context.slither) var.analyze(caller_context) return_values_vars.append(var) diff --git a/slither/solc_parsing/variables/__init__.py b/slither/solc_parsing/variables/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/solc_parsing/variables/event_variable.py b/slither/solc_parsing/variables/event_variable.py new file mode 100644 index 000000000..794f19fe1 --- /dev/null +++ b/slither/solc_parsing/variables/event_variable.py @@ -0,0 +1,5 @@ + +from .variable_declaration import VariableDeclarationSolc +from slither.core.variables.event_variable import EventVariable + +class EventVariableSolc(VariableDeclarationSolc, EventVariable): pass diff --git a/slither/solc_parsing/variables/function_type_variable.py b/slither/solc_parsing/variables/function_type_variable.py new file mode 100644 index 000000000..50af5145c --- /dev/null +++ b/slither/solc_parsing/variables/function_type_variable.py @@ -0,0 +1,5 @@ + +from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc +from slither.core.variables.function_type_variable import FunctionTypeVariable + +class FunctionTypeVariableSolc(VariableDeclarationSolc, FunctionTypeVariable): pass diff --git a/slither/solc_parsing/variables/local_variable.py b/slither/solc_parsing/variables/local_variable.py new file mode 100644 index 000000000..1df53261b --- /dev/null +++ b/slither/solc_parsing/variables/local_variable.py @@ -0,0 +1,24 @@ + +from .variable_declaration import VariableDeclarationSolc +from slither.core.variables.local_variable import LocalVariable + +class LocalVariableSolc(VariableDeclarationSolc, LocalVariable): + + def _analyze_variable_attributes(self, attributes): + '''' + Variable Location + Can be storage/memory or default + ''' + if 'storageLocation' in attributes: + location = attributes['storageLocation'] + self._location = location + else: + if 'memory' in attributes['type']: + self._location = 'memory' + elif'storage' in attributes['type']: + self._location = 'storage' + else: + self._location = 'default' + + super(LocalVariableSolc, self)._analyze_variable_attributes(attributes) + diff --git a/slither/solcParsing/variables/localVariableInitFromTupleSolc.py b/slither/solc_parsing/variables/local_variable_init_from_tuple.py similarity index 60% rename from slither/solcParsing/variables/localVariableInitFromTupleSolc.py rename to slither/solc_parsing/variables/local_variable_init_from_tuple.py index 5a73c4ee7..e962e4c3c 100644 --- a/slither/solcParsing/variables/localVariableInitFromTupleSolc.py +++ b/slither/solc_parsing/variables/local_variable_init_from_tuple.py @@ -1,6 +1,6 @@ -from .variableDeclarationSolc import VariableDeclarationSolc -from slither.core.variables.localVariableInitFromTuple import LocalVariableInitFromTuple +from .variable_declaration import VariableDeclarationSolc +from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple class LocalVariableInitFromTupleSolc(VariableDeclarationSolc, LocalVariableInitFromTuple): @@ -8,3 +8,4 @@ class LocalVariableInitFromTupleSolc(VariableDeclarationSolc, LocalVariableInitF super(LocalVariableInitFromTupleSolc, self).__init__(var) self._tuple_index = index + diff --git a/slither/solc_parsing/variables/state_variable.py b/slither/solc_parsing/variables/state_variable.py new file mode 100644 index 000000000..242f441ca --- /dev/null +++ b/slither/solc_parsing/variables/state_variable.py @@ -0,0 +1,5 @@ + +from .variable_declaration import VariableDeclarationSolc +from slither.core.variables.state_variable import StateVariable + +class StateVariableSolc(VariableDeclarationSolc, StateVariable): pass diff --git a/slither/solc_parsing/variables/structure_variable.py b/slither/solc_parsing/variables/structure_variable.py new file mode 100644 index 000000000..f0823f67d --- /dev/null +++ b/slither/solc_parsing/variables/structure_variable.py @@ -0,0 +1,5 @@ + +from .variable_declaration import VariableDeclarationSolc +from slither.core.variables.structure_variable import StructureVariable + +class StructureVariableSolc(VariableDeclarationSolc, StructureVariable): pass diff --git a/slither/solc_parsing/variables/variable_declaration.py b/slither/solc_parsing/variables/variable_declaration.py new file mode 100644 index 000000000..17b1e9c37 --- /dev/null +++ b/slither/solc_parsing/variables/variable_declaration.py @@ -0,0 +1,158 @@ +import logging +from slither.solc_parsing.expressions.expression_parsing import parse_expression + +from slither.core.variables.variable import Variable + +from slither.solc_parsing.solidity_types.type_parsing import parse_type, UnknownType + +from slither.core.solidity_types.elementary_type import ElementaryType, NonElementaryType + +logger = logging.getLogger("VariableDeclarationSolcParsing") + +class MultipleVariablesDeclaration(Exception): + ''' + This is raised on + var (a,b) = ... + It should occur only on local variable definition + ''' + pass + +class VariableDeclarationSolc(Variable): + + def __init__(self, var): + ''' + A variable can be declared through a statement, or directly. + If it is through a statement, the following children may contain + the init value. + It may be possible that the variable is declared through a statement, + but the init value is declared at the VariableDeclaration children level + ''' + + super(VariableDeclarationSolc, self).__init__() + self._was_analyzed = False + self._elem_to_parse = None + self._initializedNotParsed = None + + self._is_compact_ast = False + + if 'nodeType' in var: + self._is_compact_ast = True + nodeType = var['nodeType'] + if nodeType in ['VariableDeclarationStatement', 'VariableDefinitionStatement']: + if len(var['declarations'])>1: + raise MultipleVariablesDeclaration + init = None + if 'initialValue' in var: + init = var['initialValue'] + self._init_from_declaration(var['declarations'][0], init) + elif nodeType == 'VariableDeclaration': + self._init_from_declaration(var, var['value']) + else: + logger.error('Incorrect variable declaration type {}'.format(nodeType)) + exit(-1) + + else: + nodeType = var['name'] + + if nodeType in ['VariableDeclarationStatement', 'VariableDefinitionStatement']: + if len(var['children']) == 2: + init = var['children'][1] + elif len(var['children']) == 1: + init = None + elif len(var['children']) > 2: + raise MultipleVariablesDeclaration + else: + logger.error('Variable declaration without children?'+var) + exit(-1) + declaration = var['children'][0] + self._init_from_declaration(declaration, init) + elif nodeType == 'VariableDeclaration': + self._init_from_declaration(var, None) + else: + logger.error('Incorrect variable declaration type {}'.format(nodeType)) + exit(-1) + + @property + def initialized(self): + return self._initialized + + @property + def uninitialized(self): + return not self._initialized + + def _analyze_variable_attributes(self, attributes): + if 'visibility' in attributes: + self._visibility = attributes['visibility'] + else: + self._visibility = 'internal' + + def _init_from_declaration(self, var, init): + if self._is_compact_ast: + attributes = var + self._typeName = attributes['typeDescriptions']['typeString'] + else: + assert len(var['children']) <= 2 + assert var['name'] == 'VariableDeclaration' + + attributes = var['attributes'] + self._typeName = attributes['type'] + + self._name = attributes['name'] + self._arrayDepth = 0 + self._isMapping = False + self._mappingFrom = None + self._mappingTo = False + self._initial_expression = None + self._type = None + + if 'constant' in attributes: + self._is_constant = attributes['constant'] + + self._analyze_variable_attributes(attributes) + + if self._is_compact_ast: + if var['typeName']: + self._elem_to_parse = var['typeName'] + else: + self._elem_to_parse = UnknownType(var['typeDescriptions']['typeString']) + else: + if not var['children']: + # It happens on variable declared inside loop declaration + try: + self._type = ElementaryType(self._typeName) + self._elem_to_parse = None + except NonElementaryType: + self._elem_to_parse = UnknownType(self._typeName) + else: + self._elem_to_parse = var['children'][0] + + if self._is_compact_ast: + self._initializedNotParsed = init + if init: + self._initialized = True + else: + if init: # there are two way to init a var local in the AST + assert len(var['children']) <= 1 + self._initialized = True + self._initializedNotParsed = init + elif len(var['children']) in [0, 1]: + self._initialized = False + self._initializedNotParsed = [] + else: + assert len(var['children']) == 2 + self._initialized = True + self._initializedNotParsed = var['children'][1] + + def analyze(self, caller_context): + # Can be re-analyzed due to inheritance + if self._was_analyzed: + return + self._was_analyzed = True + + if self._elem_to_parse: + self._type = parse_type(self._elem_to_parse, caller_context) + self._elem_to_parse = None + + if self._initialized: + self._initial_expression = parse_expression(self._initializedNotParsed, caller_context) + self._initializedNotParsed = None diff --git a/slither/utils/command_line.py b/slither/utils/command_line.py new file mode 100644 index 000000000..57542901d --- /dev/null +++ b/slither/utils/command_line.py @@ -0,0 +1,89 @@ +from prettytable import PrettyTable + +from slither.detectors.abstract_detector import classification_txt + +def output_to_markdown(detector_classes, printer_classes): + + def extract_help(detector): + if detector.WIKI == '': + return detector.HELP + return '[{}]({})'.format(detector.HELP, detector.WIKI) + + detectors_list = [] + for detector in detector_classes: + argument = detector.ARGUMENT + # dont show the backdoor example + if argument == 'backdoor': + continue + help_info = extract_help(detector) + impact = detector.IMPACT + confidence = classification_txt[detector.CONFIDENCE] + detectors_list.append((argument, help_info, impact, confidence)) + + # Sort by impact, confidence, and name + detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[3], element[0])) + idx = 1 + for (argument, help_info, impact, confidence) in detectors_list: + print('{} | `{}` | {} | {} | {}'.format(idx, + argument, + help_info, + classification_txt[impact], + confidence)) + idx = idx + 1 + + print() + printers_list = [] + for printer in printer_classes: + argument = printer.ARGUMENT + help_info = printer.HELP + printers_list.append((argument, help_info)) + + # Sort by impact, confidence, and name + printers_list = sorted(printers_list, key=lambda element: (element[0])) + idx = 1 + for (argument, help_info) in printers_list: + print('{} | `{}` | {}'.format(idx, argument, help_info)) + idx = idx + 1 + +def output_detectors(detector_classes): + detectors_list = [] + for detector in detector_classes: + argument = detector.ARGUMENT + # dont show the backdoor example + if argument == 'backdoor': + continue + help_info = detector.HELP + impact = detector.IMPACT + confidence = classification_txt[detector.CONFIDENCE] + detectors_list.append((argument, help_info, impact, confidence)) + table = PrettyTable(["Num", + "Check", + "What it Detects", + "Impact", + "Confidence"]) + + # Sort by impact, confidence, and name + detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[3], element[0])) + idx = 1 + for (argument, help_info, impact, confidence) in detectors_list: + table.add_row([idx, argument, help_info, classification_txt[impact], confidence]) + idx = idx + 1 + print(table) + +def output_printers(printer_classes): + printers_list = [] + for printer in printer_classes: + argument = printer.ARGUMENT + help_info = printer.HELP + printers_list.append((argument, help_info)) + table = PrettyTable(["Num", + "Printer", + "What it Does"]) + + # Sort by impact, confidence, and name + printers_list = sorted(printers_list, key=lambda element: (element[0])) + idx = 1 + for (argument, help_info) in printers_list: + table.add_row([idx, argument, help_info]) + idx = idx + 1 + print(table) diff --git a/slither/utils/expression_manipulations.py b/slither/utils/expression_manipulations.py new file mode 100644 index 000000000..271c6e544 --- /dev/null +++ b/slither/utils/expression_manipulations.py @@ -0,0 +1,121 @@ +""" + We use protected member, to avoid having setter in the expression + as they should be immutable +""" +import copy +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.new_elementary_type import NewElementaryType +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation + +def f_expressions(e, x): + e._expressions.append(x) + +def f_call(e, x): + e._arguments.append(x) + +def f_expression(e, x): + e._expression = x + +def f_called(e, x): + e._called = x + +class SplitTernaryExpression(object): + + def __init__(self, expression): + + # print(expression) + + if isinstance(expression, ConditionalExpression): + self.true_expression = copy.copy(expression.then_expression) + self.false_expression = copy.copy(expression.else_expression) + self.condition = copy.copy(expression.if_expression) + else: + self.true_expression = copy.copy(expression) + self.false_expression = copy.copy(expression) + self.condition = None + self.copy_expression(expression, self.true_expression, self.false_expression) + + def apply_copy(self, next_expr, true_expression, false_expression, f): + + if isinstance(next_expr, ConditionalExpression): + f(true_expression, copy.copy(next_expr.then_expression)) + f(false_expression, copy.copy(next_expr.else_expression)) + self.condition = copy.copy(next_expr.if_expression) + return False + else: + f(true_expression, copy.copy(next_expr)) + f(false_expression, copy.copy(next_expr)) + return True + + def copy_expression(self, expression, true_expression, false_expression): + if self.condition: + return + + if isinstance(expression, ConditionalExpression): + raise Exception('Nested ternary operator not handled') + + if isinstance(expression, (Literal, Identifier, IndexAccess)): + return None + + # case of lib + # (.. ? .. : ..).add + if isinstance(expression, MemberAccess): + next_expr = expression.expression + if self.apply_copy(next_expr, true_expression, false_expression, f_expression): + self.copy_expression(next_expr, + true_expression.expression, + false_expression.expression) + + elif isinstance(expression, (AssignmentOperation, BinaryOperation, TupleExpression)): + true_expression._expressions = [] + false_expression._expressions = [] + + for next_expr in expression.expressions: + if self.apply_copy(next_expr, true_expression, false_expression, f_expressions): + # always on last arguments added + self.copy_expression(next_expr, + true_expression.expressions[-1], + false_expression.expressions[-1]) + + elif isinstance(expression, CallExpression): + next_expr = expression.called + + # case of lib + # (.. ? .. : ..).add + if self.apply_copy(next_expr, true_expression, false_expression, f_called): + self.copy_expression(next_expr, + true_expression.called, + false_expression.called) + + true_expression._arguments = [] + false_expression._arguments = [] + + for next_expr in expression.arguments: + if self.apply_copy(next_expr, true_expression, false_expression, f_call): + # always on last arguments added + self.copy_expression(next_expr, + true_expression.arguments[-1], + false_expression.arguments[-1]) + + elif isinstance(expression, TypeConversion): + next_expr = expression.expression + if self.apply_copy(next_expr, true_expression, false_expression, f_expression): + self.copy_expression(expression.expression, + true_expression.expression, + false_expression.expression) + + else: + raise Exception('Ternary operation not handled {}'.format(type(expression))) + diff --git a/slither/utils/utils.py b/slither/utils/utils.py deleted file mode 100644 index 8f90ed8fb..000000000 --- a/slither/utils/utils.py +++ /dev/null @@ -1,61 +0,0 @@ -""" - Utils module -""" -import re - -def find_call(call, contract, contracts): - """ Find call in the contract - - Do not respect c3 lineralization - Args: - call: call the find - contract: current contract - contracts: list of contracts - Return: - function: returns the function called (or None if the funciton was not found) -""" - if '.call.value' in str(call): - return None - if '.call.gas.value' in str(call): - return None - for f in contract.functions + contract.modifiers: - if call == f.name: - return f - for father in contract.inheritances: - fatherContract = next((x for x in contracts if x.name == father), None) - if fatherContract: - for f in fatherContract.functions: - if call == f.name: - return f - call_found = find_call(call, fatherContract, contracts) - if call_found: - return call_found - return None - -def convert_offset(offset, sourceUnits): - ''' - Convert a text offset to a real offset - see https://solidity.readthedocs.io/en/develop/miscellaneous.html#source-mappings - - To handle solc <0.3.6: - - If the matching is not found, returns an empty dict - - If the matching is found, but the filename is not knwon, return only start/length - Args: - offset (str): format: '..:..:..' - sourceUnits (dict): map int -> filename - Returns: - (dict): {'start':0, 'length':0, 'filename': 'file.sol'} - ''' - position = re.findall('([0-9]*):([0-9]*):([-]?[0-9]*)', offset) - if len(position) != 1: - return {} - - s, l, f = position[0] - s = int(s) - l = int(l) - f = int(f) - - if f not in sourceUnits: - return {'start':s, 'length':l} - filename = sourceUnits[f] - return {'start':s, 'length':l, 'filename': filename} diff --git a/slither/visitors/expression/exportValues.py b/slither/visitors/expression/export_values.py similarity index 97% rename from slither/visitors/expression/exportValues.py rename to slither/visitors/expression/export_values.py index 18ed9ed67..15afb4dfa 100644 --- a/slither/visitors/expression/exportValues.py +++ b/slither/visitors/expression/export_values.py @@ -1,7 +1,7 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignmentOperation import AssignmentOperationType +from slither.core.expressions.assignment_operation import AssignmentOperationType from slither.core.variables.variable import Variable diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index b0486ba5c..fe8eeb821 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -1,20 +1,20 @@ import logging -from slither.core.expressions.assignmentOperation import AssignmentOperation -from slither.core.expressions.binaryOperation import BinaryOperation -from slither.core.expressions.callExpression import CallExpression -from slither.core.expressions.conditionalExpression import ConditionalExpression -from slither.core.expressions.elementaryTypeNameExpression import ElementaryTypeNameExpression +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.expressions.identifier import Identifier -from slither.core.expressions.indexAccess import IndexAccess +from slither.core.expressions.index_access import IndexAccess from slither.core.expressions.literal import Literal -from slither.core.expressions.memberAccess import MemberAccess -from slither.core.expressions.newArray import NewArray -from slither.core.expressions.newContract import NewContract -from slither.core.expressions.newElementaryType import NewElementaryType -from slither.core.expressions.tupleExpression import TupleExpression -from slither.core.expressions.typeConversion import TypeConversion -from slither.core.expressions.unaryOperation import UnaryOperation +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.new_elementary_type import NewElementaryType +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation logger = logging.getLogger("ExpressionVisitor") @@ -22,8 +22,8 @@ class ExpressionVisitor: def __init__(self, expression): self._expression = expression - self._visit_expression(self.expression) self._result = None + self._visit_expression(self.expression) def result(self): return self._result diff --git a/slither/visitors/expression/expressionPrinter.py b/slither/visitors/expression/expression_printer.py similarity index 100% rename from slither/visitors/expression/expressionPrinter.py rename to slither/visitors/expression/expression_printer.py diff --git a/slither/visitors/expression/findCalls.py b/slither/visitors/expression/find_calls.py similarity index 97% rename from slither/visitors/expression/findCalls.py rename to slither/visitors/expression/find_calls.py index 4f89ef481..6af072329 100644 --- a/slither/visitors/expression/findCalls.py +++ b/slither/visitors/expression/find_calls.py @@ -1,7 +1,7 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignmentOperation import AssignmentOperationType +from slither.core.expressions.assignment_operation import AssignmentOperationType from slither.core.variables.variable import Variable diff --git a/slither/visitors/expression/findPush.py b/slither/visitors/expression/find_push.py similarity index 95% rename from slither/visitors/expression/findPush.py rename to slither/visitors/expression/find_push.py index 9194b3e4e..76e0d7a56 100644 --- a/slither/visitors/expression/findPush.py +++ b/slither/visitors/expression/find_push.py @@ -1,8 +1,8 @@ from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.identifier import Identifier -from slither.core.expressions.indexAccess import IndexAccess +from slither.core.expressions.index_access import IndexAccess -from slither.visitors.expression.rightValue import RightValue +from slither.visitors.expression.right_value import RightValue key = 'FindPush' diff --git a/slither/visitors/expression/has_conditional.py b/slither/visitors/expression/has_conditional.py new file mode 100644 index 000000000..5378e4d98 --- /dev/null +++ b/slither/visitors/expression/has_conditional.py @@ -0,0 +1,13 @@ + +from slither.visitors.expression.expression import ExpressionVisitor + +class HasConditional(ExpressionVisitor): + + def result(self): + # == True, to convert None to false + return self._result is True + + def _post_conditional_expression(self, expression): +# if self._result is True: +# raise('Slither does not support nested ternary operator') + self._result = True diff --git a/slither/visitors/expression/leftValue.py b/slither/visitors/expression/left_value.py similarity index 97% rename from slither/visitors/expression/leftValue.py rename to slither/visitors/expression/left_value.py index cf679859a..c23c3c06c 100644 --- a/slither/visitors/expression/leftValue.py +++ b/slither/visitors/expression/left_value.py @@ -2,7 +2,7 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignmentOperation import AssignmentOperationType +from slither.core.expressions.assignment_operation import AssignmentOperationType from slither.core.variables.variable import Variable diff --git a/slither/visitors/expression/readVar.py b/slither/visitors/expression/read_var.py similarity index 95% rename from slither/visitors/expression/readVar.py rename to slither/visitors/expression/read_var.py index 4a10eece7..ae4882d84 100644 --- a/slither/visitors/expression/readVar.py +++ b/slither/visitors/expression/read_var.py @@ -1,10 +1,10 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignmentOperation import AssignmentOperationType +from slither.core.expressions.assignment_operation import AssignmentOperationType from slither.core.variables.variable import Variable -from slither.core.declarations.solidityVariables import SolidityVariable +from slither.core.declarations.solidity_variables import SolidityVariable key = 'ReadVar' diff --git a/slither/visitors/expression/rightValue.py b/slither/visitors/expression/right_value.py similarity index 97% rename from slither/visitors/expression/rightValue.py rename to slither/visitors/expression/right_value.py index 5f34e45db..718ed392a 100644 --- a/slither/visitors/expression/rightValue.py +++ b/slither/visitors/expression/right_value.py @@ -5,7 +5,7 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignmentOperation import AssignmentOperationType +from slither.core.expressions.assignment_operation import AssignmentOperationType from slither.core.expressions.expression import Expression from slither.core.variables.variable import Variable diff --git a/slither/visitors/expression/writeVar.py b/slither/visitors/expression/write_var.py similarity index 95% rename from slither/visitors/expression/writeVar.py rename to slither/visitors/expression/write_var.py index 5edb3c3b7..0368b4ad1 100644 --- a/slither/visitors/expression/writeVar.py +++ b/slither/visitors/expression/write_var.py @@ -1,13 +1,13 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignmentOperation import AssignmentOperation +from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.variables.variable import Variable -from slither.core.expressions.memberAccess import MemberAccess +from slither.core.expressions.member_access import MemberAccess -from slither.core.expressions.indexAccess import IndexAccess +from slither.core.expressions.index_access import IndexAccess key = 'WriteVar' diff --git a/slither/visitors/slithir/__init__.py b/slither/visitors/slithir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py new file mode 100644 index 000000000..e1c90c595 --- /dev/null +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -0,0 +1,247 @@ +import logging + +from slither.core.declarations import Function, Structure +from slither.core.expressions import (AssignmentOperationType, + UnaryOperationType) +from slither.core.solidity_types.array_type import ArrayType +from slither.slithir.operations import (Assignment, Binary, BinaryType, Delete, + Index, InitArray, InternalCall, Member, + NewArray, NewContract, NewStructure, + TypeConversion, Unary, Unpack) +from slither.slithir.tmp_operations.argument import Argument +from slither.slithir.tmp_operations.tmp_call import TmpCall +from slither.slithir.tmp_operations.tmp_new_array import TmpNewArray +from slither.slithir.tmp_operations.tmp_new_contract import TmpNewContract +from slither.slithir.tmp_operations.tmp_new_elementary_type import \ + TmpNewElementaryType +from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure +from slither.slithir.variables import (Constant, ReferenceVariable, + TemporaryVariable, TupleVariable) +from slither.visitors.expression.expression import ExpressionVisitor + +logger = logging.getLogger("VISTIOR:ExpressionToSlithIR") + +key = 'expressionToSlithIR' + +def get(expression): + val = expression.context[key] + # we delete the item to reduce memory use + del expression.context[key] + return val + +def set_val(expression, val): + expression.context[key] = val + +def convert_assignment(left, right, t, return_type): + if t == AssignmentOperationType.ASSIGN: + return Assignment(left, right, return_type) + elif t == AssignmentOperationType.ASSIGN_OR: + return Binary(left, left, right, BinaryType.OR) + elif t == AssignmentOperationType.ASSIGN_CARET: + return Binary(left, left, right, BinaryType.CARET) + elif t == AssignmentOperationType.ASSIGN_AND: + return Binary(left, left, right, BinaryType.AND) + elif t == AssignmentOperationType.ASSIGN_LEFT_SHIFT: + return Binary(left, left, right, BinaryType.LEFT_SHIFT) + elif t == AssignmentOperationType.ASSIGN_RIGHT_SHIFT: + return Binary(left, left, right, BinaryType.RIGHT_SHIT) + elif t == AssignmentOperationType.ASSIGN_ADDITION: + return Binary(left, left, right, BinaryType.ADDITION) + elif t == AssignmentOperationType.ASSIGN_SUBTRACTION: + return Binary(left, left, right, BinaryType.SUBTRACTION) + elif t == AssignmentOperationType.ASSIGN_MULTIPLICATION: + return Binary(left, left, right, BinaryType.MULTIPLICATION) + elif t == AssignmentOperationType.ASSIGN_DIVISION: + return Binary(left, left, right, BinaryType.DIVISION) + elif t == AssignmentOperationType.ASSIGN_MODULO: + return Binary(left, left, right, BinaryType.MODULO) + + logger.error('Missing type during assignment conversion') + exit(-1) + +class ExpressionToSlithIR(ExpressionVisitor): + + def __init__(self, expression, node): + self._expression = expression + self._node = node + self._result = [] + self._visit_expression(self.expression) + + def result(self): + return self._result + + def _post_assignement_operation(self, expression): + left = get(expression.expression_left) + right = get(expression.expression_right) + if isinstance(left, list): # tuple expression: + if isinstance(right, list): # unbox assigment + assert len(left) == len(right) + for idx in range(len(left)): + if not left[idx] is None: + operation = convert_assignment(left[idx], right[idx], expression.type, expression.expression_return_type) + self._result.append(operation) + set_val(expression, None) + else: + assert isinstance(right, TupleVariable) + for idx in range(len(left)): + if not left[idx] is None: + operation = Unpack(left[idx], right, idx) + self._result.append(operation) + set_val(expression, None) + else: + # Init of array, like + # uint8[2] var = [1,2]; + if isinstance(right, list): + operation = InitArray(right, left) + self._result.append(operation) + set_val(expression, left) + else: + operation = convert_assignment(left, right, expression.type, expression.expression_return_type) + self._result.append(operation) + # Return left to handle + # a = b = 1; + set_val(expression, left) + + def _post_binary_operation(self, expression): + left = get(expression.expression_left) + right = get(expression.expression_right) + val = TemporaryVariable(self._node) + + operation = Binary(val, left, right, expression.type) + self._result.append(operation) + set_val(expression, val) + + def _post_call_expression(self, expression): + called = get(expression.called) + args = [get(a) for a in expression.arguments if a] + for arg in args: + arg_ = Argument(arg) + self._result.append(arg_) + if isinstance(called, Function): + # internal call + + # If tuple + if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()': + val = TupleVariable() + else: + val = TemporaryVariable(self._node) + internal_call = InternalCall(called, len(args), val, expression.type_call) + self._result.append(internal_call) + set_val(expression, val) + else: + val = TemporaryVariable(self._node) + + # If tuple + if expression.type_call.startswith('tuple(') and expression.type_call != 'tuple()': + val = TupleVariable() + else: + val = TemporaryVariable(self._node) + + message_call = TmpCall(called, len(args), val, expression.type_call) + self._result.append(message_call) + set_val(expression, val) + + def _post_conditional_expression(self, expression): + raise Exception('Ternary operator are not convertible to SlithIR {}'.format(expression)) + + def _post_elementary_type_name_expression(self, expression): + set_val(expression, expression.type) + + def _post_identifier(self, expression): + set_val(expression, expression.value) + + def _post_index_access(self, expression): + left = get(expression.expression_left) + right = get(expression.expression_right) + val = ReferenceVariable(self._node) + operation = Index(val, left, right, expression.type) + self._result.append(operation) + set_val(expression, val) + + def _post_literal(self, expression): + set_val(expression, Constant(expression.value)) + + def _post_member_access(self, expression): + expr = get(expression.expression) + val = ReferenceVariable(self._node) + member = Member(expr, Constant(expression.member_name), val) + self._result.append(member) + set_val(expression, val) + + def _post_new_array(self, expression): + val = TemporaryVariable(self._node) + operation = TmpNewArray(expression.depth, expression.array_type, val) + self._result.append(operation) + set_val(expression, val) + + def _post_new_contract(self, expression): + val = TemporaryVariable(self._node) + operation = TmpNewContract(expression.contract_name, val) + self._result.append(operation) + set_val(expression, val) + + def _post_new_elementary_type(self, expression): + # TODO unclear if this is ever used? + val = TemporaryVariable(self._node) + operation = TmpNewElementaryType(expression.type, val) + self._result.append(operation) + set_val(expression, val) + + def _post_tuple_expression(self, expression): + expressions = [get(e) if e else None for e in expression.expressions] + if len(expressions) == 1: + val = expressions[0] + else: + val = expressions + set_val(expression, val) + + def _post_type_conversion(self, expression): + expr = get(expression.expression) + val = TemporaryVariable(self._node) + operation = TypeConversion(val, expr, expression.type) + self._result.append(operation) + set_val(expression, val) + + def _post_unary_operation(self, expression): + value = get(expression.expression) + if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: + lvalue = TemporaryVariable(self._node) + operation = Unary(lvalue, value, expression.type) + self._result.append(operation) + set_val(expression, lvalue) + elif expression.type in [UnaryOperationType.DELETE]: + operation = Delete(value) + self._result.append(operation) + set_val(expression, value) + elif expression.type in [UnaryOperationType.PLUSPLUS_PRE]: + operation = Binary(value, value, Constant("1"), BinaryType.ADDITION) + self._result.append(operation) + set_val(expression, value) + elif expression.type in [UnaryOperationType.MINUSMINUS_PRE]: + operation = Binary(value, value, Constant("1"), BinaryType.SUBTRACTION) + self._result.append(operation) + set_val(expression, value) + elif expression.type in [UnaryOperationType.PLUSPLUS_POST]: + lvalue = TemporaryVariable(self._node) + operation = Assignment(lvalue, value, value.type) + self._result.append(operation) + operation = Binary(value, value, Constant("1"), BinaryType.ADDITION) + self._result.append(operation) + set_val(expression, lvalue) + elif expression.type in [UnaryOperationType.MINUSMINUS_POST]: + lvalue = TemporaryVariable(self._node) + operation = Assignment(lvalue, value, value.type) + self._result.append(operation) + operation = Binary(value, value, Constant("1"), BinaryType.SUBTRACTION) + self._result.append(operation) + set_val(expression, lvalue) + elif expression.type in [UnaryOperationType.PLUS_PRE]: + set_val(expression, value) + elif expression.type in [UnaryOperationType.MINUS_PRE]: + lvalue = TemporaryVariable(self._node) + operation = Binary(lvalue, Constant("0"), value, BinaryType.SUBTRACTION) + self._result.append(operation) + set_val(expression, lvalue) + else: + raise Exception('Unary operation to IR not supported {}'.format(expression)) + diff --git a/tests/arbitrary_send.sol b/tests/arbitrary_send.sol new file mode 100644 index 000000000..3544a903c --- /dev/null +++ b/tests/arbitrary_send.sol @@ -0,0 +1,41 @@ +contract Test{ + + address destination; + + mapping (address => uint) balances; + + constructor(){ + balances[msg.sender] = 0; + } + + function direct(){ + msg.sender.send(this.balance); + } + + function init(){ + destination = msg.sender; + } + + function indirect(){ + destination.send(this.balance); + } + + // these are legitimate calls + // and should not be detected + function repay() payable{ + msg.sender.transfer(msg.value); + } + + function withdraw(){ + uint val = balances[msg.sender]; + msg.sender.send(val); + } + + function buy() payable{ + uint value_send = msg.value; + uint value_spent = 0 ; // simulate a buy of tokens + uint remaining = value_send - value_spent; + msg.sender.send(remaining); +} + +} diff --git a/examples/bugs/backdoor.sol b/tests/backdoor.sol similarity index 100% rename from examples/bugs/backdoor.sol rename to tests/backdoor.sol diff --git a/tests/complex_func.sol b/tests/complex_func.sol new file mode 100644 index 000000000..cdb716efd --- /dev/null +++ b/tests/complex_func.sol @@ -0,0 +1,88 @@ +pragma solidity ^0.4.24; + +contract Complex { + int numberOfSides = 7; + string shape; + uint i0 = 0; + uint i1 = 0; + uint i2 = 0; + uint i3 = 0; + uint i4 = 0; + uint i5 = 0; + uint i6 = 0; + uint i7 = 0; + uint i8 = 0; + uint i9 = 0; + uint i10 = 0; + + + function computeShape() external { + if (numberOfSides <= 2) { + shape = "Cant be a shape!"; + } else if (numberOfSides == 3) { + shape = "Triangle"; + } else if (numberOfSides == 4) { + shape = "Square"; + } else if (numberOfSides == 5) { + shape = "Pentagon"; + } else if (numberOfSides == 6) { + shape = "Hexagon"; + } else if (numberOfSides == 7) { + shape = "Heptagon"; + } else if (numberOfSides == 8) { + shape = "Octagon"; + } else if (numberOfSides == 9) { + shape = "Nonagon"; + } else if (numberOfSides == 10) { + shape = "Decagon"; + } else if (numberOfSides == 11) { + shape = "Hendecagon"; + } else { + shape = "Your shape is more than 11 sides."; + } + } + + function complexExternalWrites() external { + Increment test1 = new Increment(); + test1.increaseBy1(); + test1.increaseBy1(); + test1.increaseBy1(); + test1.increaseBy1(); + test1.increaseBy1(); + + Increment test2 = new Increment(); + test2.increaseBy1(); + + address test3 = new Increment(); + test3.call(bytes4(keccak256("increaseBy2()"))); + + address test4 = new Increment(); + test4.call(bytes4(keccak256("increaseBy2()"))); + } + + function complexStateVars() external { + i0 = 1; + i1 = 1; + i2 = 1; + i3 = 1; + i4 = 1; + i5 = 1; + i6 = 1; + i7 = 1; + i8 = 1; + i9 = 1; + i10 = 1; + } +} + +contract Increment { + uint i = 0; + + function increaseBy1() public { + i += 1; + } + + function increaseBy2() public { + i += 2; + } +} \ No newline at end of file diff --git a/tests/const_state_variables.sol b/tests/const_state_variables.sol new file mode 100644 index 000000000..f518ceb84 --- /dev/null +++ b/tests/const_state_variables.sol @@ -0,0 +1,37 @@ +pragma solidity ^0.4.24; + + +contract A { + + address constant public MY_ADDRESS = 0xE0f5206BBD039e7b0592d8918820024e2a7437b9; + address public myFriendsAddress = 0xc0ffee254729296a45a3885639AC7E10F9d54979; + + uint public used; + uint public test = 5; + + uint constant X = 32**22 + 8; + string constant TEXT1 = "abc"; + string text2 = "xyz"; + + function setUsed() public { + if (msg.sender == MY_ADDRESS) { + used = test; + } + } +} + + +contract B is A { + + address public mySistersAddress = 0x999999cf1046e68e36E1aA2E0E07105eDDD1f08E; + + function () public { + used = 0; + } + + function setUsed(uint a) public { + if (msg.sender == MY_ADDRESS) { + used = a; + } + } +} diff --git a/tests/expected_json/arbitrary_send.arbitrary-send.json b/tests/expected_json/arbitrary_send.arbitrary-send.json new file mode 100644 index 000000000..f8aa83cc1 --- /dev/null +++ b/tests/expected_json/arbitrary_send.arbitrary-send.json @@ -0,0 +1,40 @@ +[ + { + "calls": [ + "msg.sender.send(this.balance)" + ], + "contract": "Test", + "filename": "tests/arbitrary_send.sol", + "func": "direct", + "sourceMapping": [ + { + "filename": "tests/arbitrary_send.sol", + "length": 29, + "lines": [ + 12 + ], + "start": 174 + } + ], + "vuln": "ArbitrarySend" + }, + { + "calls": [ + "destination.send(this.balance)" + ], + "contract": "Test", + "filename": "tests/arbitrary_send.sol", + "func": "indirect", + "sourceMapping": [ + { + "filename": "tests/arbitrary_send.sol", + "length": 30, + "lines": [ + 20 + ], + "start": 307 + } + ], + "vuln": "ArbitrarySend" + } +] \ No newline at end of file diff --git a/tests/expected_json/backdoor.backdoor.json b/tests/expected_json/backdoor.backdoor.json new file mode 100644 index 000000000..9ebfb2d2b --- /dev/null +++ b/tests/expected_json/backdoor.backdoor.json @@ -0,0 +1,16 @@ +[ + { + "contract": "C", + "sourceMapping": { + "filename": "tests/backdoor.sol", + "length": 74, + "lines": [ + 5, + 6, + 7 + ], + "start": 42 + }, + "vuln": "backdoor" + } +] \ No newline at end of file diff --git a/tests/expected_json/backdoor.suicidal.json b/tests/expected_json/backdoor.suicidal.json new file mode 100644 index 000000000..663b33ef4 --- /dev/null +++ b/tests/expected_json/backdoor.suicidal.json @@ -0,0 +1,18 @@ +[ + { + "contract": "C", + "filename": "tests/backdoor.sol", + "func": "i_am_a_backdoor", + "sourceMapping": { + "filename": "tests/backdoor.sol", + "length": 74, + "lines": [ + 5, + 6, + 7 + ], + "start": 42 + }, + "vuln": "SuicidalFunc" + } +] \ No newline at end of file diff --git a/tests/expected_json/const_state_variables.constable-states.json b/tests/expected_json/const_state_variables.constable-states.json new file mode 100644 index 000000000..eb3a44afa --- /dev/null +++ b/tests/expected_json/const_state_variables.constable-states.json @@ -0,0 +1,88 @@ +[ + { + "contract": "B", + "filename": "tests/const_state_variables.sol", + "sourceMapping": [ + { + "filename": "tests/const_state_variables.sol", + "length": 20, + "lines": [ + 10 + ], + "start": 235 + }, + { + "filename": "tests/const_state_variables.sol", + "length": 20, + "lines": [ + 14 + ], + "start": 331 + }, + { + "filename": "tests/const_state_variables.sol", + "length": 76, + "lines": [ + 7 + ], + "start": 130 + }, + { + "filename": "tests/const_state_variables.sol", + "length": 76, + "lines": [ + 26 + ], + "start": 494 + } + ], + "unusedVars": [ + "myFriendsAddress", + "test", + "text2" + ], + "vuln": "ConstStateVariableCandidates" + }, + { + "contract": "B", + "filename": "tests/const_state_variables.sol", + "sourceMapping": [ + { + "filename": "tests/const_state_variables.sol", + "length": 20, + "lines": [ + 10 + ], + "start": 235 + }, + { + "filename": "tests/const_state_variables.sol", + "length": 20, + "lines": [ + 14 + ], + "start": 331 + }, + { + "filename": "tests/const_state_variables.sol", + "length": 76, + "lines": [ + 7 + ], + "start": 130 + }, + { + "filename": "tests/const_state_variables.sol", + "length": 76, + "lines": [ + 26 + ], + "start": 494 + } + ], + "unusedVars": [ + "mySistersAddress" + ], + "vuln": "ConstStateVariableCandidates" + } +] \ No newline at end of file diff --git a/tests/expected_json/external_function.external-function.json b/tests/expected_json/external_function.external-function.json new file mode 100644 index 000000000..e17984067 --- /dev/null +++ b/tests/expected_json/external_function.external-function.json @@ -0,0 +1,71 @@ +[ + { + "contract": "ContractWithFunctionNotCalled", + "filename": "tests/external_function.sol", + "func": "funcNotCalled", + "sourceMapping": { + "filename": "tests/external_function.sol", + "length": 40, + "lines": [ + 21, + 22, + 23 + ], + "start": 351 + }, + "vuln": "ExternalFunc" + }, + { + "contract": "ContractWithFunctionNotCalled", + "filename": "tests/external_function.sol", + "func": "funcNotCalled2", + "sourceMapping": { + "filename": "tests/external_function.sol", + "length": 41, + "lines": [ + 17, + 18, + 19 + ], + "start": 304 + }, + "vuln": "ExternalFunc" + }, + { + "contract": "ContractWithFunctionNotCalled", + "filename": "tests/external_function.sol", + "func": "funcNotCalled3", + "sourceMapping": { + "filename": "tests/external_function.sol", + "length": 41, + "lines": [ + 13, + 14, + 15 + ], + "start": 257 + }, + "vuln": "ExternalFunc" + }, + { + "contract": "ContractWithFunctionNotCalled2", + "filename": "tests/external_function.sol", + "func": "funcNotCalled", + "sourceMapping": { + "filename": "tests/external_function.sol", + "length": 304, + "lines": [ + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39 + ], + "start": 552 + }, + "vuln": "ExternalFunc" + } +] \ No newline at end of file diff --git a/tests/expected_json/inline_assembly_contract.assembly.json b/tests/expected_json/inline_assembly_contract.assembly.json new file mode 100644 index 000000000..376d3a49d --- /dev/null +++ b/tests/expected_json/inline_assembly_contract.assembly.json @@ -0,0 +1,31 @@ +[ + { + "contract": "GetCode", + "filename": "tests/inline_assembly_contract.sol", + "function_name": "at", + "sourceMapping": [ + { + "filename": "tests/inline_assembly_contract.sol", + "length": 628, + "lines": [ + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20 + ], + "start": 191 + } + ], + "vuln": "Assembly" + } +] \ No newline at end of file diff --git a/tests/expected_json/inline_assembly_library.assembly.json b/tests/expected_json/inline_assembly_library.assembly.json new file mode 100644 index 000000000..5183fc8cd --- /dev/null +++ b/tests/expected_json/inline_assembly_library.assembly.json @@ -0,0 +1,58 @@ +[ + { + "contract": "VectorSum", + "filename": "tests/inline_assembly_library.sol", + "function_name": "sumAsm", + "sourceMapping": [ + { + "filename": "tests/inline_assembly_library.sol", + "length": 114, + "lines": [ + 18, + 19, + 20, + 21 + ], + "start": 720 + } + ], + "vuln": "Assembly" + }, + { + "contract": "VectorSum", + "filename": "tests/inline_assembly_library.sol", + "function_name": "sumPureAsm", + "sourceMapping": [ + { + "filename": "tests/inline_assembly_library.sol", + "length": 677, + "lines": [ + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47 + ], + "start": 1000 + } + ], + "vuln": "Assembly" + } +] \ No newline at end of file diff --git a/tests/expected_json/locked_ether.locked-ether.json b/tests/expected_json/locked_ether.locked-ether.json new file mode 100644 index 000000000..10aa6213a --- /dev/null +++ b/tests/expected_json/locked_ether.locked-ether.json @@ -0,0 +1,21 @@ +[ + { + "contract": "OnlyLocked", + "functions_payable": [ + "receive" + ], + "sourceMapping": [ + { + "filename": "tests/locked_ether.sol", + "length": 72, + "lines": [ + 4, + 5, + 6 + ], + "start": 46 + } + ], + "vuln": "LockedEther" + } +] \ No newline at end of file diff --git a/tests/expected_json/low_level_calls.low-level-calls.json b/tests/expected_json/low_level_calls.low-level-calls.json new file mode 100644 index 000000000..eb0e7529f --- /dev/null +++ b/tests/expected_json/low_level_calls.low-level-calls.json @@ -0,0 +1,18 @@ +[ + { + "contract": "Sender", + "filename": "tests/low_level_calls.sol", + "function_name": "send", + "sourceMapping": [ + { + "filename": "tests/low_level_calls.sol", + "length": 43, + "lines": [ + 6 + ], + "start": 100 + } + ], + "vuln": "Low level call" + } +] \ No newline at end of file diff --git a/tests/expected_json/naming_convention.naming-convention.json b/tests/expected_json/naming_convention.naming-convention.json new file mode 100644 index 000000000..51086bc23 --- /dev/null +++ b/tests/expected_json/naming_convention.naming-convention.json @@ -0,0 +1,223 @@ +[ + { + "contract": "T", + "filename": "tests/naming_convention.sol", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 17, + "lines": [ + 56 + ], + "start": 695 + }, + "variable": "_myPublicVar", + "vuln": "NamingConvention" + }, + { + "contract": "naming", + "filename": "tests/naming_convention.sol", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 598, + "lines": [ + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48 + ], + "start": 26 + }, + "vuln": "NamingConvention" + }, + { + "contract": "naming", + "filename": "tests/naming_convention.sol", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 16, + "lines": [ + 11 + ], + "start": 183 + }, + "variable": "Var_One", + "vuln": "NamingConvention" + }, + { + "contract": "naming", + "filename": "tests/naming_convention.sol", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 20, + "lines": [ + 14, + 15, + 16 + ], + "start": 227 + }, + "struct": "test", + "vuln": "NamingConvention" + }, + { + "contract": "naming", + "filename": "tests/naming_convention.sol", + "modifier": "CantDo", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 36, + "lines": [ + 41, + 42, + 43 + ], + "start": 545 + }, + "vuln": "NamingConvention" + }, + { + "contract": "naming", + "filename": "tests/naming_convention.sol", + "function": "GetOne", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 71, + "lines": [ + 30, + 31, + 32, + 33 + ], + "start": 405 + }, + "vuln": "NamingConvention" + }, + { + "contract": "naming", + "event": "event_", + "filename": "tests/naming_convention.sol", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 19, + "lines": [ + 23 + ], + "start": 303 + }, + "vuln": "NamingConvention" + }, + { + "contract": "naming", + "enum": "numbers", + "filename": "tests/naming_convention.sol", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 23, + "lines": [ + 6 + ], + "start": 77 + }, + "vuln": "NamingConvention" + }, + { + "constant": "MY_other_CONSTANT", + "contract": "naming", + "filename": "tests/naming_convention.sol", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 35, + "lines": [ + 9 + ], + "start": 141 + }, + "vuln": "NamingConvention" + }, + { + "constant": "l", + "contract": "T", + "filename": "tests/naming_convention.sol", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 10, + "lines": [ + 67 + ], + "start": 847 + }, + "vuln": "NamingConvention" + }, + { + "argument": "Number2", + "contract": "naming", + "filename": "tests/naming_convention.sol", + "function": "setInt", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 12, + "lines": [ + 35 + ], + "start": 512 + }, + "vuln": "NamingConvention" + }, + { + "argument": "_used", + "contract": "T", + "filename": "tests/naming_convention.sol", + "function": "test", + "sourceMapping": { + "filename": "tests/naming_convention.sol", + "length": 10, + "lines": [ + 59 + ], + "start": 748 + }, + "vuln": "NamingConvention" + } +] \ No newline at end of file diff --git a/tests/expected_json/old_solc.sol.json.solc-version.json b/tests/expected_json/old_solc.sol.json.solc-version.json new file mode 100644 index 000000000..6b5dd3c6f --- /dev/null +++ b/tests/expected_json/old_solc.sol.json.solc-version.json @@ -0,0 +1,16 @@ +[ + { + "pragma": [ + "0.4.21" + ], + "sourceMapping": [ + { + "filename": "old_solc.sol", + "length": 23, + "lines": [], + "start": 0 + } + ], + "vuln": "OldPragma" + } +] \ No newline at end of file diff --git a/tests/expected_json/pragma.0.4.24.pragma.json b/tests/expected_json/pragma.0.4.24.pragma.json new file mode 100644 index 000000000..0601053f2 --- /dev/null +++ b/tests/expected_json/pragma.0.4.24.pragma.json @@ -0,0 +1,27 @@ +[ + { + "sourceMapping": [ + { + "filename": "tests/pragma.0.4.23.sol", + "length": 24, + "lines": [ + 1 + ], + "start": 0 + }, + { + "filename": "tests/pragma.0.4.24.sol", + "length": 23, + "lines": [ + 1 + ], + "start": 0 + } + ], + "versions": [ + "0.4.24", + "^0.4.23" + ], + "vuln": "ConstantPragma" + } +] \ No newline at end of file diff --git a/tests/expected_json/reentrancy.reentrancy.json b/tests/expected_json/reentrancy.reentrancy.json new file mode 100644 index 000000000..26532ccf4 --- /dev/null +++ b/tests/expected_json/reentrancy.reentrancy.json @@ -0,0 +1,47 @@ +[ + { + "calls": [ + "! (msg.sender.call.value(userBalance[msg.sender])())" + ], + "contract": "Reentrancy", + "filename": "tests/reentrancy.sol", + "function_name": "withdrawBalance", + "send_eth": [ + "! (msg.sender.call.value(userBalance[msg.sender])())" + ], + "sourceMapping": [ + { + "filename": "tests/reentrancy.sol", + "length": 37, + "lines": [ + 4 + ], + "start": 52 + }, + { + "filename": "tests/reentrancy.sol", + "length": 92, + "lines": [ + 17, + 18, + 19 + ], + "start": 478 + }, + { + "filename": "tests/reentrancy.sol", + "length": 92, + "lines": [ + 17, + 18, + 19 + ], + "start": 478 + } + ], + "varsWritten": [ + "userBalance" + ], + "vuln": "Reentrancy" + } +] \ No newline at end of file diff --git a/tests/expected_json/tx_origin.tx-origin.json b/tests/expected_json/tx_origin.tx-origin.json new file mode 100644 index 000000000..f88642adf --- /dev/null +++ b/tests/expected_json/tx_origin.tx-origin.json @@ -0,0 +1,36 @@ +[ + { + "contract": "TxOrigin", + "filename": "tests/tx_origin.sol", + "function_name": "bug0", + "sourceMapping": [ + { + "filename": "tests/tx_origin.sol", + "length": 27, + "lines": [ + 10 + ], + "start": 140 + } + ], + "vuln": "TxOrigin" + }, + { + "contract": "TxOrigin", + "filename": "tests/tx_origin.sol", + "function_name": "bug2", + "sourceMapping": [ + { + "filename": "tests/tx_origin.sol", + "length": 57, + "lines": [ + 14, + 15, + 16 + ], + "start": 206 + } + ], + "vuln": "TxOrigin" + } +] \ No newline at end of file diff --git a/tests/expected_json/uninitialized.uninitialized-state.json b/tests/expected_json/uninitialized.uninitialized-state.json new file mode 100644 index 000000000..e57b7fe44 --- /dev/null +++ b/tests/expected_json/uninitialized.uninitialized-state.json @@ -0,0 +1,120 @@ +[ + { + "contract": "Test", + "filename": "tests/uninitialized.sol", + "functions": [ + "use" + ], + "sourceMapping": [ + { + "filename": "tests/uninitialized.sol", + "length": 34, + "lines": [ + 15 + ], + "start": 189 + }, + { + "filename": "tests/uninitialized.sol", + "length": 143, + "lines": [ + 23, + 24, + 25, + 26 + ], + "start": 356 + } + ], + "variable": "balances", + "vuln": "UninitializedStateVars" + }, + { + "contract": "Test2", + "filename": "tests/uninitialized.sol", + "functions": [ + "use" + ], + "sourceMapping": [ + { + "filename": "tests/uninitialized.sol", + "length": 15, + "lines": [ + 45 + ], + "start": 695 + }, + { + "filename": "tests/uninitialized.sol", + "length": 117, + "lines": [ + 53, + 54, + 55, + 56 + ], + "start": 875 + } + ], + "variable": "st", + "vuln": "UninitializedStateVars" + }, + { + "contract": "Test2", + "filename": "tests/uninitialized.sol", + "functions": [ + "init" + ], + "sourceMapping": [ + { + "filename": "tests/uninitialized.sol", + "length": 6, + "lines": [ + 47 + ], + "start": 748 + }, + { + "filename": "tests/uninitialized.sol", + "length": 52, + "lines": [ + 49, + 50, + 51 + ], + "start": 817 + } + ], + "variable": "v", + "vuln": "UninitializedStateVars" + }, + { + "contract": "Uninitialized", + "filename": "tests/uninitialized.sol", + "functions": [ + "transfer" + ], + "sourceMapping": [ + { + "filename": "tests/uninitialized.sol", + "length": 19, + "lines": [ + 5 + ], + "start": 55 + }, + { + "filename": "tests/uninitialized.sol", + "length": 82, + "lines": [ + 7, + 8, + 9 + ], + "start": 81 + } + ], + "variable": "destination", + "vuln": "UninitializedStateVars" + } +] \ No newline at end of file diff --git a/tests/expected_json/uninitialized_local_variable.uninitialized-local.json b/tests/expected_json/uninitialized_local_variable.uninitialized-local.json new file mode 100644 index 000000000..b6d922b15 --- /dev/null +++ b/tests/expected_json/uninitialized_local_variable.uninitialized-local.json @@ -0,0 +1,31 @@ +[ + { + "contract": "Uninitialized", + "filename": "tests/uninitialized_local_variable.sol", + "function": "func", + "sourceMapping": [ + { + "filename": "tests/uninitialized_local_variable.sol", + "length": 18, + "lines": [ + 4 + ], + "start": 77 + }, + { + "filename": "tests/uninitialized_local_variable.sol", + "length": 143, + "lines": [ + 3, + 4, + 5, + 6, + 7 + ], + "start": 29 + } + ], + "variable": "uint_not_init", + "vuln": "UninitializedLocalVars" + } +] \ No newline at end of file diff --git a/tests/expected_json/uninitialized_storage_pointer.uninitialized-storage.json b/tests/expected_json/uninitialized_storage_pointer.uninitialized-storage.json new file mode 100644 index 000000000..eefb6befb --- /dev/null +++ b/tests/expected_json/uninitialized_storage_pointer.uninitialized-storage.json @@ -0,0 +1,32 @@ +[ + { + "contract": "Uninitialized", + "filename": "tests/uninitialized_storage_pointer.sol", + "function": "func", + "sourceMapping": [ + { + "filename": "tests/uninitialized_storage_pointer.sol", + "length": 9, + "lines": [ + 10 + ], + "start": 171 + }, + { + "filename": "tests/uninitialized_storage_pointer.sol", + "length": 138, + "lines": [ + 7, + 8, + 9, + 10, + 11, + 12 + ], + "start": 67 + } + ], + "variable": "st_bug", + "vuln": "UninitializedStorageVars" + } +] \ No newline at end of file diff --git a/tests/expected_json/unused_state.unused-state.json b/tests/expected_json/unused_state.unused-state.json new file mode 100644 index 000000000..b3941f53c --- /dev/null +++ b/tests/expected_json/unused_state.unused-state.json @@ -0,0 +1,20 @@ +[ + { + "contract": "B", + "filename": "tests/unused_state.sol", + "sourceMapping": [ + { + "filename": "tests/unused_state.sol", + "length": 14, + "lines": [ + 4 + ], + "start": 41 + } + ], + "unusedVars": [ + "unused" + ], + "vuln": "unusedStateVars" + } +] \ No newline at end of file diff --git a/tests/external_function.sol b/tests/external_function.sol new file mode 100644 index 000000000..040bb9329 --- /dev/null +++ b/tests/external_function.sol @@ -0,0 +1,70 @@ +pragma solidity ^0.4.24; + +import "./external_function_test_2.sol"; + +contract ContractWithFunctionCalledSuper is ContractWithFunctionCalled { + function callWithSuper() public { + uint256 i = 0; + } +} + +contract ContractWithFunctionNotCalled { + + function funcNotCalled3() public { + + } + + function funcNotCalled2() public { + + } + + function funcNotCalled() public { + + } + + function my_func() internal returns(bool){ + return true; + } + +} + +contract ContractWithFunctionNotCalled2 is ContractWithFunctionCalledSuper { + function funcNotCalled() public { + uint256 i = 0; + address three = new ContractWithFunctionNotCalled(); + three.call(bytes4(keccak256("helloTwo()"))); + super.callWithSuper(); + ContractWithFunctionCalled c = new ContractWithFunctionCalled(); + c.funcCalled(); + } +} + +contract InternalCall { + + function() returns(uint) ptr; + + function set_test1() external{ + ptr = test1; + } + + function set_test2() external{ + ptr = test2; + } + + function test1() public returns(uint){ + return 1; + } + + function test2() public returns(uint){ + return 2; + } + + function test3() public returns(uint){ + return 3; + } + + function exec() external returns(uint){ + return ptr(); + } + +} diff --git a/tests/external_function_test_2.sol b/tests/external_function_test_2.sol new file mode 100644 index 000000000..406494631 --- /dev/null +++ b/tests/external_function_test_2.sol @@ -0,0 +1,7 @@ +pragma solidity ^0.4.24; + +contract ContractWithFunctionCalled { + function funcCalled() external { + uint256 i = 0; + } +} diff --git a/tests/inline_assembly_contract.sol b/tests/inline_assembly_contract.sol new file mode 100644 index 000000000..fd5aa7942 --- /dev/null +++ b/tests/inline_assembly_contract.sol @@ -0,0 +1,22 @@ +pragma solidity ^0.4.0; + +// taken from https://solidity.readthedocs.io/en/v0.4.25/assembly.html + +library GetCode { + function at(address _addr) public view returns (bytes o_code) { + assembly { + // retrieve the size of the code, this needs assembly + let size := extcodesize(_addr) + // allocate output byte array - this could also be done without assembly + // by using o_code = new bytes(size) + o_code := mload(0x40) + // new "memory end" including padding + mstore(0x40, add(o_code, and(add(add(size, 0x20), 0x1f), not(0x1f)))) + // store length in memory + mstore(o_code, size) + // actually retrieve the code, this needs assembly + extcodecopy(_addr, add(o_code, 0x20), 0, size) + } + } +} + diff --git a/tests/inline_assembly_library.sol b/tests/inline_assembly_library.sol new file mode 100644 index 000000000..ce7aef7a5 --- /dev/null +++ b/tests/inline_assembly_library.sol @@ -0,0 +1,49 @@ +pragma solidity ^0.4.16; + +// taken from https://solidity.readthedocs.io/en/v0.4.25/assembly.html + +library VectorSum { + // This function is less efficient because the optimizer currently fails to + // remove the bounds checks in array access. + function sumSolidity(uint[] _data) public view returns (uint o_sum) { + for (uint i = 0; i < _data.length; ++i) + o_sum += _data[i]; + } + + // We know that we only access the array in bounds, so we can avoid the check. + // 0x20 needs to be added to an array because the first slot contains the + // array length. + function sumAsm(uint[] _data) public view returns (uint o_sum) { + for (uint i = 0; i < _data.length; ++i) { + assembly { + o_sum := add(o_sum, mload(add(add(_data, 0x20), mul(i, 0x20)))) + } + } + } + + // Same as above, but accomplish the entire code within inline assembly. + function sumPureAsm(uint[] _data) public view returns (uint o_sum) { + assembly { + // Load the length (first 32 bytes) + let len := mload(_data) + + // Skip over the length field. + // + // Keep temporary variable so it can be incremented in place. + // + // NOTE: incrementing _data would result in an unusable + // _data variable after this assembly block + let data := add(_data, 0x20) + + // Iterate until the bound is not met. + for + { let end := add(data, len) } + lt(data, end) + { data := add(data, 0x20) } + { + o_sum := add(o_sum, mload(data)) + } + } + } +} + diff --git a/tests/locked_ether.sol b/tests/locked_ether.sol new file mode 100644 index 000000000..1e9e57c7d --- /dev/null +++ b/tests/locked_ether.sol @@ -0,0 +1,26 @@ +pragma solidity 0.4.24; +contract Locked{ + + function receive() payable public{ + require(msg.value > 0); + } + +} + +contract Send{ + address owner = msg.sender; + + function withdraw() public{ + owner.transfer(address(this).balance); + } +} + +contract Unlocked is Locked, Send{ + + function withdraw() public{ + super.withdraw(); + } + +} + +contract OnlyLocked is Locked{ } diff --git a/tests/low_level_calls.sol b/tests/low_level_calls.sol new file mode 100644 index 000000000..c5f1e2e46 --- /dev/null +++ b/tests/low_level_calls.sol @@ -0,0 +1,17 @@ +pragma solidity ^0.4.24; + + +contract Sender { + function send(address _receiver) payable { + _receiver.call.value(msg.value).gas(7777)(); + } +} + + +contract Receiver { + uint public balance = 0; + + function () payable { + balance += msg.value; + } +} \ No newline at end of file diff --git a/tests/naming_convention.sol b/tests/naming_convention.sol new file mode 100644 index 000000000..f632aa69e --- /dev/null +++ b/tests/naming_convention.sol @@ -0,0 +1,68 @@ +pragma solidity ^0.4.24; + +contract naming { + + enum Numbers {ONE, TWO} + enum numbers {ONE, TWO} + + uint constant MY_CONSTANT = 1; + uint constant MY_other_CONSTANT = 2; + + uint Var_One = 1; + uint varTwo = 2; + + struct test { + + } + + struct Test { + + } + + event Event_(uint); + event event_(uint); + + function getOne() constant returns (uint) + { + return 1; + } + + function GetOne() constant returns (uint) + { + return 1; + } + + function setInt(uint number1, uint Number2) + { + + } + + + modifier CantDo() { + _; + } + + modifier canDo() { + _; + } +} + +contract Test { + +} + +contract T { + uint private _myPrivateVar; + uint _myPublicVar; + + + function test(uint _unused, uint _used) returns(uint){ + return _used;} + + + uint k = 1; + + uint constant M = 1; + + uint l = 1; +} diff --git a/tests/old_solc.sol b/tests/old_solc.sol new file mode 100644 index 000000000..291d4c96a --- /dev/null +++ b/tests/old_solc.sol @@ -0,0 +1,5 @@ +pragma solidity 0.4.21; + +contract Contract{ + +} diff --git a/tests/old_solc.sol.json b/tests/old_solc.sol.json new file mode 100644 index 000000000..2b784ba3a --- /dev/null +++ b/tests/old_solc.sol.json @@ -0,0 +1,67 @@ +JSON AST: + + +======= old_solc.sol ======= +{ + "attributes" : + { + "absolutePath" : "old_solc.sol", + "exportedSymbols" : + { + "Contract" : + [ + 2 + ] + } + }, + "children" : + [ + { + "attributes" : + { + "literals" : + [ + "solidity", + "0.4", + ".21" + ] + }, + "id" : 1, + "name" : "PragmaDirective", + "src" : "0:23:0" + }, + { + "attributes" : + { + "baseContracts" : + [ + null + ], + "contractDependencies" : + [ + null + ], + "contractKind" : "contract", + "documentation" : null, + "fullyImplemented" : true, + "linearizedBaseContracts" : + [ + 2 + ], + "name" : "Contract", + "nodes" : + [ + null + ], + "scope" : 3 + }, + "id" : 2, + "name" : "ContractDefinition", + "src" : "25:21:0" + } + ], + "id" : 3, + "name" : "SourceUnit", + "src" : "0:47:0" +} +======= old_solc.sol:Contract ======= diff --git a/tests/pragma.0.4.23.sol b/tests/pragma.0.4.23.sol new file mode 100644 index 000000000..6e6a5000f --- /dev/null +++ b/tests/pragma.0.4.23.sol @@ -0,0 +1 @@ +pragma solidity ^0.4.23; diff --git a/tests/pragma.0.4.24.sol b/tests/pragma.0.4.24.sol new file mode 100644 index 000000000..cd946840b --- /dev/null +++ b/tests/pragma.0.4.24.sol @@ -0,0 +1,5 @@ +pragma solidity 0.4.24; + +import "./pragma.0.4.23.sol"; + +contract Test{} diff --git a/tests/reentrancy.sol b/tests/reentrancy.sol new file mode 100644 index 000000000..9f79587a2 --- /dev/null +++ b/tests/reentrancy.sol @@ -0,0 +1,51 @@ +pragma solidity ^0.4.24; + +contract Reentrancy { + mapping (address => uint) userBalance; + + function getBalance(address u) view public returns(uint){ + return userBalance[u]; + } + + function addToBalance() payable public{ + userBalance[msg.sender] += msg.value; + } + + function withdrawBalance() public{ + // send userBalance[msg.sender] ethers to msg.sender + // if mgs.sender is a contract, it will call its fallback function + if( ! (msg.sender.call.value(userBalance[msg.sender])() ) ){ + revert(); + } + userBalance[msg.sender] = 0; + } + + function withdrawBalance_fixed() public{ + // To protect against re-entrancy, the state variable + // has to be change before the call + uint amount = userBalance[msg.sender]; + userBalance[msg.sender] = 0; + if( ! (msg.sender.call.value(amount)() ) ){ + revert(); + } + } + + function withdrawBalance_fixed_2() public{ + // send() and transfer() are safe against reentrancy + // they do not transfer the remaining gas + // and they give just enough gas to execute few instructions + // in the fallback function (no further call possible) + msg.sender.transfer(userBalance[msg.sender]); + userBalance[msg.sender] = 0; + } + + function withdrawBalance_fixed_3() public{ + // The state can be changed + // But it is fine, as it can only occur if the transaction fails + uint amount = userBalance[msg.sender]; + userBalance[msg.sender] = 0; + if( ! (msg.sender.call.value(amount)() ) ){ + userBalance[msg.sender] = amount; + } + } +} diff --git a/tests/reentrancy_indirect.sol b/tests/reentrancy_indirect.sol new file mode 100644 index 000000000..73dd3db00 --- /dev/null +++ b/tests/reentrancy_indirect.sol @@ -0,0 +1,31 @@ +pragma solidity ^0.4.24; + +contract Token{ + function transfer(address to, uint value) returns(bool); + function transferFrom(address from, address to, uint value) returns(bool); +} + +contract Reentrancy { + + mapping(address => mapping(address => uint)) eth_deposed; + mapping(address => mapping(address => uint)) token_deposed; + + function deposit_eth(address token) payable{ + eth_deposed[token][msg.sender] += msg.value; + } + + function deposit_token(address token, uint value){ + token_deposed[token][msg.sender] += value; + require(Token(token).transferFrom(msg.sender, address(this), value)); + } + + function withdraw(address token){ + msg.sender.transfer(eth_deposed[token][msg.sender]); + require(Token(token).transfer(msg.sender, token_deposed[token][msg.sender])); + + eth_deposed[token][msg.sender] = 0; + token_deposed[token][msg.sender] = 0; + + } + +} diff --git a/tests/taint_mapping.sol b/tests/taint_mapping.sol new file mode 100644 index 000000000..c1a24ed38 --- /dev/null +++ b/tests/taint_mapping.sol @@ -0,0 +1,18 @@ +contract Test{ + + mapping(uint => mapping(uint => address)) authorized_destination; + + address destination; + + function init(){ + authorized_destination[0][0] = msg.sender; + } + + function setup(uint idx){ + destination = authorized_destination[0][0]; + } + + function withdraw(){ + destination.transfer(this.balance); + } +} diff --git a/tests/tx_origin.sol b/tests/tx_origin.sol new file mode 100644 index 000000000..93bb5c757 --- /dev/null +++ b/tests/tx_origin.sol @@ -0,0 +1,26 @@ +pragma solidity ^0.4.24; + +contract TxOrigin { + + address owner; + + constructor() { owner = msg.sender; } + + function bug0() { + require(tx.origin == owner); + } + + function bug2() { + if (tx.origin != owner) { + revert(); + } + } + + function legit0(){ + require(tx.origin == msg.sender); + } + + function legit1(){ + tx.origin.transfer(this.balance); + } +} diff --git a/tests/uninitialized.sol b/tests/uninitialized.sol new file mode 100644 index 000000000..2305483e6 --- /dev/null +++ b/tests/uninitialized.sol @@ -0,0 +1,58 @@ +pragma solidity ^0.4.24; + +contract Uninitialized{ + + address destination; + + function transfer() payable public{ + destination.transfer(msg.value); + } + +} + + +contract Test { + mapping (address => uint) balances; + mapping (address => uint) balancesInitialized; + + + function init() { + balancesInitialized[msg.sender] = 0; + } + + function use() { + // random operation to use the mapping + require(balances[msg.sender] == balancesInitialized[msg.sender]); + } +} + +library Lib{ + + struct MyStruct{ + uint val; + } + + function set(MyStruct storage st, uint v){ + st.val = 4; + } + +} + + +contract Test2 { + using Lib for Lib.MyStruct; + + Lib.MyStruct st; + Lib.MyStruct stInitiliazed; + uint v; // v is used as parameter of the lib, but is never init + + function init(){ + stInitiliazed.set(v); + } + + function use(){ + // random operation to use the structure + require(st.val == stInitiliazed.val); + } + +} diff --git a/tests/uninitialized_storage_pointer.sol b/tests/uninitialized_storage_pointer.sol new file mode 100644 index 000000000..e494cfe68 --- /dev/null +++ b/tests/uninitialized_storage_pointer.sol @@ -0,0 +1,14 @@ +contract Uninitialized{ + + struct St{ + uint a; + } + + function func() { + St st; // non init, but never read so its fine + St memory st2; + St st_bug; + st_bug.a; + } + +} diff --git a/tests/unused_state.sol b/tests/unused_state.sol new file mode 100644 index 000000000..7d3e5875a --- /dev/null +++ b/tests/unused_state.sol @@ -0,0 +1,13 @@ +pragma solidity 0.4.24; + +contract A{ + address unused; + address used; +} + +contract B is A{ + + function () public{ + used = 0; + } +}