[stream] added a length bytes to the start of p2p base stream (#3552)

Co-authored-by: Rongjian Lan <rongjian.lan@gmail.com>
pull/3565/head
Jacky Wang 4 years ago committed by GitHub
parent 838d14ec47
commit e12698cf44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 72
      p2p/stream/types/stream.go

@ -1,10 +1,14 @@
package sttypes
import (
"io/ioutil"
"bufio"
"encoding/binary"
"fmt"
"io"
"sync"
libp2p_network "github.com/libp2p/go-libp2p-core/network"
"github.com/pkg/errors"
)
// Stream is the interface for streams implemented in each service.
@ -15,12 +19,14 @@ type Stream interface {
ProtoSpec() (ProtoSpec, error)
WriteBytes([]byte) error
ReadBytes() ([]byte, error)
Close() error // Make sure streams can handle multiple calls of Close
Close() error
ResetOnClose() error
}
// BaseStream is the wrapper around
type BaseStream struct {
raw libp2p_network.Stream
rw *bufio.ReadWriter
// parse protocol spec fields
spec ProtoSpec
@ -30,8 +36,10 @@ type BaseStream struct {
// NewBaseStream creates BaseStream as the wrapper of libp2p Stream
func NewBaseStream(st libp2p_network.Stream) *BaseStream {
rw := bufio.NewReadWriter(bufio.NewReader(st), bufio.NewWriter(st))
return &BaseStream{
raw: st,
rw: rw,
}
}
@ -41,7 +49,7 @@ type StreamID string
// Meta return the StreamID of the stream
func (st *BaseStream) ID() StreamID {
return StreamID(st.raw.ID())
return StreamID(st.raw.Conn().ID())
}
// ProtoID return the remote protocol ID of the stream
@ -59,20 +67,64 @@ func (st *BaseStream) ProtoSpec() (ProtoSpec, error) {
// Close close the stream on both sides.
func (st *BaseStream) Close() error {
return st.raw.Reset()
return st.raw.Close()
}
// WriteBytes write the bytes to the stream
const (
maxMsgBytes = 20 * 1024 * 1024 // 20MB
sizeBytes = 4 // uint32
)
// WriteBytes write the bytes to the stream.
// First 4 bytes is used as the size bytes, and the rest is the content
func (st *BaseStream) WriteBytes(b []byte) error {
_, err := st.raw.Write(b)
return err
if len(b) > maxMsgBytes {
return errors.New("message too long")
}
if _, err := st.rw.Write(intToBytes(len(b))); err != nil {
return errors.Wrap(err, "write size bytes")
}
if _, err := st.rw.Write(b); err != nil {
return errors.Wrap(err, "write content")
}
return st.rw.Flush()
}
// ReadMsg read the bytes from the stream
func (st *BaseStream) ReadBytes() ([]byte, error) {
b, err := ioutil.ReadAll(st.raw)
sb := make([]byte, sizeBytes)
_, err := st.rw.Read(sb)
if err != nil {
return nil, errors.Wrap(err, "read size")
}
size := bytesToInt(sb)
if size > maxMsgBytes {
return nil, fmt.Errorf("message size exceed max: %v > %v", size, maxMsgBytes)
}
cb := make([]byte, size)
n, err := io.ReadFull(st.rw, cb)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "read content")
}
if n != size {
return nil, errors.New("ReadBytes sanity failed: byte size")
}
return b, nil
return cb, nil
}
// ResetOnClose reset the stream during the shutdown of the node
func (st *BaseStream) ResetOnClose() error {
return st.raw.Reset()
}
func intToBytes(val int) []byte {
b := make([]byte, sizeBytes) // uint32
binary.LittleEndian.PutUint32(b, uint32(val))
return b
}
func bytesToInt(b []byte) int {
val := binary.LittleEndian.Uint32(b)
return int(val)
}

Loading…
Cancel
Save