From c8d91b6f417e1db94269f651defba8cad5c25a72 Mon Sep 17 00:00:00 2001 From: Nikhil Parasaram Date: Wed, 21 Aug 2019 17:45:19 +0530 Subject: [PATCH] Add the rest of the precompiles (#990) * Support modular exponentiation for concrete data * Add elliptic curve operations * Fix type hints and refactor code * Support usage of the rest of the native contracts * Remove unused imports * Add tests for elliptic curve functions * Use a constant for native functions count * Update py_ecc version * Use PRECOMPILE_COUNT over hardcoded value * Use shortened list comprehensives * Clean up imports * Use precompile count in checking precompile calls * Refactor code --- mythril/analysis/call_helpers.py | 3 +- mythril/analysis/modules/external_calls.py | 3 +- mythril/analysis/symbolic.py | 6 +- mythril/laser/ethereum/call.py | 13 +- mythril/laser/ethereum/natives.py | 168 +++++++++++++----- mythril/laser/ethereum/util.py | 24 +++ requirements.txt | 2 +- setup.py | 2 +- tests/laser/Precompiles/test_ec_add.py | 27 +++ .../laser/Precompiles/test_elliptic_curves.py | 35 ++++ tests/laser/Precompiles/test_elliptic_mul.py | 27 +++ tests/laser/Precompiles/test_mod_exp.py | 61 +++++++ 12 files changed, 321 insertions(+), 50 deletions(-) create mode 100644 tests/laser/Precompiles/test_ec_add.py create mode 100644 tests/laser/Precompiles/test_elliptic_curves.py create mode 100644 tests/laser/Precompiles/test_elliptic_mul.py create mode 100644 tests/laser/Precompiles/test_mod_exp.py diff --git a/mythril/analysis/call_helpers.py b/mythril/analysis/call_helpers.py index 6cb796df..270ff5af 100644 --- a/mythril/analysis/call_helpers.py +++ b/mythril/analysis/call_helpers.py @@ -4,6 +4,7 @@ from typing import Union from mythril.analysis.ops import VarType, Call, get_variable from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.natives import PRECOMPILE_COUNT def get_call_from_state(state: GlobalState) -> Union[Call, None]: @@ -28,7 +29,7 @@ def get_call_from_state(state: GlobalState) -> Union[Call, None]: get_variable(stack[-7]), ) - if to.type == VarType.CONCRETE and 0 < to.val < 5: + if to.type == VarType.CONCRETE and 0 < to.val <= PRECOMPILE_COUNT: return None if meminstart.type == VarType.CONCRETE and meminsz.type == VarType.CONCRETE: diff --git a/mythril/analysis/modules/external_calls.py b/mythril/analysis/modules/external_calls.py index e80c8802..81cd6bd2 100644 --- a/mythril/analysis/modules/external_calls.py +++ b/mythril/analysis/modules/external_calls.py @@ -10,6 +10,7 @@ from mythril.laser.ethereum.transaction.transaction_models import ( from mythril.analysis.modules.base import DetectionModule from mythril.analysis.report import Issue from mythril.laser.smt import UGT, symbol_factory, Or, BitVec +from mythril.laser.ethereum.natives import PRECOMPILE_COUNT from mythril.laser.ethereum.state.global_state import GlobalState from mythril.exceptions import UnsatError from copy import copy @@ -33,7 +34,7 @@ def _is_precompile_call(global_state: GlobalState): constraints += [ Or( to < symbol_factory.BitVecVal(1, 256), - to > symbol_factory.BitVecVal(16, 256), + to > symbol_factory.BitVecVal(PRECOMPILE_COUNT, 256), ) ] diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index 02b5c8ca..ce98b0ce 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -14,6 +14,7 @@ from mythril.laser.ethereum.strategy.basic import ( BasicSearchStrategy, ) +from mythril.laser.ethereum.natives import PRECOMPILE_COUNT from mythril.laser.ethereum.transaction.symbolic import ( ATTACKER_ADDRESS, CREATOR_ADDRESS, @@ -212,7 +213,10 @@ class SymExecWrapper: get_variable(stack[-7]), ) - if to.type == VarType.CONCRETE and to.val < 5: + if ( + to.type == VarType.CONCRETE + and 0 < to.val <= PRECOMPILE_COUNT + ): # ignore prebuilts continue diff --git a/mythril/laser/ethereum/call.py b/mythril/laser/ethereum/call.py index bc1c360c..24a8e2f3 100644 --- a/mythril/laser/ethereum/call.py +++ b/mythril/laser/ethereum/call.py @@ -10,13 +10,14 @@ import mythril.laser.ethereum.util as util from mythril.laser.ethereum import natives from mythril.laser.ethereum.gas import OPCODE_GAS from mythril.laser.ethereum.state.account import Account +from mythril.laser.ethereum.natives import PRECOMPILE_COUNT from mythril.laser.ethereum.state.calldata import ( BaseCalldata, SymbolicCalldata, ConcreteCalldata, ) from mythril.laser.ethereum.state.global_state import GlobalState -from mythril.laser.smt import BitVec, Bool, is_true +from mythril.laser.smt import BitVec, is_true from mythril.laser.smt import simplify, Expression, symbol_factory from mythril.support.loader import DynLoader @@ -51,7 +52,7 @@ def get_call_parameters( call_data = get_call_data(global_state, memory_input_offset, memory_input_size) if ( isinstance(callee_address, BitVec) - or int(callee_address, 16) >= 5 + or int(callee_address, 16) > PRECOMPILE_COUNT or int(callee_address, 16) == 0 ): callee_account = get_callee_account( @@ -223,8 +224,12 @@ def native_call( call_data: BaseCalldata, memory_out_offset: Union[int, Expression], memory_out_size: Union[int, Expression], -) -> Union[List[GlobalState], None]: - if isinstance(callee_address, BitVec) or not 0 < int(callee_address, 16) < 5: +) -> Optional[List[GlobalState]]: + + if ( + isinstance(callee_address, BitVec) + or not 0 < int(callee_address, 16) <= PRECOMPILE_COUNT + ): return None log.debug("Native contract called: " + callee_address) diff --git a/mythril/laser/ethereum/natives.py b/mythril/laser/ethereum/natives.py index ef38d8c9..991a172c 100644 --- a/mythril/laser/ethereum/natives.py +++ b/mythril/laser/ethereum/natives.py @@ -6,12 +6,20 @@ from typing import List, Union from ethereum.utils import ecrecover_to_pub from py_ecc.secp256k1 import N as secp256k1n +import py_ecc.optimized_bn128 as bn128 from rlp.utils import ALL_BYTES from mythril.laser.ethereum.state.calldata import BaseCalldata, ConcreteCalldata -from mythril.laser.ethereum.util import bytearray_to_int -from ethereum.utils import sha3 -from mythril.laser.smt import Concat, simplify +from mythril.laser.ethereum.util import extract_copy, extract32 +from ethereum.utils import ( + sha3, + big_endian_to_int, + safe_ord, + zpad, + int_to_big_endian, + encode_int32, +) +from ethereum.specials import validate_point log = logging.getLogger(__name__) @@ -22,35 +30,6 @@ class NativeContractException(Exception): pass -def int_to_32bytes( - i: int -) -> bytes: # used because int can't fit as bytes function's input - """ - - :param i: - :return: - """ - o = [0] * 32 - for x in range(32): - o[31 - x] = i & 0xFF - i >>= 8 - return bytes(o) - - -def extract32(data: bytearray, i: int) -> int: - """ - - :param data: - :param i: - :return: - """ - if i >= len(data): - return 0 - o = data[i : min(i + 32, len(data))] - o.extend(bytearray(32 - len(o))) - return bytearray_to_int(o) - - def ecrecover(data: List[int]) -> List[int]: """ @@ -59,14 +38,14 @@ def ecrecover(data: List[int]) -> List[int]: """ # TODO: Add type hints try: - byte_data = bytearray(data) - v = extract32(byte_data, 32) - r = extract32(byte_data, 64) - s = extract32(byte_data, 96) + bytes_data = bytearray(data) + v = extract32(bytes_data, 32) + r = extract32(bytes_data, 64) + s = extract32(bytes_data, 96) except TypeError: raise NativeContractException - message = b"".join([ALL_BYTES[x] for x in byte_data[0:32]]) + message = b"".join([ALL_BYTES[x] for x in bytes_data[0:32]]) if r >= secp256k1n or s >= secp256k1n or v < 27 or v > 28: return [] try: @@ -85,10 +64,10 @@ def sha256(data: List[int]) -> List[int]: :return: """ try: - byte_data = bytes(data) + bytes_data = bytes(data) except TypeError: raise NativeContractException - return list(bytearray(hashlib.sha256(byte_data).digest())) + return list(bytearray(hashlib.sha256(bytes_data).digest())) def ripemd160(data: List[int]) -> List[int]: @@ -120,6 +99,114 @@ def identity(data: List[int]) -> List[int]: return data +def mod_exp(data: List[int]) -> List[int]: + """ + TODO: Some symbolic parts can be handled here + Modular Exponentiation + :param data: Data with + :return: modular exponentiation + """ + bytes_data = bytearray(data) + baselen = extract32(bytes_data, 0) + explen = extract32(bytes_data, 32) + modlen = extract32(bytes_data, 64) + if baselen == 0: + return [0] * modlen + if modlen == 0: + return [] + + first_exp_bytes = extract32(bytes_data, 96 + baselen) >> (8 * max(32 - explen, 0)) + bitlength = -1 + while first_exp_bytes: + bitlength += 1 + first_exp_bytes >>= 1 + + base = bytearray(baselen) + extract_copy(bytes_data, base, 0, 96, baselen) + exp = bytearray(explen) + extract_copy(bytes_data, exp, 0, 96 + baselen, explen) + mod = bytearray(modlen) + extract_copy(bytes_data, mod, 0, 96 + baselen + explen, modlen) + if big_endian_to_int(mod) == 0: + return [0] * modlen + o = pow(big_endian_to_int(base), big_endian_to_int(exp), big_endian_to_int(mod)) + return [safe_ord(x) for x in zpad(int_to_big_endian(o), modlen)] + + +def ec_add(data: List[int]) -> List[int]: + bytes_data = bytearray(data) + x1 = extract32(bytes_data, 0) + y1 = extract32(bytes_data, 32) + x2 = extract32(bytes_data, 64) + y2 = extract32(bytes_data, 96) + p1 = validate_point(x1, y1) + p2 = validate_point(x2, y2) + if p1 is False or p2 is False: + return [] + o = bn128.normalize(bn128.add(p1, p2)) + return [safe_ord(x) for x in (encode_int32(o[0].n) + encode_int32(o[1].n))] + + +def ec_mul(data: List[int]) -> List[int]: + bytes_data = bytearray(data) + x = extract32(bytes_data, 0) + y = extract32(bytes_data, 32) + m = extract32(bytes_data, 64) + p = validate_point(x, y) + if p is False: + return [] + o = bn128.normalize(bn128.multiply(p, m)) + return [safe_ord(c) for c in (encode_int32(o[0].n) + encode_int32(o[1].n))] + + +def ec_pair(data: List[int]) -> List[int]: + if len(data) % 192: + return [] + + zero = (bn128.FQ2.one(), bn128.FQ2.one(), bn128.FQ2.zero()) + exponent = bn128.FQ12.one() + bytes_data = bytearray(data) + for i in range(0, len(bytes_data), 192): + x1 = extract32(bytes_data, i) + y1 = extract32(bytes_data, i + 32) + x2_i = extract32(bytes_data, i + 64) + x2_r = extract32(bytes_data, i + 96) + y2_i = extract32(bytes_data, i + 128) + y2_r = extract32(bytes_data, i + 160) + p1 = validate_point(x1, y1) + if p1 is False: + return [] + for v in (x2_i, x2_r, y2_i, y2_r): + if v >= bn128.field_modulus: + return [] + fq2_x = bn128.FQ2([x2_r, x2_i]) + fq2_y = bn128.FQ2([y2_r, y2_i]) + if (fq2_x, fq2_y) != (bn128.FQ2.zero(), bn128.FQ2.zero()): + p2 = (fq2_x, fq2_y, bn128.FQ2.one()) + if not bn128.is_on_curve(p2, bn128.b2): + return [] + else: + p2 = zero + if bn128.multiply(p2, bn128.curve_order)[-1] != bn128.FQ2.zero(): + return [] + exponent *= bn128.pairing(p2, p1, final_exponentiate=False) + result = bn128.final_exponentiate(exponent) == bn128.FQ12.one() + return [0] * 31 + [1 if result else 0] + + +PRECOMPILE_FUNCTIONS = ( + ecrecover, + sha256, + ripemd160, + identity, + mod_exp, + ec_add, + ec_mul, + ec_pair, +) +PRECOMPILE_COUNT = len(PRECOMPILE_FUNCTIONS) + + def native_contracts(address: int, data: BaseCalldata) -> List[int]: """Takes integer address 1, 2, 3, 4. @@ -127,11 +214,10 @@ def native_contracts(address: int, data: BaseCalldata) -> List[int]: :param data: :return: """ - functions = (ecrecover, sha256, ripemd160, identity) if isinstance(data, ConcreteCalldata): concrete_data = data.concrete(None) else: raise NativeContractException() - return functions[address - 1](concrete_data) + return PRECOMPILE_FUNCTIONS[address - 1](concrete_data) diff --git a/mythril/laser/ethereum/util.py b/mythril/laser/ethereum/util.py index 9cb5d950..4191173d 100644 --- a/mythril/laser/ethereum/util.py +++ b/mythril/laser/ethereum/util.py @@ -150,3 +150,27 @@ def bytearray_to_int(arr): for a in arr: o = (o << 8) + a return o + + +def extract_copy( + data: bytearray, mem: bytearray, memstart: int, datastart: int, size: int +): + for i in range(size): + if datastart + i < len(data): + mem[memstart + i] = data[datastart + i] + else: + mem[memstart + i] = 0 + + +def extract32(data: bytearray, i: int) -> int: + """ + + :param data: + :param i: + :return: + """ + if i >= len(data): + return 0 + o = data[i : min(i + 32, len(data))] + o.extend(bytearray(32 - len(o))) + return bytearray_to_int(o) diff --git a/requirements.txt b/requirements.txt index 11382df3..6901873e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ coloredlogs>=10.0 configparser>=3.5.0 coverage -py_ecc==1.4.2 +py_ecc==1.6.0 eth_abi==1.3.0 eth-account>=0.1.0a2,<=0.3.0 ethereum>=2.3.2 diff --git a/setup.py b/setup.py index ce115193..bee87da0 100755 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ REQUIRES_PYTHON = ">=3.5.0" # What packages are required for this module to be executed? REQUIRED = [ "coloredlogs>=10.0", - "py_ecc==1.4.2", + "py_ecc==1.6.0", "ethereum>=2.3.2", "z3-solver>=4.8.5.0", "requests", diff --git a/tests/laser/Precompiles/test_ec_add.py b/tests/laser/Precompiles/test_ec_add.py new file mode 100644 index 00000000..4ede2cf0 --- /dev/null +++ b/tests/laser/Precompiles/test_ec_add.py @@ -0,0 +1,27 @@ +from mock import patch +from eth_utils import decode_hex +from mythril.laser.ethereum.natives import ec_add +from py_ecc.optimized_bn128 import FQ + +VECTOR_A = decode_hex( + "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000020" + "0000000000000000000000000000000000000000000000000000000000000020" + "03" + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2e" + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" +) + + +def test_ec_add_sanity(): + assert ec_add(VECTOR_A) == [] + + +@patch("mythril.laser.ethereum.natives.validate_point", return_value=1) +@patch("mythril.laser.ethereum.natives.bn128.add", return_value=1) +@patch("mythril.laser.ethereum.natives.bn128.normalize") +def test_ec_add(f1, f2, f3): + FQ.fielf_modulus = 128 + a = FQ(val=1) + f1.return_value = (a, a) + assert ec_add(VECTOR_A) == ([0] * 31 + [1]) * 2 diff --git a/tests/laser/Precompiles/test_elliptic_curves.py b/tests/laser/Precompiles/test_elliptic_curves.py new file mode 100644 index 00000000..28908f58 --- /dev/null +++ b/tests/laser/Precompiles/test_elliptic_curves.py @@ -0,0 +1,35 @@ +from mock import patch +from mythril.laser.ethereum.natives import ec_pair +from py_ecc.optimized_bn128 import FQ + + +def test_ec_pair_192_check(): + vec_c = [0] * 100 + assert ec_pair(vec_c) == [] + + +@patch("mythril.laser.ethereum.natives.validate_point", return_value=1) +@patch("mythril.laser.ethereum.natives.bn128.is_on_curve", return_value=True) +@patch("mythril.laser.ethereum.natives.bn128.pairing", return_value=1) +@patch("mythril.laser.ethereum.natives.bn128.normalize") +def test_ec_pair(f1, f2, f3, f4): + FQ.fielf_modulus = 100 + a = FQ(val=1) + f1.return_value = (a, a) + vec_c = [0] * 192 + assert ec_pair(vec_c) == [0] * 31 + [1] + + +@patch("mythril.laser.ethereum.natives.validate_point", return_value=False) +def test_ec_pair_point_validation_failure(f1): + vec_c = [0] * 192 + assert ec_pair(vec_c) == [] + + +@patch("mythril.laser.ethereum.natives.validate_point", return_value=1) +def test_ec_pair_field_exceed_mod(f1): + FQ.fielf_modulus = 100 + a = FQ(val=1) + f1.return_value = (a, a) + vec_c = [10] * 192 + assert ec_pair(vec_c) == [] diff --git a/tests/laser/Precompiles/test_elliptic_mul.py b/tests/laser/Precompiles/test_elliptic_mul.py new file mode 100644 index 00000000..c3b3be9d --- /dev/null +++ b/tests/laser/Precompiles/test_elliptic_mul.py @@ -0,0 +1,27 @@ +from mock import patch +from eth_utils import decode_hex +from mythril.laser.ethereum.natives import ec_mul +from py_ecc.optimized_bn128 import FQ + +VECTOR_A = decode_hex( + "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000020" + "0000000000000000000000000000000000000000000000000000000000000020" + "03" + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2e" + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" +) + + +@patch("mythril.laser.ethereum.natives.validate_point", return_value=1) +@patch("mythril.laser.ethereum.natives.bn128.multiply", return_value=1) +@patch("mythril.laser.ethereum.natives.bn128.normalize") +def test_ec_mul(f1, f2, f3): + FQ.fielf_modulus = 128 + a = FQ(val=1) + f1.return_value = (a, a) + assert ec_mul(VECTOR_A) == ([0] * 31 + [1]) * 2 + + +def test_ec_mul_validation_failure(): + assert ec_mul(VECTOR_A) == [] diff --git a/tests/laser/Precompiles/test_mod_exp.py b/tests/laser/Precompiles/test_mod_exp.py new file mode 100644 index 00000000..d050c929 --- /dev/null +++ b/tests/laser/Precompiles/test_mod_exp.py @@ -0,0 +1,61 @@ +import pytest +from eth_utils import decode_hex +from mythril.laser.ethereum.natives import mod_exp +from ethereum.utils import big_endian_to_int + + +EIP198_VECTOR_A = decode_hex( + "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000020" + "0000000000000000000000000000000000000000000000000000000000000020" + "03" + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2e" + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" +) + +EIP198_VECTOR_B = decode_hex( + "0000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000000000000000000000000000000000000000000000000020" + "0000000000000000000000000000000000000000000000000000000000000020" + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2e" + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" +) + +EIP198_VECTOR_C = decode_hex( + "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000002" + "0000000000000000000000000000000000000000000000000000000000000020" + "03" + "ffff" + "8000000000000000000000000000000000000000000000000000000000000000" + "07" +) + +EIP198_VECTOR_D = decode_hex( + "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000002" + "0000000000000000000000000000000000000000000000000000000000000020" + "03" + "ffff" + "80" +) + + +@pytest.mark.parametrize( + "data,expected", + ( + (EIP198_VECTOR_A, 1), + (EIP198_VECTOR_B, 0), + ( + EIP198_VECTOR_C, + 26689440342447178617115869845918039756797228267049433585260346420242739014315, + ), + ( + EIP198_VECTOR_D, + 26689440342447178617115869845918039756797228267049433585260346420242739014315, + ), + ), +) +def test_modexp_result(data, expected): + actual = mod_exp(data) + assert big_endian_to_int(actual) == expected