diff --git a/network/buffer.go b/network/buffer.go index 97350a5..defb791 100644 --- a/network/buffer.go +++ b/network/buffer.go @@ -17,6 +17,9 @@ import ( // size. var ErrNotEnoughSpace = errors.New("not enough space") +// ErrUnknownType is returned when attempting to write values of an unknown type +var ErrUnknownType = errors.New("unknown type") + // Buffer contains the binary representation of multiple ValueLists and state // optimally write the next ValueList. type Buffer struct { @@ -248,7 +251,7 @@ func (b *Buffer) writeValues(values []api.Value) error { case api.Derive: binary.Write(b.buffer, binary.BigEndian, uint8(dsTypeDerive)) default: - panic("unexpected type") + return ErrUnknownType } } @@ -264,7 +267,7 @@ func (b *Buffer) writeValues(values []api.Value) error { case api.Derive: binary.Write(b.buffer, binary.BigEndian, int64(v)) default: - panic("unexpected type") + return ErrUnknownType } } diff --git a/network/buffer_test.go b/network/buffer_test.go index afb5856..5ddaae7 100644 --- a/network/buffer_test.go +++ b/network/buffer_test.go @@ -143,3 +143,16 @@ func TestWriteInt(t *testing.T) { t.Errorf("got %v, want %v", got, want) } } + +func TestUnknownType(t *testing.T) { + vl, err := Parse([]byte{0x00, 0x06, 0x00, 0x0f, 0x00, 0x01, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, ParseOpts{}) + if err != nil { + t.Errorf("Error parsing input %v", err) + } + + s1 := NewBuffer(0) + if err := s1.Write(vl[0]); err == nil { + t.Errorf("Writing bad stream should return an error") + } + +} diff --git a/network/parse.go b/network/parse.go index 48fb184..48ea6a3 100644 --- a/network/parse.go +++ b/network/parse.go @@ -34,6 +34,14 @@ func Parse(b []byte, opts ParseOpts) ([]api.ValueList, error) { return parse(b, None, opts) } +func readUint16(buf *bytes.Buffer) (uint16, error) { + read := buf.Next(2) + if len(read) != 2 { + return 0, ErrInvalid + } + return binary.BigEndian.Uint16(read), nil +} + func parse(b []byte, sl SecurityLevel, opts ParseOpts) ([]api.ValueList, error) { var valueLists []api.ValueList @@ -41,8 +49,15 @@ func parse(b []byte, sl SecurityLevel, opts ParseOpts) ([]api.ValueList, error) buf := bytes.NewBuffer(b) for buf.Len() > 0 { - partType := binary.BigEndian.Uint16(buf.Next(2)) - partLength := int(binary.BigEndian.Uint16(buf.Next(2))) + partType, err := readUint16(buf) + if err != nil { + return nil, ErrInvalid + } + partLengthUnsigned, err := readUint16(buf) + if err != nil { + return nil, ErrInvalid + } + partLength := int(partLengthUnsigned) if partLength < 5 || partLength-4 > buf.Len() { return valueLists, fmt.Errorf("invalid length %d", partLength) diff --git a/network/parse_test.go b/network/parse_test.go index 2ccb167..9032e8d 100644 --- a/network/parse_test.go +++ b/network/parse_test.go @@ -77,3 +77,10 @@ func TestParseString(t *testing.T) { t.Errorf("got (%q, nil), want (\"\", ErrorInvalid)", got) } } + +func TestOneByte(t *testing.T) { + _, err := Parse([]byte{0}, ParseOpts{}) + if err == nil { + t.Errorf("Parsing byte stream containing single zero byte should return an error") + } +}