refactor stage state sync

pull/4465/head
“GheisMohammadi” 1 year ago
parent 9e1249a836
commit 841073da60
No known key found for this signature in database
GPG Key ID: 15073AED3829FE90
  1. 121
      api/service/stagedstreamsync/stage_statesync.go

@ -3,11 +3,15 @@ package stagedstreamsync
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/ethereum/go-ethereum/common"
"github.com/harmony-one/harmony/core" "github.com/harmony-one/harmony/core"
"github.com/harmony-one/harmony/internal/utils" "github.com/harmony-one/harmony/internal/utils"
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
"github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
@ -49,7 +53,7 @@ func NewStageStateSyncCfg(bc core.BlockChain,
} }
// Exec progresses States stage in the forward direction // Exec progresses States stage in the forward direction
func (stg *StageStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bool, s *StageState, reverter Reverter, tx kv.RwTx) (err error) { func (sss *StageStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bool, s *StageState, reverter Reverter, tx kv.RwTx) (err error) {
// for short range sync, skip this step // for short range sync, skip this step
if !s.state.initSync { if !s.state.initSync {
@ -57,19 +61,29 @@ func (stg *StageStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bo
} }
maxHeight := s.state.status.targetBN maxHeight := s.state.status.targetBN
currentHead := stg.configs.bc.CurrentBlock().NumberU64() currentHead := sss.configs.bc.CurrentBlock().NumberU64()
if currentHead >= maxHeight { if currentHead >= maxHeight {
return nil return nil
} }
currProgress := stg.configs.bc.CurrentBlock().NumberU64() currProgress := sss.configs.bc.CurrentBlock().NumberU64()
targetHeight := s.state.currentCycle.TargetHeight targetHeight := s.state.currentCycle.TargetHeight
if errV := CreateView(ctx, sss.configs.db, tx, func(etx kv.Tx) error {
if currProgress, err = s.CurrentStageProgress(etx); err != nil {
return err
}
return nil
}); errV != nil {
return errV
}
if currProgress >= targetHeight { if currProgress >= targetHeight {
return nil return nil
} }
useInternalTx := tx == nil useInternalTx := tx == nil
if useInternalTx { if useInternalTx {
var err error var err error
tx, err = stg.configs.db.BeginRw(ctx) tx, err = sss.configs.db.BeginRw(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -78,34 +92,107 @@ func (stg *StageStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bo
// isLastCycle := targetHeight >= maxHeight // isLastCycle := targetHeight >= maxHeight
startTime := time.Now() startTime := time.Now()
startBlock := currProgress
if stg.configs.logProgress { if sss.configs.logProgress {
fmt.Print("\033[s") // save the cursor position fmt.Print("\033[s") // save the cursor position
} }
for i := currProgress + 1; i <= targetHeight; i++ { // Fetch blocks from neighbors
// log the stage progress in console root := sss.configs.bc.CurrentBlock().Root()
if stg.configs.logProgress { sdm := newStateDownloadManager(tx, sss.configs.bc, root, sss.configs.concurrency, s.state.logger)
//calculating block speed
// Setup workers to fetch blocks from remote node
var wg sync.WaitGroup
for i := 0; i != s.state.config.Concurrency; i++ {
wg.Add(1)
go sss.runStateWorkerLoop(ctx, sdm, &wg, i, startTime)
}
wg.Wait()
if useInternalTx {
if err := tx.Commit(); err != nil {
return err
}
}
return nil
}
// runStateWorkerLoop creates a work loop for download states
func (sss *StageStateSync) runStateWorkerLoop(ctx context.Context, sdm *StateDownloadManager, wg *sync.WaitGroup, loopID int, startTime time.Time) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
nodes, paths, codes := sdm.GetNextBatch()
if len(nodes)+len(codes) == 0 {
select {
case <-ctx.Done():
return
case <-time.After(100 * time.Millisecond):
return
}
}
data, stid, err := sss.downloadStates(ctx, nodes, codes)
if err != nil {
if !errors.Is(err, context.Canceled) {
sss.configs.protocol.StreamFailed(stid, "downloadStates failed")
}
utils.Logger().Error().
Err(err).
Str("stream", string(stid)).
Msg(WrapStagedSyncMsg("downloadStates failed"))
err = errors.Wrap(err, "request error")
sdm.HandleRequestError(codes, paths, stid, err)
} else if data == nil || len(data) == 0 {
utils.Logger().Warn().
Str("stream", string(stid)).
Msg(WrapStagedSyncMsg("downloadStates failed, received empty data bytes"))
err := errors.New("downloadStates received empty data bytes")
sdm.HandleRequestError(codes, paths, stid, err)
}
sdm.HandleRequestResult(nodes, paths, data, loopID, stid)
if sss.configs.logProgress {
//calculating block download speed
dt := time.Now().Sub(startTime).Seconds() dt := time.Now().Sub(startTime).Seconds()
speed := float64(0) speed := float64(0)
if dt > 0 { if dt > 0 {
speed = float64(currProgress-startBlock) / dt speed = float64(len(data)) / dt
} }
blockSpeed := fmt.Sprintf("%.2f", speed) stateDownloadSpeed := fmt.Sprintf("%.2f", speed)
fmt.Print("\033[u\033[K") // restore the cursor position and clear the line fmt.Print("\033[u\033[K") // restore the cursor position and clear the line
fmt.Println("insert blocks progress:", currProgress, "/", targetHeight, "(", blockSpeed, "blocks/s", ")") fmt.Println("state download speed:", stateDownloadSpeed, "states/s")
} }
} }
}
if useInternalTx { func (sss *StageStateSync) downloadStates(ctx context.Context, nodes []common.Hash, codes []common.Hash) ([][]byte, sttypes.StreamID, error) {
if err := tx.Commit(); err != nil { ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
return err defer cancel()
hashes := append(codes, nodes...)
data, stid, err := sss.configs.protocol.GetNodeData(ctx, hashes)
if err != nil {
return nil, stid, err
} }
if err := validateGetNodeDataResult(hashes, data); err != nil {
return nil, stid, err
} }
return data, stid, nil
}
func validateGetNodeDataResult(requested []common.Hash, result [][]byte) error {
if len(result) != len(requested) {
return fmt.Errorf("unexpected number of nodes delivered: %v / %v", len(result), len(requested))
}
return nil return nil
} }

Loading…
Cancel
Save