[stream] added a length bytes to the start of p2p base stream (#3552)

Co-authored-by: Rongjian Lan <rongjian.lan@gmail.com>
pull/3565/head
Jacky Wang 4 years ago committed by GitHub
parent 838d14ec47
commit e12698cf44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 72
      p2p/stream/types/stream.go

@ -1,10 +1,14 @@
package sttypes package sttypes
import ( import (
"io/ioutil" "bufio"
"encoding/binary"
"fmt"
"io"
"sync" "sync"
libp2p_network "github.com/libp2p/go-libp2p-core/network" libp2p_network "github.com/libp2p/go-libp2p-core/network"
"github.com/pkg/errors"
) )
// Stream is the interface for streams implemented in each service. // Stream is the interface for streams implemented in each service.
@ -15,12 +19,14 @@ type Stream interface {
ProtoSpec() (ProtoSpec, error) ProtoSpec() (ProtoSpec, error)
WriteBytes([]byte) error WriteBytes([]byte) error
ReadBytes() ([]byte, error) ReadBytes() ([]byte, error)
Close() error // Make sure streams can handle multiple calls of Close Close() error
ResetOnClose() error
} }
// BaseStream is the wrapper around // BaseStream is the wrapper around
type BaseStream struct { type BaseStream struct {
raw libp2p_network.Stream raw libp2p_network.Stream
rw *bufio.ReadWriter
// parse protocol spec fields // parse protocol spec fields
spec ProtoSpec spec ProtoSpec
@ -30,8 +36,10 @@ type BaseStream struct {
// NewBaseStream creates BaseStream as the wrapper of libp2p Stream // NewBaseStream creates BaseStream as the wrapper of libp2p Stream
func NewBaseStream(st libp2p_network.Stream) *BaseStream { func NewBaseStream(st libp2p_network.Stream) *BaseStream {
rw := bufio.NewReadWriter(bufio.NewReader(st), bufio.NewWriter(st))
return &BaseStream{ return &BaseStream{
raw: st, raw: st,
rw: rw,
} }
} }
@ -41,7 +49,7 @@ type StreamID string
// Meta return the StreamID of the stream // Meta return the StreamID of the stream
func (st *BaseStream) ID() StreamID { func (st *BaseStream) ID() StreamID {
return StreamID(st.raw.ID()) return StreamID(st.raw.Conn().ID())
} }
// ProtoID return the remote protocol ID of the stream // ProtoID return the remote protocol ID of the stream
@ -59,20 +67,64 @@ func (st *BaseStream) ProtoSpec() (ProtoSpec, error) {
// Close close the stream on both sides. // Close close the stream on both sides.
func (st *BaseStream) Close() error { func (st *BaseStream) Close() error {
return st.raw.Reset() return st.raw.Close()
} }
// WriteBytes write the bytes to the stream const (
maxMsgBytes = 20 * 1024 * 1024 // 20MB
sizeBytes = 4 // uint32
)
// WriteBytes write the bytes to the stream.
// First 4 bytes is used as the size bytes, and the rest is the content
func (st *BaseStream) WriteBytes(b []byte) error { func (st *BaseStream) WriteBytes(b []byte) error {
_, err := st.raw.Write(b) if len(b) > maxMsgBytes {
return err return errors.New("message too long")
}
if _, err := st.rw.Write(intToBytes(len(b))); err != nil {
return errors.Wrap(err, "write size bytes")
}
if _, err := st.rw.Write(b); err != nil {
return errors.Wrap(err, "write content")
}
return st.rw.Flush()
} }
// ReadMsg read the bytes from the stream // ReadMsg read the bytes from the stream
func (st *BaseStream) ReadBytes() ([]byte, error) { func (st *BaseStream) ReadBytes() ([]byte, error) {
b, err := ioutil.ReadAll(st.raw) sb := make([]byte, sizeBytes)
_, err := st.rw.Read(sb)
if err != nil {
return nil, errors.Wrap(err, "read size")
}
size := bytesToInt(sb)
if size > maxMsgBytes {
return nil, fmt.Errorf("message size exceed max: %v > %v", size, maxMsgBytes)
}
cb := make([]byte, size)
n, err := io.ReadFull(st.rw, cb)
if err != nil { if err != nil {
return nil, err return nil, errors.Wrap(err, "read content")
}
if n != size {
return nil, errors.New("ReadBytes sanity failed: byte size")
} }
return b, nil return cb, nil
}
// ResetOnClose reset the stream during the shutdown of the node
func (st *BaseStream) ResetOnClose() error {
return st.raw.Reset()
}
func intToBytes(val int) []byte {
b := make([]byte, sizeBytes) // uint32
binary.LittleEndian.PutUint32(b, uint32(val))
return b
}
func bytesToInt(b []byte) int {
val := binary.LittleEndian.Uint32(b)
return int(val)
} }

Loading…
Cancel
Save