diff --git a/hmy/staking.go b/hmy/staking.go index 4484c3fd3..21fe26c34 100644 --- a/hmy/staking.go +++ b/hmy/staking.go @@ -519,7 +519,29 @@ func (hmy *Harmony) GetDelegationsByDelegatorByBlock( } // UndelegationPayouts .. -type UndelegationPayouts map[common.Address]map[common.Address]*big.Int +type UndelegationPayouts struct { + Data map[common.Address]map[common.Address]*big.Int +} + +func NewUndelegationPayouts() *UndelegationPayouts { + return &UndelegationPayouts{ + Data: make(map[common.Address]map[common.Address]*big.Int), + } +} + +func (u *UndelegationPayouts) SetPayoutByDelegatorAddrAndValidatorAddr( + delegator, validator common.Address, amount *big.Int, +) { + if u.Data[delegator] == nil { + u.Data[delegator] = make(map[common.Address]*big.Int) + } + + if totalPayout, ok := u.Data[delegator][validator]; ok { + u.Data[delegator][validator] = new(big.Int).Add(totalPayout, amount) + } else { + u.Data[delegator][validator] = amount + } +} // GetUndelegationPayouts returns the undelegation payouts for each delegator // @@ -528,16 +550,17 @@ type UndelegationPayouts map[common.Address]map[common.Address]*big.Int // This not a problem if a full (archival) DB is used. func (hmy *Harmony) GetUndelegationPayouts( ctx context.Context, epoch *big.Int, -) (UndelegationPayouts, error) { +) (*UndelegationPayouts, error) { if !hmy.IsPreStakingEpoch(epoch) { return nil, fmt.Errorf("not pre-staking epoch or later") } payouts, ok := hmy.undelegationPayoutsCache.Get(epoch.Uint64()) if ok { - return payouts.(UndelegationPayouts), nil + result := payouts.(UndelegationPayouts) + return &result, nil } - undelegationPayouts := UndelegationPayouts{} + undelegationPayouts := NewUndelegationPayouts() // require second to last block as saved undelegations are AFTER undelegations are payed out blockNumber := shard.Schedule.EpochLastBlock(epoch.Uint64()) - 1 undelegationPayoutBlock, err := hmy.BlockByNumber(ctx, rpc.BlockNumber(blockNumber)) @@ -556,14 +579,7 @@ func (hmy *Harmony) GetUndelegationPayouts( for _, delegation := range wrapper.Delegations { withdraw := delegation.RemoveUnlockedUndelegations(epoch, wrapper.LastEpochInCommittee, lockingPeriod, noEarlyUnlock) if withdraw.Cmp(bigZero) == 1 { - if undelegationPayouts[delegation.DelegatorAddress] == nil { - undelegationPayouts[delegation.DelegatorAddress] = make(map[common.Address]*big.Int) - } - if totalPayout, ok := undelegationPayouts[delegation.DelegatorAddress][validator]; ok { - undelegationPayouts[delegation.DelegatorAddress][validator] = new(big.Int).Add(totalPayout, withdraw) - } else { - undelegationPayouts[delegation.DelegatorAddress][validator] = withdraw - } + undelegationPayouts.SetPayoutByDelegatorAddrAndValidatorAddr(validator, delegation.DelegatorAddress, withdraw) } } } diff --git a/rosetta/services/tx_operation.go b/rosetta/services/tx_operation.go index db72c26d4..11524fcf4 100644 --- a/rosetta/services/tx_operation.go +++ b/rosetta/services/tx_operation.go @@ -219,9 +219,9 @@ func GetDelegateOperationForSubAccount(tx *stakingTypes.StakingTransaction, dele // GetSideEffectOperationsFromUndelegationPayouts from the given payouts. // If the startingOperationIndex is provided, all operations will be indexed starting from the given operation index. func GetSideEffectOperationsFromUndelegationPayouts( - payouts hmy.UndelegationPayouts, startingOperationIndex *int64, + payouts *hmy.UndelegationPayouts, startingOperationIndex *int64, ) ([]*types.Operation, *types.Error) { - return getSideEffectOperationsFromUndelegateMap( + return getSideEffectOperationsFromUndelegationPayouts( payouts, common.UndelegationPayoutOperation, startingOperationIndex, ) } @@ -451,8 +451,8 @@ func getCrossShardSenderTransferNativeOperations( } // delegator address => validator address => amount -func getSideEffectOperationsFromUndelegateMap( - valueMap map[ethcommon.Address]map[ethcommon.Address]*big.Int, opType string, startingOperationIndex *int64, +func getSideEffectOperationsFromUndelegationPayouts( + undelegationPayouts *hmy.UndelegationPayouts, opType string, startingOperationIndex *int64, ) ([]*types.Operation, *types.Error) { var opIndex int64 operations := []*types.Operation{} @@ -462,7 +462,7 @@ func getSideEffectOperationsFromUndelegateMap( opIndex = 0 } - for delegator, undelegationMap := range valueMap { + for delegator, undelegationMap := range undelegationPayouts.Data { accID, rosettaError := newAccountIdentifier(delegator) if rosettaError != nil { @@ -498,7 +498,7 @@ func getSideEffectOperationsFromUndelegateMap( return operations, nil } -// getOperationAndTotalAmountFromUndelegationMap is a helper for getSideEffectOperationsFromUndelegateMap which actually +// getOperationAndTotalAmountFromUndelegationMap is a helper for getSideEffectOperationsFromUndelegationPayouts which actually // has some side effect(opIndex will be increased by this function) so be careful while using for other purpose func getOperationAndTotalAmountFromUndelegationMap( delegator ethcommon.Address, opIndex *int64, relatedOpIdentifier *types.OperationIdentifier, opType string, diff --git a/rosetta/services/tx_operation_test.go b/rosetta/services/tx_operation_test.go index 677394001..5e21fed5a 100644 --- a/rosetta/services/tx_operation_test.go +++ b/rosetta/services/tx_operation_test.go @@ -245,11 +245,10 @@ func TestGetStakingOperationsFromDelegate(t *testing.T) { func TestGetSideEffectOperationsFromUndelegationPayouts(t *testing.T) { startingOperationIndex := int64(0) - undelegationPayouts := hmy.UndelegationPayouts{} + undelegationPayouts := hmy.NewUndelegationPayouts() delegator := ethcommon.HexToAddress("0xB5f440B5c6215eEDc1b2E12b4b964fa31f7afa7d") validator := ethcommon.HexToAddress("0x3b8DE43c8F30D3C387840681FED67783f93f1F94") - undelegationPayouts[delegator] = make(map[ethcommon.Address]*big.Int) - undelegationPayouts[delegator][validator] = new(big.Int).SetInt64(4000) + undelegationPayouts.SetPayoutByDelegatorAddrAndValidatorAddr(delegator, validator, new(big.Int).SetInt64(4000)) operations, err := GetSideEffectOperationsFromUndelegationPayouts(undelegationPayouts, &startingOperationIndex) if err != nil { t.Fatal(err)