You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
324 lines
8.9 KiB
324 lines
8.9 KiB
4 years ago
|
package downloader
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"math/big"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/harmony-one/harmony/consensus/engine"
|
||
|
staking "github.com/harmony-one/harmony/staking/types"
|
||
|
|
||
|
"github.com/harmony-one/harmony/block"
|
||
|
|
||
|
"github.com/ethereum/go-ethereum/common"
|
||
|
"github.com/ethereum/go-ethereum/event"
|
||
|
|
||
|
"github.com/harmony-one/harmony/core/types"
|
||
|
"github.com/harmony-one/harmony/internal/params"
|
||
|
"github.com/harmony-one/harmony/p2p/stream/common/streammanager"
|
||
|
syncproto "github.com/harmony-one/harmony/p2p/stream/protocols/sync"
|
||
|
sttypes "github.com/harmony-one/harmony/p2p/stream/types"
|
||
|
"github.com/harmony-one/harmony/shard"
|
||
|
)
|
||
|
|
||
|
type testBlockChain struct {
|
||
|
curBN uint64
|
||
|
insertErrHook func(bn uint64) error
|
||
|
lock sync.Mutex
|
||
|
}
|
||
|
|
||
|
func newTestBlockChain(curBN uint64, insertErrHook func(bn uint64) error) *testBlockChain {
|
||
|
return &testBlockChain{
|
||
|
curBN: curBN,
|
||
|
insertErrHook: insertErrHook,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (bc *testBlockChain) CurrentBlock() *types.Block {
|
||
|
bc.lock.Lock()
|
||
|
defer bc.lock.Unlock()
|
||
|
|
||
|
return makeTestBlock(bc.curBN)
|
||
|
}
|
||
|
|
||
|
func (bc *testBlockChain) CurrentHeader() *block.Header {
|
||
|
bc.lock.Lock()
|
||
|
defer bc.lock.Unlock()
|
||
|
|
||
|
return makeTestBlock(bc.curBN).Header()
|
||
|
}
|
||
|
|
||
|
func (bc *testBlockChain) currentBlockNumber() uint64 {
|
||
|
bc.lock.Lock()
|
||
|
defer bc.lock.Unlock()
|
||
|
|
||
|
return bc.curBN
|
||
|
}
|
||
|
|
||
|
func (bc *testBlockChain) InsertChain(chain types.Blocks, verifyHeaders bool) (int, error) {
|
||
|
bc.lock.Lock()
|
||
|
defer bc.lock.Unlock()
|
||
|
|
||
|
for i, block := range chain {
|
||
|
if bc.insertErrHook != nil {
|
||
|
if err := bc.insertErrHook(block.NumberU64()); err != nil {
|
||
|
return i, err
|
||
|
}
|
||
|
}
|
||
|
if block.NumberU64() <= bc.curBN {
|
||
|
continue
|
||
|
}
|
||
|
if block.NumberU64() != bc.curBN+1 {
|
||
|
return i, fmt.Errorf("not expected block number: %v / %v", block.NumberU64(), bc.curBN+1)
|
||
|
}
|
||
|
bc.curBN++
|
||
|
}
|
||
|
return len(chain), nil
|
||
|
}
|
||
|
|
||
|
func (bc *testBlockChain) changeBlockNumber(val uint64) {
|
||
|
bc.lock.Lock()
|
||
|
defer bc.lock.Unlock()
|
||
|
|
||
|
bc.curBN = val
|
||
|
}
|
||
|
|
||
|
func (bc *testBlockChain) ShardID() uint32 { return 0 }
|
||
|
func (bc *testBlockChain) ReadShardState(epoch *big.Int) (*shard.State, error) { return nil, nil }
|
||
|
func (bc *testBlockChain) Config() *params.ChainConfig { return nil }
|
||
|
func (bc *testBlockChain) WriteCommitSig(blockNum uint64, lastCommits []byte) error { return nil }
|
||
|
func (bc *testBlockChain) GetHeader(hash common.Hash, number uint64) *block.Header { return nil }
|
||
|
func (bc *testBlockChain) GetHeaderByNumber(number uint64) *block.Header { return nil }
|
||
|
func (bc *testBlockChain) GetHeaderByHash(hash common.Hash) *block.Header { return nil }
|
||
|
func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { return nil }
|
||
|
func (bc *testBlockChain) ReadValidatorList() ([]common.Address, error) { return nil, nil }
|
||
|
func (bc *testBlockChain) ReadCommitSig(blockNum uint64) ([]byte, error) { return nil, nil }
|
||
|
func (bc *testBlockChain) ReadBlockRewardAccumulator(uint64) (*big.Int, error) { return nil, nil }
|
||
|
func (bc *testBlockChain) ValidatorCandidates() []common.Address { return nil }
|
||
|
func (bc *testBlockChain) Engine() engine.Engine { return nil }
|
||
|
func (bc *testBlockChain) ReadValidatorInformation(addr common.Address) (*staking.ValidatorWrapper, error) {
|
||
|
return nil, nil
|
||
|
}
|
||
|
func (bc *testBlockChain) ReadValidatorSnapshot(addr common.Address) (*staking.ValidatorSnapshot, error) {
|
||
|
return nil, nil
|
||
|
}
|
||
|
func (bc *testBlockChain) ReadValidatorSnapshotAtEpoch(epoch *big.Int, addr common.Address) (*staking.ValidatorSnapshot, error) {
|
||
|
return nil, nil
|
||
|
}
|
||
|
func (bc *testBlockChain) ReadValidatorStats(addr common.Address) (*staking.ValidatorStats, error) {
|
||
|
return nil, nil
|
||
|
}
|
||
|
func (bc *testBlockChain) SuperCommitteeForNextEpoch(beacon engine.ChainReader, header *block.Header, isVerify bool) (*shard.State, error) {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
type testInsertHelper struct {
|
||
|
bc *testBlockChain
|
||
|
}
|
||
|
|
||
|
func (ch *testInsertHelper) verifyAndInsertBlock(block *types.Block) error {
|
||
|
_, err := ch.bc.InsertChain(types.Blocks{block}, true)
|
||
|
return err
|
||
|
}
|
||
|
func (ch *testInsertHelper) verifyAndInsertBlocks(blocks types.Blocks) (int, error) {
|
||
|
return ch.bc.InsertChain(blocks, true)
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
initStreamNum = 32
|
||
|
minStreamNum = 16
|
||
|
)
|
||
|
|
||
|
type testSyncProtocol struct {
|
||
|
streamIDs []sttypes.StreamID
|
||
|
remoteChain *testBlockChain
|
||
|
requestErrHook func(uint64) error
|
||
|
|
||
|
curIndex int
|
||
|
numStreams int
|
||
|
lock sync.Mutex
|
||
|
}
|
||
|
|
||
|
func newTestSyncProtocol(targetBN uint64, numStreams int, requestErrHook func(uint64) error) *testSyncProtocol {
|
||
|
return &testSyncProtocol{
|
||
|
streamIDs: makeStreamIDs(numStreams),
|
||
|
remoteChain: newTestBlockChain(targetBN, nil),
|
||
|
requestErrHook: requestErrHook,
|
||
|
curIndex: 0,
|
||
|
numStreams: numStreams,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (sp *testSyncProtocol) GetCurrentBlockNumber(ctx context.Context, opts ...syncproto.Option) (uint64, sttypes.StreamID, error) {
|
||
|
sp.lock.Lock()
|
||
|
defer sp.lock.Unlock()
|
||
|
|
||
|
bn := sp.remoteChain.currentBlockNumber()
|
||
|
|
||
|
return bn, sp.nextStreamID(), nil
|
||
|
}
|
||
|
|
||
|
func (sp *testSyncProtocol) GetBlocksByNumber(ctx context.Context, bns []uint64, opts ...syncproto.Option) ([]*types.Block, sttypes.StreamID, error) {
|
||
|
sp.lock.Lock()
|
||
|
defer sp.lock.Unlock()
|
||
|
|
||
|
res := make([]*types.Block, 0, len(bns))
|
||
|
for _, bn := range bns {
|
||
|
if sp.requestErrHook != nil {
|
||
|
if err := sp.requestErrHook(bn); err != nil {
|
||
|
return nil, sp.nextStreamID(), err
|
||
|
}
|
||
|
}
|
||
|
if bn > sp.remoteChain.currentBlockNumber() {
|
||
|
res = append(res, nil)
|
||
|
} else {
|
||
|
res = append(res, makeTestBlock(bn))
|
||
|
}
|
||
|
}
|
||
|
return res, sp.nextStreamID(), nil
|
||
|
}
|
||
|
|
||
|
func (sp *testSyncProtocol) GetBlockHashes(ctx context.Context, bns []uint64, opts ...syncproto.Option) ([]common.Hash, sttypes.StreamID, error) {
|
||
|
sp.lock.Lock()
|
||
|
defer sp.lock.Unlock()
|
||
|
|
||
|
res := make([]common.Hash, 0, len(bns))
|
||
|
for _, bn := range bns {
|
||
|
if sp.requestErrHook != nil {
|
||
|
if err := sp.requestErrHook(bn); err != nil {
|
||
|
return nil, sp.nextStreamID(), err
|
||
|
}
|
||
|
}
|
||
|
if bn > sp.remoteChain.currentBlockNumber() {
|
||
|
res = append(res, emptyHash)
|
||
|
} else {
|
||
|
res = append(res, makeTestBlockHash(bn))
|
||
|
}
|
||
|
}
|
||
|
return res, sp.nextStreamID(), nil
|
||
|
}
|
||
|
|
||
|
func (sp *testSyncProtocol) GetBlocksByHashes(ctx context.Context, hs []common.Hash, opts ...syncproto.Option) ([]*types.Block, sttypes.StreamID, error) {
|
||
|
sp.lock.Lock()
|
||
|
defer sp.lock.Unlock()
|
||
|
|
||
|
res := make([]*types.Block, 0, len(hs))
|
||
|
for _, h := range hs {
|
||
|
bn := testHashToNumber(h)
|
||
|
if sp.requestErrHook != nil {
|
||
|
if err := sp.requestErrHook(bn); err != nil {
|
||
|
return nil, sp.nextStreamID(), err
|
||
|
}
|
||
|
}
|
||
|
if bn > sp.remoteChain.currentBlockNumber() {
|
||
|
res = append(res, nil)
|
||
|
} else {
|
||
|
res = append(res, makeTestBlock(bn))
|
||
|
}
|
||
|
}
|
||
|
return res, sp.nextStreamID(), nil
|
||
|
}
|
||
|
|
||
|
func (sp *testSyncProtocol) RemoveStream(target sttypes.StreamID) {
|
||
|
sp.lock.Lock()
|
||
|
defer sp.lock.Unlock()
|
||
|
|
||
|
for i, stid := range sp.streamIDs {
|
||
|
if stid == target {
|
||
|
if i == len(sp.streamIDs)-1 {
|
||
|
sp.streamIDs = sp.streamIDs[:i]
|
||
|
} else {
|
||
|
sp.streamIDs = append(sp.streamIDs[:i], sp.streamIDs[i+1:]...)
|
||
|
}
|
||
|
// mock discovery
|
||
|
if len(sp.streamIDs) < minStreamNum {
|
||
|
sp.streamIDs = append(sp.streamIDs, makeStreamID(sp.numStreams))
|
||
|
sp.numStreams++
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (sp *testSyncProtocol) NumStreams() int {
|
||
|
sp.lock.Lock()
|
||
|
defer sp.lock.Unlock()
|
||
|
|
||
|
return len(sp.streamIDs)
|
||
|
}
|
||
|
|
||
|
func (sp *testSyncProtocol) SubscribeAddStreamEvent(ch chan<- streammanager.EvtStreamAdded) event.Subscription {
|
||
|
var evtFeed event.Feed
|
||
|
go func() {
|
||
|
sp.lock.Lock()
|
||
|
num := len(sp.streamIDs)
|
||
|
sp.lock.Unlock()
|
||
|
for i := 0; i != num; i++ {
|
||
|
evtFeed.Send(streammanager.EvtStreamAdded{Stream: nil})
|
||
|
}
|
||
|
}()
|
||
|
return evtFeed.Subscribe(ch)
|
||
|
}
|
||
|
|
||
|
// TODO: add with whitelist stuff
|
||
|
func (sp *testSyncProtocol) nextStreamID() sttypes.StreamID {
|
||
|
if sp.curIndex >= len(sp.streamIDs) {
|
||
|
sp.curIndex = 0
|
||
|
}
|
||
|
index := sp.curIndex
|
||
|
sp.curIndex++
|
||
|
if sp.curIndex >= len(sp.streamIDs) {
|
||
|
sp.curIndex = 0
|
||
|
}
|
||
|
return sp.streamIDs[index]
|
||
|
}
|
||
|
|
||
|
func (sp *testSyncProtocol) changeBlockNumber(val uint64) {
|
||
|
sp.remoteChain.changeBlockNumber(val)
|
||
|
}
|
||
|
|
||
|
func makeStreamIDs(size int) []sttypes.StreamID {
|
||
|
res := make([]sttypes.StreamID, 0, size)
|
||
|
for i := 0; i != size; i++ {
|
||
|
res = append(res, makeStreamID(i))
|
||
|
}
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func makeStreamID(index int) sttypes.StreamID {
|
||
|
return sttypes.StreamID(fmt.Sprintf("test stream %v", index))
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
hashNumberMap = map[common.Hash]uint64{}
|
||
|
computed uint64
|
||
|
hashNumberLock sync.Mutex
|
||
|
)
|
||
|
|
||
|
func testHashToNumber(h common.Hash) uint64 {
|
||
|
hashNumberLock.Lock()
|
||
|
defer hashNumberLock.Unlock()
|
||
|
|
||
|
if h == emptyHash {
|
||
|
panic("not allowed")
|
||
|
}
|
||
|
if bn, ok := hashNumberMap[h]; ok {
|
||
|
return bn
|
||
|
}
|
||
|
for ; ; computed++ {
|
||
|
ch := makeTestBlockHash(computed)
|
||
|
hashNumberMap[ch] = computed
|
||
|
if ch == h {
|
||
|
return computed
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func testNumberToHashes(nums []uint64) []common.Hash {
|
||
|
hashes := make([]common.Hash, 0, len(nums))
|
||
|
for _, num := range nums {
|
||
|
hashes = append(hashes, makeTestBlockHash(num))
|
||
|
}
|
||
|
return hashes
|
||
|
}
|