diff --git a/core/staking_verifier.go b/core/staking_verifier.go new file mode 100644 index 000000000..9ce48266a --- /dev/null +++ b/core/staking_verifier.go @@ -0,0 +1,266 @@ +package core + +import ( + "bytes" + "math/big" + + "github.com/pkg/errors" + + "github.com/harmony-one/harmony/core/vm" + common2 "github.com/harmony-one/harmony/internal/common" + staking "github.com/harmony-one/harmony/staking/types" +) + +var ( + errStateDBIsMissing = errors.New("no stateDB was provided") + errChainContextMissing = errors.New("no chain context was provided") + errEpochMissing = errors.New("no epoch was provided") + errBlockNumMissing = errors.New("no block number was provided") +) + +// TODO: add unit tests to check staking msg verification + +// VerifyAndCreateValidatorFromMsg verifies the create validator message using +// the stateDB, epoch, & blocknumber and returns the validatorWrapper created +// in the process. +// +// Note that this function never updates the stateDB, it only reads from stateDB. +func VerifyAndCreateValidatorFromMsg( + stateDB vm.StateDB, epoch *big.Int, blockNum *big.Int, msg *staking.CreateValidator, +) (*staking.ValidatorWrapper, error) { + if stateDB == nil { + return nil, errStateDBIsMissing + } + if epoch == nil { + return nil, errEpochMissing + } + if blockNum == nil { + return nil, errBlockNumMissing + } + if msg.Amount.Sign() == -1 { + return nil, errNegativeAmount + } + if stateDB.IsValidator(msg.ValidatorAddress) { + return nil, errors.Wrapf(errValidatorExist, common2.MustAddressToBech32(msg.ValidatorAddress)) + } + if !CanTransfer(stateDB, msg.ValidatorAddress, msg.Amount) { + return nil, errInsufficientBalanceForStake + } + v, err := staking.CreateValidatorFromNewMsg(msg, blockNum) + if err != nil { + return nil, err + } + wrapper := &staking.ValidatorWrapper{} + wrapper.Validator = *v + wrapper.Delegations = []staking.Delegation{ + staking.NewDelegation(v.Address, msg.Amount), + } + wrapper.Snapshot.Epoch = epoch + wrapper.Snapshot.NumBlocksSigned = big.NewInt(0) + wrapper.Snapshot.NumBlocksToSign = big.NewInt(0) + if err := wrapper.SanityCheck(); err != nil { + return nil, err + } + return wrapper, nil +} + +// VerifyAndEditValidatorFromMsg verifies the edit validator message using +// the stateDB, chainContext and returns the edited validatorWrapper. +// +// Note that this function never updates the stateDB, it only reads from stateDB. +func VerifyAndEditValidatorFromMsg( + stateDB vm.StateDB, chainContext ChainContext, blockNum *big.Int, msg *staking.EditValidator, +) (*staking.ValidatorWrapper, error) { + if stateDB == nil { + return nil, errStateDBIsMissing + } + if chainContext == nil { + return nil, errChainContextMissing + } + if blockNum == nil { + return nil, errBlockNumMissing + } + if !stateDB.IsValidator(msg.ValidatorAddress) { + return nil, errValidatorNotExist + } + wrapper := stateDB.GetStakingInfo(msg.ValidatorAddress) + if wrapper == nil { + return nil, errValidatorNotExist + } + if err := staking.UpdateValidatorFromEditMsg(&wrapper.Validator, msg); err != nil { + return nil, err + } + newRate := wrapper.Validator.Rate + if newRate.GT(wrapper.Validator.MaxRate) { + return nil, errCommissionRateChangeTooHigh + } + + // TODO: make sure we are reading from the correct snapshot + snapshotValidator, err := chainContext.ReadValidatorSnapshot(wrapper.Address) + if err != nil { + return nil, err + } + rateAtBeginningOfEpoch := snapshotValidator.Validator.Rate + + if rateAtBeginningOfEpoch.IsNil() || (!newRate.IsNil() && !rateAtBeginningOfEpoch.Equal(newRate)) { + wrapper.Validator.UpdateHeight = blockNum + } + + if newRate.Sub(rateAtBeginningOfEpoch).Abs().GT(wrapper.Validator.MaxChangeRate) { + return nil, errCommissionRateChangeTooFast + } + + if err := wrapper.SanityCheck(); err != nil { + return nil, err + } + return wrapper, nil +} + +// VerifyAndDelegateFromMsg verifies the delegate message using the stateDB +// and returns the balance to be deducted by the delegator as well as the +// validatorWrapper with the delegation applied to it. +// +// Note that this function never updates the stateDB, it only reads from stateDB. +func VerifyAndDelegateFromMsg( + stateDB vm.StateDB, msg *staking.Delegate, +) (*staking.ValidatorWrapper, *big.Int, error) { + if stateDB == nil { + return nil, nil, errStateDBIsMissing + } + if msg.Amount.Sign() == -1 { + return nil, nil, errNegativeAmount + } + if !stateDB.IsValidator(msg.ValidatorAddress) { + return nil, nil, errValidatorNotExist + } + wrapper := stateDB.GetStakingInfo(msg.ValidatorAddress) + if wrapper == nil { + return nil, nil, errValidatorNotExist + } + // Check for redelegation + for i := range wrapper.Delegations { + delegation := &wrapper.Delegations[i] + if bytes.Equal(delegation.DelegatorAddress.Bytes(), msg.DelegatorAddress.Bytes()) { + totalInUndelegation := delegation.TotalInUndelegation() + balance := stateDB.GetBalance(msg.DelegatorAddress) + // If the sum of normal balance and the total amount of tokens in undelegation is greater than the amount to delegate + if big.NewInt(0).Add(totalInUndelegation, balance).Cmp(msg.Amount) >= 0 { + // Check if it can use tokens in undelegation to delegate (redelegate) + delegateBalance := big.NewInt(0).Set(msg.Amount) + // Use the latest undelegated token first as it has the longest remaining locking time. + i := len(delegation.Undelegations) - 1 + for ; i >= 0; i-- { + if delegation.Undelegations[i].Amount.Cmp(delegateBalance) <= 0 { + delegateBalance.Sub(delegateBalance, delegation.Undelegations[i].Amount) + } else { + delegation.Undelegations[i].Amount.Sub(delegation.Undelegations[i].Amount, delegateBalance) + delegateBalance = big.NewInt(0) + break + } + } + delegation.Undelegations = delegation.Undelegations[:i+1] + delegation.Amount.Add(delegation.Amount, msg.Amount) + if err := wrapper.SanityCheck(); err != nil { + return nil, nil, err + } + // Return remaining balance to be deducted for delegation + if delegateBalance.Cmp(big.NewInt(0)) < 0 { + return nil, nil, errInsufficientBalanceForStake // shouldn't really happen + } + return wrapper, delegateBalance, nil + } + return nil, nil, errors.Wrapf( + errInsufficientBalanceForStake, + "total-delegated %s own-current-balance %s amount-to-delegate %s", + totalInUndelegation.String(), + balance.String(), + msg.Amount.String(), + ) + } + } + // If no redelegation, create new delegation + if !CanTransfer(stateDB, msg.DelegatorAddress, msg.Amount) { + return nil, nil, errInsufficientBalanceForStake + } + wrapper.Delegations = append(wrapper.Delegations, staking.NewDelegation(msg.DelegatorAddress, msg.Amount)) + if err := wrapper.SanityCheck(); err != nil { + return nil, nil, err + } + return wrapper, msg.Amount, nil +} + +// VerifyAndUndelegateFromMsg verifies the undelegate validator message +// using the stateDB & chainContext and returns the edited validatorWrapper +// with the undelegation applied to it. +// +// Note that this function never updates the stateDB, it only reads from stateDB. +func VerifyAndUndelegateFromMsg( + stateDB vm.StateDB, epoch *big.Int, msg *staking.Undelegate, +) (*staking.ValidatorWrapper, error) { + if stateDB == nil { + return nil, errStateDBIsMissing + } + if epoch == nil { + return nil, errEpochMissing + } + if msg.Amount.Sign() == -1 { + return nil, errNegativeAmount + } + if !stateDB.IsValidator(msg.ValidatorAddress) { + return nil, errValidatorNotExist + } + wrapper := stateDB.GetStakingInfo(msg.ValidatorAddress) + if wrapper == nil { + return nil, errValidatorNotExist + } + for i := range wrapper.Delegations { + delegation := &wrapper.Delegations[i] + if bytes.Equal(delegation.DelegatorAddress.Bytes(), msg.DelegatorAddress.Bytes()) { + if err := delegation.Undelegate(epoch, msg.Amount); err != nil { + return nil, err + } + if err := wrapper.SanityCheck(); err != nil { + return nil, err + } + return wrapper, nil + } + } + return nil, errNoDelegationToUndelegate +} + +// VerifyAndCollectRewardsFromDelegation verifies and collects rewards +// from the given delegation slice using the stateDB. It returns all of the +// edited validatorWrappers and the sum total of the rewards. +// +// Note that this function never updates the stateDB, it only reads from stateDB. +func VerifyAndCollectRewardsFromDelegation( + stateDB vm.StateDB, delegations []staking.DelegationIndex, +) ([]*staking.ValidatorWrapper, *big.Int, error) { + if stateDB == nil { + return nil, nil, errStateDBIsMissing + } + updatedValidatorWrappers := []*staking.ValidatorWrapper{} + totalRewards := big.NewInt(0) + for i := range delegations { + delegation := &delegations[i] + wrapper := stateDB.GetStakingInfo(delegation.ValidatorAddress) + if wrapper == nil { + return nil, nil, errValidatorNotExist + } + if uint64(len(wrapper.Delegations)) > delegation.Index { + delegation := &wrapper.Delegations[delegation.Index] + if delegation.Reward.Cmp(big.NewInt(0)) > 0 { + totalRewards.Add(totalRewards, delegation.Reward) + } + delegation.Reward.SetUint64(0) + } + if err := wrapper.SanityCheck(); err != nil { + return nil, nil, err + } + updatedValidatorWrappers = append(updatedValidatorWrappers, wrapper) + } + if totalRewards.Int64() == 0 { + return nil, nil, errNoRewardsToCollect + } + return updatedValidatorWrappers, totalRewards, nil +} diff --git a/core/state_processor.go b/core/state_processor.go index 69ed0e7d3..f7fbe0a60 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -282,7 +282,7 @@ func StakingToMessage(tx *staking.StakingTransaction, blockNum *big.Int) (types. return types.Message{}, err } - msg := types.NewStakingMessage(from, tx.Nonce(), tx.Gas(), tx.Price(), payload, blockNum) + msg := types.NewStakingMessage(from, tx.Nonce(), tx.Gas(), tx.GasPrice(), payload, blockNum) stkType := tx.StakingType() if _, ok := types.StakingTypeMap[stkType]; !ok { return types.Message{}, staking.ErrInvalidStakingKind diff --git a/core/state_transition.go b/core/state_transition.go index e67434122..6d3ebaf56 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -17,7 +17,6 @@ package core import ( - "bytes" "math" "math/big" @@ -25,7 +24,6 @@ import ( "github.com/ethereum/go-ethereum/rlp" "github.com/harmony-one/harmony/core/types" "github.com/harmony-one/harmony/core/vm" - common2 "github.com/harmony-one/harmony/internal/common" "github.com/harmony-one/harmony/internal/params" "github.com/harmony-one/harmony/internal/utils" staking "github.com/harmony-one/harmony/staking/types" @@ -295,8 +293,6 @@ func (st *StateTransition) StakingTransitionDb() (usedGas uint64, err error) { // Pay intrinsic gas // TODO: propose staking-specific formula for staking transaction gas, err := IntrinsicGas(st.data, false, homestead, msg.Type() == types.StakeCreateVal) - // TODO Remove this logging - utils.Logger().Info().Uint64("Using", gas).Msg("Gas cost of staking transaction being processed") if err != nil { return 0, err @@ -314,54 +310,51 @@ func (st *StateTransition) StakingTransitionDb() (usedGas uint64, err error) { if err = rlp.DecodeBytes(msg.Data(), stkMsg); err != nil { return 0, err } - utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, txn: %+v", msg.Type(), stkMsg) + utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, gas: %d, txn: %+v", msg.Type(), gas, stkMsg) if msg.From() != stkMsg.ValidatorAddress { return 0, errInvalidSigner } - err = st.applyCreateValidatorTx(stkMsg, msg.BlockNum()) - + err = st.verifyAndApplyCreateValidatorTx(stkMsg, msg.BlockNum()) case types.StakeEditVal: stkMsg := &staking.EditValidator{} if err = rlp.DecodeBytes(msg.Data(), stkMsg); err != nil { return 0, err } - utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, txn: %+v", msg.Type(), stkMsg) + utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, gas: %d, txn: %+v", msg.Type(), gas, stkMsg) if msg.From() != stkMsg.ValidatorAddress { return 0, errInvalidSigner } - err = st.applyEditValidatorTx(stkMsg, msg.BlockNum()) - + err = st.verifyAndApplyEditValidatorTx(stkMsg, msg.BlockNum()) case types.Delegate: stkMsg := &staking.Delegate{} if err = rlp.DecodeBytes(msg.Data(), stkMsg); err != nil { return 0, err } - utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, txn: %+v", msg.Type(), stkMsg) + utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, gas: %d, txn: %+v", msg.Type(), gas, stkMsg) if msg.From() != stkMsg.DelegatorAddress { return 0, errInvalidSigner } - err = st.applyDelegateTx(stkMsg) - + err = st.verifyAndApplyDelegateTx(stkMsg) case types.Undelegate: stkMsg := &staking.Undelegate{} if err = rlp.DecodeBytes(msg.Data(), stkMsg); err != nil { return 0, err } - utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, txn: %+v", msg.Type(), stkMsg) + utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, gas: %d, txn: %+v", msg.Type(), gas, stkMsg) if msg.From() != stkMsg.DelegatorAddress { return 0, errInvalidSigner } - err = st.applyUndelegateTx(stkMsg) + err = st.verifyAndApplyUndelegateTx(stkMsg) case types.CollectRewards: stkMsg := &staking.CollectRewards{} if err = rlp.DecodeBytes(msg.Data(), stkMsg); err != nil { return 0, err } - utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, txn: %+v", msg.Type(), stkMsg) + utils.Logger().Info().Msgf("[DEBUG STAKING] staking type: %s, gas: %d, txn: %+v", msg.Type(), gas, stkMsg) if msg.From() != stkMsg.DelegatorAddress { return 0, errInvalidSigner } - err = st.applyCollectRewards(stkMsg) + err = st.verifyAndApplyCollectRewards(stkMsg) default: return 0, staking.ErrInvalidStakingKind } @@ -373,232 +366,68 @@ func (st *StateTransition) StakingTransitionDb() (usedGas uint64, err error) { return st.gasUsed(), err } -func (st *StateTransition) applyCreateValidatorTx(createValidator *staking.CreateValidator, blockNum *big.Int) error { - if createValidator.Amount.Sign() == -1 { - return errNegativeAmount - } - - if val := createValidator.ValidatorAddress; st.state.IsValidator(val) { - return errors.Wrapf(errValidatorExist, common2.MustAddressToBech32(val)) - } - - if !CanTransfer(st.state, createValidator.ValidatorAddress, createValidator.Amount) { - return errInsufficientBalanceForStake - } - - v, err := staking.CreateValidatorFromNewMsg(createValidator, blockNum) +func (st *StateTransition) verifyAndApplyCreateValidatorTx( + createValidator *staking.CreateValidator, blockNum *big.Int, +) error { + wrapper, err := VerifyAndCreateValidatorFromMsg(st.state, st.evm.EpochNumber, blockNum, createValidator) if err != nil { return err } - - zero := big.NewInt(0) - wrapper := staking.ValidatorWrapper{} - wrapper.Validator = *v - wrapper.Delegations = []staking.Delegation{ - staking.NewDelegation(v.Address, createValidator.Amount), - } - wrapper.Snapshot.Epoch = st.evm.EpochNumber - wrapper.Snapshot.NumBlocksSigned = zero - wrapper.Snapshot.NumBlocksToSign = zero - - if err := st.state.UpdateStakingInfo(v.Address, &wrapper); err != nil { + if err := st.state.UpdateStakingInfo(wrapper.Validator.Address, wrapper); err != nil { return err } - - st.state.SetValidatorFlag(v.Address) - st.state.SubBalance(v.Address, createValidator.Amount) + st.state.SetValidatorFlag(wrapper.Validator.Address) + st.state.SubBalance(wrapper.Address, createValidator.Amount) return nil } -func (st *StateTransition) applyEditValidatorTx( +func (st *StateTransition) verifyAndApplyEditValidatorTx( editValidator *staking.EditValidator, blockNum *big.Int, ) error { - if !st.state.IsValidator(editValidator.ValidatorAddress) { - return errValidatorNotExist - } - - wrapper := st.state.GetStakingInfo(editValidator.ValidatorAddress) - if wrapper == nil { - return errValidatorNotExist - } - - if err := staking.UpdateValidatorFromEditMsg(&wrapper.Validator, editValidator); err != nil { + wrapper, err := VerifyAndEditValidatorFromMsg(st.state, st.bc, blockNum, editValidator) + if err != nil { return err } - newRate := wrapper.Validator.Rate + return st.state.UpdateStakingInfo(wrapper.Address, wrapper) +} - // TODO: make sure we are reading from the correct snapshot - snapshotValidator, err := st.bc.ReadValidatorSnapshot(wrapper.Address) +func (st *StateTransition) verifyAndApplyDelegateTx(delegate *staking.Delegate) error { + wrapper, balanceToBeDeducted, err := VerifyAndDelegateFromMsg(st.state, delegate) if err != nil { return err } - rateAtBeginningOfEpoch := snapshotValidator.Validator.Rate - - if rateAtBeginningOfEpoch.IsNil() || (!newRate.IsNil() && !rateAtBeginningOfEpoch.Equal(newRate)) { - wrapper.Validator.UpdateHeight = blockNum - } - - if newRate.Sub(rateAtBeginningOfEpoch).Abs().GT(wrapper.Validator.MaxChangeRate) { - return errCommissionRateChangeTooFast - } - - if newRate.GT(wrapper.Validator.MaxRate) { - return errCommissionRateChangeTooHigh - } - - if err := st.state.UpdateStakingInfo(wrapper.Address, wrapper); err != nil { + if err := st.state.UpdateStakingInfo(wrapper.Validator.Address, wrapper); err != nil { return err } + st.state.SubBalance(delegate.DelegatorAddress, balanceToBeDeducted) return nil } -func (st *StateTransition) applyDelegateTx(delegate *staking.Delegate) error { - if delegate.Amount.Sign() == -1 { - return errNegativeAmount - } - - if !st.state.IsValidator(delegate.ValidatorAddress) { - return errValidatorNotExist - } - wrapper := st.state.GetStakingInfo(delegate.ValidatorAddress) - if wrapper == nil { - return errValidatorNotExist - } - - stateDB := st.state - delegatorExist := false - for i := range wrapper.Delegations { - delegation := &wrapper.Delegations[i] - if bytes.Equal(delegation.DelegatorAddress.Bytes(), delegate.DelegatorAddress.Bytes()) { - delegatorExist = true - totalInUndelegation := delegation.TotalInUndelegation() - balance := stateDB.GetBalance(delegate.DelegatorAddress) - // If the sum of normal balance and the total amount of tokens in undelegation is greater than the amount to delegate - if big.NewInt(0).Add(totalInUndelegation, balance).Cmp(delegate.Amount) >= 0 { - // Firstly use the tokens in undelegation to delegate (redelegate) - delegateBalance := big.NewInt(0).Set(delegate.Amount) - // Use the latest undelegated token first as it has the longest remaining locking time. - i := len(delegation.Undelegations) - 1 - for ; i >= 0; i-- { - if delegation.Undelegations[i].Amount.Cmp(delegateBalance) <= 0 { - delegateBalance.Sub(delegateBalance, delegation.Undelegations[i].Amount) - } else { - delegation.Undelegations[i].Amount.Sub(delegation.Undelegations[i].Amount, delegateBalance) - delegateBalance = big.NewInt(0) - break - } - } - - delegation.Undelegations = delegation.Undelegations[:i+1] - delegation.Amount.Add(delegation.Amount, delegate.Amount) - err := stateDB.UpdateStakingInfo(wrapper.Validator.Address, wrapper) - - if err != nil { - return err - } - // Secondly, if all locked token are used, try use the balance. - if delegateBalance.Cmp(big.NewInt(0)) > 0 { - stateDB.SubBalance(delegate.DelegatorAddress, delegateBalance) - return nil - } - // This shouldn't really happen - return errInsufficientBalanceForStake - } - return errors.Wrapf( - errInsufficientBalanceForStake, - "total-delegated %s own-current-balance %s amount-to-delegate %s", - totalInUndelegation.String(), - balance.String(), - delegate.Amount.String(), - ) - } - } - - if !delegatorExist { - if CanTransfer(stateDB, delegate.DelegatorAddress, delegate.Amount) { - newDelegator := staking.NewDelegation(delegate.DelegatorAddress, delegate.Amount) - wrapper.Delegations = append(wrapper.Delegations, newDelegator) - - if err := stateDB.UpdateStakingInfo(wrapper.Validator.Address, wrapper); err == nil { - stateDB.SubBalance(delegate.DelegatorAddress, delegate.Amount) - } else { - return err - } - } - } - - return nil -} - -func (st *StateTransition) applyUndelegateTx(undelegate *staking.Undelegate) error { - if undelegate.Amount.Sign() == -1 { - return errNegativeAmount - } - - if !st.state.IsValidator(undelegate.ValidatorAddress) { - return errValidatorNotExist - } - wrapper := st.state.GetStakingInfo(undelegate.ValidatorAddress) - if wrapper == nil { - return errValidatorNotExist - } - - stateDB := st.state - delegatorExist := false - for i := range wrapper.Delegations { - delegation := &wrapper.Delegations[i] - if bytes.Equal(delegation.DelegatorAddress.Bytes(), undelegate.DelegatorAddress.Bytes()) { - delegatorExist = true - - err := delegation.Undelegate(st.evm.EpochNumber, undelegate.Amount) - if err != nil { - return err - } - err = stateDB.UpdateStakingInfo(wrapper.Validator.Address, wrapper) - return err - } - } - if !delegatorExist { - return errNoDelegationToUndelegate +func (st *StateTransition) verifyAndApplyUndelegateTx(undelegate *staking.Undelegate) error { + wrapper, err := VerifyAndUndelegateFromMsg(st.state, st.evm.EpochNumber, undelegate) + if err != nil { + return err } - return nil + return st.state.UpdateStakingInfo(wrapper.Validator.Address, wrapper) } -func (st *StateTransition) applyCollectRewards(collectRewards *staking.CollectRewards) error { +func (st *StateTransition) verifyAndApplyCollectRewards(collectRewards *staking.CollectRewards) error { if st.bc == nil { return errors.New("[CollectRewards] No chain context provided") } - chainContext := st.bc - delegations, err := chainContext.ReadDelegationsByDelegator(collectRewards.DelegatorAddress) - + delegations, err := st.bc.ReadDelegationsByDelegator(collectRewards.DelegatorAddress) if err != nil { return err } - - totalRewards := big.NewInt(0) - for i := range delegations { - wrapper := st.state.GetStakingInfo(delegations[i].ValidatorAddress) - if wrapper == nil { - return errValidatorNotExist - } - - if uint64(len(wrapper.Delegations)) > delegations[i].Index { - delegation := &wrapper.Delegations[delegations[i].Index] - if delegation.Reward.Cmp(big.NewInt(0)) > 0 { - totalRewards.Add(totalRewards, delegation.Reward) - } - - delegation.Reward.SetUint64(0) - } - - err = st.state.UpdateStakingInfo(wrapper.Validator.Address, wrapper) - if err != nil { + updatedValidatorWrappers, totalRewards, err := VerifyAndCollectRewardsFromDelegation(st.state, delegations) + if err != nil { + return err + } + for _, wrapper := range updatedValidatorWrappers { + if err := st.state.UpdateStakingInfo(wrapper.Validator.Address, wrapper); err != nil { return err } } - if totalRewards.Int64() == 0 { - return errNoRewardsToCollect - } st.state.AddBalance(collectRewards.DelegatorAddress, totalRewards) return nil } diff --git a/core/tx_journal.go b/core/tx_journal.go index 6392e83e8..e6a305b33 100644 --- a/core/tx_journal.go +++ b/core/tx_journal.go @@ -17,15 +17,22 @@ package core import ( - "errors" "io" "os" + "github.com/pkg/errors" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" "github.com/harmony-one/harmony/core/types" "github.com/harmony-one/harmony/internal/utils" + staking "github.com/harmony-one/harmony/staking/types" +) + +const ( + plainTxID = uint64(1) + stakingTxID = uint64(2) ) // errNoActiveJournal is returned if a transaction is attempted to be inserted @@ -55,9 +62,26 @@ func newTxJournal(path string) *txJournal { } } +// writeJournalTx writes a transaction journal tx to file with a leading uint64 +// to identify the written transaction. +func writeJournalTx(writer io.WriteCloser, tx types.PoolTransaction) error { + if _, ok := tx.(*types.Transaction); ok { + if _, err := writer.Write([]byte{byte(plainTxID)}); err != nil { + return err + } + } else if _, ok := tx.(*staking.StakingTransaction); ok { + if _, err := writer.Write([]byte{byte(stakingTxID)}); err != nil { + return err + } + } else { + return types.ErrUnknownPoolTxType + } + return tx.EncodeRLP(writer) +} + // load parses a transaction journal dump from disk, loading its contents into // the specified pool. -func (journal *txJournal) load(add func([]*types.Transaction) []error) error { +func (journal *txJournal) load(add func(types.PoolTransactions) []error) error { // Skip the parsing if the journal file doesn't exist at all if _, err := os.Stat(journal.path); os.IsNotExist(err) { return nil @@ -76,11 +100,12 @@ func (journal *txJournal) load(add func([]*types.Transaction) []error) error { // Inject all transactions from the journal into the pool stream := rlp.NewStream(input, 0) total, dropped := 0, 0 + batch := types.PoolTransactions{} // Create a method to load a limited batch of transactions and bump the // appropriate progress counters. Then use this method to load all the // journaled transactions in small-ish batches. - loadBatch := func(txs types.Transactions) { + loadBatch := func(txs types.PoolTransactions) { for _, err := range add(txs) { if err != nil { utils.Logger().Error().Err(err).Msg("Failed to add journaled transaction") @@ -88,21 +113,41 @@ func (journal *txJournal) load(add func([]*types.Transaction) []error) error { } } } - var ( - failure error - batch types.Transactions - ) for { - // Parse the next transaction and terminate on error - tx := new(types.Transaction) - if err = stream.Decode(tx); err != nil { - if err != io.EOF { - failure = err + // Parse the next transaction and terminate on errors + var tx types.PoolTransaction + switch txType, err := stream.Uint(); txType { + case plainTxID: + tx = new(types.Transaction) + case stakingTxID: + tx = new(staking.StakingTransaction) + default: + if err != nil { + if err == io.EOF { // reached end of journal file, exit with no error after loading batch + err = nil + } else { + utils.Logger().Info(). + Int("transactions", total). + Int("dropped", dropped). + Msg("Loaded local transaction journal") + } + if batch.Len() > 0 { + loadBatch(batch) + } + return err } + } + + if err = stream.Decode(tx); err != nil { + // should never hit EOF here with the leading ID journal tx encoding scheme. + utils.Logger().Info(). + Int("transactions", total). + Int("dropped", dropped). + Msg("Loaded local transaction journal") if batch.Len() > 0 { loadBatch(batch) } - break + return err } // New transaction parsed, queue up for later, import if threshold is reached total++ @@ -112,25 +157,19 @@ func (journal *txJournal) load(add func([]*types.Transaction) []error) error { batch = batch[:0] } } - utils.Logger().Info(). - Int("transactions", total). - Int("dropped", dropped). - Msg("Loaded local transaction journal") - - return failure } // insert adds the specified transaction to the local disk journal. -func (journal *txJournal) insert(tx *types.Transaction) error { +func (journal *txJournal) insert(tx types.PoolTransaction) error { if journal.writer == nil { return errNoActiveJournal } - return rlp.Encode(journal.writer, tx) + return writeJournalTx(journal.writer, tx) } // rotate regenerates the transaction journal based on the current contents of // the transaction pool. -func (journal *txJournal) rotate(all map[common.Address]types.Transactions) error { +func (journal *txJournal) rotate(all map[common.Address]types.PoolTransactions) error { // Close the current journal (if any is open) if journal.writer != nil { if err := journal.writer.Close(); err != nil { @@ -146,7 +185,7 @@ func (journal *txJournal) rotate(all map[common.Address]types.Transactions) erro journaled := 0 for _, txs := range all { for _, tx := range txs { - if err = rlp.Encode(replacement, tx); err != nil { + if err = writeJournalTx(replacement, tx); err != nil { replacement.Close() return err } diff --git a/core/tx_list.go b/core/tx_list.go index 0939f5ffb..611cc9de6 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -51,27 +51,27 @@ func (h *nonceHeap) Pop() interface{} { // txSortedMap is a nonce->transaction hash map with a heap based index to allow // iterating over the contents in a nonce-incrementing way. type txSortedMap struct { - items map[uint64]*types.Transaction // Hash map storing the transaction data - index *nonceHeap // Heap of nonces of all the stored transactions (non-strict mode) - cache types.Transactions // Cache of the transactions already sorted + items map[uint64]types.PoolTransaction // Hash map storing the transaction data + index *nonceHeap // Heap of nonces of all the stored transactions (non-strict mode) + cache types.PoolTransactions // Cache of the transactions already sorted } // newTxSortedMap creates a new nonce-sorted transaction map. func newTxSortedMap() *txSortedMap { return &txSortedMap{ - items: make(map[uint64]*types.Transaction), + items: make(map[uint64]types.PoolTransaction), index: new(nonceHeap), } } // Get retrieves the current transactions associated with the given nonce. -func (m *txSortedMap) Get(nonce uint64) *types.Transaction { +func (m *txSortedMap) Get(nonce uint64) types.PoolTransaction { return m.items[nonce] } // Put inserts a new transaction into the map, also updating the map's nonce // index. If a transaction already exists with the same nonce, it's overwritten. -func (m *txSortedMap) Put(tx *types.Transaction) { +func (m *txSortedMap) Put(tx types.PoolTransaction) { nonce := tx.Nonce() if m.items[nonce] == nil { heap.Push(m.index, nonce) @@ -82,8 +82,8 @@ func (m *txSortedMap) Put(tx *types.Transaction) { // Forward removes all transactions from the map with a nonce lower than the // provided threshold. Every removed transaction is returned for any post-removal // maintenance. -func (m *txSortedMap) Forward(threshold uint64) types.Transactions { - var removed types.Transactions +func (m *txSortedMap) Forward(threshold uint64) types.PoolTransactions { + var removed types.PoolTransactions // Pop off heap items until the threshold is reached for m.index.Len() > 0 && (*m.index)[0] < threshold { @@ -100,8 +100,8 @@ func (m *txSortedMap) Forward(threshold uint64) types.Transactions { // Filter iterates over the list of transactions and removes all of them for which // the specified function evaluates to true. -func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transactions { - var removed types.Transactions +func (m *txSortedMap) Filter(filter func(types.PoolTransaction) bool) types.PoolTransactions { + var removed types.PoolTransactions // Collect all the transactions to filter out for nonce, tx := range m.items { @@ -125,13 +125,13 @@ func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transac // Cap places a hard limit on the number of items, returning all transactions // exceeding that limit. -func (m *txSortedMap) Cap(threshold int) types.Transactions { +func (m *txSortedMap) Cap(threshold int) types.PoolTransactions { // Short circuit if the number of items is under the limit if len(m.items) <= threshold { return nil } // Otherwise gather and drop the highest nonce'd transactions - var drops types.Transactions + var drops types.PoolTransactions sort.Sort(*m.index) for size := len(m.items); size > threshold; size-- { @@ -176,13 +176,13 @@ func (m *txSortedMap) Remove(nonce uint64) bool { // Note, all transactions with nonces lower than start will also be returned to // prevent getting into and invalid state. This is not something that should ever // happen but better to be self correcting than failing! -func (m *txSortedMap) Ready(start uint64) types.Transactions { +func (m *txSortedMap) Ready(start uint64) types.PoolTransactions { // Short circuit if no transactions are available if m.index.Len() == 0 || (*m.index)[0] > start { return nil } // Otherwise start accumulating incremental transactions - var ready types.Transactions + var ready types.PoolTransactions for next := (*m.index)[0]; m.index.Len() > 0 && (*m.index)[0] == next; next++ { ready = append(ready, m.items[next]) delete(m.items, next) @@ -201,17 +201,17 @@ func (m *txSortedMap) Len() int { // Flatten creates a nonce-sorted slice of transactions based on the loosely // sorted internal representation. The result of the sorting is cached in case // it's requested again before any modifications are made to the contents. -func (m *txSortedMap) Flatten() types.Transactions { +func (m *txSortedMap) Flatten() types.PoolTransactions { // If the sorting was not cached yet, create and cache it if m.cache == nil { - m.cache = make(types.Transactions, 0, len(m.items)) + m.cache = make(types.PoolTransactions, 0, len(m.items)) for _, tx := range m.items { m.cache = append(m.cache, tx) } - sort.Sort(types.TxByNonce(m.cache)) + sort.Sort(types.PoolTxByNonce(m.cache)) } // Copy the cache to prevent accidental modifications - txs := make(types.Transactions, len(m.cache)) + txs := make(types.PoolTransactions, len(m.cache)) copy(txs, m.cache) return txs } @@ -240,7 +240,7 @@ func newTxList(strict bool) *txList { // Overlaps returns whether the transaction specified has the same nonce as one // already contained within the list. -func (l *txList) Overlaps(tx *types.Transaction) bool { +func (l *txList) Overlaps(tx types.PoolTransaction) bool { return l.txs.Get(tx.Nonce()) != nil } @@ -249,7 +249,7 @@ func (l *txList) Overlaps(tx *types.Transaction) bool { // // If the new transaction is accepted into the list, the lists' cost and gas // thresholds are also potentially updated. -func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Transaction) { +func (l *txList) Add(tx types.PoolTransaction, priceBump uint64) (bool, types.PoolTransaction) { // If there's an older better transaction, abort old := l.txs.Get(tx.Nonce()) if old != nil { @@ -275,7 +275,7 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran // Forward removes all transactions from the list with a nonce lower than the // provided threshold. Every removed transaction is returned for any post-removal // maintenance. -func (l *txList) Forward(threshold uint64) types.Transactions { +func (l *txList) Forward(threshold uint64) types.PoolTransactions { return l.txs.Forward(threshold) } @@ -288,7 +288,7 @@ func (l *txList) Forward(threshold uint64) types.Transactions { // a point in calculating all the costs or if the balance covers all. If the threshold // is lower than the costgas cap, the caps will be reset to a new high after removing // the newly invalidated transactions. -func (l *txList) Filter(costLimit *big.Int, gasLimit uint64) (types.Transactions, types.Transactions) { +func (l *txList) Filter(costLimit *big.Int, gasLimit uint64) (types.PoolTransactions, types.PoolTransactions) { // If all transactions are below the threshold, short circuit if l.costcap.Cmp(costLimit) <= 0 && l.gascap <= gasLimit { return nil, nil @@ -297,10 +297,10 @@ func (l *txList) Filter(costLimit *big.Int, gasLimit uint64) (types.Transactions l.gascap = gasLimit // Filter out all the transactions above the account's funds - removed := l.txs.Filter(func(tx *types.Transaction) bool { return tx.Cost().Cmp(costLimit) > 0 || tx.Gas() > gasLimit }) + removed := l.txs.Filter(func(tx types.PoolTransaction) bool { return tx.Cost().Cmp(costLimit) > 0 || tx.Gas() > gasLimit }) // If the list was strict, filter anything above the lowest nonce - var invalids types.Transactions + var invalids types.PoolTransactions if l.strict && len(removed) > 0 { lowest := uint64(math.MaxUint64) @@ -309,21 +309,21 @@ func (l *txList) Filter(costLimit *big.Int, gasLimit uint64) (types.Transactions lowest = nonce } } - invalids = l.txs.Filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) + invalids = l.txs.Filter(func(tx types.PoolTransaction) bool { return tx.Nonce() > lowest }) } return removed, invalids } // Cap places a hard limit on the number of items, returning all transactions // exceeding that limit. -func (l *txList) Cap(threshold int) types.Transactions { +func (l *txList) Cap(threshold int) types.PoolTransactions { return l.txs.Cap(threshold) } // Remove deletes a transaction from the maintained list, returning whether the // transaction was found, and also returning any transaction invalidated due to // the deletion (strict mode only). -func (l *txList) Remove(tx *types.Transaction) (bool, types.Transactions) { +func (l *txList) Remove(tx types.PoolTransaction) (bool, types.PoolTransactions) { // Remove the transaction from the set nonce := tx.Nonce() if removed := l.txs.Remove(nonce); !removed { @@ -331,7 +331,7 @@ func (l *txList) Remove(tx *types.Transaction) (bool, types.Transactions) { } // In strict mode, filter out non-executable transactions if l.strict { - return true, l.txs.Filter(func(tx *types.Transaction) bool { return tx.Nonce() > nonce }) + return true, l.txs.Filter(func(tx types.PoolTransaction) bool { return tx.Nonce() > nonce }) } return true, nil } @@ -343,7 +343,7 @@ func (l *txList) Remove(tx *types.Transaction) (bool, types.Transactions) { // Note, all transactions with nonces lower than start will also be returned to // prevent getting into and invalid state. This is not something that should ever // happen but better to be self correcting than failing! -func (l *txList) Ready(start uint64) types.Transactions { +func (l *txList) Ready(start uint64) types.PoolTransactions { return l.txs.Ready(start) } @@ -360,13 +360,13 @@ func (l *txList) Empty() bool { // Flatten creates a nonce-sorted slice of transactions based on the loosely // sorted internal representation. The result of the sorting is cached in case // it's requested again before any modifications are made to the contents. -func (l *txList) Flatten() types.Transactions { +func (l *txList) Flatten() types.PoolTransactions { return l.txs.Flatten() } // priceHeap is a heap.Interface implementation over transactions for retrieving // price-sorted transactions to discard when the pool fills up. -type priceHeap []*types.Transaction +type priceHeap []types.PoolTransaction func (h priceHeap) Len() int { return len(h) } func (h priceHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } @@ -384,7 +384,7 @@ func (h priceHeap) Less(i, j int) bool { } func (h *priceHeap) Push(x interface{}) { - *h = append(*h, x.(*types.Transaction)) + *h = append(*h, x.(types.PoolTransaction)) } func (h *priceHeap) Pop() interface{} { @@ -412,7 +412,7 @@ func newTxPricedList(all *txLookup) *txPricedList { } // Put inserts a new transaction into the heap. -func (l *txPricedList) Put(tx *types.Transaction) { +func (l *txPricedList) Put(tx types.PoolTransaction) { heap.Push(l.items, tx) } @@ -429,7 +429,7 @@ func (l *txPricedList) Removed() { reheap := make(priceHeap, 0, l.all.Count()) l.stales, l.items = 0, &reheap - l.all.Range(func(hash common.Hash, tx *types.Transaction) bool { + l.all.Range(func(hash common.Hash, tx types.PoolTransaction) bool { *l.items = append(*l.items, tx) return true }) @@ -438,13 +438,13 @@ func (l *txPricedList) Removed() { // Cap finds all the transactions below the given price threshold, drops them // from the priced list and returns them for further removal from the entire pool. -func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.Transactions { - drop := make(types.Transactions, 0, 128) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep +func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.PoolTransactions { + drop := make(types.PoolTransactions, 0, 128) // Remote underpriced transactions to drop + save := make(types.PoolTransactions, 0, 64) // Local underpriced transactions to keep for len(*l.items) > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) + tx := heap.Pop(l.items).(types.PoolTransaction) if l.all.Get(tx.Hash()) == nil { l.stales-- continue @@ -469,14 +469,14 @@ func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.Transact // Underpriced checks whether a transaction is cheaper than (or as cheap as) the // lowest priced transaction currently being tracked. -func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) bool { +func (l *txPricedList) Underpriced(tx types.PoolTransaction, local *accountSet) bool { // Local transactions cannot be underpriced if local.containsTx(tx) { return false } // Discard stale price points if found at the heap start for len(*l.items) > 0 { - head := []*types.Transaction(*l.items)[0] + head := types.PoolTransactions(*l.items)[0] if l.all.Get(head.Hash()) == nil { l.stales-- heap.Pop(l.items) @@ -489,19 +489,19 @@ func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) boo utils.Logger().Error().Msg("Pricing query for empty pool") // This cannot happen, print to catch programming errors return false } - cheapest := []*types.Transaction(*l.items)[0] + cheapest := types.PoolTransactions(*l.items)[0] return cheapest.GasPrice().Cmp(tx.GasPrice()) >= 0 } // Discard finds a number of most underpriced transactions, removes them from the // priced list and returns them for further removal from the entire pool. -func (l *txPricedList) Discard(count int, local *accountSet) types.Transactions { - drop := make(types.Transactions, 0, count) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep +func (l *txPricedList) Discard(count int, local *accountSet) types.PoolTransactions { + drop := make(types.PoolTransactions, 0, count) // Remote underpriced transactions to drop + save := make(types.PoolTransactions, 0, 64) // Local underpriced transactions to keep for len(*l.items) > 0 && count > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) + tx := heap.Pop(l.items).(types.PoolTransaction) if l.all.Get(tx.Hash()) == nil { l.stales-- continue diff --git a/core/tx_list_test.go b/core/tx_list_test.go index e68e8257a..2651b9dc1 100644 --- a/core/tx_list_test.go +++ b/core/tx_list_test.go @@ -30,7 +30,7 @@ func TestStrictTxListAdd(t *testing.T) { // Generate a list of transactions to insert key, _ := crypto.GenerateKey() - txs := make(types.Transactions, 1024) + txs := make(types.PoolTransactions, 1024) for i := 0; i < len(txs); i++ { txs[i] = transaction(uint64(i), 0, key) } diff --git a/core/tx_pool.go b/core/tx_pool.go index 87a782c59..55a2c2d55 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -37,6 +37,8 @@ import ( "github.com/harmony-one/harmony/core/types" hmyCommon "github.com/harmony-one/harmony/internal/common" "github.com/harmony-one/harmony/internal/utils" + "github.com/harmony-one/harmony/shard" + staking "github.com/harmony-one/harmony/staking/types" ) const ( @@ -85,6 +87,10 @@ var ( // attempting to be added to the pool. ErrKnownTransaction = errors.New("known transaction") + // ErrInvalidMsgForStakingDirective is returned if a staking message does not + // match the related directive + ErrInvalidMsgForStakingDirective = errors.New("staking message does not match directive message") + // ErrBlacklistFrom is returned if a transaction's from/source address is blacklisted ErrBlacklistFrom = errors.New("`from` address of transaction in blacklist") @@ -242,30 +248,33 @@ type TxPool struct { wg sync.WaitGroup // for shutdown sync - txnErrorSink func([]types.RPCTransactionError) + errorReporter *txPoolErrorReporter // The reporter for the tx error sinks homestead bool } // NewTxPool creates a new transaction pool to gather, sort and filter inbound // transactions from the network. -func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain blockChain, txnErrorSink func([]types.RPCTransactionError)) *TxPool { +func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain blockChain, + txnErrorSink func([]types.RPCTransactionError), + stakingTxnErrorSink func([]staking.RPCTransactionError), +) *TxPool { // Sanitize the input to ensure no vulnerable gas prices are set config = (&config).sanitize() // Create the transaction pool with its initial settings pool := &TxPool{ - config: config, - chainconfig: chainconfig, - chain: chain, - signer: types.NewEIP155Signer(chainconfig.ChainID), - pending: make(map[common.Address]*txList), - queue: make(map[common.Address]*txList), - beats: make(map[common.Address]time.Time), - all: newTxLookup(), - chainHeadCh: make(chan ChainHeadEvent, chainHeadChanSize), - gasPrice: new(big.Int).SetUint64(config.PriceLimit), - txnErrorSink: txnErrorSink, + config: config, + chainconfig: chainconfig, + chain: chain, + signer: types.NewEIP155Signer(chainconfig.ChainID), + pending: make(map[common.Address]*txList), + queue: make(map[common.Address]*txList), + beats: make(map[common.Address]time.Time), + all: newTxLookup(), + chainHeadCh: make(chan ChainHeadEvent, chainHeadChanSize), + gasPrice: new(big.Int).SetUint64(config.PriceLimit), + errorReporter: newTxPoolErrorReporter(txnErrorSink, stakingTxnErrorSink), } pool.locals = newAccountSet(pool.signer) for _, addr := range config.Locals { @@ -394,7 +403,7 @@ func (pool *TxPool) lockedReset(oldHead, newHead *block.Header) { // of the transaction pool is valid with regard to the chain state. func (pool *TxPool) reset(oldHead, newHead *block.Header) { // If we're reorging an old state, reinject all dropped transactions - var reinject types.Transactions + var reinject types.PoolTransactions if oldHead != nil && oldHead.Hash() != newHead.ParentHash() { // If the reorg is too deep, avoid doing it (will happen during fast sync) @@ -405,14 +414,19 @@ func (pool *TxPool) reset(oldHead, newHead *block.Header) { utils.Logger().Debug().Uint64("depth", depth).Msg("Skipping deep transaction reorg") } else { // Reorg seems shallow enough to pull in all transactions into memory - var discarded, included types.Transactions + var discarded, included types.PoolTransactions var ( rem = pool.chain.GetBlock(oldHead.Hash(), oldHead.Number().Uint64()) add = pool.chain.GetBlock(newHead.Hash(), newHead.Number().Uint64()) ) for rem.NumberU64() > add.NumberU64() { - discarded = append(discarded, rem.Transactions()...) + for _, tx := range rem.Transactions() { + discarded = append(discarded, tx) + } + for _, tx := range rem.StakingTransactions() { + discarded = append(discarded, tx) + } if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { utils.Logger().Error(). Str("block", oldHead.Number().String()). @@ -422,7 +436,12 @@ func (pool *TxPool) reset(oldHead, newHead *block.Header) { } } for add.NumberU64() > rem.NumberU64() { - included = append(included, add.Transactions()...) + for _, tx := range add.Transactions() { + included = append(included, tx) + } + for _, tx := range add.StakingTransactions() { + included = append(included, tx) + } if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { utils.Logger().Error(). Str("block", newHead.Number().String()). @@ -432,7 +451,12 @@ func (pool *TxPool) reset(oldHead, newHead *block.Header) { } } for rem.Hash() != add.Hash() { - discarded = append(discarded, rem.Transactions()...) + for _, tx := range rem.Transactions() { + discarded = append(discarded, tx) + } + for _, tx := range rem.StakingTransactions() { + discarded = append(discarded, tx) + } if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { utils.Logger().Error(). Str("block", oldHead.Number().String()). @@ -440,7 +464,12 @@ func (pool *TxPool) reset(oldHead, newHead *block.Header) { Msg("Unrooted old chain seen by tx pool") return } - included = append(included, add.Transactions()...) + for _, tx := range add.Transactions() { + included = append(included, tx) + } + for _, tx := range add.StakingTransactions() { + included = append(included, tx) + } if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { utils.Logger().Error(). Str("block", newHead.Number().String()). @@ -449,7 +478,7 @@ func (pool *TxPool) reset(oldHead, newHead *block.Header) { return } } - reinject = types.TxDifference(discarded, included) + reinject = types.PoolTxDifference(discarded, included) } } // Initialize the internal state to the current head @@ -566,15 +595,15 @@ func (pool *TxPool) stats() (int, int) { // Content retrieves the data content of the transaction pool, returning all the // pending as well as queued transactions, grouped by account and sorted by nonce. -func (pool *TxPool) Content() (map[common.Address]types.Transactions, map[common.Address]types.Transactions) { +func (pool *TxPool) Content() (map[common.Address]types.PoolTransactions, map[common.Address]types.PoolTransactions) { pool.mu.Lock() defer pool.mu.Unlock() - pending := make(map[common.Address]types.Transactions) + pending := make(map[common.Address]types.PoolTransactions) for addr, list := range pool.pending { pending[addr] = list.Flatten() } - queued := make(map[common.Address]types.Transactions) + queued := make(map[common.Address]types.PoolTransactions) for addr, list := range pool.queue { queued[addr] = list.Flatten() } @@ -584,11 +613,11 @@ func (pool *TxPool) Content() (map[common.Address]types.Transactions, map[common // Pending retrieves all currently processable transactions, grouped by origin // account and sorted by nonce. The returned transaction set is a copy and can be // freely modified by calling code. -func (pool *TxPool) Pending() (map[common.Address]types.Transactions, error) { +func (pool *TxPool) Pending() (map[common.Address]types.PoolTransactions, error) { pool.mu.Lock() defer pool.mu.Unlock() - pending := make(map[common.Address]types.Transactions) + pending := make(map[common.Address]types.PoolTransactions) for addr, list := range pool.pending { pending[addr] = list.Flatten() } @@ -606,8 +635,8 @@ func (pool *TxPool) Locals() []common.Address { // local retrieves all currently known local transactions, grouped by origin // account and sorted by nonce. The returned transaction set is a copy and can be // freely modified by calling code. -func (pool *TxPool) local() map[common.Address]types.Transactions { - txs := make(map[common.Address]types.Transactions) +func (pool *TxPool) local() map[common.Address]types.PoolTransactions { + txs := make(map[common.Address]types.PoolTransactions) for addr := range pool.locals.accounts { if pending := pool.pending[addr]; pending != nil { txs[addr] = append(txs[addr], pending.Flatten()...) @@ -621,7 +650,7 @@ func (pool *TxPool) local() map[common.Address]types.Transactions { // validateTx checks whether a transaction is valid according to the consensus // rules and adheres to some heuristic limits of the local node (price and size). -func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { +func (pool *TxPool) validateTx(tx types.PoolTransaction, local bool) error { // Heuristic limit, reject transactions over 32KB to prevent DOS attacks if tx.Size() > 32*1024 { return errors.WithMessagef(ErrOversizedData, "transaction size is %s", tx.Size().String()) @@ -636,7 +665,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { return errors.WithMessagef(ErrGasLimit, "transaction gas is %d", tx.Gas()) } // Make sure the transaction is signed properly - from, err := types.Sender(pool.signer, tx) + from, err := types.PoolTransactionSender(pool.signer, tx) if err != nil { if b32, err := hmyCommon.AddressToBech32(from); err == nil { return errors.WithMessagef(ErrInvalidSender, "transaction sender is %s", b32) @@ -675,17 +704,131 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if pool.currentState.GetBalance(from).Cmp(tx.Cost()) < 0 { return ErrInsufficientFunds } - // TODO(Daniel): add support for staking txn - create validator - intrGas, err := IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead, false) + intrGas := uint64(0) + stakingTx, isStakingTx := tx.(*staking.StakingTransaction) + if isStakingTx { + intrGas, err = IntrinsicGas(tx.Data(), false, pool.homestead, stakingTx.StakingType() == staking.DirectiveCreateValidator) + } else { + intrGas, err = IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead, false) + } if err != nil { return err } if tx.Gas() < intrGas { return errors.WithMessagef(ErrIntrinsicGas, "transaction gas is %d", tx.Gas()) } + // Do more checks if it is a staking transaction + if isStakingTx { + return pool.validateStakingTx(stakingTx) + } return nil } +// validateStakingTx checks the staking message based on the staking directive +func (pool *TxPool) validateStakingTx(tx *staking.StakingTransaction) error { + // from address already validated + from, _ := types.PoolTransactionSender(pool.signer, tx) + b32, _ := hmyCommon.AddressToBech32(from) + + switch tx.StakingType() { + case staking.DirectiveCreateValidator: + msg, err := staking.RLPDecodeStakeMsg(tx.Data(), staking.DirectiveCreateValidator) + if err != nil { + return err + } + stkMsg, ok := msg.(*staking.CreateValidator) + if !ok { + return ErrInvalidMsgForStakingDirective + } + if from != stkMsg.ValidatorAddress { + return errors.WithMessagef(ErrInvalidSender, "staking transaction sender is %s", b32) + } + currentBlockNumber := pool.chain.CurrentBlock().Number() + pendingBlockNumber := new(big.Int).Add(currentBlockNumber, big.NewInt(1)) + pendingEpoch := pool.chain.CurrentBlock().Epoch() + if shard.Schedule.IsLastBlock(currentBlockNumber.Uint64()) { + pendingEpoch = new(big.Int).Add(pendingEpoch, big.NewInt(1)) + } + _, err = VerifyAndCreateValidatorFromMsg(pool.currentState, pendingEpoch, pendingBlockNumber, stkMsg) + return err + case staking.DirectiveEditValidator: + msg, err := staking.RLPDecodeStakeMsg(tx.Data(), staking.DirectiveEditValidator) + if err != nil { + return err + } + stkMsg, ok := msg.(*staking.EditValidator) + if !ok { + return ErrInvalidMsgForStakingDirective + } + if from != stkMsg.ValidatorAddress { + return errors.WithMessagef(ErrInvalidSender, "staking transaction sender is %s", b32) + } + chainContext, ok := pool.chain.(ChainContext) + if !ok { + chainContext = nil // might use testing blockchain, set to nil for verifier to handle. + } + pendingBlockNumber := new(big.Int).Add(pool.chain.CurrentBlock().Number(), big.NewInt(1)) + _, err = VerifyAndEditValidatorFromMsg(pool.currentState, chainContext, pendingBlockNumber, stkMsg) + return err + case staking.DirectiveDelegate: + msg, err := staking.RLPDecodeStakeMsg(tx.Data(), staking.DirectiveDelegate) + if err != nil { + return err + } + stkMsg, ok := msg.(*staking.Delegate) + if !ok { + return ErrInvalidMsgForStakingDirective + } + if from != stkMsg.DelegatorAddress { + return errors.WithMessagef(ErrInvalidSender, "staking transaction sender is %s", b32) + } + _, _, err = VerifyAndDelegateFromMsg(pool.currentState, stkMsg) + return err + case staking.DirectiveUndelegate: + msg, err := staking.RLPDecodeStakeMsg(tx.Data(), staking.DirectiveUndelegate) + if err != nil { + return err + } + stkMsg, ok := msg.(*staking.Undelegate) + if !ok { + return ErrInvalidMsgForStakingDirective + } + if from != stkMsg.DelegatorAddress { + return errors.WithMessagef(ErrInvalidSender, "staking transaction sender is %s", b32) + } + pendingEpoch := pool.chain.CurrentBlock().Epoch() + if shard.Schedule.IsLastBlock(pool.chain.CurrentBlock().Number().Uint64()) { + pendingEpoch = new(big.Int).Add(pendingEpoch, big.NewInt(1)) + } + _, err = VerifyAndUndelegateFromMsg(pool.currentState, pendingEpoch, stkMsg) + return err + case staking.DirectiveCollectRewards: + msg, err := staking.RLPDecodeStakeMsg(tx.Data(), staking.DirectiveCollectRewards) + if err != nil { + return err + } + stkMsg, ok := msg.(*staking.CollectRewards) + if !ok { + return ErrInvalidMsgForStakingDirective + } + if from != stkMsg.DelegatorAddress { + return errors.WithMessagef(ErrInvalidSender, "staking transaction sender is %s", b32) + } + chain, ok := pool.chain.(ChainContext) + if !ok { + return nil // for testing, chain could be testing blockchain + } + delegations, err := chain.ReadDelegationsByDelegator(stkMsg.DelegatorAddress) + if err != nil { + return err + } + _, _, err = VerifyAndCollectRewardsFromDelegation(pool.currentState, delegations) + return err + default: + return staking.ErrInvalidStakingKind + } +} + // add validates a transaction and inserts it into the non-executable queue for // later pending promotion and execution. If the transaction is a replacement for // an already pending or queued one, it overwrites the previous and returns this @@ -694,7 +837,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { // If a newly added transaction is marked as local, its sending account will be // whitelisted, preventing any associated transaction from being dropped out of // the pool due to pricing constraints. -func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { +func (pool *TxPool) add(tx types.PoolTransaction, local bool) (bool, error) { logger := utils.Logger().With().Stack().Logger() // If the transaction is already known, discard it hash := tx.Hash() @@ -733,7 +876,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { } } // If the transaction is replacing an already pending one, do directly - from, _ := types.Sender(pool.signer, tx) // already validated + from, _ := types.PoolTransactionSender(pool.signer, tx) // already validated if list := pool.pending[from]; list != nil && list.Overlaps(tx) { // Nonce already pending, check if required price bump is met inserted, old := list.Add(tx, pool.config.PriceBump) @@ -758,7 +901,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { Msg("Pooled new executable transaction") // We've directly injected a replacement transaction, notify subsystems - // go pool.txFeed.Send(NewTxsEvent{types.Transactions{tx}}) + // go pool.txFeed.Send(NewTxsEvent{types.PoolTransactions{tx}}) return old != nil, nil } @@ -786,7 +929,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { // Add adds a transaction to the pool if valid and passes it to the tx relay // backend -func (pool *TxPool) Add(ctx context.Context, tx *types.Transaction) error { +func (pool *TxPool) Add(ctx context.Context, tx *types.PoolTransaction) error { // TODO(ricl): placeholder // TODO(minhdoan): follow with richard why we need this. As of now TxPool is not used now. return nil @@ -795,9 +938,9 @@ func (pool *TxPool) Add(ctx context.Context, tx *types.Transaction) error { // enqueueTx inserts a new transaction into the non-executable transaction queue. // // Note, this method assumes the pool lock is held! -func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, error) { +func (pool *TxPool) enqueueTx(hash common.Hash, tx types.PoolTransaction) (bool, error) { // Try to insert the transaction into the future queue - from, _ := types.Sender(pool.signer, tx) // already validated + from, _ := types.PoolTransactionSender(pool.signer, tx) // already validated if pool.queue[from] == nil { pool.queue[from] = newTxList(false) } @@ -822,7 +965,7 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er // journalTx adds the specified transaction to the local disk journal if it is // deemed to have been sent from a local account. -func (pool *TxPool) journalTx(from common.Address, tx *types.Transaction) { +func (pool *TxPool) journalTx(from common.Address, tx types.PoolTransaction) { // Only journal if it's enabled and the transaction is local if pool.journal == nil || !pool.locals.contains(from) { return @@ -836,7 +979,7 @@ func (pool *TxPool) journalTx(from common.Address, tx *types.Transaction) { // and returns whether it was inserted or an older was better. // // Note, this method assumes the pool lock is held! -func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.Transaction) bool { +func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx types.PoolTransaction) bool { // Try to insert the transaction into the pending queue if pool.pending[addr] == nil { pool.pending[addr] = newTxList(true) @@ -874,33 +1017,33 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T // AddLocal enqueues a single transaction into the pool if it is valid, marking // the sender as a local one in the mean time, ensuring it goes around the local // pricing constraints. -func (pool *TxPool) AddLocal(tx *types.Transaction) error { +func (pool *TxPool) AddLocal(tx types.PoolTransaction) error { return pool.addTx(tx, !pool.config.NoLocals) } // AddRemote enqueues a single transaction into the pool if it is valid. If the // sender is not among the locally tracked ones, full pricing constraints will // apply. -func (pool *TxPool) AddRemote(tx *types.Transaction) error { +func (pool *TxPool) AddRemote(tx types.PoolTransaction) error { return pool.addTx(tx, false) } // AddLocals enqueues a batch of transactions into the pool if they are valid, // marking the senders as a local ones in the mean time, ensuring they go around // the local pricing constraints. -func (pool *TxPool) AddLocals(txs []*types.Transaction) []error { +func (pool *TxPool) AddLocals(txs types.PoolTransactions) []error { return pool.addTxs(txs, !pool.config.NoLocals) } // AddRemotes enqueues a batch of transactions into the pool if they are valid. // If the senders are not among the locally tracked ones, full pricing constraints // will apply. -func (pool *TxPool) AddRemotes(txs []*types.Transaction) []error { +func (pool *TxPool) AddRemotes(txs types.PoolTransactions) []error { return pool.addTxs(txs, false) } // addTx enqueues a single transaction into the pool if it is valid. -func (pool *TxPool) addTx(tx *types.Transaction, local bool) error { +func (pool *TxPool) addTx(tx types.PoolTransaction, local bool) error { pool.mu.Lock() defer pool.mu.Unlock() @@ -909,20 +1052,24 @@ func (pool *TxPool) addTx(tx *types.Transaction, local bool) error { if err != nil { errCause := errors.Cause(err) if errCause != ErrKnownTransaction { - pool.txnErrorSink([]types.RPCTransactionError{*types.NewRPCTransactionError(tx.Hash(), err)}) + pool.errorReporter.add(tx, err) } return errCause } // If we added a new transaction, run promotion checks and return if !replace { - from, _ := types.Sender(pool.signer, tx) // already validated + from, _ := types.PoolTransactionSender(pool.signer, tx) // already validated pool.promoteExecutables([]common.Address{from}) } + if err := pool.errorReporter.report(); err != nil { + utils.Logger().Error().Err(err). + Msg("could not report failed transactions in tx pool when adding 1 tx") + } return nil } // addTxs attempts to queue a batch of transactions if they are valid. -func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) []error { +func (pool *TxPool) addTxs(txs types.PoolTransactions, local bool) []error { pool.mu.Lock() defer pool.mu.Unlock() @@ -931,22 +1078,22 @@ func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) []error { // addTxsLocked attempts to queue a batch of transactions if they are valid, // whilst assuming the transaction pool lock is already held. -func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) []error { +func (pool *TxPool) addTxsLocked(txs types.PoolTransactions, local bool) []error { // Add the batch of transaction, tracking the accepted ones dirty := make(map[common.Address]struct{}) - errs := make([]error, len(txs)) - erroredTxns := []types.RPCTransactionError{} + errs := make([]error, txs.Len()) for i, tx := range txs { replace, err := pool.add(tx, local) if err == nil && !replace { - from, _ := types.Sender(pool.signer, tx) // already validated + from, _ := types.PoolTransactionSender(pool.signer, tx) // already validated dirty[from] = struct{}{} } - if err != nil && err != ErrKnownTransaction { - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) + errCause := errors.Cause(err) + if err != nil && errCause != ErrKnownTransaction { + pool.errorReporter.add(tx, err) } - errs[i] = errors.Cause(err) + errs[i] = errCause } // Only reprocess the internal state if something was actually added if len(dirty) > 0 { @@ -957,7 +1104,10 @@ func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) []error { pool.promoteExecutables(addrs) } - pool.txnErrorSink(erroredTxns) + if err := pool.errorReporter.report(); err != nil { + utils.Logger().Error().Err(err). + Msg("could not report failed transactions in tx pool when adding txs") + } return errs } @@ -970,7 +1120,7 @@ func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { status := make([]TxStatus, len(hashes)) for i, hash := range hashes { if tx := pool.all.Get(hash); tx != nil { - from, _ := types.Sender(pool.signer, tx) // already validated + from, _ := types.PoolTransactionSender(pool.signer, tx) // already validated if pool.pending[from] != nil && pool.pending[from].txs.items[tx.Nonce()] != nil { status[i] = TxStatusPending } else { @@ -983,7 +1133,7 @@ func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { // Get returns a transaction if it is contained in the pool // and nil otherwise. -func (pool *TxPool) Get(hash common.Hash) *types.Transaction { +func (pool *TxPool) Get(hash common.Hash) types.PoolTransaction { return pool.all.Get(hash) } @@ -995,7 +1145,7 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { if tx == nil { return } - addr, _ := types.Sender(pool.signer, tx) // already validated during insertion + addr, _ := types.PoolTransactionSender(pool.signer, tx) // already validated during insertion // Remove it from the list of known transactions pool.all.Remove(hash) @@ -1012,7 +1162,9 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { } // Postpone any invalidated transactions for _, tx := range invalids { - pool.enqueueTx(tx.Hash(), tx) + if _, err := pool.enqueueTx(tx.Hash(), tx); err != nil { + pool.errorReporter.add(tx, err) + } } // Update the account nonce if needed if nonce := tx.Nonce(); pool.pendingState.GetNonce(addr) > nonce { @@ -1028,6 +1180,11 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { delete(pool.queue, addr) } } + + if err := pool.errorReporter.report(); err != nil { + utils.Logger().Error().Err(err). + Msg("could not report failed transactions in tx pool when removing tx from queue") + } } // promoteExecutables moves transactions that have become processable from the @@ -1035,9 +1192,8 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { // invalidated transactions (low nonce, low balance) are deleted. func (pool *TxPool) promoteExecutables(accounts []common.Address) { // Track the promoted transactions to broadcast them at once - var promoted []*types.Transaction + var promoted types.PoolTransactions logger := utils.Logger().With().Stack().Logger() - erroredTxns := []types.RPCTransactionError{} // Gather all the accounts potentially needing updates if accounts == nil { @@ -1057,10 +1213,6 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range list.Forward(nonce) { hash := tx.Hash() logger.Warn().Str("hash", hash.Hex()).Msg("Removed old queued transaction") - if pool.chain.CurrentBlock().Transaction(hash) == nil { - err := fmt.Errorf("old transaction, nonce %d is too low", nonce) - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) - } pool.all.Remove(hash) pool.priced.Removed() } @@ -1069,8 +1221,6 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range drops { hash := tx.Hash() logger.Warn().Str("hash", hash.Hex()).Msg("Removed unpayable queued transaction") - err := fmt.Errorf("unpayable transaction, out of gas or balance of %d cannot pay cost of %d", tx.Value(), tx.Cost()) - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) pool.all.Remove(hash) pool.priced.Removed() queuedNofundsCounter.Inc(1) @@ -1088,8 +1238,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range list.Cap(int(pool.config.AccountQueue)) { hash := tx.Hash() logger.Warn().Str("hash", hash.Hex()).Msg("Removed cap-exceeding queued transaction") - err := fmt.Errorf("exceeds cap for queued transactions for account %s", addr.String()) - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) + pool.errorReporter.add(tx, fmt.Errorf("exceeds cap for queued transactions for account %s", addr.String())) pool.all.Remove(hash) pool.priced.Removed() queuedRateLimitCounter.Inc(1) @@ -1138,8 +1287,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range list.Cap(list.Len() - 1) { // Drop the transaction from the global pools too hash := tx.Hash() - err := fmt.Errorf("fairness-exceeding pending transaction") - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) + pool.errorReporter.add(tx, fmt.Errorf("fairness-exceeding pending transaction")) pool.all.Remove(hash) pool.priced.Removed() @@ -1162,8 +1310,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { for _, tx := range list.Cap(list.Len() - 1) { // Drop the transaction from the global pools too hash := tx.Hash() - err := fmt.Errorf("fairness-exceeding pending transaction") - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) + pool.errorReporter.add(tx, fmt.Errorf("fairness-exceeding pending transaction")) pool.all.Remove(hash) pool.priced.Removed() @@ -1204,8 +1351,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { // Drop all transactions if they are less than the overflow if size := uint64(list.Len()); size <= drop { for _, tx := range list.Flatten() { - err := fmt.Errorf("exceeds global cap for queued transactions") - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) + pool.errorReporter.add(tx, fmt.Errorf("exceeds global cap for queued transactions")) pool.removeTx(tx.Hash(), true) } drop -= size @@ -1215,8 +1361,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { // Otherwise drop only last few transactions txs := list.Flatten() for i := len(txs) - 1; i >= 0 && drop > 0; i-- { - err := fmt.Errorf("exceeds global cap for queued transactions") - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(txs[i].Hash(), err)) + pool.errorReporter.add(txs[i], fmt.Errorf("exceeds global cap for queued transactions")) pool.removeTx(txs[i].Hash(), true) drop-- queuedRateLimitCounter.Inc(1) @@ -1224,7 +1369,10 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { } } - pool.txnErrorSink(erroredTxns) + if err := pool.errorReporter.report(); err != nil { + logger.Error().Err(err). + Msg("could not report failed transactions in tx pool when promoting executables") + } } // demoteUnexecutables removes invalid and processed transactions from the pools @@ -1233,7 +1381,6 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { func (pool *TxPool) demoteUnexecutables() { // Iterate over all accounts and demote any non-executable transactions logger := utils.Logger().With().Stack().Logger() - erroredTxns := []types.RPCTransactionError{} for addr, list := range pool.pending { nonce := pool.currentState.GetNonce(addr) @@ -1242,10 +1389,6 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range list.Forward(nonce) { hash := tx.Hash() logger.Warn().Str("hash", hash.Hex()).Msg("Removed old pending transaction") - if pool.chain.CurrentBlock().Transaction(hash) == nil { - err := fmt.Errorf("old transaction, nonce %d is too low", nonce) - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) - } pool.all.Remove(hash) pool.priced.Removed() } @@ -1254,8 +1397,6 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range drops { hash := tx.Hash() logger.Warn().Str("hash", hash.Hex()).Msg("Removed unpayable pending transaction") - err := fmt.Errorf("unpayable transaction, out of gas or balance of %d cannot pay cost of %d", tx.Value(), tx.Cost()) - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) pool.all.Remove(hash) pool.priced.Removed() pendingNofundsCounter.Inc(1) @@ -1263,18 +1404,18 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range invalids { hash := tx.Hash() logger.Warn().Str("hash", hash.Hex()).Msg("Demoting pending transaction") - err := fmt.Errorf("demoting pending transaction") - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) - pool.enqueueTx(hash, tx) + if _, err := pool.enqueueTx(hash, tx); err != nil { + pool.errorReporter.add(tx, err) + } } // If there's a gap in front, alert (should never happen) and postpone all transactions if list.Len() > 0 && list.txs.Get(nonce) == nil { for _, tx := range list.Cap(0) { hash := tx.Hash() logger.Error().Str("hash", hash.Hex()).Msg("Demoting invalidated transaction") - err := fmt.Errorf("demoting invalid transaction") - erroredTxns = append(erroredTxns, *types.NewRPCTransactionError(tx.Hash(), err)) - pool.enqueueTx(hash, tx) + if _, err := pool.enqueueTx(hash, tx); err != nil { + pool.errorReporter.add(tx, err) + } } } // Delete the entire queue entry if it became empty. @@ -1283,8 +1424,61 @@ func (pool *TxPool) demoteUnexecutables() { delete(pool.beats, addr) } - pool.txnErrorSink(erroredTxns) + if err := pool.errorReporter.report(); err != nil { + logger.Error().Err(err). + Msg("could not report failed transactions in tx pool when demoting unexecutables") + } + } +} + +// txPoolErrorReporter holds and reports transaction errors in the tx-pool. +// Format assumes that error i in errors corresponds to transaction i in transactions. +type txPoolErrorReporter struct { + transactions types.PoolTransactions + errors []error + txnErrorReportSink func([]types.RPCTransactionError) + stkTxnErrorReportSink func([]staking.RPCTransactionError) +} + +func newTxPoolErrorReporter(txnErrorSink func([]types.RPCTransactionError), + stakingTxnErrorSink func([]staking.RPCTransactionError), +) *txPoolErrorReporter { + return &txPoolErrorReporter{ + transactions: types.PoolTransactions{}, + errors: []error{}, + txnErrorReportSink: txnErrorSink, + stkTxnErrorReportSink: stakingTxnErrorSink, + } +} + +func (txErrs *txPoolErrorReporter) add(tx types.PoolTransaction, err error) { + txErrs.transactions = append(txErrs.transactions, tx) + txErrs.errors = append(txErrs.errors, err) +} + +func (txErrs *txPoolErrorReporter) reset() { + txErrs.transactions = types.PoolTransactions{} + txErrs.errors = []error{} +} + +// report errors thrown in the tx pool to the appropriate error sink. +// It resets the held errors after the errors are reported to the sink. +func (txErrs *txPoolErrorReporter) report() error { + plainTxErrors := []types.RPCTransactionError{} + stakingTxErrors := []staking.RPCTransactionError{} + for i, tx := range txErrs.transactions { + if plainTx, ok := tx.(*types.Transaction); ok { + plainTxErrors = append(plainTxErrors, types.NewRPCTransactionError(plainTx.Hash(), txErrs.errors[i])) + } else if stakingTx, ok := tx.(*staking.StakingTransaction); ok { + stakingTxErrors = append(stakingTxErrors, staking.NewRPCTransactionError(stakingTx.Hash(), stakingTx.StakingType(), txErrs.errors[i])) + } else { + return types.ErrUnknownPoolTxType + } } + txErrs.txnErrorReportSink(plainTxErrors) + txErrs.stkTxnErrorReportSink(stakingTxErrors) + txErrs.reset() + return nil } // addressByHeartbeat is an account address tagged with its last activity timestamp. @@ -1324,8 +1518,8 @@ func (as *accountSet) contains(addr common.Address) bool { // containsTx checks if the sender of a given tx is within the set. If the sender // cannot be derived, this method returns false. -func (as *accountSet) containsTx(tx *types.Transaction) bool { - if addr, err := types.Sender(as.signer, tx); err == nil { +func (as *accountSet) containsTx(tx types.PoolTransaction) bool { + if addr, err := types.PoolTransactionSender(as.signer, tx); err == nil { return as.contains(addr) } return false @@ -1360,19 +1554,19 @@ func (as *accountSet) flatten() []common.Address { // peeking into the pool in TxPool.Get without having to acquire the widely scoped // TxPool.mu mutex. type txLookup struct { - all map[common.Hash]*types.Transaction + all map[common.Hash]types.PoolTransaction lock sync.RWMutex } // newTxLookup returns a new txLookup structure. func newTxLookup() *txLookup { return &txLookup{ - all: make(map[common.Hash]*types.Transaction), + all: make(map[common.Hash]types.PoolTransaction), } } // Range calls f on each key and value present in the map. -func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction) bool) { +func (t *txLookup) Range(f func(hash common.Hash, tx types.PoolTransaction) bool) { t.lock.RLock() defer t.lock.RUnlock() @@ -1384,7 +1578,7 @@ func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction) bool) { } // Get returns a transaction if it exists in the lookup, or nil if not found. -func (t *txLookup) Get(hash common.Hash) *types.Transaction { +func (t *txLookup) Get(hash common.Hash) types.PoolTransaction { t.lock.RLock() defer t.lock.RUnlock() @@ -1400,7 +1594,7 @@ func (t *txLookup) Count() int { } // Add adds a transaction to the lookup. -func (t *txLookup) Add(tx *types.Transaction) { +func (t *txLookup) Add(tx types.PoolTransaction) { t.lock.Lock() defer t.lock.Unlock() diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 3992f7044..98f82a74d 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -30,16 +30,24 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" + "github.com/harmony-one/bls/ffi/go/bls" blockfactory "github.com/harmony-one/harmony/block/factory" "github.com/harmony-one/harmony/common/denominations" "github.com/harmony-one/harmony/core/state" "github.com/harmony-one/harmony/core/types" + "github.com/harmony-one/harmony/crypto/hash" "github.com/harmony-one/harmony/internal/params" + "github.com/harmony-one/harmony/numeric" + "github.com/harmony-one/harmony/shard" + staking "github.com/harmony-one/harmony/staking/types" ) -// testTxPoolConfig is a transaction pool configuration without stateful disk -// sideeffects used during testing. -var testTxPoolConfig TxPoolConfig +var ( + // testTxPoolConfig is a transaction pool configuration without stateful disk sideeffects used during testing. + testTxPoolConfig TxPoolConfig + testBLSPubKey = "30b2c38b1316da91e068ac3bd8751c0901ef6c02a1d58bc712104918302c6ed03d5894671d0c816dad2b4d303320f202" + testBLSPrvKey = "c6d7603520311f7a4e6aac0b26701fc433b75b38df504cd416ef2b900cd66205" +) func init() { testTxPoolConfig = DefaultTxPoolConfig @@ -70,21 +78,67 @@ func (bc *testBlockChain) SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) even return bc.chainHeadFeed.Subscribe(ch) } -func transaction(nonce uint64, gaslimit uint64, key *ecdsa.PrivateKey) *types.Transaction { +// TODO: more staking tests in tx pool & testing lib +func stakingCreateValidatorTransaction(key *ecdsa.PrivateKey) (*staking.StakingTransaction, error) { + stakePayloadMaker := func() (staking.Directive, interface{}) { + p := &bls.PublicKey{} + p.DeserializeHexStr(testBLSPubKey) + pub := shard.BlsPublicKey{} + pub.FromLibBLSPublicKey(p) + messageBytes := []byte(staking.BlsVerificationStr) + privateKey := &bls.SecretKey{} + privateKey.DeserializeHexStr(testBLSPrvKey) + msgHash := hash.Keccak256(messageBytes) + signature := privateKey.SignHash(msgHash[:]) + var sig shard.BLSSignature + copy(sig[:], signature.Serialize()) + + ra, _ := numeric.NewDecFromStr("0.7") + maxRate, _ := numeric.NewDecFromStr("1") + maxChangeRate, _ := numeric.NewDecFromStr("0.5") + return staking.DirectiveCreateValidator, staking.CreateValidator{ + Description: staking.Description{ + Name: "SuperHero", + Identity: "YouWouldNotKnow", + Website: "Secret Website", + SecurityContact: "LicenseToKill", + Details: "blah blah blah", + }, + CommissionRates: staking.CommissionRates{ + Rate: ra, + MaxRate: maxRate, + MaxChangeRate: maxChangeRate, + }, + MinSelfDelegation: big.NewInt(1e18), + MaxTotalDelegation: big.NewInt(3e18), + ValidatorAddress: crypto.PubkeyToAddress(key.PublicKey), + SlotPubKeys: []shard.BlsPublicKey{pub}, + SlotKeySigs: []shard.BLSSignature{sig}, + Amount: big.NewInt(1e18), + } + } + + gasPrice := big.NewInt(10000) + tx, _ := staking.NewStakingTransaction(0, 1e10, gasPrice, stakePayloadMaker) + return staking.Sign(tx, staking.NewEIP155Signer(tx.ChainID()), key) +} + +func transaction(nonce uint64, gaslimit uint64, key *ecdsa.PrivateKey) types.PoolTransaction { return pricedTransaction(nonce, gaslimit, big.NewInt(1), key) } -func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ecdsa.PrivateKey) *types.Transaction { - tx, _ := types.SignTx(types.NewTransaction(nonce, common.Address{}, 0, big.NewInt(100), gaslimit, gasprice, nil), types.HomesteadSigner{}, key) - return tx +func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ecdsa.PrivateKey) types.PoolTransaction { + signedTx, _ := types.SignTx(types.NewTransaction(nonce, common.Address{}, 0, big.NewInt(100), gaslimit, gasprice, nil), types.HomesteadSigner{}, key) + return signedTx } func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { statedb, _ := state.New(common.Hash{}, state.NewDatabase(ethdb.NewMemDatabase())) - blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + blockchain := &testBlockChain{statedb, 1e18, new(event.Feed)} key, _ := crypto.GenerateKey() - pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) return pool, key } @@ -146,8 +200,8 @@ func validateEvents(events chan NewTxsEvent, count int) error { return nil } -func deriveSender(tx *types.Transaction) (common.Address, error) { - return types.Sender(types.HomesteadSigner{}, tx) +func deriveSender(tx types.PoolTransaction) (common.Address, error) { + return types.PoolTransactionSender(types.HomesteadSigner{}, tx) } type testChain struct { @@ -194,7 +248,8 @@ func TestStateChangeDuringTransactionPoolReset(t *testing.T) { tx0 := transaction(0, 100000, key) tx1 := transaction(1, 100000, key) - pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() nonce := pool.State().GetNonce(address) @@ -202,7 +257,7 @@ func TestStateChangeDuringTransactionPoolReset(t *testing.T) { t.Fatalf("Invalid nonce, want 0, got %d", nonce) } - pool.AddRemotes(types.Transactions{tx0, tx1}) + pool.AddRemotes(types.PoolTransactions{tx0, tx1}) nonce = pool.State().GetNonce(address) if nonce != 2 { @@ -261,6 +316,65 @@ func TestInvalidTransactions(t *testing.T) { } } +func TestCreateValidatorTransaction(t *testing.T) { + t.Parallel() + + pool, _ := setupTxPool() + defer pool.Stop() + + fromKey, _ := crypto.GenerateKey() + stx, err := stakingCreateValidatorTransaction(fromKey) + if err != nil { + t.Errorf("cannot create new staking transaction, %v\n", err) + } + senderAddr, _ := stx.SenderAddress() + pool.currentState.AddBalance(senderAddr, big.NewInt(1e18)) + + err = pool.AddRemote(stx) + if err != nil { + t.Error(err.Error()) + } + + if pool.pending[senderAddr] == nil || pool.pending[senderAddr].Len() != 1 { + t.Error("Expected 1 pending transaction") + } +} + +func TestMixedTransactions(t *testing.T) { + t.Parallel() + + pool, _ := setupTxPool() + defer pool.Stop() + + fromKey, _ := crypto.GenerateKey() + stx, err := stakingCreateValidatorTransaction(fromKey) + if err != nil { + t.Errorf("cannot create new staking transaction, %v\n", err) + } + stxAddr, _ := stx.SenderAddress() + pool.currentState.AddBalance(stxAddr, big.NewInt(1e18)) + + goodFromKey, _ := crypto.GenerateKey() + tx := transaction(0, 25000, goodFromKey) + txAddr, _ := deriveSender(tx) + pool.currentState.AddBalance(txAddr, big.NewInt(50100)) + + errs := pool.AddRemotes(types.PoolTransactions{stx, tx}) + for _, err := range errs { + if err != nil { + t.Error(err) + } + } + + if pool.pending[stxAddr] == nil || pool.pending[stxAddr].Len() != 1 { + t.Error("Expected 1 pending transaction") + } + + if pool.pending[txAddr] == nil || pool.pending[txAddr].Len() != 1 { + t.Error("Expected 1 pending transaction") + } +} + func TestBlacklistedTransactions(t *testing.T) { // DO NOT parallelize, test will add accounts to tx pool config. @@ -284,20 +398,20 @@ func TestBlacklistedTransactions(t *testing.T) { pool.currentState.AddBalance(goodFromAcc, big.NewInt(50100)) (*DefaultTxPoolConfig.Blacklist)[bannedToAcc] = struct{}{} - err := pool.AddRemotes([]*types.Transaction{badTx}) + err := pool.AddRemotes(types.PoolTransactions{badTx}) if err[0] != ErrBlacklistTo { t.Error("expected", ErrBlacklistTo, "got", err[0]) } delete(*DefaultTxPoolConfig.Blacklist, bannedToAcc) (*DefaultTxPoolConfig.Blacklist)[bannedFromAcc] = struct{}{} - err = pool.AddRemotes([]*types.Transaction{badTx}) + err = pool.AddRemotes(types.PoolTransactions{badTx}) if err[0] != ErrBlacklistFrom { t.Error("expected", ErrBlacklistFrom, "got", err[0]) } // to acc is same for bad and good tx, so keep off blacklist for valid tx check - err = pool.AddRemotes([]*types.Transaction{goodTx}) + err = pool.AddRemotes(types.PoolTransactions{goodTx}) if err[0] != nil { t.Error("expected", nil, "got", err[0]) } @@ -346,9 +460,9 @@ func TestTransactionQueue(t *testing.T) { pool.currentState.AddBalance(from, big.NewInt(1000)) pool.lockedReset(nil, nil) - pool.enqueueTx(tx1.Hash(), tx1) - pool.enqueueTx(tx2.Hash(), tx2) - pool.enqueueTx(tx3.Hash(), tx3) + pool.enqueueTx(tx.Hash(), tx1) + pool.enqueueTx(tx.Hash(), tx2) + pool.enqueueTx(tx.Hash(), tx3) pool.promoteExecutables([]common.Address{from}) @@ -366,7 +480,9 @@ func TestTransactionNegativeValue(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - tx, _ := types.SignTx(types.NewTransaction(0, common.Address{}, 0, big.NewInt(-1), 100, big.NewInt(1), nil), types.HomesteadSigner{}, key) + tx, _ := types.SignTx( + types.NewTransaction(0, common.Address{}, 0, big.NewInt(-1), 100, big.NewInt(1), nil), + types.HomesteadSigner{}, key) from, _ := deriveSender(tx) pool.currentState.AddBalance(from, big.NewInt(1)) if err := pool.AddRemote(tx); err != ErrNegativeValue { @@ -420,9 +536,15 @@ func TestTransactionDoubleNonce(t *testing.T) { resetState() signer := types.HomesteadSigner{} - tx1, _ := types.SignTx(types.NewTransaction(0, common.Address{}, 0, big.NewInt(100), 100000, big.NewInt(1), nil), signer, key) - tx2, _ := types.SignTx(types.NewTransaction(0, common.Address{}, 0, big.NewInt(100), 1000000, big.NewInt(2), nil), signer, key) - tx3, _ := types.SignTx(types.NewTransaction(0, common.Address{}, 0, big.NewInt(100), 1000000, big.NewInt(1), nil), signer, key) + tx1, _ := types.SignTx( + types.NewTransaction(0, common.Address{}, 0, big.NewInt(100), 100000, big.NewInt(1), nil), + signer, key) + tx2, _ := types.SignTx( + types.NewTransaction(0, common.Address{}, 0, big.NewInt(100), 1000000, big.NewInt(2), nil), + signer, key) + tx3, _ := types.SignTx( + types.NewTransaction(0, common.Address{}, 0, big.NewInt(100), 1000000, big.NewInt(1), nil), + signer, key) // Add the first two transaction, ensure higher priced stays only if replace, err := pool.add(tx1, false); err != nil || replace { @@ -435,8 +557,8 @@ func TestTransactionDoubleNonce(t *testing.T) { if pool.pending[addr].Len() != 1 { t.Error("expected 1 pending transactions, got", pool.pending[addr].Len()) } - if tx := pool.pending[addr].txs.items[0]; tx.Hash() != tx2.Hash() { - t.Errorf("transaction mismatch: have %x, want %x", tx.Hash(), tx2.Hash()) + if tx := pool.pending[addr].txs.items[0]; tx.Hash() != (*tx2).Hash() { + t.Errorf("transaction mismatch: have %x, want %x", tx.Hash(), (*tx2).Hash()) } // Add the third transaction and ensure it's not saved (smaller price) pool.add(tx3, false) @@ -444,8 +566,8 @@ func TestTransactionDoubleNonce(t *testing.T) { if pool.pending[addr].Len() != 1 { t.Error("expected 1 pending transactions, got", pool.pending[addr].Len()) } - if tx := pool.pending[addr].txs.items[0]; tx.Hash() != tx2.Hash() { - t.Errorf("transaction mismatch: have %x, want %x", tx.Hash(), tx2.Hash()) + if tx := pool.pending[addr].txs.items[0]; tx.Hash() != (*tx2).Hash() { + t.Errorf("transaction mismatch: have %x, want %x", tx.Hash(), (*tx2).Hash()) } // Ensure the total transaction count is correct if pool.all.Count() != 1 { @@ -604,7 +726,8 @@ func TestTransactionPostponing(t *testing.T) { statedb, _ := state.New(common.Hash{}, state.NewDatabase(ethdb.NewMemDatabase())) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} - pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() // Create two test accounts to produce different gap profiles with @@ -618,11 +741,11 @@ func TestTransactionPostponing(t *testing.T) { pool.currentState.AddBalance(crypto.PubkeyToAddress(keys[i].PublicKey), big.NewInt(50100)) } // Add a batch consecutive pending transactions for validation - txs := []*types.Transaction{} + txs := types.PoolTransactions{} for i, key := range keys { for j := 0; j < 100; j++ { - var tx *types.Transaction + var tx types.PoolTransaction if (i+j)%2 == 0 { tx = transaction(uint64(j), 25000, key) } else { @@ -766,7 +889,8 @@ func testTransactionQueueGlobalLimiting(t *testing.T, nolocals bool) { config.NoLocals = nolocals config.GlobalQueue = config.AccountQueue*3 - 1 // reduce the queue limits to shorten test time (-1 to make it non divisible) - pool := NewTxPool(config, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(config, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() // Create a number of test accounts and fund them (last one will be the local) @@ -780,7 +904,7 @@ func testTransactionQueueGlobalLimiting(t *testing.T, nolocals bool) { // Generate and queue a batch of transactions nonces := make(map[common.Address]uint64) - txs := make(types.Transactions, 0, 3*config.GlobalQueue) + txs := make(types.PoolTransactions, 0, 3*config.GlobalQueue) for len(txs) < cap(txs) { key := keys[rand.Intn(len(keys)-1)] // skip adding transactions with the local account addr := crypto.PubkeyToAddress(key.PublicKey) @@ -854,7 +978,8 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { config.Lifetime = time.Second config.NoLocals = nolocals - pool := NewTxPool(config, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(config, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() // Create two test accounts to ensure remotes expire but locals do not @@ -929,7 +1054,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { account2, _ := deriveSender(transaction(0, 0, key2)) pool2.currentState.AddBalance(account2, big.NewInt(1000000)) - txs := []*types.Transaction{} + txs := types.PoolTransactions{} for i := uint64(0); i < testTxPoolConfig.AccountQueue+5; i++ { txs = append(txs, transaction(origin+i, 100000, key2)) } @@ -966,7 +1091,8 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { config := testTxPoolConfig config.GlobalSlots = config.AccountSlots * 10 - pool := NewTxPool(config, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(config, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() // Create a number of test accounts and fund them @@ -978,7 +1104,7 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { // Generate and queue a batch of transactions nonces := make(map[common.Address]uint64) - txs := types.Transactions{} + txs := types.PoolTransactions{} for _, key := range keys { addr := crypto.PubkeyToAddress(key.PublicKey) for j := 0; j < int(config.GlobalSlots)/len(keys)*2; j++ { @@ -1014,7 +1140,8 @@ func TestTransactionCapClearsFromAll(t *testing.T) { config.AccountQueue = 2 config.GlobalSlots = 8 - pool := NewTxPool(config, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(config, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() // Create a number of test accounts and fund them @@ -1022,7 +1149,7 @@ func TestTransactionCapClearsFromAll(t *testing.T) { addr := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(addr, big.NewInt(1000000)) - txs := types.Transactions{} + txs := types.PoolTransactions{} for j := 0; j < int(config.GlobalSlots)*2; j++ { txs = append(txs, transaction(uint64(j), 100000, key)) } @@ -1046,7 +1173,8 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { config := testTxPoolConfig config.GlobalSlots = 0 - pool := NewTxPool(config, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(config, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() // Create a number of test accounts and fund them @@ -1058,7 +1186,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { // Generate and queue a batch of transactions nonces := make(map[common.Address]uint64) - txs := types.Transactions{} + txs := types.PoolTransactions{} for _, key := range keys { addr := crypto.PubkeyToAddress(key.PublicKey) for j := 0; j < int(config.AccountSlots)*2; j++ { @@ -1088,7 +1216,8 @@ func TestTransactionPoolRepricingKeepsLocals(t *testing.T) { statedb, _ := state.New(common.Hash{}, state.NewDatabase(ethdb.NewMemDatabase())) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} - pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() // Create a number of test accounts and fund them @@ -1167,7 +1296,8 @@ func testTransactionJournaling(t *testing.T, nolocals bool) { config.Journal = journal config.Rejournal = time.Second - pool := NewTxPool(config, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(config, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) // Create two test accounts to ensure remotes expire but locals do not local, _ := crypto.GenerateKey() @@ -1204,7 +1334,8 @@ func testTransactionJournaling(t *testing.T, nolocals bool) { statedb.SetNonce(crypto.PubkeyToAddress(local.PublicKey), 1) blockchain = &testBlockChain{statedb, 1000000, new(event.Feed)} - pool = NewTxPool(config, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool = NewTxPool(config, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) pending, queued = pool.Stats() if queued != 0 { @@ -1230,7 +1361,8 @@ func testTransactionJournaling(t *testing.T, nolocals bool) { statedb.SetNonce(crypto.PubkeyToAddress(local.PublicKey), 1) blockchain = &testBlockChain{statedb, 1000000, new(event.Feed)} - pool = NewTxPool(config, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool = NewTxPool(config, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) pending, queued = pool.Stats() if pending != 0 { @@ -1260,7 +1392,8 @@ func TestTransactionStatusCheck(t *testing.T) { statedb, _ := state.New(common.Hash{}, state.NewDatabase(ethdb.NewMemDatabase())) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} - pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, func([]types.RPCTransactionError) {}) + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) defer pool.Stop() // Create the test accounts to check various transaction statuses with @@ -1270,7 +1403,7 @@ func TestTransactionStatusCheck(t *testing.T) { pool.currentState.AddBalance(crypto.PubkeyToAddress(keys[i].PublicKey), big.NewInt(1000000)) } // Generate and queue a batch of transactions, both pending and queued - txs := types.Transactions{} + txs := types.PoolTransactions{} txs = append(txs, pricedTransaction(0, 100000, big.NewInt(1), keys[0])) // Pending only txs = append(txs, pricedTransaction(0, 100000, big.NewInt(1), keys[1])) // Pending and queued @@ -1366,7 +1499,7 @@ func BenchmarkPoolInsert(b *testing.B) { account, _ := deriveSender(transaction(0, 0, key)) pool.currentState.AddBalance(account, big.NewInt(1000000)) - txs := make(types.Transactions, b.N) + txs := make(types.PoolTransactions, b.N) for i := 0; i < b.N; i++ { txs[i] = transaction(uint64(i), 100000, key) } @@ -1390,9 +1523,9 @@ func benchmarkPoolBatchInsert(b *testing.B, size int) { account, _ := deriveSender(transaction(0, 0, key)) pool.currentState.AddBalance(account, big.NewInt(1000000)) - batches := make([]types.Transactions, b.N) + batches := make([]types.PoolTransactions, b.N) for i := 0; i < b.N; i++ { - batches[i] = make(types.Transactions, size) + batches[i] = make(types.PoolTransactions, size) for j := 0; j < size; j++ { batches[i][j] = transaction(uint64(size*i+j), 100000, key) } diff --git a/core/types/transaction.go b/core/types/transaction.go index e2488ea26..42f48f5c7 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -79,8 +79,8 @@ type RPCTransactionError struct { } // NewRPCTransactionError ... -func NewRPCTransactionError(hash common.Hash, err error) *RPCTransactionError { - return &RPCTransactionError{ +func NewRPCTransactionError(hash common.Hash, err error) RPCTransactionError { + return RPCTransactionError{ TxHashID: hash.Hex(), TimestampOfRejection: time.Now().Unix(), ErrMessage: err.Error(), diff --git a/core/types/tx_pool.go b/core/types/tx_pool.go new file mode 100644 index 000000000..7c42d47e3 --- /dev/null +++ b/core/types/tx_pool.go @@ -0,0 +1,96 @@ +package types + +import ( + "io" + "math/big" + + "github.com/pkg/errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" + staking "github.com/harmony-one/harmony/staking/types" +) + +var ( + // ErrUnknownPoolTxType is returned when attempting to assert a PoolTransaction to its concrete type + ErrUnknownPoolTxType = errors.New("unknown transaction type in tx-pool") +) + +// PoolTransaction is the general transaction interface used by the tx pool +type PoolTransaction interface { + Hash() common.Hash + Nonce() uint64 + ChainID() *big.Int + To() *common.Address + Size() common.StorageSize + Data() []byte + GasPrice() *big.Int + Gas() uint64 + Cost() *big.Int + Value() *big.Int + EncodeRLP(w io.Writer) error + DecodeRLP(s *rlp.Stream) error + Protected() bool +} + +// PoolTransactionSender returns the address derived from the signature (V, R, S) u +// sing secp256k1 elliptic curve and an error if it failed deriving or upon an +// incorrect signature. +// +// Sender may cache the address, allowing it to be used regardless of +// signing method. The cache is invalidated if the cached signer does +// not match the signer used in the current call. +// +// Note that the signer is an interface since different txs have different signers. +func PoolTransactionSender(signer interface{}, tx PoolTransaction) (common.Address, error) { + if plainTx, ok := tx.(*Transaction); ok { + if sig, ok := signer.(Signer); ok { + return Sender(sig, plainTx) + } + } else if stakingTx, ok := tx.(*staking.StakingTransaction); ok { + return stakingTx.SenderAddress() + } + return common.Address{}, errors.WithMessage(ErrUnknownPoolTxType, "when fetching transaction sender") +} + +// PoolTransactions is a PoolTransactions slice type for basic sorting. +type PoolTransactions []PoolTransaction + +// Len returns the length of s. +func (s PoolTransactions) Len() int { return len(s) } + +// Swap swaps the i'th and the j'th element in s. +func (s PoolTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// GetRlp implements Rlpable and returns the i'th element of s in rlp. +func (s PoolTransactions) GetRlp(i int) []byte { + enc, _ := rlp.EncodeToBytes(s[i]) + return enc +} + +// PoolTxDifference returns a new set which is the difference between a and b. +func PoolTxDifference(a, b PoolTransactions) PoolTransactions { + keep := make(PoolTransactions, 0, len(a)) + + remove := make(map[common.Hash]struct{}) + for _, tx := range b { + remove[tx.Hash()] = struct{}{} + } + + for _, tx := range a { + if _, ok := remove[tx.Hash()]; !ok { + keep = append(keep, tx) + } + } + + return keep +} + +// PoolTxByNonce implements the sort interface to allow sorting a list of transactions +// by their nonces. This is usually only useful for sorting transactions from a +// single account, otherwise a nonce comparison doesn't make much sense. +type PoolTxByNonce PoolTransactions + +func (s PoolTxByNonce) Len() int { return len(s) } +func (s PoolTxByNonce) Less(i, j int) bool { return (s[i]).Nonce() < (s[j]).Nonce() } +func (s PoolTxByNonce) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/hmy/api_backend.go b/hmy/api_backend.go index be9ca6d99..6b1a84d9d 100644 --- a/hmy/api_backend.go +++ b/hmy/api_backend.go @@ -48,7 +48,7 @@ func (b *APIBackend) GetBlock(ctx context.Context, hash common.Hash) (*types.Blo } // GetPoolTransaction ... -func (b *APIBackend) GetPoolTransaction(hash common.Hash) *types.Transaction { +func (b *APIBackend) GetPoolTransaction(hash common.Hash) types.PoolTransaction { return b.hmy.txPool.Get(hash) } @@ -208,12 +208,12 @@ func (b *APIBackend) SubscribeLogsEvent(ch chan<- []*types.Log) event.Subscripti // GetPoolTransactions returns pool transactions. // TODO: this is not implemented or verified yet for harmony. -func (b *APIBackend) GetPoolTransactions() (types.Transactions, error) { +func (b *APIBackend) GetPoolTransactions() (types.PoolTransactions, error) { pending, err := b.hmy.txPool.Pending() if err != nil { return nil, err } - var txs types.Transactions + var txs types.PoolTransactions for _, batch := range pending { txs = append(txs, batch...) } diff --git a/internal/hmyapi/apiv1/backend.go b/internal/hmyapi/apiv1/backend.go index b9662ee0a..13021bf74 100644 --- a/internal/hmyapi/apiv1/backend.go +++ b/internal/hmyapi/apiv1/backend.go @@ -52,8 +52,8 @@ type Backend interface { // TxPool API SendTx(ctx context.Context, signedTx *types.Transaction) error // GetTransaction(ctx context.Context, txHash common.Hash) (*types.Transaction, common.Hash, uint64, uint64, error) - GetPoolTransactions() (types.Transactions, error) - GetPoolTransaction(txHash common.Hash) *types.Transaction + GetPoolTransactions() (types.PoolTransactions, error) + GetPoolTransaction(txHash common.Hash) types.PoolTransaction GetPoolNonce(ctx context.Context, addr common.Address) (uint64, error) // Stats() (pending int, queued int) // TxPoolContent() (map[common.Address]types.Transactions, map[common.Address]types.Transactions) diff --git a/internal/hmyapi/apiv1/transactionpool.go b/internal/hmyapi/apiv1/transactionpool.go index 4c18487c2..09d5f044a 100644 --- a/internal/hmyapi/apiv1/transactionpool.go +++ b/internal/hmyapi/apiv1/transactionpool.go @@ -106,7 +106,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionByBlockHashAndIndex(ctx context return nil } -// GetTransactionByHash returns the transaction for the given hash +// GetTransactionByHash returns the plain transaction for the given hash func (s *PublicTransactionPoolAPI) GetTransactionByHash(ctx context.Context, hash common.Hash) *RPCTransaction { // Try to return an already finalized transaction tx, blockHash, blockNumber, index := rawdb.ReadTransaction(s.b.ChainDb(), hash) @@ -117,15 +117,11 @@ func (s *PublicTransactionPoolAPI) GetTransactionByHash(ctx context.Context, has if tx != nil { return newRPCTransaction(tx, blockHash, blockNumber, block.Time().Uint64(), index) } - // No finalized transaction, try to retrieve it from the pool - if tx = s.b.GetPoolTransaction(hash); tx != nil { - return newRPCPendingTransaction(tx) - } // Transaction unknown, return as such return nil } -// GetStakingTransactionByHash returns the transaction for the given hash +// GetStakingTransactionByHash returns the staking transaction for the given hash func (s *PublicTransactionPoolAPI) GetStakingTransactionByHash(ctx context.Context, hash common.Hash) *RPCStakingTransaction { // Try to return an already finalized transaction stx, blockHash, blockNumber, index := rawdb.ReadStakingTransaction(s.b.ChainDb(), hash) @@ -136,6 +132,7 @@ func (s *PublicTransactionPoolAPI) GetStakingTransactionByHash(ctx context.Conte if stx != nil { return newRPCStakingTransaction(stx, blockHash, blockNumber, block.Time().Uint64(), index) } + // Transaction unknown, return as such return nil } @@ -332,7 +329,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha return fields, nil } -// PendingTransactions returns the transactions that are in the transaction pool +// PendingTransactions returns the plain transactions that are in the transaction pool func (s *PublicTransactionPoolAPI) PendingTransactions() ([]*RPCTransaction, error) { pending, err := s.b.GetPoolTransactions() if err != nil { @@ -340,7 +337,32 @@ func (s *PublicTransactionPoolAPI) PendingTransactions() ([]*RPCTransaction, err } transactions := make([]*RPCTransaction, len(pending)) for i := range pending { - transactions[i] = newRPCPendingTransaction(pending[i]) + if plainTx, ok := pending[i].(*types.Transaction); ok { + transactions[i] = newRPCPendingTransaction(plainTx) + } else if _, ok := pending[i].(*staking.StakingTransaction); ok { + continue // Do not return staking transactions here. + } else { + return nil, types.ErrUnknownPoolTxType + } + } + return transactions, nil +} + +// PendingStakingTransactions returns the staking transactions that are in the transaction pool +func (s *PublicTransactionPoolAPI) PendingStakingTransactions() ([]*RPCStakingTransaction, error) { + pending, err := s.b.GetPoolTransactions() + if err != nil { + return nil, err + } + transactions := make([]*RPCStakingTransaction, len(pending)) + for i := range pending { + if _, ok := pending[i].(*types.Transaction); ok { + continue // Do not return plain transactions here + } else if stakingTx, ok := pending[i].(*staking.StakingTransaction); ok { + transactions[i] = newRPCPendingStakingTransaction(stakingTx) + } else { + return nil, types.ErrUnknownPoolTxType + } } return transactions, nil } diff --git a/internal/hmyapi/apiv1/types.go b/internal/hmyapi/apiv1/types.go index b425efcbb..81e23a293 100644 --- a/internal/hmyapi/apiv1/types.go +++ b/internal/hmyapi/apiv1/types.go @@ -296,7 +296,7 @@ func newRPCStakingTransaction(tx *types2.StakingTransaction, blockHash common.Ha result := &RPCStakingTransaction{ Gas: hexutil.Uint64(tx.Gas()), - GasPrice: (*hexutil.Big)(tx.Price()), + GasPrice: (*hexutil.Big)(tx.GasPrice()), Hash: tx.Hash(), Nonce: hexutil.Uint64(tx.Nonce()), Timestamp: hexutil.Uint64(timestamp), @@ -326,6 +326,11 @@ func newRPCPendingTransaction(tx *types.Transaction) *RPCTransaction { return newRPCTransaction(tx, common.Hash{}, 0, 0, 0) } +// newRPCPendingStakingTransaction returns a pending transaction that will serialize to the RPC representation +func newRPCPendingStakingTransaction(tx *types2.StakingTransaction) *RPCStakingTransaction { + return newRPCStakingTransaction(tx, common.Hash{}, 0, 0, 0) +} + // RPCBlock represents a block that will serialize to the RPC representation of a block type RPCBlock struct { Number *hexutil.Big `json:"number"` diff --git a/internal/hmyapi/apiv2/backend.go b/internal/hmyapi/apiv2/backend.go index 1765132d9..25e27d974 100644 --- a/internal/hmyapi/apiv2/backend.go +++ b/internal/hmyapi/apiv2/backend.go @@ -52,8 +52,8 @@ type Backend interface { // TxPool API SendTx(ctx context.Context, signedTx *types.Transaction) error // GetTransaction(ctx context.Context, txHash common.Hash) (*types.Transaction, common.Hash, uint64, uint64, error) - GetPoolTransactions() (types.Transactions, error) - GetPoolTransaction(txHash common.Hash) *types.Transaction + GetPoolTransactions() (types.PoolTransactions, error) + GetPoolTransaction(txHash common.Hash) types.PoolTransaction GetPoolNonce(ctx context.Context, addr common.Address) (uint64, error) // Stats() (pending int, queued int) // TxPoolContent() (map[common.Address]types.Transactions, map[common.Address]types.Transactions) diff --git a/internal/hmyapi/apiv2/transactionpool.go b/internal/hmyapi/apiv2/transactionpool.go index a1eb85d85..60c2f7e4b 100644 --- a/internal/hmyapi/apiv2/transactionpool.go +++ b/internal/hmyapi/apiv2/transactionpool.go @@ -115,10 +115,6 @@ func (s *PublicTransactionPoolAPI) GetTransactionByHash(ctx context.Context, has if tx != nil { return newRPCTransaction(tx, blockHash, blockNumber, block.Time().Uint64(), index) } - // No finalized transaction, try to retrieve it from the pool - if tx = s.b.GetPoolTransaction(hash); tx != nil { - return newRPCPendingTransaction(tx) - } // Transaction unknown, return as such return nil } @@ -134,6 +130,7 @@ func (s *PublicTransactionPoolAPI) GetStakingTransactionByHash(ctx context.Conte if stx != nil { return newRPCStakingTransaction(stx, blockHash, blockNumber, block.Time().Uint64(), index) } + // Transaction unknown, return as such return nil } @@ -330,17 +327,17 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha return fields, nil } -// PendingTransactions returns the transactions that are in the transaction pool +// PendingTransactions returns the plain transactions that are in the transaction pool // and have a from address that is one of the accounts this node manages. func (s *PublicTransactionPoolAPI) PendingTransactions() ([]*RPCTransaction, error) { pending, err := s.b.GetPoolTransactions() if err != nil { return nil, err } - accounts := make(map[common.Address]struct{}) + managedAccounts := make(map[common.Address]struct{}) for _, wallet := range s.b.AccountManager().Wallets() { for _, account := range wallet.Accounts() { - accounts[account.Address] = struct{}{} + managedAccounts[account.Address] = struct{}{} } } transactions := make([]*RPCTransaction, 0, len(pending)) @@ -349,9 +346,48 @@ func (s *PublicTransactionPoolAPI) PendingTransactions() ([]*RPCTransaction, err if tx.Protected() { signer = types.NewEIP155Signer(tx.ChainID()) } - from, _ := types.Sender(signer, tx) - if _, exists := accounts[from]; exists { - transactions = append(transactions, newRPCPendingTransaction(tx)) + from, _ := types.PoolTransactionSender(signer, tx) + if _, exists := managedAccounts[from]; exists { + if plainTx, ok := tx.(*types.Transaction); ok { + transactions = append(transactions, newRPCPendingTransaction(plainTx)) + } else if _, ok := tx.(*staking.StakingTransaction); ok { + continue // Do not return staking transactions here + } else { + return nil, types.ErrUnknownPoolTxType + } + } + } + return transactions, nil +} + +// PendingStakingTransactions returns the staking transactions that are in the transaction pool +// and have a from address that is one of the accounts this node manages. +func (s *PublicTransactionPoolAPI) PendingStakingTransactions() ([]*RPCStakingTransaction, error) { + pending, err := s.b.GetPoolTransactions() + if err != nil { + return nil, err + } + managedAccounts := make(map[common.Address]struct{}) + for _, wallet := range s.b.AccountManager().Wallets() { + for _, account := range wallet.Accounts() { + managedAccounts[account.Address] = struct{}{} + } + } + transactions := make([]*RPCStakingTransaction, 0, len(pending)) + for _, tx := range pending { + var signer types.Signer = types.HomesteadSigner{} + if tx.Protected() { + signer = types.NewEIP155Signer(tx.ChainID()) + } + from, _ := types.PoolTransactionSender(signer, tx) + if _, exists := managedAccounts[from]; exists { + if _, ok := tx.(*types.Transaction); ok { + continue // Do not return plain transactions here + } else if stakingTx, ok := tx.(*staking.StakingTransaction); ok { + transactions = append(transactions, newRPCPendingStakingTransaction(stakingTx)) + } else { + return nil, types.ErrUnknownPoolTxType + } } } return transactions, nil diff --git a/internal/hmyapi/apiv2/types.go b/internal/hmyapi/apiv2/types.go index 24064a40f..79d909873 100644 --- a/internal/hmyapi/apiv2/types.go +++ b/internal/hmyapi/apiv2/types.go @@ -297,7 +297,7 @@ func newRPCStakingTransaction(tx *types2.StakingTransaction, blockHash common.Ha result := &RPCStakingTransaction{ Gas: tx.Gas(), - GasPrice: tx.Price(), + GasPrice: tx.GasPrice(), Hash: tx.Hash(), Nonce: tx.Nonce(), Timestamp: timestamp, @@ -327,6 +327,11 @@ func newRPCPendingTransaction(tx *types.Transaction) *RPCTransaction { return newRPCTransaction(tx, common.Hash{}, 0, 0, 0) } +// newRPCPendingStakingTransaction returns a pending transaction that will serialize to the RPC representation +func newRPCPendingStakingTransaction(tx *types2.StakingTransaction) *RPCStakingTransaction { + return newRPCStakingTransaction(tx, common.Hash{}, 0, 0, 0) +} + // RPCBlock represents a block that will serialize to the RPC representation of a block type RPCBlock struct { Number *big.Int `json:"number"` diff --git a/internal/hmyapi/backend.go b/internal/hmyapi/backend.go index c59646559..672870e20 100644 --- a/internal/hmyapi/backend.go +++ b/internal/hmyapi/backend.go @@ -54,8 +54,8 @@ type Backend interface { // TxPool API SendTx(ctx context.Context, signedTx *types.Transaction) error // GetTransaction(ctx context.Context, txHash common.Hash) (*types.Transaction, common.Hash, uint64, uint64, error) - GetPoolTransactions() (types.Transactions, error) - GetPoolTransaction(txHash common.Hash) *types.Transaction + GetPoolTransactions() (types.PoolTransactions, error) + GetPoolTransaction(txHash common.Hash) types.PoolTransaction GetPoolNonce(ctx context.Context, addr common.Address) (uint64, error) // Stats() (pending int, queued int) // TxPoolContent() (map[common.Address]types.Transactions, map[common.Address]types.Transactions) diff --git a/node/node.go b/node/node.go index 694e89cd5..8139c236e 100644 --- a/node/node.go +++ b/node/node.go @@ -134,9 +134,6 @@ type Node struct { CxPool *core.CxPool // pool for missing cross shard receipts resend - pendingStakingTransactions map[common.Hash]*staking.StakingTransaction // All the staking transactions received but not yet processed for Consensus - pendingStakingTxMutex sync.Mutex - Worker *worker.Worker BeaconWorker *worker.Worker // worker for beacon chain @@ -282,50 +279,45 @@ func (node *Node) tryBroadcastStaking(stakingTx *staking.StakingTransaction) { // Add new transactions to the pending transaction list. func (node *Node) addPendingTransactions(newTxs types.Transactions) { - node.TxPool.AddRemotes(newTxs) + poolTxs := types.PoolTransactions{} + for _, tx := range newTxs { + poolTxs = append(poolTxs, tx) + } + node.TxPool.AddRemotes(poolTxs) pendingCount, queueCount := node.TxPool.Stats() - utils.Logger().Info().Int("length of newTxs", len(newTxs)).Int("totalPending", pendingCount).Int("totalQueued", queueCount).Msg("Got more transactions") + utils.Logger().Info(). + Int("length of newTxs", len(newTxs)). + Int("totalPending", pendingCount). + Int("totalQueued", queueCount). + Msg("Got more transactions") } // Add new staking transactions to the pending staking transaction list. func (node *Node) addPendingStakingTransactions(newStakingTxs staking.StakingTransactions) { - // TODO: incorporate staking txn into TxPool if node.NodeConfig.ShardID == shard.BeaconChainShardID && node.Blockchain().Config().IsPreStaking(node.Blockchain().CurrentHeader().Epoch()) { - node.pendingStakingTxMutex.Lock() + poolTxs := types.PoolTransactions{} for _, tx := range newStakingTxs { - const txPoolLimit = 1000 - if s := len(node.pendingStakingTransactions); s >= txPoolLimit { - utils.Logger().Info(). - Int("tx-pool-size", s). - Int("tx-pool-limit", txPoolLimit). - Msg("Current staking txn pool reached limit") - break - } - if _, ok := node.pendingStakingTransactions[tx.Hash()]; !ok { - node.pendingStakingTransactions[tx.Hash()] = tx - } + poolTxs = append(poolTxs, tx) } + node.TxPool.AddRemotes(poolTxs) + pendingCount, queueCount := node.TxPool.Stats() utils.Logger().Info(). - Int("length of newStakingTxs", len(newStakingTxs)). - Int("totalPending", len(node.pendingStakingTransactions)). + Int("length of newStakingTxs", len(poolTxs)). + Int("totalPending", pendingCount). + Int("totalQueued", queueCount). Msg("Got more staking transactions") - node.pendingStakingTxMutex.Unlock() } } // AddPendingStakingTransaction staking transactions -func (node *Node) AddPendingStakingTransaction( - newStakingTx *staking.StakingTransaction) { - // TODO: everyone should record staking txns, not just leader - if node.Consensus.IsLeader() && - node.NodeConfig.ShardID == shard.BeaconChainShardID { +func (node *Node) AddPendingStakingTransaction(newStakingTx *staking.StakingTransaction) { + if node.NodeConfig.ShardID == shard.BeaconChainShardID { node.addPendingStakingTransactions(staking.StakingTransactions{newStakingTx}) - } else { - utils.Logger().Info().Str("Hash", newStakingTx.Hash().Hex()).Msg("Broadcasting Staking Tx") - node.tryBroadcastStaking(newStakingTx) } + utils.Logger().Info().Str("Hash", newStakingTx.Hash().Hex()).Msg("Broadcasting Staking Tx") + node.tryBroadcastStaking(newStakingTx) } // AddPendingTransaction adds one new transaction to the pending transaction list. @@ -333,9 +325,6 @@ func (node *Node) AddPendingStakingTransaction( func (node *Node) AddPendingTransaction(newTx *types.Transaction) { if newTx.ShardID() == node.NodeConfig.ShardID { node.addPendingTransactions(types.Transactions{newTx}) - if node.NodeConfig.Role() != nodeconfig.ExplorerNode { - return - } } utils.Logger().Info().Str("Hash", newTx.Hash().Hex()).Msg("Broadcasting Tx") node.tryBroadcast(newTx) @@ -515,7 +504,16 @@ func New(host p2p.Host, consensusObj *consensus.Consensus, } node.errorSink.Unlock() } - }) + }, + func(payload []staking.RPCTransactionError) { + node.errorSink.Lock() + for i := range payload { + node.errorSink.failedStakingTxns.Value = payload[i] + node.errorSink.failedStakingTxns = node.errorSink.failedStakingTxns.Next() + } + node.errorSink.Unlock() + }, + ) node.CxPool = core.NewCxPool(core.CxPoolSize) node.Worker = worker.New(node.Blockchain().Config(), blockchain, chain.Engine) @@ -526,7 +524,6 @@ func New(host p2p.Host, consensusObj *consensus.Consensus, } node.pendingCXReceipts = make(map[string]*types.CXReceiptsProof) - node.pendingStakingTransactions = make(map[common.Hash]*staking.StakingTransaction) node.Consensus.VerifiedNewBlock = make(chan *types.Block) chain.Engine.SetRewarder(node.Consensus.Decider.(reward.Distributor)) chain.Engine.SetBeaconchain(beaconChain) diff --git a/node/node_handler_test.go b/node/node_handler_test.go index 2f8cfc7f2..21ec9fb1d 100644 --- a/node/node_handler_test.go +++ b/node/node_handler_test.go @@ -39,7 +39,6 @@ func TestAddNewBlock(t *testing.T) { stks := staking.StakingTransactions{} node.Worker.CommitTransactions( txs, stks, common.Address{}, - func([]staking.RPCTransactionError) {}, ) block, _ := node.Worker.FinalizeNewBlock( []byte{}, []byte{}, 0, common.Address{}, nil, nil, nil, @@ -77,7 +76,6 @@ func TestVerifyNewBlock(t *testing.T) { stks := staking.StakingTransactions{} node.Worker.CommitTransactions( txs, stks, common.Address{}, - func([]staking.RPCTransactionError) {}, ) block, _ := node.Worker.FinalizeNewBlock( []byte{}, []byte{}, 0, common.Address{}, nil, nil, nil, diff --git a/node/node_newblock.go b/node/node_newblock.go index f357dd570..94f88ac12 100644 --- a/node/node_newblock.go +++ b/node/node_newblock.go @@ -103,34 +103,35 @@ func (node *Node) proposeNewBlock() (*types.Block, error) { } // Prepare transactions including staking transactions - pending, err := node.TxPool.Pending() + pendingPoolTxs, err := node.TxPool.Pending() if err != nil { utils.Logger().Err(err).Msg("Failed to fetch pending transactions") return nil, err } - - // TODO: integrate staking transaction into tx pool - pendingStakingTransactions := staking.StakingTransactions{} - // Only process staking transactions after pre-staking epoch happened. - if node.Blockchain().Config().IsPreStaking(node.Worker.GetCurrentHeader().Epoch()) { - node.pendingStakingTxMutex.Lock() - for _, tx := range node.pendingStakingTransactions { - pendingStakingTransactions = append(pendingStakingTransactions, tx) + pendingPlainTxs := map[common.Address]types.Transactions{} + pendingStakingTxs := staking.StakingTransactions{} + for addr, poolTxs := range pendingPoolTxs { + plainTxsPerAcc := types.Transactions{} + for _, tx := range poolTxs { + if plainTx, ok := tx.(*types.Transaction); ok { + plainTxsPerAcc = append(plainTxsPerAcc, plainTx) + } else if stakingTx, ok := tx.(*staking.StakingTransaction); ok { + // Only process staking transactions after pre-staking epoch happened. + if node.Blockchain().Config().IsPreStaking(node.Worker.GetCurrentHeader().Epoch()) { + pendingStakingTxs = append(pendingStakingTxs, stakingTx) + } + } else { + utils.Logger().Err(types.ErrUnknownPoolTxType).Msg("Failed to parse pending transactions") + return nil, types.ErrUnknownPoolTxType + } + } + if plainTxsPerAcc.Len() > 0 { + pendingPlainTxs[addr] = plainTxsPerAcc } - node.pendingStakingTransactions = make(map[common.Hash]*staking.StakingTransaction) - node.pendingStakingTxMutex.Unlock() } if err := node.Worker.CommitTransactions( - pending, pendingStakingTransactions, beneficiary, - func(payload []staking.RPCTransactionError) { - node.errorSink.Lock() - for i := range payload { - node.errorSink.failedStakingTxns.Value = payload[i] - node.errorSink.failedStakingTxns = node.errorSink.failedStakingTxns.Next() - } - node.errorSink.Unlock() - }, + pendingPlainTxs, pendingStakingTxs, beneficiary, ); err != nil { utils.Logger().Error().Err(err).Msg("cannot commit transactions") return nil, err diff --git a/node/worker/worker.go b/node/worker/worker.go index e19e9aa2d..4dd6a9751 100644 --- a/node/worker/worker.go +++ b/node/worker/worker.go @@ -55,7 +55,6 @@ type Worker struct { func (w *Worker) CommitTransactions( pendingNormal map[common.Address]types.Transactions, pendingStaking staking.StakingTransactions, coinbase common.Address, - stkingTxErrorSink func([]staking.RPCTransactionError), ) error { if w.current.gasPool == nil { @@ -64,7 +63,6 @@ func (w *Worker) CommitTransactions( txs := types.NewTransactionsByPriceAndNonce(w.current.signer, pendingNormal) coalescedLogs := []*types.Log{} - erroredStakingTxns := []staking.RPCTransactionError{} // NORMAL for { // If we don't have enough gas for any further transactions then we're done @@ -134,12 +132,6 @@ func (w *Worker) CommitTransactions( logs, err := w.commitStakingTransaction(tx, coinbase) if err != nil { txID := tx.Hash().Hex() - erroredStakingTxns = append(erroredStakingTxns, staking.RPCTransactionError{ - TxHashID: txID, - StakingDirective: tx.StakingType().String(), - TimestampOfRejection: time.Now().Unix(), - ErrMessage: err.Error(), - }) utils.Logger().Error().Err(err). Str("stakingTxId", txID). Msg("Commit staking transaction error") @@ -151,8 +143,6 @@ func (w *Worker) CommitTransactions( } } } - // Here call the error functions - stkingTxErrorSink(erroredStakingTxns) utils.Logger().Info(). Int("newTxns", len(w.current.txs)). diff --git a/node/worker/worker_test.go b/node/worker/worker_test.go index 41016c727..6d2109dd4 100644 --- a/node/worker/worker_test.go +++ b/node/worker/worker_test.go @@ -15,7 +15,6 @@ import ( "github.com/harmony-one/harmony/core/vm" chain2 "github.com/harmony-one/harmony/internal/chain" "github.com/harmony-one/harmony/internal/params" - staking "github.com/harmony-one/harmony/staking/types" ) var ( @@ -80,7 +79,6 @@ func TestCommitTransactions(t *testing.T) { txs[testBankAddress] = types.Transactions{tx} err := worker.CommitTransactions( txs, nil, testBankAddress, - func([]staking.RPCTransactionError) {}, ) if err != nil { t.Error(err) diff --git a/staking/types/rpc-result.go b/staking/types/rpc-result.go deleted file mode 100644 index 41c5a53f8..000000000 --- a/staking/types/rpc-result.go +++ /dev/null @@ -1,9 +0,0 @@ -package types - -// RPCTransactionError .. -type RPCTransactionError struct { - TxHashID string `json:"tx-hash-id"` - StakingDirective string `json:"directive-kind"` - TimestampOfRejection int64 `json:"time-at-rejection"` - ErrMessage string `json:"error-message"` -} diff --git a/staking/types/transaction.go b/staking/types/transaction.go index cfa69fc07..00a0b640b 100644 --- a/staking/types/transaction.go +++ b/staking/types/transaction.go @@ -5,6 +5,7 @@ import ( "io" "math/big" "sync/atomic" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" @@ -72,6 +73,24 @@ type StakingTransaction struct { from atomic.Value } +// RPCTransactionError .. +type RPCTransactionError struct { + TxHashID string `json:"tx-hash-id"` + StakingDirective string `json:"directive-kind"` + TimestampOfRejection int64 `json:"time-at-rejection"` + ErrMessage string `json:"error-message"` +} + +// NewRPCTransactionError ... +func NewRPCTransactionError(hash common.Hash, directive Directive, err error) RPCTransactionError { + return RPCTransactionError{ + TxHashID: hash.Hex(), + StakingDirective: directive.String(), + TimestampOfRejection: time.Now().Unix(), + ErrMessage: err.Error(), + } +} + // StakeMsgFulfiller is signature of callback intended to produce the StakeMsg type StakeMsgFulfiller func() (Directive, interface{}) @@ -146,11 +165,16 @@ func (tx *StakingTransaction) Gas() uint64 { return tx.data.GasLimit } -// Price returns price of StakingTransaction. -func (tx *StakingTransaction) Price() *big.Int { +// GasPrice returns price of StakingTransaction. +func (tx *StakingTransaction) GasPrice() *big.Int { return tx.data.Price } +// Cost .. +func (tx *StakingTransaction) Cost() *big.Int { + return new(big.Int).Mul(tx.data.Price, new(big.Int).SetUint64(tx.data.GasLimit)) +} + // ChainID is what chain this staking transaction for func (tx *StakingTransaction) ChainID() *big.Int { return deriveChainID(tx.data.V) @@ -184,6 +208,48 @@ func (tx *StakingTransaction) RLPEncodeStakeMsg() (by []byte, err error) { return rlp.EncodeToBytes(tx.data.StakeMsg) } +// Protected .. +func (tx *StakingTransaction) Protected() bool { + return true +} + +// To .. +func (tx *StakingTransaction) To() *common.Address { + return nil +} + +// Data .. +func (tx *StakingTransaction) Data() []byte { + data, err := tx.RLPEncodeStakeMsg() + if err != nil { + return nil + } + return data +} + +// Value .. +func (tx *StakingTransaction) Value() *big.Int { + return new(big.Int).SetInt64(0) +} + +// Size .. +func (tx *StakingTransaction) Size() common.StorageSize { + if size := tx.size.Load(); size != nil { + return size.(common.StorageSize) + } + c := writeCounter(0) + rlp.Encode(&c, &tx.data) + tx.size.Store(common.StorageSize(c)) + return common.StorageSize(c) +} + +type writeCounter common.StorageSize + +func (c *writeCounter) Write(b []byte) (int, error) { + *c += writeCounter(len(b)) + return len(b), nil +} + // RLPDecodeStakeMsg .. func RLPDecodeStakeMsg(payload []byte, d Directive) (interface{}, error) { var oops error diff --git a/staking/types/validator.go b/staking/types/validator.go index e7d21a5f3..ec90625f9 100644 --- a/staking/types/validator.go +++ b/staking/types/validator.go @@ -106,6 +106,83 @@ type Validator struct { Banned bool } +// SanityCheck checks basic requirements of a validator +func (v *Validator) SanityCheck() error { + if _, err := v.EnsureLength(); err != nil { + return err + } + + if len(v.SlotPubKeys) == 0 { + return errNeedAtLeastOneSlotKey + } + + if v.MinSelfDelegation == nil { + return errNilMinSelfDelegation + } + + if v.MaxTotalDelegation == nil { + return errNilMaxTotalDelegation + } + + // MinSelfDelegation must be >= 1 ONE + if v.MinSelfDelegation.Cmp(big.NewInt(denominations.One)) < 0 { + return errors.Wrapf( + errMinSelfDelegationTooSmall, + "delegation-given %s", v.MinSelfDelegation.String(), + ) + } + + // MaxTotalDelegation must not be less than MinSelfDelegation + if v.MaxTotalDelegation.Cmp(v.MinSelfDelegation) < 0 { + return errors.Wrapf( + errInvalidMaxTotalDelegation, + "max-total-delegation %s min-self-delegation %s", + v.MaxTotalDelegation.String(), + v.MinSelfDelegation.String(), + ) + } + + if v.Rate.LT(zeroPercent) || v.Rate.GT(hundredPercent) { + return errors.Wrapf( + errInvalidCommissionRate, "rate:%s", v.Rate.String(), + ) + } + + if v.MaxRate.LT(zeroPercent) || v.MaxRate.GT(hundredPercent) { + return errors.Wrapf( + errInvalidCommissionRate, "rate:%s", v.MaxRate.String(), + ) + } + + if v.MaxChangeRate.LT(zeroPercent) || v.MaxChangeRate.GT(hundredPercent) { + return errors.Wrapf( + errInvalidCommissionRate, "rate:%s", v.MaxChangeRate.String(), + ) + } + + if v.Rate.GT(v.MaxRate) { + return errors.Wrapf( + errCommissionRateTooLarge, "rate:%s", v.MaxRate.String(), + ) + } + + if v.MaxChangeRate.GT(v.MaxRate) { + return errors.Wrapf( + errCommissionRateTooLarge, "rate:%s", v.MaxChangeRate.String(), + ) + } + + allKeys := map[shard.BlsPublicKey]struct{}{} + for i := range v.SlotPubKeys { + if _, ok := allKeys[v.SlotPubKeys[i]]; !ok { + allKeys[v.SlotPubKeys[i]] = struct{}{} + } else { + return errDuplicateSlotKeys + } + } + return nil +} + // MarshalJSON .. func (v *ValidatorStats) MarshalJSON() ([]byte, error) { type t struct { @@ -176,26 +253,9 @@ var ( // SanityCheck checks the basic requirements func (w *ValidatorWrapper) SanityCheck() error { - if len(w.SlotPubKeys) == 0 { - return errNeedAtLeastOneSlotKey - } - - if w.Validator.MinSelfDelegation == nil { - return errNilMinSelfDelegation - } - - if w.Validator.MaxTotalDelegation == nil { - return errNilMaxTotalDelegation - } - - // MinSelfDelegation must be >= 1 ONE - if w.Validator.MinSelfDelegation.Cmp(big.NewInt(denominations.One)) < 0 { - return errors.Wrapf( - errMinSelfDelegationTooSmall, - "delegation-given %s", w.Validator.MinSelfDelegation.String(), - ) + if err := w.Validator.SanityCheck(); err != nil { + return err } - // Self delegation must be >= MinSelfDelegation switch len(w.Delegations) { case 0: @@ -210,17 +270,6 @@ func (w *ValidatorWrapper) SanityCheck() error { ) } } - - // MaxTotalDelegation must not be less than MinSelfDelegation - if w.Validator.MaxTotalDelegation.Cmp(w.Validator.MinSelfDelegation) < 0 { - return errors.Wrapf( - errInvalidMaxTotalDelegation, - "max-total-delegation %s min-self-delegation %s", - w.Validator.MaxTotalDelegation.String(), - w.Validator.MinSelfDelegation.String(), - ) - } - totalDelegation := w.TotalDelegation() // Total delegation must be <= MaxTotalDelegation if totalDelegation.Cmp(w.Validator.MaxTotalDelegation) > 0 { @@ -231,45 +280,6 @@ func (w *ValidatorWrapper) SanityCheck() error { w.Validator.MaxTotalDelegation.String(), ) } - - if w.Validator.Rate.LT(zeroPercent) || w.Validator.Rate.GT(hundredPercent) { - return errors.Wrapf( - errInvalidCommissionRate, "rate:%s", w.Validator.Rate.String(), - ) - } - - if w.Validator.MaxRate.LT(zeroPercent) || w.Validator.MaxRate.GT(hundredPercent) { - return errors.Wrapf( - errInvalidCommissionRate, "rate:%s", w.Validator.MaxRate.String(), - ) - } - - if w.Validator.MaxChangeRate.LT(zeroPercent) || w.Validator.MaxChangeRate.GT(hundredPercent) { - return errors.Wrapf( - errInvalidCommissionRate, "rate:%s", w.Validator.MaxChangeRate.String(), - ) - } - - if w.Validator.Rate.GT(w.Validator.MaxRate) { - return errors.Wrapf( - errCommissionRateTooLarge, "rate:%s", w.Validator.MaxRate.String(), - ) - } - - if w.Validator.MaxChangeRate.GT(w.Validator.MaxRate) { - return errors.Wrapf( - errCommissionRateTooLarge, "rate:%s", w.Validator.MaxChangeRate.String(), - ) - } - - allKeys := map[shard.BlsPublicKey]struct{}{} - for i := range w.Validator.SlotPubKeys { - if _, ok := allKeys[w.Validator.SlotPubKeys[i]]; !ok { - allKeys[w.Validator.SlotPubKeys[i]] = struct{}{} - } else { - return errDuplicateSlotKeys - } - } return nil } @@ -350,13 +360,15 @@ func (v *Validator) GetCommissionRate() numeric.Dec { return v.Commission.Rate } // GetMinSelfDelegation returns the minimum amount the validator must stake func (v *Validator) GetMinSelfDelegation() *big.Int { return v.MinSelfDelegation } -func verifyBLSKeys(pubKeys []shard.BlsPublicKey, pubKeySigs []shard.BLSSignature) error { +// VerifyBLSKeys checks if the public BLS key at index i of pubKeys matches the +// BLS key signature at index i of pubKeysSigs. +func VerifyBLSKeys(pubKeys []shard.BlsPublicKey, pubKeySigs []shard.BLSSignature) error { if len(pubKeys) != len(pubKeySigs) { return errBLSKeysNotMatchSigs } for i := 0; i < len(pubKeys); i++ { - if err := verifyBLSKey(&pubKeys[i], &pubKeySigs[i]); err != nil { + if err := VerifyBLSKey(&pubKeys[i], &pubKeySigs[i]); err != nil { return err } } @@ -364,7 +376,8 @@ func verifyBLSKeys(pubKeys []shard.BlsPublicKey, pubKeySigs []shard.BLSSignature return nil } -func verifyBLSKey(pubKey *shard.BlsPublicKey, pubKeySig *shard.BLSSignature) error { +// VerifyBLSKey checks if the public BLS key matches the BLS signature +func VerifyBLSKey(pubKey *shard.BlsPublicKey, pubKeySig *shard.BLSSignature) error { if len(pubKeySig) == 0 { return errBLSKeysNotMatchSigs } @@ -397,7 +410,7 @@ func CreateValidatorFromNewMsg(val *CreateValidator, blockNum *big.Int) (*Valida commission := Commission{val.CommissionRates, blockNum} pubKeys := append(val.SlotPubKeys[0:0], val.SlotPubKeys...) - if err = verifyBLSKeys(pubKeys, val.SlotKeySigs); err != nil { + if err = VerifyBLSKeys(pubKeys, val.SlotKeySigs); err != nil { return nil, err } @@ -458,7 +471,7 @@ func UpdateValidatorFromEditMsg(validator *Validator, edit *EditValidator) error } } if !found { - if err := verifyBLSKey(edit.SlotKeyToAdd, edit.SlotKeyToAddSig); err != nil { + if err := VerifyBLSKey(edit.SlotKeyToAdd, edit.SlotKeyToAddSig); err != nil { return err } validator.SlotPubKeys = append(validator.SlotPubKeys, *edit.SlotKeyToAdd) diff --git a/test/chain/main.go b/test/chain/main.go index a56b32b50..d01a75f2d 100644 --- a/test/chain/main.go +++ b/test/chain/main.go @@ -129,7 +129,6 @@ func fundFaucetContract(chain *core.BlockChain) { err := contractworker.CommitTransactions( txmap, nil, testUserAddress, - func([]staking.RPCTransactionError) {}, ) if err != nil { fmt.Println(err) @@ -172,7 +171,6 @@ func callFaucetContractToFundAnAddress(chain *core.BlockChain) { err = contractworker.CommitTransactions( txmap, nil, testUserAddress, - func([]staking.RPCTransactionError) {}, ) if err != nil { fmt.Println(err) @@ -206,14 +204,19 @@ func main() { genesis := gspec.MustCommit(database) chain, _ := core.NewBlockChain(database, nil, gspec.Config, chain.Engine(), vm.Config{}, nil) - txpool := core.NewTxPool(core.DefaultTxPoolConfig, chainConfig, chain, func([]types.RPCTransactionError) {}) + txpool := core.NewTxPool(core.DefaultTxPoolConfig, chainConfig, chain, + func([]types.RPCTransactionError) {}, func([]staking.RPCTransactionError) {}) backend := &testWorkerBackend{ db: database, chain: chain, txPool: txpool, } - backend.txPool.AddLocals(pendingTxs) + poolPendingTx := types.PoolTransactions{} + for _, tx := range pendingTxs { + poolPendingTx = append(poolPendingTx, tx) + } + backend.txPool.AddLocals(poolPendingTx) //// Generate a small n-block chain and an uncle block for it n := 3