diff --git a/api/service/stagedstreamsync/stage_statesync.go b/api/service/stagedstreamsync/stage_statesync.go index 9391944b7..75326b6ac 100644 --- a/api/service/stagedstreamsync/stage_statesync.go +++ b/api/service/stagedstreamsync/stage_statesync.go @@ -3,11 +3,15 @@ package stagedstreamsync import ( "context" "fmt" + "sync" "time" + "github.com/ethereum/go-ethereum/common" "github.com/harmony-one/harmony/core" "github.com/harmony-one/harmony/internal/utils" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" "github.com/ledgerwatch/erigon-lib/kv" + "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog" ) @@ -49,7 +53,7 @@ func NewStageStateSyncCfg(bc core.BlockChain, } // Exec progresses States stage in the forward direction -func (stg *StageStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bool, s *StageState, reverter Reverter, tx kv.RwTx) (err error) { +func (sss *StageStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bool, s *StageState, reverter Reverter, tx kv.RwTx) (err error) { // for short range sync, skip this step if !s.state.initSync { @@ -57,19 +61,29 @@ func (stg *StageStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bo } maxHeight := s.state.status.targetBN - currentHead := stg.configs.bc.CurrentBlock().NumberU64() + currentHead := sss.configs.bc.CurrentBlock().NumberU64() if currentHead >= maxHeight { return nil } - currProgress := stg.configs.bc.CurrentBlock().NumberU64() + currProgress := sss.configs.bc.CurrentBlock().NumberU64() targetHeight := s.state.currentCycle.TargetHeight + + if errV := CreateView(ctx, sss.configs.db, tx, func(etx kv.Tx) error { + if currProgress, err = s.CurrentStageProgress(etx); err != nil { + return err + } + return nil + }); errV != nil { + return errV + } + if currProgress >= targetHeight { return nil } useInternalTx := tx == nil if useInternalTx { var err error - tx, err = stg.configs.db.BeginRw(ctx) + tx, err = sss.configs.db.BeginRw(ctx) if err != nil { return err } @@ -78,34 +92,107 @@ func (stg *StageStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bo // isLastCycle := targetHeight >= maxHeight startTime := time.Now() - startBlock := currProgress - if stg.configs.logProgress { + if sss.configs.logProgress { fmt.Print("\033[s") // save the cursor position } - for i := currProgress + 1; i <= targetHeight; i++ { - // log the stage progress in console - if stg.configs.logProgress { - //calculating block speed + // Fetch blocks from neighbors + root := sss.configs.bc.CurrentBlock().Root() + sdm := newStateDownloadManager(tx, sss.configs.bc, root, sss.configs.concurrency, s.state.logger) + + // Setup workers to fetch blocks from remote node + var wg sync.WaitGroup + + for i := 0; i != s.state.config.Concurrency; i++ { + wg.Add(1) + go sss.runStateWorkerLoop(ctx, sdm, &wg, i, startTime) + } + + wg.Wait() + + if useInternalTx { + if err := tx.Commit(); err != nil { + return err + } + } + + return nil +} + +// runStateWorkerLoop creates a work loop for download states +func (sss *StageStateSync) runStateWorkerLoop(ctx context.Context, sdm *StateDownloadManager, wg *sync.WaitGroup, loopID int, startTime time.Time) { + defer wg.Done() + + for { + select { + case <-ctx.Done(): + return + default: + } + nodes, paths, codes := sdm.GetNextBatch() + if len(nodes)+len(codes) == 0 { + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + return + } + } + + data, stid, err := sss.downloadStates(ctx, nodes, codes) + if err != nil { + if !errors.Is(err, context.Canceled) { + sss.configs.protocol.StreamFailed(stid, "downloadStates failed") + } + utils.Logger().Error(). + Err(err). + Str("stream", string(stid)). + Msg(WrapStagedSyncMsg("downloadStates failed")) + err = errors.Wrap(err, "request error") + sdm.HandleRequestError(codes, paths, stid, err) + } else if data == nil || len(data) == 0 { + utils.Logger().Warn(). + Str("stream", string(stid)). + Msg(WrapStagedSyncMsg("downloadStates failed, received empty data bytes")) + err := errors.New("downloadStates received empty data bytes") + sdm.HandleRequestError(codes, paths, stid, err) + } + sdm.HandleRequestResult(nodes, paths, data, loopID, stid) + if sss.configs.logProgress { + //calculating block download speed dt := time.Now().Sub(startTime).Seconds() speed := float64(0) if dt > 0 { - speed = float64(currProgress-startBlock) / dt + speed = float64(len(data)) / dt } - blockSpeed := fmt.Sprintf("%.2f", speed) + stateDownloadSpeed := fmt.Sprintf("%.2f", speed) + fmt.Print("\033[u\033[K") // restore the cursor position and clear the line - fmt.Println("insert blocks progress:", currProgress, "/", targetHeight, "(", blockSpeed, "blocks/s", ")") + fmt.Println("state download speed:", stateDownloadSpeed, "states/s") } - } +} - if useInternalTx { - if err := tx.Commit(); err != nil { - return err - } +func (sss *StageStateSync) downloadStates(ctx context.Context, nodes []common.Hash, codes []common.Hash) ([][]byte, sttypes.StreamID, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + hashes := append(codes, nodes...) + data, stid, err := sss.configs.protocol.GetNodeData(ctx, hashes) + if err != nil { + return nil, stid, err + } + if err := validateGetNodeDataResult(hashes, data); err != nil { + return nil, stid, err } + return data, stid, nil +} +func validateGetNodeDataResult(requested []common.Hash, result [][]byte) error { + if len(result) != len(requested) { + return fmt.Errorf("unexpected number of nodes delivered: %v / %v", len(result), len(requested)) + } return nil }