diff --git a/p2p/stream/types/stream.go b/p2p/stream/types/stream.go index 985b64e00..492a77166 100644 --- a/p2p/stream/types/stream.go +++ b/p2p/stream/types/stream.go @@ -1,10 +1,14 @@ package sttypes import ( - "io/ioutil" + "bufio" + "encoding/binary" + "fmt" + "io" "sync" libp2p_network "github.com/libp2p/go-libp2p-core/network" + "github.com/pkg/errors" ) // Stream is the interface for streams implemented in each service. @@ -15,12 +19,14 @@ type Stream interface { ProtoSpec() (ProtoSpec, error) WriteBytes([]byte) error ReadBytes() ([]byte, error) - Close() error // Make sure streams can handle multiple calls of Close + Close() error + ResetOnClose() error } // BaseStream is the wrapper around type BaseStream struct { raw libp2p_network.Stream + rw *bufio.ReadWriter // parse protocol spec fields spec ProtoSpec @@ -30,8 +36,10 @@ type BaseStream struct { // NewBaseStream creates BaseStream as the wrapper of libp2p Stream func NewBaseStream(st libp2p_network.Stream) *BaseStream { + rw := bufio.NewReadWriter(bufio.NewReader(st), bufio.NewWriter(st)) return &BaseStream{ raw: st, + rw: rw, } } @@ -41,7 +49,7 @@ type StreamID string // Meta return the StreamID of the stream func (st *BaseStream) ID() StreamID { - return StreamID(st.raw.ID()) + return StreamID(st.raw.Conn().ID()) } // ProtoID return the remote protocol ID of the stream @@ -59,20 +67,64 @@ func (st *BaseStream) ProtoSpec() (ProtoSpec, error) { // Close close the stream on both sides. func (st *BaseStream) Close() error { - return st.raw.Reset() + return st.raw.Close() } -// WriteBytes write the bytes to the stream +const ( + maxMsgBytes = 20 * 1024 * 1024 // 20MB + sizeBytes = 4 // uint32 +) + +// WriteBytes write the bytes to the stream. +// First 4 bytes is used as the size bytes, and the rest is the content func (st *BaseStream) WriteBytes(b []byte) error { - _, err := st.raw.Write(b) - return err + if len(b) > maxMsgBytes { + return errors.New("message too long") + } + if _, err := st.rw.Write(intToBytes(len(b))); err != nil { + return errors.Wrap(err, "write size bytes") + } + if _, err := st.rw.Write(b); err != nil { + return errors.Wrap(err, "write content") + } + return st.rw.Flush() } // ReadMsg read the bytes from the stream func (st *BaseStream) ReadBytes() ([]byte, error) { - b, err := ioutil.ReadAll(st.raw) + sb := make([]byte, sizeBytes) + _, err := st.rw.Read(sb) + if err != nil { + return nil, errors.Wrap(err, "read size") + } + size := bytesToInt(sb) + if size > maxMsgBytes { + return nil, fmt.Errorf("message size exceed max: %v > %v", size, maxMsgBytes) + } + + cb := make([]byte, size) + n, err := io.ReadFull(st.rw, cb) if err != nil { - return nil, err + return nil, errors.Wrap(err, "read content") + } + if n != size { + return nil, errors.New("ReadBytes sanity failed: byte size") } - return b, nil + return cb, nil +} + +// ResetOnClose reset the stream during the shutdown of the node +func (st *BaseStream) ResetOnClose() error { + return st.raw.Reset() +} + +func intToBytes(val int) []byte { + b := make([]byte, sizeBytes) // uint32 + binary.LittleEndian.PutUint32(b, uint32(val)) + return b +} + +func bytesToInt(b []byte) int { + val := binary.LittleEndian.Uint32(b) + return int(val) }