From 0901e92bf8cc17085e072dbc90294b46e49dd0f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CGheisMohammadi=E2=80=9D?= <36589218+GheisMohammadi@users.noreply.github.com> Date: Thu, 7 Dec 2023 16:32:03 +0800 Subject: [PATCH] add state sync full, complete full state sync stage --- .../stagedstreamsync/stage_statesync_full.go | 449 ++++++++++++++ .../{state_sync.go => state_sync_full.go} | 583 +++++++++++++++--- api/service/stagedstreamsync/syncing.go | 2 +- p2p/stream/protocols/sync/chain.go | 2 +- 4 files changed, 951 insertions(+), 85 deletions(-) create mode 100644 api/service/stagedstreamsync/stage_statesync_full.go rename api/service/stagedstreamsync/{state_sync.go => state_sync_full.go} (80%) diff --git a/api/service/stagedstreamsync/stage_statesync_full.go b/api/service/stagedstreamsync/stage_statesync_full.go new file mode 100644 index 000000000..3e190bdc9 --- /dev/null +++ b/api/service/stagedstreamsync/stage_statesync_full.go @@ -0,0 +1,449 @@ +package stagedstreamsync + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/harmony-one/harmony/core" + "github.com/harmony-one/harmony/internal/utils" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/pkg/errors" + + //sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/ledgerwatch/erigon-lib/kv" + "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" +) + +type StageFullStateSync struct { + configs StageFullStateSyncCfg +} + +type StageFullStateSyncCfg struct { + bc core.BlockChain + db kv.RwDB + concurrency int + protocol syncProtocol + logger zerolog.Logger + logProgress bool +} + +func NewStageFullStateSync(cfg StageFullStateSyncCfg) *StageFullStateSync { + return &StageFullStateSync{ + configs: cfg, + } +} + +func NewStageFullStateSyncCfg(bc core.BlockChain, + db kv.RwDB, + concurrency int, + protocol syncProtocol, + logger zerolog.Logger, + logProgress bool) StageFullStateSyncCfg { + + return StageFullStateSyncCfg{ + bc: bc, + db: db, + concurrency: concurrency, + protocol: protocol, + logger: logger, + logProgress: logProgress, + } +} + +// Exec progresses States stage in the forward direction +func (sss *StageFullStateSync) Exec(ctx context.Context, bool, invalidBlockRevert bool, s *StageState, reverter Reverter, tx kv.RwTx) (err error) { + + // for short range sync, skip this step + if !s.state.initSync { + return nil + } // only execute this stage in fast/snap sync mode and once we reach to pivot + + if s.state.status.pivotBlock == nil || + s.state.CurrentBlockNumber() != s.state.status.pivotBlock.NumberU64() || + s.state.status.statesSynced { + return nil + } + + s.state.Debug("STATE SYNC ======================================================>", "started") + // maxHeight := s.state.status.targetBN + // currentHead := s.state.CurrentBlockNumber() + // if currentHead >= maxHeight { + // return nil + // } + // currProgress := s.state.CurrentBlockNumber() + // targetHeight := s.state.currentCycle.TargetHeight + + // if errV := CreateView(ctx, sss.configs.db, tx, func(etx kv.Tx) error { + // if currProgress, err = s.CurrentStageProgress(etx); err != nil { + // return err + // } + // return nil + // }); errV != nil { + // return errV + // } + + // if currProgress >= targetHeight { + // return nil + // } + useInternalTx := tx == nil + if useInternalTx { + var err error + tx, err = sss.configs.db.BeginRw(ctx) + if err != nil { + return err + } + defer tx.Rollback() + } + + // isLastCycle := targetHeight >= maxHeight + startTime := time.Now() + + if sss.configs.logProgress { + fmt.Print("\033[s") // save the cursor position + } + + // Fetch states from neighbors + pivotRootHash := s.state.status.pivotBlock.Root() + currentBlockRootHash := s.state.bc.CurrentFastBlock().Root() + scheme := sss.configs.bc.TrieDB().Scheme() + sdm := newFullStateDownloadManager(sss.configs.bc.ChainDb(), scheme, tx, sss.configs.bc, sss.configs.concurrency, s.state.logger) + sdm.setRootHash(currentBlockRootHash) + s.state.Debug("StateSync/setRootHash", pivotRootHash) + s.state.Debug("StateSync/currentFastBlockRoot", currentBlockRootHash) + s.state.Debug("StateSync/pivotBlockNumber", s.state.status.pivotBlock.NumberU64()) + s.state.Debug("StateSync/currentFastBlockNumber", s.state.bc.CurrentFastBlock().NumberU64()) + var wg sync.WaitGroup + for i := 0; i < s.state.config.Concurrency; i++ { + wg.Add(1) + go sss.runStateWorkerLoop(ctx, sdm, &wg, i, startTime, s) + } + wg.Wait() + + // insert block + if err := sss.configs.bc.WriteHeadBlock(s.state.status.pivotBlock); err != nil { + sss.configs.logger.Warn().Err(err). + Uint64("pivot block number", s.state.status.pivotBlock.NumberU64()). + Msg(WrapStagedSyncMsg("insert pivot block failed")) + s.state.Debug("StateSync/pivot/insert/error", err) + // TODO: panic("pivot block is failed to insert in chain.") + return err + } + + // states should be fully synced in this stage + s.state.status.statesSynced = true + + s.state.Debug("StateSync/pivot/num", s.state.status.pivotBlock.NumberU64()) + s.state.Debug("StateSync/pivot/insert", "done") + + /* + gbm := s.state.gbm + + // Setup workers to fetch states from remote node + var wg sync.WaitGroup + curHeight := s.state.CurrentBlockNumber() + + for bn := curHeight + 1; bn <= gbm.targetBN; bn++ { + root := gbm.GetRootHash(bn) + if root == emptyHash { + continue + } + sdm.setRootHash(root) + for i := 0; i < s.state.config.Concurrency; i++ { + wg.Add(1) + go sss.runStateWorkerLoop(ctx, sdm, &wg, i, startTime, s) + } + wg.Wait() + } + */ + + if useInternalTx { + if err := tx.Commit(); err != nil { + return err + } + } + + return nil +} + +// runStateWorkerLoop creates a work loop for download states +func (sss *StageFullStateSync) runStateWorkerLoop(ctx context.Context, sdm *FullStateDownloadManager, wg *sync.WaitGroup, loopID int, startTime time.Time, s *StageState) { + + s.state.Debug("runStateWorkerLoop/info", "started") + + defer wg.Done() + + for { + select { + case <-ctx.Done(): + s.state.Debug("runStateWorkerLoop/ctx/done", "Finished") + return + default: + } + accountTasks, codes, storages, healtask, codetask, err := sdm.GetNextBatch() + s.state.Debug("runStateWorkerLoop/batch/len", len(accountTasks)+len(codes)+len(storages.accounts)) + s.state.Debug("runStateWorkerLoop/batch/heals/len", len(healtask.hashes)+len(codetask.hashes)) + s.state.Debug("runStateWorkerLoop/batch/err", err) + if len(accountTasks)+len(codes)+len(storages.accounts)+len(healtask.hashes)+len(codetask.hashes) == 0 || err != nil { + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + return + } + } + s.state.Debug("runStateWorkerLoop/batch/accounts", accountTasks) + s.state.Debug("runStateWorkerLoop/batch/codes", codes) + + if len(accountTasks) > 0 { + + task := accountTasks[0] + origin := task.Next + limit := task.Last + root := sdm.root + cap := maxRequestSize + retAccounts, proof, stid, err := sss.configs.protocol.GetAccountRange(ctx, root, origin, limit, uint64(cap)) + if err != nil { + return + } + if err := sdm.HandleAccountRequestResult(task, retAccounts, proof, origin[:], limit[:], loopID, stid); err != nil { + return + } + + } else if len(codes)+len(storages.accounts) > 0 { + + if len(codes) > 0 { + stid, err := sss.downloadByteCodes(ctx, sdm, codes, loopID) + if err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + sss.configs.protocol.StreamFailed(stid, "downloadByteCodes failed") + } + utils.Logger().Error(). + Err(err). + Str("stream", string(stid)). + Msg(WrapStagedSyncMsg("downloadByteCodes failed")) + err = errors.Wrap(err, "request error") + sdm.HandleRequestError(accountTasks, codes, storages, healtask, codetask, stid, err) + return + } + } + + if len(storages.accounts) > 0 { + root := sdm.root + roots := storages.roots + accounts := storages.accounts + cap := maxRequestSize + origin := storages.origin + limit := storages.limit + mainTask := storages.mainTask + subTask := storages.subtask + + slots, proof, stid, err := sss.configs.protocol.GetStorageRanges(ctx, root, accounts, origin, limit, uint64(cap)) + if err != nil { + return + } + if err := sdm.HandleStorageRequestResult(mainTask, subTask, accounts, roots, origin, limit, slots, proof, loopID, stid); err != nil { + return + } + } + + // data, stid, err := sss.downloadStates(ctx, accounts, codes, storages) + // if err != nil { + // s.state.Debug("runStateWorkerLoop/downloadStates/error", err) + // if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + // sss.configs.protocol.StreamFailed(stid, "downloadStates failed") + // } + // utils.Logger().Error(). + // Err(err). + // Str("stream", string(stid)). + // Msg(WrapStagedSyncMsg("downloadStates failed")) + // err = errors.Wrap(err, "request error") + // sdm.HandleRequestError(codes, paths, stid, err) + // } else if data == nil || len(data) == 0 { + // s.state.Debug("runStateWorkerLoop/downloadStates/data", "nil array") + // utils.Logger().Warn(). + // Str("stream", string(stid)). + // Msg(WrapStagedSyncMsg("downloadStates failed, received empty data bytes")) + // err := errors.New("downloadStates received empty data bytes") + // sdm.HandleRequestError(codes, paths, stid, err) + // } else { + // s.state.Debug("runStateWorkerLoop/downloadStates/data/len", len(data)) + // sdm.HandleRequestResult(nodes, paths, data, loopID, stid) + // if sss.configs.logProgress { + // //calculating block download speed + // dt := time.Now().Sub(startTime).Seconds() + // speed := float64(0) + // if dt > 0 { + // speed = float64(len(data)) / dt + // } + // stateDownloadSpeed := fmt.Sprintf("%.2f", speed) + + // fmt.Print("\033[u\033[K") // restore the cursor position and clear the line + // fmt.Println("state download speed:", stateDownloadSpeed, "states/s") + // } + // } + + } else { + // assign trie node Heal Tasks + if len(healtask.hashes) > 0 { + root := sdm.root + task := healtask.task + hashes := healtask.hashes + pathsets := healtask.pathsets + paths := healtask.paths + + nodes, stid, err := sss.configs.protocol.GetTrieNodes(ctx, root, pathsets, maxRequestSize) + if err != nil { + return + } + if err := sdm.HandleTrieNodeHealRequestResult(task, paths, hashes, nodes, loopID, stid); err != nil { + return + } + } + + if len(codetask.hashes) > 0 { + task := codetask.task + hashes := codetask.hashes + codes, stid, err := sss.configs.protocol.GetByteCodes(ctx, hashes, maxRequestSize) + if err != nil { + return + } + if err := sdm.HandleBytecodeRequestResult(task, hashes, codes, loopID, stid); err != nil { + return + } + } + } + } +} + +func (sss *StageFullStateSync) downloadByteCodes(ctx context.Context, sdm *FullStateDownloadManager, codeTasks []*byteCodeTasksBundle, loopID int) (stid sttypes.StreamID, err error) { + for _, codeTask := range codeTasks { + // try to get byte codes from remote peer + // if any of them failed, the stid will be the id of the failed stream + retCodes, stid, err := sss.configs.protocol.GetByteCodes(ctx, codeTask.hashes, maxRequestSize) + if err != nil { + return stid, err + } + if err = sdm.HandleBytecodeRequestResult(codeTask.task, codeTask.hashes, retCodes, loopID, stid); err != nil { + return stid, err + } + } + return +} + +func (sss *StageFullStateSync) downloadStorages(ctx context.Context, sdm *FullStateDownloadManager, codeTasks []*byteCodeTasksBundle, loopID int) (stid sttypes.StreamID, err error) { + for _, codeTask := range codeTasks { + // try to get byte codes from remote peer + // if any of them failed, the stid will be the id of failed stream + retCodes, stid, err := sss.configs.protocol.GetByteCodes(ctx, codeTask.hashes, maxRequestSize) + if err != nil { + return stid, err + } + if err = sdm.HandleBytecodeRequestResult(codeTask.task, codeTask.hashes, retCodes, loopID, stid); err != nil { + return stid, err + } + } + return +} + +// func (sss *StageFullStateSync) downloadStates(ctx context.Context, +// root common.Hash, +// origin common.Hash, +// accounts []*accountTask, +// codes []common.Hash, +// storages *storageTaskBundle) ([][]byte, sttypes.StreamID, error) { + +// ctx, cancel := context.WithTimeout(ctx, 10*time.Second) +// defer cancel() + +// // if there is any account task, first we have to complete that +// if len(accounts) > 0 { + +// } +// // hashes := append(codes, nodes...) +// // data, stid, err := sss.configs.protocol.GetNodeData(ctx, hashes) +// // if err != nil { +// // return nil, stid, err +// // } +// // if err := validateGetNodeDataResult(hashes, data); err != nil { +// // return nil, stid, err +// // } +// return data, stid, nil +// } + +func (stg *StageFullStateSync) insertChain(gbm *blockDownloadManager, + protocol syncProtocol, + lbls prometheus.Labels, + targetBN uint64) { + +} + +func (stg *StageFullStateSync) saveProgress(s *StageState, tx kv.RwTx) (err error) { + + useInternalTx := tx == nil + if useInternalTx { + var err error + tx, err = stg.configs.db.BeginRw(context.Background()) + if err != nil { + return err + } + defer tx.Rollback() + } + + // save progress + if err = s.Update(tx, s.state.CurrentBlockNumber()); err != nil { + utils.Logger().Error(). + Err(err). + Msgf("[STAGED_SYNC] saving progress for block States stage failed") + return ErrSaveStateProgressFail + } + + if useInternalTx { + if err := tx.Commit(); err != nil { + return err + } + } + return nil +} + +func (stg *StageFullStateSync) Revert(ctx context.Context, firstCycle bool, u *RevertState, s *StageState, tx kv.RwTx) (err error) { + useInternalTx := tx == nil + if useInternalTx { + tx, err = stg.configs.db.BeginRw(ctx) + if err != nil { + return err + } + defer tx.Rollback() + } + + if err = u.Done(tx); err != nil { + return err + } + + if useInternalTx { + if err = tx.Commit(); err != nil { + return err + } + } + return nil +} + +func (stg *StageFullStateSync) CleanUp(ctx context.Context, firstCycle bool, p *CleanUpState, tx kv.RwTx) (err error) { + useInternalTx := tx == nil + if useInternalTx { + tx, err = stg.configs.db.BeginRw(ctx) + if err != nil { + return err + } + defer tx.Rollback() + } + + if useInternalTx { + if err = tx.Commit(); err != nil { + return err + } + } + return nil +} diff --git a/api/service/stagedstreamsync/state_sync.go b/api/service/stagedstreamsync/state_sync_full.go similarity index 80% rename from api/service/stagedstreamsync/state_sync.go rename to api/service/stagedstreamsync/state_sync_full.go index 1bf685826..daf0f4869 100644 --- a/api/service/stagedstreamsync/state_sync.go +++ b/api/service/stagedstreamsync/state_sync_full.go @@ -3,6 +3,7 @@ package stagedstreamsync import ( "bytes" "encoding/json" + "fmt" gomath "math" "math/big" "math/rand" @@ -17,11 +18,14 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" + + //"github.com/ethereum/go-ethereum/trie/trienode" "github.com/harmony-one/harmony/common/math" "github.com/harmony-one/harmony/core" "github.com/harmony-one/harmony/core/rawdb" "github.com/harmony-one/harmony/core/state" "github.com/harmony-one/harmony/internal/utils" + "github.com/harmony-one/harmony/p2p/stream/protocols/sync/message" sttypes "github.com/harmony-one/harmony/p2p/stream/types" "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/log/v3" @@ -191,7 +195,7 @@ func (t *healRequestSort) Swap(i, j int) { // Merge merges the pathsets, so that several storage requests concerning the // same account are merged into one, to reduce bandwidth. // This operation is moot if t has not first been sorted. -func (t *healRequestSort) Merge() []TrieNodePathSet { +func (t *healRequestSort) Merge() []*message.TrieNodePathSet { var result []TrieNodePathSet for _, path := range t.syncPaths { pathset := TrieNodePathSet(path) @@ -211,7 +215,20 @@ func (t *healRequestSort) Merge() []TrieNodePathSet { } } } - return result + // convert to array of pointers + result_ptr := make([]*message.TrieNodePathSet, 0) + for _, p := range result { + result_ptr = append(result_ptr, &message.TrieNodePathSet{ + Pathset: p, + }) + } + return result_ptr +} + +type byteCodeTasksBundle struct { + id uint64 //unique id for bytecode task bundle + task *accountTask + hashes []common.Hash } type storageTaskBundle struct { @@ -231,16 +248,16 @@ type healTask struct { codeTasks map[common.Hash]struct{} // Set of byte code tasks currently queued for retrieval, indexed by code hash paths []string hashes []common.Hash - pathsets []TrieNodePathSet + pathsets []*message.TrieNodePathSet task *healTask root common.Hash byteCodeReq bool } type tasks struct { - accountTasks map[uint64]*accountTask // Current account task set being synced - storageTasks map[uint64]*storageTaskBundle // Set of trie node tasks currently queued for retrieval, indexed by path - codeTasks map[common.Hash]struct{} // Set of byte code tasks currently queued for retrieval, indexed by hash + accountTasks map[uint64]*accountTask // Current account task set being synced + storageTasks map[uint64]*storageTaskBundle // Set of trie node tasks currently queued for retrieval, indexed by path + codeTasks map[uint64]*byteCodeTasksBundle // Set of byte code tasks currently queued for retrieval, indexed by hash healer map[uint64]*healTask snapped bool // Flag to signal that snap phase is done } @@ -249,7 +266,7 @@ func newTasks() *tasks { return &tasks{ accountTasks: make(map[uint64]*accountTask, 0), storageTasks: make(map[uint64]*storageTaskBundle, 0), - codeTasks: make(map[common.Hash]struct{}), + codeTasks: make(map[uint64]*byteCodeTasksBundle), healer: make(map[uint64]*healTask, 0), snapped: false, } @@ -272,13 +289,13 @@ func (t *tasks) deleteAccountTask(accountTaskIndex uint64) { } } -func (t *tasks) addCodeTask(h common.Hash) { - t.codeTasks[h] = struct{}{} +func (t *tasks) addCodeTask(id uint64, bytecodeTask *byteCodeTasksBundle) { + t.codeTasks[id] = bytecodeTask } -func (t *tasks) deleteCodeTask(hash common.Hash) { - if _, ok := t.codeTasks[hash]; ok { - delete(t.codeTasks, hash) +func (t *tasks) deleteCodeTask(id uint64) { + if _, ok := t.codeTasks[id]; ok { + delete(t.codeTasks, id) } } @@ -500,33 +517,6 @@ func FullAccountRLP(data []byte) ([]byte, error) { return rlp.EncodeToBytes(account) } -// onHealState is a callback method to invoke when a flat state(account -// or storage slot) is downloaded during the healing stage. The flat states -// can be persisted blindly and can be fixed later in the generation stage. -// Note it's not concurrent safe, please handle the concurrent issue outside. -func (s *FullStateDownloadManager) onHealState(paths [][]byte, value []byte) error { - if len(paths) == 1 { - var account types.StateAccount - if err := rlp.DecodeBytes(value, &account); err != nil { - return nil // Returning the error here would drop the remote peer - } - blob := s.SlimAccountRLP(account) - rawdb.WriteAccountSnapshot(s.stateWriter, common.BytesToHash(paths[0]), blob) - s.accountHealed += 1 - s.accountHealedBytes += common.StorageSize(1 + common.HashLength + len(blob)) - } - if len(paths) == 2 { - rawdb.WriteStorageSnapshot(s.stateWriter, common.BytesToHash(paths[0]), common.BytesToHash(paths[1]), value) - s.storageHealed += 1 - s.storageHealedBytes += common.StorageSize(1 + 2*common.HashLength + len(value)) - } - if s.stateWriter.ValueSize() > ethdb.IdealBatchSize { - s.stateWriter.Write() // It's fine to ignore the error here - s.stateWriter.Reset() - } - return nil -} - func (s *FullStateDownloadManager) commitHealer(force bool) { if !force && s.scheduler.MemSize() < ethdb.IdealBatchSize { return @@ -572,7 +562,7 @@ func (s *FullStateDownloadManager) SyncCompleted() { // getNextBatch returns objects with a maximum of n state download // tasks to send to the remote peer. func (s *FullStateDownloadManager) GetNextBatch() (accounts []*accountTask, - codes []common.Hash, + codes []*byteCodeTasksBundle, storages *storageTaskBundle, healtask *healTask, codetask *healTask, @@ -936,13 +926,13 @@ func (s *FullStateDownloadManager) updateStats(written, duplicate, unexpected in // tasks to send to the remote peer. func (s *FullStateDownloadManager) getBatchFromUnprocessed(n int, withHealTasks bool) ( accounts []*accountTask, - codes []common.Hash, + codes []*byteCodeTasksBundle, storages *storageTaskBundle, healtask *healTask, codetask *healTask) { // over trie nodes as those can be written to disk and forgotten about. - codes = make([]common.Hash, 0, n) + codes = make([]*byteCodeTasksBundle, 0, n) accounts = make([]*accountTask, 0, n) for i, task := range s.tasks.accountTasks { @@ -961,9 +951,12 @@ func (s *FullStateDownloadManager) getBatchFromUnprocessed(n int, withHealTasks accounts = append(accounts, task) s.requesting.addAccountTask(task.id, task) // s.tasks.deleteAccountTask(task) + + // one task account is enough for an stream + return } - cap := n - len(accounts) + cap := n // - len(accounts) for _, task := range s.tasks.accountTasks { // Skip tasks that are already retrieving (or done with) all codes @@ -971,19 +964,42 @@ func (s *FullStateDownloadManager) getBatchFromUnprocessed(n int, withHealTasks continue } + var hashes []common.Hash for hash := range task.codeTasks { delete(task.codeTasks, hash) - codes = append(codes, hash) - s.requesting.addCodeTask(hash) - s.tasks.deleteCodeTask(hash) - // Stop when we've gathered enough requests - if len(codes) >= cap { - return + hashes = append(hashes, hash) + } + + // create a unique id for task bundle + var taskID uint64 + for { + taskID = uint64(rand.Int63()) + if taskID == 0 { + continue } + if _, ok := s.tasks.codeTasks[taskID]; ok { + continue + } + break + } + + bytecodeTask := &byteCodeTasksBundle{ + id: taskID, + hashes: hashes, + task: task, + } + codes = append(codes, bytecodeTask) + + s.requesting.addCodeTask(taskID, bytecodeTask) + //s.tasks.deleteCodeTask(taskID) + + // Stop when we've gathered enough requests + if len(codes) >= cap { + return } } - cap = n - len(accounts) - len(codes) + cap = n - len(codes) // - len(accounts) for accTaskID, task := range s.tasks.accountTasks { // Skip tasks that are already retrieving (or done with) all small states @@ -1118,7 +1134,7 @@ func (s *FullStateDownloadManager) getBatchFromUnprocessed(n int, withHealTasks var ( hashes = make([]common.Hash, 0, cap) paths = make([]string, 0, cap) - pathsets = make([]TrieNodePathSet, 0, cap) + pathsets = make([]*message.TrieNodePathSet, 0, cap) ) for path, hash := range s.tasks.healer[0].trieTasks { delete(s.tasks.healer[0].trieTasks, path) @@ -1228,7 +1244,7 @@ func (s *FullStateDownloadManager) getBatchFromUnprocessed(n int, withHealTasks // sortByAccountPath takes hashes and paths, and sorts them. After that, it generates // the TrieNodePaths and merges paths which belongs to the same account path. -func sortByAccountPath(paths []string, hashes []common.Hash) ([]string, []common.Hash, []trie.SyncPath, []TrieNodePathSet) { +func sortByAccountPath(paths []string, hashes []common.Hash) ([]string, []common.Hash, []trie.SyncPath, []*message.TrieNodePathSet) { var syncPaths []trie.SyncPath for _, path := range paths { syncPaths = append(syncPaths, trie.NewSyncPath([]byte(path))) @@ -1242,14 +1258,14 @@ func sortByAccountPath(paths []string, hashes []common.Hash) ([]string, []common // getBatchFromRetries get the block number batch to be requested from retries. func (s *FullStateDownloadManager) getBatchFromRetries(n int) ( accounts []*accountTask, - codes []common.Hash, + codes []*byteCodeTasksBundle, storages *storageTaskBundle, healtask *healTask, codetask *healTask) { // over trie nodes as those can be written to disk and forgotten about. - accounts = make([]*accountTask, 0, n) - codes = make([]common.Hash, 0, n) + accounts = make([]*accountTask, 0) + codes = make([]*byteCodeTasksBundle, 0) for _, task := range s.retries.accountTasks { // Stop when we've gathered enough requests @@ -1263,14 +1279,14 @@ func (s *FullStateDownloadManager) getBatchFromRetries(n int) ( cap := n - len(accounts) - for code := range s.retries.codeTasks { + for _, code := range s.retries.codeTasks { // Stop when we've gathered enough requests if len(codes) >= cap { return } codes = append(codes, code) - s.requesting.addCodeTask(code) - s.retries.deleteCodeTask(code) + s.requesting.addCodeTask(code.id, code) + s.retries.deleteCodeTask(code.id) } cap = n - len(accounts) - len(codes) @@ -1339,7 +1355,7 @@ func (s *FullStateDownloadManager) getBatchFromRetries(n int) ( // HandleRequestError handles the error result func (s *FullStateDownloadManager) HandleRequestError(accounts []*accountTask, - codes []common.Hash, + codes []*byteCodeTasksBundle, storages *storageTaskBundle, healtask *healTask, codetask *healTask, @@ -1354,8 +1370,8 @@ func (s *FullStateDownloadManager) HandleRequestError(accounts []*accountTask, } for _, code := range codes { - s.requesting.deleteCodeTask(code) - s.retries.addCodeTask(code) + s.requesting.deleteCodeTask(code.id) + s.retries.addCodeTask(code.id, code) } if storages != nil { @@ -1374,18 +1390,99 @@ func (s *FullStateDownloadManager) HandleRequestError(accounts []*accountTask, } } +// UnpackAccountRanges retrieves the accounts from the range packet and converts from slim +// wire representation to consensus format. The returned data is RLP encoded +// since it's expected to be serialized to disk without further interpretation. +// +// Note, this method does a round of RLP decoding and re-encoding, so only use it +// once and cache the results if need be. Ideally discard the packet afterwards +// to not double the memory use. +func (s *FullStateDownloadManager) UnpackAccountRanges(retAccounts []*message.AccountData) ([]common.Hash, [][]byte, error) { + var ( + hashes = make([]common.Hash, len(retAccounts)) + accounts = make([][]byte, len(retAccounts)) + ) + for i, acc := range retAccounts { + val, err := FullAccountRLP(acc.Body) + if err != nil { + return nil, nil, fmt.Errorf("invalid account %x: %v", acc.Body, err) + } + hashes[i] = common.BytesToHash(acc.Hash) + accounts[i] = val + } + return hashes, accounts, nil +} + // HandleAccountRequestResult handles get account ranges result -func (s *FullStateDownloadManager) HandleAccountRequestResult(task *accountTask, // Task which this request is filling - hashes []common.Hash, // Account hashes in the returned range - accounts []*types.StateAccount, // Expanded accounts in the returned range - cont bool, // Whether the account range has a continuation +func (s *FullStateDownloadManager) HandleAccountRequestResult(task *accountTask, + retAccounts []*message.AccountData, + proof [][]byte, + origin []byte, + last []byte, loopID int, streamID sttypes.StreamID) error { + hashes, accounts, err := s.UnpackAccountRanges(retAccounts) + if err != nil { + return err + } + + size := common.StorageSize(len(hashes) * common.HashLength) + for _, account := range accounts { + size += common.StorageSize(len(account)) + } + for _, node := range proof { + size += common.StorageSize(len(node)) + } + utils.Logger().Trace(). + Int("hashes", len(hashes)). + Int("accounts", len(accounts)). + Int("proofs", len(proof)). + Interface("bytes", size). + Msg("Delivering range of accounts") + s.lock.Lock() defer s.lock.Unlock() - if err := s.processAccountResponse(task, hashes, accounts, cont); err != nil { + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For account range queries that means the state being + // retrieved was either already pruned remotely, or the peer is not yet + // synced to our head. + if len(hashes) == 0 && len(accounts) == 0 && len(proof) == 0 { + utils.Logger().Debug(). + Interface("root", s.root). + Msg("Peer rejected account range request") + s.lock.Unlock() + return nil + } + root := s.root + s.lock.Unlock() + + // Reconstruct a partial trie from the response and verify it + keys := make([][]byte, len(hashes)) + for i, key := range hashes { + keys[i] = common.CopyBytes(key[:]) + } + nodes := make(ProofList, len(proof)) + for i, node := range proof { + nodes[i] = node + } + cont, err := trie.VerifyRangeProof(root, origin[:], last[:], keys, accounts, nodes.Set()) + if err != nil { + utils.Logger().Warn().Err(err).Msg("Account range failed proof") + // Signal this request as failed, and ready for rescheduling + return err + } + accs := make([]*types.StateAccount, len(accounts)) + for i, account := range accounts { + acc := new(types.StateAccount) + if err := rlp.DecodeBytes(account, acc); err != nil { + panic(err) // We created these blobs, we must be able to decode them + } + accs[i] = acc + } + + if err := s.processAccountResponse(task, hashes, accs, cont); err != nil { return err } @@ -1491,16 +1588,72 @@ func (s *FullStateDownloadManager) processAccountResponse(task *accountTask, // } // HandleBytecodeRequestResult handles get bytecode result -func (s *FullStateDownloadManager) HandleBytecodeRequestResult(task *accountTask, // Task which this request is filling - hashes []common.Hash, // Hashes of the bytecode to avoid double hashing +// it is a callback method to invoke when a batch of contract +// bytes codes are received from a remote peer. +func (s *FullStateDownloadManager) HandleBytecodeRequestResult(task interface{}, // Task which this request is filling + reqHashes []common.Hash, // Hashes of the bytecode to avoid double hashing bytecodes [][]byte, // Actual bytecodes to store into the database (nil = missing) loopID int, streamID sttypes.StreamID) error { + s.lock.RLock() + syncing := !s.snapped + s.lock.RUnlock() + + if syncing { + return s.onByteCodes(task.(*accountTask), bytecodes, reqHashes) + } + return s.onHealByteCodes(task.(*healTask), reqHashes, bytecodes) +} + +// onByteCodes is a callback method to invoke when a batch of contract +// bytes codes are received from a remote peer in the syncing phase. +func (s *FullStateDownloadManager) onByteCodes(task *accountTask, bytecodes [][]byte, reqHashes []common.Hash) error { + var size common.StorageSize + for _, code := range bytecodes { + size += common.StorageSize(len(code)) + } + + utils.Logger().Trace().Int("bytecodes", len(bytecodes)).Interface("bytes", size).Msg("Delivering set of bytecodes") + s.lock.Lock() defer s.lock.Unlock() - if err := s.processBytecodeResponse(task, hashes, bytecodes); err != nil { + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For bytecode range queries that means the peer is not + // yet synced. + if len(bytecodes) == 0 { + utils.Logger().Debug().Msg("Peer rejected bytecode request") + return nil + } + + // Cross reference the requested bytecodes with the response to find gaps + // that the serving node is missing + hasher := sha3.NewLegacyKeccak256().(crypto.KeccakState) + hash := make([]byte, 32) + + codes := make([][]byte, len(reqHashes)) + for i, j := 0, 0; i < len(bytecodes); i++ { + // Find the next hash that we've been served, leaving misses with nils + hasher.Reset() + hasher.Write(bytecodes[i]) + hasher.Read(hash) + + for j < len(reqHashes) && !bytes.Equal(hash, reqHashes[j][:]) { + j++ + } + if j < len(reqHashes) { + codes[j] = bytecodes[i] + j++ + continue + } + // We've either ran out of hashes, or got unrequested data + utils.Logger().Warn().Int("count", len(bytecodes)-i).Msg("Unexpected bytecodes") + // Signal this request as failed, and ready for rescheduling + return errors.New("unexpected bytecode") + } + // Response validated, send it to the scheduler for filling + if err := s.processBytecodeResponse(task, reqHashes, codes); err != nil { return err } @@ -1574,21 +1727,143 @@ func estimateRemainingSlots(hashes int, last common.Hash) (uint64, error) { return space.Uint64() - uint64(hashes), nil } -// HandleStorageRequestResult handles get storages result -func (s *FullStateDownloadManager) HandleStorageRequestResult(mainTask *accountTask, // Task which this response belongs to - subTask *storageTask, // Task which this response is filling - accounts []common.Hash, // Account hashes requested, may be only partially filled - roots []common.Hash, // Storage roots requested, may be only partially filled - hashes [][]common.Hash, // Storage slot hashes in the returned range - storageSlots [][][]byte, // Storage slot values in the returned range - cont bool, // Whether the last storage range has a continuation +// Unpack retrieves the storage slots from the range packet and returns them in +// a split flat format that's more consistent with the internal data structures. +func (s *FullStateDownloadManager) UnpackStorages(slots [][]*message.StorageData) ([][]common.Hash, [][][]byte) { + var ( + hashset = make([][]common.Hash, len(slots)) + slotset = make([][][]byte, len(slots)) + ) + for i, slots := range slots { + hashset[i] = make([]common.Hash, len(slots)) + slotset[i] = make([][]byte, len(slots)) + for j, slot := range slots { + hashset[i][j] = common.BytesToHash(slot.Hash) + slotset[i][j] = slot.Body + } + } + return hashset, slotset +} + +// HandleStorageRequestResult handles get storages result when ranges of storage slots +// are received from a remote peer. +func (s *FullStateDownloadManager) HandleStorageRequestResult(mainTask *accountTask, + subTask *storageTask, + reqAccounts []common.Hash, + roots []common.Hash, + origin common.Hash, + limit common.Hash, + receivedSlots [][]*message.StorageData, + proof [][]byte, loopID int, streamID sttypes.StreamID) error { s.lock.Lock() defer s.lock.Unlock() - if err := s.processStorageResponse(mainTask, subTask, accounts, roots, hashes, storageSlots, cont); err != nil { + hashes, slots := s.UnpackStorages(receivedSlots) + + // Gather some trace stats to aid in debugging issues + var ( + hashCount int + slotCount int + size common.StorageSize + ) + for _, hashset := range hashes { + size += common.StorageSize(common.HashLength * len(hashset)) + hashCount += len(hashset) + } + for _, slotset := range slots { + for _, slot := range slotset { + size += common.StorageSize(len(slot)) + } + slotCount += len(slotset) + } + for _, node := range proof { + size += common.StorageSize(len(node)) + } + + utils.Logger().Trace(). + Int("accounts", len(hashes)). + Int("hashes", hashCount). + Int("slots", slotCount). + Int("proofs", len(proof)). + Interface("size", size). + Msg("Delivering ranges of storage slots") + + s.lock.Lock() + defer s.lock.Unlock() + + // Reject the response if the hash sets and slot sets don't match, or if the + // peer sent more data than requested. + if len(hashes) != len(slots) { + utils.Logger().Warn(). + Int("hashset", len(hashes)). + Int("slotset", len(slots)). + Msg("Hash and slot set size mismatch") + return errors.New("hash and slot set size mismatch") + } + if len(hashes) > len(reqAccounts) { + utils.Logger().Warn(). + Int("hashset", len(hashes)). + Int("requested", len(reqAccounts)). + Msg("Hash set larger than requested") + return errors.New("hash set larger than requested") + } + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For storage range queries that means the state being + // retrieved was either already pruned remotely, or the peer is not yet + // synced to our head. + if len(hashes) == 0 && len(proof) == 0 { + utils.Logger().Debug().Msg("Peer rejected storage request") + return nil + } + + // Reconstruct the partial tries from the response and verify them + var cont bool + + // If a proof was attached while the response is empty, it indicates that the + // requested range specified with 'origin' is empty. Construct an empty state + // response locally to finalize the range. + if len(hashes) == 0 && len(proof) > 0 { + hashes = append(hashes, []common.Hash{}) + slots = append(slots, [][]byte{}) + } + for i := 0; i < len(hashes); i++ { + // Convert the keys and proofs into an internal format + keys := make([][]byte, len(hashes[i])) + for j, key := range hashes[i] { + keys[j] = common.CopyBytes(key[:]) + } + nodes := make(ProofList, 0, len(proof)) + if i == len(hashes)-1 { + for _, node := range proof { + nodes = append(nodes, node) + } + } + var err error + if len(nodes) == 0 { + // No proof has been attached, the response must cover the entire key + // space and hash to the origin root. + _, err = trie.VerifyRangeProof(roots[i], nil, nil, keys, slots[i], nil) + if err != nil { + utils.Logger().Warn().Err(err).Msg("Storage slots failed proof") + return err + } + } else { + // A proof was attached, the response is only partial, check that the + // returned data is indeed part of the storage trie + proofdb := nodes.Set() + + cont, err = trie.VerifyRangeProof(roots[i], origin[:], limit[:], keys, slots[i], proofdb) + if err != nil { + utils.Logger().Warn().Err(err).Msg("Storage range failed proof") + return err + } + } + } + + if err := s.processStorageResponse(mainTask, subTask, reqAccounts, roots, hashes, slots, cont); err != nil { return err } @@ -1835,18 +2110,72 @@ func (s *FullStateDownloadManager) processStorageResponse(mainTask *accountTask, return nil } -// HandleTrieNodeHealRequestResult handles get trie nodes heal result +// HandleTrieNodeHealRequestResult handles get trie nodes heal result when a batch of trie nodes +// are received from a remote peer. func (s *FullStateDownloadManager) HandleTrieNodeHealRequestResult(task *healTask, // Task which this request is filling - paths []string, // Paths of the trie nodes - hashes []common.Hash, // Hashes of the trie nodes to avoid double hashing - nodes [][]byte, // Actual trie nodes to store into the database (nil = missing) + reqPaths []string, + reqHashes []common.Hash, + trienodes [][]byte, loopID int, streamID sttypes.StreamID) error { s.lock.Lock() defer s.lock.Unlock() - if err := s.processTrienodeHealResponse(task, paths, hashes, nodes); err != nil { + var size common.StorageSize + for _, node := range trienodes { + size += common.StorageSize(len(node)) + } + + utils.Logger().Trace(). + Int("trienodes", len(trienodes)). + Interface("bytes", size). + Msg("Delivering set of healing trienodes") + + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For bytecode range queries that means the peer is not + // yet synced. + if len(trienodes) == 0 { + utils.Logger().Debug().Msg("Peer rejected trienode heal request") + return nil + } + + // Cross reference the requested trienodes with the response to find gaps + // that the serving node is missing + var ( + hasher = sha3.NewLegacyKeccak256().(crypto.KeccakState) + hash = make([]byte, 32) + nodes = make([][]byte, len(reqHashes)) + fills uint64 + ) + for i, j := 0, 0; i < len(trienodes); i++ { + // Find the next hash that we've been served, leaving misses with nils + hasher.Reset() + hasher.Write(trienodes[i]) + hasher.Read(hash) + + for j < len(reqHashes) && !bytes.Equal(hash, reqHashes[j][:]) { + j++ + } + if j < len(reqHashes) { + nodes[j] = trienodes[i] + fills++ + j++ + continue + } + // We've either ran out of hashes, or got unrequested data + utils.Logger().Warn().Int("count", len(trienodes)-i).Msg("Unexpected healing trienodes") + + // Signal this request as failed, and ready for rescheduling + return errors.New("unexpected healing trienode") + } + // Response validated, send it to the scheduler for filling + s.trienodeHealPend.Add(fills) + defer func() { + s.trienodeHealPend.Add(^(fills - 1)) + }() + + if err := s.processTrienodeHealResponse(task, reqPaths, reqHashes, nodes); err != nil { return err } @@ -1959,6 +2288,67 @@ func (s *FullStateDownloadManager) HandleByteCodeHealRequestResult(task *healTas return nil } +// onHealByteCodes is a callback method to invoke when a batch of contract +// bytes codes are received from a remote peer in the healing phase. +func (s *FullStateDownloadManager) onHealByteCodes(task *healTask, + reqHashes []common.Hash, + bytecodes [][]byte) error { + + var size common.StorageSize + for _, code := range bytecodes { + size += common.StorageSize(len(code)) + } + + utils.Logger().Trace(). + Int("bytecodes", len(bytecodes)). + Interface("bytes", size). + Msg("Delivering set of healing bytecodes") + + s.lock.Lock() + s.lock.Unlock() + + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For bytecode range queries that means the peer is not + // yet synced. + if len(bytecodes) == 0 { + utils.Logger().Debug().Msg("Peer rejected bytecode heal request") + return nil + } + + // Cross reference the requested bytecodes with the response to find gaps + // that the serving node is missing + hasher := sha3.NewLegacyKeccak256().(crypto.KeccakState) + hash := make([]byte, 32) + + codes := make([][]byte, len(reqHashes)) + for i, j := 0, 0; i < len(bytecodes); i++ { + // Find the next hash that we've been served, leaving misses with nils + hasher.Reset() + hasher.Write(bytecodes[i]) + hasher.Read(hash) + + for j < len(reqHashes) && !bytes.Equal(hash, reqHashes[j][:]) { + j++ + } + if j < len(reqHashes) { + codes[j] = bytecodes[i] + j++ + continue + } + // We've either ran out of hashes, or got unrequested data + utils.Logger().Warn().Int("count", len(bytecodes)-i).Msg("Unexpected healing bytecodes") + + // Signal this request as failed, and ready for rescheduling + return errors.New("unexpected healing bytecode") + } + + if err := s.processBytecodeHealResponse(task, reqHashes, codes); err != nil { + return err + } + + return nil +} + // processBytecodeHealResponse integrates an already validated bytecode response // into the healer tasks. func (s *FullStateDownloadManager) processBytecodeHealResponse(task *healTask, // Task which this request is filling @@ -1992,3 +2382,30 @@ func (s *FullStateDownloadManager) processBytecodeHealResponse(task *healTask, / return nil } + +// onHealState is a callback method to invoke when a flat state(account +// or storage slot) is downloaded during the healing stage. The flat states +// can be persisted blindly and can be fixed later in the generation stage. +// Note it's not concurrent safe, please handle the concurrent issue outside. +func (s *FullStateDownloadManager) onHealState(paths [][]byte, value []byte) error { + if len(paths) == 1 { + var account types.StateAccount + if err := rlp.DecodeBytes(value, &account); err != nil { + return nil // Returning the error here would drop the remote peer + } + blob := s.SlimAccountRLP(account) + rawdb.WriteAccountSnapshot(s.stateWriter, common.BytesToHash(paths[0]), blob) + s.accountHealed += 1 + s.accountHealedBytes += common.StorageSize(1 + common.HashLength + len(blob)) + } + if len(paths) == 2 { + rawdb.WriteStorageSnapshot(s.stateWriter, common.BytesToHash(paths[0]), common.BytesToHash(paths[1]), value) + s.storageHealed += 1 + s.storageHealedBytes += common.StorageSize(1 + 2*common.HashLength + len(value)) + } + if s.stateWriter.ValueSize() > ethdb.IdealBatchSize { + s.stateWriter.Write() // It's fine to ignore the error here + s.stateWriter.Reset() + } + return nil +} diff --git a/api/service/stagedstreamsync/syncing.go b/api/service/stagedstreamsync/syncing.go index 73f050080..e6879a523 100644 --- a/api/service/stagedstreamsync/syncing.go +++ b/api/service/stagedstreamsync/syncing.go @@ -367,7 +367,7 @@ func (s *StagedStreamSync) doSync(downloaderContext context.Context, initSync bo } // add consensus last mile blocks - if s.consensus != nil { + if s.consensus != nil && s.isBeaconNode { if hashes, err := s.addConsensusLastMile(s.Blockchain(), s.consensus); err != nil { utils.Logger().Error().Err(err). Msg("[STAGED_STREAM_SYNC] Add consensus last mile failed") diff --git a/p2p/stream/protocols/sync/chain.go b/p2p/stream/protocols/sync/chain.go index aa4dced3f..3c147c91a 100644 --- a/p2p/stream/protocols/sync/chain.go +++ b/p2p/stream/protocols/sync/chain.go @@ -199,7 +199,7 @@ func (ch *chainHelperImpl) getReceipts(hs []common.Hash) ([]types.Receipts, erro return receipts, nil } -// getAccountRangeRequest +// getAccountRange func (ch *chainHelperImpl) getAccountRange(root common.Hash, origin common.Hash, limit common.Hash, bytes uint64) ([]*message.AccountData, [][]byte, error) { if bytes > softResponseLimit { bytes = softResponseLimit