From df3a5d45184a447b80ec111211c0149c1168d261 Mon Sep 17 00:00:00 2001 From: Marcin Bojanczyk Date: Fri, 4 Apr 2025 20:36:39 -0700 Subject: [PATCH 01/10] feat(parquet): add variant encoder/decoder --- go.mod | 1 + go.sum | 4 +- parquet/variants/array.go | 248 ++++++++++ parquet/variants/array_test.go | 423 +++++++++++++++++ parquet/variants/builder.go | 201 ++++++++ parquet/variants/builder_test.go | 307 +++++++++++++ parquet/variants/decoder.go | 79 ++++ parquet/variants/decoder_test.go | 259 +++++++++++ parquet/variants/doc.go | 38 ++ parquet/variants/metadata.go | 152 ++++++ parquet/variants/metadata_test.go | 282 ++++++++++++ parquet/variants/object.go | 458 +++++++++++++++++++ parquet/variants/object_test.go | 711 +++++++++++++++++++++++++++++ parquet/variants/primitive.go | 678 +++++++++++++++++++++++++++ parquet/variants/primitive_test.go | 606 ++++++++++++++++++++++++ parquet/variants/testutils.go | 30 ++ parquet/variants/util.go | 154 +++++++ parquet/variants/util_test.go | 274 +++++++++++ parquet/variants/variant.go | 53 +++ 19 files changed, 4956 insertions(+), 2 deletions(-) create mode 100644 parquet/variants/array.go create mode 100644 parquet/variants/array_test.go create mode 100644 parquet/variants/builder.go create mode 100644 parquet/variants/builder_test.go create mode 100644 parquet/variants/decoder.go create mode 100644 parquet/variants/decoder_test.go create mode 100644 parquet/variants/doc.go create mode 100644 parquet/variants/metadata.go create mode 100644 parquet/variants/metadata_test.go create mode 100644 parquet/variants/object.go create mode 100644 parquet/variants/object_test.go create mode 100644 parquet/variants/primitive.go create mode 100644 parquet/variants/primitive_test.go create mode 100644 parquet/variants/testutils.go create mode 100644 parquet/variants/util.go create mode 100644 parquet/variants/util_test.go create mode 100644 parquet/variants/variant.go diff --git a/go.mod b/go.mod index bf9c6888..c8995f2d 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/goccy/go-json v0.10.5 github.com/golang/snappy v1.0.0 github.com/google/flatbuffers v25.2.10+incompatible + github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/hamba/avro/v2 v2.28.0 github.com/klauspost/asmfmt v1.3.2 diff --git a/go.sum b/go.sum index 764cbcbf..337578c9 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,8 @@ github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= diff --git a/parquet/variants/array.go b/parquet/variants/array.go new file mode 100644 index 00000000..0d4d0eb5 --- /dev/null +++ b/parquet/variants/array.go @@ -0,0 +1,248 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "fmt" + "io" + "reflect" +) + +type arrayBuilder struct { + w io.Writer + buf bytes.Buffer + numItems int + offsets []int32 + nextOffset int32 + doneCB doneCB + mdb *metadataBuilder + built bool +} + +func newArrayBuilder(w io.Writer, mdb *metadataBuilder, doneCB doneCB) *arrayBuilder { + return &arrayBuilder{ + w: w, + doneCB: doneCB, + mdb: mdb, + } +} + +var _ ArrayBuilder = (*arrayBuilder)(nil) + +// Write marshals the provided value into the appropriate Variant type and appends it to this array. +func (a *arrayBuilder) Write(val any, opts ...MarshalOpts) error { + return writeCommon(val, &a.buf, a.mdb, a.recordOffset) +} + +// Appends all elements from a provided slice into this array +func (a *arrayBuilder) fromSlice(sl any, opts ...MarshalOpts) error { + val := reflect.ValueOf(sl) + if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { + return fmt.Errorf("not a slice: %v", val.Kind()) + } + for i := range val.Len() { + item := val.Index(i) + if err := a.Write(item.Interface(), opts...); err != nil { + return err + } + } + return nil +} + +func (a *arrayBuilder) recordOffset(size int) { + a.numItems++ + a.offsets = append(a.offsets, a.nextOffset) + a.nextOffset += int32(size) +} + +// Array returns a new ArrayBuilder associated with this array. The marshaled array +// will not be part of the array until the returned ArrayBuilder's Build method is called. +func (a *arrayBuilder) Array() ArrayBuilder { + ab := newArrayBuilder(&a.buf, a.mdb, a.recordOffset) + return ab +} + +// Object returns a new ObjectBuilder associated with this array. The marshaled object +// will not be part of the array until the returned ObjectBuilder's Build() method is called. +func (a *arrayBuilder) Object() ObjectBuilder { + ob := newObjectBuilder(&a.buf, a.mdb, a.recordOffset) + return ob +} + +// Build marshals an Array in Variant format and writes its header and value data to the +// underlying Writer. This prepends serialized metadata about the array (ie. its header, number +// of elements, and element offsets) to the runnig data buffer. +func (a *arrayBuilder) Build() error { + if a.built { + return errAlreadyBuilt + } + large := a.numItems > 0xFF + offsetSize := fieldOffsetSize(a.nextOffset) + + // Preallocate a buffer for the header, number of items, and the field offsets. + numItemsSize := 1 + if large { + numItemsSize = 4 + } + serializedOffsetSize := (a.numItems + 1) * offsetSize + serializedDataBuf := bytes.NewBuffer(make([]byte, 0, 1+serializedOffsetSize)) + + // Write the header and number of elements in the array + serializedDataBuf.WriteByte(a.header(large, offsetSize)) + encodeNumber(int64(a.numItems), numItemsSize, serializedDataBuf) + + // Write all of the field offsets, including the final offset which is the first index after all + // of the array's elements. + for _, o := range a.offsets { + encodeNumber(int64(o), offsetSize, serializedDataBuf) + } + encodeNumber(int64(a.nextOffset), offsetSize, serializedDataBuf) + + // Calculate the size of this entire array in bytes, then call the recordkeeping callback if configured. + hdrSize, _ := a.w.Write(serializedDataBuf.Bytes()) + dataSize, _ := a.w.Write(a.buf.Bytes()) + totalSize := hdrSize + dataSize + + if a.doneCB != nil { + a.doneCB(totalSize) + } + return nil +} + +func (a *arrayBuilder) header(large bool, offsetSize int) byte { + // Header is one byte: AAABCCDD + // * A: Unused + // * B: Is Large: whether there are more than 255 elements in this array or not + // * C: Field Offset Size Minus One: the number of bytes (minus one) used to encode each Field Offset + // * D: 0x03: the identifier of the Array basic type + hdr := byte(offsetSize - 1) + if large { + hdr |= (1 << 2) + } + // Shift the value header over 2 to allow for the lower to bits to + // denote the array basic type + hdr <<= 2 + hdr |= byte(BasicArray) + return hdr +} + +type arrayData struct { + size int + numElements int + firstOffsetIdx int + firstDataIdx int + offsetWidth int +} + +// Parses array data from a marshaled object (where the different encoded sections start, plus size in bytes +// and number of elements). This also ensures that the entire array exists in the raw buffer. +func getArrayData(raw []byte, offset int) (*arrayData, error) { + if err := checkBounds(raw, offset, offset); err != nil { + return nil, err + } + hdr := raw[offset] + if bt := BasicTypeFromHeader(hdr); bt != BasicArray { + return nil, fmt.Errorf("not an array: %s", bt) + } + + // Get the size of all encoded metadata fields. Bitshift by two to expose the 5 raw value header bits. + hdr >>= 2 + + offsetWidth := int(hdr&0x03) + 1 + numElementsWidth := 1 + if hdr&0x2 != 0 { + numElementsWidth = 4 + } + + numElements, err := readUint(raw, offset+1, numElementsWidth) + if err != nil { + return nil, fmt.Errorf("could not get number of elements: %v", err) + } + firstOffsetIdx := offset + 1 + numElementsWidth // Header plus width of # of elements + lastOffsetIdx := firstOffsetIdx + int(numElements)*offsetWidth + firstDataIdx := lastOffsetIdx + offsetWidth + + // Do some bounds checking to ensure that the entire array is present in the raw buffer. + lastDataOffset, err := readUint(raw, lastOffsetIdx, offsetWidth) + if err != nil { + return nil, fmt.Errorf("could not read last offset: %v", err) + } + lastDataIdx := firstDataIdx + int(lastDataOffset) + if err := checkBounds(raw, offset, lastDataIdx); err != nil { + return nil, fmt.Errorf("array is out of bounds: %v", err) + } + + return &arrayData{ + size: lastDataIdx - offset, + numElements: int(numElements), + firstOffsetIdx: firstOffsetIdx, + firstDataIdx: firstDataIdx, + offsetWidth: offsetWidth, + }, nil +} + +// Unmarshals a Variant array into the provided destination. The destination must be a pointer to either +// a slice, or to the "any" type (which is then populated with []any{}). Any passed in slice will be +// cleared before unmarshaling. +func unmarshalArray(raw []byte, md *decodedMetadata, offset int, dest reflect.Value) error { + data, err := getArrayData(raw, offset) + if err != nil { + return err + } + + if kind := dest.Kind(); kind != reflect.Pointer { + return fmt.Errorf("invalid dest, must be non-nil pointer (got kind %s)", kind) + } + if dest.IsNil() { + return fmt.Errorf("invalid dest, must be non-nil pointer") + } + + destElem := dest.Elem() + if destElem.Kind() != reflect.Slice && destElem.Kind() != reflect.Interface { + return fmt.Errorf("invalid dest, must be a pointer to a slice (got pointer to %s)", destElem.Kind()) + } + + // Reset the slice. + var ret reflect.Value + if destElem.Kind() == reflect.Slice { + ret = reflect.MakeSlice(destElem.Type(), 0, data.numElements) + } else if destElem.Kind() == reflect.Interface { + ret = reflect.MakeSlice(reflect.TypeOf([]any{}), 0, data.numElements) + } + + // Iterate through all the elements in the encoded variant. + for i := range data.numElements { + elemOffset, err := readUint(raw, data.firstOffsetIdx+data.offsetWidth*i, data.offsetWidth) + if err != nil { + return err + } + dataIdx := int(elemOffset) + data.firstDataIdx + if err := checkBounds(raw, dataIdx, dataIdx); err != nil { + return err + } + + // Unmarshal the element and append to the slice to return. + newElemValue := reflect.New(ret.Type().Elem()) + if err := unmarshalCommon(raw, md, dataIdx, newElemValue); err != nil { + return err + } + ret = reflect.Append(ret, newElemValue.Elem()) + } + destElem.Set(ret) + return nil +} diff --git a/parquet/variants/array_test.go b/parquet/variants/array_test.go new file mode 100644 index 00000000..dc63a6c2 --- /dev/null +++ b/parquet/variants/array_test.go @@ -0,0 +1,423 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestArrayPrimitives(t *testing.T) { + var buf bytes.Buffer + ab := newArrayBuilder(&buf, nil, nil) + toEncode := []any{true, 256, "hello", 10, []byte{'t', 'h', 'e', 'r', 'e'}} + for _, te := range toEncode { + if err := ab.Write(te); err != nil { + t.Fatalf("Write(%v): %v", te, err) + } + } + if err := ab.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + + wantBytes := []byte{ + 0b00011, // Basic type = array, field_offset_minus_one = 0, is_large = false + 0x05, // 5 elements in the array + 0x00, // Index of "true" + 0x01, // Index of 256 + 0x04, // Index of "hello" + 0x0A, // Index of 10 + 0x0C, // Index of []byte{t,h,e,r,e} + 0x16, // First index after last element + 0b100, // Primitive true (header & value) + 0b10000, // Primitive int16 + 0x00, + 0x01, // 256 encoded + 0b10101, // Basic short string, length = 5 + 'h', + 'e', + 'l', + 'l', + 'o', // "hello" encoded + 0b1100, // Primitive int8 + 0x0A, // 10 encoded + 0b111100, // Primitive binary + 0x05, + 0x00, + 0x00, + 0x00, // bytes of length 5 + 't', + 'h', + 'e', + 'r', + 'e', // []byte{t,h,e,r,e} encoded + } + + diffByteArrays(t, buf.Bytes(), wantBytes) + + // Decode and ensure we got what was expected. + var got []any + if err := unmarshalArray(buf.Bytes(), &decodedMetadata{}, 0, reflect.ValueOf(&got)); err != nil { + t.Fatalf("unmarshalArray(): %v", err) + } + want := []any{true, int64(256), "hello", int64(10), []byte{'t', 'h', 'e', 'r', 'e'}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("Incorrect returned array. Diff (-got +want):\n%s", diff) + } +} + +func TestArrayLarge(t *testing.T) { + var buf bytes.Buffer + ab := newArrayBuilder(&buf, nil, nil) + // Create 256 items, which triggers "is_large" (256 cannot be encoded in one byte) + for i := range 256 { + if err := ab.Write(true); err != nil { + t.Fatalf("Write(iter = %d): %v", i, err) + } + } + if err := ab.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + wantBytes := []byte{ + 0b10111, // Basic type = array, field_offset_minus_one = 1, is_large = true + 0x00, + 0x01, + 0x00, + 0x00, // 256 encoded in 4 bytes + } + // Create the offset section + for i := range 256 { + wantBytes = append(wantBytes, []byte{byte(i), 0}...) + } + wantBytes = append(wantBytes, []byte{0, 1}...) // 256- the first index after all elements. + + // Create 256 trues + for range 256 { + wantBytes = append(wantBytes, 4) // 0x04 is basic type true + } + diffByteArrays(t, buf.Bytes(), wantBytes) +} + +func TestNestedArray(t *testing.T) { + var buf bytes.Buffer + ab := newArrayBuilder(&buf, nil, nil) + + // Create a nested array so that we get {true, 1, {false, 256}, 3} + ab.Write(true) + ab.Write(1) + nested := ab.Array() + nested.Write(false) + nested.Write(256) + nested.Build() + ab.Write(3) + + if err := ab.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + + wantBytes := []byte{ + 0b00011, // Basic type = array, field_offset_minus_one = 0, is_large = false + 0x04, // 4 elements in the array + 0x00, // Index of "true" + 0x01, // Index of 1 + 0x03, // Index of nested array + 0x0C, // Index of 3 + 0x0E, // First index after last element + 0b100, // Primitive true (header & value) + 0b1100, // Primitive int8 + 0x01, // 1 encoded + // Beginning of nested array + 0b00011, // Nested array, basic type = array, field_offset_minus_one = 0, is_large = false + 0x02, // 2 elements in the array + 0x00, // Index of "false" + 0x01, // Index of 256 + 0x04, // First index after last element + 0b1000, // Primitive false (header & value) + 0b10000, // Primitive int16 + 0x00, + 0x01, // 256 encoded + // End of nested array + 0b1100, // Primitive int8 + 0x03, // 3 encoded + } + + diffByteArrays(t, buf.Bytes(), wantBytes) +} + +func TestFromSlice(t *testing.T) { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ab := newArrayBuilder(&buf, mdb, nil) + if err := ab.fromSlice([]any{1, false, 2}); err != nil { + t.Fatalf("Write(): %v", err) + } + if err := ab.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + wantBytes := []byte{ + 0x03, // Basic type = array, field offset size = 1, is_large = false, + 0x03, // 3 items in array + 0x00, // Index of 1 + 0x02, // Index of false + 0x03, // Index of 2 + 0x05, // First index after last element + 0b1100, 0x01, // Int8, value = 1 + 0b1000, // false (value and header) + 0b1100, 0x02, // Int8, value = 2 + } + diffByteArrays(t, buf.Bytes(), wantBytes) +} + +func TestNestedObject(t *testing.T) { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ab := newArrayBuilder(&buf, mdb, nil) + + ab.Write(true) + ab.Write(1) + nested := ab.Object() + nested.Write("a", false) + nested.Write("b", 256) + nested.Build() + + ab.Write(3) + + if err := ab.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + + wantBytes := []byte{ + 0b00011, // Basic type = array, field_offset_minus_one = 0, is_large = false + 0x04, // 4 elements in the array + 0x00, // Index of "true" + 0x01, // Index of 1 + 0x03, // Index of nested object + 0x0E, // Index of 3 + 0x10, // First index after the last element + 0b100, // Primitive true (header & value) + 0b1100, // Primitive int8 + 0x01, // 1 encoded + // Beginning of nested object + 0b00000010, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false + 0x02, // 2 elements + 0x00, // Field ID of "a" + 0x01, // FieldID of "b" + 0x00, // Index of "a"s value (false) + 0x01, // Index of "b"s value (256) + 0x04, // First index after the last value + 0b1000, // Primitive false (header & value) + 0b10000, // Primitive int16 + 0x00, + 0x01, // 256 encoded + // End of nested object + 0b1100, // Primitive int8 + 0x03, // 3 encoded + } + + diffByteArrays(t, buf.Bytes(), wantBytes) +} + +func checkErr(t *testing.T, wantErr bool, err error) { + t.Helper() + if err != nil { + if wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if wantErr { + t.Fatal("Got no error when one was expected") + } +} + +func TestGetArrayData(t *testing.T) { + cases := []struct { + name string + encoded []byte + offset int + want *arrayData + wantErr bool + }{ + { + name: "Array with offset", + encoded: []byte{ + 0x00, 0x00, // Offset bytes + 0b11, // Basic type = array, field offset size = 1, is_large = false, + 0x03, // 3 elements in the array, + 0x00, // Index of "true" + 0x01, // Index of 256 + 0x04, // Index of "hello" + 0x0A, // First index after last element + 0b100, // Primitive true + 0b10000, 0x00, 0x01, // Primitive int16 val = 256 + 0b10101, 'h', 'e', 'l', 'l', 'o', // Basic short string, length = 5 + }, + offset: 2, + want: &arrayData{ + size: 16, + numElements: 3, + firstOffsetIdx: 4, + firstDataIdx: 8, + offsetWidth: 1, + }, + }, + { + name: "Array with large widths", + encoded: []byte{ + 0b00011011, // Basic type = array, field offset size = 3, is_large = true + 0x03, 0x00, 0x00, 0x00, // 3 elements in the array + 0x00, 0x00, 0x00, // Index of "true" + 0x01, 0x00, 0x00, // Index of 256 + 0x04, 0x00, 0x00, // Index of "hello" + 0x0A, 0x00, 0x00, // First index after last element + 0b100, // Primitive true (header & value) + 0b10000, 0x00, 0x01, // Primitive int16 val = 256 + 0b10101, 'h', 'e', 'l', 'l', 'o', // Basic short string, length = 5 + }, + want: &arrayData{ + size: 27, + numElements: 3, + firstOffsetIdx: 5, + firstDataIdx: 17, + offsetWidth: 3, + }, + }, + { + name: "Not an array", + encoded: []byte{0x00, 0x00}, // Primitive nulls + wantErr: true, + }, + { + name: "Elements would be out of bounds", + encoded: []byte{0b11, 0x03, 0x00, 0x01, 0x04, 0x0A, 0b100, 0b10000, 0x00, 0x01 /* missing string */}, + wantErr: true, + }, + { + name: "Offset is out of bounds", + encoded: []byte{0b11, 0x01, 0x00, 0x01, 0b100}, // Array with one boolean + offset: 10, + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := getArrayData(c.encoded, c.offset) + checkErr(t, c.wantErr, err) + if diff := cmp.Diff(got, c.want, cmp.AllowUnexported(arrayData{})); diff != "" { + t.Fatalf("Incorrect returned value. Diff (-got + want):\n%s", diff) + } + }) + } +} + +func TestUnmarshalArray(t *testing.T) { + cases := []struct { + name string + md *decodedMetadata + encoded []byte + offset int + // Normally the test will unmarhall into the type in want, but if this override is set, + // it'll try to unmarshal into something of this type. + overrideDecodeType reflect.Type + want any + wantErr bool + }{ + { + name: "Array with large widths and offset", + offset: 3, + encoded: []byte{ + 0x00, 0x00, 0x00, // 3 offset bytes + 0b00011011, // Basic type = array, field offset size = 3, is_large = true + 0x03, 0x00, 0x00, 0x00, // 3 elements in the array + 0x00, 0x00, 0x00, // Index of "true" + 0x01, 0x00, 0x00, // Index of 256 + 0x04, 0x00, 0x00, // Index of "hello" + 0x0A, 0x00, 0x00, // First index after last element + 0b100, // Primitive true (header & value) + 0b10000, 0x00, 0x01, // Primitive int16 val = 256 + 0b10101, 'h', 'e', 'l', 'l', 'o', // Basic short string, length = 5 + }, + want: []any{true, int64(256), "hello"}, + }, + { + name: "Unmarshal into typed array", + encoded: []byte{ + 0b011, // Basic type = array, field offset size = 1, is_large = false + 0x03, // 3 elements in the array + 0x00, 0x02, 0x04, 0x06, // Offsets for 3 integers and the index after the last element. + 0b1100, 0x01, // Primitive int8 val = 1 + 0b1100, 0x02, // Primitive int8 val = 2 + 0b1100, 0x03, // Primitive int8 val = 3 + }, + want: []int{1, 2, 3}, + }, + { + name: "Nested array", + encoded: []byte{ + 0b011, // Basic type = array, field offset size = 1, is_large = false + 0x01, // 1 element in the array + 0x00, 0x06, // Offsets for nested array and index after the last element + 0b011, // Nested array, field offset size = 1, is_large = false + 0x01, // 1 element in the array + 0x00, 0x02, // Offsets for 1 integer and index after the last element + 0b1100, 0x01, //primitive int8 val = 1 + }, + want: []any{[]any{int64(1)}}, + }, + { + name: "Invalid data", + encoded: []byte{0x00, 0x00}, + overrideDecodeType: reflect.TypeOf([]any{}), + wantErr: true, + }, + { + name: "Can't decode into primitive", + encoded: []byte{ + 0b011, // Basic type = array, field offset size = 1, is_large = false + 0x03, // 3 elements in the array + 0x00, 0x02, 0x04, 0x06, // Offsets for 3 integers and the index after the last element. + 0b1100, 0x01, // Primitive int8 val = 1 + 0b1100, 0x02, // Primitive int8 val = 2 + 0b1100, 0x03, // Primitive int8 val = 3 + }, + overrideDecodeType: reflect.TypeOf(""), + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + typ := reflect.TypeOf(c.want) + if c.overrideDecodeType != nil { + typ = c.overrideDecodeType + } + got := reflect.New(typ) + if err := unmarshalArray(c.encoded, c.md, c.offset, got); err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatalf("Got no error when one was expected") + } + diff(t, got.Elem().Interface(), c.want) + }) + } +} diff --git a/parquet/variants/builder.go b/parquet/variants/builder.go new file mode 100644 index 00000000..b8e8aa9b --- /dev/null +++ b/parquet/variants/builder.go @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" +) + +// ArrayBuilder provides a mechanism to build a Variant encoded array. +type ArrayBuilder interface { + Builder + Write(val any, opts ...MarshalOpts) error + Array() ArrayBuilder + Object() ObjectBuilder +} + +// ObjectBuilder provides a mechanism to build a Variant encoded object +type ObjectBuilder interface { + Builder + Write(key string, val any, opts ...MarshalOpts) error + Array(key string) (ArrayBuilder, error) + Object(key string) (ObjectBuilder, error) +} + +// Builder provides a mechanism to build something that's Variant encoded +type Builder interface { + Build() error +} + +// Options for marshaling time types into a Variant. +type MarshalOpts int + +const ( + MarshalTimeNanos MarshalOpts = 1 << iota + MarshalTimeNTZ + MarshalAsDate + MarshalAsTime + MarshalAsTimestamp + MarshalAsUUID +) + +var errAlreadyBuilt = errors.New("component already built") + +// VariantBuilder is a helper to build and encode a Variant +type VariantBuilder struct { + buf bytes.Buffer + builder Builder + typ BasicType + mdb *metadataBuilder + built bool +} + +func NewBuilder() *VariantBuilder { + return &VariantBuilder{ + typ: BasicUndefined, + mdb: newMetadataBuilder(), + } +} + +// Marshals a Variant from a provided value. This will automatically convert Go primitives +// into equivalent Variant values: +// - Slice/Array: Converted into Variant Arrays, with the exception of []byte which is a Variant Binary Primitive +// - map[string]any: Converted into Variant Objects +// - Structs: Converted into Variant Objects. Keys are either the exported struct fields, or the value present +// in the `variant` field annotation. +// - Go primitives: Converted into Variant primitives +func Marshal(val any, opts ...MarshalOpts) (*MarshaledVariant, error) { + b := NewBuilder() + if err := writeCommon(val, &b.buf, b.mdb, nil); err != nil { + return nil, err + } + ev, err := b.Build() + if err != nil { + return nil, err + } + return ev, nil +} + +func (vb *VariantBuilder) check() error { + if vb.built { + return errors.New("Variant has already been built") + } + if vb.typ != BasicUndefined { + return fmt.Errorf("Variant type has already been started as a %q", vb.typ) + } + return nil +} + +// Callback to record the number of bytes written. +type doneCB func(int) + +// Common functionalities in writing Variant encoded data. This will be recursed into from various places. +func writeCommon(val any, buf io.Writer, mdb *metadataBuilder, doneCB doneCB, opts ...MarshalOpts) error { + typ := kindFromValue(val) + switch typ { + case BasicPrimitive: + b, err := marshalPrimitive(val, buf, opts...) + if err != nil { + return fmt.Errorf("marshalPrimitive(): %v", err) + } + if doneCB != nil { + doneCB(b) + } + case BasicObject: + // Objects can be built from structs or maps. + ob := newObjectBuilder(buf, mdb, doneCB) + if reflect.ValueOf(val).Kind() == reflect.Map { + if err := ob.fromMap(val); err != nil { + return err + } + } else { + // No need to check if this is a struct- kindFromValue() has done that already. + if err := ob.fromStruct(val); err != nil { + return err + } + } + if err := ob.Build(); err != nil { + return err + } + case BasicArray: + ab := newArrayBuilder(buf, mdb, doneCB) + if err := ab.fromSlice(val, opts...); err != nil { + return err + } + if err := ab.Build(); err != nil { + return err + } + default: + return fmt.Errorf("unknown basic type: %s", typ) + } + return nil +} + +// Sets this Variant as a primitive, and writes the provided value. +func (vb *VariantBuilder) Primitive(val any, opts ...MarshalOpts) error { + if err := vb.check(); err != nil { + return err + } + vb.typ = BasicPrimitive + _, err := marshalPrimitive(val, &vb.buf, opts...) + return err +} + +// Sets this Variant as an Object and returns an ObjectBuilder. +func (vb *VariantBuilder) Object() (ObjectBuilder, error) { + if err := vb.check(); err != nil { + return nil, err + } + ob := newObjectBuilder(&vb.buf, vb.mdb, nil) + vb.typ = BasicObject + vb.builder = ob + return ob, nil +} + +// Sets this Variant as an Array and returns an ArrayBuilder. +func (vb *VariantBuilder) Array() (ArrayBuilder, error) { + if err := vb.check(); err != nil { + return nil, err + } + ab := newArrayBuilder(&vb.buf, vb.mdb, nil) + vb.typ = BasicArray + vb.builder = ab + return ab, nil +} + +// Builds the Variant +func (vb *VariantBuilder) Build() (*MarshaledVariant, error) { + // Indicate that all building has completed to prevent any mutation. + vb.built = true + + var encoded MarshaledVariant + encoded.Metadata = vb.mdb.Build() + + // Build an object or an array if necessary + if vb.builder != nil { + if err := vb.builder.Build(); err != nil && err != errAlreadyBuilt { + return nil, err + } + } + + encoded.Value = vb.buf.Bytes() + return &encoded, nil +} diff --git a/parquet/variants/builder_test.go b/parquet/variants/builder_test.go new file mode 100644 index 00000000..c91ab1eb --- /dev/null +++ b/parquet/variants/builder_test.go @@ -0,0 +1,307 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "testing" +) + +func TestVariantMarshal(t *testing.T) { + emptyMetadata := []byte{0x01, 0x00, 0x00} + cases := []struct { + name string + val any + wantEncoded *MarshaledVariant + wantErr bool + }{ + { + name: "Primitive", + val: 123, + wantEncoded: func() *MarshaledVariant { + var buf bytes.Buffer + marshalPrimitive(123, &buf) + return &MarshaledVariant{ + Metadata: emptyMetadata, + Value: buf.Bytes(), + } + }(), + }, + { + name: "Array", + val: []any{123, "hello", false, []any{321, "olleh", true}}, + wantEncoded: func() *MarshaledVariant { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ab := newArrayBuilder(&buf, mdb, nil) + ab.Write(123) + ab.Write("hello") + ab.Write(false) + + sub := ab.Array() + sub.Write(321) + sub.Write("olleh") + sub.Write(true) + sub.Build() + + ab.Build() + return &MarshaledVariant{ + Metadata: emptyMetadata, + Value: buf.Bytes(), + } + }(), + }, + { + name: "Struct", + val: struct { + FieldKey string + TagKey int `variant:"tag_key"` + Arr []int `variant:"array"` + unexported bool + }{"hello", 1, []int{1, 2, 3}, false}, + wantEncoded: func() *MarshaledVariant { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ob := newObjectBuilder(&buf, mdb, nil) + ob.Write("FieldKey", "hello") + ob.Write("tag_key", 1) + + ab, err := ob.Array("array") + if err != nil { + t.Fatalf("Array(): %v", err) + } + ab.Write(1) + ab.Write(2) + ab.Write(3) + ab.Build() + + ob.Build() + return &MarshaledVariant{ + Metadata: mdb.Build(), + Value: buf.Bytes(), + } + }(), + }, + { + name: "Struct pointer", + val: &struct { + Field1 string + Field2 int + }{"hello", 123}, + wantEncoded: func() *MarshaledVariant { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ob := newObjectBuilder(&buf, mdb, nil) + ob.Write("Field1", "hello") + ob.Write("Field2", 123) + ob.Build() + return &MarshaledVariant{ + Metadata: mdb.Build(), + Value: buf.Bytes(), + } + }(), + }, + { + name: "Valid map", + // Map iteration order is undefined so only use one key here to test. Rely on the tests in object.go to cover maps more fully. + val: map[string]int{"solitary_key": 1}, + wantEncoded: func() *MarshaledVariant { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ob := newObjectBuilder(&buf, mdb, nil) + ob.Write("solitary_key", 1) + ob.Build() + return &MarshaledVariant{ + Metadata: mdb.Build(), + Value: buf.Bytes(), + } + }(), + }, + { + name: "Nil", + val: nil, + wantEncoded: &MarshaledVariant{ + Metadata: emptyMetadata, + Value: []byte{0x00}, + }, + }, + { + name: "Invalid map", + val: map[int]string{1: "hello"}, + wantErr: true, + }, + { + name: "Invalid value", + val: func() {}, + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + encoded, err := Marshal(c.val) + if err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatalf("Got no error when one was expected") + } + diff(t, encoded, c.wantEncoded) + }) + } +} + +func TestVariantBuilderPrimitive(t *testing.T) { + vb := NewBuilder() + vb.Primitive(1) + ev, err := vb.Build() + if err != nil { + t.Fatalf("Build(): %v", err) + } + + // Check metadata + md := newMetadataBuilder().Build() + diffByteArrays(t, ev.Metadata, md) + + // Check value + var buf bytes.Buffer + marshalPrimitive(1, &buf) + diffByteArrays(t, ev.Value, buf.Bytes()) +} + +func TestVariantBuilderArray(t *testing.T) { + vb := NewBuilder() + ab, err := vb.Array() + if err != nil { + t.Fatalf("Array(): %v", err) + } + + buildArray := func(ab ArrayBuilder) { + ab.Write(1) + ab.Write(true) + nested := ab.Array() + nested.Write("hello") + nested.Build() + } + + buildArray(ab) + + ev, err := vb.Build() + if err != nil { + t.Fatalf("Build(): %v", err) + } + + // Check metadata + md := newMetadataBuilder().Build() + diffByteArrays(t, ev.Metadata, md) + + // Check value + wantArray := func() []byte { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ab := newArrayBuilder(&buf, mdb, nil) + buildArray(ab) + ab.Build() + return buf.Bytes() + }() + diffByteArrays(t, ev.Value, wantArray) +} + +func TestVariantBuilderObject(t *testing.T) { + vb := NewBuilder() + ob, err := vb.Object() + if err != nil { + t.Fatalf("Object(): %v", err) + } + + buildObject := func(ob ObjectBuilder) { + ob.Write("b", 1) + ob.Write("c", 2) + ob.Write("a", 3) + nested, _ := ob.Object("d") + nested.Write("a", true) + nested.Write("e", "nested") + nested.Build() + } + buildObject(ob) + ev, err := vb.Build() + if err != nil { + t.Fatalf("Build(): %v", err) + } + + wantEncoded := func() ([]byte, []byte) { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ob := newObjectBuilder(&buf, mdb, nil) + buildObject(ob) + ob.Build() + md := mdb.Build() + return md, buf.Bytes() + } + + wantMetadata, wantValue := wantEncoded() + + diffByteArrays(t, ev.Metadata, wantMetadata) + diffByteArrays(t, ev.Value, wantValue) +} + +func TestCannotChangeVariantType(t *testing.T) { + vbPrim := NewBuilder() + if err := vbPrim.Primitive(1); err != nil { + t.Fatalf("Write first time: %v", err) + } + + if err := vbPrim.Primitive(2); err == nil { + t.Fatal("Primitive already started") + } + if _, err := vbPrim.Array(); err == nil { + t.Fatal("Prmitive already started") + } + if _, err := vbPrim.Object(); err == nil { + t.Fatal("Primitive already started") + } + + vbArr := NewBuilder() + if _, err := vbArr.Array(); err != nil { + t.Fatalf("Array first time: %v", err) + } + if err := vbArr.Primitive(1); err == nil { + t.Fatal("Array already started") + } + if _, err := vbArr.Array(); err == nil { + t.Fatalf("Array already started") + } + if _, err := vbArr.Object(); err == nil { + t.Fatalf("Array already started") + } + + vbObj := NewBuilder() + if _, err := vbObj.Object(); err != nil { + t.Fatalf("Object first time: %v", err) + } + if err := vbObj.Primitive(1); err == nil { + t.Fatal("Object alrady started") + } + if _, err := vbObj.Array(); err == nil { + t.Fatal("Object already started") + } + if _, err := vbObj.Object(); err == nil { + t.Fatal("Object already started") + } +} diff --git a/parquet/variants/decoder.go b/parquet/variants/decoder.go new file mode 100644 index 00000000..9a85a065 --- /dev/null +++ b/parquet/variants/decoder.go @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "errors" + "fmt" + "reflect" +) + +// Decode provides a way to decode an encoded Variant into a Go type. The mapping of Variant +// types to Go types is: +// - Null: nil +// - Boolean: bool +// - Int8, Int16, Int32, Int64: int64 +// - Float: float32 +// - Double: float64 +// - Time: time.Time +// - Timestamp (all varieties): time.Time +// - String: string +// - Binary: []byte +// - UUID: string +// - Array: []any +// - Object: map[string]any + +// DecodeInto provides a way to decode an encoded Variant into a provided non-nil pointer dest if possible. +// For structs, this will attempt to match field names or fields annotated with the `variant` field annotation. +// TODO(Finish this comment) +func DecodeInto(encoded *MarshaledVariant, dest any) error { + destVal := reflect.ValueOf(dest) + if kind := destVal.Kind(); kind != reflect.Pointer { + return fmt.Errorf("dest must be a pointer (got %s)", kind) + } + if destVal.IsNil() { + return errors.New("dest pointer must not be nil") + } + + md, err := decodeMetadata(encoded.Metadata) + if err != nil { + return fmt.Errorf("could not decode metadata: %v", err) + } + + return unmarshalCommon(encoded.Value, md, 0, destVal) +} + +func unmarshalCommon(raw []byte, md *decodedMetadata, offset int, dest reflect.Value) error { + if err := checkBounds(raw, offset, offset); err != nil { + return err + } + switch bt := BasicTypeFromHeader(raw[offset]); bt { + case BasicPrimitive, BasicShortString: + if err := unmarshalPrimitive(raw, offset, dest); err != nil { + return fmt.Errorf("could not decode primitive: %v", err) + } + case BasicArray: + if err := unmarshalArray(raw, md, offset, dest); err != nil { + return fmt.Errorf("could not decode array: %v", err) + } + case BasicObject: + if err := unmarshalObject(raw, md, offset, dest); err != nil { + return fmt.Errorf("could not decode object: %v", err) + } + } + return nil +} diff --git a/parquet/variants/decoder_test.go b/parquet/variants/decoder_test.go new file mode 100644 index 00000000..3db54e4f --- /dev/null +++ b/parquet/variants/decoder_test.go @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func mustEncodeVariant(t *testing.T, val any) *MarshaledVariant { + t.Helper() + ev, err := Marshal(val) + if err != nil { + t.Fatalf("Marshal(): %v", err) + } + return ev +} + +func TestUnmarshal(t *testing.T) { + emptyMetadata := func() []byte { + return []byte{0x01, 0x00, 0x00} // Basic, but valid, metadata with no keys. + } + cases := []struct { + name string + encoded *MarshaledVariant + want any + wantErr bool + }{ + { + name: "Primitive", + encoded: mustEncodeVariant(t, "hello"), + want: "hello", + }, + { + name: "Array", + encoded: mustEncodeVariant(t, []any{"hello", 256, true}), + want: []any{"hello", int64(256), true}, + }, + { + name: "Object into map", + encoded: mustEncodeVariant(t, struct { + Key1 int `variant:"key1"` + Key2 []byte `variant:"key2"` + Key3 string `variant:"key3"` + }{1234, []byte{'b', 'y', 't', 'e'}, "hello"}), + want: map[string]any{ + "key1": int64(1234), + "key2": []byte{'b', 'y', 't', 'e'}, + "key3": "hello", + }, + }, + { + name: "Complex", + encoded: mustEncodeVariant(t, []any{ + 1234, struct { + Key1 string `variant:"key1"` + Arr []any `variant:"array"` + }{"hello", []any{false, true, "hello"}}, + "fin"}), + want: []any{ + int64(1234), + map[string]any{ + "key1": "hello", + "array": []any{false, true, "hello"}, + }, + "fin", + }, + }, + { + name: "Nil primitive", + encoded: mustEncodeVariant(t, nil), + want: nil, + }, + { + name: "Missing metadata", + encoded: &MarshaledVariant{ + Value: []byte{0x00}, // Primitive nil + }, + wantErr: true, + }, + { + name: "Missing Value", + encoded: &MarshaledVariant{ + Metadata: emptyMetadata(), + }, + wantErr: true, + }, + { + name: "Malformed array", + encoded: &MarshaledVariant{ + Metadata: emptyMetadata(), + Value: []byte{0x03, 0x02}, // Array, length 2, no other items. + }, + wantErr: true, + }, + { + name: "Object missing key", + encoded: func() *MarshaledVariant { + builder := NewBuilder() + ob, err := builder.Object() + if err != nil { + t.Fatalf("Object(): %v", err) + } + ob.Write("key", "value") + encoded, err := builder.Build() + if err != nil { + t.Fatalf("Build(): %v", err) + } + encoded.Metadata = emptyMetadata() + return encoded + }(), + wantErr: true, + }, + { + name: "Malformed object", + encoded: func() *MarshaledVariant { + builder := NewBuilder() + ob, err := builder.Object() + if err != nil { + t.Fatalf("Object(): %v", err) + } + ob.Write("key", "value") + encoded, err := builder.Build() + if err != nil { + t.Fatalf("Build(): %v", err) + } + encoded.Value = encoded.Value[:len(encoded.Value)-2] // Truncate + return encoded + }(), + wantErr: true, + }, + { + name: "Malformed primitive", + encoded: &MarshaledVariant{ + Metadata: emptyMetadata(), + Value: []byte{0xFD, 'a'}, // Short string, length 63 + }, + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var got any + err := DecodeInto(c.encoded, &got) + if err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatalf("Got no error when one was expected") + } + if diff := cmp.Diff(got, c.want); diff != "" { + t.Fatalf("Incorrect data returned. Diff (-got, +want):\n%s", diff) + } + }) + } +} + +func TestUnmarshalWithTypes(t *testing.T) { + cases := []struct { + name string + ev *MarshaledVariant + decodeType reflect.Type + want reflect.Value + wantErr bool + }{ + { + name: "Primitive int", + ev: mustEncodeVariant(t, 1), + decodeType: reflect.TypeOf(int(0)), + want: reflect.ValueOf(int(1)), + }, + { + name: "Primitive string", + ev: mustEncodeVariant(t, "hello"), + decodeType: reflect.TypeOf(""), + want: reflect.ValueOf("hello"), + }, + { + name: "Nested array", + ev: mustEncodeVariant(t, []any{[]any{1}}), + decodeType: reflect.TypeOf([]any{}), + want: reflect.ValueOf([]any{[]any{int64(1)}}), + }, + { + name: "Complex object into map", + ev: mustEncodeVariant(t, map[string]any{ + "key1": 123, + "key2": []any{true, false, "hello", []any{1, 2, 3}}, + "key3": map[string]any{ + "key1": "foo", + "key4": "bar", + }, + }), + decodeType: reflect.TypeOf(map[string]any{}), + want: reflect.ValueOf(map[string]any{ + "key1": int64(123), + "key2": []any{true, false, "hello", []any{int64(1), int64(2), int64(3)}}, + "key3": map[string]any{ + "key1": "foo", + "key4": "bar", + }, + }), + }, + { + name: "Object to map", + ev: mustEncodeVariant(t, map[string]int{"a": 1, "b": 2}), + decodeType: reflect.TypeOf(map[string]any{}), + want: reflect.ValueOf(map[string]any{"a": int64(1), "b": int64(2)}), + }, + // TODO: add tests to decode into struct + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + decodedMD, err := decodeMetadata(c.ev.Metadata) + if err != nil { + t.Fatalf("decodeMetadata(): %v", err) + } + + var decoded reflect.Value + if c.decodeType.Kind() == reflect.Map { + // Create a pointer to the map. + underlying := reflect.MakeMap(c.decodeType) + decoded = reflect.New(c.decodeType) + decoded.Elem().Set(underlying) + } else { + decoded = reflect.New(c.decodeType) + } + if err := unmarshalCommon(c.ev.Value, decodedMD, 0, decoded); err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatalf("Got no error when one was expected") + } + diff(t, decoded.Elem().Interface(), c.want.Interface()) + }) + } +} diff --git a/parquet/variants/doc.go b/parquet/variants/doc.go new file mode 100644 index 00000000..1faa329d --- /dev/null +++ b/parquet/variants/doc.go @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This package contains utilities to marshal and unmarshal data to and from the Variant +// encoding format as described in +// [the Variant encoding spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md +// +// There are two main ways to create a marshaled Variant: +// +// 1. Using `variants.Marshal()`. Simply pass in the value you'd like to marshal, and all the +// type inference is done for you. Structs and string-keyed maps will be converted into +// objects, slices will be converted into arrays, and primitives will be encoded with the +// appropriate primitive type. This will feel like the JSON library's `Marshal()`. +// 2. Using `variants.NewBuilder()`. This allows you to build out your Variant bit by bit. +// +// To convert from a marshaled Variant back to a type, use `variants.Unmarshal()`. Like the JSON +// `Unmarshal()`, this takes in a pointer to a value to "fill up." Objects can be unmarshaled into +// either structs or string-keyed maps, arrays can be unmarshaled into slices, and primitives into +// primitives. +// +// This library does have a few shortcomings, namely in that the Metadata is always marshaled with +// unordered keys (done to make marshaling considerably easier to code up), and that currently, +// unmarshaling decodes the whole Variant, not just a specific field. + +package variants diff --git a/parquet/variants/metadata.go b/parquet/variants/metadata.go new file mode 100644 index 00000000..3c402b8f --- /dev/null +++ b/parquet/variants/metadata.go @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "fmt" + "strings" +) + +const ( + versionMask = 0x0F + sortedMask = 0x10 + offsetMask = 0xC0 + + version = 0x01 +) + +type decodedMetadata struct { + keys []string +} + +func (d *decodedMetadata) At(i int) (string, bool) { + if i >= len(d.keys) { + return "", false + } + return d.keys[i], true +} + +func decodeMetadata(raw []byte) (*decodedMetadata, error) { + if len(raw) == 0 { + return nil, fmt.Errorf("invalid metadata") + } + + // Ensure the version is something recognizable. + if ver := raw[0] & versionMask; ver != version { + return nil, fmt.Errorf("invalid version (got %d, want %d)", ver, version) + } + + // Get the offset size. + offsetSize := int((raw[0] >> 6) + 1) + + // Get the number of elements in the dictionary. + elems, err := readUint(raw, 1, offsetSize) + if err != nil { + return nil, err + } + + var keys []string + if elems > 0 { + keys = make([]string, int(elems)) + for i := range int(elems) { + // Offset here is the first index of the offset list, which is the + // first element after the header and the size. + offset := offsetSize + 1 + raw, err := readNthItem(raw, offset, i, offsetSize, int(elems)) + if err != nil { + return nil, err + } + keys[i] = string(raw) + } + } + return &decodedMetadata{keys: keys}, err +} + +type metadataBuilder struct { + keyToIdx map[string]int + utf8Keys [][]byte + keyBytes int +} + +func newMetadataBuilder() *metadataBuilder { + return &metadataBuilder{ + keyToIdx: make(map[string]int), + } +} + +func (m *metadataBuilder) Build() []byte { + // Build the header. + hdr := byte(version) + offsetSize := m.calculateOffsetBytes() + hdr |= byte(offsetSize-1) << 6 + + mdSize := 1 + offsetSize*(len(m.utf8Keys)+1) + m.keyBytes + + buf := bytes.NewBuffer(make([]byte, 0, mdSize)) + buf.WriteByte(hdr) + + // Write the number of elements in the dictionary. + encodeNumber(int64(len(m.utf8Keys)), offsetSize, buf) + + // Write all of the offsets. + var currOffset int64 + for _, k := range m.utf8Keys { + encodeNumber(currOffset, offsetSize, buf) + currOffset += int64(len(k)) + } + encodeNumber(currOffset, offsetSize, buf) + + // Write all of the keys. + for _, k := range m.utf8Keys { + buf.Write(k) + } + + return buf.Bytes() +} + +func (m *metadataBuilder) calculateOffsetBytes() int { + maxNum := m.keyBytes + 1 + if dictLen := len(m.utf8Keys); dictLen > maxNum { + maxNum = dictLen + } + return fieldOffsetSize(int32(maxNum)) +} + +// Add adds a key to the metadata dictionary if not already present, and returns the index +// that the key is present. +func (m *metadataBuilder) Add(key string) int { + // Key already present, nothing to do. + if idx, ok := m.keyToIdx[key]; ok { + return idx + } + + // Ensure the passed in string is in UTF8 form (replacing invalid sequences with + // a replacement character), and append to the key slice. + keyBytes := []byte(strings.ToValidUTF8(key, "\uFFFD")) + idx := len(m.utf8Keys) + m.keyToIdx[key] = idx + m.utf8Keys = append(m.utf8Keys, keyBytes) + m.keyBytes += len(keyBytes) + + return idx +} + +func (m *metadataBuilder) KeyID(key string) (int, bool) { + id, ok := m.keyToIdx[key] + return id, ok +} diff --git a/parquet/variants/metadata_test.go b/parquet/variants/metadata_test.go new file mode 100644 index 00000000..66dac6ad --- /dev/null +++ b/parquet/variants/metadata_test.go @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "math/rand" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestBuildMetadata(t *testing.T) { + cases := []struct { + name string + keys []string + wantKeys []string + wantEncoded []byte + }{ + { + name: "No keys added", + wantEncoded: []byte{ + 0b00000001, // Header: offset_size = 1, not sorted, version = 1 + 0x00, // Dictionary size of 0 + 0x00, // First (and last) index of elements in the empty dictionary + }, + }, + { + name: "Small number of items", + keys: []string{"b", "c", "a"}, + wantKeys: []string{"b", "c", "a"}, + wantEncoded: []byte{ + 0b00000001, // Header: offset_size = 1, not sorted, version = 1 + 0x03, // Dictionary size of 3 + 0x00, // Index of first item "b" + 0x01, // Index of second item "c" + 0x02, // Index of third item "a" + 0x03, // First index after the last item + 'b', + 'c', + 'a', + }, + }, + { + name: "Dedupe similar keys", + keys: []string{"b", "c", "a", "a", "a", "a", "b", "b", "c", "c", "c"}, + wantKeys: []string{"b", "c", "a"}, + wantEncoded: []byte{ + 0b00000001, // Header: offset_size = 1, not sorted, version = 1 + 0x03, // Dictionary size of 3 + 0x00, // Index of first item "b" + 0x01, // Index of second item "c" + 0x02, // Index of third item "a" + 0x03, // First index after the last item + 'b', + 'c', + 'a', + }, + }, + { + name: "Large number of keys (encoded in more than one byte)", + keys: func() []string { + keys := make([]string, 26*26) + idx := 0 + for i := range 26 { + for j := range 26 { + keys[idx] = string([]byte{byte('a' + i), byte('a' + j)}) + idx++ + } + } + return keys + }(), + wantKeys: largeKeysString(), + wantEncoded: largeEncodedMetadata(), + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + mdb := newMetadataBuilder() + for _, k := range c.keys { + mdb.Add(k) + } + md := mdb.Build() + + diffByteArrays(t, md, c.wantEncoded) + + // Decode and check keys. This does a bit of double duty and a lot of these cases are + // covered in TestDecodeMetadata below, but it is useful to prove that anything that + // can be encoded can also be decoded. + decoded, err := decodeMetadata(md) + if err != nil { + t.Fatalf("decodeMetadata(): %v", err) + } + if diff := cmp.Diff(decoded.keys, c.wantKeys); diff != "" { + t.Fatalf("Received incorrect keys. Diff (-want +got):\n%s", diff) + } + }) + } +} + +// Helpers to create a Metadata struct that has a large list of keys (26^2) that will +// take more than one byte to encode. +func largeListKeyBytes() [][]byte { + keys := make([][]byte, 26*26) + idx := 0 + for i := range 26 { + for j := range 26 { + keys[idx] = []byte{byte('a' + i), byte('a' + j)} + idx++ + } + } + return keys +} + +func largeEncodedMetadata() []byte { + // Offset size = 2 + // Total size of encoded metadata is: + // * Header: 1 + // * Number of elements: 2 + // * Offset table: (26*26 + 1)*2 + // * Elements: 26*26*2 + buf := bytes.NewBuffer(make([]byte, 0, 1+2+(26*26+1)*2+(26*26*2))) + buf.WriteByte(0b01000001) // offset_size_minus_one = 1, is_sorted = false, version = 1 + + encodeNumber(676, 2, buf) // Encode the number of elements + + // Encode the offsets. NB: each key is 2 bytes. + for i := range 676 + 1 { + encodeNumber(int64(i*2), 2, buf) + } + + for _, k := range largeListKeyBytes() { + buf.Write(k) + } + + return buf.Bytes() +} + +func largeKeysString() []string { + rawKeys := largeListKeyBytes() + keys := make([]string, len(rawKeys)) + for i, k := range rawKeys { + keys[i] = string(k) + } + return keys +} + +// This test does duplicate some coverage provided in TestBuildMetadata, but is specifically +// focused on the decoding side of things. +func TestDecodeMetadata(t *testing.T) { + cases := []struct { + name string + raw []byte + want []string + wantErr bool + }{ + { + name: "Valid metadata with no elements", + raw: []byte{ + 0x01, // Base header, version = 1, + 0x00, // Zero length + 0x00, // First and last element + }, + }, + { + name: "Valid metadata, large number of elements", + raw: largeEncodedMetadata(), + want: largeKeysString(), + }, + { + name: "Zero length metadata", + wantErr: true, + }, + { + name: "Invalid version: 0", + raw: []byte{0x00, 0x00}, + wantErr: true, + }, + { + name: "Invalid version: 2", + raw: []byte{0x02, 0x00}, + wantErr: true, + }, + { + name: "Bad number of elements", + raw: []byte{0b11000001, 0x00}, // Offset size = 4, should be out of bounds read + wantErr: true, + }, + { + name: "Missing elements", + raw: []byte{0x01, 0x02, 0x00, 0x01, 0x02}, + wantErr: true, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := decodeMetadata(c.raw) + if err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatal("Got no error when one was expected") + } + if diff := cmp.Diff(got.keys, c.want); diff != "" { + t.Fatalf("Received incorrect keys. Diff (-want +got):\n%s", diff) + } + }) + } +} + +func buildRandomKey(l int) string { + buf := make([]byte, l) + // Create a random ascii character between decimal 33 (!) and decimal 126 (~) + for i := range l { + randChar := byte(rand.Intn(94)) + 33 + buf[i] = randChar + } + return string(buf) +} + +func TestOffsetCalculation(t *testing.T) { + cases := []struct { + name string + keyLen int + numKeys int + wantHdr byte + }{ + { + name: "Offset length 1", + keyLen: 1, + numKeys: 1, + wantHdr: 0b00000001, + }, + { + name: "Offset length 2", + keyLen: 4, + numKeys: 256, + wantHdr: 0b01000001, + }, + { + name: "Offset length 3", + keyLen: 1<<16 + 1, + numKeys: 1, + wantHdr: 0b10000001, + }, + { + name: "Offset length 4", + keyLen: 1<<24 + 1, + numKeys: 1, + wantHdr: 0b11000001, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + mdb := newMetadataBuilder() + for range c.numKeys { + mdb.Add(buildRandomKey(c.keyLen)) + } + md := mdb.Build() + if got, want := md[0], c.wantHdr; got != want { + t.Fatalf("Incorrect header: got %x, want %x", got, want) + } + }) + } +} diff --git a/parquet/variants/object.go b/parquet/variants/object.go new file mode 100644 index 00000000..2031f64a --- /dev/null +++ b/parquet/variants/object.go @@ -0,0 +1,458 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "sort" + "strings" +) + +// Container to keep track of the metadata for a given element in an object. This is mainly +// used so that it's possible to sort by key during build time (as required by spec) and +// not need to go doing a bunch of lookups to find field IDs and offsets. +type objectKey struct { + fieldID int + offset int + key string +} + +type objectBuilder struct { + w io.Writer + buf bytes.Buffer + objKeys []objectKey + fieldIDs map[int]struct{} + maxFieldID int + numItems int + offsets []int + nextOffset int + doneCB doneCB + mdb *metadataBuilder + built bool +} + +var _ ObjectBuilder = (*objectBuilder)(nil) + +func newObjectBuilder(w io.Writer, mdb *metadataBuilder, doneCB doneCB) *objectBuilder { + return &objectBuilder{ + w: w, + mdb: mdb, + doneCB: doneCB, + fieldIDs: make(map[int]struct{}), + } +} + +// Write marshals the provided value into the appropriate Variant type and adds it to this object with +// the provided key. Keys must be unique per object (though nested objects may share the same key). +func (o *objectBuilder) Write(key string, val any, opts ...MarshalOpts) error { + if err := o.checkKey(key); err != nil { + return err + } + return writeCommon(val, &o.buf, o.mdb, func(size int) { + o.record(key, size) + }) +} + +// Extracts field info from a struct field, namely the key name and any options associated with the field. +// If the Variant key name is not present in the `variant` annotation, the struct's field name will be +// used as the key. +// +// This function assumes the field is exported. +func extractFieldInfo(field reflect.StructField) (string, []MarshalOpts) { + var opts []MarshalOpts + + tag, ok := field.Tag.Lookup("variant") + if !ok || tag == "" { + return field.Name, nil + } + + // Tag is of the form "key_name,comma,separated,flags" + parts := strings.Split(tag, ",") + if len(parts) == 1 { + return tag, nil + } + + keyName := parts[0] + if keyName == "" { + keyName = field.Name + } + + for _, optStr := range parts[1:] { + switch strings.ToLower(optStr) { + case "nanos": + opts = append(opts, MarshalTimeNanos) + case "ntz": + opts = append(opts, MarshalTimeNTZ) + case "date": + opts = append(opts, MarshalAsDate) + case "time": + opts = append(opts, MarshalAsTime) + case "timestamp": + opts = append(opts, MarshalAsTimestamp) + case "uuid": + opts = append(opts, MarshalAsUUID) + } + } + + return keyName, opts +} + +// Creates an object from a struct. Key names are determined by either the struct's field name, or +// by a value in a `variant` field annotation (with the annotation taking precedence). +func (o *objectBuilder) fromStruct(st any) error { + stVal := reflect.ValueOf(st) + + // Get the underlying struct if this is a pointer to one. + if stVal.Kind() == reflect.Pointer { + stVal = stVal.Elem() + } + if stVal.Kind() != reflect.Struct { + return fmt.Errorf("not a struct: %s", stVal.Kind()) + } + typ := stVal.Type() + + for i := range typ.NumField() { + field := typ.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Get the tag. If not present, use the field name. + key, opts := extractFieldInfo(field) + if err := o.Write(key, stVal.Field(i).Interface(), opts...); err != nil { + return err + } + } + return nil +} + +// Creates an object from a string-keyed map. +func (o *objectBuilder) fromMap(m any) error { + mapVal := reflect.ValueOf(m) + + // Make sure this is a map with string keys. + if mapVal.Kind() != reflect.Map { + return fmt.Errorf("not a map: %v", mapVal.Kind()) + } + if keyKind := mapVal.Type().Key().Kind(); keyKind != reflect.String { + return fmt.Errorf("map does not have string keys: %v", keyKind) + } + + for _, keyVal := range mapVal.MapKeys() { + valVal := mapVal.MapIndex(keyVal) + if err := o.Write(keyVal.Interface().(string), valVal.Interface()); err != nil { + return err + } + } + return nil +} + +// Keys within a given object must be unique +func (o *objectBuilder) checkKey(key string) error { + fieldID, ok := o.mdb.KeyID(key) + if ok { + if _, ok := o.fieldIDs[fieldID]; ok { + return fmt.Errorf("mutiple insertion of key %q in object", key) + } + } + return nil +} + +// Array returns a new ArrayBuilder associated with this object. The marshaled array will +// not be part of the object until the returned ArrayBuilder's Build method is called. +func (o *objectBuilder) Array(key string) (ArrayBuilder, error) { + if err := o.checkKey(key); err != nil { + return nil, err + } + ab := newArrayBuilder(&o.buf, o.mdb, func(size int) { + o.record(key, size) + }) + return ab, nil +} + +// Object returns a new ObjectBuilder associated with this object. The marshaled object will +// not be part of this object until the returned ObjectBuilder's Build method is called. +// +// NB. A nested object can contain a key that also exists in the parent object. +func (o *objectBuilder) Object(key string) (ObjectBuilder, error) { + if err := o.checkKey(key); err != nil { + return nil, err + } + ob := newObjectBuilder(&o.buf, o.mdb, func(size int) { + o.record(key, size) + }) + return ob, nil +} + +// Bookkeeping to record information about an elements key, offset, field ID, and +// to keep a running track of the number of items and the max field ID seen. +func (o *objectBuilder) record(key string, size int) { + o.numItems++ + currOffset := o.nextOffset + o.offsets = append(o.offsets, currOffset) + o.nextOffset += size + + fieldID := o.mdb.Add(key) + o.objKeys = append(o.objKeys, objectKey{ + fieldID: fieldID, + offset: currOffset, + key: key, + }) + o.fieldIDs[fieldID] = struct{}{} + + if fieldID > o.maxFieldID { + o.maxFieldID = fieldID + } +} + +// Build writes the marshaled object to the builders io.Writer. This prepends serialized +// metadata about the object (ie. its header, number of elements, sorted field IDs, and +// offsets) to the running data buffer. +func (o *objectBuilder) Build() error { + if o.built { + return errAlreadyBuilt + } + numItemsSize := 1 + if isLarge(o.numItems) { + numItemsSize = 4 + } + offsetSize := fieldOffsetSize(int32(o.nextOffset)) + fieldIDSize := fieldOffsetSize(int32(o.maxFieldID)) + + // Sort the object keys as per the spec. + sort.Slice(o.objKeys, func(i, j int) bool { + return strings.Compare(o.objKeys[i].key, o.objKeys[j].key) < 0 + }) + + // Preallocate a buffer for the header, number of items, field IDs, and field offsets + serializedFieldIDSize := fieldIDSize * o.numItems + serializedFieldOffsetSize := offsetSize * (o.numItems + 1) + + serializedDataBuf := bytes.NewBuffer(make([]byte, 0, 1+numItemsSize+serializedFieldIDSize+serializedFieldOffsetSize)) + serializedDataBuf.WriteByte(o.header(fieldIDSize, offsetSize)) + + encodeNumber(int64(o.numItems), numItemsSize, serializedDataBuf) + for _, k := range o.objKeys { + encodeNumber(int64(k.fieldID), fieldIDSize, serializedDataBuf) + } + for _, k := range o.objKeys { + encodeNumber(int64(k.offset), offsetSize, serializedDataBuf) + } + encodeNumber(int64(o.nextOffset), offsetSize, serializedDataBuf) + + // Write everything to the writer. + hdrSize, _ := o.w.Write(serializedDataBuf.Bytes()) + dataSize, _ := o.w.Write(o.buf.Bytes()) + + totalSize := hdrSize + dataSize + if o.doneCB != nil { + o.doneCB(totalSize) + } + return nil +} + +func (o *objectBuilder) header(fieldIDSize, fieldOffsetSize int) byte { + // Header is one byte: ABCCDDEE + // * A: Unused + // * B: Is Large: whether there are more than 255 elements in this object or not. + // * C: Field ID Size Minus One: the number of bytes (minus one) used to encode each Field ID + // * D: Field Offset Size Minus One: the number of bytes (minus one) used to encode each Field Offset + // * E: 0x02: the identifier of the Object basic type + hdr := byte(fieldOffsetSize - 1) + hdr |= byte((fieldIDSize - 1) << 2) + if isLarge(o.numItems) { + hdr |= byte(1 << 4) + } + + // Basic type is the lower two bits of the header. Shift the Object specific bits over 2. + hdr <<= 2 + hdr |= byte(BasicObject) + return hdr +} + +type objectData struct { + size int + numElements int + firstFieldIDIdx int + firstOffsetIdx int + firstDataIdx int + fieldIDWidth int + offsetWidth int +} + +// Parses object data from a marshaled object (where the different encoded sections start, +// plus size in bytes and number of elements), plus ensures that the entire object is present +// in the raw buffer. +func getObjectData(raw []byte, offset int) (*objectData, error) { + if err := checkBounds(raw, offset, offset); err != nil { + return nil, err + } + + hdr := raw[offset] + if bt := BasicTypeFromHeader(hdr); bt != BasicObject { + return nil, fmt.Errorf("not an object: %s", bt) + } + + // Get the size of all encoded metadata fields. Bitshift by two to expose the 5 raw value header bits. + hdr >>= 2 + + offsetWidth := int(hdr&0x03) + 1 + fieldIDWidth := int((hdr>>2)&0x03) + 1 + + numElementsWidth := 1 + if hdr&0x08 != 0 { + numElementsWidth = 4 + } + + numElements, err := readUint(raw, offset+1, numElementsWidth) + if err != nil { + return nil, fmt.Errorf("could not get number of elements: %v", err) + } + + firstFieldIDIdx := offset + 1 + numElementsWidth // Header plus width of # of elements + firstOffsetIdx := firstFieldIDIdx + int(numElements)*fieldIDWidth + firstDataIdx := firstOffsetIdx + int(numElements+1)*offsetWidth + lastDataOffset, err := readUint(raw, firstDataIdx-offsetWidth, offsetWidth) + if err != nil { + return nil, fmt.Errorf("could not read last offset: %v", err) + } + lastDataIdx := firstDataIdx + int(lastDataOffset) + + // Also do some bounds checking to ensure that the entire object is represented in the raw buffer. + if err := checkBounds(raw, offset, int(lastDataIdx)); err != nil { + return nil, fmt.Errorf("object is out of bounds: %v", err) + } + return &objectData{ + size: lastDataIdx - offset, + numElements: int(numElements), + firstFieldIDIdx: firstFieldIDIdx, + firstOffsetIdx: firstOffsetIdx, + firstDataIdx: firstDataIdx, + fieldIDWidth: fieldIDWidth, + offsetWidth: offsetWidth, + }, nil +} + +// Unmarshals a Variant object into the provided destination. The destination must be a pointer +// to one of three types: +// - A struct (unmarshal will map the Variant fields to struct fields by name, or contents of the `variant` annotation) +// - A string-keyed map (the passed in map will be cleared) +// - The "any" type. This will be returned as a map[string]any +func unmarshalObject(raw []byte, md *decodedMetadata, offset int, destPtr reflect.Value) error { + data, err := getObjectData(raw, offset) + if err != nil { + return err + } + + if kind := destPtr.Kind(); kind != reflect.Pointer { + return fmt.Errorf("invalid dest, must be non-nil pointer (got kind %s)", kind) + } + if destPtr.IsNil() { + return errors.New("invalid dest, must be non-nil pointer") + } + + destElem := destPtr.Elem() + + switch kind := destElem.Kind(); kind { + case reflect.Struct: + // Nothing to do. + case reflect.Interface: + // Create a new map[string]any + newMap := reflect.MakeMap(reflect.TypeOf(map[string]any{})) + destElem.Set(newMap) + destElem = newMap + case reflect.Map: + if keyKind := destElem.Type().Key().Kind(); keyKind != reflect.String { + return fmt.Errorf("invalid dest map- must have a string for a key, got %s", keyKind) + } + // Clear out the map to start fresh. + destElem.Clear() + default: + return fmt.Errorf("invalid kind- must be a string-keyed map, struct, or any, got %s", kind) + } + + destType := destElem.Type() + + // For slightly faster lookups, preprocess the struct to get a mapping from field name to field ID. + // We only care about settable struct fields. + fieldIDMap := make(map[string]int) + if destElem.Kind() == reflect.Struct { + for i := range destType.NumField() { + structField := destElem.Field(i) + if structField.CanSet() { + fieldName, _ := extractFieldInfo(destType.Field(i)) + fieldIDMap[fieldName] = i + + // Zero out the field if possible to start fresh. + structField.Set(reflect.Zero(structField.Type())) + } + } + } + + // Iterate through all elements in the encoded Variant. + for i := range data.numElements { + variantFieldID, err := readUint(raw, data.firstFieldIDIdx+data.fieldIDWidth*i, data.fieldIDWidth) + if err != nil { + return err + } + variantKey, ok := md.At(int(variantFieldID)) + if !ok { + return fmt.Errorf("key ID %d not present in metadata dictionary", i) + } + + // Get the new element value depending on whether this is a struct or a map + var newElemValue reflect.Value + if destElem.Kind() == reflect.Struct { + // Get pointer to the field within the struct. + structFieldID, ok := fieldIDMap[variantKey] + if !ok { + continue + } + field := destElem.Field(structFieldID) + newElemValue = field.Addr() + } else { + // New element within the map. + newElemValue = reflect.New(destType.Elem()) + } + + // Set the element value based on what's encoded in the Variant. + elemOffset, err := readUint(raw, data.firstOffsetIdx+data.offsetWidth*i, data.offsetWidth) + if err != nil { + return err + } + dataIdx := int(elemOffset) + data.firstDataIdx + if err := checkBounds(raw, dataIdx, dataIdx); err != nil { + return err + } + if err := unmarshalCommon(raw, md, dataIdx, newElemValue); err != nil { + return err + } + + // Structs already have a pointer to the value and are set. For maps, set the value here. + if destElem.Kind() == reflect.Map { + destElem.SetMapIndex(reflect.ValueOf(variantKey), newElemValue.Elem()) + } + } + + return nil +} diff --git a/parquet/variants/object_test.go b/parquet/variants/object_test.go new file mode 100644 index 00000000..de53d5e9 --- /dev/null +++ b/parquet/variants/object_test.go @@ -0,0 +1,711 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestObjectFromStruct(t *testing.T) { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ob := newObjectBuilder(&buf, mdb, nil) + + type nested struct { + Baz string `variant:"d"` + } + + st := struct { + Foo int `variant:"b"` + Bar bool `variant:"a"` + unexported int `variant:"c"` + Nest nested + }{1, true, 2, nested{Baz: "hi"}} + + if err := ob.fromStruct(&st); err != nil { + t.Fatalf("fromStruct(): %v", err) + } + if err := ob.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + + wantBytes := []byte{ + 0x02, // Basic object, all sizes = 1 byte + 0x03, // 3 items + // Keys are inserted in the order "b", "a", "d", "Nest". "Nest" comes first + // since N (value 78) < a (value 97). + 0x03, // Nested's index + 0x01, // a's index + 0x00, // b's index + 0x03, // Nest's value + 0x02, // a's value + 0x00, // b's value + 0x0B, // end + 0b1100, 0x01, // b's value, int8 val = 1 + 0b0100, // a's value + // Nested object + 0x02, // Basic object, all sizes = 1 byte + 0x01, // 1 item + 0x02, // d's index in the dictionary + 0x00, // d's offset + 0x03, // last item + 0b1001, 'h', 'i', // Basic short string, length 2 + // End of nested object + } + got := buf.Bytes() + + diffByteArrays(t, got, wantBytes) + + // Check the metadata keys as well to ensure the struct tags were picked up approrpiately. + encodedMD := mdb.Build() + decodedMetadata, err := decodeMetadata(encodedMD) + if err != nil { + t.Fatalf("Metadata decode error: %v", err) + } + gotKeys, wantKeys := decodedMetadata.keys, []string{"b", "a", "d", "Nest"} + if diff := cmp.Diff(gotKeys, wantKeys); diff != "" { + t.Fatalf("Incorrect metadata keys. Diff (-got +want):\n%s", diff) + } +} + +func TestObjectFromMap(t *testing.T) { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ob := newObjectBuilder(&buf, mdb, nil) + + toWrite := map[string]any{ + "key1": int64(1), + "key2": false, + "key3": int64(2), + } + if err := ob.fromMap(toWrite); err != nil { + t.Fatalf("fromMap(): %v", err) + } + if err := ob.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + decodedMD, err := decodeMetadata(mdb.Build()) + if err != nil { + t.Fatalf("decodeMetadata(): %v", err) + } + + // Decode into a map and ensure things are correct. Can't really compare bytes here since the + // iteration order of a map is undefined. + got := map[string]any{} + dest := reflect.ValueOf(&got) + if err := unmarshalObject(buf.Bytes(), decodedMD, 0, dest); err != nil { + t.Fatalf("unmarshalObject(): %v", err) + } + diff(t, got, toWrite) +} + +func TestObjectPrimitive(t *testing.T) { + var buf bytes.Buffer + mdb := newMetadataBuilder() + + ob := newObjectBuilder(&buf, mdb, nil) + if err := ob.Write("b", 3); err != nil { + t.Fatalf("Write(b): %v", err) + } + if err := ob.Write("a", 1); err != nil { + t.Fatalf("Write(a): %v", err) + } + if err := ob.Write("c", 2); err != nil { + t.Fatalf("Write(c): %v", err) + } + + if err := ob.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + + wantBytes := []byte{ + 0b00000010, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false + 0x03, // 3 elements + 0x01, // Field ID of "a" + 0x00, // Field ID of "b" + 0x02, // Field ID of "c" + 0x02, // Index of "a"s value (1) + 0x00, // Index of "b"s value (3) + 0x04, // Index of "c"s value (2) + 0x06, // First index after elements + 0b1100, // "b"s header- basic Int8 + 0x03, // "b"'s value of 3 + 0b1100, // "a"s header- basic Int8 + 0x01, // "a"s value of 1 + 0b1100, // "c"s header- basic Int8 + 0x02, // "c"s value of 2 + } + + diff(t, buf.Bytes(), wantBytes) + + // Check the metadata keys as well. + encodedMD := mdb.Build() + decodedMetadata, err := decodeMetadata(encodedMD) + if err != nil { + t.Fatalf("Metadata decode error: %v", err) + } + gotKeys, wantKeys := decodedMetadata.keys, []string{"b", "a", "c"} + diff(t, gotKeys, wantKeys) +} + +func TestObjectNestedArray(t *testing.T) { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ob := newObjectBuilder(&buf, mdb, nil) + if err := ob.Write("b", 3); err != nil { + t.Fatalf("Write(b): %v", err) + } + + arr, err := ob.Array("a") + if err != nil { + t.Fatalf("Array(a): %v", err) + } + for _, val := range []any{true, 123} { + if err := arr.Write(val); err != nil { + t.Fatalf("arr.Write(%v): %v", val, err) + } + } + arr.Build() + ob.Write("c", 8) + + if err := ob.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + + wantBytes := []byte{ + 0b00000010, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false + 0x03, // 3 elements + 0x01, // Field ID of "a" + 0x00, // Field ID of "b" + 0x02, // Field ID of "c" + 0x02, // Index of "a"s value (array {true, 123}) + 0x00, // Index of "b"s value (3) + 0x0A, // Index of "c"s value (8) + 0x0C, // First index after the last value + 0b1100, // "b"s header- basic Int8 + 0x03, // "b"s value of 3 + // Beginning of array + 0b011, // "a"s header- basic array + 0x02, // 2 elements in "a"s array + 0x00, // Index of first element (true) + 0x01, // Index of second element (123) + 0x03, // First index after the last value + 0b100, // First element (basic true- header and value) + 0b1100, // Second element- basic Int8 + 0x7B, // 123 encoded + // End of array + 0b1100, // "c"s header- basic Int8 + 0x08, // "c"s value of 8 + } + diffByteArrays(t, buf.Bytes(), wantBytes) + + // Check the metadata keys as well. + encodedMD := mdb.Build() + decodedMetadata, err := decodeMetadata(encodedMD) + if err != nil { + t.Fatalf("Metadata decode error: %v", err) + } + gotKeys, wantKeys := decodedMetadata.keys, []string{"b", "a", "c"} + diff(t, gotKeys, wantKeys) +} + +func TestObjectNestedObjectSharingKeys(t *testing.T) { + var buf bytes.Buffer + mdb := newMetadataBuilder() + ob := newObjectBuilder(&buf, mdb, nil) + + if err := ob.Write("a", true); err != nil { + t.Fatalf("Write(a): %v", err) + } + + nestedOb, err := ob.Object("b") + if err != nil { + t.Fatalf("Object(b): %v", err) + } + + // Same key can exist in a nested object + if err := nestedOb.Write("b", 123); err != nil { + t.Fatalf("Nested Object Write(b): %v", err) + } + nestedOb.Build() + + if err := ob.Build(); err != nil { + t.Fatalf("Object Build(): %v", err) + } + + wantBytes := []byte{ + 0b10, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false + 0x02, // 2 elements + 0x00, // Field ID of "a", + 0x01, // Field ID of "b", + 0x00, // Index of first element (true) + 0x01, // Index of second element (nested object) + 0x08, // First index after the last value + 0b100, // "a"s header & value (basic true) + // Beginning of nested object + 0b10, // "b"s header- basic object, field_offset_size & field_id_size = 1, is_large = false + 0x01, // 1 element + 0x01, // Field ID of "b" + 0x00, // Index of "b"s value (123) + 0x02, // First index after the last value + 0b1100, // "b"s header- basic Int8 + 0x7B, // 123 encoded + // End of nested object + } + diffByteArrays(t, buf.Bytes(), wantBytes) + + // Check the metadata keys as well. + encodedMD := mdb.Build() + decodedMetadata, err := decodeMetadata(encodedMD) + if err != nil { + t.Fatalf("Metadata decode error: %v", err) + } + gotKeys, wantKeys := decodedMetadata.keys, []string{"a", "b"} + diff(t, gotKeys, wantKeys) +} + +func TestObjectData(t *testing.T) { + cases := []struct { + name string + encoded []byte + offset int + want *objectData + wantErr bool + }{ + { + name: "Basic object no offset", + encoded: []byte{ + 0b00000010, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false + 0x03, // 3 elements + 0x01, // Field ID of "a" + 0x00, // Field ID of "b" + 0x02, // Field ID of "c" + 0x02, // Index of "a"s value (1) + 0x00, // Index of "b"s value (3) + 0x04, // Index of "c"s value (2) + 0x06, // First index after elements + 0b1100, // "b"s header- basic Int8 + 0x03, // "b"'s value of 3 + 0b1100, // "a"s header- basic Int8 + 0x01, // "a"s value of 1 + 0b1100, // "c"s header- basic Int8 + 0x02, // "c"s value of 2 + }, + want: &objectData{ + size: 15, + numElements: 3, + firstFieldIDIdx: 2, + firstOffsetIdx: 5, + firstDataIdx: 9, + fieldIDWidth: 1, + offsetWidth: 1, + }, + }, + { + name: "Basic object with offset", + encoded: []byte{ + 0x00, 0b00000010, 0x03, 0x01, 0x00, 0x02, 0x02, 0x00, + 0x04, 0x06, 0b1100, 0x03, 0b1100, 0x01, 0b1100, 0x02, + }, + offset: 1, + want: &objectData{ + size: 15, + numElements: 3, + firstFieldIDIdx: 3, + firstOffsetIdx: 6, + firstDataIdx: 10, + fieldIDWidth: 1, + offsetWidth: 1, + }, + }, + { + name: "Object with larger widths", + encoded: []byte{ + 0b01100110, // Basic type = object, field offset size = 2, field ID size = 3, is_large = true, + 0x03, 0x00, 0x00, 0x00, // 3 elements (encoded in 4 bytes due to is_large) + 0x01, 0x00, 0x00, // Field ID of "a" + 0x00, 0x00, 0x00, // Field ID of "b" + 0x02, 0x00, 0x00, // Field ID of "c" + 0x02, 0x00, // Index of "a"s value (1) + 0x00, 0x00, // Index of "b"s value (3) + 0x04, 0x00, // Index of "c"s value (2) + 0x06, 0x00, // First index after elements + 0b1100, 0x03, // Basic Int8 value of 3 + 0b1100, 0x01, // Basic Int8 value of 1 + 0b1100, 0x02, // Basic Int8 value of 2 + }, + want: &objectData{ + size: 28, + numElements: 3, + firstFieldIDIdx: 5, + firstOffsetIdx: 14, + firstDataIdx: 22, + fieldIDWidth: 3, + offsetWidth: 2, + }, + }, + { + name: "Incorrect basic type", + encoded: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, + wantErr: true, + }, + { + name: "Num elements out of bounds", + encoded: []byte{0x02}, + wantErr: true, + }, + { + name: "Object out of bounds", + encoded: []byte{0x02, 0x03, 0x01, 0x00, 0x02, 0x02, 0x00, 0x04, 0x06}, + wantErr: true, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := getObjectData(c.encoded, c.offset) + if err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatal("Got no error when one was expected") + } + if diff := cmp.Diff(got, c.want, cmp.AllowUnexported(objectData{})); diff != "" { + t.Fatalf("Incorrect returned data. Diff (-got, +want)\n%s", diff) + } + }) + } +} + +func TestUnmarshalObject(t *testing.T) { + cases := []struct { + name string + md *decodedMetadata + encoded []byte + offset int + // Normally the test will unmarhall into the type in want, but if this override is set, + // it'll try to unmarshal into something of this type. + overrideDecodeType reflect.Type + want any + wantErr bool + }{ + { + name: "Object built of primitives into map", + md: &decodedMetadata{keys: []string{"b", "a", "c"}}, + encoded: []byte{ + 0b01111110, // Basic type = object, field offset & field ID size = 4, is_large = true, + 0x03, 0x00, 0x00, 0x00, // 3 elements (encoded in 4 bytes due to is_large) + 0x01, 0x00, 0x00, 0x00, // Field ID of "a" + 0x00, 0x00, 0x00, 0x00, // Field ID of "b" + 0x02, 0x00, 0x00, 0x00, // Field ID of "c" + 0x02, 0x00, 0x00, 0x00, // Index of "a"s value (1) + 0x00, 0x00, 0x00, 0x00, // Index of "b"s value (3) + 0x04, 0x00, 0x00, 0x00, // Index of "c"s value (2) + 0x06, 0x00, 0x00, 0x00, // First index after elements + 0b1100, 0x03, // Basic Int8 value of 3 + 0b1100, 0x01, // Basic Int8 value of 1 + 0b1100, 0x02, // Basic Int8 value of 2 + }, + want: map[string]any{ + "a": int64(1), + "b": int64(3), + "c": int64(2), + }, + }, + { + name: "Object built of primitives into map with offset", + md: &decodedMetadata{keys: []string{"b", "a", "c"}}, + offset: 3, + encoded: []byte{ + 0x00, 0x00, 0x00, // 3 offset bytes + 0b01111110, // Basic type = object, field offset & field ID size = 4, is_large = true, + 0x03, 0x00, 0x00, 0x00, // 3 elements (encoded in 4 bytes due to is_large) + 0x01, 0x00, 0x00, 0x00, // Field ID of "a" + 0x00, 0x00, 0x00, 0x00, // Field ID of "b" + 0x02, 0x00, 0x00, 0x00, // Field ID of "c" + 0x02, 0x00, 0x00, 0x00, // Index of "a"s value (1) + 0x00, 0x00, 0x00, 0x00, // Index of "b"s value (3) + 0x04, 0x00, 0x00, 0x00, // Index of "c"s value (2) + 0x06, 0x00, 0x00, 0x00, // First index after elements + 0b1100, 0x03, // Basic Int8 value of 3 + 0b1100, 0x01, // Basic Int8 value of 1 + 0b1100, 0x02, // Basic Int8 value of 2 + }, + want: map[string]any{ + "a": int64(1), + "b": int64(3), + "c": int64(2), + }, + }, + { + name: "Object built of primitives into struct", + md: &decodedMetadata{keys: []string{"b", "a", "c"}}, + encoded: []byte{ + 0b01111110, // Basic type = object, field offset & field ID size = 4, is_large = true, + 0x03, 0x00, 0x00, 0x00, // 3 elements (encoded in 4 bytes due to is_large) + 0x01, 0x00, 0x00, 0x00, // Field ID of "a" + 0x00, 0x00, 0x00, 0x00, // Field ID of "b" + 0x02, 0x00, 0x00, 0x00, // Field ID of "c" + 0x02, 0x00, 0x00, 0x00, // Index of "a"s value (1) + 0x00, 0x00, 0x00, 0x00, // Index of "b"s value (3) + 0x04, 0x00, 0x00, 0x00, // Index of "c"s value (2) + 0x06, 0x00, 0x00, 0x00, // First index after elements + 0b1100, 0x03, // Basic Int8 value of 3 + 0b1100, 0x01, // Basic Int8 value of 1 + 0b1100, 0x02, // Basic Int8 value of 2 + }, + want: struct { + A int `variant:"a"` + B int `variant:"b"` + C int `variant:"c"` + }{1, 3, 2}, + }, + { + name: "Complex object into map", + md: &decodedMetadata{keys: []string{"key1", "key2", "array", "otherkey"}}, + encoded: func() []byte { + var buf bytes.Buffer + mdb := newMetadataBuilder() + + builder := newObjectBuilder(&buf, mdb, nil) + + builder.Write("key1", 123) + builder.Write("key2", "hello") + ab, err := builder.Array("array") + if err != nil { + t.Fatalf("Array('array'): %v", err) + } + ab.Write(false) + ab.Write("substr") + ab.Build() + + builder.Write("otherkey", []byte{'b', 'y', 't', 'e'}) + if err := builder.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + return buf.Bytes() + }(), + want: map[string]any{ + "key1": int64(123), + "key2": "hello", + "array": []any{false, "substr"}, + "otherkey": []byte{'b', 'y', 't', 'e'}, + }, + }, + { + name: "Complex object into struct", + md: &decodedMetadata{keys: []string{"key1", "key2", "array", "otherkey"}}, + encoded: func() []byte { + var buf bytes.Buffer + mdb := newMetadataBuilder() + + builder := newObjectBuilder(&buf, mdb, nil) + + builder.Write("key1", 123) + builder.Write("key2", "hello") + ab, err := builder.Array("array") + if err != nil { + t.Fatalf("Array('array'): %v", err) + } + ab.Write(false) + ab.Write("substr") + ab.Build() + + builder.Write("otherkey", []byte{'b', 'y', 't', 'e'}) + if err := builder.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + return buf.Bytes() + }(), + want: struct { + K1 int `variant:"key1"` + K2 string `variant:"key2"` + Arr []any `variant:"array"` + Other []byte `variant:"otherkey"` + }{123, "hello", []any{false, "substr"}, []byte{'b', 'y', 't', 'e'}}, + }, + { + name: "Unmarshal skips non-present fields in struct", + md: &decodedMetadata{keys: []string{"key1", "key2", "key3"}}, + encoded: func() []byte { + var buf bytes.Buffer + mdb := newMetadataBuilder() + builder := newObjectBuilder(&buf, mdb, nil) + builder.Write("key1", 123) + builder.Write("key2", "hello") + builder.Write("key3", false) + if err := builder.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + return buf.Bytes() + }(), + want: struct { + K1 int `variant:"key1"` + // key2 is undefined + K3 bool `variant:"key3"` + }{123, false}, + }, + { + name: "Unmarshal into typed map", + md: &decodedMetadata{keys: []string{"key1", "key2", "key3"}}, + encoded: func() []byte { + var buf bytes.Buffer + mdb := newMetadataBuilder() + builder := newObjectBuilder(&buf, mdb, nil) + builder.Write("key1", 123) + builder.Write("key2", 234) + builder.Write("key3", 345) + if err := builder.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + return buf.Bytes() + }(), + want: map[string]int{ + "key1": 123, + "key2": 234, + "key3": 345, + }, + }, + { + name: "Malformed raw data", + md: &decodedMetadata{keys: []string{"key1", "key2", "key3"}}, + encoded: func() []byte { + var buf bytes.Buffer + mdb := newMetadataBuilder() + builder := newObjectBuilder(&buf, mdb, nil) + builder.Write("key1", 123) + builder.Write("key2", 234) + builder.Write("key3", 345) + if err := builder.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + return buf.Bytes()[1:] // Lop off the first byte + }(), + overrideDecodeType: reflect.TypeOf(map[string]any{}), + wantErr: true, + }, + { + name: "Maps must be string keyed", + md: &decodedMetadata{keys: []string{"key1", "key2", "key3"}}, + encoded: func() []byte { + var buf bytes.Buffer + mdb := newMetadataBuilder() + builder := newObjectBuilder(&buf, mdb, nil) + builder.Write("key1", 123) + builder.Write("key2", 234) + builder.Write("key3", 345) + if err := builder.Build(); err != nil { + t.Fatalf("Build(): %v", err) + } + return buf.Bytes() + }(), + overrideDecodeType: reflect.TypeOf(map[int]any{}), + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var got reflect.Value + typ := reflect.TypeOf(c.want) + if c.overrideDecodeType != nil { + typ = c.overrideDecodeType + } + if typ.Kind() == reflect.Map { + underlying := reflect.MakeMap(typ) + got = reflect.New(typ) + got.Elem().Set(underlying) + } else { + got = reflect.New(typ) + } + if err := unmarshalObject(c.encoded, c.md, c.offset, got); err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatalf("Got no error when one was expected") + } + diff(t, got.Elem().Interface(), c.want) + }) + } +} + +func TestExtractFieldInfo(t *testing.T) { + type testStruct struct { + NoTags int + JustName int `variant:"just_name"` + EmptyTag int `variant:""` + WithOpts int `variant:"with_opts,ntz,date,nanos,time"` + OptsNoName int `variant:",uuid"` + UnknownOpt int `variant:"unknown,not_defined_opt"` + } + cases := []struct { + name string + field int + wantName string + wantOpts []MarshalOpts + }{ + { + name: "No tags", + field: 0, + wantName: "NoTags", + }, + { + name: "Field tag with just name", + field: 1, + wantName: "just_name", + }, + { + name: "Empty tag uses struct field name", + field: 2, + wantName: "EmptyTag", + }, + { + name: "Field tag with name and options", + field: 3, + wantName: "with_opts", + wantOpts: []MarshalOpts{MarshalTimeNTZ, MarshalAsDate, MarshalTimeNanos, MarshalAsTime}, + }, + { + name: "Just options, no name uses struct field name", + field: 4, + wantName: "OptsNoName", + wantOpts: []MarshalOpts{MarshalAsUUID}, + }, + { + name: "Ignore unknown options", + field: 5, + wantName: "unknown", + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + val := reflect.ValueOf(testStruct{}) + gotName, gotOpts := extractFieldInfo(val.Type().Field(c.field)) + if gotName != c.wantName { + t.Errorf("Incorrect name. Got %q, want %q", gotName, c.wantName) + } + diff(t, gotOpts, c.wantOpts) + }) + } +} diff --git a/parquet/variants/primitive.go b/parquet/variants/primitive.go new file mode 100644 index 00000000..8ab48784 --- /dev/null +++ b/parquet/variants/primitive.go @@ -0,0 +1,678 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "fmt" + "io" + "math" + "reflect" + "strings" + "time" +) + +// Variant primitive type IDs. +type primitiveType int + +const ( + primitiveInvalid primitiveType = -1 + primitiveNull primitiveType = 0 + primitiveTrue primitiveType = 1 + primitiveFalse primitiveType = 2 + primitiveInt8 primitiveType = 3 + primitiveInt16 primitiveType = 4 + primitiveInt32 primitiveType = 5 + primitiveInt64 primitiveType = 6 + primitiveDouble primitiveType = 7 + primitiveDecimal4 primitiveType = 8 // TODO + primitiveDecimal8 primitiveType = 9 // TODO + primitiveDecimal16 primitiveType = 10 // TODO + primitiveDate primitiveType = 11 + primitiveTimestampMicros primitiveType = 12 + primitiveTimestampNTZMicros primitiveType = 13 + primitiveFloat primitiveType = 14 + primitiveBinary primitiveType = 15 + primitiveString primitiveType = 16 + primitiveTimeNTZ primitiveType = 17 + primitiveTimestampNanos primitiveType = 18 + primitiveTimestampNTZNanos primitiveType = 19 + primitiveUUID primitiveType = 20 +) + +func (pt primitiveType) String() string { + switch pt { + case primitiveNull: + return "Null" + case primitiveFalse, primitiveTrue: + return "Boolean" + case primitiveInt8: + return "Int8" + case primitiveInt16: + return "Int16" + case primitiveInt32: + return "Int32" + case primitiveInt64: + return "Int64" + case primitiveDouble: + return "Double" + case primitiveDecimal4: + return "Decimal4" + case primitiveDecimal8: + return "Decimal8" + case primitiveDecimal16: + return "Decimal16" + case primitiveDate: + return "Date" + case primitiveTimestampMicros: + return "Timestamp(micros)" + case primitiveTimestampNTZMicros: + return "TimestampNTZ(micros)" + case primitiveFloat: + return "Float" + case primitiveBinary: + return "Binary" + case primitiveString: + return "String" + case primitiveTimeNTZ: + return "TimeNTZ" + case primitiveTimestampNanos: + return "Timestamp(nanos)" + case primitiveTimestampNTZNanos: + return "TimestampNTZ(nanos)" + case primitiveUUID: + return "UUID" + } + return "Invalid" +} + +func validPrimitiveValue(prim primitiveType) error { + if prim < primitiveNull || prim > primitiveUUID { + return fmt.Errorf("invalid primitive type: %d", prim) + } + return nil +} + +func primitiveFromHeader(hdr byte) (primitiveType, error) { + // Special case the basic type of Short String and call it a Primitive String. + bt := BasicTypeFromHeader(hdr) + if bt == BasicShortString { + return primitiveString, nil + } else if bt == BasicPrimitive { + prim := primitiveType(hdr >> 2) + if err := validPrimitiveValue(prim); err != nil { + return primitiveInvalid, err + } + return prim, nil + } + return primitiveInvalid, fmt.Errorf("header is not of a primitive or short string basic type: %s", bt) +} + +func primitiveHeader(prim primitiveType) (byte, error) { + if err := validPrimitiveValue(prim); err != nil { + return 0, err + } + hdr := byte(prim << 2) + hdr |= byte(BasicPrimitive) + return hdr, nil +} + +// marshalPrimitive takes in a primitive value, asserts its type, then marshals the data according to the Variant spec +// into the provided writer, returning the number of bytes written. +// +// Time can be provided in various ways- either by a time.Time struct, or by an int64 when the EncodeAs{Date,Time,Timestamp} +// options are provided. By default, timestamps are written as microseconds- to use nanoseconds pass in EncodeTimeAsNanos. +// Timezone information can be determined from a time.Time struct. Otherwise, by default, timestamps will be written with +// local timezone set. +func marshalPrimitive(v any, w io.Writer, opts ...MarshalOpts) (int, error) { + var allOpts MarshalOpts + for _, o := range opts { + allOpts |= o + } + switch val := v.(type) { + case bool: + return marshalBoolean(val, w), nil + case int: + return marshalInt(int64(val), w), nil + case int8: + return marshalInt(int64(val), w), nil + case int16: + return marshalInt(int64(val), w), nil + case int32: + return marshalInt(int64(val), w), nil + case int64: + if allOpts&MarshalAsTime != 0 { + encodeTimestamp(val, allOpts&MarshalTimeNanos != 0, allOpts&MarshalTimeNTZ != 0, w) + } + return marshalInt(val, w), nil + case float32: + return marshalFloat(val, w), nil + case float64: + return marshalDouble(val, w), nil + case string: + if allOpts&MarshalAsUUID != 0 { + return marshalUUID([]byte(val), w), nil + } + return marshalString(val, w), nil + case []byte: + if allOpts&MarshalAsUUID != 0 { + return marshalUUID([]byte(val), w), nil + } + return marshalBinary(val, w), nil + case time.Time: + if allOpts&MarshalAsDate != 0 { + return marshalDate(val, w), nil + } + return marshalTimestamp(val, allOpts&MarshalTimeNanos != 0, w), nil + } + if v == nil { + return marshalNull(w), nil + } + return -1, fmt.Errorf("unsupported primitive type") +} + +// unmarshals a primitive (or a short string) into dest. dest must be a non-nil pointer to variable that is +// compatible with the Variant value to decode. Some conversions can take place: +// - Integer values: Higher widths can be decoded into smaller widths so long as they don't overflow. Also, +// integral values can be decoded into floats (also so long as they don't overflow). +// - Time/timestamps: Can be decoded into either int64 or time.Time, the latter of which will carry time zone information +// - Strings and binary: Can be decoded into either string or []byte +// +// If the Variant primitive is of the Null type, dest will be set to its zero value. +func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { + dest := destPtr.Elem() + kind := dest.Kind() + + if err := checkBounds(raw, offset, offset); err != nil { + return err + } + + prim, err := primitiveFromHeader(raw[offset]) + if err != nil { + return err + } + + switch prim { + case primitiveNull: + dest.Set(reflect.Zero(dest.Type())) + case primitiveTrue, primitiveFalse: + if kind != reflect.Bool && kind != reflect.Interface { + return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) + } + dest.Set(reflect.ValueOf(prim == primitiveTrue)) + case primitiveInt8, primitiveInt16, primitiveInt32, primitiveInt64: + iv, err := decodeIntPhysical(raw, offset) + if err != nil { + return err + } + if kind == reflect.Interface { + dest.Set(reflect.ValueOf(iv)) + } else if dest.CanInt() { + if dest.OverflowInt(iv) { + return fmt.Errorf("int value of %d will overflow dest", iv) + } + switch kind { + case reflect.Int: + dest.Set(reflect.ValueOf(int(iv))) + case reflect.Int8: + dest.Set(reflect.ValueOf(int8(iv))) + case reflect.Int16: + dest.Set(reflect.ValueOf(int16(iv))) + case reflect.Int32: + dest.Set(reflect.ValueOf(int32(iv))) + case reflect.Int64: + dest.Set(reflect.ValueOf(iv)) + default: + panic("unhandled int value") + } + } else if dest.CanFloat() { + // Converting from an int64 to a float64 can potentially lose precision, but it's still a valid + // conversion and can be supported here. + fv := float64(iv) + if dest.OverflowFloat(fv) { + return fmt.Errorf("value of %d will overflow dest", iv) + } + switch kind { + case reflect.Float32: + dest.Set(reflect.ValueOf(float32(fv))) + case reflect.Float64: + dest.Set(reflect.ValueOf(fv)) + default: + panic("unhandled float value") + } + } else { + return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) + } + case primitiveFloat: + fv, err := unmarshalFloat(raw, offset) + if err != nil { + return err + } + if !dest.CanFloat() && kind != reflect.Interface { + return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) + } + switch kind { + case reflect.Float32, reflect.Interface: + dest.Set(reflect.ValueOf(fv)) + case reflect.Float64: + dest.Set(reflect.ValueOf(float64(fv))) + default: + panic("unhandled float value") + } + case primitiveDouble: + dv, err := unmarshalDouble(raw, offset) + if err != nil { + return err + } + if !dest.CanFloat() && kind != reflect.Interface { + return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) + } + switch kind { + case reflect.Float32: + if dest.OverflowFloat(dv) { + return fmt.Errorf("value of %f will overflow dest", dv) + } + dest.Set(reflect.ValueOf(float32(dv))) + case reflect.Float64, reflect.Interface: + dest.Set(reflect.ValueOf(dv)) + default: + panic("unhandled float value") + } + case primitiveTimeNTZ, primitiveTimestampMicros, primitiveTimestampNTZMicros, + primitiveTimestampNanos, primitiveTimestampNTZNanos: + tsv, err := readUint(raw, offset+1, 8) + if err != nil { + return err + } + + // Time can be decoded into either an int64 (the physical time), or into a time.Time struct. + // Anything else is invalid. + if kind == reflect.Int64 || kind == reflect.Interface { + dest.Set(reflect.ValueOf(int64(tsv))) + } else if kind == reflect.Uint64 { + dest.Set(reflect.ValueOf(tsv)) + } else if dest.Type() == reflect.TypeOf(time.Time{}) { + var t time.Time + if prim == primitiveTimeNTZ { + // TimeNTZ for Variants is UTC=false (ie. local timezone) and in microseconds + t = time.Date(0, 0, 0, 0, 0, 0, 1000*int(tsv), time.Local) + } else { + if prim == primitiveTimestampMicros || prim == primitiveTimestampNTZMicros { + t = time.UnixMicro(int64(tsv)) + } else { + sec := int64(tsv / 1e9) + nsec := int64(tsv % 1e9) + t = time.Unix(sec, nsec) + } + if prim == primitiveTimestampMicros || prim == primitiveTimestampNanos { + t = t.In(time.Local) + } + } + dest.Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) + } + case primitiveString: + str, err := unmarshalString(raw, offset) + if err != nil { + return err + } + if kind == reflect.String || kind == reflect.Interface { + dest.Set(reflect.ValueOf(str)) + } else if kind == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8 { + dest.Set(reflect.ValueOf([]byte(str))) + } else { + return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) + } + case primitiveBinary: + bytes, err := unmarshalBinary(raw, offset) + if err != nil { + return err + } + if kind == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8 || kind == reflect.Interface { + dest.Set(reflect.ValueOf(bytes)) + } else if kind == reflect.String { + dest.Set(reflect.ValueOf(string(bytes))) + } else { + return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) + } + case primitiveUUID: + bytes, err := unmarshalUUID(raw, offset) + if err != nil { + return err + } + if kind == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8 || kind == reflect.Interface { + dest.Set(reflect.ValueOf(bytes)) + } else if kind == reflect.String { + dest.Set(reflect.ValueOf(string(bytes))) + } else { + return fmt.Errorf("cannot decode Variant UUID into dest %s", kind) + } + default: + return fmt.Errorf("unknown primitive: %s", prim) + } + + return nil +} + +func marshalNull(w io.Writer) int { + hdr, _ := primitiveHeader(primitiveNull) + w.Write([]byte{hdr}) + return 1 +} + +func marshalBoolean(b bool, w io.Writer) int { + var hdr byte + if b { + hdr, _ = primitiveHeader(primitiveTrue) + } else { + hdr, _ = primitiveHeader(primitiveFalse) + } + w.Write([]byte{hdr}) + return 1 +} + +func unmarshalBoolean(raw []byte, offset int) (bool, error) { + prim, err := primitiveFromHeader(raw[offset]) + if err != nil { + return false, err + } + return prim == primitiveTrue, nil +} + +// Encodes an integer with the appropriate primitive header. This encodes the int +// into the minimal space necessary regardless of the width that's passed in (eg. an +// int64 of value 1 will be encoded into an int8) +func marshalInt(val int64, w io.Writer) int { + var hdr byte + var size int + if val < math.MaxInt8 && val > math.MinInt8 { + hdr, _ = primitiveHeader(primitiveInt8) + size = 1 + } else if val < math.MaxInt16 && val > math.MinInt16 { + hdr, _ = primitiveHeader(primitiveInt16) + size = 2 + } else if val < math.MaxInt32 && val > math.MinInt32 { + hdr, _ = primitiveHeader(primitiveInt32) + size = 4 + } else { + hdr, _ = primitiveHeader(primitiveInt64) + size = 8 + } + w.Write([]byte{hdr}) + encodeNumber(val, size, w) + return size + 1 +} + +func decodeIntPhysical(raw []byte, offset int) (int64, error) { + typ, _ := primitiveFromHeader(raw[offset]) + var size int + switch typ { + case primitiveInt8: + size = 1 + case primitiveInt16: + size = 2 + case primitiveInt32, primitiveDate: + size = 4 + case primitiveInt64: + size = 8 + default: + return -1, fmt.Errorf("not an integral type: %s", typ) + } + val, err := readInt(raw, offset+1, size) + if err != nil { + return -1, err + } + + // Do a conversion dance from the minimal width to int64 to catch + // negative numbers. + switch typ { + case primitiveInt8: + return int64(int8(val)), nil + case primitiveInt16: + return int64(int16(val)), nil + case primitiveInt32: + return int64(int32(val)), nil + default: + return int64(val), nil + } +} + +func marshalFloat(val float32, w io.Writer) int { + buf := make([]byte, 5) + hdr, _ := primitiveHeader(primitiveFloat) + buf[0] = hdr + bits := math.Float32bits(val) + for i := range 4 { + buf[i+1] = byte(bits) + bits >>= 8 + } + w.Write(buf) + return 5 +} + +func marshalDouble(val float64, w io.Writer) int { + buf := make([]byte, 9) + hdr, _ := primitiveHeader(primitiveDouble) + buf[0] = hdr + bits := math.Float64bits(val) + for i := range 8 { + buf[i+1] = byte(bits) + bits >>= 8 + } + w.Write(buf) + return 9 +} + +func unmarshalFloat(raw []byte, offset int) (float32, error) { + v, err := readUint(raw, offset+1, 4) + if err != nil { + return -1, err + } + return math.Float32frombits(uint32(v)), nil +} + +func unmarshalDouble(raw []byte, offset int) (float64, error) { + v, err := readUint(raw, offset+1, 8) + if err != nil { + return -1, err + } + return math.Float64frombits(v), nil +} + +func encodePrimitiveBytes(b []byte, w io.Writer) int { + encodeNumber(int64(len(b)), 4, w) + w.Write(b) + return len(b) + 4 +} + +func marshalString(str string, w io.Writer) int { + str = strings.ToValidUTF8(str, "\uFFFD") + + // If the string is 63 characters or less, encode this as a short string to save space. + strlen := len(str) + if strlen < 0x3F { + hdr := byte(strlen << 2) + hdr |= byte(BasicShortString) + w.Write([]byte{hdr}) + w.Write([]byte(str)) + return 1 + strlen + } + + // Otherwise, encode this as a basic string. + hdr, _ := primitiveHeader(primitiveString) + w.Write([]byte{hdr}) + return 1 + encodePrimitiveBytes([]byte(strings.ToValidUTF8(str, "\uFFFD")), w) +} + +func marshalUUID(uuid []byte, w io.Writer) int { + hdr, _ := primitiveHeader(primitiveUUID) + w.Write([]byte{hdr}) + + // A UUID is 16 bytes. Either pad or truncate to this length. + if len(uuid) > 16 { + uuid = uuid[:16] + } else if pad := 16 - len(uuid); pad > 0 { + uuid = append(uuid, make([]byte, pad)...) + } + w.Write(uuid) + return 17 +} + +func unmarshalUUID(raw []byte, offset int) ([]byte, error) { + if err := checkBounds(raw, offset, offset+17); err != nil { + return nil, err + } + return raw[offset+1 : offset+17], nil +} + +func unmarshalString(raw []byte, offset int) (string, error) { + // Determine if the string is a short string, or a basic string. + maxPos := len(raw) + if offset >= maxPos { + return "", fmt.Errorf("offset is out of bounds: trying to access position %d, max position is %d", offset, maxPos) + } + bt := BasicTypeFromHeader(raw[offset]) + + if bt == BasicShortString { + l := int(raw[offset] >> 2) + endIdx := 1 + l + offset + if endIdx > maxPos { + return "", fmt.Errorf("end index is out of bounds: trying to access position %d, max position is %d", endIdx, maxPos) + } + return string(raw[offset+1 : endIdx]), nil + } + + b, err := getBytes(raw, offset+1) + if err != nil { + return "", err + } + return string(b), nil +} + +func getBytes(raw []byte, offset int) ([]byte, error) { + l, err := readUint(raw, offset, 4) + if err != nil { + return nil, fmt.Errorf("could not read length: %v", err) + } + maxIdx := offset + 4 + int(l) + if len(raw) < maxIdx { + return nil, fmt.Errorf("bytes are out of bounds") + } + return raw[offset+4 : maxIdx], nil +} + +func marshalBinary(b []byte, w io.Writer) int { + hdr, _ := primitiveHeader(primitiveBinary) + w.Write([]byte{hdr}) + return 1 + encodePrimitiveBytes(b, w) +} + +func unmarshalBinary(raw []byte, offset int) ([]byte, error) { + return getBytes(raw, offset+1) +} + +func marshalTimestamp(t time.Time, nanos bool, w io.Writer) int { + var typ primitiveType + var ts int64 + ntz := t.Location() == time.UTC + if nanos { + ts = t.UnixNano() + if ntz { + typ = primitiveTimestampNTZNanos + } else { + typ = primitiveTimestampNanos + } + } else { + ts = t.UnixMicro() + if ntz { + typ = primitiveTimestampNTZMicros + } else { + typ = primitiveTimestampMicros + } + } + hdr, _ := primitiveHeader(typ) + w.Write([]byte{hdr}) + encodeNumber(ts, 8, w) + return 9 +} + +func encodeTimestamp(t int64, nanos, ntz bool, w io.Writer) int { + var typ primitiveType + if nanos { + if ntz { + typ = primitiveTimestampNTZNanos + } else { + typ = primitiveTimestampNanos + } + } else { + if ntz { + typ = primitiveTimestampNTZMicros + } else { + typ = primitiveTimestampMicros + } + } + hdr, _ := primitiveHeader(typ) + w.Write([]byte{hdr}) + encodeNumber(t, 8, w) + return 9 +} + +// func decodeTimestamp(raw []byte, offset int) (int64, error) { +// ts, err := readUint(raw, offset+1, 8) +// if err != nil { +// return -1, err +// } +// return int64(ts), nil +// } + +func unmarshalTimestamp(raw []byte, offset int) (time.Time, error) { + typ, _ := primitiveFromHeader(raw[offset]) + ts, err := readUint(raw, offset+1, 8) + if err != nil { + return time.Time{}, err + } + var ret time.Time + if typ == primitiveTimestampMicros || typ == primitiveTimestampNTZMicros { + ret = time.UnixMicro(int64(ts)) + } else { + ret = time.Unix(0, int64(ts)) + } + if typ == primitiveTimestampNTZMicros || typ == primitiveTimestampNTZNanos { + ret = ret.UTC() + } else { + ret = ret.Local() + } + return ret, nil +} + +func marshalDate(t time.Time, w io.Writer) int { + epoch := time.Unix(0, 0) + since := t.Sub(epoch) + days := int64(since.Hours() / 24) + hdr, _ := primitiveHeader(primitiveDate) + w.Write([]byte{hdr}) + encodeNumber(days, 4, w) + return 5 +} + +func unmarshalDate(raw []byte, offset int) (time.Time, error) { + days, err := readUint(raw, offset+1, 4) + if err != nil { + return time.Time{}, err + } + return time.Unix(0, 0).Add(time.Hour * 24 * time.Duration(days)), nil +} diff --git a/parquet/variants/primitive_test.go b/parquet/variants/primitive_test.go new file mode 100644 index 00000000..371b8aa0 --- /dev/null +++ b/parquet/variants/primitive_test.go @@ -0,0 +1,606 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "bytes" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func diffByteArrays(t *testing.T, got, want []byte) { + t.Helper() + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("Incorrect encoding. Diff (-got, +want):\n%s", diff) + } +} + +func checkSize(t *testing.T, wantSize int, buf []byte) { + t.Helper() + if gotSize := len(buf); gotSize != wantSize { + t.Errorf("Incorrect reported size: got %d, want %d", gotSize, wantSize) + } +} + +func TestBoolean(t *testing.T) { + var b bytes.Buffer + size := marshalBoolean(true, &b) + encodedTrue := b.Bytes() + checkSize(t, size, encodedTrue) + diffByteArrays(t, encodedTrue, []byte{0b100}) + got, err := unmarshalBoolean(encodedTrue, 0) + if err != nil { + t.Fatalf("unmarshalBoolean(): %v", err) + } + if got != true { + t.Fatalf("Incorrect boolean returned. Got %t, want true", got) + } + + b.Reset() + marshalBoolean(false, &b) + encodedFalse := b.Bytes() + diffByteArrays(t, encodedFalse, []byte{0b1000}) + got, err = unmarshalBoolean(encodedFalse, 0) + if err != nil { + t.Fatalf("unmarshalBoolean(): %v", err) + } + if got != false { + t.Fatalf("Incorrect boolean returned. Got %t, want false", got) + } +} + +func TestInt(t *testing.T) { + cases := []struct { + name string + val int64 + wantHdr byte + wantHexVal []byte + }{ + { + name: "Positive Int8", + val: 8, + wantHdr: 0b1100, + wantHexVal: []byte{0x08}, + }, + { + name: "Negative Int8", + val: -8, + wantHdr: 0b1100, + wantHexVal: []byte{0xF8}, + }, + { + name: "Positive Int16", + val: 200, + wantHdr: 0b10000, + wantHexVal: []byte{0xC8, 0x00}, + }, + { + name: "NegativeInt16", + val: -200, + wantHdr: 0b10000, + wantHexVal: []byte{0x38, 0xFF}, + }, + { + name: "Positive Int32", + val: 32768, + wantHdr: 0b10100, + wantHexVal: []byte{0x00, 0x80, 0x00, 0x00}, + }, + { + name: "Negative Int32", + val: -32768, + wantHdr: 0b10100, + wantHexVal: []byte{0x00, 0x80, 0xFF, 0xFF}, + }, + { + name: "Positive Int64", + val: 9223372036854775807, + wantHdr: 0b11000, + wantHexVal: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, + }, + { + name: "Negative Int64", + val: -9223372036854775807, + wantHdr: 0b11000, + wantHexVal: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var b bytes.Buffer + size := marshalInt(c.val, &b) + encoded := b.Bytes() + checkSize(t, size, encoded) + if gotHdr := encoded[0]; gotHdr != c.wantHdr { + t.Fatalf("Incorrect header: got %x, want %x", gotHdr, c.wantHdr) + } + diffByteArrays(t, encoded[1:], c.wantHexVal) + gotVal, err := decodeIntPhysical(encoded, 0) + if err != nil { + t.Fatalf("decodeIntPhysical(): %v", err) + } + if wantVal := c.val; gotVal != wantVal { + t.Fatalf("Incorrect decoded value: got %d, want %d", gotVal, wantVal) + } + }) + } +} + +func TestUUID(t *testing.T) { + cases := []struct { + name string + uuid []byte + want []byte + }{ + { + name: "UUID no padding", + uuid: []byte("sixteencharacter"), + want: []byte{ + 0b1010000, // Basic primitive UUID + 's', 'i', 'x', 't', 'e', 'e', 'n', + 'c', 'h', 'a', 'r', 'a', 'c', 't', 'e', 'r', + }, + }, + { + name: "UUID truncation", + uuid: []byte("sixteencharacters"), + want: []byte{ + 0b1010000, // Basic primitive UUID + 's', 'i', 'x', 't', 'e', 'e', 'n', + 'c', 'h', 'a', 'r', 'a', 'c', 't', 'e', 'r', + }, + }, + { + name: "UUID padding", + uuid: []byte("small"), + want: []byte{ + 0b1010000, // Basic primitive UUID + 's', 'm', 'a', 'l', 'l', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var b bytes.Buffer + size := marshalUUID(c.uuid, &b) + if size != 17 { + t.Fatalf("Incorrect size. Got %d, want 17", size) + } + diff(t, b.Bytes(), c.want) + + gotUUID, err := unmarshalUUID(b.Bytes(), 0) + if err != nil { + t.Fatalf("unmarshalUUID(): %v", err) + } + diff(t, gotUUID, c.want[1:]) + }) + } +} + +func TestFloat(t *testing.T) { + var b bytes.Buffer + size := marshalFloat(1.1, &b) + encodedFloat := b.Bytes() + checkSize(t, size, encodedFloat) + diffByteArrays(t, encodedFloat, []byte{ + 0b111000, // Primitive type, float + 0xCD, + 0xCC, + 0x8C, + 0x3F, // 0x3F8C CCCD ~= 1.1 encoded + }) + got, err := unmarshalFloat(encodedFloat, 0) + if err != nil { + t.Fatalf("unmarshalFloat(): %v", err) + } + if want := float32(1.1); got != want { + t.Fatalf("Incorrect float returned. Got %.2f, want %.2f", got, want) + } +} + +func TestDouble(t *testing.T) { + var b bytes.Buffer + size := marshalDouble(1.1, &b) + encodedDouble := b.Bytes() + checkSize(t, size, encodedDouble) + diffByteArrays(t, encodedDouble, []byte{ + 0b11100, // Primitive type, double + 0x9A, + 0x99, + 0x99, + 0x99, + 0x99, + 0x99, + 0xF1, + 0x3F, // 0x3FF1 9999 9999 999A ~= 1.1 encoded + }) + got, err := unmarshalDouble(encodedDouble, 0) + if err != nil { + t.Fatalf("unmarshalDouble(): %v", err) + } + if want := float64(1.1); got != want { + t.Fatalf("Incorrect double returned. Got %.2f, want %.2f", got, want) + } +} + +func mustMarshalPrimitive(t *testing.T, val any, opts ...MarshalOpts) []byte { + t.Helper() + var buf bytes.Buffer + if _, err := marshalPrimitive(val, &buf, opts...); err != nil { + t.Fatalf("marshalPrimitive(): %v", err) + } + return buf.Bytes() +} + +func TestUnmarshalPrimitive(t *testing.T) { + cases := []struct { + name string + encoded []byte + offset int + unmarshalType reflect.Type + want any + wantErr bool + }{ + { + name: "Unmarshal bool (with offset)", + encoded: append([]byte{0, 0}, mustMarshalPrimitive(t, true)...), + offset: 2, + want: true, + }, + { + name: "Unmarshal into int (with offset)", + encoded: append([]byte{0, 0}, mustMarshalPrimitive(t, 1)...), // Encodes to Int8 + offset: 2, + want: int(1), + }, + { + name: "Unmarshal into int8", + encoded: mustMarshalPrimitive(t, 1), + want: int8(1), + }, + { + name: "Unmarshal into int16", + encoded: mustMarshalPrimitive(t, 1), + want: int16(1), + }, + { + name: "Unmarshal into int32", + encoded: mustMarshalPrimitive(t, 1), + want: int32(1), + }, + { + name: "Unmarshal into int64", + encoded: mustMarshalPrimitive(t, 1), + want: int64(1), + }, + { + name: "Unmarshal negative", + encoded: mustMarshalPrimitive(t, -100), + want: -100, + }, + { + name: "Unmarshal int into float32", + encoded: mustMarshalPrimitive(t, 1), + want: float32(1), + }, + { + name: "unmarsUnmarshalhal float32", + encoded: mustMarshalPrimitive(t, float32(1.2)), + want: float32(1.2), + }, + { + name: "Unmarshal float64 (with offset)", + encoded: append([]byte{1, 1}, mustMarshalPrimitive(t, float64(1.2))...), + offset: 2, + want: float64(1.2), + }, + { + name: "Unmarshal timestamp into int64", + encoded: mustMarshalPrimitive(t, time.Unix(123, 0).Local(), MarshalTimeNanos), + want: time.Unix(123, 0).UnixNano(), + }, + { + name: "Unmarshal timestamp into time", + encoded: mustMarshalPrimitive(t, time.UnixMilli(1742967183000)), + want: time.UnixMilli(1742967183000), + }, + { + name: "Unmarshal timestamp into time (nanos)", + encoded: mustMarshalPrimitive(t, time.UnixMilli(1742967183000), MarshalTimeNanos), + want: time.Unix(0, 1742967183000000000), + }, + { + name: "Unmarshal short string with offset", + encoded: append([]byte{0, 0}, mustMarshalPrimitive(t, "hello")...), + offset: 2, + want: "hello", + }, + { + name: "Unmarshal basic string", + encoded: []byte{0b1000000, 0x03, 0x00, 0x00, 0x00, 'a', 'b', 'c'}, + want: "abc", + }, + { + name: "Unmarshal string into byte slice", + encoded: mustMarshalPrimitive(t, "hello"), + want: []byte("hello"), + }, + { + name: "Unmarshal binary into byte slice", + encoded: mustMarshalPrimitive(t, []byte{'b', 'y', 't', 'e'}), + want: []byte("byte"), + }, + { + name: "Unmarshal binary into string", + encoded: mustMarshalPrimitive(t, []byte{'b', 'y', 't', 'e'}), + want: "byte", + }, + { + name: "Unmarshal empty binary", + encoded: mustMarshalPrimitive(t, []byte{}), + want: []byte{}, + }, + { + name: "Unmarshal UUID to byte slice", + encoded: mustMarshalPrimitive(t, []byte("sixteencharacter"), MarshalAsUUID), + want: []byte("sixteencharacter"), + }, + { + name: "Unmarshal UUID to string", + encoded: mustMarshalPrimitive(t, "sixteencharacter", MarshalAsUUID), + want: "sixteencharacter", + }, + { + name: "Unmarshal into int8 would overflow", + encoded: mustMarshalPrimitive(t, 12345), + unmarshalType: reflect.TypeOf(int8(0)), + wantErr: true, + }, + { + name: "Cannot unmarshal int into non-int type", + encoded: mustMarshalPrimitive(t, 1), + unmarshalType: reflect.TypeOf(string("")), + wantErr: true, + }, + { + name: "Cannot unmarshal string into non-string type", + encoded: mustMarshalPrimitive(t, "hello"), + unmarshalType: reflect.TypeOf(int(1)), + wantErr: true, + }, + { + name: "Cannot unmarshal binary into non-binary type", + encoded: mustMarshalPrimitive(t, []byte{0, 1, 2}), + unmarshalType: reflect.TypeOf(int(1)), + wantErr: true, + }, + { + name: "Malformed value", + encoded: mustMarshalPrimitive(t, 256)[:1], // int16 is usually 3 bytes + unmarshalType: reflect.TypeOf(int(1)), + wantErr: true, + }, + { + name: "Short string out of bounds", + encoded: []byte{0b1001, 'a'}, + unmarshalType: reflect.TypeOf(""), + wantErr: true, + }, + { + name: "Binary out of bounds", + encoded: []byte{0b111100, 0x09, 0x00, 0x00, 0x00, 'a', 'b', 'c'}, + unmarshalType: reflect.TypeOf([]byte{}), + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + typ := reflect.TypeOf(c.want) + if c.unmarshalType != nil { + typ = c.unmarshalType + } + got := reflect.New(typ) + if err := unmarshalPrimitive(c.encoded, c.offset, got); err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatalf("Got no error when one was expected") + } + + diff(t, got.Elem().Interface(), c.want) + }) + } +} + +func TestString(t *testing.T) { + cases := []struct { + name string + str string + wantEncoded []byte + }{ + { + name: "Short string", + str: "short", + wantEncoded: []byte{ + 0b0010101, // Short string type, length=5 + 's', 'h', 'o', 'r', 't', + }, + }, + { + name: "Basic string", + str: "abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklmnopqrstuvwxyz1234567890", + wantEncoded: append([]byte{ + 0b1000000, + 0x48, 0x00, 0x00, 0x00, // Length of 72 + }, []byte("abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklmnopqrstuvwxyz1234567890")...), + }, + { + name: "Empty string", + str: "", + wantEncoded: []byte{0b01}, // Short string basic type, length = 0 + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var b bytes.Buffer + size := marshalString(c.str, &b) + checkSize(t, size, c.wantEncoded) + + gotEncoded := b.Bytes() + diff(t, gotEncoded, c.wantEncoded) + }) + } +} + +func TestBinary(t *testing.T) { + cases := []struct { + name string + bin []byte + wantEncoded []byte + }{ + { + name: "Binary data", + bin: []byte("hello"), + wantEncoded: []byte{ + 0b111100, // Primitive type, binary + 0x05, 0x00, 0x00, 0x00, // Length of 5 + 'h', 'e', 'l', 'l', 'o', + }, + }, + { + name: "Empty data", + bin: []byte{}, + wantEncoded: []byte{0b111100, 0x00, 0x00, 0x00, 0x00}, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var b bytes.Buffer + size := marshalBinary(c.bin, &b) + checkSize(t, size, c.wantEncoded) + diff(t, b.Bytes(), c.wantEncoded) + }) + } +} + +func TestTimestamp(t *testing.T) { + cases := []struct { + name string + nanos bool + ntz bool + wantHdr byte + }{ + { + name: "Nanos NTZ", + nanos: true, + ntz: true, + wantHdr: 0b1001100, + }, + { + name: "Nanos", + nanos: true, + wantHdr: 0b1001000, + }, + { + name: "Micros NTZ", + ntz: true, + wantHdr: 0b110100, + }, + { + name: "Micros", + wantHdr: 0b110000, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + ref := time.UnixMicro(1000000000) + if c.ntz { + ref = ref.UTC() + } else { + ref = ref.Local() + } + var b bytes.Buffer + size := marshalTimestamp(ref, c.nanos, &b) + wantEncoded := []byte{c.wantHdr} + if c.nanos { + wantEncoded = append(wantEncoded, []byte{ + 0x00, + 0x10, + 0xA5, + 0xD4, + 0xE8, + 0x00, + 0x00, + 0x00, // Binary encoding of 1,000,000,000,000 + }...) + } else { + wantEncoded = append(wantEncoded, []byte{ + 0x00, + 0xCA, + 0x9A, + 0x3B, + 0x00, + 0x00, + 0x00, + 0x00, // Binary encoding of 1,000,000,000 + }...) + } + encodedTimestamp := b.Bytes() + checkSize(t, size, encodedTimestamp) + diffByteArrays(t, encodedTimestamp, wantEncoded) + got, err := unmarshalTimestamp(b.Bytes(), 0) + if err != nil { + t.Fatalf("unmarshalTimestamp(): %v", err) + } + if want := ref; got != want { + t.Fatalf("Timestamps differ: got %s, want %s", got, want) + } + }) + } +} + +func TestDate(t *testing.T) { + day := time.Unix(0, 0).Add(10000 * 24 * time.Hour) + var b bytes.Buffer + size := marshalDate(day, &b) + encodedDate := b.Bytes() + checkSize(t, size, encodedDate) + diffByteArrays(t, encodedDate, []byte{ + 0b101100, // Primitive type, date + 0x10, + 0x27, + 0x00, + 0x00, // 10000 = 0x0000 2710 + }) + got, err := unmarshalDate(encodedDate, 0) + if err != nil { + t.Fatalf("unmarshalDate(): %v", err) + } + if want := day; got != want { + t.Fatalf("Incorrect date: got %s, want %s", got, want) + } +} diff --git a/parquet/variants/testutils.go b/parquet/variants/testutils.go new file mode 100644 index 00000000..937226b4 --- /dev/null +++ b/parquet/variants/testutils.go @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func diff(t *testing.T, got, want any, cmpOpts ...cmp.Option) { + t.Helper() + if d := cmp.Diff(got, want, cmpOpts...); d != "" { + t.Fatalf("Incorrect returned value. Diff (-got, +want):\n%s", d) + } +} diff --git a/parquet/variants/util.go b/parquet/variants/util.go new file mode 100644 index 00000000..b3f36082 --- /dev/null +++ b/parquet/variants/util.go @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "fmt" + "io" + "reflect" + "time" +) + +// Reads a little-endian encoded uint (betwen 1 and 8 bytes wide) from a raw buffer at a specified +// offset and returns its value. If any part of the read would be out of bounds, this returns an error. +func readUint(raw []byte, offset, size int) (uint64, error) { + if size < 1 || size > 8 { + return 0, fmt.Errorf("invalid size, must be in range [1,8]: %d", size) + } + if maxPos := offset + size; maxPos > len(raw) { + return 0, fmt.Errorf("out of bounds: trying to access position %d, max position is %d", maxPos, len(raw)) + } + var ret uint64 + for i := range size { + ret |= uint64(raw[i+offset]) << (8 * i) + } + return ret, nil +} + +// Reads a little-endian encoded integer (between 1 and 8 bytes wide) from a raw buffer at a specified offset. +func readInt(raw []byte, offset, size int) (int64, error) { + u, err := readUint(raw, offset, size) + if err != nil { + return -1, err + } + return int64(u), nil +} + +func fieldOffsetSize(maxSize int32) int { + if maxSize < 0xFF { + return 1 + } else if maxSize < 0xFFFF { + return 2 + } else if maxSize < 0xFFFFFF { + return 3 + } + return 4 +} + +// Checks that a given range is in the provided raw buffer. +func checkBounds(raw []byte, low, high int) error { + maxPos := len(raw) + if low >= maxPos { + return fmt.Errorf("out of bounds: trying to access position %d, max is %d", low, maxPos) + } + if high > maxPos { + return fmt.Errorf("out of bounds: trying to access position %d, max is %d", high, maxPos) + } + return nil +} + +// Encodes a number of a specified width in little-endian format and writes to a writer. +func encodeNumber(val int64, size int, w io.Writer) { + buf := make([]byte, size) + for i := range size { + buf[i] = byte(val) + val >>= 8 + } + w.Write(buf) +} + +func isLarge(numItems int) bool { + return numItems > 0xFF +} + +// Returns the basic type the passed in value should be encoded as, or undefined if it cannot be handled. +func kindFromValue(val any) BasicType { + if val == nil { + return BasicPrimitive + } + v := reflect.ValueOf(val) + + if v.Kind() == reflect.Pointer { + v = v.Elem() + } + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Bool, + reflect.String, reflect.Float32, reflect.Float64: + return BasicPrimitive + case reflect.Struct: + // Time is considered a primitive. All other structs are objects. + if v.Type() == reflect.TypeOf(time.Time{}) { + return BasicPrimitive + } + return BasicObject + case reflect.Array, reflect.Slice: + typ := v.Type() + if typ.Elem().Kind() == reflect.Uint8 { + return BasicPrimitive + } + return BasicArray + case reflect.Map: + // Only maps with string keys are supported. + typ := v.Type() + if typ.Key().Kind() == reflect.String { + return BasicObject + } + } + return BasicUndefined +} + +// Returns the nth item (zero indexed) in a serialized list (ie. a serialized Array, or serialized Metadata). +// The offset should be the index of the first offset listing. +func readNthItem(raw []byte, offset, item, offsetSize, numElements int) ([]byte, error) { + if err := checkBounds(raw, offset, offset); err != nil { + return nil, err + } + + if item > numElements { + return nil, fmt.Errorf("item number is greater than number of elements (%d vs %d)", item, numElements) + } + + // Calculate the range to return by getting the upper and lower bound of the item. + lowerBound, err := readUint(raw, offset+item*offsetSize, offsetSize) + if err != nil { + return nil, err + } + upperBound, err := readUint(raw, offset+(item+1)*offsetSize, offsetSize) + if err != nil { + return nil, err + } + firstElemIdx := offset + (numElements+1)*offsetSize + + lowIdx := firstElemIdx + int(lowerBound) + highIdx := firstElemIdx + int(upperBound) + + if err := checkBounds(raw, lowIdx, highIdx); err != nil { + return nil, err + } + + return raw[lowIdx:highIdx], nil +} diff --git a/parquet/variants/util_test.go b/parquet/variants/util_test.go new file mode 100644 index 00000000..50036106 --- /dev/null +++ b/parquet/variants/util_test.go @@ -0,0 +1,274 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +import ( + "testing" + "time" +) + +func TestReadUint(t *testing.T) { + cases := []struct { + name string + raw []byte + offset int + size int + want uint64 + wantErr bool + }{ + { + name: "Read uint8, offset=1", + raw: []byte{0x00, 0x05}, + offset: 1, + size: 1, + want: 5, + }, + { + name: "Read uint16, offset=1", + raw: []byte{0x00, 0x00, 0x01}, // 256 + offset: 1, + size: 2, + want: 256, + }, + { + name: "Read uint32, offset=1", + raw: []byte{0x00, 0x00, 0x00, 0x01, 0x00}, // 65536 + offset: 1, + size: 4, + want: 65536, + }, + { + name: "Read uint64, offset=1", + raw: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00}, // 4294967293 + offset: 1, + size: 8, + want: 4294967296, + }, + { + name: "Empty raw buffer", + offset: 0, + size: 1, + wantErr: true, + }, + { + name: "Not enough bytes for offset", + raw: []byte{0x00}, + offset: 1, + size: 1, + wantErr: true, + }, + { + name: "Not enough bytes to read", + raw: []byte{0x00}, + offset: 0, + size: 2, + wantErr: true, + }, + { + name: "Invalid size 0", + raw: []byte{0x00}, + offset: 0, + size: 0, + wantErr: true, + }, + { + name: "Invalid size too big", + raw: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + offset: 0, + size: 9, + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := readUint(c.raw, c.offset, c.size) + if err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatalf("Got no error when one was expected") + } + if got != c.want { + t.Fatalf("Incorrect value returned. Got %d, want %d", got, c.want) + } + }) + } +} + +func TestKindFromValue(t *testing.T) { + cases := []struct { + name string + val any + want BasicType + }{ + { + name: "Int", + val: 123, + want: BasicPrimitive, + }, + { + name: "Int pointer", + val: func() *int { + a := 123 + return &a + }(), + want: BasicPrimitive, + }, + { + name: "Bool", + val: false, + want: BasicPrimitive, + }, + { + name: "Byte slice is primitive", + val: []byte{'a', 'b', 'c'}, + want: BasicPrimitive, + }, + { + name: "Time", + val: time.Unix(100, 100), + want: BasicPrimitive, + }, + { + name: "Struct", + val: struct{ a int }{1}, + want: BasicObject, + }, + { + name: "Struct pointer", + val: &struct{ a int }{1}, + want: BasicObject, + }, + { + name: "Slice is an array", + val: []int{1, 2, 3}, + want: BasicArray, + }, + { + name: "Map with string keys is an object", + val: map[string]bool{"a": true}, + want: BasicObject, + }, + { + name: "Map with non string keys is not supported", + val: map[int]string{1: "a"}, + want: BasicUndefined, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := kindFromValue(c.val); got != c.want { + t.Fatalf("Incorrect kind. Got %s, want %s", got, c.want) + } + }) + } +} + +func TestReadNthItem(t *testing.T) { + cases := []struct { + name string + raw []byte + offset int + item int + offsetSize int + numElements int + want []byte + wantErr bool + }{ + { + name: "Third item, offset=1, offsetSize=1, width=1", + raw: []byte{0x00, 0x00, 0x01, 0x02, 0x03, 0xAA, 0xBB, 0xCC}, + offset: 1, + item: 2, + offsetSize: 1, + numElements: 3, + want: []byte{0xCC}, + }, + { + name: "Second item, offset=1, offsetSize=2, width=1", + raw: []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0xAA, 0xBB, 0xCC}, + offset: 1, + item: 1, + offsetSize: 2, + numElements: 3, + want: []byte{0xBB}, + }, + { + name: "First item, offset=1, offsetSize=2, width=2", + raw: []byte{ + 0x00, 0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x00, + 0xAA, 0xAA, 0xBB, 0xBB, 0xCC, 0xCC}, + offset: 1, + item: 0, + offsetSize: 2, + numElements: 3, + want: []byte{0xAA, 0xAA}, + }, + { + name: "Second item, offset=1, offsetSize=2, width=2", + raw: []byte{ + 0x00, 0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x00, + 0xAA, 0xAA, 0xBB, 0xBB, 0xCC, 0xCC}, + offset: 1, + item: 1, + offsetSize: 2, + numElements: 3, + want: []byte{0xBB, 0xBB}, + }, + { + name: "Offset out of bounds", + raw: []byte{0x00}, + offset: 1, + wantErr: true, + }, + { + name: "Item is greater than numElements", + raw: []byte{0x00, 0x00, 0x01, 0x02, 0x03, 0xAA, 0xBB, 0xCC}, + offset: 1, + item: 4, + offsetSize: 1, + numElements: 3, + wantErr: true, + }, + { + name: "Item is out of bounds", + raw: []byte{0x00, 0x01, 0x02, 0x04, 0xAA, 0xBB, 0xCC}, + offset: 0, + item: 2, + offsetSize: 1, + numElements: 3, + wantErr: true, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := readNthItem(c.raw, c.offset, c.item, c.offsetSize, c.numElements) + if err != nil { + if c.wantErr { + return + } + t.Fatalf("readNthItem(): %v", err) + } else if c.wantErr { + t.Fatalf("readNthItem(): wanted error, got none") + } + diffByteArrays(t, got, c.want) + }) + } +} diff --git a/parquet/variants/variant.go b/parquet/variants/variant.go new file mode 100644 index 00000000..bd6fb83e --- /dev/null +++ b/parquet/variants/variant.go @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variants + +// Basic types +type BasicType int + +const ( + BasicUndefined BasicType = -1 + BasicPrimitive BasicType = 0 + BasicShortString BasicType = 1 + BasicObject BasicType = 2 + BasicArray BasicType = 3 +) + +func (bt BasicType) String() string { + switch bt { + case BasicPrimitive: + return "Primitive" + case BasicShortString: + return "ShortString" + case BasicObject: + return "Object" + case BasicArray: + return "Array" + } + return "Unknown" +} + +// Function to get the Variant basic type from a provided value header +func BasicTypeFromHeader(hdr byte) BasicType { + return BasicType(hdr & 0x3) +} + +// Container to hold a marshaled Variant. +type MarshaledVariant struct { + Metadata []byte + Value []byte +} From ec026eaf90991bfcbfe9ab7cd7aabae1436d2885 Mon Sep 17 00:00:00 2001 From: Marcin Bojanczyk Date: Tue, 22 Apr 2025 16:51:13 -0700 Subject: [PATCH 02/10] Address feedback --- parquet/variants/builder.go | 1 - parquet/variants/decoder.go | 2 +- parquet/variants/object.go | 2 - parquet/variants/object_test.go | 4 +- parquet/variants/primitive.go | 72 ++++++++++++++---------------- parquet/variants/primitive_test.go | 57 ++++++++++++----------- parquet/variants/util.go | 7 +++ parquet/variants/util_test.go | 57 +++++++++++++++++++++++ 8 files changed, 132 insertions(+), 70 deletions(-) diff --git a/parquet/variants/builder.go b/parquet/variants/builder.go index b8e8aa9b..032e069e 100644 --- a/parquet/variants/builder.go +++ b/parquet/variants/builder.go @@ -54,7 +54,6 @@ const ( MarshalAsDate MarshalAsTime MarshalAsTimestamp - MarshalAsUUID ) var errAlreadyBuilt = errors.New("component already built") diff --git a/parquet/variants/decoder.go b/parquet/variants/decoder.go index 9a85a065..7cdd1c4e 100644 --- a/parquet/variants/decoder.go +++ b/parquet/variants/decoder.go @@ -33,7 +33,7 @@ import ( // - Timestamp (all varieties): time.Time // - String: string // - Binary: []byte -// - UUID: string +// - UUID: uuid.UUID // - Array: []any // - Object: map[string]any diff --git a/parquet/variants/object.go b/parquet/variants/object.go index 2031f64a..5d00b492 100644 --- a/parquet/variants/object.go +++ b/parquet/variants/object.go @@ -107,8 +107,6 @@ func extractFieldInfo(field reflect.StructField) (string, []MarshalOpts) { opts = append(opts, MarshalAsTime) case "timestamp": opts = append(opts, MarshalAsTimestamp) - case "uuid": - opts = append(opts, MarshalAsUUID) } } diff --git a/parquet/variants/object_test.go b/parquet/variants/object_test.go index de53d5e9..7c7423ae 100644 --- a/parquet/variants/object_test.go +++ b/parquet/variants/object_test.go @@ -656,7 +656,7 @@ func TestExtractFieldInfo(t *testing.T) { JustName int `variant:"just_name"` EmptyTag int `variant:""` WithOpts int `variant:"with_opts,ntz,date,nanos,time"` - OptsNoName int `variant:",uuid"` + OptsNoName int `variant:",ntz"` UnknownOpt int `variant:"unknown,not_defined_opt"` } cases := []struct { @@ -690,7 +690,7 @@ func TestExtractFieldInfo(t *testing.T) { name: "Just options, no name uses struct field name", field: 4, wantName: "OptsNoName", - wantOpts: []MarshalOpts{MarshalAsUUID}, + wantOpts: []MarshalOpts{MarshalTimeNTZ}, }, { name: "Ignore unknown options", diff --git a/parquet/variants/primitive.go b/parquet/variants/primitive.go index 8ab48784..5d4eed01 100644 --- a/parquet/variants/primitive.go +++ b/parquet/variants/primitive.go @@ -23,6 +23,9 @@ import ( "reflect" "strings" "time" + "unsafe" + + "github.com/google/uuid" ) // Variant primitive type IDs. @@ -162,15 +165,11 @@ func marshalPrimitive(v any, w io.Writer, opts ...MarshalOpts) (int, error) { return marshalFloat(val, w), nil case float64: return marshalDouble(val, w), nil + case uuid.UUID: + return marshalUUID(val, w), nil case string: - if allOpts&MarshalAsUUID != 0 { - return marshalUUID([]byte(val), w), nil - } return marshalString(val, w), nil case []byte: - if allOpts&MarshalAsUUID != 0 { - return marshalUUID([]byte(val), w), nil - } return marshalBinary(val, w), nil case time.Time: if allOpts&MarshalAsDate != 0 { @@ -195,6 +194,7 @@ func marshalPrimitive(v any, w io.Writer, opts ...MarshalOpts) (int, error) { func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { dest := destPtr.Elem() kind := dest.Kind() + isEmptyInterface := kind == reflect.Interface && dest.NumMethod() == 0 if err := checkBounds(raw, offset, offset); err != nil { return err @@ -218,7 +218,7 @@ func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { if err != nil { return err } - if kind == reflect.Interface { + if isEmptyInterface { dest.Set(reflect.ValueOf(iv)) } else if dest.CanInt() { if dest.OverflowInt(iv) { @@ -261,7 +261,7 @@ func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { if err != nil { return err } - if !dest.CanFloat() && kind != reflect.Interface { + if !dest.CanFloat() && !isEmptyInterface { return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) } switch kind { @@ -277,7 +277,7 @@ func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { if err != nil { return err } - if !dest.CanFloat() && kind != reflect.Interface { + if !dest.CanFloat() && !isEmptyInterface { return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) } switch kind { @@ -300,7 +300,7 @@ func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { // Time can be decoded into either an int64 (the physical time), or into a time.Time struct. // Anything else is invalid. - if kind == reflect.Int64 || kind == reflect.Interface { + if kind == reflect.Int64 || isEmptyInterface { dest.Set(reflect.ValueOf(int64(tsv))) } else if kind == reflect.Uint64 { dest.Set(reflect.ValueOf(tsv)) @@ -330,9 +330,9 @@ func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { if err != nil { return err } - if kind == reflect.String || kind == reflect.Interface { + if kind == reflect.String || isEmptyInterface { dest.Set(reflect.ValueOf(str)) - } else if kind == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8 { + } else if dest.Type() == reflect.TypeOf([]byte{}) { dest.Set(reflect.ValueOf([]byte(str))) } else { return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) @@ -342,7 +342,7 @@ func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { if err != nil { return err } - if kind == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8 || kind == reflect.Interface { + if isEmptyInterface || dest.Type() == reflect.TypeOf([]byte{}) { dest.Set(reflect.ValueOf(bytes)) } else if kind == reflect.String { dest.Set(reflect.ValueOf(string(bytes))) @@ -350,14 +350,22 @@ func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) } case primitiveUUID: - bytes, err := unmarshalUUID(raw, offset) + u, err := unmarshalUUID(raw, offset) if err != nil { return err } - if kind == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8 || kind == reflect.Interface { + if dest.Type() == reflect.TypeOf(uuid.UUID{}) || isEmptyInterface { + dest.Set(reflect.ValueOf(u)) + } else if dest.Type() == reflect.TypeOf([]byte{}) { + bytes, _ := u.MarshalBinary() dest.Set(reflect.ValueOf(bytes)) + } else if dest.Type() == reflect.TypeOf([16]byte{}) { + bytes, _ := u.MarshalBinary() + var fixed [16]byte + copy(fixed[:], bytes) + dest.Set(reflect.ValueOf(fixed)) } else if kind == reflect.String { - dest.Set(reflect.ValueOf(string(bytes))) + dest.Set(reflect.ValueOf(u.String())) } else { return fmt.Errorf("cannot decode Variant UUID into dest %s", kind) } @@ -518,25 +526,19 @@ func marshalString(str string, w io.Writer) int { return 1 + encodePrimitiveBytes([]byte(strings.ToValidUTF8(str, "\uFFFD")), w) } -func marshalUUID(uuid []byte, w io.Writer) int { +func marshalUUID(u uuid.UUID, w io.Writer) int { hdr, _ := primitiveHeader(primitiveUUID) w.Write([]byte{hdr}) - - // A UUID is 16 bytes. Either pad or truncate to this length. - if len(uuid) > 16 { - uuid = uuid[:16] - } else if pad := 16 - len(uuid); pad > 0 { - uuid = append(uuid, make([]byte, pad)...) - } - w.Write(uuid) + m, _ := u.MarshalBinary() // MarshalBinary() can never return an error + w.Write(m) return 17 } -func unmarshalUUID(raw []byte, offset int) ([]byte, error) { +func unmarshalUUID(raw []byte, offset int) (uuid.UUID, error) { if err := checkBounds(raw, offset, offset+17); err != nil { - return nil, err + return uuid.UUID{}, err } - return raw[offset+1 : offset+17], nil + return uuid.FromBytes(raw[offset+1 : offset+17]) } func unmarshalString(raw []byte, offset int) (string, error) { @@ -553,14 +555,16 @@ func unmarshalString(raw []byte, offset int) (string, error) { if endIdx > maxPos { return "", fmt.Errorf("end index is out of bounds: trying to access position %d, max position is %d", endIdx, maxPos) } - return string(raw[offset+1 : endIdx]), nil + strPtr := (*byte)(unsafe.Pointer(&raw[offset+1])) + return unsafe.String(strPtr, l), nil } b, err := getBytes(raw, offset+1) if err != nil { return "", err } - return string(b), nil + strPtr := (*byte)(unsafe.Pointer(&b[0])) + return unsafe.String(strPtr, len(b)), nil } func getBytes(raw []byte, offset int) ([]byte, error) { @@ -631,14 +635,6 @@ func encodeTimestamp(t int64, nanos, ntz bool, w io.Writer) int { return 9 } -// func decodeTimestamp(raw []byte, offset int) (int64, error) { -// ts, err := readUint(raw, offset+1, 8) -// if err != nil { -// return -1, err -// } -// return int64(ts), nil -// } - func unmarshalTimestamp(raw []byte, offset int) (time.Time, error) { typ, _ := primitiveFromHeader(raw[offset]) ts, err := readUint(raw, offset+1, 8) diff --git a/parquet/variants/primitive_test.go b/parquet/variants/primitive_test.go index 371b8aa0..c13fa92c 100644 --- a/parquet/variants/primitive_test.go +++ b/parquet/variants/primitive_test.go @@ -23,6 +23,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/uuid" ) func diffByteArrays(t *testing.T, got, want []byte) { @@ -147,35 +148,21 @@ func TestInt(t *testing.T) { func TestUUID(t *testing.T) { cases := []struct { name string - uuid []byte + uuid uuid.UUID want []byte }{ { name: "UUID no padding", - uuid: []byte("sixteencharacter"), + uuid: func() uuid.UUID { + u, _ := uuid.FromBytes([]byte("sixteencharacter")) + return u + }(), want: []byte{ 0b1010000, // Basic primitive UUID 's', 'i', 'x', 't', 'e', 'e', 'n', 'c', 'h', 'a', 'r', 'a', 'c', 't', 'e', 'r', }, }, - { - name: "UUID truncation", - uuid: []byte("sixteencharacters"), - want: []byte{ - 0b1010000, // Basic primitive UUID - 's', 'i', 'x', 't', 'e', 'e', 'n', - 'c', 'h', 'a', 'r', 'a', 'c', 't', 'e', 'r', - }, - }, - { - name: "UUID padding", - uuid: []byte("small"), - want: []byte{ - 0b1010000, // Basic primitive UUID - 's', 'm', 'a', 'l', 'l', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - }, - }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -190,7 +177,8 @@ func TestUUID(t *testing.T) { if err != nil { t.Fatalf("unmarshalUUID(): %v", err) } - diff(t, gotUUID, c.want[1:]) + gotUUIDBytes, _ := gotUUID.MarshalBinary() + diff(t, gotUUIDBytes, c.want[1:]) }) } } @@ -359,14 +347,31 @@ func TestUnmarshalPrimitive(t *testing.T) { want: []byte{}, }, { - name: "Unmarshal UUID to byte slice", - encoded: mustMarshalPrimitive(t, []byte("sixteencharacter"), MarshalAsUUID), - want: []byte("sixteencharacter"), + name: "Unmarshal UUID", + encoded: func() []byte { + u, _ := uuid.FromBytes([]byte("sixteencharacter")) + return mustMarshalPrimitive(t, u) + }(), + want: func() uuid.UUID { + u, _ := uuid.FromBytes([]byte("sixteencharacter")) + return u + }(), + }, + { + name: "Unmarshal UUID to byte slice", + encoded: func() []byte { + u, _ := uuid.FromBytes([]byte("sixteencharacter")) + return mustMarshalPrimitive(t, u) + }(), + want: []byte("sixteencharacter"), }, { - name: "Unmarshal UUID to string", - encoded: mustMarshalPrimitive(t, "sixteencharacter", MarshalAsUUID), - want: "sixteencharacter", + name: "Unmarshal UUID to string", + encoded: func() []byte { + u, _ := uuid.FromBytes([]byte("sixteencharacter")) + return mustMarshalPrimitive(t, u) + }(), + want: "73697874-6565-6e63-6861-726163746572", }, { name: "Unmarshal into int8 would overflow", diff --git a/parquet/variants/util.go b/parquet/variants/util.go index b3f36082..bba2e8ce 100644 --- a/parquet/variants/util.go +++ b/parquet/variants/util.go @@ -68,6 +68,12 @@ func checkBounds(raw []byte, low, high int) error { if high > maxPos { return fmt.Errorf("out of bounds: trying to access position %d, max is %d", high, maxPos) } + if high < low { + return fmt.Errorf("incorrect bounds- high (%d) must higher than or equal to low (%d)", high, low) + } + if low < 0 { + return fmt.Errorf("bounds must be positive, have [%d, %d]", low, high) + } return nil } @@ -107,6 +113,7 @@ func kindFromValue(val any) BasicType { return BasicObject case reflect.Array, reflect.Slice: typ := v.Type() + // Byte arrays are primitives. UUID happens to fall into this bucket too serindiptously. if typ.Elem().Kind() == reflect.Uint8 { return BasicPrimitive } diff --git a/parquet/variants/util_test.go b/parquet/variants/util_test.go index 50036106..f52b9004 100644 --- a/parquet/variants/util_test.go +++ b/parquet/variants/util_test.go @@ -272,3 +272,60 @@ func TestReadNthItem(t *testing.T) { }) } } + +func TestCheckBounds(t *testing.T) { + cases := []struct { + name string + raw []byte + low, high int + wantErr bool + }{ + { + name: "In bounds", + raw: make([]byte, 10), + low: 1, + high: 9, + }, + { + name: "low == high", + raw: make([]byte, 10), + low: 1, + high: 1, + }, + { + name: "Out of bounds (idx == len(raw))", + raw: make([]byte, 10), + low: 10, + high: 10, + wantErr: true, + }, + { + name: "high < low", + raw: make([]byte, 10), + low: 5, + high: 1, + wantErr: true, + }, + { + name: "Negative index", + raw: make([]byte, 10), + low: -1, + high: 1, + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + err := checkBounds(c.raw, c.low, c.high) + if err != nil { + if c.wantErr { + return + } + t.Fatalf("Unexpected error: %v", err) + } else if c.wantErr { + t.Fatalf("Got no error when one was expected") + } + }) + } +} From 849933442df9b8930b6bd1feda1b6fd74a53f457 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 28 Apr 2025 18:50:55 -0400 Subject: [PATCH 03/10] some generic cleanups some more cleanup refactor using learnings builder and tests --- arrow/endian/big.go | 16 +- arrow/endian/endian.go | 4 + arrow/endian/little.go | 8 +- arrow/extensions/variant/basic_type_string.go | 28 + arrow/extensions/variant/builder.go | 491 +++++++++++++ arrow/extensions/variant/builder_test.go | 297 ++++++++ .../variant/primitive_type_string.go | 45 ++ arrow/extensions/variant/utils.go | 170 +++++ arrow/extensions/variant/variant.go | 649 ++++++++++++++++++ arrow/extensions/variant/variant_test.go | 512 ++++++++++++++ go.mod | 14 +- go.sum | 28 +- parquet-testing | 2 +- parquet/variants/builder.go | 4 +- parquet/variants/primitive.go | 300 +++++--- parquet/variants/primitive_test.go | 35 +- 16 files changed, 2450 insertions(+), 153 deletions(-) create mode 100644 arrow/extensions/variant/basic_type_string.go create mode 100644 arrow/extensions/variant/builder.go create mode 100644 arrow/extensions/variant/builder_test.go create mode 100644 arrow/extensions/variant/primitive_type_string.go create mode 100644 arrow/extensions/variant/utils.go create mode 100644 arrow/extensions/variant/variant.go create mode 100644 arrow/extensions/variant/variant_test.go diff --git a/arrow/endian/big.go b/arrow/endian/big.go index 0b925857..9dfc76c0 100644 --- a/arrow/endian/big.go +++ b/arrow/endian/big.go @@ -19,12 +19,22 @@ package endian -import "encoding/binary" - -var Native = binary.BigEndian +import "math/bits" const ( IsBigEndian = true NativeEndian = BigEndian NonNativeEndian = LittleEndian ) + +func FromLE[T uint16 | uint32 | uint64](x T) T { + switch v := any(x).(type) { + case uint16: + return T(bits.Reverse16(v)) + case uint32: + return T(bits.Reverse32(v)) + case uint64: + return T(bits.Reverse64(v)) + } + return x +} diff --git a/arrow/endian/endian.go b/arrow/endian/endian.go index f369945d..ad2b1085 100644 --- a/arrow/endian/endian.go +++ b/arrow/endian/endian.go @@ -17,10 +17,14 @@ package endian import ( + "encoding/binary" + "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/arrow/internal/flatbuf" ) +var Native = binary.NativeEndian + type Endianness flatbuf.Endianness const ( diff --git a/arrow/endian/little.go b/arrow/endian/little.go index def1fc64..12fbed30 100644 --- a/arrow/endian/little.go +++ b/arrow/endian/little.go @@ -19,12 +19,12 @@ package endian -import "encoding/binary" - -var Native = binary.LittleEndian - const ( IsBigEndian = false NativeEndian = LittleEndian NonNativeEndian = BigEndian ) + +func FromLE[T uint16 | uint32 | uint64](x T) T { + return x +} diff --git a/arrow/extensions/variant/basic_type_string.go b/arrow/extensions/variant/basic_type_string.go new file mode 100644 index 00000000..31afdc4a --- /dev/null +++ b/arrow/extensions/variant/basic_type_string.go @@ -0,0 +1,28 @@ +// Code generated by "stringer -type=BasicType -linecomment -output=basic_type_string.go"; DO NOT EDIT. + +package variant + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[BasicUndefined - -1] + _ = x[BasicPrimitive-0] + _ = x[BasicShortString-1] + _ = x[BasicObject-2] + _ = x[BasicArray-3] +} + +const _BasicType_name = "UnknownPrimitiveShortStringObjectArray" + +var _BasicType_index = [...]uint8{0, 7, 16, 27, 33, 38} + +func (i BasicType) String() string { + i -= -1 + if i < 0 || i >= BasicType(len(_BasicType_index)-1) { + return "BasicType(" + strconv.FormatInt(int64(i+-1), 10) + ")" + } + return _BasicType_name[_BasicType_index[i]:_BasicType_index[i+1]] +} diff --git a/arrow/extensions/variant/builder.go b/arrow/extensions/variant/builder.go new file mode 100644 index 00000000..2fee6bad --- /dev/null +++ b/arrow/extensions/variant/builder.go @@ -0,0 +1,491 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variant + +import ( + "bytes" + "cmp" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "slices" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/google/uuid" +) + +type Builder struct { + buf bytes.Buffer + dict map[string]uint32 + dictKeys [][]byte + allowDuplicates bool +} + +func (b *Builder) SetAllowDuplicates(allow bool) { + b.allowDuplicates = allow +} + +func (b *Builder) AddKeys(keys []string) (ids []uint32) { + if b.dict == nil { + b.dict = make(map[string]uint32) + b.dictKeys = make([][]byte, 0, len(keys)) + } + + ids = make([]uint32, len(keys)) + for i, key := range keys { + var ok bool + if ids[i], ok = b.dict[key]; ok { + continue + } + + ids[i] = uint32(len(b.dictKeys)) + b.dict[key] = ids[i] + b.dictKeys = append(b.dictKeys, unsafe.Slice(unsafe.StringData(key), len(key))) + } + + return ids +} + +func (b *Builder) AddKey(key string) (id uint32) { + if b.dict == nil { + b.dict = make(map[string]uint32) + b.dictKeys = make([][]byte, 0, 16) + } + + var ok bool + if id, ok = b.dict[key]; ok { + return id + } + + id = uint32(len(b.dictKeys)) + b.dict[key] = id + b.dictKeys = append(b.dictKeys, unsafe.Slice(unsafe.StringData(key), len(key))) + + return id +} + +func (b *Builder) AppendNull() error { + return b.buf.WriteByte(primitiveHeader(PrimitiveNull)) +} + +func (b *Builder) AppendBool(v bool) error { + var t PrimitiveType + if v { + t = PrimitiveBoolTrue + } else { + t = PrimitiveBoolFalse + } + + return b.buf.WriteByte(primitiveHeader(t)) +} + +type primitiveNumeric interface { + int8 | int16 | int32 | int64 | float32 | float64 | + arrow.Date32 | arrow.Time64 +} + +type buffer interface { + io.Writer + io.ByteWriter +} + +func writeBinary[T string | []byte](w buffer, v T) error { + var t PrimitiveType + switch any(v).(type) { + case string: + t = PrimitiveString + case []byte: + t = PrimitiveBinary + } + + if err := w.WriteByte(primitiveHeader(t)); err != nil { + return err + } + + if err := binary.Write(w, binary.LittleEndian, uint32(len(v))); err != nil { + return err + } + + _, err := w.Write([]byte(v)) + return err +} + +func writeNumeric[T primitiveNumeric](w buffer, v T) error { + var t PrimitiveType + switch any(v).(type) { + case int8: + t = PrimitiveInt8 + case int16: + t = PrimitiveInt16 + case int32: + t = PrimitiveInt32 + case int64: + t = PrimitiveInt64 + case float32: + t = PrimitiveFloat + case float64: + t = PrimitiveDouble + case arrow.Date32: + t = PrimitiveDate + case arrow.Time64: + t = PrimitiveTimeMicrosNTZ + } + + if err := w.WriteByte(primitiveHeader(t)); err != nil { + return err + } + + return binary.Write(w, binary.LittleEndian, v) +} + +func (b *Builder) AppendInt(v int64) error { + b.buf.Grow(9) + switch { + case v >= math.MinInt8 && v <= math.MaxInt8: + return writeNumeric(&b.buf, int8(v)) + case v >= math.MinInt16 && v <= math.MaxInt16: + return writeNumeric(&b.buf, int16(v)) + case v >= math.MinInt32 && v <= math.MaxInt32: + return writeNumeric(&b.buf, int32(v)) + default: + return writeNumeric(&b.buf, v) + } +} + +func (b *Builder) AppendFloat32(v float32) error { + b.buf.Grow(5) + return writeNumeric(&b.buf, v) +} + +func (b *Builder) AppendFloat64(v float64) error { + b.buf.Grow(9) + return writeNumeric(&b.buf, v) +} + +func (b *Builder) AppendDate(v arrow.Date32) error { + b.buf.Grow(5) + return writeNumeric(&b.buf, v) +} + +func (b *Builder) AppendTimeMicro(v arrow.Time64) error { + b.buf.Grow(9) + return writeNumeric(&b.buf, v) +} + +func (b *Builder) AppendTimestamp(v arrow.Timestamp, useMicros, useUTC bool) error { + b.buf.Grow(9) + var t PrimitiveType + if useMicros { + t = PrimitiveTimestampMicrosNTZ + } else { + t = PrimitiveTimestampNanosNTZ + } + + if useUTC { + t-- + } + + if err := b.buf.WriteByte(primitiveHeader(t)); err != nil { + return err + } + + return binary.Write(&b.buf, binary.LittleEndian, v) +} + +func (b *Builder) AppendBinary(v []byte) error { + b.buf.Grow(5 + len(v)) + return writeBinary(&b.buf, v) +} + +func (b *Builder) AppendString(v string) error { + if len(v) > maxShortStringSize { + b.buf.Grow(5 + len(v)) + return writeBinary(&b.buf, v) + } + + b.buf.Grow(1 + len(v)) + if err := b.buf.WriteByte(shortStrHeader(len(v))); err != nil { + return err + } + + _, err := b.buf.WriteString(v) + return err +} + +func (b *Builder) AppendUUID(v uuid.UUID) error { + b.buf.Grow(17) + if err := b.buf.WriteByte(primitiveHeader(PrimitiveUUID)); err != nil { + return err + } + + m, _ := v.MarshalBinary() + _, err := b.buf.Write(m) + return err +} + +func (b *Builder) AppendDecimal4(scale uint8, v decimal.Decimal32) error { + b.buf.Grow(6) + if err := b.buf.WriteByte(primitiveHeader(PrimitiveDecimal4)); err != nil { + return err + } + + if err := b.buf.WriteByte(scale); err != nil { + return err + } + + return binary.Write(&b.buf, binary.LittleEndian, int32(v)) +} + +func (b *Builder) AppendDecimal8(scale uint8, v decimal.Decimal64) error { + b.buf.Grow(10) + return errors.Join( + b.buf.WriteByte(primitiveHeader(PrimitiveDecimal8)), + b.buf.WriteByte(scale), + binary.Write(&b.buf, binary.LittleEndian, int64(v)), + ) +} + +func (b *Builder) AppendDecimal16(scale uint8, v decimal.Decimal128) error { + b.buf.Grow(18) + return errors.Join( + b.buf.WriteByte(primitiveHeader(PrimitiveDecimal16)), + b.buf.WriteByte(scale), + binary.Write(&b.buf, binary.LittleEndian, v.LowBits()), + binary.Write(&b.buf, binary.LittleEndian, v.HighBits()), + ) +} + +func (b *Builder) Offset() int { + return b.buf.Len() +} + +func (b *Builder) FinishArray(start int, offsets []int) error { + var ( + dataSize, sz = b.buf.Len() - start, len(offsets) + isLarge = sz > math.MaxUint8 + sizeBytes = 1 + ) + + if isLarge { + sizeBytes = 4 + } + + if dataSize < 0 { + return errors.New("invalid array size") + } + + offsetSize := intSize(dataSize) + headerSize := 1 + sizeBytes + (sz+1)*int(offsetSize) + + // shift the just written data to make room for the header section + b.buf.Grow(headerSize) + av := b.buf.AvailableBuffer() + if _, err := b.buf.Write(av[:headerSize]); err != nil { + return err + } + + bs := b.buf.Bytes() + copy(bs[start+headerSize:], bs[start:start+dataSize]) + + // populate the header + bs[start] = arrayHeader(isLarge, offsetSize) + writeOffset(bs[start+1:], sz, uint8(sizeBytes)) + + offsetsStart := start + 1 + sizeBytes + for i, off := range offsets { + writeOffset(bs[offsetsStart+i*int(offsetSize):], off, offsetSize) + } + writeOffset(bs[offsetsStart+sz*int(offsetSize):], dataSize, offsetSize) + + return nil +} + +type FieldEntry struct { + Key string + ID uint32 + Offset int +} + +func (b *Builder) NextField(start int, key string) FieldEntry { + id := b.AddKey(key) + return FieldEntry{ + Key: key, + ID: id, + Offset: b.Offset() - start, + } +} + +func (b *Builder) FinishObject(start int, fields []FieldEntry) error { + slices.SortFunc(fields, func(a, b FieldEntry) int { + return cmp.Compare(a.Key, b.Key) + }) + + sz := len(fields) + var maxID uint32 + if sz > 0 { + maxID = fields[0].ID + } + + // if a duplicate key is found, one of two things happens: + // - if allowDuplicates is true, then the field with the greatest + // offset value (the last appended field) is kept. + // - if allowDuplicates is false, then an error is returned + if b.allowDuplicates { + distinctPos := 0 + // maintain a list of distinct keys in-place + for i := 1; i < sz; i++ { + maxID = max(maxID, fields[i].ID) + if fields[i].ID == fields[i-1].ID { + // found a duplicate key. keep the + // field with a greater offset + if fields[distinctPos].Offset < fields[i].Offset { + fields[distinctPos].Offset = fields[i].Offset + } + } else { + // found distinct key, add field to the list + distinctPos++ + fields[distinctPos] = fields[i] + } + } + + if distinctPos+1 < len(fields) { + sz = distinctPos + 1 + // resize fields to size + fields = fields[:sz] + // sort the fields by offsets so that we can move the value + // data of each field to the new offset without overwriting the + // fields after it. + slices.SortFunc(fields, func(a, b FieldEntry) int { + return cmp.Compare(a.Offset, b.Offset) + }) + + buf := b.buf.Bytes() + curOffset := 0 + for i := range sz { + oldOffset := fields[i].Offset + fieldSize := valueSize(buf[start+oldOffset:]) + copy(buf[start+curOffset:], buf[start+oldOffset:start+oldOffset+fieldSize]) + fields[i].Offset = curOffset + curOffset += fieldSize + } + b.buf.Truncate(start + curOffset) + // change back to sort order by field keys to meet variant spec + slices.SortFunc(fields, func(a, b FieldEntry) int { + return cmp.Compare(a.Key, b.Key) + }) + } + } else { + for i := 1; i < sz; i++ { + maxID = max(maxID, fields[i].ID) + if fields[i].Key == fields[i-1].Key { + return fmt.Errorf("disallowed duplicate key found: %s", fields[i].Key) + } + } + } + + var ( + dataSize = b.buf.Len() - start + isLarge = sz > math.MaxUint8 + sizeBytes = 1 + idSize, offsetSize = intSize(int(maxID)), intSize(dataSize) + ) + + if isLarge { + sizeBytes = 4 + } + + if dataSize < 0 { + return errors.New("invalid object size") + } + + headerSize := 1 + sizeBytes + sz*int(idSize) + (sz+1)*int(offsetSize) + // shift the just written data to make room for the header section + b.buf.Grow(headerSize) + av := b.buf.AvailableBuffer() + if _, err := b.buf.Write(av[:headerSize]); err != nil { + return err + } + + bs := b.buf.Bytes() + copy(bs[start+headerSize:], bs[start:start+dataSize]) + + // populate the header + bs[start] = objectHeader(isLarge, idSize, offsetSize) + writeOffset(bs[start+1:], sz, uint8(sizeBytes)) + + idStart := start + 1 + sizeBytes + offsetStart := idStart + sz*int(idSize) + for i, field := range fields { + writeOffset(bs[idStart+i*int(idSize):], int(field.ID), idSize) + writeOffset(bs[offsetStart+i*int(offsetSize):], field.Offset, offsetSize) + } + writeOffset(bs[offsetStart+sz*int(offsetSize):], dataSize, offsetSize) + return nil +} + +func (b *Builder) Build() (Value, error) { + nkeys := len(b.dictKeys) + totalDictSize := 0 + for _, k := range b.dictKeys { + totalDictSize += len(k) + } + + // determine the number of bytes required per offset entry. + // the largest offset is the one-past-the-end value, the total size. + // It's very unlikely that the number of keys could be larger, but + // incorporate that into the calculation in case of pathological data. + maxSize := max(totalDictSize, nkeys) + if maxSize > maxSizeLimit { + return Value{}, fmt.Errorf("metadata size too large: %d", maxSize) + } + + offsetSize := intSize(int(maxSize)) + offsetStart := 1 + offsetSize + stringStart := int(offsetStart) + (nkeys+1)*int(offsetSize) + metadataSize := stringStart + totalDictSize + + if metadataSize > maxSizeLimit { + return Value{}, fmt.Errorf("metadata size too large: %d", metadataSize) + } + + meta := make([]byte, metadataSize) + + meta[0] = supportedVersion | ((offsetSize - 1) << 6) + if nkeys > 0 && slices.IsSortedFunc(b.dictKeys, bytes.Compare) { + meta[0] |= 1 << 4 + } + writeOffset(meta[1:], nkeys, offsetSize) + + curOffset := 0 + for i, k := range b.dictKeys { + writeOffset(meta[int(offsetStart)+i*int(offsetSize):], curOffset, offsetSize) + curOffset += copy(meta[stringStart+curOffset:], k) + } + writeOffset(meta[int(offsetStart)+nkeys*int(offsetSize):], curOffset, offsetSize) + + return Value{ + value: b.buf.Bytes(), + meta: Metadata{ + data: meta, + keys: b.dictKeys, + }, + }, nil +} diff --git a/arrow/extensions/variant/builder_test.go b/arrow/extensions/variant/builder_test.go new file mode 100644 index 00000000..ae49bcf7 --- /dev/null +++ b/arrow/extensions/variant/builder_test.go @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variant_test + +import ( + "encoding/json" + "testing" + + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/extensions/variant" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildNullValue(t *testing.T) { + var b variant.Builder + b.AppendNull() + + v, err := b.Build() + require.NoError(t, err) + + assert.Equal(t, variant.Null, v.Type()) + assert.EqualValues(t, 1, v.Metadata().Version()) + assert.Zero(t, v.Metadata().DictionarySize()) +} + +func TestBuildPrimitive(t *testing.T) { + tests := []struct { + name string + op func(*variant.Builder) error + }{ + {"primitive_boolean_true", func(b *variant.Builder) error { + return b.AppendBool(true) + }}, + {"primitive_boolean_false", func(b *variant.Builder) error { + return b.AppendBool(false) + }}, + // AppendInt will use the smallest possible int type + {"primitive_int8", func(b *variant.Builder) error { return b.AppendInt(42) }}, + {"primitive_int16", func(b *variant.Builder) error { return b.AppendInt(1234) }}, + {"primitive_int32", func(b *variant.Builder) error { return b.AppendInt(123456) }}, + // FIXME: https://github.com/apache/parquet-testing/issues/82 + // primitive_int64 is an int32 value, but the metadata is int64 + {"primitive_int64", func(b *variant.Builder) error { return b.AppendInt(12345678) }}, + {"primitive_float", func(b *variant.Builder) error { return b.AppendFloat32(1234568000) }}, + {"primitive_double", func(b *variant.Builder) error { return b.AppendFloat64(1234567890.1234) }}, + {"primitive_string", func(b *variant.Builder) error { + return b.AppendString(`This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥️, 🎣 and 🤦!!`) + }}, + {"short_string", func(b *variant.Builder) error { return b.AppendString(`Less than 64 bytes (❤️ with utf8)`) }}, + // 031337deadbeefcafe + {"primitive_binary", func(b *variant.Builder) error { + return b.AppendBinary([]byte{0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}) + }}, + {"primitive_decimal4", func(b *variant.Builder) error { return b.AppendDecimal4(2, 1234) }}, + {"primitive_decimal8", func(b *variant.Builder) error { return b.AppendDecimal8(2, 1234567890) }}, + {"primitive_decimal16", func(b *variant.Builder) error { return b.AppendDecimal16(2, decimal128.FromU64(1234567891234567890)) }}, + {"primitive_date", func(b *variant.Builder) error { return b.AppendDate(20194) }}, + {"primitive_timestamp", func(b *variant.Builder) error { return b.AppendTimestamp(1744821296780000, true, true) }}, + {"primitive_timestampntz", func(b *variant.Builder) error { return b.AppendTimestamp(1744806896780000, true, false) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected := loadVariant(t, tt.name) + + var b variant.Builder + require.NoError(t, tt.op(&b)) + + v, err := b.Build() + require.NoError(t, err) + + assert.Equal(t, expected.Type(), v.Type()) + assert.Equal(t, expected.Bytes(), v.Bytes()) + assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) + }) + } +} + +func TestBuildInt64(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendInt(1234567890987654321)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Int64, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveInt64), + 0xB1, 0x1C, 0x6C, 0xB1, 0xF4, 0x10, 0x22, 0x11}, v.Bytes()) +} + +func TestBuildObjec(t *testing.T) { + var b variant.Builder + start := b.Offset() + + fields := make([]variant.FieldEntry, 0, 7) + + fields = append(fields, b.NextField(start, "int_field")) + require.NoError(t, b.AppendInt(1)) + + fields = append(fields, b.NextField(start, "double_field")) + require.NoError(t, b.AppendDecimal4(8, 123456789)) + + fields = append(fields, b.NextField(start, "boolean_true_field")) + require.NoError(t, b.AppendBool(true)) + + fields = append(fields, b.NextField(start, "boolean_false_field")) + require.NoError(t, b.AppendBool(false)) + + fields = append(fields, b.NextField(start, "string_field")) + require.NoError(t, b.AppendString("Apache Parquet")) + + fields = append(fields, b.NextField(start, "null_field")) + require.NoError(t, b.AppendNull()) + + fields = append(fields, b.NextField(start, "timestamp_field")) + require.NoError(t, b.AppendString("2025-04-16T12:34:56.78")) + + require.NoError(t, b.FinishObject(start, fields)) + v, err := b.Build() + require.NoError(t, err) + + assert.Equal(t, variant.Object, v.Type()) + expected := loadVariant(t, "object_primitive") + + assert.Equal(t, expected.Metadata().DictionarySize(), v.Metadata().DictionarySize()) + assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) + assert.Equal(t, expected.Bytes(), v.Bytes()) +} + +func TestBuildObjectNested(t *testing.T) { + var b variant.Builder + + start := b.Offset() + topFields := make([]variant.FieldEntry, 0, 3) + + topFields = append(topFields, b.NextField(start, "id")) + require.NoError(t, b.AppendInt(1)) + + topFields = append(topFields, b.NextField(start, "observation")) + + observeFields := make([]variant.FieldEntry, 0, 3) + observeStart := b.Offset() + observeFields = append(observeFields, b.NextField(observeStart, "location")) + require.NoError(t, b.AppendString("In the Volcano")) + observeFields = append(observeFields, b.NextField(observeStart, "time")) + require.NoError(t, b.AppendString("12:34:56")) + observeFields = append(observeFields, b.NextField(observeStart, "value")) + + valueStart := b.Offset() + valueFields := make([]variant.FieldEntry, 0, 2) + valueFields = append(valueFields, b.NextField(valueStart, "humidity")) + require.NoError(t, b.AppendInt(456)) + valueFields = append(valueFields, b.NextField(valueStart, "temperature")) + require.NoError(t, b.AppendInt(123)) + + require.NoError(t, b.FinishObject(valueStart, valueFields)) + require.NoError(t, b.FinishObject(observeStart, observeFields)) + + topFields = append(topFields, b.NextField(start, "species")) + speciesStart := b.Offset() + speciesFields := make([]variant.FieldEntry, 0, 2) + speciesFields = append(speciesFields, b.NextField(speciesStart, "name")) + require.NoError(t, b.AppendString("lava monster")) + + speciesFields = append(speciesFields, b.NextField(speciesStart, "population")) + require.NoError(t, b.AppendInt(6789)) + + require.NoError(t, b.FinishObject(speciesStart, speciesFields)) + require.NoError(t, b.FinishObject(start, topFields)) + + v, err := b.Build() + require.NoError(t, err) + + out, err := json.Marshal(v) + require.NoError(t, err) + + assert.JSONEq(t, `{ + "id": 1, + "observation": { + "location": "In the Volcano", + "time": "12:34:56", + "value": { + "humidity": 456, + "temperature": 123 + } + }, + "species": { + "name": "lava monster", + "population": 6789 + } + }`, string(out)) +} + +func TestBuildUUID(t *testing.T) { + u := uuid.MustParse("00112233-4455-6677-8899-aabbccddeeff") + + var b variant.Builder + require.NoError(t, b.AppendUUID(u)) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.UUID, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveUUID), + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, v.Bytes()) +} + +func TestBuildTimestampNanos(t *testing.T) { + t.Run("ts nanos tz negative", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendTimestamp(-1, false, true)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanos, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanos), + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, v.Bytes()) + }) + + t.Run("ts nanos tz positive", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendTimestamp(1744877350123456789, false, true)) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanos, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanos), + 0x15, 0xC9, 0xBB, 0x86, 0xB4, 0x0C, 0x37, 0x18}, v.Bytes()) + }) + + t.Run("ts nanos ntz positive", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendTimestamp(1744877350123456789, false, false)) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanosNTZ, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanosNTZ), + 0x15, 0xC9, 0xBB, 0x86, 0xB4, 0x0C, 0x37, 0x18}, v.Bytes()) + }) +} + +func TestBuildArrayValues(t *testing.T) { + t.Run("array primitive", func(t *testing.T) { + var b variant.Builder + + start := b.Offset() + offsets := make([]int, 0, 4) + + offsets = append(offsets, b.Offset()-start) + require.NoError(t, b.AppendInt(2)) + + offsets = append(offsets, b.Offset()-start) + require.NoError(t, b.AppendInt(1)) + + offsets = append(offsets, b.Offset()-start) + require.NoError(t, b.AppendInt(5)) + + offsets = append(offsets, b.Offset()-start) + require.NoError(t, b.AppendInt(9)) + + require.NoError(t, b.FinishArray(start, offsets)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Array, v.Type()) + + expected := loadVariant(t, "array_primitive") + assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) + assert.Equal(t, expected.Bytes(), v.Bytes()) + }) + + t.Run("array empty", func(t *testing.T) { + var b variant.Builder + + require.NoError(t, b.FinishArray(b.Offset(), nil)) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Array, v.Type()) + + expected := loadVariant(t, "array_empty") + assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) + assert.Equal(t, expected.Bytes(), v.Bytes()) + }) +} diff --git a/arrow/extensions/variant/primitive_type_string.go b/arrow/extensions/variant/primitive_type_string.go new file mode 100644 index 00000000..f24ed4d4 --- /dev/null +++ b/arrow/extensions/variant/primitive_type_string.go @@ -0,0 +1,45 @@ +// Code generated by "stringer -type=PrimitiveType -linecomment -output=primitive_type_string.go"; DO NOT EDIT. + +package variant + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[PrimitiveInvalid - -1] + _ = x[PrimitiveNull-0] + _ = x[PrimitiveBoolTrue-1] + _ = x[PrimitiveBoolFalse-2] + _ = x[PrimitiveInt8-3] + _ = x[PrimitiveInt16-4] + _ = x[PrimitiveInt32-5] + _ = x[PrimitiveInt64-6] + _ = x[PrimitiveDouble-7] + _ = x[PrimitiveDecimal4-8] + _ = x[PrimitiveDecimal8-9] + _ = x[PrimitiveDecimal16-10] + _ = x[PrimitiveDate-11] + _ = x[PrimitiveTimestampMicros-12] + _ = x[PrimitiveTimestampMicrosNTZ-13] + _ = x[PrimitiveFloat-14] + _ = x[PrimitiveBinary-15] + _ = x[PrimitiveString-16] + _ = x[PrimitiveTimeMicrosNTZ-17] + _ = x[PrimitiveTimestampNanos-18] + _ = x[PrimitiveTimestampNanosNTZ-19] + _ = x[PrimitiveUUID-20] +} + +const _PrimitiveType_name = "UnknownNullBoolTrueBoolFalseInt8Int16Int32Int64DoubleDecimal32Decimal64Decimal128DateTimestamp(micros)TimestampNTZ(micros)FloatBinaryStringTimeNTZ(micros)Timestamp(nanos)TimestampNTZ(nanos)UUID" + +var _PrimitiveType_index = [...]uint8{0, 7, 11, 19, 28, 32, 37, 42, 47, 53, 62, 71, 81, 85, 102, 122, 127, 133, 139, 154, 170, 189, 193} + +func (i PrimitiveType) String() string { + i -= -1 + if i < 0 || i >= PrimitiveType(len(_PrimitiveType_index)-1) { + return "PrimitiveType(" + strconv.FormatInt(int64(i+-1), 10) + ")" + } + return _PrimitiveType_name[_PrimitiveType_index[i]:_PrimitiveType_index[i+1]] +} diff --git a/arrow/extensions/variant/utils.go b/arrow/extensions/variant/utils.go new file mode 100644 index 00000000..523dce87 --- /dev/null +++ b/arrow/extensions/variant/utils.go @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variant + +import ( + "encoding/binary" + "math" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow/endian" + "github.com/apache/arrow-go/v18/arrow/internal/debug" +) + +func readLEU32(b []byte) uint32 { + debug.Assert(len(b) <= 4, "buffer too large") + debug.Assert(len(b) >= 1, "buffer too small") + + var result uint32 + v := (*[4]byte)(unsafe.Pointer(&result)) + copy(v[:], b) + + return endian.FromLE(result) +} + +func readLEU64(b []byte) uint64 { + debug.Assert(len(b) <= 8, "buffer too large") + debug.Assert(len(b) >= 1, "buffer too small") + + var result uint64 + v := (*[8]byte)(unsafe.Pointer(&result)) + copy(v[:], b) + + return endian.FromLE(result) +} + +func readExact[T int8 | int16 | int32 | int64 | float32 | float64](b []byte) T { + debug.Assert(len(b) >= binary.Size(T(0)), "buffer size mismatch") + var result T + binary.Decode(b, binary.LittleEndian, &result) + return result +} + +func primitiveHeader(t PrimitiveType) byte { + return (byte(t)<<2 | byte(BasicPrimitive)) +} + +func shortStrHeader(sz int) byte { + return byte(sz<<2) | byte(BasicShortString) +} + +func arrayHeader(large bool, offsetSize uint8) byte { + var largeBit byte + if large { + largeBit = 1 + } + + return (largeBit << (basicTypeBits + 2)) | + ((offsetSize - 1) << basicTypeBits) | byte(BasicArray) +} + +func objectHeader(large bool, idSize, offsetSize uint8) byte { + var largeBit byte + if large { + largeBit = 1 + } + + return (largeBit << (basicTypeBits + 4)) | + ((idSize - 1) << (basicTypeBits + 2)) | + ((offsetSize - 1) << basicTypeBits) | byte(BasicObject) +} + +func intSize(v int) uint8 { + debug.Assert(v <= maxSizeLimit, "size too large") + debug.Assert(v >= 0, "size cannot be negative") + + switch { + case v <= math.MaxUint8: + return 1 + case v <= math.MaxUint16: + return 2 + case v <= 0xFFFFFF: // MaxUint24 + return 3 + default: + return 4 + } +} + +func writeOffset(buf []byte, v int, nbytes uint8) { + debug.Assert(nbytes <= 4, "nbytes size too large") + debug.Assert(nbytes >= 1, "nbytes size too small") + + for i := range nbytes { + buf[i] = byte((v >> (i * 8)) & 0xFF) + } +} + +func valueSize(v []byte) int { + basicType, typeInfo := v[0]&basicTypeMask, (v[0]>>basicTypeBits)&typeInfoMask + switch basicType { + case byte(BasicShortString): + return 1 + int(typeInfo) + case byte(BasicObject): + var szBytes uint8 = 1 + if ((typeInfo >> 4) & 0x1) != 0 { + szBytes = 4 + } + + sz := readLEU32(v[1 : 1+szBytes]) + idSize, offsetSize := ((typeInfo>>2)&0b11)+1, uint32((typeInfo&0b11)+1) + idStart := 1 + szBytes + offsetStart := uint32(idStart) + sz*uint32(idSize) + dataStart := offsetStart + (sz+1)*offsetSize + + idx := offsetStart + sz*uint32(offsetSize) + return int(dataStart + readLEU32(v[idx:idx+offsetSize])) + case byte(BasicArray): + var szBytes uint8 = 1 + if ((typeInfo >> 4) & 0x1) != 0 { + szBytes = 4 + } + + sz := readLEU32(v[1 : 1+szBytes]) + offsetSize, offsetStart := uint32((typeInfo&0b11)+1), uint32(1+szBytes) + dataStart := offsetStart + (sz+1)*offsetSize + + idx := offsetStart + sz*uint32(offsetSize) + return int(dataStart + readLEU32(v[idx:idx+offsetSize])) + default: + switch PrimitiveType(typeInfo) { + case PrimitiveNull, PrimitiveBoolTrue, PrimitiveBoolFalse: + return 1 + case PrimitiveInt8: + return 2 + case PrimitiveInt16: + return 3 + case PrimitiveInt32, PrimitiveDate, PrimitiveFloat, PrimitiveTimeMicrosNTZ: + return 5 + case PrimitiveInt64, PrimitiveDouble, + PrimitiveTimestampMicros, PrimitiveTimestampMicrosNTZ, + PrimitiveTimestampNanos, PrimitiveTimestampNanosNTZ: + return 9 + case PrimitiveDecimal4: + return 6 + case PrimitiveDecimal8: + return 10 + case PrimitiveDecimal16: + return 18 + case PrimitiveBinary, PrimitiveString: + return 5 + int(readLEU32(v[1:5])) + case PrimitiveUUID: + return 17 + default: + panic("unknown primitive type") + } + } +} diff --git a/arrow/extensions/variant/variant.go b/arrow/extensions/variant/variant.go new file mode 100644 index 00000000..8b1e3f4e --- /dev/null +++ b/arrow/extensions/variant/variant.go @@ -0,0 +1,649 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variant + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "iter" + "maps" + "slices" + "strings" + "time" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/google/uuid" +) + +//go:generate go tool stringer -type=BasicType -linecomment -output=basic_type_string.go +//go:generate go tool stringer -type=PrimitiveType -linecomment -output=primitive_type_string.go + +type BasicType int + +const ( + BasicUndefined BasicType = iota - 1 // Unknown + BasicPrimitive // Primitive + BasicShortString // ShortString + BasicObject // Object + BasicArray // Array +) + +func basicTypeFromHeader(hdr byte) BasicType { + return BasicType(hdr & basicTypeMask) +} + +type PrimitiveType int + +const ( + PrimitiveInvalid PrimitiveType = iota - 1 // Unknown + PrimitiveNull // Null + PrimitiveBoolTrue // BoolTrue + PrimitiveBoolFalse // BoolFalse + PrimitiveInt8 // Int8 + PrimitiveInt16 // Int16 + PrimitiveInt32 // Int32 + PrimitiveInt64 // Int64 + PrimitiveDouble // Double + PrimitiveDecimal4 // Decimal32 + PrimitiveDecimal8 // Decimal64 + PrimitiveDecimal16 // Decimal128 + PrimitiveDate // Date + PrimitiveTimestampMicros // Timestamp(micros) + PrimitiveTimestampMicrosNTZ // TimestampNTZ(micros) + PrimitiveFloat // Float + PrimitiveBinary // Binary + PrimitiveString // String + PrimitiveTimeMicrosNTZ // TimeNTZ(micros) + PrimitiveTimestampNanos // Timestamp(nanos) + PrimitiveTimestampNanosNTZ // TimestampNTZ(nanos) + PrimitiveUUID // UUID +) + +func primitiveTypeFromHeader(hdr byte) PrimitiveType { + return PrimitiveType((hdr >> basicTypeBits) & typeInfoMask) +} + +type Type int + +const ( + Object Type = iota + Array + Null + Bool + Int8 + Int16 + Int32 + Int64 + String + Double + Decimal4 + Decimal8 + Decimal16 + Date + TimestampMicros + TimestampMicrosNTZ + Float + Binary + Time + TimestampNanos + TimestampNanosNTZ + UUID +) + +const ( + versionMask uint8 = 0x0F + sortedStrMask uint8 = 0b10000 + basicTypeMask uint8 = 0x3 + basicTypeBits uint8 = 2 + typeInfoMask uint8 = 0x3F + hdrSizeBytes = 1 + minOffsetSizeBytes = 1 + maxOffsetSizeBytes = 4 + + // mask is applied after shift + offsetSizeMask uint8 = 0b11 + offsetSizeBitShift uint8 = 6 + supportedVersion = 1 + maxShortStringSize = 0x3F + maxSizeLimit = 128 * 1024 * 1024 // 128MB +) + +var ( + EmptyMetadataBytes = [3]byte{0x1, 0, 0} +) + +type Metadata struct { + data []byte + keys [][]byte +} + +func NewMetadata(data []byte) (Metadata, error) { + m := Metadata{data: data} + if len(data) < hdrSizeBytes+minOffsetSizeBytes*2 { + return m, fmt.Errorf("invalid variant metadata: too short: size=%d", len(data)) + } + + if m.Version() != supportedVersion { + return m, fmt.Errorf("invalid variant metadata: unsupported version: %d", m.Version()) + } + + offsetSz := m.OffsetSize() + if offsetSz < minOffsetSizeBytes || offsetSz > maxOffsetSizeBytes { + return m, fmt.Errorf("invalid variant metadata: invalid offset size: %d", offsetSz) + } + + dictSize, err := m.loadDictionary(offsetSz) + if err != nil { + return m, err + } + + if hdrSizeBytes+int(dictSize+1)*int(offsetSz) > len(m.data) { + return m, fmt.Errorf("invalid variant metadata: offset out of range: %d > %d", + (dictSize+hdrSizeBytes)*uint32(offsetSz), len(m.data)) + } + + return m, nil +} + +func (m *Metadata) Clone() Metadata { + return Metadata{ + data: bytes.Clone(m.data), + // shallow copy of the values, but the slice is copied + // more efficient, and nothing should be mutating the keys + // so it's probably safe, but something we should keep in mind + keys: slices.Clone(m.keys), + } +} + +func (m *Metadata) loadDictionary(offsetSz uint8) (uint32, error) { + if int(offsetSz+hdrSizeBytes) > len(m.data) { + return 0, errors.New("invalid variant metadata: too short for dictionary size") + } + + dictSize := readLEU32(m.data[hdrSizeBytes : hdrSizeBytes+offsetSz]) + m.keys = make([][]byte, dictSize) + + if dictSize == 0 { + return 0, nil + } + + // first offset is always 0 + offsetStart, offsetPos := uint32(0), hdrSizeBytes+offsetSz + valuesStart := hdrSizeBytes + (dictSize+2)*uint32(offsetSz) + for i := range dictSize { + offsetPos += offsetSz + end := readLEU32(m.data[offsetPos : offsetPos+offsetSz]) + + keySize := end - offsetStart + valStart := valuesStart + offsetStart + if valStart+keySize > uint32(len(m.data)) { + return 0, fmt.Errorf("invalid variant metadata: string data out of range: %d + %d > %d", + valStart, keySize, len(m.data)) + } + m.keys[i] = m.data[valStart : valStart+keySize] + offsetStart += keySize + } + + return dictSize, nil +} + +func (m Metadata) Bytes() []byte { return m.data } + +func (m Metadata) Version() uint8 { return m.data[0] & versionMask } +func (m Metadata) SortedAndUnique() bool { return m.data[0]&sortedStrMask != 0 } +func (m Metadata) OffsetSize() uint8 { + return ((m.data[0] >> offsetSizeBitShift) & offsetSizeMask) + 1 +} + +func (m Metadata) DictionarySize() uint32 { return uint32(len(m.keys)) } + +func (m Metadata) KeyAt(id uint32) (string, error) { + if id >= uint32(len(m.keys)) { + return "", fmt.Errorf("invalid variant metadata: id out of range: %d >= %d", + id, len(m.keys)) + } + + return unsafe.String(&m.keys[id][0], len(m.keys[id])), nil +} + +func (m Metadata) IdFor(key string) []uint32 { + k := unsafe.Slice(unsafe.StringData(key), len(key)) + + var ret []uint32 + if m.SortedAndUnique() { + idx, found := slices.BinarySearchFunc(m.keys, k, bytes.Compare) + if found { + ret = append(ret, uint32(idx)) + } + + return ret + } + + for i, k := range m.keys { + if bytes.Equal(k, k) { + ret = append(ret, uint32(i)) + } + } + + return ret +} + +type DecimalValue[T decimal.DecimalTypes] struct { + Scale uint8 + Value decimal.Num[T] +} + +func (v DecimalValue[T]) MarshalJSON() ([]byte, error) { + return []byte(v.Value.ToString(int32(v.Scale))), nil +} + +type ArrayValue struct { + value []byte + meta Metadata + + numElements uint32 + dataStart uint32 + offsetSize uint8 + offsetStart uint8 +} + +func (v ArrayValue) MarshalJSON() ([]byte, error) { + return json.Marshal(v.Values()) +} + +func (v ArrayValue) NumElements() uint32 { return v.numElements } + +func (v ArrayValue) Values() []Value { + values := make([]Value, v.numElements) + for i := range v.numElements { + idx := uint32(v.offsetStart) + i*uint32(v.offsetSize) + offset := readLEU32(v.value[idx : idx+uint32(v.offsetSize)]) + values[i] = Value{value: v.value[v.dataStart+offset:], meta: v.meta} + } + return values +} + +func (v ArrayValue) Value(i uint32) (Value, error) { + if i >= v.numElements { + return Value{}, fmt.Errorf("%w: invalid array value: index out of range: %d >= %d", + arrow.ErrIndex, i, v.numElements) + } + + idx := uint32(v.offsetStart) + i*uint32(v.offsetSize) + offset := readLEU32(v.value[idx : idx+uint32(v.offsetSize)]) + + return Value{meta: v.meta, value: v.value[v.dataStart+offset:]}, nil +} + +type ObjectValue struct { + value []byte + meta Metadata + + numElements uint32 + offsetStart uint32 + dataStart uint32 + idSize uint8 + offsetSize uint8 + idStart uint8 +} + +type ObjectField struct { + Key string + Value Value +} + +func (v ObjectValue) NumElements() uint32 { return v.numElements } +func (v ObjectValue) ValueByKey(key string) (ObjectField, error) { + n := v.numElements + + // if total list size is smaller than threshold, linear search will + // likely be faster than a binary search + const binarySearchThreshold = 32 + if n < binarySearchThreshold { + for i := range n { + idx := uint32(v.idStart) + i*uint32(v.idSize) + id := readLEU32(v.value[idx : idx+uint32(v.idSize)]) + k, err := v.meta.KeyAt(id) + if err != nil { + return ObjectField{}, fmt.Errorf("invalid object value: fieldID at idx %d is not in metadata", idx) + } + if k == key { + idx := uint32(v.offsetStart) + uint32(v.offsetSize)*i + offset := readLEU32(v.value[idx : idx+uint32(v.offsetSize)]) + return ObjectField{ + Key: key, + Value: Value{value: v.value[v.dataStart+offset:], meta: v.meta}}, nil + } + } + return ObjectField{}, arrow.ErrNotFound + } + + i, j := uint32(0), n + for i < j { + mid := (i + j) >> 1 + idx := uint32(v.idStart) + mid*uint32(v.idSize) + id := readLEU32(v.value[idx : idx+uint32(v.idSize)]) + k, err := v.meta.KeyAt(id) + if err != nil { + return ObjectField{}, fmt.Errorf("invalid object value: fieldID at idx %d is not in metadata", idx) + } + + switch strings.Compare(k, key) { + case -1: + i = mid + 1 + case 0: + idx := uint32(v.offsetStart) + uint32(v.offsetSize)*mid + offset := readLEU32(v.value[idx : idx+uint32(v.offsetSize)]) + + return ObjectField{ + Key: key, + Value: Value{value: v.value[v.dataStart+offset:], meta: v.meta}}, nil + case 1: + j = mid - 1 + } + } + + return ObjectField{}, arrow.ErrNotFound +} + +func (v ObjectValue) FieldAt(i uint32) (ObjectField, error) { + if i >= v.numElements { + return ObjectField{}, fmt.Errorf("%w: invalid object value: index out of range: %d >= %d", + arrow.ErrIndex, i, v.numElements) + } + + idx := uint32(v.idStart) + i*uint32(v.idSize) + id := readLEU32(v.value[idx : idx+uint32(v.idSize)]) + k, err := v.meta.KeyAt(id) + if err != nil { + return ObjectField{}, fmt.Errorf("invalid object value: fieldID at idx %d is not in metadata", idx) + } + + offsetIdx := uint32(v.offsetStart) + i*uint32(v.offsetSize) + offset := readLEU32(v.value[offsetIdx : offsetIdx+uint32(v.offsetSize)]) + + return ObjectField{ + Key: k, + Value: Value{value: v.value[v.dataStart+offset:], meta: v.meta}}, nil +} + +func (v ObjectValue) Values() iter.Seq2[string, Value] { + return func(yield func(string, Value) bool) { + for i := range v.numElements { + idx := uint32(v.idStart) + i*uint32(v.idSize) + id := readLEU32(v.value[idx : idx+uint32(v.idSize)]) + k, err := v.meta.KeyAt(id) + if err != nil { + return + } + + offsetIdx := uint32(v.offsetStart) + i*uint32(v.offsetSize) + offset := readLEU32(v.value[offsetIdx : offsetIdx+uint32(v.offsetSize)]) + + if !yield(k, Value{value: v.value[v.dataStart+offset:], meta: v.meta}) { + return + } + } + } +} + +func (v ObjectValue) MarshalJSON() ([]byte, error) { + // for now we'll use a naive approach and just build a map + // then marshal it. This is not the most efficient way to do this + // but it is the simplest and most straightforward. + mapping := make(map[string]Value) + maps.Insert(mapping, v.Values()) + return json.Marshal(mapping) +} + +type Value struct { + value []byte + meta Metadata +} + +func NewWithMetadata(meta Metadata, value []byte) (Value, error) { + if len(value) == 0 { + return Value{}, errors.New("invalid variant value: empty") + } + + return Value{value: value, meta: meta}, nil +} + +func New(meta, value []byte) (Value, error) { + m, err := NewMetadata(meta) + if err != nil { + return Value{}, err + } + + return NewWithMetadata(m, value) +} + +func (v Value) Bytes() []byte { return v.value } + +func (v Value) Clone() Value { return Value{value: bytes.Clone(v.value)} } + +func (v Value) Metadata() Metadata { return v.meta } + +func (v Value) BasicType() BasicType { + return basicTypeFromHeader(v.value[0]) +} + +func (v Value) Type() Type { + switch t := v.BasicType(); t { + case BasicPrimitive: + switch primType := primitiveTypeFromHeader(v.value[0]); primType { + case PrimitiveNull: + return Null + case PrimitiveBoolTrue, PrimitiveBoolFalse: + return Bool + case PrimitiveInt8: + return Int8 + case PrimitiveInt16: + return Int16 + case PrimitiveInt32: + return Int32 + case PrimitiveInt64: + return Int64 + case PrimitiveDouble: + return Double + case PrimitiveDecimal4: + return Decimal4 + case PrimitiveDecimal8: + return Decimal8 + case PrimitiveDecimal16: + return Decimal16 + case PrimitiveDate: + return Date + case PrimitiveTimestampMicros: + return TimestampMicros + case PrimitiveTimestampMicrosNTZ: + return TimestampMicrosNTZ + case PrimitiveFloat: + return Float + case PrimitiveBinary: + return Binary + case PrimitiveString: + return String + case PrimitiveTimeMicrosNTZ: + return Time + case PrimitiveTimestampNanos: + return TimestampNanos + case PrimitiveTimestampNanosNTZ: + return TimestampNanosNTZ + case PrimitiveUUID: + return UUID + default: + panic(fmt.Errorf("invalid primitive type found: %d", primType)) + } + case BasicShortString: + return String + case BasicObject: + return Object + case BasicArray: + return Array + default: + panic(fmt.Errorf("invalid basic type found: %d", t)) + } +} + +func (v Value) Value() any { + switch t := v.BasicType(); t { + case BasicPrimitive: + switch primType := primitiveTypeFromHeader(v.value[0]); primType { + case PrimitiveNull: + return nil + case PrimitiveBoolTrue: + return true + case PrimitiveBoolFalse: + return false + case PrimitiveInt8: + return readExact[int8](v.value[1:]) + case PrimitiveInt16: + return readExact[int16](v.value[1:]) + case PrimitiveInt32: + return readExact[int32](v.value[1:]) + case PrimitiveInt64: + return readExact[int64](v.value[1:]) + case PrimitiveDouble: + return readExact[float64](v.value[1:]) + case PrimitiveFloat: + return readExact[float32](v.value[1:]) + case PrimitiveDate: + return arrow.Date32(readExact[int32](v.value[1:])) + case PrimitiveTimestampMicros, PrimitiveTimestampMicrosNTZ, + PrimitiveTimestampNanos, PrimitiveTimestampNanosNTZ: + return arrow.Timestamp(readExact[int64](v.value[1:])) + case PrimitiveTimeMicrosNTZ: + return arrow.Time32(readExact[int32](v.value[1:])) + case PrimitiveUUID: + debug.Assert(len(v.value[1:]) == 16, "invalid UUID length") + return uuid.Must(uuid.FromBytes(v.value[1:])) + case PrimitiveBinary: + sz := binary.LittleEndian.Uint32(v.value[1:5]) + return v.value[5 : 5+sz] + case PrimitiveString: + sz := binary.LittleEndian.Uint32(v.value[1:5]) + return unsafe.String(&v.value[5], sz) + case PrimitiveDecimal4: + scale := uint8(v.value[1]) + val := decimal.Decimal32(readExact[int32](v.value[2:])) + return DecimalValue[decimal.Decimal32]{Scale: scale, Value: val} + case PrimitiveDecimal8: + scale := uint8(v.value[1]) + val := decimal.Decimal64(readExact[int64](v.value[2:])) + return DecimalValue[decimal.Decimal64]{Scale: scale, Value: val} + case PrimitiveDecimal16: + scale := uint8(v.value[1]) + lowBits := readLEU64(v.value[2:10]) + highBits := readExact[int64](v.value[10:]) + return DecimalValue[decimal.Decimal128]{ + Scale: scale, + Value: decimal128.New(highBits, lowBits), + } + } + case BasicShortString: + sz := int(v.value[0] >> 2) + return unsafe.String(&v.value[1], sz) + case BasicObject: + valueHdr := (v.value[0] >> basicTypeBits) + fieldOffsetSz := (valueHdr & 0b11) + 1 + fieldIdSz := ((valueHdr >> 2) & 0b11) + 1 + isLarge := ((valueHdr >> 4) & 0b1) == 1 + + var nelemSize uint8 = 1 + if isLarge { + nelemSize = 4 + } + + debug.Assert(len(v.value) >= int(1+nelemSize), "invalid object value: too short") + numElements := readLEU32(v.value[1 : 1+nelemSize]) + idStart := uint32(1 + nelemSize) + offsetStart := idStart + numElements*uint32(fieldIdSz) + dataStart := offsetStart + (numElements+1)*uint32(fieldOffsetSz) + + debug.Assert(dataStart <= uint32(len(v.value)), "invalid object value: dataStart out of range") + return ObjectValue{ + value: v.value, + meta: v.meta, + numElements: numElements, + offsetStart: offsetStart, + dataStart: dataStart, + idSize: fieldIdSz, + offsetSize: fieldOffsetSz, + idStart: uint8(idStart), + } + case BasicArray: + valueHdr := (v.value[0] >> basicTypeBits) + fieldOffsetSz := (valueHdr & 0b11) + 1 + isLarge := (valueHdr & 0b1) == 1 + + var ( + sz int + offsetStart, dataStart int + ) + + if isLarge { + sz, offsetStart = int(readLEU32(v.value[1:5])), 5 + } else { + sz, offsetStart = int(v.value[1]), 2 + } + + dataStart = offsetStart + (sz+1)*int(fieldOffsetSz) + debug.Assert(dataStart <= len(v.value), "invalid array value: dataStart out of range") + return ArrayValue{ + value: v.value, + meta: v.meta, + numElements: uint32(sz), + dataStart: uint32(dataStart), + offsetSize: fieldOffsetSz, + offsetStart: uint8(offsetStart), + } + } + + debug.Assert(false, "unsupported type") + return nil +} + +func (v Value) MarshalJSON() ([]byte, error) { + result := v.Value() + switch t := result.(type) { + case arrow.Date32: + result = t.FormattedString() + case arrow.Timestamp: + switch primType := primitiveTypeFromHeader(v.value[0]); primType { + case PrimitiveTimestampMicros: + result = t.ToTime(arrow.Microsecond).Format("2006-01-02 15:04:05.999999Z0700") + case PrimitiveTimestampMicrosNTZ: + result = t.ToTime(arrow.Microsecond).In(time.Local).Format("2006-01-02 15:04:05.999999Z0700") + case PrimitiveTimestampNanos: + result = t.ToTime(arrow.Nanosecond).Format("2006-01-02 15:04:05.999999999Z0700") + case PrimitiveTimestampNanosNTZ: + result = t.ToTime(arrow.Nanosecond).In(time.Local).Format("2006-01-02 15:04:05.999999999Z0700") + } + case arrow.Time32: + result = t.ToTime(arrow.Microsecond).In(time.Local).Format("15:04:05.999999Z0700") + } + + return json.Marshal(result) +} diff --git a/arrow/extensions/variant/variant_test.go b/arrow/extensions/variant/variant_test.go new file mode 100644 index 00000000..a9fa5d22 --- /dev/null +++ b/arrow/extensions/variant/variant_test.go @@ -0,0 +1,512 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variant_test + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/extensions/variant" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getVariantDir() string { + variantDir := os.Getenv("PARQUET_TEST_DATA") + if variantDir == "" { + return "" + } + + return filepath.Join(variantDir, "..", "variant") +} + +func metadataTestFilename(test string) string { + return test + ".metadata" +} + +func valueTestFilename(test string) string { + return test + ".value" +} + +func TestBasicRead(t *testing.T) { + dir := getVariantDir() + if dir == "" { + t.Skip("PARQUET_TEST_DATA not set") + } + + tests := []string{ + // FIXME: null metadata is corrupt, see + // https://github.com/apache/parquet-testing/issues/81 + // "primitive_null.metadata", + "primitive_boolean_true.metadata", + "primitive_boolean_false.metadata", + "primitive_int8.metadata", + "primitive_int16.metadata", + "primitive_int32.metadata", + "primitive_int64.metadata", + "primitive_float.metadata", + "primitive_double.metadata", + "primitive_string.metadata", + "primitive_binary.metadata", + "primitive_date.metadata", + "primitive_decimal4.metadata", + "primitive_decimal8.metadata", + "primitive_decimal16.metadata", + "primitive_timestamp.metadata", + "primitive_timestampntz.metadata", + } + + for _, test := range tests { + t.Run(test, func(t *testing.T) { + fname := filepath.Join(dir, test) + require.FileExists(t, fname, "file %s does not exist", fname) + + metadata, err := os.ReadFile(fname) + require.NoError(t, err) + + m, err := variant.NewMetadata(metadata) + require.NoError(t, err) + assert.EqualValues(t, 1, m.Version()) + _, err = m.KeyAt(0) + assert.Error(t, err) + }) + } + + t.Run("object_primitive.metadata", func(t *testing.T) { + fname := filepath.Join(dir, "object_primitive.metadata") + require.FileExists(t, fname, "file %s does not exist", fname) + + metadata, err := os.ReadFile(fname) + require.NoError(t, err) + + m, err := variant.NewMetadata(metadata) + require.NoError(t, err) + assert.EqualValues(t, 1, m.Version()) + + keys := []string{ + "int_field", "double_field", "boolean_true_field", + "boolean_false_field", "string_field", "null_field", + "timestamp_field", + } + + for i, k := range keys { + key, err := m.KeyAt(uint32(i)) + require.NoError(t, err) + assert.Equal(t, k, key) + } + }) +} + +func loadVariant(t *testing.T, test string) variant.Value { + dir := getVariantDir() + if dir == "" { + t.Skip("PARQUET_TEST_DATA not set") + } + + fname := filepath.Join(dir, test) + metadataPath := metadataTestFilename(fname) + valuePath := valueTestFilename(fname) + + metaBytes, err := os.ReadFile(metadataPath) + require.NoError(t, err) + valueBytes, err := os.ReadFile(valuePath) + require.NoError(t, err) + + v, err := variant.New(metaBytes, valueBytes) + require.NoError(t, err) + return v +} + +func TestPrimitiveVariants(t *testing.T) { + tests := []struct { + name string + expected any + variantType variant.Type + jsonStr string + }{ + {"primitive_boolean_true", true, variant.Bool, "true"}, + {"primitive_boolean_false", false, variant.Bool, "false"}, + {"primitive_int8", int8(42), variant.Int8, "42"}, + {"primitive_int16", int16(1234), variant.Int16, "1234"}, + {"primitive_int32", int32(123456), variant.Int32, "123456"}, + // FIXME: https://github.com/apache/parquet-testing/issues/82 + // primitive_int64 is an int32 value, but the metadata is int64 + {"primitive_int64", int32(12345678), variant.Int32, "12345678"}, + {"primitive_float", float32(1234567940.0), variant.Float, "1234568000"}, + {"primitive_double", float64(1234567890.1234), variant.Double, "1234567890.1234"}, + {"primitive_string", + `This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥️, 🎣 and 🤦!!`, + variant.String, `"This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥️, 🎣 and 🤦!!"`}, + {"short_string", `Less than 64 bytes (❤️ with utf8)`, variant.String, `"Less than 64 bytes (❤️ with utf8)"`}, + // 031337deadbeefcafe + {"primitive_binary", []byte{0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, variant.Binary, `"AxM33q2+78r+"`}, + {"primitive_decimal4", variant.DecimalValue[decimal.Decimal32]{ + Scale: 2, + Value: decimal.Decimal32(1234), + }, variant.Decimal4, `12.34`}, + {"primitive_decimal8", variant.DecimalValue[decimal.Decimal64]{ + Scale: 2, + Value: decimal.Decimal64(1234567890), + }, variant.Decimal8, `12345678.90`}, + {"primitive_decimal16", variant.DecimalValue[decimal.Decimal128]{ + Scale: 2, + Value: decimal128.FromU64(1234567891234567890), + }, variant.Decimal16, `12345678912345678.90`}, + // // 2025-04-16 + {"primitive_date", arrow.Date32(20194), variant.Date, `"2025-04-16"`}, + {"primitive_timestamp", arrow.Timestamp(1744821296780000), variant.TimestampMicros, `"2025-04-16 16:34:56.78Z"`}, + {"primitive_timestampntz", arrow.Timestamp(1744806896780000), variant.TimestampMicrosNTZ, `"` + time.UnixMicro(1744806896780000).UTC().In(time.Local).Format("2006-01-02 15:04:05.999999Z0700") + `"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := loadVariant(t, tt.name) + assert.Equal(t, tt.expected, v.Value()) + assert.Equal(t, tt.variantType, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.Equal(t, tt.jsonStr, string(out)) + }) + } +} + +func primitiveHeader(p variant.PrimitiveType) uint8 { + return (uint8(p) << 2) +} + +func TestNullValue(t *testing.T) { + emptyMeta := variant.EmptyMetadataBytes + nullChars := []byte{primitiveHeader(variant.PrimitiveNull)} + + v, err := variant.New(emptyMeta[:], nullChars) + require.NoError(t, err) + + assert.Equal(t, variant.Null, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.Equal(t, "null", string(out)) +} + +func TestSimpleInt64(t *testing.T) { + metaBytes := variant.EmptyMetadataBytes[:] + + int64Bytes := []byte{primitiveHeader(variant.PrimitiveInt64), + 0xB1, 0x1C, 0x6C, 0xB1, 0xF4, 0x10, 0x22, 0x11} + + v, err := variant.New(metaBytes, int64Bytes) + require.NoError(t, err) + + assert.Equal(t, variant.Int64, v.Type()) + assert.Equal(t, int64(1234567890987654321), v.Value()) + + negInt64Bytes := []byte{primitiveHeader(variant.PrimitiveInt64), + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} + + v, err = variant.New(metaBytes, negInt64Bytes) + require.NoError(t, err) + + assert.Equal(t, variant.Int64, v.Type()) + assert.Equal(t, int64(-1), v.Value()) +} + +func TestObjectValues(t *testing.T) { + v := loadVariant(t, "object_primitive") + assert.Equal(t, variant.Object, v.Type()) + + obj := v.Value().(variant.ObjectValue) + assert.EqualValues(t, 7, obj.NumElements()) + + tests := []struct { + field string + expected any + typ variant.Type + }{ + {"int_field", int8(1), variant.Int8}, + {"double_field", variant.DecimalValue[decimal.Decimal32]{ + Scale: 8, Value: decimal.Decimal32(123456789)}, variant.Decimal4}, + {"boolean_true_field", true, variant.Bool}, + {"boolean_false_field", false, variant.Bool}, + {"string_field", "Apache Parquet", variant.String}, + {"null_field", nil, variant.Null}, + {"timestamp_field", "2025-04-16T12:34:56.78", variant.String}, + } + + for _, tt := range tests { + t.Run(tt.field, func(t *testing.T) { + v, err := obj.ValueByKey(tt.field) + require.NoError(t, err) + + assert.Equal(t, tt.typ, v.Value.Type()) + assert.Equal(t, tt.expected, v.Value.Value()) + }) + } + + t.Run("json", func(t *testing.T) { + out, err := json.Marshal(v) + require.NoError(t, err) + + expected := `{ + "boolean_false_field":false, + "boolean_true_field":true, + "double_field":1.23456789, + "int_field":1, + "null_field":null, + "string_field":"Apache Parquet", + "timestamp_field":"2025-04-16T12:34:56.78"}` + + assert.JSONEq(t, expected, string(out)) + }) + + t.Run("invalid_key", func(t *testing.T) { + v, err := obj.ValueByKey("invalid_key") + require.ErrorIs(t, err, arrow.ErrNotFound) + assert.Zero(t, v) + }) + + t.Run("field by index", func(t *testing.T) { + fieldOrder := []string{ + "boolean_false_field", + "boolean_true_field", + "double_field", + "int_field", + "null_field", + "string_field", + "timestamp_field", + } + + for i := range obj.NumElements() { + val, err := obj.FieldAt(i) + require.NoError(t, err) + + assert.Equal(t, fieldOrder[i], val.Key) + } + }) +} + +func TestNestedObjectValues(t *testing.T) { + v := loadVariant(t, "object_nested") + assert.Equal(t, variant.Object, v.Type()) + obj := v.Value().(variant.ObjectValue) + assert.EqualValues(t, 3, obj.NumElements()) + + // trying to get the exists key + id, err := obj.ValueByKey("id") + require.NoError(t, err) + assert.Equal(t, variant.Int8, id.Value.Type()) + assert.Equal(t, int8(1), id.Value.Value()) + + observation, err := obj.ValueByKey("observation") + require.NoError(t, err) + assert.Equal(t, variant.Object, observation.Value.Type()) + + species, err := obj.ValueByKey("species") + require.NoError(t, err) + assert.Equal(t, variant.Object, species.Value.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `{ + "id": 1, + "observation": { + "location": "In the Volcano", + "time": "12:34:56", + "value": { + "humidity": 456, + "temperature": 123 + } + }, + "species": { + "name": "lava monster", + "population": 6789 + } + }`, string(out)) + + t.Run("inner object", func(t *testing.T) { + speciesObj := species.Value.Value().(variant.ObjectValue) + assert.EqualValues(t, 2, speciesObj.NumElements()) + + name, err := speciesObj.ValueByKey("name") + require.NoError(t, err) + assert.Equal(t, variant.String, name.Value.Type()) + assert.Equal(t, "lava monster", name.Value.Value()) + + population, err := speciesObj.ValueByKey("population") + require.NoError(t, err) + assert.Equal(t, variant.Int16, population.Value.Type()) + assert.Equal(t, int16(6789), population.Value.Value()) + }) + + t.Run("inner key outside", func(t *testing.T) { + // only observation should successfully retrieve key + observationKeys := []string{"location", "time", "value"} + observationObj := observation.Value.Value().(variant.ObjectValue) + speciesObj := species.Value.Value().(variant.ObjectValue) + for _, k := range observationKeys { + inner, err := observationObj.ValueByKey(k) + require.NoError(t, err) + assert.Equal(t, k, inner.Key) + + _, err = obj.ValueByKey(k) + require.ErrorIs(t, err, arrow.ErrNotFound) + + _, err = speciesObj.ValueByKey(k) + require.ErrorIs(t, err, arrow.ErrNotFound) + } + }) +} + +func TestUUID(t *testing.T) { + emptyMeta := variant.EmptyMetadataBytes[:] + uuidBytes := []byte{primitiveHeader(variant.PrimitiveUUID), + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, + 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF} + + v, err := variant.New(emptyMeta, uuidBytes) + require.NoError(t, err) + assert.Equal(t, variant.UUID, v.Type()) + assert.Equal(t, uuid.MustParse("00112233-4455-6677-8899-aabbccddeeff"), v.Value()) +} + +func TestTimestampNanos(t *testing.T) { + emptyMeta := variant.EmptyMetadataBytes[:] + + t.Run("ts nanos tz negative", func(t *testing.T) { + data := []byte{primitiveHeader(variant.PrimitiveTimestampNanos), + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} + v, err := variant.New(emptyMeta, data) + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanos, v.Type()) + assert.Equal(t, arrow.Timestamp(-1), v.Value()) + }) + + t.Run("ts nanos tz positive", func(t *testing.T) { + data := []byte{primitiveHeader(variant.PrimitiveTimestampNanos), + 0x15, 0xC9, 0xBB, 0x86, 0xB4, 0x0C, 0x37, 0x18} + v, err := variant.New(emptyMeta, data) + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanos, v.Type()) + assert.Equal(t, arrow.Timestamp(1744877350123456789), v.Value()) + }) + + t.Run("ts nanos ntz positive", func(t *testing.T) { + data := []byte{primitiveHeader(variant.PrimitiveTimestampNanosNTZ), + 0x15, 0xC9, 0xBB, 0x86, 0xB4, 0x0C, 0x37, 0x18} + v, err := variant.New(emptyMeta, data) + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanosNTZ, v.Type()) + assert.Equal(t, arrow.Timestamp(1744877350123456789), v.Value()) + }) +} + +func TestArrayValues(t *testing.T) { + t.Run("array primitive", func(t *testing.T) { + v := loadVariant(t, "array_primitive") + assert.Equal(t, variant.Array, v.Type()) + + arr := v.Value().(variant.ArrayValue) + assert.EqualValues(t, 4, arr.NumElements()) + + elem0, err := arr.Value(0) + require.NoError(t, err) + assert.Equal(t, variant.Int8, elem0.Type()) + assert.Equal(t, int8(2), elem0.Value()) + + elem1, err := arr.Value(1) + require.NoError(t, err) + assert.Equal(t, variant.Int8, elem1.Type()) + assert.Equal(t, int8(1), elem1.Value()) + + elem2, err := arr.Value(2) + require.NoError(t, err) + assert.Equal(t, variant.Int8, elem2.Type()) + assert.Equal(t, int8(5), elem2.Value()) + + elem3, err := arr.Value(3) + require.NoError(t, err) + assert.Equal(t, variant.Int8, elem3.Type()) + assert.Equal(t, int8(9), elem3.Value()) + + _, err = arr.Value(4) + require.ErrorIs(t, err, arrow.ErrIndex) + + out, err := json.Marshal(v) + require.NoError(t, err) + expected := `[2,1,5,9]` + assert.JSONEq(t, expected, string(out)) + }) + + t.Run("empty array", func(t *testing.T) { + v := loadVariant(t, "array_empty") + assert.Equal(t, variant.Array, v.Type()) + + arr := v.Value().(variant.ArrayValue) + assert.EqualValues(t, 0, arr.NumElements()) + _, err := arr.Value(0) + require.ErrorIs(t, err, arrow.ErrIndex) + }) + + t.Run("array nested", func(t *testing.T) { + v := loadVariant(t, "array_nested") + assert.Equal(t, variant.Array, v.Type()) + + arr := v.Value().(variant.ArrayValue) + assert.EqualValues(t, 3, arr.NumElements()) + + elem0, err := arr.Value(0) + require.NoError(t, err) + assert.Equal(t, variant.Object, elem0.Type()) + elemObj0 := elem0.Value().(variant.ObjectValue) + assert.EqualValues(t, 2, elemObj0.NumElements()) + + id, err := elemObj0.ValueByKey("id") + require.NoError(t, err) + assert.Equal(t, variant.Int8, id.Value.Type()) + assert.Equal(t, int8(1), id.Value.Value()) + + elem1, err := arr.Value(1) + require.NoError(t, err) + assert.Equal(t, variant.Null, elem1.Type()) + + elem2, err := arr.Value(2) + require.NoError(t, err) + assert.Equal(t, variant.Object, elem2.Type()) + elemObj2 := elem2.Value().(variant.ObjectValue) + assert.EqualValues(t, 3, elemObj2.NumElements()) + id, err = elemObj2.ValueByKey("id") + require.NoError(t, err) + assert.Equal(t, variant.Int8, id.Value.Type()) + assert.Equal(t, int8(2), id.Value.Value()) + + out, err := json.Marshal(v) + require.NoError(t, err) + expected := `[ + {"id":1, "thing":{"names": ["Contrarian", "Spider"]}}, + null, + {"id":2, "names": ["Apple", "Ray", null], "type": "if"} + ]` + assert.JSONEq(t, expected, string(out)) + }) +} diff --git a/go.mod b/go.mod index c8995f2d..e4d7cf08 100644 --- a/go.mod +++ b/go.mod @@ -44,9 +44,9 @@ require ( github.com/tidwall/sjson v1.2.5 github.com/zeebo/xxh3 v1.0.2 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 - golang.org/x/sync v0.13.0 - golang.org/x/sys v0.32.0 - golang.org/x/tools v0.32.0 + golang.org/x/sync v0.14.0 + golang.org/x/sys v0.33.0 + golang.org/x/tools v0.33.0 golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da gonum.org/v1/gonum v0.16.0 google.golang.org/grpc v1.72.0 @@ -89,9 +89,9 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/mod v0.24.0 // indirect - golang.org/x/net v0.39.0 // indirect - golang.org/x/term v0.31.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/term v0.32.0 // indirect + golang.org/x/text v0.25.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect @@ -101,3 +101,5 @@ require ( modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect ) + +tool golang.org/x/tools/cmd/stringer diff --git a/go.sum b/go.sum index 337578c9..068d9c2d 100644 --- a/go.sum +++ b/go.sum @@ -198,8 +198,8 @@ go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= -golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -210,13 +210,13 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= -golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -229,28 +229,28 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= -golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= +golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU= -golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= diff --git a/parquet-testing b/parquet-testing index 39b91cf8..2dc8bf14 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 39b91cf853062d92f0d20581d37b20dabe70a6a0 +Subproject commit 2dc8bf140ed6e28652fc347211c7d661714c7f95 diff --git a/parquet/variants/builder.go b/parquet/variants/builder.go index 032e069e..de07146e 100644 --- a/parquet/variants/builder.go +++ b/parquet/variants/builder.go @@ -95,10 +95,10 @@ func Marshal(val any, opts ...MarshalOpts) (*MarshaledVariant, error) { func (vb *VariantBuilder) check() error { if vb.built { - return errors.New("Variant has already been built") + return errors.New("variant has already been built") } if vb.typ != BasicUndefined { - return fmt.Errorf("Variant type has already been started as a %q", vb.typ) + return fmt.Errorf("variant type has already been started as a %q", vb.typ) } return nil } diff --git a/parquet/variants/primitive.go b/parquet/variants/primitive.go index 5d4eed01..49449f17 100644 --- a/parquet/variants/primitive.go +++ b/parquet/variants/primitive.go @@ -17,14 +17,18 @@ package variants import ( + "encoding/binary" "fmt" "io" "math" + "math/bits" "reflect" "strings" "time" "unsafe" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/google/uuid" ) @@ -104,7 +108,7 @@ func (pt primitiveType) String() string { func validPrimitiveValue(prim primitiveType) error { if prim < primitiveNull || prim > primitiveUUID { - return fmt.Errorf("invalid primitive type: %d", prim) + return fmt.Errorf("%w: primitive type: %d", arrow.ErrInvalid, prim) } return nil } @@ -133,6 +137,35 @@ func primitiveHeader(prim primitiveType) (byte, error) { return hdr, nil } +func marshalDecimal[T decimal.Decimal32 | decimal.Decimal64 | decimal.Decimal128](scale int8, val T, w io.Writer) (int, error) { + hdr := [2]byte{0, byte(scale)} + switch v := any(val).(type) { + case decimal.Decimal32: + hdr[0], _ = primitiveHeader(primitiveDecimal4) + if _, err := w.Write(hdr[:]); err != nil { + return 0, err + } + return 6, binary.Write(w, binary.LittleEndian, int32(v)) + case decimal.Decimal64: + hdr[0], _ = primitiveHeader(primitiveDecimal8) + if _, err := w.Write(hdr[:]); err != nil { + return 0, err + } + return 10, binary.Write(w, binary.LittleEndian, int64(v)) + case decimal.Decimal128: + hdr[0], _ = primitiveHeader(primitiveDecimal16) + if _, err := w.Write(hdr[:]); err != nil { + return 0, err + } + if err := binary.Write(w, binary.LittleEndian, v.LowBits()); err != nil { + return 2, err + } + return 18, binary.Write(w, binary.LittleEndian, v.HighBits()) + default: + panic("should never get here") + } +} + // marshalPrimitive takes in a primitive value, asserts its type, then marshals the data according to the Variant spec // into the provided writer, returning the number of bytes written. // @@ -149,33 +182,50 @@ func marshalPrimitive(v any, w io.Writer, opts ...MarshalOpts) (int, error) { case bool: return marshalBoolean(val, w), nil case int: - return marshalInt(int64(val), w), nil + if bits.UintSize == 32 { + return marshalNumeric(int32(val), w) + } + return marshalNumeric(int64(val), w) case int8: - return marshalInt(int64(val), w), nil + return marshalNumeric(val, w) + case uint8: + return marshalNumeric(int16(val), w) case int16: - return marshalInt(int64(val), w), nil + return marshalNumeric(val, w) + case uint16: + return marshalNumeric(int32(val), w) case int32: - return marshalInt(int64(val), w), nil + return marshalNumeric(val, w) + case uint32: + return marshalNumeric(int64(val), w) case int64: - if allOpts&MarshalAsTime != 0 { - encodeTimestamp(val, allOpts&MarshalTimeNanos != 0, allOpts&MarshalTimeNTZ != 0, w) - } - return marshalInt(val, w), nil + return marshalNumeric(val, w) + case uint64: + return 0, fmt.Errorf("%w: cannot marshal uint64 values", arrow.ErrInvalid) case float32: - return marshalFloat(val, w), nil + return marshalNumeric(val, w) case float64: - return marshalDouble(val, w), nil + return marshalNumeric(val, w) + case arrow.Date32: + return marshalNumeric(val, w) + case arrow.Date64: + return marshalNumeric(arrow.Date32FromTime(val.ToTime()), w) + case arrow.Time64: + return marshalNumeric(val, w) case uuid.UUID: - return marshalUUID(val, w), nil + return marshalUUID(val, w) case string: - return marshalString(val, w), nil + return marshalString(val, w) case []byte: - return marshalBinary(val, w), nil + return marshalBinary(val, w) + case arrow.Timestamp: + return encodeTimestamp(int64(val), allOpts&MarshalTimeNanos != 0, false, w) + // TODO: add decimal.Decimal32/Decimal64/Decimal128 case time.Time: if allOpts&MarshalAsDate != 0 { - return marshalDate(val, w), nil + return marshalNumeric(arrow.Date32FromTime(val), w) } - return marshalTimestamp(val, allOpts&MarshalTimeNanos != 0, w), nil + return marshalTimestamp(val, allOpts&MarshalTimeNanos != 0, w) } if v == nil { return marshalNull(w), nil @@ -401,30 +451,79 @@ func unmarshalBoolean(raw []byte, offset int) (bool, error) { return prim == primitiveTrue, nil } -// Encodes an integer with the appropriate primitive header. This encodes the int -// into the minimal space necessary regardless of the width that's passed in (eg. an -// int64 of value 1 will be encoded into an int8) -func marshalInt(val int64, w io.Writer) int { +func marshalNumeric[T float32 | float64 | int8 | int16 | int32 | int64 | arrow.Date32 | arrow.Time64](val T, w io.Writer) (int, error) { var hdr byte - var size int - if val < math.MaxInt8 && val > math.MinInt8 { + switch any(val).(type) { + case int8: hdr, _ = primitiveHeader(primitiveInt8) - size = 1 - } else if val < math.MaxInt16 && val > math.MinInt16 { + case int16: hdr, _ = primitiveHeader(primitiveInt16) - size = 2 - } else if val < math.MaxInt32 && val > math.MinInt32 { + case int32: hdr, _ = primitiveHeader(primitiveInt32) - size = 4 - } else { + case int64: hdr, _ = primitiveHeader(primitiveInt64) - size = 8 + case float32: + hdr, _ = primitiveHeader(primitiveFloat) + case float64: + hdr, _ = primitiveHeader(primitiveDouble) + case arrow.Date32: + hdr, _ = primitiveHeader(primitiveDate) + case arrow.Time64: + hdr, _ = primitiveHeader(primitiveTimeNTZ) } - w.Write([]byte{hdr}) - encodeNumber(val, size, w) - return size + 1 + + if _, err := w.Write([]byte{hdr}); err != nil { + return 0, err + } + return binary.Size(val) + 1, binary.Write(w, binary.LittleEndian, val) +} + +func marshalBinary[T string | []byte](val T, w io.Writer) (int, error) { + var buf [5]byte + switch any(val).(type) { + case []byte: + buf[0], _ = primitiveHeader(primitiveBinary) + case string: + buf[0], _ = primitiveHeader(primitiveString) + } + + binary.Encode(buf[1:], binary.LittleEndian, int32(len(val))) + n, err := w.Write(buf[:]) + if err != nil { + return n, err + } + + if c, err := w.Write([]byte(val)); err != nil { + return n + c, err + } + + return n + len(val), nil } +// // Encodes an integer with the appropriate primitive header. This encodes the int +// // into the minimal space necessary regardless of the width that's passed in (eg. an +// // int64 of value 1 will be encoded into an int8) +// func marshalInt(val int64, w io.Writer) int { +// var hdr byte +// var size int +// if val < math.MaxInt8 && val > math.MinInt8 { +// hdr, _ = primitiveHeader(primitiveInt8) +// size = 1 +// } else if val < math.MaxInt16 && val > math.MinInt16 { +// hdr, _ = primitiveHeader(primitiveInt16) +// size = 2 +// } else if val < math.MaxInt32 && val > math.MinInt32 { +// hdr, _ = primitiveHeader(primitiveInt32) +// size = 4 +// } else { +// hdr, _ = primitiveHeader(primitiveInt64) +// size = 8 +// } +// w.Write([]byte{hdr}) +// encodeNumber(val, size, w) +// return size + 1 +// } + func decodeIntPhysical(raw []byte, offset int) (int64, error) { typ, _ := primitiveFromHeader(raw[offset]) var size int @@ -459,31 +558,26 @@ func decodeIntPhysical(raw []byte, offset int) (int64, error) { } } -func marshalFloat(val float32, w io.Writer) int { - buf := make([]byte, 5) - hdr, _ := primitiveHeader(primitiveFloat) - buf[0] = hdr - bits := math.Float32bits(val) - for i := range 4 { - buf[i+1] = byte(bits) - bits >>= 8 - } - w.Write(buf) - return 5 -} - -func marshalDouble(val float64, w io.Writer) int { - buf := make([]byte, 9) - hdr, _ := primitiveHeader(primitiveDouble) - buf[0] = hdr - bits := math.Float64bits(val) - for i := range 8 { - buf[i+1] = byte(bits) - bits >>= 8 - } - w.Write(buf) - return 9 -} +// func marshalFloat(val float32, w io.Writer) (int, error) { +// buf := make([]byte, 5) +// hdr, _ := primitiveHeader(primitiveFloat) +// buf[0] = hdr +// binary.Encode(buf[1:], binary.LittleEndian, val) +// return w.Write(buf) +// } + +// func marshalDouble(val float64, w io.Writer) int { +// buf := make([]byte, 9) +// hdr, _ := primitiveHeader(primitiveDouble) +// buf[0] = hdr +// bits := math.Float64bits(val) +// for i := range 8 { +// buf[i+1] = byte(bits) +// bits >>= 8 +// } +// w.Write(buf) +// return 9 +// } func unmarshalFloat(raw []byte, offset int) (float32, error) { v, err := readUint(raw, offset+1, 4) @@ -501,13 +595,13 @@ func unmarshalDouble(raw []byte, offset int) (float64, error) { return math.Float64frombits(v), nil } -func encodePrimitiveBytes(b []byte, w io.Writer) int { - encodeNumber(int64(len(b)), 4, w) - w.Write(b) - return len(b) + 4 -} +// func encodePrimitiveBytes(b []byte, w io.Writer) int { +// encodeNumber(int64(len(b)), 4, w) +// w.Write(b) +// return len(b) + 4 +// } -func marshalString(str string, w io.Writer) int { +func marshalString(str string, w io.Writer) (int, error) { str = strings.ToValidUTF8(str, "\uFFFD") // If the string is 63 characters or less, encode this as a short string to save space. @@ -515,23 +609,25 @@ func marshalString(str string, w io.Writer) int { if strlen < 0x3F { hdr := byte(strlen << 2) hdr |= byte(BasicShortString) - w.Write([]byte{hdr}) - w.Write([]byte(str)) - return 1 + strlen + if _, err := w.Write([]byte{hdr}); err != nil { + return 0, err + } + n, err := w.Write([]byte(str)) + return 1 + n, err } - // Otherwise, encode this as a basic string. - hdr, _ := primitiveHeader(primitiveString) - w.Write([]byte{hdr}) - return 1 + encodePrimitiveBytes([]byte(strings.ToValidUTF8(str, "\uFFFD")), w) + return marshalBinary(str, w) } -func marshalUUID(u uuid.UUID, w io.Writer) int { +func marshalUUID(u uuid.UUID, w io.Writer) (int, error) { hdr, _ := primitiveHeader(primitiveUUID) - w.Write([]byte{hdr}) + if _, err := w.Write([]byte{hdr}); err != nil { + return 0, err + } + m, _ := u.MarshalBinary() // MarshalBinary() can never return an error - w.Write(m) - return 17 + n, err := w.Write(m) + return 1 + n, err } func unmarshalUUID(raw []byte, offset int) (uuid.UUID, error) { @@ -579,42 +675,27 @@ func getBytes(raw []byte, offset int) ([]byte, error) { return raw[offset+4 : maxIdx], nil } -func marshalBinary(b []byte, w io.Writer) int { - hdr, _ := primitiveHeader(primitiveBinary) - w.Write([]byte{hdr}) - return 1 + encodePrimitiveBytes(b, w) -} +// func marshalBinary(b []byte, w io.Writer) int { +// hdr, _ := primitiveHeader(primitiveBinary) +// w.Write([]byte{hdr}) +// return 1 + encodePrimitiveBytes(b, w) +// } func unmarshalBinary(raw []byte, offset int) ([]byte, error) { return getBytes(raw, offset+1) } -func marshalTimestamp(t time.Time, nanos bool, w io.Writer) int { - var typ primitiveType +func marshalTimestamp(t time.Time, nanos bool, w io.Writer) (int, error) { var ts int64 - ntz := t.Location() == time.UTC if nanos { ts = t.UnixNano() - if ntz { - typ = primitiveTimestampNTZNanos - } else { - typ = primitiveTimestampNanos - } } else { ts = t.UnixMicro() - if ntz { - typ = primitiveTimestampNTZMicros - } else { - typ = primitiveTimestampMicros - } } - hdr, _ := primitiveHeader(typ) - w.Write([]byte{hdr}) - encodeNumber(ts, 8, w) - return 9 + return encodeTimestamp(ts, nanos, t.Location() == time.UTC, w) } -func encodeTimestamp(t int64, nanos, ntz bool, w io.Writer) int { +func encodeTimestamp(t int64, nanos, ntz bool, w io.Writer) (int, error) { var typ primitiveType if nanos { if ntz { @@ -630,9 +711,10 @@ func encodeTimestamp(t int64, nanos, ntz bool, w io.Writer) int { } } hdr, _ := primitiveHeader(typ) - w.Write([]byte{hdr}) - encodeNumber(t, 8, w) - return 9 + if _, err := w.Write([]byte{hdr}); err != nil { + return 0, err + } + return 9, binary.Write(w, binary.LittleEndian, t) } func unmarshalTimestamp(raw []byte, offset int) (time.Time, error) { @@ -655,15 +737,15 @@ func unmarshalTimestamp(raw []byte, offset int) (time.Time, error) { return ret, nil } -func marshalDate(t time.Time, w io.Writer) int { - epoch := time.Unix(0, 0) - since := t.Sub(epoch) - days := int64(since.Hours() / 24) - hdr, _ := primitiveHeader(primitiveDate) - w.Write([]byte{hdr}) - encodeNumber(days, 4, w) - return 5 -} +// func marshalDate(t time.Time, w io.Writer) int { +// epoch := time.Unix(0, 0) +// since := t.Sub(epoch) +// days := int64(since.Hours() / 24) +// hdr, _ := primitiveHeader(primitiveDate) +// w.Write([]byte{hdr}) +// encodeNumber(days, 4, w) +// return 5 +// } func unmarshalDate(raw []byte, offset int) (time.Time, error) { days, err := readUint(raw, offset+1, 4) diff --git a/parquet/variants/primitive_test.go b/parquet/variants/primitive_test.go index c13fa92c..ccd1db43 100644 --- a/parquet/variants/primitive_test.go +++ b/parquet/variants/primitive_test.go @@ -22,8 +22,11 @@ import ( "testing" "time" + "github.com/apache/arrow-go/v18/arrow" "github.com/google/go-cmp/cmp" "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func diffByteArrays(t *testing.T, got, want []byte) { @@ -127,7 +130,8 @@ func TestInt(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { var b bytes.Buffer - size := marshalInt(c.val, &b) + size, err := marshalNumeric(c.val, &b) + require.NoError(t, err) encoded := b.Bytes() checkSize(t, size, encoded) if gotHdr := encoded[0]; gotHdr != c.wantHdr { @@ -167,7 +171,8 @@ func TestUUID(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { var b bytes.Buffer - size := marshalUUID(c.uuid, &b) + size, err := marshalUUID(c.uuid, &b) + require.NoError(t, err) if size != 17 { t.Fatalf("Incorrect size. Got %d, want 17", size) } @@ -185,7 +190,8 @@ func TestUUID(t *testing.T) { func TestFloat(t *testing.T) { var b bytes.Buffer - size := marshalFloat(1.1, &b) + size, err := marshalNumeric(1.1, &b) + require.NoError(t, err) encodedFloat := b.Bytes() checkSize(t, size, encodedFloat) diffByteArrays(t, encodedFloat, []byte{ @@ -206,7 +212,8 @@ func TestFloat(t *testing.T) { func TestDouble(t *testing.T) { var b bytes.Buffer - size := marshalDouble(1.1, &b) + size, err := marshalNumeric(float64(1.1), &b) + require.NoError(t, err) encodedDouble := b.Bytes() checkSize(t, size, encodedDouble) diffByteArrays(t, encodedDouble, []byte{ @@ -470,7 +477,8 @@ func TestString(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { var b bytes.Buffer - size := marshalString(c.str, &b) + size, err := marshalString(c.str, &b) + require.NoError(t, err) checkSize(t, size, c.wantEncoded) gotEncoded := b.Bytes() @@ -504,7 +512,8 @@ func TestBinary(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { var b bytes.Buffer - size := marshalBinary(c.bin, &b) + size, err := marshalBinary(c.bin, &b) + require.NoError(t, err) checkSize(t, size, c.wantEncoded) diff(t, b.Bytes(), c.wantEncoded) }) @@ -549,7 +558,8 @@ func TestTimestamp(t *testing.T) { ref = ref.Local() } var b bytes.Buffer - size := marshalTimestamp(ref, c.nanos, &b) + size, err := marshalTimestamp(ref, c.nanos, &b) + require.NoError(t, err) wantEncoded := []byte{c.wantHdr} if c.nanos { wantEncoded = append(wantEncoded, []byte{ @@ -591,7 +601,8 @@ func TestTimestamp(t *testing.T) { func TestDate(t *testing.T) { day := time.Unix(0, 0).Add(10000 * 24 * time.Hour) var b bytes.Buffer - size := marshalDate(day, &b) + size, err := marshalNumeric(arrow.Date32FromTime(day), &b) + require.NoError(t, err) encodedDate := b.Bytes() checkSize(t, size, encodedDate) diffByteArrays(t, encodedDate, []byte{ @@ -602,10 +613,6 @@ func TestDate(t *testing.T) { 0x00, // 10000 = 0x0000 2710 }) got, err := unmarshalDate(encodedDate, 0) - if err != nil { - t.Fatalf("unmarshalDate(): %v", err) - } - if want := day; got != want { - t.Fatalf("Incorrect date: got %s, want %s", got, want) - } + require.NoError(t, err) + assert.Equal(t, got, day) } From c56d099196592bdcaad69d3177f5c2b890951cb9 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 23 May 2025 15:35:37 -0400 Subject: [PATCH 04/10] refactor and redesign. add docs --- arrow/extensions/variant/builder.go | 491 ---------- arrow/extensions/variant/builder_test.go | 297 ------ .../variant/basic_type_string.go | 0 parquet/variant/builder.go | 847 ++++++++++++++++++ parquet/variant/builder_test.go | 787 ++++++++++++++++ parquet/variant/doc.go | 142 +++ .../variant/primitive_type_string.go | 0 .../extensions => parquet}/variant/utils.go | 2 +- .../extensions => parquet}/variant/variant.go | 102 ++- .../variant/variant_test.go | 8 +- parquet/variants/array.go | 248 ----- parquet/variants/array_test.go | 423 --------- parquet/variants/builder.go | 200 ----- parquet/variants/builder_test.go | 307 ------- parquet/variants/decoder.go | 79 -- parquet/variants/decoder_test.go | 259 ------ parquet/variants/doc.go | 38 - parquet/variants/metadata.go | 152 ---- parquet/variants/metadata_test.go | 282 ------ parquet/variants/object.go | 456 ---------- parquet/variants/object_test.go | 711 --------------- parquet/variants/primitive.go | 756 ---------------- parquet/variants/primitive_test.go | 618 ------------- parquet/variants/testutils.go | 30 - parquet/variants/util.go | 161 ---- parquet/variants/util_test.go | 331 ------- parquet/variants/variant.go | 53 -- 27 files changed, 1870 insertions(+), 5910 deletions(-) delete mode 100644 arrow/extensions/variant/builder.go delete mode 100644 arrow/extensions/variant/builder_test.go rename {arrow/extensions => parquet}/variant/basic_type_string.go (100%) create mode 100644 parquet/variant/builder.go create mode 100644 parquet/variant/builder_test.go create mode 100644 parquet/variant/doc.go rename {arrow/extensions => parquet}/variant/primitive_type_string.go (100%) rename {arrow/extensions => parquet}/variant/utils.go (98%) rename {arrow/extensions => parquet}/variant/variant.go (80%) rename {arrow/extensions => parquet}/variant/variant_test.go (98%) delete mode 100644 parquet/variants/array.go delete mode 100644 parquet/variants/array_test.go delete mode 100644 parquet/variants/builder.go delete mode 100644 parquet/variants/builder_test.go delete mode 100644 parquet/variants/decoder.go delete mode 100644 parquet/variants/decoder_test.go delete mode 100644 parquet/variants/doc.go delete mode 100644 parquet/variants/metadata.go delete mode 100644 parquet/variants/metadata_test.go delete mode 100644 parquet/variants/object.go delete mode 100644 parquet/variants/object_test.go delete mode 100644 parquet/variants/primitive.go delete mode 100644 parquet/variants/primitive_test.go delete mode 100644 parquet/variants/testutils.go delete mode 100644 parquet/variants/util.go delete mode 100644 parquet/variants/util_test.go delete mode 100644 parquet/variants/variant.go diff --git a/arrow/extensions/variant/builder.go b/arrow/extensions/variant/builder.go deleted file mode 100644 index 2fee6bad..00000000 --- a/arrow/extensions/variant/builder.go +++ /dev/null @@ -1,491 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variant - -import ( - "bytes" - "cmp" - "encoding/binary" - "errors" - "fmt" - "io" - "math" - "slices" - "unsafe" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/decimal" - "github.com/google/uuid" -) - -type Builder struct { - buf bytes.Buffer - dict map[string]uint32 - dictKeys [][]byte - allowDuplicates bool -} - -func (b *Builder) SetAllowDuplicates(allow bool) { - b.allowDuplicates = allow -} - -func (b *Builder) AddKeys(keys []string) (ids []uint32) { - if b.dict == nil { - b.dict = make(map[string]uint32) - b.dictKeys = make([][]byte, 0, len(keys)) - } - - ids = make([]uint32, len(keys)) - for i, key := range keys { - var ok bool - if ids[i], ok = b.dict[key]; ok { - continue - } - - ids[i] = uint32(len(b.dictKeys)) - b.dict[key] = ids[i] - b.dictKeys = append(b.dictKeys, unsafe.Slice(unsafe.StringData(key), len(key))) - } - - return ids -} - -func (b *Builder) AddKey(key string) (id uint32) { - if b.dict == nil { - b.dict = make(map[string]uint32) - b.dictKeys = make([][]byte, 0, 16) - } - - var ok bool - if id, ok = b.dict[key]; ok { - return id - } - - id = uint32(len(b.dictKeys)) - b.dict[key] = id - b.dictKeys = append(b.dictKeys, unsafe.Slice(unsafe.StringData(key), len(key))) - - return id -} - -func (b *Builder) AppendNull() error { - return b.buf.WriteByte(primitiveHeader(PrimitiveNull)) -} - -func (b *Builder) AppendBool(v bool) error { - var t PrimitiveType - if v { - t = PrimitiveBoolTrue - } else { - t = PrimitiveBoolFalse - } - - return b.buf.WriteByte(primitiveHeader(t)) -} - -type primitiveNumeric interface { - int8 | int16 | int32 | int64 | float32 | float64 | - arrow.Date32 | arrow.Time64 -} - -type buffer interface { - io.Writer - io.ByteWriter -} - -func writeBinary[T string | []byte](w buffer, v T) error { - var t PrimitiveType - switch any(v).(type) { - case string: - t = PrimitiveString - case []byte: - t = PrimitiveBinary - } - - if err := w.WriteByte(primitiveHeader(t)); err != nil { - return err - } - - if err := binary.Write(w, binary.LittleEndian, uint32(len(v))); err != nil { - return err - } - - _, err := w.Write([]byte(v)) - return err -} - -func writeNumeric[T primitiveNumeric](w buffer, v T) error { - var t PrimitiveType - switch any(v).(type) { - case int8: - t = PrimitiveInt8 - case int16: - t = PrimitiveInt16 - case int32: - t = PrimitiveInt32 - case int64: - t = PrimitiveInt64 - case float32: - t = PrimitiveFloat - case float64: - t = PrimitiveDouble - case arrow.Date32: - t = PrimitiveDate - case arrow.Time64: - t = PrimitiveTimeMicrosNTZ - } - - if err := w.WriteByte(primitiveHeader(t)); err != nil { - return err - } - - return binary.Write(w, binary.LittleEndian, v) -} - -func (b *Builder) AppendInt(v int64) error { - b.buf.Grow(9) - switch { - case v >= math.MinInt8 && v <= math.MaxInt8: - return writeNumeric(&b.buf, int8(v)) - case v >= math.MinInt16 && v <= math.MaxInt16: - return writeNumeric(&b.buf, int16(v)) - case v >= math.MinInt32 && v <= math.MaxInt32: - return writeNumeric(&b.buf, int32(v)) - default: - return writeNumeric(&b.buf, v) - } -} - -func (b *Builder) AppendFloat32(v float32) error { - b.buf.Grow(5) - return writeNumeric(&b.buf, v) -} - -func (b *Builder) AppendFloat64(v float64) error { - b.buf.Grow(9) - return writeNumeric(&b.buf, v) -} - -func (b *Builder) AppendDate(v arrow.Date32) error { - b.buf.Grow(5) - return writeNumeric(&b.buf, v) -} - -func (b *Builder) AppendTimeMicro(v arrow.Time64) error { - b.buf.Grow(9) - return writeNumeric(&b.buf, v) -} - -func (b *Builder) AppendTimestamp(v arrow.Timestamp, useMicros, useUTC bool) error { - b.buf.Grow(9) - var t PrimitiveType - if useMicros { - t = PrimitiveTimestampMicrosNTZ - } else { - t = PrimitiveTimestampNanosNTZ - } - - if useUTC { - t-- - } - - if err := b.buf.WriteByte(primitiveHeader(t)); err != nil { - return err - } - - return binary.Write(&b.buf, binary.LittleEndian, v) -} - -func (b *Builder) AppendBinary(v []byte) error { - b.buf.Grow(5 + len(v)) - return writeBinary(&b.buf, v) -} - -func (b *Builder) AppendString(v string) error { - if len(v) > maxShortStringSize { - b.buf.Grow(5 + len(v)) - return writeBinary(&b.buf, v) - } - - b.buf.Grow(1 + len(v)) - if err := b.buf.WriteByte(shortStrHeader(len(v))); err != nil { - return err - } - - _, err := b.buf.WriteString(v) - return err -} - -func (b *Builder) AppendUUID(v uuid.UUID) error { - b.buf.Grow(17) - if err := b.buf.WriteByte(primitiveHeader(PrimitiveUUID)); err != nil { - return err - } - - m, _ := v.MarshalBinary() - _, err := b.buf.Write(m) - return err -} - -func (b *Builder) AppendDecimal4(scale uint8, v decimal.Decimal32) error { - b.buf.Grow(6) - if err := b.buf.WriteByte(primitiveHeader(PrimitiveDecimal4)); err != nil { - return err - } - - if err := b.buf.WriteByte(scale); err != nil { - return err - } - - return binary.Write(&b.buf, binary.LittleEndian, int32(v)) -} - -func (b *Builder) AppendDecimal8(scale uint8, v decimal.Decimal64) error { - b.buf.Grow(10) - return errors.Join( - b.buf.WriteByte(primitiveHeader(PrimitiveDecimal8)), - b.buf.WriteByte(scale), - binary.Write(&b.buf, binary.LittleEndian, int64(v)), - ) -} - -func (b *Builder) AppendDecimal16(scale uint8, v decimal.Decimal128) error { - b.buf.Grow(18) - return errors.Join( - b.buf.WriteByte(primitiveHeader(PrimitiveDecimal16)), - b.buf.WriteByte(scale), - binary.Write(&b.buf, binary.LittleEndian, v.LowBits()), - binary.Write(&b.buf, binary.LittleEndian, v.HighBits()), - ) -} - -func (b *Builder) Offset() int { - return b.buf.Len() -} - -func (b *Builder) FinishArray(start int, offsets []int) error { - var ( - dataSize, sz = b.buf.Len() - start, len(offsets) - isLarge = sz > math.MaxUint8 - sizeBytes = 1 - ) - - if isLarge { - sizeBytes = 4 - } - - if dataSize < 0 { - return errors.New("invalid array size") - } - - offsetSize := intSize(dataSize) - headerSize := 1 + sizeBytes + (sz+1)*int(offsetSize) - - // shift the just written data to make room for the header section - b.buf.Grow(headerSize) - av := b.buf.AvailableBuffer() - if _, err := b.buf.Write(av[:headerSize]); err != nil { - return err - } - - bs := b.buf.Bytes() - copy(bs[start+headerSize:], bs[start:start+dataSize]) - - // populate the header - bs[start] = arrayHeader(isLarge, offsetSize) - writeOffset(bs[start+1:], sz, uint8(sizeBytes)) - - offsetsStart := start + 1 + sizeBytes - for i, off := range offsets { - writeOffset(bs[offsetsStart+i*int(offsetSize):], off, offsetSize) - } - writeOffset(bs[offsetsStart+sz*int(offsetSize):], dataSize, offsetSize) - - return nil -} - -type FieldEntry struct { - Key string - ID uint32 - Offset int -} - -func (b *Builder) NextField(start int, key string) FieldEntry { - id := b.AddKey(key) - return FieldEntry{ - Key: key, - ID: id, - Offset: b.Offset() - start, - } -} - -func (b *Builder) FinishObject(start int, fields []FieldEntry) error { - slices.SortFunc(fields, func(a, b FieldEntry) int { - return cmp.Compare(a.Key, b.Key) - }) - - sz := len(fields) - var maxID uint32 - if sz > 0 { - maxID = fields[0].ID - } - - // if a duplicate key is found, one of two things happens: - // - if allowDuplicates is true, then the field with the greatest - // offset value (the last appended field) is kept. - // - if allowDuplicates is false, then an error is returned - if b.allowDuplicates { - distinctPos := 0 - // maintain a list of distinct keys in-place - for i := 1; i < sz; i++ { - maxID = max(maxID, fields[i].ID) - if fields[i].ID == fields[i-1].ID { - // found a duplicate key. keep the - // field with a greater offset - if fields[distinctPos].Offset < fields[i].Offset { - fields[distinctPos].Offset = fields[i].Offset - } - } else { - // found distinct key, add field to the list - distinctPos++ - fields[distinctPos] = fields[i] - } - } - - if distinctPos+1 < len(fields) { - sz = distinctPos + 1 - // resize fields to size - fields = fields[:sz] - // sort the fields by offsets so that we can move the value - // data of each field to the new offset without overwriting the - // fields after it. - slices.SortFunc(fields, func(a, b FieldEntry) int { - return cmp.Compare(a.Offset, b.Offset) - }) - - buf := b.buf.Bytes() - curOffset := 0 - for i := range sz { - oldOffset := fields[i].Offset - fieldSize := valueSize(buf[start+oldOffset:]) - copy(buf[start+curOffset:], buf[start+oldOffset:start+oldOffset+fieldSize]) - fields[i].Offset = curOffset - curOffset += fieldSize - } - b.buf.Truncate(start + curOffset) - // change back to sort order by field keys to meet variant spec - slices.SortFunc(fields, func(a, b FieldEntry) int { - return cmp.Compare(a.Key, b.Key) - }) - } - } else { - for i := 1; i < sz; i++ { - maxID = max(maxID, fields[i].ID) - if fields[i].Key == fields[i-1].Key { - return fmt.Errorf("disallowed duplicate key found: %s", fields[i].Key) - } - } - } - - var ( - dataSize = b.buf.Len() - start - isLarge = sz > math.MaxUint8 - sizeBytes = 1 - idSize, offsetSize = intSize(int(maxID)), intSize(dataSize) - ) - - if isLarge { - sizeBytes = 4 - } - - if dataSize < 0 { - return errors.New("invalid object size") - } - - headerSize := 1 + sizeBytes + sz*int(idSize) + (sz+1)*int(offsetSize) - // shift the just written data to make room for the header section - b.buf.Grow(headerSize) - av := b.buf.AvailableBuffer() - if _, err := b.buf.Write(av[:headerSize]); err != nil { - return err - } - - bs := b.buf.Bytes() - copy(bs[start+headerSize:], bs[start:start+dataSize]) - - // populate the header - bs[start] = objectHeader(isLarge, idSize, offsetSize) - writeOffset(bs[start+1:], sz, uint8(sizeBytes)) - - idStart := start + 1 + sizeBytes - offsetStart := idStart + sz*int(idSize) - for i, field := range fields { - writeOffset(bs[idStart+i*int(idSize):], int(field.ID), idSize) - writeOffset(bs[offsetStart+i*int(offsetSize):], field.Offset, offsetSize) - } - writeOffset(bs[offsetStart+sz*int(offsetSize):], dataSize, offsetSize) - return nil -} - -func (b *Builder) Build() (Value, error) { - nkeys := len(b.dictKeys) - totalDictSize := 0 - for _, k := range b.dictKeys { - totalDictSize += len(k) - } - - // determine the number of bytes required per offset entry. - // the largest offset is the one-past-the-end value, the total size. - // It's very unlikely that the number of keys could be larger, but - // incorporate that into the calculation in case of pathological data. - maxSize := max(totalDictSize, nkeys) - if maxSize > maxSizeLimit { - return Value{}, fmt.Errorf("metadata size too large: %d", maxSize) - } - - offsetSize := intSize(int(maxSize)) - offsetStart := 1 + offsetSize - stringStart := int(offsetStart) + (nkeys+1)*int(offsetSize) - metadataSize := stringStart + totalDictSize - - if metadataSize > maxSizeLimit { - return Value{}, fmt.Errorf("metadata size too large: %d", metadataSize) - } - - meta := make([]byte, metadataSize) - - meta[0] = supportedVersion | ((offsetSize - 1) << 6) - if nkeys > 0 && slices.IsSortedFunc(b.dictKeys, bytes.Compare) { - meta[0] |= 1 << 4 - } - writeOffset(meta[1:], nkeys, offsetSize) - - curOffset := 0 - for i, k := range b.dictKeys { - writeOffset(meta[int(offsetStart)+i*int(offsetSize):], curOffset, offsetSize) - curOffset += copy(meta[stringStart+curOffset:], k) - } - writeOffset(meta[int(offsetStart)+nkeys*int(offsetSize):], curOffset, offsetSize) - - return Value{ - value: b.buf.Bytes(), - meta: Metadata{ - data: meta, - keys: b.dictKeys, - }, - }, nil -} diff --git a/arrow/extensions/variant/builder_test.go b/arrow/extensions/variant/builder_test.go deleted file mode 100644 index ae49bcf7..00000000 --- a/arrow/extensions/variant/builder_test.go +++ /dev/null @@ -1,297 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variant_test - -import ( - "encoding/json" - "testing" - - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/extensions/variant" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestBuildNullValue(t *testing.T) { - var b variant.Builder - b.AppendNull() - - v, err := b.Build() - require.NoError(t, err) - - assert.Equal(t, variant.Null, v.Type()) - assert.EqualValues(t, 1, v.Metadata().Version()) - assert.Zero(t, v.Metadata().DictionarySize()) -} - -func TestBuildPrimitive(t *testing.T) { - tests := []struct { - name string - op func(*variant.Builder) error - }{ - {"primitive_boolean_true", func(b *variant.Builder) error { - return b.AppendBool(true) - }}, - {"primitive_boolean_false", func(b *variant.Builder) error { - return b.AppendBool(false) - }}, - // AppendInt will use the smallest possible int type - {"primitive_int8", func(b *variant.Builder) error { return b.AppendInt(42) }}, - {"primitive_int16", func(b *variant.Builder) error { return b.AppendInt(1234) }}, - {"primitive_int32", func(b *variant.Builder) error { return b.AppendInt(123456) }}, - // FIXME: https://github.com/apache/parquet-testing/issues/82 - // primitive_int64 is an int32 value, but the metadata is int64 - {"primitive_int64", func(b *variant.Builder) error { return b.AppendInt(12345678) }}, - {"primitive_float", func(b *variant.Builder) error { return b.AppendFloat32(1234568000) }}, - {"primitive_double", func(b *variant.Builder) error { return b.AppendFloat64(1234567890.1234) }}, - {"primitive_string", func(b *variant.Builder) error { - return b.AppendString(`This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥️, 🎣 and 🤦!!`) - }}, - {"short_string", func(b *variant.Builder) error { return b.AppendString(`Less than 64 bytes (❤️ with utf8)`) }}, - // 031337deadbeefcafe - {"primitive_binary", func(b *variant.Builder) error { - return b.AppendBinary([]byte{0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}) - }}, - {"primitive_decimal4", func(b *variant.Builder) error { return b.AppendDecimal4(2, 1234) }}, - {"primitive_decimal8", func(b *variant.Builder) error { return b.AppendDecimal8(2, 1234567890) }}, - {"primitive_decimal16", func(b *variant.Builder) error { return b.AppendDecimal16(2, decimal128.FromU64(1234567891234567890)) }}, - {"primitive_date", func(b *variant.Builder) error { return b.AppendDate(20194) }}, - {"primitive_timestamp", func(b *variant.Builder) error { return b.AppendTimestamp(1744821296780000, true, true) }}, - {"primitive_timestampntz", func(b *variant.Builder) error { return b.AppendTimestamp(1744806896780000, true, false) }}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - expected := loadVariant(t, tt.name) - - var b variant.Builder - require.NoError(t, tt.op(&b)) - - v, err := b.Build() - require.NoError(t, err) - - assert.Equal(t, expected.Type(), v.Type()) - assert.Equal(t, expected.Bytes(), v.Bytes()) - assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) - }) - } -} - -func TestBuildInt64(t *testing.T) { - var b variant.Builder - require.NoError(t, b.AppendInt(1234567890987654321)) - - v, err := b.Build() - require.NoError(t, err) - assert.Equal(t, variant.Int64, v.Type()) - assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveInt64), - 0xB1, 0x1C, 0x6C, 0xB1, 0xF4, 0x10, 0x22, 0x11}, v.Bytes()) -} - -func TestBuildObjec(t *testing.T) { - var b variant.Builder - start := b.Offset() - - fields := make([]variant.FieldEntry, 0, 7) - - fields = append(fields, b.NextField(start, "int_field")) - require.NoError(t, b.AppendInt(1)) - - fields = append(fields, b.NextField(start, "double_field")) - require.NoError(t, b.AppendDecimal4(8, 123456789)) - - fields = append(fields, b.NextField(start, "boolean_true_field")) - require.NoError(t, b.AppendBool(true)) - - fields = append(fields, b.NextField(start, "boolean_false_field")) - require.NoError(t, b.AppendBool(false)) - - fields = append(fields, b.NextField(start, "string_field")) - require.NoError(t, b.AppendString("Apache Parquet")) - - fields = append(fields, b.NextField(start, "null_field")) - require.NoError(t, b.AppendNull()) - - fields = append(fields, b.NextField(start, "timestamp_field")) - require.NoError(t, b.AppendString("2025-04-16T12:34:56.78")) - - require.NoError(t, b.FinishObject(start, fields)) - v, err := b.Build() - require.NoError(t, err) - - assert.Equal(t, variant.Object, v.Type()) - expected := loadVariant(t, "object_primitive") - - assert.Equal(t, expected.Metadata().DictionarySize(), v.Metadata().DictionarySize()) - assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) - assert.Equal(t, expected.Bytes(), v.Bytes()) -} - -func TestBuildObjectNested(t *testing.T) { - var b variant.Builder - - start := b.Offset() - topFields := make([]variant.FieldEntry, 0, 3) - - topFields = append(topFields, b.NextField(start, "id")) - require.NoError(t, b.AppendInt(1)) - - topFields = append(topFields, b.NextField(start, "observation")) - - observeFields := make([]variant.FieldEntry, 0, 3) - observeStart := b.Offset() - observeFields = append(observeFields, b.NextField(observeStart, "location")) - require.NoError(t, b.AppendString("In the Volcano")) - observeFields = append(observeFields, b.NextField(observeStart, "time")) - require.NoError(t, b.AppendString("12:34:56")) - observeFields = append(observeFields, b.NextField(observeStart, "value")) - - valueStart := b.Offset() - valueFields := make([]variant.FieldEntry, 0, 2) - valueFields = append(valueFields, b.NextField(valueStart, "humidity")) - require.NoError(t, b.AppendInt(456)) - valueFields = append(valueFields, b.NextField(valueStart, "temperature")) - require.NoError(t, b.AppendInt(123)) - - require.NoError(t, b.FinishObject(valueStart, valueFields)) - require.NoError(t, b.FinishObject(observeStart, observeFields)) - - topFields = append(topFields, b.NextField(start, "species")) - speciesStart := b.Offset() - speciesFields := make([]variant.FieldEntry, 0, 2) - speciesFields = append(speciesFields, b.NextField(speciesStart, "name")) - require.NoError(t, b.AppendString("lava monster")) - - speciesFields = append(speciesFields, b.NextField(speciesStart, "population")) - require.NoError(t, b.AppendInt(6789)) - - require.NoError(t, b.FinishObject(speciesStart, speciesFields)) - require.NoError(t, b.FinishObject(start, topFields)) - - v, err := b.Build() - require.NoError(t, err) - - out, err := json.Marshal(v) - require.NoError(t, err) - - assert.JSONEq(t, `{ - "id": 1, - "observation": { - "location": "In the Volcano", - "time": "12:34:56", - "value": { - "humidity": 456, - "temperature": 123 - } - }, - "species": { - "name": "lava monster", - "population": 6789 - } - }`, string(out)) -} - -func TestBuildUUID(t *testing.T) { - u := uuid.MustParse("00112233-4455-6677-8899-aabbccddeeff") - - var b variant.Builder - require.NoError(t, b.AppendUUID(u)) - v, err := b.Build() - require.NoError(t, err) - assert.Equal(t, variant.UUID, v.Type()) - assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveUUID), - 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, - 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, v.Bytes()) -} - -func TestBuildTimestampNanos(t *testing.T) { - t.Run("ts nanos tz negative", func(t *testing.T) { - var b variant.Builder - require.NoError(t, b.AppendTimestamp(-1, false, true)) - - v, err := b.Build() - require.NoError(t, err) - assert.Equal(t, variant.TimestampNanos, v.Type()) - assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanos), - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, v.Bytes()) - }) - - t.Run("ts nanos tz positive", func(t *testing.T) { - var b variant.Builder - require.NoError(t, b.AppendTimestamp(1744877350123456789, false, true)) - v, err := b.Build() - require.NoError(t, err) - assert.Equal(t, variant.TimestampNanos, v.Type()) - assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanos), - 0x15, 0xC9, 0xBB, 0x86, 0xB4, 0x0C, 0x37, 0x18}, v.Bytes()) - }) - - t.Run("ts nanos ntz positive", func(t *testing.T) { - var b variant.Builder - require.NoError(t, b.AppendTimestamp(1744877350123456789, false, false)) - v, err := b.Build() - require.NoError(t, err) - assert.Equal(t, variant.TimestampNanosNTZ, v.Type()) - assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanosNTZ), - 0x15, 0xC9, 0xBB, 0x86, 0xB4, 0x0C, 0x37, 0x18}, v.Bytes()) - }) -} - -func TestBuildArrayValues(t *testing.T) { - t.Run("array primitive", func(t *testing.T) { - var b variant.Builder - - start := b.Offset() - offsets := make([]int, 0, 4) - - offsets = append(offsets, b.Offset()-start) - require.NoError(t, b.AppendInt(2)) - - offsets = append(offsets, b.Offset()-start) - require.NoError(t, b.AppendInt(1)) - - offsets = append(offsets, b.Offset()-start) - require.NoError(t, b.AppendInt(5)) - - offsets = append(offsets, b.Offset()-start) - require.NoError(t, b.AppendInt(9)) - - require.NoError(t, b.FinishArray(start, offsets)) - - v, err := b.Build() - require.NoError(t, err) - assert.Equal(t, variant.Array, v.Type()) - - expected := loadVariant(t, "array_primitive") - assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) - assert.Equal(t, expected.Bytes(), v.Bytes()) - }) - - t.Run("array empty", func(t *testing.T) { - var b variant.Builder - - require.NoError(t, b.FinishArray(b.Offset(), nil)) - v, err := b.Build() - require.NoError(t, err) - assert.Equal(t, variant.Array, v.Type()) - - expected := loadVariant(t, "array_empty") - assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) - assert.Equal(t, expected.Bytes(), v.Bytes()) - }) -} diff --git a/arrow/extensions/variant/basic_type_string.go b/parquet/variant/basic_type_string.go similarity index 100% rename from arrow/extensions/variant/basic_type_string.go rename to parquet/variant/basic_type_string.go diff --git a/parquet/variant/builder.go b/parquet/variant/builder.go new file mode 100644 index 00000000..70f1d9a1 --- /dev/null +++ b/parquet/variant/builder.go @@ -0,0 +1,847 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variant + +import ( + "bytes" + "cmp" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "reflect" + "slices" + "strings" + "time" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/google/uuid" +) + +// Builder is used to construct Variant values by appending data of various types. +// It manages an internal buffer for the value data and a dictionary for field keys. +type Builder struct { + buf bytes.Buffer + dict map[string]uint32 + dictKeys [][]byte + allowDuplicates bool +} + +// SetAllowDuplicates controls whether duplicate keys are allowed in objects. +// When true, the last value for a key is used. When false, an error is returned +// if a duplicate key is detected. +func (b *Builder) SetAllowDuplicates(allow bool) { + b.allowDuplicates = allow +} + +// AddKey adds a key to the builder's dictionary and returns its ID. +// If the key already exists in the dictionary, its existing ID is returned. +func (b *Builder) AddKey(key string) (id uint32) { + if b.dict == nil { + b.dict = make(map[string]uint32) + b.dictKeys = make([][]byte, 0, 16) + } + + var ok bool + if id, ok = b.dict[key]; ok { + return id + } + + id = uint32(len(b.dictKeys)) + b.dict[key] = id + b.dictKeys = append(b.dictKeys, unsafe.Slice(unsafe.StringData(key), len(key))) + + return id +} + +// AppendOpt represents options for appending time-related values. These are only +// used when using the generic Append method that takes an interface{}. +type AppendOpt int16 + +const ( + // OptTimestampNano specifies that timestamps should use nanosecond precision, + // otherwise microsecond precision is used. + OptTimestampNano AppendOpt = 1 << iota + // OptTimestampUTC specifies that timestamps should be in UTC timezone, otherwise + // no time zone (NTZ) is used. + OptTimestampUTC + // OptTimeAsDate specifies that time.Time values should be encoded as dates + OptTimeAsDate + // OptTimeAsTime specifies that time.Time values should be encoded as a time value + OptTimeAsTime +) + +func extractFieldInfo(f reflect.StructField) (name string, o AppendOpt) { + tag := f.Tag.Get("variant") + if tag == "" { + return f.Name, 0 + } + + parts := strings.Split(tag, ",") + if len(parts) == 1 { + return parts[0], 0 + } + + name = parts[0] + if name == "" { + name = f.Name + } + + for _, opt := range parts[1:] { + switch strings.ToLower(opt) { + case "nanos": + o |= OptTimestampNano + case "utc": + o |= OptTimestampUTC + case "date": + o |= OptTimeAsDate + case "time": + o |= OptTimeAsTime + } + } + + return name, o +} + +// Append adds a value of any supported type to the builder. +// +// Any basic primitive type is supported, the AppendOpt options are used to control how +// timestamps are appended (e.g., as microseconds or nanoseconds and timezone). The options +// also control how a [time.Time] value is appended (e.g., as a date, timestamp, or time). +// +// Appending a value with type `[]any` will construct an array appropriately, appending +// each element. Calling with a map[string]any will construct an object, recursively calling +// Append for each value, propagating the options. +// +// For other types (arbitrary slices, arrays, maps and structs), reflection is used to determine +// the type and whether we can append it. A nil pointer will append a null, while a non-nil +// pointer will append the value that it points to. +// +// For structs, field tags can be used to control the field names and options. Only exported +// fields are considered, with the field name being used as the key. A struct tag of `variant` +// can be used with the following format and options: +// +// type MyStruct struct { +// Field1 string `variant:"key"` // Use "key" instead of "Field1" as the field name +// Field2 time.Time `variant:"day,date"` // Use "day" instead of "Field2" as the field name +// // append this value as a "date" value +// Time time.Time `variant:",time"` // Use "Time" as the field name, append the value as +// // a "time" value +// Field3 int `variant:"-"` // Ignore this field +// Timestamp time.Time `variant:"ts"` // Use "ts" as the field name, append value as a +// // timestamp(UTC=false,MICROS) +// Ts2 time.Time `variant:"ts2,nanos,utc"` // Use "ts2" as the field name, append value as a +// // timestamp(UTC=true,NANOS) +// } +// +// Options specified in the struct tags will be OR'd with any options passed to the original call +// to Append. +func (b *Builder) Append(v any, opts ...AppendOpt) error { + var o AppendOpt + for _, opt := range opts { + o |= opt + } + + return b.append(v, o) +} + +func (b *Builder) append(v any, o AppendOpt) error { + switch v := v.(type) { + case nil: + return b.AppendNull() + case bool: + return b.AppendBool(v) + case int8: + return b.AppendInt(int64(v)) + case uint8: + return b.AppendInt(int64(v)) + case int16: + return b.AppendInt(int64(v)) + case uint16: + return b.AppendInt(int64(v)) + case int32: + return b.AppendInt(int64(v)) + case uint32: + return b.AppendInt(int64(v)) + case int64: + return b.AppendInt(v) + case int: + return b.AppendInt(int64(v)) + case uint: + return b.AppendInt(int64(v)) + case float32: + return b.AppendFloat32(v) + case float64: + return b.AppendFloat64(v) + case arrow.Date32: + return b.AppendDate(v) + case arrow.Time64: + return b.AppendTimeMicro(v) + case arrow.Timestamp: + return b.AppendTimestamp(v, o&OptTimestampNano == 0, o&OptTimestampUTC != 0) + case []byte: + return b.AppendBinary(v) + case string: + return b.AppendString(v) + case uuid.UUID: + return b.AppendUUID(v) + case time.Time: + switch { + case o&OptTimeAsDate != 0: + return b.AppendDate(arrow.Date32FromTime(v)) + case o&OptTimeAsTime != 0: + t := v.Sub(v.Truncate(24 * time.Hour)) + return b.AppendTimeMicro(arrow.Time64(t.Microseconds())) + default: + unit := arrow.Microsecond + if o&OptTimestampNano != 0 { + unit = arrow.Nanosecond + } + + if o&OptTimestampUTC != 0 { + v = v.UTC() + } + + t, err := arrow.TimestampFromTime(v, unit) + if err != nil { + return err + } + + return b.AppendTimestamp(t, o&OptTimestampNano == 0, o&OptTimestampUTC != 0) + } + case DecimalValue[decimal.Decimal32]: + return b.AppendDecimal4(v.Scale, v.Value.(decimal.Decimal32)) + case DecimalValue[decimal.Decimal64]: + return b.AppendDecimal8(v.Scale, v.Value.(decimal.Decimal64)) + case DecimalValue[decimal.Decimal128]: + return b.AppendDecimal16(v.Scale, v.Value.(decimal.Decimal128)) + case []any: + start, offsets := b.Offset(), make([]int, 0, len(v)) + for _, item := range v { + offsets = append(offsets, b.NextElement(start)) + if err := b.append(item, o); err != nil { + return err + } + } + return b.FinishArray(start, offsets) + case map[string]any: + start, fields := b.Offset(), make([]FieldEntry, 0, len(v)) + for key, item := range v { + fields = append(fields, b.NextField(start, key)) + if err := b.append(item, o); err != nil { + return err + } + } + return b.FinishObject(start, fields) + default: + // attempt to use reflection before we give up! + val := reflect.ValueOf(v) + switch val.Kind() { + case reflect.Pointer, reflect.Interface: + if val.IsNil() { + return b.AppendNull() + } + return b.append(val.Elem().Interface(), o) + case reflect.Array, reflect.Slice: + start, offsets := b.Offset(), make([]int, 0, val.Len()) + for _, item := range val.Seq2() { + offsets = append(offsets, b.NextElement(start)) + if err := b.append(item.Interface(), o); err != nil { + return err + } + } + return b.FinishArray(start, offsets) + case reflect.Map: + if val.Type().Key().Kind() != reflect.String { + return fmt.Errorf("unsupported map key type: %s", val.Type().Key()) + } + + start, fields := b.Offset(), make([]FieldEntry, 0, val.Len()) + for k, v := range val.Seq2() { + fields = append(fields, b.NextField(start, k.String())) + if err := b.append(v.Interface(), o); err != nil { + return err + } + } + return b.FinishObject(start, fields) + case reflect.Struct: + start, fields := b.Offset(), make([]FieldEntry, 0, val.NumField()) + + typ := val.Type() + for i := range typ.NumField() { + f := typ.Field(i) + if !f.IsExported() { + continue + } + + name, opt := extractFieldInfo(f) + if name == "-" { + continue + } + + fields = append(fields, b.NextField(start, name)) + if err := b.append(val.Field(i).Interface(), o|opt); err != nil { + return err + } + } + return b.FinishObject(start, fields) + } + } + return fmt.Errorf("cannot append unsupported type to variant: %T", v) +} + +// AppendNull appends a null value to the builder. +func (b *Builder) AppendNull() error { + return b.buf.WriteByte(primitiveHeader(PrimitiveNull)) +} + +// AppendBool appends a boolean value to the builder. +func (b *Builder) AppendBool(v bool) error { + var t PrimitiveType + if v { + t = PrimitiveBoolTrue + } else { + t = PrimitiveBoolFalse + } + + return b.buf.WriteByte(primitiveHeader(t)) +} + +type primitiveNumeric interface { + int8 | int16 | int32 | int64 | float32 | float64 | + arrow.Date32 | arrow.Time64 +} + +type buffer interface { + io.Writer + io.ByteWriter +} + +func writeBinary[T string | []byte](w buffer, v T) error { + var t PrimitiveType + switch any(v).(type) { + case string: + t = PrimitiveString + case []byte: + t = PrimitiveBinary + } + + if err := w.WriteByte(primitiveHeader(t)); err != nil { + return err + } + + if err := binary.Write(w, binary.LittleEndian, uint32(len(v))); err != nil { + return err + } + + _, err := w.Write([]byte(v)) + return err +} + +func writeNumeric[T primitiveNumeric](w buffer, v T) error { + var t PrimitiveType + switch any(v).(type) { + case int8: + t = PrimitiveInt8 + case int16: + t = PrimitiveInt16 + case int32: + t = PrimitiveInt32 + case int64: + t = PrimitiveInt64 + case float32: + t = PrimitiveFloat + case float64: + t = PrimitiveDouble + case arrow.Date32: + t = PrimitiveDate + case arrow.Time64: + t = PrimitiveTimeMicrosNTZ + } + + if err := w.WriteByte(primitiveHeader(t)); err != nil { + return err + } + + return binary.Write(w, binary.LittleEndian, v) +} + +// AppendInt appends an integer value to the builder, using the smallest +// possible integer representation based on the value's range. +func (b *Builder) AppendInt(v int64) error { + b.buf.Grow(9) + switch { + case v >= math.MinInt8 && v <= math.MaxInt8: + return writeNumeric(&b.buf, int8(v)) + case v >= math.MinInt16 && v <= math.MaxInt16: + return writeNumeric(&b.buf, int16(v)) + case v >= math.MinInt32 && v <= math.MaxInt32: + return writeNumeric(&b.buf, int32(v)) + default: + return writeNumeric(&b.buf, v) + } +} + +// AppendFloat32 appends a 32-bit floating point value to the builder. +func (b *Builder) AppendFloat32(v float32) error { + b.buf.Grow(5) + return writeNumeric(&b.buf, v) +} + +// AppendFloat64 appends a 64-bit floating point value to the builder. +func (b *Builder) AppendFloat64(v float64) error { + b.buf.Grow(9) + return writeNumeric(&b.buf, v) +} + +// AppendDate appends a date value to the builder. +func (b *Builder) AppendDate(v arrow.Date32) error { + b.buf.Grow(5) + return writeNumeric(&b.buf, v) +} + +// AppendTimeMicro appends a time value with microsecond precision to the builder. +func (b *Builder) AppendTimeMicro(v arrow.Time64) error { + b.buf.Grow(9) + return writeNumeric(&b.buf, v) +} + +// AppendTimestamp appends a timestamp value to the builder. +// The useMicros parameter controls whether microsecond or nanosecond precision is used. +// The useUTC parameter controls whether the timestamp is in UTC timezone or has no time zone (NTZ). +func (b *Builder) AppendTimestamp(v arrow.Timestamp, useMicros, useUTC bool) error { + b.buf.Grow(9) + var t PrimitiveType + if useMicros { + t = PrimitiveTimestampMicrosNTZ + } else { + t = PrimitiveTimestampNanosNTZ + } + + if useUTC { + t-- + } + + if err := b.buf.WriteByte(primitiveHeader(t)); err != nil { + return err + } + + return binary.Write(&b.buf, binary.LittleEndian, v) +} + +// AppendBinary appends a binary value to the builder. +func (b *Builder) AppendBinary(v []byte) error { + b.buf.Grow(5 + len(v)) + return writeBinary(&b.buf, v) +} + +// AppendString appends a string value to the builder. +// Small strings are encoded using the short string representation if small enough. +func (b *Builder) AppendString(v string) error { + if len(v) > maxShortStringSize { + b.buf.Grow(5 + len(v)) + return writeBinary(&b.buf, v) + } + + b.buf.Grow(1 + len(v)) + if err := b.buf.WriteByte(shortStrHeader(len(v))); err != nil { + return err + } + + _, err := b.buf.WriteString(v) + return err +} + +// AppendUUID appends a UUID value to the builder. +func (b *Builder) AppendUUID(v uuid.UUID) error { + b.buf.Grow(17) + if err := b.buf.WriteByte(primitiveHeader(PrimitiveUUID)); err != nil { + return err + } + + m, _ := v.MarshalBinary() + _, err := b.buf.Write(m) + return err +} + +// AppendDecimal4 appends a 4-byte decimal value with the specified scale to the builder. +func (b *Builder) AppendDecimal4(scale uint8, v decimal.Decimal32) error { + b.buf.Grow(6) + if err := b.buf.WriteByte(primitiveHeader(PrimitiveDecimal4)); err != nil { + return err + } + + if err := b.buf.WriteByte(scale); err != nil { + return err + } + + return binary.Write(&b.buf, binary.LittleEndian, int32(v)) +} + +// AppendDecimal8 appends a 8-byte decimal value with the specified scale to the builder. +func (b *Builder) AppendDecimal8(scale uint8, v decimal.Decimal64) error { + b.buf.Grow(10) + return errors.Join( + b.buf.WriteByte(primitiveHeader(PrimitiveDecimal8)), + b.buf.WriteByte(scale), + binary.Write(&b.buf, binary.LittleEndian, int64(v)), + ) +} + +// AppendDecimal16 appends a 16-byte decimal value with the specified scale to the builder. +func (b *Builder) AppendDecimal16(scale uint8, v decimal.Decimal128) error { + b.buf.Grow(18) + return errors.Join( + b.buf.WriteByte(primitiveHeader(PrimitiveDecimal16)), + b.buf.WriteByte(scale), + binary.Write(&b.buf, binary.LittleEndian, v.LowBits()), + binary.Write(&b.buf, binary.LittleEndian, v.HighBits()), + ) +} + +// Offset returns the current offset in the builder's buffer. Generally used for +// grabbing a starting point for building an array or object. +func (b *Builder) Offset() int { + return b.buf.Len() +} + +// NextElement returns the offset of the next element relative to the start position. +// Use when building arrays to track element positions. The following creates a variant +// equivalent to `[5, 10]`. +// +// var b variant.Builder +// start, offsets := b.Offset(), make([]int, 0) +// offsets = append(offsets, b.NextElement(start)) +// b.Append(5) +// offsets = append(offsets, b.NextElement(start)) +// b.Append(10) +// b.FinishArray(start, offsets) +// +// The value returned by this is equivalent to `b.Offset() - start`, as offsets are all +// relative to the start position. This allows for creating nested arrays, the following +// creates a variant equivalent to `[5, [10, 20], 30]`. +// +// var b variant.Builder +// start, offsets := b.Offset(), make([]int, 0) +// offsets = append(offsets, b.NextElement(start)) +// b.Append(5) +// offsets = append(offsets, b.NextElement(start)) +// +// nestedStart, nestedOffsets := b.Offset(), make([]int, 0) +// nestedOffsets = append(nestedOffsets, b.NextElement(nestedStart)) +// b.Append(10) +// nestedOffsets = append(nestedOffsets, b.NextElement(nestedStart)) +// b.Append(20) +// b.FinishArray(nestedStart, nestedOffsets) +// +// offsets = append(offsets, b.NextElement(start)) +// b.Append(30) +// b.FinishArray(start, offsets) +func (b *Builder) NextElement(start int) int { + return b.Offset() - start +} + +// FinishArray finalizes an array value in the builder. +// The start parameter is the offset where the array begins. +// The offsets parameter contains the offsets of each element in the array. See [Builder.NextElement] +// for examples of how to use this. +func (b *Builder) FinishArray(start int, offsets []int) error { + var ( + dataSize, sz = b.buf.Len() - start, len(offsets) + isLarge = sz > math.MaxUint8 + sizeBytes = 1 + ) + + if isLarge { + sizeBytes = 4 + } + + if dataSize < 0 { + return errors.New("invalid array size") + } + + offsetSize := intSize(dataSize) + headerSize := 1 + sizeBytes + (sz+1)*int(offsetSize) + + // shift the just written data to make room for the header section + b.buf.Grow(headerSize) + av := b.buf.AvailableBuffer() + if _, err := b.buf.Write(av[:headerSize]); err != nil { + return err + } + + bs := b.buf.Bytes() + copy(bs[start+headerSize:], bs[start:start+dataSize]) + + // populate the header + bs[start] = arrayHeader(isLarge, offsetSize) + writeOffset(bs[start+1:], sz, uint8(sizeBytes)) + + offsetsStart := start + 1 + sizeBytes + for i, off := range offsets { + writeOffset(bs[offsetsStart+i*int(offsetSize):], off, offsetSize) + } + writeOffset(bs[offsetsStart+sz*int(offsetSize):], dataSize, offsetSize) + + return nil +} + +// FieldEntry represents a field in an object, with its key, ID, and offset. +// Usually constructed by using [Builder.NextField] and then passed to [Builder.FinishObject]. +type FieldEntry struct { + Key string + ID uint32 + Offset int +} + +// NextField creates a new field entry for an object with the given key. +// The start parameter is the offset where the object begins. The following example would +// construct a variant equivalent to `{"key1": 5, "key2": 10}`. +// +// var b variant.Builder +// start, fields := b.Offset(), make([]variant.FieldEntry, 0) +// fields = append(fields, b.NextField(start, "key1")) +// b.Append(5) +// fields = append(fields, b.NextField(start, "key2")) +// b.Append(10) +// b.FinishObject(start, fields) +// +// This allows for creating nested objects, the following example would create a variant +// equivalent to `{"key1": 5, "key2": {"key3": 10, "key4": 20}, "key5": 30}`. +// +// var b variant.Builder +// start, fields := b.Offset(), make([]variant.FieldEntry, 0) +// fields = append(fields, b.NextField(start, "key1")) +// b.Append(5) +// fields = append(fields, b.NextField(start, "key2")) +// nestedStart, nestedFields := b.Offset(), make([]variant.FieldEntry, 0) +// nestedFields = append(nestedFields, b.NextField(nestedStart, "key3")) +// b.Append(10) +// nestedFields = append(nestedFields, b.NextField(nestedStart, "key4")) +// b.Append(20) +// b.FinishObject(nestedStart, nestedFields) +// fields = append(fields, b.NextField(start, "key5")) +// b.Append(30) +// b.FinishObject(start, fields) +// +// The offset value returned by this is equivalent to `b.Offset() - start`, as offsets are all +// relative to the start position. The key provided will be passed to the [Builder.AddKey] method +// to ensure that the key is added to the dictionary and an ID is assigned. It will re-use existing +// IDs if the key already exists in the dictionary. +func (b *Builder) NextField(start int, key string) FieldEntry { + id := b.AddKey(key) + return FieldEntry{ + Key: key, + ID: id, + Offset: b.Offset() - start, + } +} + +// FinishObject finalizes an object value in the builder. +// The start parameter is the offset where the object begins. +// The fields parameter contains the entries for each field in the object. See [Builder.NextField] +// for examples of how to use this. +// +// The fields are sorted by key before finalizing the object. If duplicate keys are found, +// the last value for a key is kept if [Builder.SetAllowDuplicates] is set to true. If false, +// an error is returned. +func (b *Builder) FinishObject(start int, fields []FieldEntry) error { + slices.SortFunc(fields, func(a, b FieldEntry) int { + return cmp.Compare(a.Key, b.Key) + }) + + sz := len(fields) + var maxID uint32 + if sz > 0 { + maxID = fields[0].ID + } + + // if a duplicate key is found, one of two things happens: + // - if allowDuplicates is true, then the field with the greatest + // offset value (the last appended field) is kept. + // - if allowDuplicates is false, then an error is returned + if b.allowDuplicates { + distinctPos := 0 + // maintain a list of distinct keys in-place + for i := 1; i < sz; i++ { + maxID = max(maxID, fields[i].ID) + if fields[i].ID == fields[i-1].ID { + // found a duplicate key. keep the + // field with a greater offset + if fields[distinctPos].Offset < fields[i].Offset { + fields[distinctPos].Offset = fields[i].Offset + } + } else { + // found distinct key, add field to the list + distinctPos++ + fields[distinctPos] = fields[i] + } + } + + if distinctPos+1 < len(fields) { + sz = distinctPos + 1 + // resize fields to size + fields = fields[:sz] + // sort the fields by offsets so that we can move the value + // data of each field to the new offset without overwriting the + // fields after it. + slices.SortFunc(fields, func(a, b FieldEntry) int { + return cmp.Compare(a.Offset, b.Offset) + }) + + buf := b.buf.Bytes() + curOffset := 0 + for i := range sz { + oldOffset := fields[i].Offset + fieldSize := valueSize(buf[start+oldOffset:]) + copy(buf[start+curOffset:], buf[start+oldOffset:start+oldOffset+fieldSize]) + fields[i].Offset = curOffset + curOffset += fieldSize + } + b.buf.Truncate(start + curOffset) + // change back to sort order by field keys to meet variant spec + slices.SortFunc(fields, func(a, b FieldEntry) int { + return cmp.Compare(a.Key, b.Key) + }) + } + } else { + for i := 1; i < sz; i++ { + maxID = max(maxID, fields[i].ID) + if fields[i].Key == fields[i-1].Key { + return fmt.Errorf("disallowed duplicate key found: %s", fields[i].Key) + } + } + } + + var ( + dataSize = b.buf.Len() - start + isLarge = sz > math.MaxUint8 + sizeBytes = 1 + idSize, offsetSize = intSize(int(maxID)), intSize(dataSize) + ) + + if isLarge { + sizeBytes = 4 + } + + if dataSize < 0 { + return errors.New("invalid object size") + } + + headerSize := 1 + sizeBytes + sz*int(idSize) + (sz+1)*int(offsetSize) + // shift the just written data to make room for the header section + b.buf.Grow(headerSize) + av := b.buf.AvailableBuffer() + if _, err := b.buf.Write(av[:headerSize]); err != nil { + return err + } + + bs := b.buf.Bytes() + copy(bs[start+headerSize:], bs[start:start+dataSize]) + + // populate the header + bs[start] = objectHeader(isLarge, idSize, offsetSize) + writeOffset(bs[start+1:], sz, uint8(sizeBytes)) + + idStart := start + 1 + sizeBytes + offsetStart := idStart + sz*int(idSize) + for i, field := range fields { + writeOffset(bs[idStart+i*int(idSize):], int(field.ID), idSize) + writeOffset(bs[offsetStart+i*int(offsetSize):], field.Offset, offsetSize) + } + writeOffset(bs[offsetStart+sz*int(offsetSize):], dataSize, offsetSize) + return nil +} + +// Reset truncates the builder's buffer and clears the dictionary while re-using the +// underlying storage where possible. This allows for reusing the builder while keeping +// the total memory usage low. The caveat to this is that any variant value returned +// by calling [Builder.Build] must be cloned with [Value.Clone] before calling this +// method. Otherwise, the byte slice used by the value will be invalidated upon calling +// this method. +// +// For trivial cases where the builder is not reused, this method never needs to be called, +// and the variant built by the builder gets to avoid having to copy the buffer, just referring +// to it directly. +func (b *Builder) Reset() { + b.buf.Reset() + b.dict = make(map[string]uint32) + for i := range b.dictKeys { + b.dictKeys[i] = nil + } + b.dictKeys = b.dictKeys[:0] +} + +// Build creates a Variant Value from the builder's current state. +// The returned Value includes both the value data and the metadata (dictionary). +// +// Importantly, the value data is the returned variant value is not copied here. This will +// return the raw buffer data owned by the builder's buffer. If you wish to reuse a builder, +// then the [Value.Clone] method must be called on the returned value to copy the data before +// calling [Builder.Reset]. This enables trivial cases that don't reuse the builder to avoid +// performing this copy. +func (b *Builder) Build() (Value, error) { + nkeys := len(b.dictKeys) + totalDictSize := 0 + for _, k := range b.dictKeys { + totalDictSize += len(k) + } + + // determine the number of bytes required per offset entry. + // the largest offset is the one-past-the-end value, the total size. + // It's very unlikely that the number of keys could be larger, but + // incorporate that into the calculation in case of pathological data. + maxSize := max(totalDictSize, nkeys) + if maxSize > maxSizeLimit { + return Value{}, fmt.Errorf("metadata size too large: %d", maxSize) + } + + offsetSize := intSize(int(maxSize)) + offsetStart := 1 + offsetSize + stringStart := int(offsetStart) + (nkeys+1)*int(offsetSize) + metadataSize := stringStart + totalDictSize + + if metadataSize > maxSizeLimit { + return Value{}, fmt.Errorf("metadata size too large: %d", metadataSize) + } + + meta := make([]byte, metadataSize) + + meta[0] = supportedVersion | ((offsetSize - 1) << 6) + if nkeys > 0 && slices.IsSortedFunc(b.dictKeys, bytes.Compare) { + meta[0] |= 1 << 4 + } + writeOffset(meta[1:], nkeys, offsetSize) + + curOffset := 0 + for i, k := range b.dictKeys { + writeOffset(meta[int(offsetStart)+i*int(offsetSize):], curOffset, offsetSize) + curOffset += copy(meta[stringStart+curOffset:], k) + } + writeOffset(meta[int(offsetStart)+nkeys*int(offsetSize):], curOffset, offsetSize) + + return Value{ + value: b.buf.Bytes(), + meta: Metadata{ + data: meta, + keys: b.dictKeys, + }, + }, nil +} diff --git a/parquet/variant/builder_test.go b/parquet/variant/builder_test.go new file mode 100644 index 00000000..21292f50 --- /dev/null +++ b/parquet/variant/builder_test.go @@ -0,0 +1,787 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package variant_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/parquet/variant" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildNullValue(t *testing.T) { + var b variant.Builder + b.AppendNull() + + v, err := b.Build() + require.NoError(t, err) + + assert.Equal(t, variant.Null, v.Type()) + assert.EqualValues(t, 1, v.Metadata().Version()) + assert.Zero(t, v.Metadata().DictionarySize()) +} + +func TestBuildPrimitive(t *testing.T) { + tests := []struct { + name string + op func(*variant.Builder) error + }{ + {"primitive_boolean_true", func(b *variant.Builder) error { + return b.AppendBool(true) + }}, + {"primitive_boolean_false", func(b *variant.Builder) error { + return b.AppendBool(false) + }}, + // AppendInt will use the smallest possible int type + {"primitive_int8", func(b *variant.Builder) error { return b.AppendInt(42) }}, + {"primitive_int16", func(b *variant.Builder) error { return b.AppendInt(1234) }}, + {"primitive_int32", func(b *variant.Builder) error { return b.AppendInt(123456) }}, + // FIXME: https://github.com/apache/parquet-testing/issues/82 + // primitive_int64 is an int32 value, but the metadata is int64 + {"primitive_int64", func(b *variant.Builder) error { return b.AppendInt(12345678) }}, + {"primitive_float", func(b *variant.Builder) error { return b.AppendFloat32(1234568000) }}, + {"primitive_double", func(b *variant.Builder) error { return b.AppendFloat64(1234567890.1234) }}, + {"primitive_string", func(b *variant.Builder) error { + return b.AppendString(`This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥️, 🎣 and 🤦!!`) + }}, + {"short_string", func(b *variant.Builder) error { return b.AppendString(`Less than 64 bytes (❤️ with utf8)`) }}, + // 031337deadbeefcafe + {"primitive_binary", func(b *variant.Builder) error { + return b.AppendBinary([]byte{0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}) + }}, + {"primitive_decimal4", func(b *variant.Builder) error { return b.AppendDecimal4(2, 1234) }}, + {"primitive_decimal8", func(b *variant.Builder) error { return b.AppendDecimal8(2, 1234567890) }}, + {"primitive_decimal16", func(b *variant.Builder) error { return b.AppendDecimal16(2, decimal128.FromU64(1234567891234567890)) }}, + {"primitive_date", func(b *variant.Builder) error { return b.AppendDate(20194) }}, + {"primitive_timestamp", func(b *variant.Builder) error { return b.AppendTimestamp(1744821296780000, true, true) }}, + {"primitive_timestampntz", func(b *variant.Builder) error { return b.AppendTimestamp(1744806896780000, true, false) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected := loadVariant(t, tt.name) + + var b variant.Builder + require.NoError(t, tt.op(&b)) + + v, err := b.Build() + require.NoError(t, err) + + assert.Equal(t, expected.Type(), v.Type()) + assert.Equal(t, expected.Bytes(), v.Bytes()) + assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) + }) + } +} + +func TestBuildInt64(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendInt(1234567890987654321)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Int64, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveInt64), + 0xB1, 0x1C, 0x6C, 0xB1, 0xF4, 0x10, 0x22, 0x11}, v.Bytes()) +} + +func TestBuildObject(t *testing.T) { + var b variant.Builder + start := b.Offset() + + fields := make([]variant.FieldEntry, 0, 7) + + fields = append(fields, b.NextField(start, "int_field")) + require.NoError(t, b.AppendInt(1)) + + fields = append(fields, b.NextField(start, "double_field")) + require.NoError(t, b.AppendDecimal4(8, 123456789)) + + fields = append(fields, b.NextField(start, "boolean_true_field")) + require.NoError(t, b.AppendBool(true)) + + fields = append(fields, b.NextField(start, "boolean_false_field")) + require.NoError(t, b.AppendBool(false)) + + fields = append(fields, b.NextField(start, "string_field")) + require.NoError(t, b.AppendString("Apache Parquet")) + + fields = append(fields, b.NextField(start, "null_field")) + require.NoError(t, b.AppendNull()) + + fields = append(fields, b.NextField(start, "timestamp_field")) + require.NoError(t, b.AppendString("2025-04-16T12:34:56.78")) + + require.NoError(t, b.FinishObject(start, fields)) + v, err := b.Build() + require.NoError(t, err) + + assert.Equal(t, variant.Object, v.Type()) + expected := loadVariant(t, "object_primitive") + + assert.Equal(t, expected.Metadata().DictionarySize(), v.Metadata().DictionarySize()) + assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) + assert.Equal(t, expected.Bytes(), v.Bytes()) +} + +func TestBuildObjectDuplicateKeys(t *testing.T) { + t.Run("disallow duplicates", func(t *testing.T) { + var b variant.Builder + start := b.Offset() + + fields := make([]variant.FieldEntry, 0, 3) + + fields = append(fields, b.NextField(start, "int_field")) + require.NoError(t, b.AppendInt(1)) + + fields = append(fields, b.NextField(start, "int_field")) + require.NoError(t, b.AppendInt(2)) + + fields = append(fields, b.NextField(start, "int_field")) + require.NoError(t, b.AppendInt(3)) + + require.Error(t, b.FinishObject(start, fields)) + }) + + t.Run("allow duplicates", func(t *testing.T) { + var b variant.Builder + start := b.Offset() + + fields := make([]variant.FieldEntry, 0, 3) + + fields = append(fields, b.NextField(start, "int_field")) + require.NoError(t, b.AppendInt(1)) + + fields = append(fields, b.NextField(start, "string_field")) + require.NoError(t, b.AppendString("Apache Parquet")) + + fields = append(fields, b.NextField(start, "int_field")) + require.NoError(t, b.AppendInt(2)) + + fields = append(fields, b.NextField(start, "int_field")) + require.NoError(t, b.AppendInt(3)) + + fields = append(fields, b.NextField(start, "string_field")) + require.NoError(t, b.AppendString("Apache Arrow")) + + b.SetAllowDuplicates(true) + require.NoError(t, b.FinishObject(start, fields)) + + v, err := b.Build() + require.NoError(t, err) + + assert.Equal(t, variant.Object, v.Type()) + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `{ + "int_field": 3, + "string_field": "Apache Arrow" + }`, string(out)) + }) +} + +func TestBuildObjectNested(t *testing.T) { + var b variant.Builder + + start := b.Offset() + topFields := make([]variant.FieldEntry, 0, 3) + + topFields = append(topFields, b.NextField(start, "id")) + require.NoError(t, b.AppendInt(1)) + + topFields = append(topFields, b.NextField(start, "observation")) + + observeFields := make([]variant.FieldEntry, 0, 3) + observeStart := b.Offset() + observeFields = append(observeFields, b.NextField(observeStart, "location")) + require.NoError(t, b.AppendString("In the Volcano")) + observeFields = append(observeFields, b.NextField(observeStart, "time")) + require.NoError(t, b.AppendString("12:34:56")) + observeFields = append(observeFields, b.NextField(observeStart, "value")) + + valueStart := b.Offset() + valueFields := make([]variant.FieldEntry, 0, 2) + valueFields = append(valueFields, b.NextField(valueStart, "humidity")) + require.NoError(t, b.AppendInt(456)) + valueFields = append(valueFields, b.NextField(valueStart, "temperature")) + require.NoError(t, b.AppendInt(123)) + + require.NoError(t, b.FinishObject(valueStart, valueFields)) + require.NoError(t, b.FinishObject(observeStart, observeFields)) + + topFields = append(topFields, b.NextField(start, "species")) + speciesStart := b.Offset() + speciesFields := make([]variant.FieldEntry, 0, 2) + speciesFields = append(speciesFields, b.NextField(speciesStart, "name")) + require.NoError(t, b.AppendString("lava monster")) + + speciesFields = append(speciesFields, b.NextField(speciesStart, "population")) + require.NoError(t, b.AppendInt(6789)) + + require.NoError(t, b.FinishObject(speciesStart, speciesFields)) + require.NoError(t, b.FinishObject(start, topFields)) + + v, err := b.Build() + require.NoError(t, err) + + out, err := json.Marshal(v) + require.NoError(t, err) + + assert.JSONEq(t, `{ + "id": 1, + "observation": { + "location": "In the Volcano", + "time": "12:34:56", + "value": { + "humidity": 456, + "temperature": 123 + } + }, + "species": { + "name": "lava monster", + "population": 6789 + } + }`, string(out)) +} + +func TestBuildUUID(t *testing.T) { + u := uuid.MustParse("00112233-4455-6677-8899-aabbccddeeff") + + var b variant.Builder + require.NoError(t, b.AppendUUID(u)) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.UUID, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveUUID), + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, v.Bytes()) +} + +func TestBuildTimestampNanos(t *testing.T) { + t.Run("ts nanos tz negative", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendTimestamp(-1, false, true)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanos, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanos), + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, v.Bytes()) + }) + + t.Run("ts nanos tz positive", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendTimestamp(1744877350123456789, false, true)) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanos, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanos), + 0x15, 0xC9, 0xBB, 0x86, 0xB4, 0x0C, 0x37, 0x18}, v.Bytes()) + }) + + t.Run("ts nanos ntz positive", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendTimestamp(1744877350123456789, false, false)) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanosNTZ, v.Type()) + assert.Equal(t, []byte{primitiveHeader(variant.PrimitiveTimestampNanosNTZ), + 0x15, 0xC9, 0xBB, 0x86, 0xB4, 0x0C, 0x37, 0x18}, v.Bytes()) + }) +} + +func TestBuildArrayValues(t *testing.T) { + t.Run("array primitive", func(t *testing.T) { + var b variant.Builder + + start := b.Offset() + offsets := make([]int, 0, 4) + + offsets = append(offsets, b.NextElement(start)) + require.NoError(t, b.AppendInt(2)) + + offsets = append(offsets, b.NextElement(start)) + require.NoError(t, b.AppendInt(1)) + + offsets = append(offsets, b.NextElement(start)) + require.NoError(t, b.AppendInt(5)) + + offsets = append(offsets, b.NextElement(start)) + require.NoError(t, b.AppendInt(9)) + + require.NoError(t, b.FinishArray(start, offsets)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Array, v.Type()) + + expected := loadVariant(t, "array_primitive") + assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) + assert.Equal(t, expected.Bytes(), v.Bytes()) + }) + + t.Run("array empty", func(t *testing.T) { + var b variant.Builder + + require.NoError(t, b.FinishArray(b.Offset(), nil)) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Array, v.Type()) + + expected := loadVariant(t, "array_empty") + assert.Equal(t, expected.Metadata().Bytes(), v.Metadata().Bytes()) + assert.Equal(t, expected.Bytes(), v.Bytes()) + }) + + t.Run("array nested", func(t *testing.T) { + var b variant.Builder + + start := b.Offset() + offsets := make([]int, 0, 3) + + { + offsets = append(offsets, b.NextElement(start)) + objStart := b.Offset() + objFields := make([]variant.FieldEntry, 0, 2) + objFields = append(objFields, b.NextField(objStart, "id")) + require.NoError(t, b.AppendInt(1)) + + objFields = append(objFields, b.NextField(objStart, "thing")) + thingObjStart := b.Offset() + thingObjFields := make([]variant.FieldEntry, 0, 1) + thingObjFields = append(thingObjFields, b.NextField(thingObjStart, "names")) + + namesStart := b.Offset() + namesOffsets := make([]int, 0, 2) + namesOffsets = append(namesOffsets, b.NextElement(namesStart)) + require.NoError(t, b.AppendString("Contrarian")) + namesOffsets = append(namesOffsets, b.NextElement(namesStart)) + require.NoError(t, b.AppendString("Spider")) + require.NoError(t, b.FinishArray(namesStart, namesOffsets)) + + require.NoError(t, b.FinishObject(thingObjStart, thingObjFields)) + require.NoError(t, b.FinishObject(objStart, objFields)) + } + { + offsets = append(offsets, b.NextElement(start)) + b.AppendNull() + } + { + offsets = append(offsets, b.NextElement(start)) + objStart := b.Offset() + objFields := make([]variant.FieldEntry, 0, 3) + objFields = append(objFields, b.NextField(objStart, "id")) + require.NoError(t, b.AppendInt(2)) + + objFields = append(objFields, b.NextField(objStart, "names")) + namesStart := b.Offset() + namesOffsets := make([]int, 0, 3) + namesOffsets = append(namesOffsets, b.NextElement(namesStart)) + require.NoError(t, b.AppendString("Apple")) + namesOffsets = append(namesOffsets, b.NextElement(namesStart)) + require.NoError(t, b.AppendString("Ray")) + namesOffsets = append(namesOffsets, b.NextElement(namesStart)) + require.NoError(t, b.AppendNull()) + require.NoError(t, b.FinishArray(namesStart, namesOffsets)) + + objFields = append(objFields, b.NextField(objStart, "type")) + require.NoError(t, b.AppendString("if")) + + require.NoError(t, b.FinishObject(objStart, objFields)) + } + + require.NoError(t, b.FinishArray(start, offsets)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Array, v.Type()) + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `[ + {"id": 1, "thing": {"names": ["Contrarian", "Spider"]}}, + null, + {"id": 2, "names": ["Apple", "Ray", null], "type": "if"} + ]`, string(out)) + }) +} + +func TestAppendPrimitives(t *testing.T) { + tests := []struct { + name string + value any + valueOut any + expected variant.Type + }{ + {"null", nil, nil, variant.Null}, + {"bool_true", true, true, variant.Bool}, + {"bool_false", false, false, variant.Bool}, + {"int8", int8(42), int8(42), variant.Int8}, + {"uint8", uint8(42), int8(42), variant.Int8}, + {"int16", int16(1234), int16(1234), variant.Int16}, + {"uint16", uint16(1234), int16(1234), variant.Int16}, + {"int32", int32(123456), int32(123456), variant.Int32}, + {"uint32", uint32(123456), int32(123456), variant.Int32}, + {"int64", int64(1234567890123), int64(1234567890123), variant.Int64}, + {"int", int(123456), int32(123456), variant.Int32}, + {"uint", uint(123456), int32(123456), variant.Int32}, + {"float32", float32(123.45), float32(123.45), variant.Float}, + {"float64", float64(123.45), float64(123.45), variant.Double}, + {"string", "test string", "test string", variant.String}, + {"bytes", []byte{1, 2, 3, 4}, []byte{1, 2, 3, 4}, variant.Binary}, + {"date", arrow.Date32(2023), arrow.Date32(2023), variant.Date}, + {"timestamp", arrow.Timestamp(1234567890), arrow.Timestamp(1234567890), variant.TimestampMicrosNTZ}, + {"time", arrow.Time64(123456), arrow.Time64(123456), variant.Time}, + {"uuid", uuid.MustParse("00112233-4455-6677-8899-aabbccddeeff"), + uuid.MustParse("00112233-4455-6677-8899-aabbccddeeff"), variant.UUID}, + {"decimal4", variant.DecimalValue[decimal.Decimal32]{ + Scale: 2, Value: decimal.Decimal32(1234), + }, variant.DecimalValue[decimal.Decimal32]{ + Scale: 2, Value: decimal.Decimal32(1234), + }, variant.Decimal4}, + {"decimal8", variant.DecimalValue[decimal.Decimal64]{ + Scale: 2, Value: decimal.Decimal64(1234), + }, variant.DecimalValue[decimal.Decimal64]{ + Scale: 2, Value: decimal.Decimal64(1234), + }, variant.Decimal8}, + {"decimal16", variant.DecimalValue[decimal.Decimal128]{ + Scale: 2, Value: decimal128.FromU64(1234), + }, variant.DecimalValue[decimal.Decimal128]{ + Scale: 2, Value: decimal128.FromU64(1234), + }, variant.Decimal16}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append(tt.value)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, tt.expected, v.Type()) + assert.Equal(t, tt.valueOut, v.Value()) + }) + } +} + +func TestAppendTimestampOptions(t *testing.T) { + testTime := time.Date(2023, 5, 15, 14, 30, 0, 123456789, time.UTC) + + tests := []struct { + name string + opts []variant.AppendOpt + expected variant.Type + }{ + {"default_micros", nil, variant.TimestampMicrosNTZ}, + {"nanos", []variant.AppendOpt{variant.OptTimestampNano}, variant.TimestampNanosNTZ}, + {"utc_micros", []variant.AppendOpt{variant.OptTimestampUTC}, variant.TimestampMicros}, + {"utc_nanos", []variant.AppendOpt{variant.OptTimestampUTC, variant.OptTimestampNano}, variant.TimestampNanos}, + {"as_date", []variant.AppendOpt{variant.OptTimeAsDate}, variant.Date}, + {"as_time", []variant.AppendOpt{variant.OptTimeAsTime}, variant.Time}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append(testTime, tt.opts...)) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, tt.expected, v.Type()) + }) + } +} + +func TestAppendArrays(t *testing.T) { + t.Run("slice_of_any", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append([]any{1, "test", true, nil})) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Array, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `[1, "test", true, null]`, string(out)) + }) + + t.Run("slice_of_ints", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append([]int{10, 20, 30})) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Array, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `[10, 20, 30]`, string(out)) + }) + + t.Run("nested_slices", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append([]any{ + []int{1, 2}, + []string{"a", "b"}, + })) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Array, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `[[1, 2], ["a", "b"]]`, string(out)) + }) +} + +func TestAppendMaps(t *testing.T) { + t.Run("map_string_any", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append(map[string]any{ + "int": 123, + "str": "test", + "bool": true, + "null": nil, + })) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Object, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `{"bool": true, "int": 123, "null": null, "str": "test"}`, string(out)) + }) + + t.Run("map_string_int", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append(map[string]int{ + "int": 123, + "int2": 456, + })) + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Object, v.Type()) + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `{"int": 123, "int2": 456}`, string(out)) + }) + + t.Run("map_with_nested_objects", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append(map[string]any{ + "metadata": map[string]any{ + "id": 1, + "name": "test", + }, + "values": []int{10, 20, 30}, + })) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Object, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `{"metadata": {"id": 1, "name": "test"}, "values": [10, 20, 30]}`, string(out)) + }) + + t.Run("unsupported_map_key", func(t *testing.T) { + var b variant.Builder + err := b.Append(map[int]string{1: "test"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported map key type") + }) +} + +type SimpleStruct struct { + ID int + Name string + IsValid bool +} + +type StructWithTags struct { + ID int `variant:"id"` + Name string `variant:"name"` + Ignored string `variant:"-"` + Timestamp time.Time `variant:"ts,nanos,utc"` + Date time.Time `variant:"date,date"` + TimeOnly time.Time `variant:",time"` +} + +type NestedStruct struct { + ID int `variant:"id"` + Metadata *SimpleStruct `variant:"meta"` + Tags []string `variant:"tags"` +} + +func TestAppendStructs(t *testing.T) { + t.Run("simple_struct", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append(SimpleStruct{ + ID: 123, + Name: "test", + IsValid: true, + })) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Object, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `{"ID": 123, "Name": "test", "IsValid": true}`, string(out)) + }) + + t.Run("struct_with_tags", func(t *testing.T) { + testTime := time.Date(2023, 5, 15, 14, 30, 0, 123456789, time.UTC) + var b variant.Builder + require.NoError(t, b.Append(StructWithTags{ + ID: 123, + Name: "test", + Ignored: "should not appear", + Timestamp: testTime, + Date: testTime, + TimeOnly: testTime, + })) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Object, v.Type()) + + obj := v.Value().(variant.ObjectValue) + id, err := obj.ValueByKey("id") + require.NoError(t, err) + assert.Equal(t, variant.Int8, id.Value.Type()) + assert.Equal(t, int8(123), id.Value.Value()) + + name, err := obj.ValueByKey("name") + require.NoError(t, err) + assert.Equal(t, variant.String, name.Value.Type()) + assert.Equal(t, "test", name.Value.Value()) + + ignored, err := obj.ValueByKey("Ignored") + require.ErrorIs(t, err, arrow.ErrNotFound) + assert.Zero(t, ignored) + + ts, err := obj.ValueByKey("ts") + require.NoError(t, err) + assert.Equal(t, variant.TimestampNanos, ts.Value.Type()) + assert.Equal(t, arrow.Timestamp(testTime.UnixNano()), ts.Value.Value()) + + date, err := obj.ValueByKey("date") + require.NoError(t, err) + assert.Equal(t, variant.Date, date.Value.Type()) + assert.Equal(t, arrow.Date32FromTime(testTime), date.Value.Value()) + + timeOnly, err := obj.ValueByKey("TimeOnly") + require.NoError(t, err) + assert.Equal(t, variant.Time, timeOnly.Value.Type()) + assert.Equal(t, arrow.Time64(52200123456), timeOnly.Value.Value()) + }) + + t.Run("nested_struct", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append(NestedStruct{ + ID: 123, + Metadata: &SimpleStruct{ + ID: 456, + Name: "nested", + IsValid: true, + }, + Tags: []string{"tag1", "tag2"}, + })) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Object, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `{ + "id": 123, + "meta": {"ID": 456, "Name": "nested", "IsValid": true}, + "tags": ["tag1", "tag2"] + }`, string(out)) + }) + + t.Run("nil_struct_pointer", func(t *testing.T) { + var b variant.Builder + require.NoError(t, b.Append(NestedStruct{ + ID: 123, + Metadata: nil, + Tags: []string{"tag1"}, + })) + + v, err := b.Build() + require.NoError(t, err) + assert.Equal(t, variant.Object, v.Type()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `{"id": 123, "meta": null, "tags": ["tag1"]}`, string(out)) + }) +} + +func TestAppendReset(t *testing.T) { + var b variant.Builder + + // First build + require.NoError(t, b.Append(map[string]any{"key": "value"})) + v1, err := b.Build() + require.NoError(t, err) + + out1, err := json.Marshal(v1) + require.NoError(t, err) + assert.JSONEq(t, `{"key": "value"}`, string(out1)) + v1 = v1.Clone() + + // Reset and build again + b.Reset() + require.NoError(t, b.Append([]int{1, 2, 3})) + v2, err := b.Build() + require.NoError(t, err) + + // First value should still be valid because we cloned it + // before calling Reset + assert.Equal(t, variant.Object, v1.Type()) + out1, err = json.Marshal(v1) + require.NoError(t, err) + assert.JSONEq(t, `{"key": "value"}`, string(out1)) + + // Second value should be different + assert.Equal(t, variant.Array, v2.Type()) + out2, err := json.Marshal(v2) + require.NoError(t, err) + assert.JSONEq(t, `[1, 2, 3]`, string(out2)) + + // Without cloning, the first value would be invalidated + v1Clone := v1.Clone() + b.Reset() + + out3, err := json.Marshal(v1Clone) + require.NoError(t, err) + assert.JSONEq(t, `{"key": "value"}`, string(out3)) +} diff --git a/parquet/variant/doc.go b/parquet/variant/doc.go new file mode 100644 index 00000000..29c6b777 --- /dev/null +++ b/parquet/variant/doc.go @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package variant provides an implementation of the Apache Parquet Variant data type. +// +// The Variant type is a flexible binary format designed to represent complex nested +// data structures with minimal overhead. It supports a wide range of primitive types +// as well as nested arrays and objects (similar to JSON). The format uses a memory-efficient +// binary representation with a separate metadata section for dictionary encoding of keys. +// +// # Key Components +// +// - [Value]: The primary type representing a variant value +// - [Metadata]: Contains information about the dictionary of keys +// - [Builder]: Used to construct variant values +// +// # Format Overview +// +// The variant format consists of two parts: +// +// 1. Metadata: A dictionary of keys used in objects +// 2. Value: The actual data payload +// +// Values can be one of the following types: +// +// - Primitive values (null, bool, int8/16/32/64, float32/64, etc.) +// - Short strings (less than 64 bytes) +// - Long strings and binary data +// - Date, time and timestamp values +// - Decimal values (4, 8, or 16 bytes) +// - Arrays of any variant value +// - Objects with key-value pairs +// +// # Working with Variants +// +// To create a variant value, use the Builder: +// +// var b variant.Builder +// b.Append(map[string]any{ +// "id": 123, +// "name": "example", +// "data": []any{1, 2, 3}, +// }) +// value, err := b.Build() +// +// To parse an existing variant value: +// +// v, err := variant.New(metadataBytes, valueBytes) +// +// You can access the data using the [Value.Value] method which returns the appropriate Go type: +// +// switch v.Type() { +// case variant.Object: +// obj := v.Value().(variant.ObjectValue) +// field, err := obj.ValueByKey("name") +// case variant.Array: +// arr := v.Value().(variant.ArrayValue) +// elem, err := arr.Value(0) +// case variant.String: +// s := v.Value().(string) +// case variant.Int64: +// i := v.Value().(int64) +// } +// +// You can also switch on the type of the result value from the [Value.Value] method: +// +// switch val := v.Value().(type) { +// case nil: +// // ... +// case int32: +// // ... +// case string: +// // ... +// case variant.ArrayValue: +// for i, item := range val.Values() { +// // item is a variant.Value +// } +// case variant.ObjectValue: +// for k, item := range val.Values() { +// // k is the field key +// // item is a variant.Value for that field +// } +// } +// +// Values can also be converted to JSON: +// +// jsonBytes, err := json.Marshal(v) +// +// # Low-level Construction +// +// For direct construction of complex nested structures, you can use the low-level +// methods: +// +// var b variant.Builder +// // Start an object +// start := b.Offset() +// fields := make([]variant.FieldEntry, 0) +// +// // Add a field +// fields = append(fields, b.NextField(start, "key")) +// b.AppendString("value") +// +// // Finish the object +// b.FinishObject(start, fields) +// +// value, err := b.Build() +// +// # Using Struct Tags +// +// When appending Go structs, you can use struct tags to control field names and +// encoding options: +// +// type Person struct { +// ID int `variant:"id"` +// Name string `variant:"name"` +// CreatedAt time.Time `variant:"timestamp,nanos,utc"` +// Internal string `variant:"-"` // Ignored field +// } +// +// # Reusing Builders +// +// When reusing a Builder for multiple values, use Reset() to clear it: +// +// var b variant.Builder +// v1, _ := b.Append(data1).Build() +// v1 = v1.Clone() // Clone before reset if you need to keep the value +// b.Reset() +// v2, _ := b.Append(data2).Build() +package variant diff --git a/arrow/extensions/variant/primitive_type_string.go b/parquet/variant/primitive_type_string.go similarity index 100% rename from arrow/extensions/variant/primitive_type_string.go rename to parquet/variant/primitive_type_string.go diff --git a/arrow/extensions/variant/utils.go b/parquet/variant/utils.go similarity index 98% rename from arrow/extensions/variant/utils.go rename to parquet/variant/utils.go index 523dce87..05dee06d 100644 --- a/arrow/extensions/variant/utils.go +++ b/parquet/variant/utils.go @@ -22,7 +22,7 @@ import ( "unsafe" "github.com/apache/arrow-go/v18/arrow/endian" - "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/apache/arrow-go/v18/parquet/internal/debug" ) func readLEU32(b []byte) uint32 { diff --git a/arrow/extensions/variant/variant.go b/parquet/variant/variant.go similarity index 80% rename from arrow/extensions/variant/variant.go rename to parquet/variant/variant.go index 8b1e3f4e..a600c6c7 100644 --- a/arrow/extensions/variant/variant.go +++ b/parquet/variant/variant.go @@ -32,13 +32,14 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/apache/arrow-go/v18/parquet/internal/debug" "github.com/google/uuid" ) //go:generate go tool stringer -type=BasicType -linecomment -output=basic_type_string.go //go:generate go tool stringer -type=PrimitiveType -linecomment -output=primitive_type_string.go +// BasicType represents the fundamental type category of a variant value. type BasicType int const ( @@ -53,6 +54,7 @@ func basicTypeFromHeader(hdr byte) BasicType { return BasicType(hdr & basicTypeMask) } +// PrimitiveType represents specific primitive data types within the variant format. type PrimitiveType int const ( @@ -84,6 +86,8 @@ func primitiveTypeFromHeader(hdr byte) PrimitiveType { return PrimitiveType((hdr >> basicTypeBits) & typeInfoMask) } +// Type represents the high-level variant data type. +// This is what applications typically use to identify the type of a variant value. type Type int const ( @@ -130,14 +134,19 @@ const ( ) var ( + // EmptyMetadataBytes contains a minimal valid metadata section with no dictionary entries. EmptyMetadataBytes = [3]byte{0x1, 0, 0} ) +// Metadata represents the dictionary part of a variant value, which stores +// the keys used in object values. type Metadata struct { data []byte keys [][]byte } +// NewMetadata creates a Metadata instance from a raw byte slice. +// It validates the metadata format and loads the key dictionary. func NewMetadata(data []byte) (Metadata, error) { m := Metadata{data: data} if len(data) < hdrSizeBytes+minOffsetSizeBytes*2 { @@ -166,6 +175,7 @@ func NewMetadata(data []byte) (Metadata, error) { return m, nil } +// Clone creates a deep copy of the metadata. func (m *Metadata) Clone() Metadata { return Metadata{ data: bytes.Clone(m.data), @@ -208,16 +218,25 @@ func (m *Metadata) loadDictionary(offsetSz uint8) (uint32, error) { return dictSize, nil } +// Bytes returns the raw byte representation of the metadata. func (m Metadata) Bytes() []byte { return m.data } -func (m Metadata) Version() uint8 { return m.data[0] & versionMask } +// Version returns the metadata format version. +func (m Metadata) Version() uint8 { return m.data[0] & versionMask } + +// SortedAndUnique returns whether the keys in the metadata dictionary are sorted and unique. func (m Metadata) SortedAndUnique() bool { return m.data[0]&sortedStrMask != 0 } + +// OffsetSize returns the size in bytes used to store offsets in the metadata. func (m Metadata) OffsetSize() uint8 { return ((m.data[0] >> offsetSizeBitShift) & offsetSizeMask) + 1 } +// DictionarySize returns the number of keys in the metadata dictionary. func (m Metadata) DictionarySize() uint32 { return uint32(len(m.keys)) } +// KeyAt returns the string key at the given dictionary ID. +// Returns an error if the ID is out of range. func (m Metadata) KeyAt(id uint32) (string, error) { if id >= uint32(len(m.keys)) { return "", fmt.Errorf("invalid variant metadata: id out of range: %d >= %d", @@ -227,6 +246,12 @@ func (m Metadata) KeyAt(id uint32) (string, error) { return unsafe.String(&m.keys[id][0], len(m.keys[id])), nil } +// IdFor returns the dictionary IDs for the given key. +// If the metadata is sorted and unique, this performs a binary search. +// Otherwise, it performs a linear search. +// +// If the metadata is not sorted and unique, then it's possible that multiple +// IDs will be returned for the same key. func (m Metadata) IdFor(key string) []uint32 { k := unsafe.Slice(unsafe.StringData(key), len(key)) @@ -249,15 +274,19 @@ func (m Metadata) IdFor(key string) []uint32 { return ret } +// DecimalValue represents a decimal number with a specified scale. +// The generic parameter T can be any supported variant decimal type (Decimal32, Decimal64, Decimal128). type DecimalValue[T decimal.DecimalTypes] struct { Scale uint8 Value decimal.Num[T] } +// MarshalJSON implements the json.Marshaler interface for DecimalValue. func (v DecimalValue[T]) MarshalJSON() ([]byte, error) { return []byte(v.Value.ToString(int32(v.Scale))), nil } +// ArrayValue represents an array of variant values. type ArrayValue struct { value []byte meta Metadata @@ -268,22 +297,31 @@ type ArrayValue struct { offsetStart uint8 } +// MarshalJSON implements the json.Marshaler interface for ArrayValue. func (v ArrayValue) MarshalJSON() ([]byte, error) { - return json.Marshal(v.Values()) + return json.Marshal(slices.Collect(v.Values())) } -func (v ArrayValue) NumElements() uint32 { return v.numElements } +// Len returns the number of elements in the array. +func (v ArrayValue) Len() uint32 { return v.numElements } -func (v ArrayValue) Values() []Value { - values := make([]Value, v.numElements) - for i := range v.numElements { - idx := uint32(v.offsetStart) + i*uint32(v.offsetSize) - offset := readLEU32(v.value[idx : idx+uint32(v.offsetSize)]) - values[i] = Value{value: v.value[v.dataStart+offset:], meta: v.meta} +// Values returns an iterator for the elements in the array, allowing +// for lazy evaluation of the offsets (for the situation where not all elements +// are iterated). +func (v ArrayValue) Values() iter.Seq[Value] { + return func(yield func(Value) bool) { + for i := range v.numElements { + idx := uint32(v.offsetStart) + i*uint32(v.offsetSize) + offset := readLEU32(v.value[idx : idx+uint32(v.offsetSize)]) + if !yield(Value{value: v.value[v.dataStart+offset:], meta: v.meta}) { + return + } + } } - return values } +// Value returns the Value at the specified index. +// Returns an error if the index is out of range. func (v ArrayValue) Value(i uint32) (Value, error) { if i >= v.numElements { return Value{}, fmt.Errorf("%w: invalid array value: index out of range: %d >= %d", @@ -296,6 +334,7 @@ func (v ArrayValue) Value(i uint32) (Value, error) { return Value{meta: v.meta, value: v.value[v.dataStart+offset:]}, nil } +// ObjectValue represents an object (map/dictionary) of key-value pairs. type ObjectValue struct { value []byte meta Metadata @@ -308,12 +347,17 @@ type ObjectValue struct { idStart uint8 } +// ObjectField represents a key-value pair in an object. type ObjectField struct { Key string Value Value } +// NumElements returns the number of fields in the object. func (v ObjectValue) NumElements() uint32 { return v.numElements } + +// ValueByKey returns the field with the specified key. +// Returns arrow.ErrNotFound if the key doesn't exist. func (v ObjectValue) ValueByKey(key string) (ObjectField, error) { n := v.numElements @@ -367,6 +411,8 @@ func (v ObjectValue) ValueByKey(key string) (ObjectField, error) { return ObjectField{}, arrow.ErrNotFound } +// FieldAt returns the field at the specified index. +// Returns an error if the index is out of range. func (v ObjectValue) FieldAt(i uint32) (ObjectField, error) { if i >= v.numElements { return ObjectField{}, fmt.Errorf("%w: invalid object value: index out of range: %d >= %d", @@ -388,6 +434,7 @@ func (v ObjectValue) FieldAt(i uint32) (ObjectField, error) { Value: Value{value: v.value[v.dataStart+offset:], meta: v.meta}}, nil } +// Values returns an iterator over all key-value pairs in the object. func (v ObjectValue) Values() iter.Seq2[string, Value] { return func(yield func(string, Value) bool) { for i := range v.numElements { @@ -408,6 +455,7 @@ func (v ObjectValue) Values() iter.Seq2[string, Value] { } } +// MarshalJSON implements the json.Marshaler interface for ObjectValue. func (v ObjectValue) MarshalJSON() ([]byte, error) { // for now we'll use a naive approach and just build a map // then marshal it. This is not the most efficient way to do this @@ -417,11 +465,13 @@ func (v ObjectValue) MarshalJSON() ([]byte, error) { return json.Marshal(mapping) } +// Value represents a variant value of any type. type Value struct { value []byte meta Metadata } +// NewWithMetadata creates a Value with the provided metadata and value bytes. func NewWithMetadata(meta Metadata, value []byte) (Value, error) { if len(value) == 0 { return Value{}, errors.New("invalid variant value: empty") @@ -430,6 +480,7 @@ func NewWithMetadata(meta Metadata, value []byte) (Value, error) { return Value{value: value, meta: meta}, nil } +// New creates a Value by parsing both the metadata and value bytes. func New(meta, value []byte) (Value, error) { m, err := NewMetadata(meta) if err != nil { @@ -439,16 +490,25 @@ func New(meta, value []byte) (Value, error) { return NewWithMetadata(m, value) } +// Bytes returns the raw byte representation of the value (excluding metadata). func (v Value) Bytes() []byte { return v.value } -func (v Value) Clone() Value { return Value{value: bytes.Clone(v.value)} } +// Clone creates a deep copy of the value including its metadata. +func (v Value) Clone() Value { + return Value{ + meta: v.meta.Clone(), + value: bytes.Clone(v.value)} +} +// Metadata returns the metadata associated with the value. func (v Value) Metadata() Metadata { return v.meta } +// BasicType returns the fundamental type category of the value. func (v Value) BasicType() BasicType { return basicTypeFromHeader(v.value[0]) } +// Type returns the specific data type of the value. func (v Value) Type() Type { switch t := v.BasicType(); t { case BasicPrimitive: @@ -507,6 +567,21 @@ func (v Value) Type() Type { } } +// Value returns the Go value representation of the variant. +// The returned type depends on the variant type: +// - Null: nil +// - Bool: bool +// - Int8/16/32/64: corresponding int type +// - Float/Double: float32/float64 +// - String: string +// - Binary: []byte +// - Decimal: DecimalValue +// - Date: arrow.Date32 +// - Time: arrow.Time64 +// - Timestamp: arrow.Timestamp +// - UUID: uuid.UUID +// - Object: ObjectValue +// - Array: ArrayValue func (v Value) Value() any { switch t := v.BasicType(); t { case BasicPrimitive: @@ -535,7 +610,7 @@ func (v Value) Value() any { PrimitiveTimestampNanos, PrimitiveTimestampNanosNTZ: return arrow.Timestamp(readExact[int64](v.value[1:])) case PrimitiveTimeMicrosNTZ: - return arrow.Time32(readExact[int32](v.value[1:])) + return arrow.Time64(readExact[int64](v.value[1:])) case PrimitiveUUID: debug.Assert(len(v.value[1:]) == 16, "invalid UUID length") return uuid.Must(uuid.FromBytes(v.value[1:])) @@ -625,6 +700,7 @@ func (v Value) Value() any { return nil } +// MarshalJSON implements the json.Marshaler interface for Value. func (v Value) MarshalJSON() ([]byte, error) { result := v.Value() switch t := result.(type) { diff --git a/arrow/extensions/variant/variant_test.go b/parquet/variant/variant_test.go similarity index 98% rename from arrow/extensions/variant/variant_test.go rename to parquet/variant/variant_test.go index a9fa5d22..44c582f7 100644 --- a/arrow/extensions/variant/variant_test.go +++ b/parquet/variant/variant_test.go @@ -26,7 +26,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/decimal" "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/arrow-go/v18/arrow/extensions/variant" + "github.com/apache/arrow-go/v18/parquet/variant" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -427,7 +427,7 @@ func TestArrayValues(t *testing.T) { assert.Equal(t, variant.Array, v.Type()) arr := v.Value().(variant.ArrayValue) - assert.EqualValues(t, 4, arr.NumElements()) + assert.EqualValues(t, 4, arr.Len()) elem0, err := arr.Value(0) require.NoError(t, err) @@ -463,7 +463,7 @@ func TestArrayValues(t *testing.T) { assert.Equal(t, variant.Array, v.Type()) arr := v.Value().(variant.ArrayValue) - assert.EqualValues(t, 0, arr.NumElements()) + assert.EqualValues(t, 0, arr.Len()) _, err := arr.Value(0) require.ErrorIs(t, err, arrow.ErrIndex) }) @@ -473,7 +473,7 @@ func TestArrayValues(t *testing.T) { assert.Equal(t, variant.Array, v.Type()) arr := v.Value().(variant.ArrayValue) - assert.EqualValues(t, 3, arr.NumElements()) + assert.EqualValues(t, 3, arr.Len()) elem0, err := arr.Value(0) require.NoError(t, err) diff --git a/parquet/variants/array.go b/parquet/variants/array.go deleted file mode 100644 index 0d4d0eb5..00000000 --- a/parquet/variants/array.go +++ /dev/null @@ -1,248 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "fmt" - "io" - "reflect" -) - -type arrayBuilder struct { - w io.Writer - buf bytes.Buffer - numItems int - offsets []int32 - nextOffset int32 - doneCB doneCB - mdb *metadataBuilder - built bool -} - -func newArrayBuilder(w io.Writer, mdb *metadataBuilder, doneCB doneCB) *arrayBuilder { - return &arrayBuilder{ - w: w, - doneCB: doneCB, - mdb: mdb, - } -} - -var _ ArrayBuilder = (*arrayBuilder)(nil) - -// Write marshals the provided value into the appropriate Variant type and appends it to this array. -func (a *arrayBuilder) Write(val any, opts ...MarshalOpts) error { - return writeCommon(val, &a.buf, a.mdb, a.recordOffset) -} - -// Appends all elements from a provided slice into this array -func (a *arrayBuilder) fromSlice(sl any, opts ...MarshalOpts) error { - val := reflect.ValueOf(sl) - if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { - return fmt.Errorf("not a slice: %v", val.Kind()) - } - for i := range val.Len() { - item := val.Index(i) - if err := a.Write(item.Interface(), opts...); err != nil { - return err - } - } - return nil -} - -func (a *arrayBuilder) recordOffset(size int) { - a.numItems++ - a.offsets = append(a.offsets, a.nextOffset) - a.nextOffset += int32(size) -} - -// Array returns a new ArrayBuilder associated with this array. The marshaled array -// will not be part of the array until the returned ArrayBuilder's Build method is called. -func (a *arrayBuilder) Array() ArrayBuilder { - ab := newArrayBuilder(&a.buf, a.mdb, a.recordOffset) - return ab -} - -// Object returns a new ObjectBuilder associated with this array. The marshaled object -// will not be part of the array until the returned ObjectBuilder's Build() method is called. -func (a *arrayBuilder) Object() ObjectBuilder { - ob := newObjectBuilder(&a.buf, a.mdb, a.recordOffset) - return ob -} - -// Build marshals an Array in Variant format and writes its header and value data to the -// underlying Writer. This prepends serialized metadata about the array (ie. its header, number -// of elements, and element offsets) to the runnig data buffer. -func (a *arrayBuilder) Build() error { - if a.built { - return errAlreadyBuilt - } - large := a.numItems > 0xFF - offsetSize := fieldOffsetSize(a.nextOffset) - - // Preallocate a buffer for the header, number of items, and the field offsets. - numItemsSize := 1 - if large { - numItemsSize = 4 - } - serializedOffsetSize := (a.numItems + 1) * offsetSize - serializedDataBuf := bytes.NewBuffer(make([]byte, 0, 1+serializedOffsetSize)) - - // Write the header and number of elements in the array - serializedDataBuf.WriteByte(a.header(large, offsetSize)) - encodeNumber(int64(a.numItems), numItemsSize, serializedDataBuf) - - // Write all of the field offsets, including the final offset which is the first index after all - // of the array's elements. - for _, o := range a.offsets { - encodeNumber(int64(o), offsetSize, serializedDataBuf) - } - encodeNumber(int64(a.nextOffset), offsetSize, serializedDataBuf) - - // Calculate the size of this entire array in bytes, then call the recordkeeping callback if configured. - hdrSize, _ := a.w.Write(serializedDataBuf.Bytes()) - dataSize, _ := a.w.Write(a.buf.Bytes()) - totalSize := hdrSize + dataSize - - if a.doneCB != nil { - a.doneCB(totalSize) - } - return nil -} - -func (a *arrayBuilder) header(large bool, offsetSize int) byte { - // Header is one byte: AAABCCDD - // * A: Unused - // * B: Is Large: whether there are more than 255 elements in this array or not - // * C: Field Offset Size Minus One: the number of bytes (minus one) used to encode each Field Offset - // * D: 0x03: the identifier of the Array basic type - hdr := byte(offsetSize - 1) - if large { - hdr |= (1 << 2) - } - // Shift the value header over 2 to allow for the lower to bits to - // denote the array basic type - hdr <<= 2 - hdr |= byte(BasicArray) - return hdr -} - -type arrayData struct { - size int - numElements int - firstOffsetIdx int - firstDataIdx int - offsetWidth int -} - -// Parses array data from a marshaled object (where the different encoded sections start, plus size in bytes -// and number of elements). This also ensures that the entire array exists in the raw buffer. -func getArrayData(raw []byte, offset int) (*arrayData, error) { - if err := checkBounds(raw, offset, offset); err != nil { - return nil, err - } - hdr := raw[offset] - if bt := BasicTypeFromHeader(hdr); bt != BasicArray { - return nil, fmt.Errorf("not an array: %s", bt) - } - - // Get the size of all encoded metadata fields. Bitshift by two to expose the 5 raw value header bits. - hdr >>= 2 - - offsetWidth := int(hdr&0x03) + 1 - numElementsWidth := 1 - if hdr&0x2 != 0 { - numElementsWidth = 4 - } - - numElements, err := readUint(raw, offset+1, numElementsWidth) - if err != nil { - return nil, fmt.Errorf("could not get number of elements: %v", err) - } - firstOffsetIdx := offset + 1 + numElementsWidth // Header plus width of # of elements - lastOffsetIdx := firstOffsetIdx + int(numElements)*offsetWidth - firstDataIdx := lastOffsetIdx + offsetWidth - - // Do some bounds checking to ensure that the entire array is present in the raw buffer. - lastDataOffset, err := readUint(raw, lastOffsetIdx, offsetWidth) - if err != nil { - return nil, fmt.Errorf("could not read last offset: %v", err) - } - lastDataIdx := firstDataIdx + int(lastDataOffset) - if err := checkBounds(raw, offset, lastDataIdx); err != nil { - return nil, fmt.Errorf("array is out of bounds: %v", err) - } - - return &arrayData{ - size: lastDataIdx - offset, - numElements: int(numElements), - firstOffsetIdx: firstOffsetIdx, - firstDataIdx: firstDataIdx, - offsetWidth: offsetWidth, - }, nil -} - -// Unmarshals a Variant array into the provided destination. The destination must be a pointer to either -// a slice, or to the "any" type (which is then populated with []any{}). Any passed in slice will be -// cleared before unmarshaling. -func unmarshalArray(raw []byte, md *decodedMetadata, offset int, dest reflect.Value) error { - data, err := getArrayData(raw, offset) - if err != nil { - return err - } - - if kind := dest.Kind(); kind != reflect.Pointer { - return fmt.Errorf("invalid dest, must be non-nil pointer (got kind %s)", kind) - } - if dest.IsNil() { - return fmt.Errorf("invalid dest, must be non-nil pointer") - } - - destElem := dest.Elem() - if destElem.Kind() != reflect.Slice && destElem.Kind() != reflect.Interface { - return fmt.Errorf("invalid dest, must be a pointer to a slice (got pointer to %s)", destElem.Kind()) - } - - // Reset the slice. - var ret reflect.Value - if destElem.Kind() == reflect.Slice { - ret = reflect.MakeSlice(destElem.Type(), 0, data.numElements) - } else if destElem.Kind() == reflect.Interface { - ret = reflect.MakeSlice(reflect.TypeOf([]any{}), 0, data.numElements) - } - - // Iterate through all the elements in the encoded variant. - for i := range data.numElements { - elemOffset, err := readUint(raw, data.firstOffsetIdx+data.offsetWidth*i, data.offsetWidth) - if err != nil { - return err - } - dataIdx := int(elemOffset) + data.firstDataIdx - if err := checkBounds(raw, dataIdx, dataIdx); err != nil { - return err - } - - // Unmarshal the element and append to the slice to return. - newElemValue := reflect.New(ret.Type().Elem()) - if err := unmarshalCommon(raw, md, dataIdx, newElemValue); err != nil { - return err - } - ret = reflect.Append(ret, newElemValue.Elem()) - } - destElem.Set(ret) - return nil -} diff --git a/parquet/variants/array_test.go b/parquet/variants/array_test.go deleted file mode 100644 index dc63a6c2..00000000 --- a/parquet/variants/array_test.go +++ /dev/null @@ -1,423 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestArrayPrimitives(t *testing.T) { - var buf bytes.Buffer - ab := newArrayBuilder(&buf, nil, nil) - toEncode := []any{true, 256, "hello", 10, []byte{'t', 'h', 'e', 'r', 'e'}} - for _, te := range toEncode { - if err := ab.Write(te); err != nil { - t.Fatalf("Write(%v): %v", te, err) - } - } - if err := ab.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - - wantBytes := []byte{ - 0b00011, // Basic type = array, field_offset_minus_one = 0, is_large = false - 0x05, // 5 elements in the array - 0x00, // Index of "true" - 0x01, // Index of 256 - 0x04, // Index of "hello" - 0x0A, // Index of 10 - 0x0C, // Index of []byte{t,h,e,r,e} - 0x16, // First index after last element - 0b100, // Primitive true (header & value) - 0b10000, // Primitive int16 - 0x00, - 0x01, // 256 encoded - 0b10101, // Basic short string, length = 5 - 'h', - 'e', - 'l', - 'l', - 'o', // "hello" encoded - 0b1100, // Primitive int8 - 0x0A, // 10 encoded - 0b111100, // Primitive binary - 0x05, - 0x00, - 0x00, - 0x00, // bytes of length 5 - 't', - 'h', - 'e', - 'r', - 'e', // []byte{t,h,e,r,e} encoded - } - - diffByteArrays(t, buf.Bytes(), wantBytes) - - // Decode and ensure we got what was expected. - var got []any - if err := unmarshalArray(buf.Bytes(), &decodedMetadata{}, 0, reflect.ValueOf(&got)); err != nil { - t.Fatalf("unmarshalArray(): %v", err) - } - want := []any{true, int64(256), "hello", int64(10), []byte{'t', 'h', 'e', 'r', 'e'}} - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("Incorrect returned array. Diff (-got +want):\n%s", diff) - } -} - -func TestArrayLarge(t *testing.T) { - var buf bytes.Buffer - ab := newArrayBuilder(&buf, nil, nil) - // Create 256 items, which triggers "is_large" (256 cannot be encoded in one byte) - for i := range 256 { - if err := ab.Write(true); err != nil { - t.Fatalf("Write(iter = %d): %v", i, err) - } - } - if err := ab.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - wantBytes := []byte{ - 0b10111, // Basic type = array, field_offset_minus_one = 1, is_large = true - 0x00, - 0x01, - 0x00, - 0x00, // 256 encoded in 4 bytes - } - // Create the offset section - for i := range 256 { - wantBytes = append(wantBytes, []byte{byte(i), 0}...) - } - wantBytes = append(wantBytes, []byte{0, 1}...) // 256- the first index after all elements. - - // Create 256 trues - for range 256 { - wantBytes = append(wantBytes, 4) // 0x04 is basic type true - } - diffByteArrays(t, buf.Bytes(), wantBytes) -} - -func TestNestedArray(t *testing.T) { - var buf bytes.Buffer - ab := newArrayBuilder(&buf, nil, nil) - - // Create a nested array so that we get {true, 1, {false, 256}, 3} - ab.Write(true) - ab.Write(1) - nested := ab.Array() - nested.Write(false) - nested.Write(256) - nested.Build() - ab.Write(3) - - if err := ab.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - - wantBytes := []byte{ - 0b00011, // Basic type = array, field_offset_minus_one = 0, is_large = false - 0x04, // 4 elements in the array - 0x00, // Index of "true" - 0x01, // Index of 1 - 0x03, // Index of nested array - 0x0C, // Index of 3 - 0x0E, // First index after last element - 0b100, // Primitive true (header & value) - 0b1100, // Primitive int8 - 0x01, // 1 encoded - // Beginning of nested array - 0b00011, // Nested array, basic type = array, field_offset_minus_one = 0, is_large = false - 0x02, // 2 elements in the array - 0x00, // Index of "false" - 0x01, // Index of 256 - 0x04, // First index after last element - 0b1000, // Primitive false (header & value) - 0b10000, // Primitive int16 - 0x00, - 0x01, // 256 encoded - // End of nested array - 0b1100, // Primitive int8 - 0x03, // 3 encoded - } - - diffByteArrays(t, buf.Bytes(), wantBytes) -} - -func TestFromSlice(t *testing.T) { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ab := newArrayBuilder(&buf, mdb, nil) - if err := ab.fromSlice([]any{1, false, 2}); err != nil { - t.Fatalf("Write(): %v", err) - } - if err := ab.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - wantBytes := []byte{ - 0x03, // Basic type = array, field offset size = 1, is_large = false, - 0x03, // 3 items in array - 0x00, // Index of 1 - 0x02, // Index of false - 0x03, // Index of 2 - 0x05, // First index after last element - 0b1100, 0x01, // Int8, value = 1 - 0b1000, // false (value and header) - 0b1100, 0x02, // Int8, value = 2 - } - diffByteArrays(t, buf.Bytes(), wantBytes) -} - -func TestNestedObject(t *testing.T) { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ab := newArrayBuilder(&buf, mdb, nil) - - ab.Write(true) - ab.Write(1) - nested := ab.Object() - nested.Write("a", false) - nested.Write("b", 256) - nested.Build() - - ab.Write(3) - - if err := ab.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - - wantBytes := []byte{ - 0b00011, // Basic type = array, field_offset_minus_one = 0, is_large = false - 0x04, // 4 elements in the array - 0x00, // Index of "true" - 0x01, // Index of 1 - 0x03, // Index of nested object - 0x0E, // Index of 3 - 0x10, // First index after the last element - 0b100, // Primitive true (header & value) - 0b1100, // Primitive int8 - 0x01, // 1 encoded - // Beginning of nested object - 0b00000010, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false - 0x02, // 2 elements - 0x00, // Field ID of "a" - 0x01, // FieldID of "b" - 0x00, // Index of "a"s value (false) - 0x01, // Index of "b"s value (256) - 0x04, // First index after the last value - 0b1000, // Primitive false (header & value) - 0b10000, // Primitive int16 - 0x00, - 0x01, // 256 encoded - // End of nested object - 0b1100, // Primitive int8 - 0x03, // 3 encoded - } - - diffByteArrays(t, buf.Bytes(), wantBytes) -} - -func checkErr(t *testing.T, wantErr bool, err error) { - t.Helper() - if err != nil { - if wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if wantErr { - t.Fatal("Got no error when one was expected") - } -} - -func TestGetArrayData(t *testing.T) { - cases := []struct { - name string - encoded []byte - offset int - want *arrayData - wantErr bool - }{ - { - name: "Array with offset", - encoded: []byte{ - 0x00, 0x00, // Offset bytes - 0b11, // Basic type = array, field offset size = 1, is_large = false, - 0x03, // 3 elements in the array, - 0x00, // Index of "true" - 0x01, // Index of 256 - 0x04, // Index of "hello" - 0x0A, // First index after last element - 0b100, // Primitive true - 0b10000, 0x00, 0x01, // Primitive int16 val = 256 - 0b10101, 'h', 'e', 'l', 'l', 'o', // Basic short string, length = 5 - }, - offset: 2, - want: &arrayData{ - size: 16, - numElements: 3, - firstOffsetIdx: 4, - firstDataIdx: 8, - offsetWidth: 1, - }, - }, - { - name: "Array with large widths", - encoded: []byte{ - 0b00011011, // Basic type = array, field offset size = 3, is_large = true - 0x03, 0x00, 0x00, 0x00, // 3 elements in the array - 0x00, 0x00, 0x00, // Index of "true" - 0x01, 0x00, 0x00, // Index of 256 - 0x04, 0x00, 0x00, // Index of "hello" - 0x0A, 0x00, 0x00, // First index after last element - 0b100, // Primitive true (header & value) - 0b10000, 0x00, 0x01, // Primitive int16 val = 256 - 0b10101, 'h', 'e', 'l', 'l', 'o', // Basic short string, length = 5 - }, - want: &arrayData{ - size: 27, - numElements: 3, - firstOffsetIdx: 5, - firstDataIdx: 17, - offsetWidth: 3, - }, - }, - { - name: "Not an array", - encoded: []byte{0x00, 0x00}, // Primitive nulls - wantErr: true, - }, - { - name: "Elements would be out of bounds", - encoded: []byte{0b11, 0x03, 0x00, 0x01, 0x04, 0x0A, 0b100, 0b10000, 0x00, 0x01 /* missing string */}, - wantErr: true, - }, - { - name: "Offset is out of bounds", - encoded: []byte{0b11, 0x01, 0x00, 0x01, 0b100}, // Array with one boolean - offset: 10, - wantErr: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - got, err := getArrayData(c.encoded, c.offset) - checkErr(t, c.wantErr, err) - if diff := cmp.Diff(got, c.want, cmp.AllowUnexported(arrayData{})); diff != "" { - t.Fatalf("Incorrect returned value. Diff (-got + want):\n%s", diff) - } - }) - } -} - -func TestUnmarshalArray(t *testing.T) { - cases := []struct { - name string - md *decodedMetadata - encoded []byte - offset int - // Normally the test will unmarhall into the type in want, but if this override is set, - // it'll try to unmarshal into something of this type. - overrideDecodeType reflect.Type - want any - wantErr bool - }{ - { - name: "Array with large widths and offset", - offset: 3, - encoded: []byte{ - 0x00, 0x00, 0x00, // 3 offset bytes - 0b00011011, // Basic type = array, field offset size = 3, is_large = true - 0x03, 0x00, 0x00, 0x00, // 3 elements in the array - 0x00, 0x00, 0x00, // Index of "true" - 0x01, 0x00, 0x00, // Index of 256 - 0x04, 0x00, 0x00, // Index of "hello" - 0x0A, 0x00, 0x00, // First index after last element - 0b100, // Primitive true (header & value) - 0b10000, 0x00, 0x01, // Primitive int16 val = 256 - 0b10101, 'h', 'e', 'l', 'l', 'o', // Basic short string, length = 5 - }, - want: []any{true, int64(256), "hello"}, - }, - { - name: "Unmarshal into typed array", - encoded: []byte{ - 0b011, // Basic type = array, field offset size = 1, is_large = false - 0x03, // 3 elements in the array - 0x00, 0x02, 0x04, 0x06, // Offsets for 3 integers and the index after the last element. - 0b1100, 0x01, // Primitive int8 val = 1 - 0b1100, 0x02, // Primitive int8 val = 2 - 0b1100, 0x03, // Primitive int8 val = 3 - }, - want: []int{1, 2, 3}, - }, - { - name: "Nested array", - encoded: []byte{ - 0b011, // Basic type = array, field offset size = 1, is_large = false - 0x01, // 1 element in the array - 0x00, 0x06, // Offsets for nested array and index after the last element - 0b011, // Nested array, field offset size = 1, is_large = false - 0x01, // 1 element in the array - 0x00, 0x02, // Offsets for 1 integer and index after the last element - 0b1100, 0x01, //primitive int8 val = 1 - }, - want: []any{[]any{int64(1)}}, - }, - { - name: "Invalid data", - encoded: []byte{0x00, 0x00}, - overrideDecodeType: reflect.TypeOf([]any{}), - wantErr: true, - }, - { - name: "Can't decode into primitive", - encoded: []byte{ - 0b011, // Basic type = array, field offset size = 1, is_large = false - 0x03, // 3 elements in the array - 0x00, 0x02, 0x04, 0x06, // Offsets for 3 integers and the index after the last element. - 0b1100, 0x01, // Primitive int8 val = 1 - 0b1100, 0x02, // Primitive int8 val = 2 - 0b1100, 0x03, // Primitive int8 val = 3 - }, - overrideDecodeType: reflect.TypeOf(""), - wantErr: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - typ := reflect.TypeOf(c.want) - if c.overrideDecodeType != nil { - typ = c.overrideDecodeType - } - got := reflect.New(typ) - if err := unmarshalArray(c.encoded, c.md, c.offset, got); err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatalf("Got no error when one was expected") - } - diff(t, got.Elem().Interface(), c.want) - }) - } -} diff --git a/parquet/variants/builder.go b/parquet/variants/builder.go deleted file mode 100644 index de07146e..00000000 --- a/parquet/variants/builder.go +++ /dev/null @@ -1,200 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "errors" - "fmt" - "io" - "reflect" -) - -// ArrayBuilder provides a mechanism to build a Variant encoded array. -type ArrayBuilder interface { - Builder - Write(val any, opts ...MarshalOpts) error - Array() ArrayBuilder - Object() ObjectBuilder -} - -// ObjectBuilder provides a mechanism to build a Variant encoded object -type ObjectBuilder interface { - Builder - Write(key string, val any, opts ...MarshalOpts) error - Array(key string) (ArrayBuilder, error) - Object(key string) (ObjectBuilder, error) -} - -// Builder provides a mechanism to build something that's Variant encoded -type Builder interface { - Build() error -} - -// Options for marshaling time types into a Variant. -type MarshalOpts int - -const ( - MarshalTimeNanos MarshalOpts = 1 << iota - MarshalTimeNTZ - MarshalAsDate - MarshalAsTime - MarshalAsTimestamp -) - -var errAlreadyBuilt = errors.New("component already built") - -// VariantBuilder is a helper to build and encode a Variant -type VariantBuilder struct { - buf bytes.Buffer - builder Builder - typ BasicType - mdb *metadataBuilder - built bool -} - -func NewBuilder() *VariantBuilder { - return &VariantBuilder{ - typ: BasicUndefined, - mdb: newMetadataBuilder(), - } -} - -// Marshals a Variant from a provided value. This will automatically convert Go primitives -// into equivalent Variant values: -// - Slice/Array: Converted into Variant Arrays, with the exception of []byte which is a Variant Binary Primitive -// - map[string]any: Converted into Variant Objects -// - Structs: Converted into Variant Objects. Keys are either the exported struct fields, or the value present -// in the `variant` field annotation. -// - Go primitives: Converted into Variant primitives -func Marshal(val any, opts ...MarshalOpts) (*MarshaledVariant, error) { - b := NewBuilder() - if err := writeCommon(val, &b.buf, b.mdb, nil); err != nil { - return nil, err - } - ev, err := b.Build() - if err != nil { - return nil, err - } - return ev, nil -} - -func (vb *VariantBuilder) check() error { - if vb.built { - return errors.New("variant has already been built") - } - if vb.typ != BasicUndefined { - return fmt.Errorf("variant type has already been started as a %q", vb.typ) - } - return nil -} - -// Callback to record the number of bytes written. -type doneCB func(int) - -// Common functionalities in writing Variant encoded data. This will be recursed into from various places. -func writeCommon(val any, buf io.Writer, mdb *metadataBuilder, doneCB doneCB, opts ...MarshalOpts) error { - typ := kindFromValue(val) - switch typ { - case BasicPrimitive: - b, err := marshalPrimitive(val, buf, opts...) - if err != nil { - return fmt.Errorf("marshalPrimitive(): %v", err) - } - if doneCB != nil { - doneCB(b) - } - case BasicObject: - // Objects can be built from structs or maps. - ob := newObjectBuilder(buf, mdb, doneCB) - if reflect.ValueOf(val).Kind() == reflect.Map { - if err := ob.fromMap(val); err != nil { - return err - } - } else { - // No need to check if this is a struct- kindFromValue() has done that already. - if err := ob.fromStruct(val); err != nil { - return err - } - } - if err := ob.Build(); err != nil { - return err - } - case BasicArray: - ab := newArrayBuilder(buf, mdb, doneCB) - if err := ab.fromSlice(val, opts...); err != nil { - return err - } - if err := ab.Build(); err != nil { - return err - } - default: - return fmt.Errorf("unknown basic type: %s", typ) - } - return nil -} - -// Sets this Variant as a primitive, and writes the provided value. -func (vb *VariantBuilder) Primitive(val any, opts ...MarshalOpts) error { - if err := vb.check(); err != nil { - return err - } - vb.typ = BasicPrimitive - _, err := marshalPrimitive(val, &vb.buf, opts...) - return err -} - -// Sets this Variant as an Object and returns an ObjectBuilder. -func (vb *VariantBuilder) Object() (ObjectBuilder, error) { - if err := vb.check(); err != nil { - return nil, err - } - ob := newObjectBuilder(&vb.buf, vb.mdb, nil) - vb.typ = BasicObject - vb.builder = ob - return ob, nil -} - -// Sets this Variant as an Array and returns an ArrayBuilder. -func (vb *VariantBuilder) Array() (ArrayBuilder, error) { - if err := vb.check(); err != nil { - return nil, err - } - ab := newArrayBuilder(&vb.buf, vb.mdb, nil) - vb.typ = BasicArray - vb.builder = ab - return ab, nil -} - -// Builds the Variant -func (vb *VariantBuilder) Build() (*MarshaledVariant, error) { - // Indicate that all building has completed to prevent any mutation. - vb.built = true - - var encoded MarshaledVariant - encoded.Metadata = vb.mdb.Build() - - // Build an object or an array if necessary - if vb.builder != nil { - if err := vb.builder.Build(); err != nil && err != errAlreadyBuilt { - return nil, err - } - } - - encoded.Value = vb.buf.Bytes() - return &encoded, nil -} diff --git a/parquet/variants/builder_test.go b/parquet/variants/builder_test.go deleted file mode 100644 index c91ab1eb..00000000 --- a/parquet/variants/builder_test.go +++ /dev/null @@ -1,307 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "testing" -) - -func TestVariantMarshal(t *testing.T) { - emptyMetadata := []byte{0x01, 0x00, 0x00} - cases := []struct { - name string - val any - wantEncoded *MarshaledVariant - wantErr bool - }{ - { - name: "Primitive", - val: 123, - wantEncoded: func() *MarshaledVariant { - var buf bytes.Buffer - marshalPrimitive(123, &buf) - return &MarshaledVariant{ - Metadata: emptyMetadata, - Value: buf.Bytes(), - } - }(), - }, - { - name: "Array", - val: []any{123, "hello", false, []any{321, "olleh", true}}, - wantEncoded: func() *MarshaledVariant { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ab := newArrayBuilder(&buf, mdb, nil) - ab.Write(123) - ab.Write("hello") - ab.Write(false) - - sub := ab.Array() - sub.Write(321) - sub.Write("olleh") - sub.Write(true) - sub.Build() - - ab.Build() - return &MarshaledVariant{ - Metadata: emptyMetadata, - Value: buf.Bytes(), - } - }(), - }, - { - name: "Struct", - val: struct { - FieldKey string - TagKey int `variant:"tag_key"` - Arr []int `variant:"array"` - unexported bool - }{"hello", 1, []int{1, 2, 3}, false}, - wantEncoded: func() *MarshaledVariant { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ob := newObjectBuilder(&buf, mdb, nil) - ob.Write("FieldKey", "hello") - ob.Write("tag_key", 1) - - ab, err := ob.Array("array") - if err != nil { - t.Fatalf("Array(): %v", err) - } - ab.Write(1) - ab.Write(2) - ab.Write(3) - ab.Build() - - ob.Build() - return &MarshaledVariant{ - Metadata: mdb.Build(), - Value: buf.Bytes(), - } - }(), - }, - { - name: "Struct pointer", - val: &struct { - Field1 string - Field2 int - }{"hello", 123}, - wantEncoded: func() *MarshaledVariant { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ob := newObjectBuilder(&buf, mdb, nil) - ob.Write("Field1", "hello") - ob.Write("Field2", 123) - ob.Build() - return &MarshaledVariant{ - Metadata: mdb.Build(), - Value: buf.Bytes(), - } - }(), - }, - { - name: "Valid map", - // Map iteration order is undefined so only use one key here to test. Rely on the tests in object.go to cover maps more fully. - val: map[string]int{"solitary_key": 1}, - wantEncoded: func() *MarshaledVariant { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ob := newObjectBuilder(&buf, mdb, nil) - ob.Write("solitary_key", 1) - ob.Build() - return &MarshaledVariant{ - Metadata: mdb.Build(), - Value: buf.Bytes(), - } - }(), - }, - { - name: "Nil", - val: nil, - wantEncoded: &MarshaledVariant{ - Metadata: emptyMetadata, - Value: []byte{0x00}, - }, - }, - { - name: "Invalid map", - val: map[int]string{1: "hello"}, - wantErr: true, - }, - { - name: "Invalid value", - val: func() {}, - wantErr: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - encoded, err := Marshal(c.val) - if err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatalf("Got no error when one was expected") - } - diff(t, encoded, c.wantEncoded) - }) - } -} - -func TestVariantBuilderPrimitive(t *testing.T) { - vb := NewBuilder() - vb.Primitive(1) - ev, err := vb.Build() - if err != nil { - t.Fatalf("Build(): %v", err) - } - - // Check metadata - md := newMetadataBuilder().Build() - diffByteArrays(t, ev.Metadata, md) - - // Check value - var buf bytes.Buffer - marshalPrimitive(1, &buf) - diffByteArrays(t, ev.Value, buf.Bytes()) -} - -func TestVariantBuilderArray(t *testing.T) { - vb := NewBuilder() - ab, err := vb.Array() - if err != nil { - t.Fatalf("Array(): %v", err) - } - - buildArray := func(ab ArrayBuilder) { - ab.Write(1) - ab.Write(true) - nested := ab.Array() - nested.Write("hello") - nested.Build() - } - - buildArray(ab) - - ev, err := vb.Build() - if err != nil { - t.Fatalf("Build(): %v", err) - } - - // Check metadata - md := newMetadataBuilder().Build() - diffByteArrays(t, ev.Metadata, md) - - // Check value - wantArray := func() []byte { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ab := newArrayBuilder(&buf, mdb, nil) - buildArray(ab) - ab.Build() - return buf.Bytes() - }() - diffByteArrays(t, ev.Value, wantArray) -} - -func TestVariantBuilderObject(t *testing.T) { - vb := NewBuilder() - ob, err := vb.Object() - if err != nil { - t.Fatalf("Object(): %v", err) - } - - buildObject := func(ob ObjectBuilder) { - ob.Write("b", 1) - ob.Write("c", 2) - ob.Write("a", 3) - nested, _ := ob.Object("d") - nested.Write("a", true) - nested.Write("e", "nested") - nested.Build() - } - buildObject(ob) - ev, err := vb.Build() - if err != nil { - t.Fatalf("Build(): %v", err) - } - - wantEncoded := func() ([]byte, []byte) { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ob := newObjectBuilder(&buf, mdb, nil) - buildObject(ob) - ob.Build() - md := mdb.Build() - return md, buf.Bytes() - } - - wantMetadata, wantValue := wantEncoded() - - diffByteArrays(t, ev.Metadata, wantMetadata) - diffByteArrays(t, ev.Value, wantValue) -} - -func TestCannotChangeVariantType(t *testing.T) { - vbPrim := NewBuilder() - if err := vbPrim.Primitive(1); err != nil { - t.Fatalf("Write first time: %v", err) - } - - if err := vbPrim.Primitive(2); err == nil { - t.Fatal("Primitive already started") - } - if _, err := vbPrim.Array(); err == nil { - t.Fatal("Prmitive already started") - } - if _, err := vbPrim.Object(); err == nil { - t.Fatal("Primitive already started") - } - - vbArr := NewBuilder() - if _, err := vbArr.Array(); err != nil { - t.Fatalf("Array first time: %v", err) - } - if err := vbArr.Primitive(1); err == nil { - t.Fatal("Array already started") - } - if _, err := vbArr.Array(); err == nil { - t.Fatalf("Array already started") - } - if _, err := vbArr.Object(); err == nil { - t.Fatalf("Array already started") - } - - vbObj := NewBuilder() - if _, err := vbObj.Object(); err != nil { - t.Fatalf("Object first time: %v", err) - } - if err := vbObj.Primitive(1); err == nil { - t.Fatal("Object alrady started") - } - if _, err := vbObj.Array(); err == nil { - t.Fatal("Object already started") - } - if _, err := vbObj.Object(); err == nil { - t.Fatal("Object already started") - } -} diff --git a/parquet/variants/decoder.go b/parquet/variants/decoder.go deleted file mode 100644 index 7cdd1c4e..00000000 --- a/parquet/variants/decoder.go +++ /dev/null @@ -1,79 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "errors" - "fmt" - "reflect" -) - -// Decode provides a way to decode an encoded Variant into a Go type. The mapping of Variant -// types to Go types is: -// - Null: nil -// - Boolean: bool -// - Int8, Int16, Int32, Int64: int64 -// - Float: float32 -// - Double: float64 -// - Time: time.Time -// - Timestamp (all varieties): time.Time -// - String: string -// - Binary: []byte -// - UUID: uuid.UUID -// - Array: []any -// - Object: map[string]any - -// DecodeInto provides a way to decode an encoded Variant into a provided non-nil pointer dest if possible. -// For structs, this will attempt to match field names or fields annotated with the `variant` field annotation. -// TODO(Finish this comment) -func DecodeInto(encoded *MarshaledVariant, dest any) error { - destVal := reflect.ValueOf(dest) - if kind := destVal.Kind(); kind != reflect.Pointer { - return fmt.Errorf("dest must be a pointer (got %s)", kind) - } - if destVal.IsNil() { - return errors.New("dest pointer must not be nil") - } - - md, err := decodeMetadata(encoded.Metadata) - if err != nil { - return fmt.Errorf("could not decode metadata: %v", err) - } - - return unmarshalCommon(encoded.Value, md, 0, destVal) -} - -func unmarshalCommon(raw []byte, md *decodedMetadata, offset int, dest reflect.Value) error { - if err := checkBounds(raw, offset, offset); err != nil { - return err - } - switch bt := BasicTypeFromHeader(raw[offset]); bt { - case BasicPrimitive, BasicShortString: - if err := unmarshalPrimitive(raw, offset, dest); err != nil { - return fmt.Errorf("could not decode primitive: %v", err) - } - case BasicArray: - if err := unmarshalArray(raw, md, offset, dest); err != nil { - return fmt.Errorf("could not decode array: %v", err) - } - case BasicObject: - if err := unmarshalObject(raw, md, offset, dest); err != nil { - return fmt.Errorf("could not decode object: %v", err) - } - } - return nil -} diff --git a/parquet/variants/decoder_test.go b/parquet/variants/decoder_test.go deleted file mode 100644 index 3db54e4f..00000000 --- a/parquet/variants/decoder_test.go +++ /dev/null @@ -1,259 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func mustEncodeVariant(t *testing.T, val any) *MarshaledVariant { - t.Helper() - ev, err := Marshal(val) - if err != nil { - t.Fatalf("Marshal(): %v", err) - } - return ev -} - -func TestUnmarshal(t *testing.T) { - emptyMetadata := func() []byte { - return []byte{0x01, 0x00, 0x00} // Basic, but valid, metadata with no keys. - } - cases := []struct { - name string - encoded *MarshaledVariant - want any - wantErr bool - }{ - { - name: "Primitive", - encoded: mustEncodeVariant(t, "hello"), - want: "hello", - }, - { - name: "Array", - encoded: mustEncodeVariant(t, []any{"hello", 256, true}), - want: []any{"hello", int64(256), true}, - }, - { - name: "Object into map", - encoded: mustEncodeVariant(t, struct { - Key1 int `variant:"key1"` - Key2 []byte `variant:"key2"` - Key3 string `variant:"key3"` - }{1234, []byte{'b', 'y', 't', 'e'}, "hello"}), - want: map[string]any{ - "key1": int64(1234), - "key2": []byte{'b', 'y', 't', 'e'}, - "key3": "hello", - }, - }, - { - name: "Complex", - encoded: mustEncodeVariant(t, []any{ - 1234, struct { - Key1 string `variant:"key1"` - Arr []any `variant:"array"` - }{"hello", []any{false, true, "hello"}}, - "fin"}), - want: []any{ - int64(1234), - map[string]any{ - "key1": "hello", - "array": []any{false, true, "hello"}, - }, - "fin", - }, - }, - { - name: "Nil primitive", - encoded: mustEncodeVariant(t, nil), - want: nil, - }, - { - name: "Missing metadata", - encoded: &MarshaledVariant{ - Value: []byte{0x00}, // Primitive nil - }, - wantErr: true, - }, - { - name: "Missing Value", - encoded: &MarshaledVariant{ - Metadata: emptyMetadata(), - }, - wantErr: true, - }, - { - name: "Malformed array", - encoded: &MarshaledVariant{ - Metadata: emptyMetadata(), - Value: []byte{0x03, 0x02}, // Array, length 2, no other items. - }, - wantErr: true, - }, - { - name: "Object missing key", - encoded: func() *MarshaledVariant { - builder := NewBuilder() - ob, err := builder.Object() - if err != nil { - t.Fatalf("Object(): %v", err) - } - ob.Write("key", "value") - encoded, err := builder.Build() - if err != nil { - t.Fatalf("Build(): %v", err) - } - encoded.Metadata = emptyMetadata() - return encoded - }(), - wantErr: true, - }, - { - name: "Malformed object", - encoded: func() *MarshaledVariant { - builder := NewBuilder() - ob, err := builder.Object() - if err != nil { - t.Fatalf("Object(): %v", err) - } - ob.Write("key", "value") - encoded, err := builder.Build() - if err != nil { - t.Fatalf("Build(): %v", err) - } - encoded.Value = encoded.Value[:len(encoded.Value)-2] // Truncate - return encoded - }(), - wantErr: true, - }, - { - name: "Malformed primitive", - encoded: &MarshaledVariant{ - Metadata: emptyMetadata(), - Value: []byte{0xFD, 'a'}, // Short string, length 63 - }, - wantErr: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var got any - err := DecodeInto(c.encoded, &got) - if err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatalf("Got no error when one was expected") - } - if diff := cmp.Diff(got, c.want); diff != "" { - t.Fatalf("Incorrect data returned. Diff (-got, +want):\n%s", diff) - } - }) - } -} - -func TestUnmarshalWithTypes(t *testing.T) { - cases := []struct { - name string - ev *MarshaledVariant - decodeType reflect.Type - want reflect.Value - wantErr bool - }{ - { - name: "Primitive int", - ev: mustEncodeVariant(t, 1), - decodeType: reflect.TypeOf(int(0)), - want: reflect.ValueOf(int(1)), - }, - { - name: "Primitive string", - ev: mustEncodeVariant(t, "hello"), - decodeType: reflect.TypeOf(""), - want: reflect.ValueOf("hello"), - }, - { - name: "Nested array", - ev: mustEncodeVariant(t, []any{[]any{1}}), - decodeType: reflect.TypeOf([]any{}), - want: reflect.ValueOf([]any{[]any{int64(1)}}), - }, - { - name: "Complex object into map", - ev: mustEncodeVariant(t, map[string]any{ - "key1": 123, - "key2": []any{true, false, "hello", []any{1, 2, 3}}, - "key3": map[string]any{ - "key1": "foo", - "key4": "bar", - }, - }), - decodeType: reflect.TypeOf(map[string]any{}), - want: reflect.ValueOf(map[string]any{ - "key1": int64(123), - "key2": []any{true, false, "hello", []any{int64(1), int64(2), int64(3)}}, - "key3": map[string]any{ - "key1": "foo", - "key4": "bar", - }, - }), - }, - { - name: "Object to map", - ev: mustEncodeVariant(t, map[string]int{"a": 1, "b": 2}), - decodeType: reflect.TypeOf(map[string]any{}), - want: reflect.ValueOf(map[string]any{"a": int64(1), "b": int64(2)}), - }, - // TODO: add tests to decode into struct - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - decodedMD, err := decodeMetadata(c.ev.Metadata) - if err != nil { - t.Fatalf("decodeMetadata(): %v", err) - } - - var decoded reflect.Value - if c.decodeType.Kind() == reflect.Map { - // Create a pointer to the map. - underlying := reflect.MakeMap(c.decodeType) - decoded = reflect.New(c.decodeType) - decoded.Elem().Set(underlying) - } else { - decoded = reflect.New(c.decodeType) - } - if err := unmarshalCommon(c.ev.Value, decodedMD, 0, decoded); err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatalf("Got no error when one was expected") - } - diff(t, decoded.Elem().Interface(), c.want.Interface()) - }) - } -} diff --git a/parquet/variants/doc.go b/parquet/variants/doc.go deleted file mode 100644 index 1faa329d..00000000 --- a/parquet/variants/doc.go +++ /dev/null @@ -1,38 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This package contains utilities to marshal and unmarshal data to and from the Variant -// encoding format as described in -// [the Variant encoding spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md -// -// There are two main ways to create a marshaled Variant: -// -// 1. Using `variants.Marshal()`. Simply pass in the value you'd like to marshal, and all the -// type inference is done for you. Structs and string-keyed maps will be converted into -// objects, slices will be converted into arrays, and primitives will be encoded with the -// appropriate primitive type. This will feel like the JSON library's `Marshal()`. -// 2. Using `variants.NewBuilder()`. This allows you to build out your Variant bit by bit. -// -// To convert from a marshaled Variant back to a type, use `variants.Unmarshal()`. Like the JSON -// `Unmarshal()`, this takes in a pointer to a value to "fill up." Objects can be unmarshaled into -// either structs or string-keyed maps, arrays can be unmarshaled into slices, and primitives into -// primitives. -// -// This library does have a few shortcomings, namely in that the Metadata is always marshaled with -// unordered keys (done to make marshaling considerably easier to code up), and that currently, -// unmarshaling decodes the whole Variant, not just a specific field. - -package variants diff --git a/parquet/variants/metadata.go b/parquet/variants/metadata.go deleted file mode 100644 index 3c402b8f..00000000 --- a/parquet/variants/metadata.go +++ /dev/null @@ -1,152 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "fmt" - "strings" -) - -const ( - versionMask = 0x0F - sortedMask = 0x10 - offsetMask = 0xC0 - - version = 0x01 -) - -type decodedMetadata struct { - keys []string -} - -func (d *decodedMetadata) At(i int) (string, bool) { - if i >= len(d.keys) { - return "", false - } - return d.keys[i], true -} - -func decodeMetadata(raw []byte) (*decodedMetadata, error) { - if len(raw) == 0 { - return nil, fmt.Errorf("invalid metadata") - } - - // Ensure the version is something recognizable. - if ver := raw[0] & versionMask; ver != version { - return nil, fmt.Errorf("invalid version (got %d, want %d)", ver, version) - } - - // Get the offset size. - offsetSize := int((raw[0] >> 6) + 1) - - // Get the number of elements in the dictionary. - elems, err := readUint(raw, 1, offsetSize) - if err != nil { - return nil, err - } - - var keys []string - if elems > 0 { - keys = make([]string, int(elems)) - for i := range int(elems) { - // Offset here is the first index of the offset list, which is the - // first element after the header and the size. - offset := offsetSize + 1 - raw, err := readNthItem(raw, offset, i, offsetSize, int(elems)) - if err != nil { - return nil, err - } - keys[i] = string(raw) - } - } - return &decodedMetadata{keys: keys}, err -} - -type metadataBuilder struct { - keyToIdx map[string]int - utf8Keys [][]byte - keyBytes int -} - -func newMetadataBuilder() *metadataBuilder { - return &metadataBuilder{ - keyToIdx: make(map[string]int), - } -} - -func (m *metadataBuilder) Build() []byte { - // Build the header. - hdr := byte(version) - offsetSize := m.calculateOffsetBytes() - hdr |= byte(offsetSize-1) << 6 - - mdSize := 1 + offsetSize*(len(m.utf8Keys)+1) + m.keyBytes - - buf := bytes.NewBuffer(make([]byte, 0, mdSize)) - buf.WriteByte(hdr) - - // Write the number of elements in the dictionary. - encodeNumber(int64(len(m.utf8Keys)), offsetSize, buf) - - // Write all of the offsets. - var currOffset int64 - for _, k := range m.utf8Keys { - encodeNumber(currOffset, offsetSize, buf) - currOffset += int64(len(k)) - } - encodeNumber(currOffset, offsetSize, buf) - - // Write all of the keys. - for _, k := range m.utf8Keys { - buf.Write(k) - } - - return buf.Bytes() -} - -func (m *metadataBuilder) calculateOffsetBytes() int { - maxNum := m.keyBytes + 1 - if dictLen := len(m.utf8Keys); dictLen > maxNum { - maxNum = dictLen - } - return fieldOffsetSize(int32(maxNum)) -} - -// Add adds a key to the metadata dictionary if not already present, and returns the index -// that the key is present. -func (m *metadataBuilder) Add(key string) int { - // Key already present, nothing to do. - if idx, ok := m.keyToIdx[key]; ok { - return idx - } - - // Ensure the passed in string is in UTF8 form (replacing invalid sequences with - // a replacement character), and append to the key slice. - keyBytes := []byte(strings.ToValidUTF8(key, "\uFFFD")) - idx := len(m.utf8Keys) - m.keyToIdx[key] = idx - m.utf8Keys = append(m.utf8Keys, keyBytes) - m.keyBytes += len(keyBytes) - - return idx -} - -func (m *metadataBuilder) KeyID(key string) (int, bool) { - id, ok := m.keyToIdx[key] - return id, ok -} diff --git a/parquet/variants/metadata_test.go b/parquet/variants/metadata_test.go deleted file mode 100644 index 66dac6ad..00000000 --- a/parquet/variants/metadata_test.go +++ /dev/null @@ -1,282 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "math/rand" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestBuildMetadata(t *testing.T) { - cases := []struct { - name string - keys []string - wantKeys []string - wantEncoded []byte - }{ - { - name: "No keys added", - wantEncoded: []byte{ - 0b00000001, // Header: offset_size = 1, not sorted, version = 1 - 0x00, // Dictionary size of 0 - 0x00, // First (and last) index of elements in the empty dictionary - }, - }, - { - name: "Small number of items", - keys: []string{"b", "c", "a"}, - wantKeys: []string{"b", "c", "a"}, - wantEncoded: []byte{ - 0b00000001, // Header: offset_size = 1, not sorted, version = 1 - 0x03, // Dictionary size of 3 - 0x00, // Index of first item "b" - 0x01, // Index of second item "c" - 0x02, // Index of third item "a" - 0x03, // First index after the last item - 'b', - 'c', - 'a', - }, - }, - { - name: "Dedupe similar keys", - keys: []string{"b", "c", "a", "a", "a", "a", "b", "b", "c", "c", "c"}, - wantKeys: []string{"b", "c", "a"}, - wantEncoded: []byte{ - 0b00000001, // Header: offset_size = 1, not sorted, version = 1 - 0x03, // Dictionary size of 3 - 0x00, // Index of first item "b" - 0x01, // Index of second item "c" - 0x02, // Index of third item "a" - 0x03, // First index after the last item - 'b', - 'c', - 'a', - }, - }, - { - name: "Large number of keys (encoded in more than one byte)", - keys: func() []string { - keys := make([]string, 26*26) - idx := 0 - for i := range 26 { - for j := range 26 { - keys[idx] = string([]byte{byte('a' + i), byte('a' + j)}) - idx++ - } - } - return keys - }(), - wantKeys: largeKeysString(), - wantEncoded: largeEncodedMetadata(), - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - mdb := newMetadataBuilder() - for _, k := range c.keys { - mdb.Add(k) - } - md := mdb.Build() - - diffByteArrays(t, md, c.wantEncoded) - - // Decode and check keys. This does a bit of double duty and a lot of these cases are - // covered in TestDecodeMetadata below, but it is useful to prove that anything that - // can be encoded can also be decoded. - decoded, err := decodeMetadata(md) - if err != nil { - t.Fatalf("decodeMetadata(): %v", err) - } - if diff := cmp.Diff(decoded.keys, c.wantKeys); diff != "" { - t.Fatalf("Received incorrect keys. Diff (-want +got):\n%s", diff) - } - }) - } -} - -// Helpers to create a Metadata struct that has a large list of keys (26^2) that will -// take more than one byte to encode. -func largeListKeyBytes() [][]byte { - keys := make([][]byte, 26*26) - idx := 0 - for i := range 26 { - for j := range 26 { - keys[idx] = []byte{byte('a' + i), byte('a' + j)} - idx++ - } - } - return keys -} - -func largeEncodedMetadata() []byte { - // Offset size = 2 - // Total size of encoded metadata is: - // * Header: 1 - // * Number of elements: 2 - // * Offset table: (26*26 + 1)*2 - // * Elements: 26*26*2 - buf := bytes.NewBuffer(make([]byte, 0, 1+2+(26*26+1)*2+(26*26*2))) - buf.WriteByte(0b01000001) // offset_size_minus_one = 1, is_sorted = false, version = 1 - - encodeNumber(676, 2, buf) // Encode the number of elements - - // Encode the offsets. NB: each key is 2 bytes. - for i := range 676 + 1 { - encodeNumber(int64(i*2), 2, buf) - } - - for _, k := range largeListKeyBytes() { - buf.Write(k) - } - - return buf.Bytes() -} - -func largeKeysString() []string { - rawKeys := largeListKeyBytes() - keys := make([]string, len(rawKeys)) - for i, k := range rawKeys { - keys[i] = string(k) - } - return keys -} - -// This test does duplicate some coverage provided in TestBuildMetadata, but is specifically -// focused on the decoding side of things. -func TestDecodeMetadata(t *testing.T) { - cases := []struct { - name string - raw []byte - want []string - wantErr bool - }{ - { - name: "Valid metadata with no elements", - raw: []byte{ - 0x01, // Base header, version = 1, - 0x00, // Zero length - 0x00, // First and last element - }, - }, - { - name: "Valid metadata, large number of elements", - raw: largeEncodedMetadata(), - want: largeKeysString(), - }, - { - name: "Zero length metadata", - wantErr: true, - }, - { - name: "Invalid version: 0", - raw: []byte{0x00, 0x00}, - wantErr: true, - }, - { - name: "Invalid version: 2", - raw: []byte{0x02, 0x00}, - wantErr: true, - }, - { - name: "Bad number of elements", - raw: []byte{0b11000001, 0x00}, // Offset size = 4, should be out of bounds read - wantErr: true, - }, - { - name: "Missing elements", - raw: []byte{0x01, 0x02, 0x00, 0x01, 0x02}, - wantErr: true, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - got, err := decodeMetadata(c.raw) - if err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatal("Got no error when one was expected") - } - if diff := cmp.Diff(got.keys, c.want); diff != "" { - t.Fatalf("Received incorrect keys. Diff (-want +got):\n%s", diff) - } - }) - } -} - -func buildRandomKey(l int) string { - buf := make([]byte, l) - // Create a random ascii character between decimal 33 (!) and decimal 126 (~) - for i := range l { - randChar := byte(rand.Intn(94)) + 33 - buf[i] = randChar - } - return string(buf) -} - -func TestOffsetCalculation(t *testing.T) { - cases := []struct { - name string - keyLen int - numKeys int - wantHdr byte - }{ - { - name: "Offset length 1", - keyLen: 1, - numKeys: 1, - wantHdr: 0b00000001, - }, - { - name: "Offset length 2", - keyLen: 4, - numKeys: 256, - wantHdr: 0b01000001, - }, - { - name: "Offset length 3", - keyLen: 1<<16 + 1, - numKeys: 1, - wantHdr: 0b10000001, - }, - { - name: "Offset length 4", - keyLen: 1<<24 + 1, - numKeys: 1, - wantHdr: 0b11000001, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - mdb := newMetadataBuilder() - for range c.numKeys { - mdb.Add(buildRandomKey(c.keyLen)) - } - md := mdb.Build() - if got, want := md[0], c.wantHdr; got != want { - t.Fatalf("Incorrect header: got %x, want %x", got, want) - } - }) - } -} diff --git a/parquet/variants/object.go b/parquet/variants/object.go deleted file mode 100644 index 5d00b492..00000000 --- a/parquet/variants/object.go +++ /dev/null @@ -1,456 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "errors" - "fmt" - "io" - "reflect" - "sort" - "strings" -) - -// Container to keep track of the metadata for a given element in an object. This is mainly -// used so that it's possible to sort by key during build time (as required by spec) and -// not need to go doing a bunch of lookups to find field IDs and offsets. -type objectKey struct { - fieldID int - offset int - key string -} - -type objectBuilder struct { - w io.Writer - buf bytes.Buffer - objKeys []objectKey - fieldIDs map[int]struct{} - maxFieldID int - numItems int - offsets []int - nextOffset int - doneCB doneCB - mdb *metadataBuilder - built bool -} - -var _ ObjectBuilder = (*objectBuilder)(nil) - -func newObjectBuilder(w io.Writer, mdb *metadataBuilder, doneCB doneCB) *objectBuilder { - return &objectBuilder{ - w: w, - mdb: mdb, - doneCB: doneCB, - fieldIDs: make(map[int]struct{}), - } -} - -// Write marshals the provided value into the appropriate Variant type and adds it to this object with -// the provided key. Keys must be unique per object (though nested objects may share the same key). -func (o *objectBuilder) Write(key string, val any, opts ...MarshalOpts) error { - if err := o.checkKey(key); err != nil { - return err - } - return writeCommon(val, &o.buf, o.mdb, func(size int) { - o.record(key, size) - }) -} - -// Extracts field info from a struct field, namely the key name and any options associated with the field. -// If the Variant key name is not present in the `variant` annotation, the struct's field name will be -// used as the key. -// -// This function assumes the field is exported. -func extractFieldInfo(field reflect.StructField) (string, []MarshalOpts) { - var opts []MarshalOpts - - tag, ok := field.Tag.Lookup("variant") - if !ok || tag == "" { - return field.Name, nil - } - - // Tag is of the form "key_name,comma,separated,flags" - parts := strings.Split(tag, ",") - if len(parts) == 1 { - return tag, nil - } - - keyName := parts[0] - if keyName == "" { - keyName = field.Name - } - - for _, optStr := range parts[1:] { - switch strings.ToLower(optStr) { - case "nanos": - opts = append(opts, MarshalTimeNanos) - case "ntz": - opts = append(opts, MarshalTimeNTZ) - case "date": - opts = append(opts, MarshalAsDate) - case "time": - opts = append(opts, MarshalAsTime) - case "timestamp": - opts = append(opts, MarshalAsTimestamp) - } - } - - return keyName, opts -} - -// Creates an object from a struct. Key names are determined by either the struct's field name, or -// by a value in a `variant` field annotation (with the annotation taking precedence). -func (o *objectBuilder) fromStruct(st any) error { - stVal := reflect.ValueOf(st) - - // Get the underlying struct if this is a pointer to one. - if stVal.Kind() == reflect.Pointer { - stVal = stVal.Elem() - } - if stVal.Kind() != reflect.Struct { - return fmt.Errorf("not a struct: %s", stVal.Kind()) - } - typ := stVal.Type() - - for i := range typ.NumField() { - field := typ.Field(i) - - // Skip unexported fields - if !field.IsExported() { - continue - } - - // Get the tag. If not present, use the field name. - key, opts := extractFieldInfo(field) - if err := o.Write(key, stVal.Field(i).Interface(), opts...); err != nil { - return err - } - } - return nil -} - -// Creates an object from a string-keyed map. -func (o *objectBuilder) fromMap(m any) error { - mapVal := reflect.ValueOf(m) - - // Make sure this is a map with string keys. - if mapVal.Kind() != reflect.Map { - return fmt.Errorf("not a map: %v", mapVal.Kind()) - } - if keyKind := mapVal.Type().Key().Kind(); keyKind != reflect.String { - return fmt.Errorf("map does not have string keys: %v", keyKind) - } - - for _, keyVal := range mapVal.MapKeys() { - valVal := mapVal.MapIndex(keyVal) - if err := o.Write(keyVal.Interface().(string), valVal.Interface()); err != nil { - return err - } - } - return nil -} - -// Keys within a given object must be unique -func (o *objectBuilder) checkKey(key string) error { - fieldID, ok := o.mdb.KeyID(key) - if ok { - if _, ok := o.fieldIDs[fieldID]; ok { - return fmt.Errorf("mutiple insertion of key %q in object", key) - } - } - return nil -} - -// Array returns a new ArrayBuilder associated with this object. The marshaled array will -// not be part of the object until the returned ArrayBuilder's Build method is called. -func (o *objectBuilder) Array(key string) (ArrayBuilder, error) { - if err := o.checkKey(key); err != nil { - return nil, err - } - ab := newArrayBuilder(&o.buf, o.mdb, func(size int) { - o.record(key, size) - }) - return ab, nil -} - -// Object returns a new ObjectBuilder associated with this object. The marshaled object will -// not be part of this object until the returned ObjectBuilder's Build method is called. -// -// NB. A nested object can contain a key that also exists in the parent object. -func (o *objectBuilder) Object(key string) (ObjectBuilder, error) { - if err := o.checkKey(key); err != nil { - return nil, err - } - ob := newObjectBuilder(&o.buf, o.mdb, func(size int) { - o.record(key, size) - }) - return ob, nil -} - -// Bookkeeping to record information about an elements key, offset, field ID, and -// to keep a running track of the number of items and the max field ID seen. -func (o *objectBuilder) record(key string, size int) { - o.numItems++ - currOffset := o.nextOffset - o.offsets = append(o.offsets, currOffset) - o.nextOffset += size - - fieldID := o.mdb.Add(key) - o.objKeys = append(o.objKeys, objectKey{ - fieldID: fieldID, - offset: currOffset, - key: key, - }) - o.fieldIDs[fieldID] = struct{}{} - - if fieldID > o.maxFieldID { - o.maxFieldID = fieldID - } -} - -// Build writes the marshaled object to the builders io.Writer. This prepends serialized -// metadata about the object (ie. its header, number of elements, sorted field IDs, and -// offsets) to the running data buffer. -func (o *objectBuilder) Build() error { - if o.built { - return errAlreadyBuilt - } - numItemsSize := 1 - if isLarge(o.numItems) { - numItemsSize = 4 - } - offsetSize := fieldOffsetSize(int32(o.nextOffset)) - fieldIDSize := fieldOffsetSize(int32(o.maxFieldID)) - - // Sort the object keys as per the spec. - sort.Slice(o.objKeys, func(i, j int) bool { - return strings.Compare(o.objKeys[i].key, o.objKeys[j].key) < 0 - }) - - // Preallocate a buffer for the header, number of items, field IDs, and field offsets - serializedFieldIDSize := fieldIDSize * o.numItems - serializedFieldOffsetSize := offsetSize * (o.numItems + 1) - - serializedDataBuf := bytes.NewBuffer(make([]byte, 0, 1+numItemsSize+serializedFieldIDSize+serializedFieldOffsetSize)) - serializedDataBuf.WriteByte(o.header(fieldIDSize, offsetSize)) - - encodeNumber(int64(o.numItems), numItemsSize, serializedDataBuf) - for _, k := range o.objKeys { - encodeNumber(int64(k.fieldID), fieldIDSize, serializedDataBuf) - } - for _, k := range o.objKeys { - encodeNumber(int64(k.offset), offsetSize, serializedDataBuf) - } - encodeNumber(int64(o.nextOffset), offsetSize, serializedDataBuf) - - // Write everything to the writer. - hdrSize, _ := o.w.Write(serializedDataBuf.Bytes()) - dataSize, _ := o.w.Write(o.buf.Bytes()) - - totalSize := hdrSize + dataSize - if o.doneCB != nil { - o.doneCB(totalSize) - } - return nil -} - -func (o *objectBuilder) header(fieldIDSize, fieldOffsetSize int) byte { - // Header is one byte: ABCCDDEE - // * A: Unused - // * B: Is Large: whether there are more than 255 elements in this object or not. - // * C: Field ID Size Minus One: the number of bytes (minus one) used to encode each Field ID - // * D: Field Offset Size Minus One: the number of bytes (minus one) used to encode each Field Offset - // * E: 0x02: the identifier of the Object basic type - hdr := byte(fieldOffsetSize - 1) - hdr |= byte((fieldIDSize - 1) << 2) - if isLarge(o.numItems) { - hdr |= byte(1 << 4) - } - - // Basic type is the lower two bits of the header. Shift the Object specific bits over 2. - hdr <<= 2 - hdr |= byte(BasicObject) - return hdr -} - -type objectData struct { - size int - numElements int - firstFieldIDIdx int - firstOffsetIdx int - firstDataIdx int - fieldIDWidth int - offsetWidth int -} - -// Parses object data from a marshaled object (where the different encoded sections start, -// plus size in bytes and number of elements), plus ensures that the entire object is present -// in the raw buffer. -func getObjectData(raw []byte, offset int) (*objectData, error) { - if err := checkBounds(raw, offset, offset); err != nil { - return nil, err - } - - hdr := raw[offset] - if bt := BasicTypeFromHeader(hdr); bt != BasicObject { - return nil, fmt.Errorf("not an object: %s", bt) - } - - // Get the size of all encoded metadata fields. Bitshift by two to expose the 5 raw value header bits. - hdr >>= 2 - - offsetWidth := int(hdr&0x03) + 1 - fieldIDWidth := int((hdr>>2)&0x03) + 1 - - numElementsWidth := 1 - if hdr&0x08 != 0 { - numElementsWidth = 4 - } - - numElements, err := readUint(raw, offset+1, numElementsWidth) - if err != nil { - return nil, fmt.Errorf("could not get number of elements: %v", err) - } - - firstFieldIDIdx := offset + 1 + numElementsWidth // Header plus width of # of elements - firstOffsetIdx := firstFieldIDIdx + int(numElements)*fieldIDWidth - firstDataIdx := firstOffsetIdx + int(numElements+1)*offsetWidth - lastDataOffset, err := readUint(raw, firstDataIdx-offsetWidth, offsetWidth) - if err != nil { - return nil, fmt.Errorf("could not read last offset: %v", err) - } - lastDataIdx := firstDataIdx + int(lastDataOffset) - - // Also do some bounds checking to ensure that the entire object is represented in the raw buffer. - if err := checkBounds(raw, offset, int(lastDataIdx)); err != nil { - return nil, fmt.Errorf("object is out of bounds: %v", err) - } - return &objectData{ - size: lastDataIdx - offset, - numElements: int(numElements), - firstFieldIDIdx: firstFieldIDIdx, - firstOffsetIdx: firstOffsetIdx, - firstDataIdx: firstDataIdx, - fieldIDWidth: fieldIDWidth, - offsetWidth: offsetWidth, - }, nil -} - -// Unmarshals a Variant object into the provided destination. The destination must be a pointer -// to one of three types: -// - A struct (unmarshal will map the Variant fields to struct fields by name, or contents of the `variant` annotation) -// - A string-keyed map (the passed in map will be cleared) -// - The "any" type. This will be returned as a map[string]any -func unmarshalObject(raw []byte, md *decodedMetadata, offset int, destPtr reflect.Value) error { - data, err := getObjectData(raw, offset) - if err != nil { - return err - } - - if kind := destPtr.Kind(); kind != reflect.Pointer { - return fmt.Errorf("invalid dest, must be non-nil pointer (got kind %s)", kind) - } - if destPtr.IsNil() { - return errors.New("invalid dest, must be non-nil pointer") - } - - destElem := destPtr.Elem() - - switch kind := destElem.Kind(); kind { - case reflect.Struct: - // Nothing to do. - case reflect.Interface: - // Create a new map[string]any - newMap := reflect.MakeMap(reflect.TypeOf(map[string]any{})) - destElem.Set(newMap) - destElem = newMap - case reflect.Map: - if keyKind := destElem.Type().Key().Kind(); keyKind != reflect.String { - return fmt.Errorf("invalid dest map- must have a string for a key, got %s", keyKind) - } - // Clear out the map to start fresh. - destElem.Clear() - default: - return fmt.Errorf("invalid kind- must be a string-keyed map, struct, or any, got %s", kind) - } - - destType := destElem.Type() - - // For slightly faster lookups, preprocess the struct to get a mapping from field name to field ID. - // We only care about settable struct fields. - fieldIDMap := make(map[string]int) - if destElem.Kind() == reflect.Struct { - for i := range destType.NumField() { - structField := destElem.Field(i) - if structField.CanSet() { - fieldName, _ := extractFieldInfo(destType.Field(i)) - fieldIDMap[fieldName] = i - - // Zero out the field if possible to start fresh. - structField.Set(reflect.Zero(structField.Type())) - } - } - } - - // Iterate through all elements in the encoded Variant. - for i := range data.numElements { - variantFieldID, err := readUint(raw, data.firstFieldIDIdx+data.fieldIDWidth*i, data.fieldIDWidth) - if err != nil { - return err - } - variantKey, ok := md.At(int(variantFieldID)) - if !ok { - return fmt.Errorf("key ID %d not present in metadata dictionary", i) - } - - // Get the new element value depending on whether this is a struct or a map - var newElemValue reflect.Value - if destElem.Kind() == reflect.Struct { - // Get pointer to the field within the struct. - structFieldID, ok := fieldIDMap[variantKey] - if !ok { - continue - } - field := destElem.Field(structFieldID) - newElemValue = field.Addr() - } else { - // New element within the map. - newElemValue = reflect.New(destType.Elem()) - } - - // Set the element value based on what's encoded in the Variant. - elemOffset, err := readUint(raw, data.firstOffsetIdx+data.offsetWidth*i, data.offsetWidth) - if err != nil { - return err - } - dataIdx := int(elemOffset) + data.firstDataIdx - if err := checkBounds(raw, dataIdx, dataIdx); err != nil { - return err - } - if err := unmarshalCommon(raw, md, dataIdx, newElemValue); err != nil { - return err - } - - // Structs already have a pointer to the value and are set. For maps, set the value here. - if destElem.Kind() == reflect.Map { - destElem.SetMapIndex(reflect.ValueOf(variantKey), newElemValue.Elem()) - } - } - - return nil -} diff --git a/parquet/variants/object_test.go b/parquet/variants/object_test.go deleted file mode 100644 index 7c7423ae..00000000 --- a/parquet/variants/object_test.go +++ /dev/null @@ -1,711 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestObjectFromStruct(t *testing.T) { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ob := newObjectBuilder(&buf, mdb, nil) - - type nested struct { - Baz string `variant:"d"` - } - - st := struct { - Foo int `variant:"b"` - Bar bool `variant:"a"` - unexported int `variant:"c"` - Nest nested - }{1, true, 2, nested{Baz: "hi"}} - - if err := ob.fromStruct(&st); err != nil { - t.Fatalf("fromStruct(): %v", err) - } - if err := ob.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - - wantBytes := []byte{ - 0x02, // Basic object, all sizes = 1 byte - 0x03, // 3 items - // Keys are inserted in the order "b", "a", "d", "Nest". "Nest" comes first - // since N (value 78) < a (value 97). - 0x03, // Nested's index - 0x01, // a's index - 0x00, // b's index - 0x03, // Nest's value - 0x02, // a's value - 0x00, // b's value - 0x0B, // end - 0b1100, 0x01, // b's value, int8 val = 1 - 0b0100, // a's value - // Nested object - 0x02, // Basic object, all sizes = 1 byte - 0x01, // 1 item - 0x02, // d's index in the dictionary - 0x00, // d's offset - 0x03, // last item - 0b1001, 'h', 'i', // Basic short string, length 2 - // End of nested object - } - got := buf.Bytes() - - diffByteArrays(t, got, wantBytes) - - // Check the metadata keys as well to ensure the struct tags were picked up approrpiately. - encodedMD := mdb.Build() - decodedMetadata, err := decodeMetadata(encodedMD) - if err != nil { - t.Fatalf("Metadata decode error: %v", err) - } - gotKeys, wantKeys := decodedMetadata.keys, []string{"b", "a", "d", "Nest"} - if diff := cmp.Diff(gotKeys, wantKeys); diff != "" { - t.Fatalf("Incorrect metadata keys. Diff (-got +want):\n%s", diff) - } -} - -func TestObjectFromMap(t *testing.T) { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ob := newObjectBuilder(&buf, mdb, nil) - - toWrite := map[string]any{ - "key1": int64(1), - "key2": false, - "key3": int64(2), - } - if err := ob.fromMap(toWrite); err != nil { - t.Fatalf("fromMap(): %v", err) - } - if err := ob.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - decodedMD, err := decodeMetadata(mdb.Build()) - if err != nil { - t.Fatalf("decodeMetadata(): %v", err) - } - - // Decode into a map and ensure things are correct. Can't really compare bytes here since the - // iteration order of a map is undefined. - got := map[string]any{} - dest := reflect.ValueOf(&got) - if err := unmarshalObject(buf.Bytes(), decodedMD, 0, dest); err != nil { - t.Fatalf("unmarshalObject(): %v", err) - } - diff(t, got, toWrite) -} - -func TestObjectPrimitive(t *testing.T) { - var buf bytes.Buffer - mdb := newMetadataBuilder() - - ob := newObjectBuilder(&buf, mdb, nil) - if err := ob.Write("b", 3); err != nil { - t.Fatalf("Write(b): %v", err) - } - if err := ob.Write("a", 1); err != nil { - t.Fatalf("Write(a): %v", err) - } - if err := ob.Write("c", 2); err != nil { - t.Fatalf("Write(c): %v", err) - } - - if err := ob.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - - wantBytes := []byte{ - 0b00000010, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false - 0x03, // 3 elements - 0x01, // Field ID of "a" - 0x00, // Field ID of "b" - 0x02, // Field ID of "c" - 0x02, // Index of "a"s value (1) - 0x00, // Index of "b"s value (3) - 0x04, // Index of "c"s value (2) - 0x06, // First index after elements - 0b1100, // "b"s header- basic Int8 - 0x03, // "b"'s value of 3 - 0b1100, // "a"s header- basic Int8 - 0x01, // "a"s value of 1 - 0b1100, // "c"s header- basic Int8 - 0x02, // "c"s value of 2 - } - - diff(t, buf.Bytes(), wantBytes) - - // Check the metadata keys as well. - encodedMD := mdb.Build() - decodedMetadata, err := decodeMetadata(encodedMD) - if err != nil { - t.Fatalf("Metadata decode error: %v", err) - } - gotKeys, wantKeys := decodedMetadata.keys, []string{"b", "a", "c"} - diff(t, gotKeys, wantKeys) -} - -func TestObjectNestedArray(t *testing.T) { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ob := newObjectBuilder(&buf, mdb, nil) - if err := ob.Write("b", 3); err != nil { - t.Fatalf("Write(b): %v", err) - } - - arr, err := ob.Array("a") - if err != nil { - t.Fatalf("Array(a): %v", err) - } - for _, val := range []any{true, 123} { - if err := arr.Write(val); err != nil { - t.Fatalf("arr.Write(%v): %v", val, err) - } - } - arr.Build() - ob.Write("c", 8) - - if err := ob.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - - wantBytes := []byte{ - 0b00000010, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false - 0x03, // 3 elements - 0x01, // Field ID of "a" - 0x00, // Field ID of "b" - 0x02, // Field ID of "c" - 0x02, // Index of "a"s value (array {true, 123}) - 0x00, // Index of "b"s value (3) - 0x0A, // Index of "c"s value (8) - 0x0C, // First index after the last value - 0b1100, // "b"s header- basic Int8 - 0x03, // "b"s value of 3 - // Beginning of array - 0b011, // "a"s header- basic array - 0x02, // 2 elements in "a"s array - 0x00, // Index of first element (true) - 0x01, // Index of second element (123) - 0x03, // First index after the last value - 0b100, // First element (basic true- header and value) - 0b1100, // Second element- basic Int8 - 0x7B, // 123 encoded - // End of array - 0b1100, // "c"s header- basic Int8 - 0x08, // "c"s value of 8 - } - diffByteArrays(t, buf.Bytes(), wantBytes) - - // Check the metadata keys as well. - encodedMD := mdb.Build() - decodedMetadata, err := decodeMetadata(encodedMD) - if err != nil { - t.Fatalf("Metadata decode error: %v", err) - } - gotKeys, wantKeys := decodedMetadata.keys, []string{"b", "a", "c"} - diff(t, gotKeys, wantKeys) -} - -func TestObjectNestedObjectSharingKeys(t *testing.T) { - var buf bytes.Buffer - mdb := newMetadataBuilder() - ob := newObjectBuilder(&buf, mdb, nil) - - if err := ob.Write("a", true); err != nil { - t.Fatalf("Write(a): %v", err) - } - - nestedOb, err := ob.Object("b") - if err != nil { - t.Fatalf("Object(b): %v", err) - } - - // Same key can exist in a nested object - if err := nestedOb.Write("b", 123); err != nil { - t.Fatalf("Nested Object Write(b): %v", err) - } - nestedOb.Build() - - if err := ob.Build(); err != nil { - t.Fatalf("Object Build(): %v", err) - } - - wantBytes := []byte{ - 0b10, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false - 0x02, // 2 elements - 0x00, // Field ID of "a", - 0x01, // Field ID of "b", - 0x00, // Index of first element (true) - 0x01, // Index of second element (nested object) - 0x08, // First index after the last value - 0b100, // "a"s header & value (basic true) - // Beginning of nested object - 0b10, // "b"s header- basic object, field_offset_size & field_id_size = 1, is_large = false - 0x01, // 1 element - 0x01, // Field ID of "b" - 0x00, // Index of "b"s value (123) - 0x02, // First index after the last value - 0b1100, // "b"s header- basic Int8 - 0x7B, // 123 encoded - // End of nested object - } - diffByteArrays(t, buf.Bytes(), wantBytes) - - // Check the metadata keys as well. - encodedMD := mdb.Build() - decodedMetadata, err := decodeMetadata(encodedMD) - if err != nil { - t.Fatalf("Metadata decode error: %v", err) - } - gotKeys, wantKeys := decodedMetadata.keys, []string{"a", "b"} - diff(t, gotKeys, wantKeys) -} - -func TestObjectData(t *testing.T) { - cases := []struct { - name string - encoded []byte - offset int - want *objectData - wantErr bool - }{ - { - name: "Basic object no offset", - encoded: []byte{ - 0b00000010, // Basic type = object, field_offset_size & field_id_size = 1, is_large = false - 0x03, // 3 elements - 0x01, // Field ID of "a" - 0x00, // Field ID of "b" - 0x02, // Field ID of "c" - 0x02, // Index of "a"s value (1) - 0x00, // Index of "b"s value (3) - 0x04, // Index of "c"s value (2) - 0x06, // First index after elements - 0b1100, // "b"s header- basic Int8 - 0x03, // "b"'s value of 3 - 0b1100, // "a"s header- basic Int8 - 0x01, // "a"s value of 1 - 0b1100, // "c"s header- basic Int8 - 0x02, // "c"s value of 2 - }, - want: &objectData{ - size: 15, - numElements: 3, - firstFieldIDIdx: 2, - firstOffsetIdx: 5, - firstDataIdx: 9, - fieldIDWidth: 1, - offsetWidth: 1, - }, - }, - { - name: "Basic object with offset", - encoded: []byte{ - 0x00, 0b00000010, 0x03, 0x01, 0x00, 0x02, 0x02, 0x00, - 0x04, 0x06, 0b1100, 0x03, 0b1100, 0x01, 0b1100, 0x02, - }, - offset: 1, - want: &objectData{ - size: 15, - numElements: 3, - firstFieldIDIdx: 3, - firstOffsetIdx: 6, - firstDataIdx: 10, - fieldIDWidth: 1, - offsetWidth: 1, - }, - }, - { - name: "Object with larger widths", - encoded: []byte{ - 0b01100110, // Basic type = object, field offset size = 2, field ID size = 3, is_large = true, - 0x03, 0x00, 0x00, 0x00, // 3 elements (encoded in 4 bytes due to is_large) - 0x01, 0x00, 0x00, // Field ID of "a" - 0x00, 0x00, 0x00, // Field ID of "b" - 0x02, 0x00, 0x00, // Field ID of "c" - 0x02, 0x00, // Index of "a"s value (1) - 0x00, 0x00, // Index of "b"s value (3) - 0x04, 0x00, // Index of "c"s value (2) - 0x06, 0x00, // First index after elements - 0b1100, 0x03, // Basic Int8 value of 3 - 0b1100, 0x01, // Basic Int8 value of 1 - 0b1100, 0x02, // Basic Int8 value of 2 - }, - want: &objectData{ - size: 28, - numElements: 3, - firstFieldIDIdx: 5, - firstOffsetIdx: 14, - firstDataIdx: 22, - fieldIDWidth: 3, - offsetWidth: 2, - }, - }, - { - name: "Incorrect basic type", - encoded: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, - wantErr: true, - }, - { - name: "Num elements out of bounds", - encoded: []byte{0x02}, - wantErr: true, - }, - { - name: "Object out of bounds", - encoded: []byte{0x02, 0x03, 0x01, 0x00, 0x02, 0x02, 0x00, 0x04, 0x06}, - wantErr: true, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - got, err := getObjectData(c.encoded, c.offset) - if err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatal("Got no error when one was expected") - } - if diff := cmp.Diff(got, c.want, cmp.AllowUnexported(objectData{})); diff != "" { - t.Fatalf("Incorrect returned data. Diff (-got, +want)\n%s", diff) - } - }) - } -} - -func TestUnmarshalObject(t *testing.T) { - cases := []struct { - name string - md *decodedMetadata - encoded []byte - offset int - // Normally the test will unmarhall into the type in want, but if this override is set, - // it'll try to unmarshal into something of this type. - overrideDecodeType reflect.Type - want any - wantErr bool - }{ - { - name: "Object built of primitives into map", - md: &decodedMetadata{keys: []string{"b", "a", "c"}}, - encoded: []byte{ - 0b01111110, // Basic type = object, field offset & field ID size = 4, is_large = true, - 0x03, 0x00, 0x00, 0x00, // 3 elements (encoded in 4 bytes due to is_large) - 0x01, 0x00, 0x00, 0x00, // Field ID of "a" - 0x00, 0x00, 0x00, 0x00, // Field ID of "b" - 0x02, 0x00, 0x00, 0x00, // Field ID of "c" - 0x02, 0x00, 0x00, 0x00, // Index of "a"s value (1) - 0x00, 0x00, 0x00, 0x00, // Index of "b"s value (3) - 0x04, 0x00, 0x00, 0x00, // Index of "c"s value (2) - 0x06, 0x00, 0x00, 0x00, // First index after elements - 0b1100, 0x03, // Basic Int8 value of 3 - 0b1100, 0x01, // Basic Int8 value of 1 - 0b1100, 0x02, // Basic Int8 value of 2 - }, - want: map[string]any{ - "a": int64(1), - "b": int64(3), - "c": int64(2), - }, - }, - { - name: "Object built of primitives into map with offset", - md: &decodedMetadata{keys: []string{"b", "a", "c"}}, - offset: 3, - encoded: []byte{ - 0x00, 0x00, 0x00, // 3 offset bytes - 0b01111110, // Basic type = object, field offset & field ID size = 4, is_large = true, - 0x03, 0x00, 0x00, 0x00, // 3 elements (encoded in 4 bytes due to is_large) - 0x01, 0x00, 0x00, 0x00, // Field ID of "a" - 0x00, 0x00, 0x00, 0x00, // Field ID of "b" - 0x02, 0x00, 0x00, 0x00, // Field ID of "c" - 0x02, 0x00, 0x00, 0x00, // Index of "a"s value (1) - 0x00, 0x00, 0x00, 0x00, // Index of "b"s value (3) - 0x04, 0x00, 0x00, 0x00, // Index of "c"s value (2) - 0x06, 0x00, 0x00, 0x00, // First index after elements - 0b1100, 0x03, // Basic Int8 value of 3 - 0b1100, 0x01, // Basic Int8 value of 1 - 0b1100, 0x02, // Basic Int8 value of 2 - }, - want: map[string]any{ - "a": int64(1), - "b": int64(3), - "c": int64(2), - }, - }, - { - name: "Object built of primitives into struct", - md: &decodedMetadata{keys: []string{"b", "a", "c"}}, - encoded: []byte{ - 0b01111110, // Basic type = object, field offset & field ID size = 4, is_large = true, - 0x03, 0x00, 0x00, 0x00, // 3 elements (encoded in 4 bytes due to is_large) - 0x01, 0x00, 0x00, 0x00, // Field ID of "a" - 0x00, 0x00, 0x00, 0x00, // Field ID of "b" - 0x02, 0x00, 0x00, 0x00, // Field ID of "c" - 0x02, 0x00, 0x00, 0x00, // Index of "a"s value (1) - 0x00, 0x00, 0x00, 0x00, // Index of "b"s value (3) - 0x04, 0x00, 0x00, 0x00, // Index of "c"s value (2) - 0x06, 0x00, 0x00, 0x00, // First index after elements - 0b1100, 0x03, // Basic Int8 value of 3 - 0b1100, 0x01, // Basic Int8 value of 1 - 0b1100, 0x02, // Basic Int8 value of 2 - }, - want: struct { - A int `variant:"a"` - B int `variant:"b"` - C int `variant:"c"` - }{1, 3, 2}, - }, - { - name: "Complex object into map", - md: &decodedMetadata{keys: []string{"key1", "key2", "array", "otherkey"}}, - encoded: func() []byte { - var buf bytes.Buffer - mdb := newMetadataBuilder() - - builder := newObjectBuilder(&buf, mdb, nil) - - builder.Write("key1", 123) - builder.Write("key2", "hello") - ab, err := builder.Array("array") - if err != nil { - t.Fatalf("Array('array'): %v", err) - } - ab.Write(false) - ab.Write("substr") - ab.Build() - - builder.Write("otherkey", []byte{'b', 'y', 't', 'e'}) - if err := builder.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - return buf.Bytes() - }(), - want: map[string]any{ - "key1": int64(123), - "key2": "hello", - "array": []any{false, "substr"}, - "otherkey": []byte{'b', 'y', 't', 'e'}, - }, - }, - { - name: "Complex object into struct", - md: &decodedMetadata{keys: []string{"key1", "key2", "array", "otherkey"}}, - encoded: func() []byte { - var buf bytes.Buffer - mdb := newMetadataBuilder() - - builder := newObjectBuilder(&buf, mdb, nil) - - builder.Write("key1", 123) - builder.Write("key2", "hello") - ab, err := builder.Array("array") - if err != nil { - t.Fatalf("Array('array'): %v", err) - } - ab.Write(false) - ab.Write("substr") - ab.Build() - - builder.Write("otherkey", []byte{'b', 'y', 't', 'e'}) - if err := builder.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - return buf.Bytes() - }(), - want: struct { - K1 int `variant:"key1"` - K2 string `variant:"key2"` - Arr []any `variant:"array"` - Other []byte `variant:"otherkey"` - }{123, "hello", []any{false, "substr"}, []byte{'b', 'y', 't', 'e'}}, - }, - { - name: "Unmarshal skips non-present fields in struct", - md: &decodedMetadata{keys: []string{"key1", "key2", "key3"}}, - encoded: func() []byte { - var buf bytes.Buffer - mdb := newMetadataBuilder() - builder := newObjectBuilder(&buf, mdb, nil) - builder.Write("key1", 123) - builder.Write("key2", "hello") - builder.Write("key3", false) - if err := builder.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - return buf.Bytes() - }(), - want: struct { - K1 int `variant:"key1"` - // key2 is undefined - K3 bool `variant:"key3"` - }{123, false}, - }, - { - name: "Unmarshal into typed map", - md: &decodedMetadata{keys: []string{"key1", "key2", "key3"}}, - encoded: func() []byte { - var buf bytes.Buffer - mdb := newMetadataBuilder() - builder := newObjectBuilder(&buf, mdb, nil) - builder.Write("key1", 123) - builder.Write("key2", 234) - builder.Write("key3", 345) - if err := builder.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - return buf.Bytes() - }(), - want: map[string]int{ - "key1": 123, - "key2": 234, - "key3": 345, - }, - }, - { - name: "Malformed raw data", - md: &decodedMetadata{keys: []string{"key1", "key2", "key3"}}, - encoded: func() []byte { - var buf bytes.Buffer - mdb := newMetadataBuilder() - builder := newObjectBuilder(&buf, mdb, nil) - builder.Write("key1", 123) - builder.Write("key2", 234) - builder.Write("key3", 345) - if err := builder.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - return buf.Bytes()[1:] // Lop off the first byte - }(), - overrideDecodeType: reflect.TypeOf(map[string]any{}), - wantErr: true, - }, - { - name: "Maps must be string keyed", - md: &decodedMetadata{keys: []string{"key1", "key2", "key3"}}, - encoded: func() []byte { - var buf bytes.Buffer - mdb := newMetadataBuilder() - builder := newObjectBuilder(&buf, mdb, nil) - builder.Write("key1", 123) - builder.Write("key2", 234) - builder.Write("key3", 345) - if err := builder.Build(); err != nil { - t.Fatalf("Build(): %v", err) - } - return buf.Bytes() - }(), - overrideDecodeType: reflect.TypeOf(map[int]any{}), - wantErr: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var got reflect.Value - typ := reflect.TypeOf(c.want) - if c.overrideDecodeType != nil { - typ = c.overrideDecodeType - } - if typ.Kind() == reflect.Map { - underlying := reflect.MakeMap(typ) - got = reflect.New(typ) - got.Elem().Set(underlying) - } else { - got = reflect.New(typ) - } - if err := unmarshalObject(c.encoded, c.md, c.offset, got); err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatalf("Got no error when one was expected") - } - diff(t, got.Elem().Interface(), c.want) - }) - } -} - -func TestExtractFieldInfo(t *testing.T) { - type testStruct struct { - NoTags int - JustName int `variant:"just_name"` - EmptyTag int `variant:""` - WithOpts int `variant:"with_opts,ntz,date,nanos,time"` - OptsNoName int `variant:",ntz"` - UnknownOpt int `variant:"unknown,not_defined_opt"` - } - cases := []struct { - name string - field int - wantName string - wantOpts []MarshalOpts - }{ - { - name: "No tags", - field: 0, - wantName: "NoTags", - }, - { - name: "Field tag with just name", - field: 1, - wantName: "just_name", - }, - { - name: "Empty tag uses struct field name", - field: 2, - wantName: "EmptyTag", - }, - { - name: "Field tag with name and options", - field: 3, - wantName: "with_opts", - wantOpts: []MarshalOpts{MarshalTimeNTZ, MarshalAsDate, MarshalTimeNanos, MarshalAsTime}, - }, - { - name: "Just options, no name uses struct field name", - field: 4, - wantName: "OptsNoName", - wantOpts: []MarshalOpts{MarshalTimeNTZ}, - }, - { - name: "Ignore unknown options", - field: 5, - wantName: "unknown", - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - val := reflect.ValueOf(testStruct{}) - gotName, gotOpts := extractFieldInfo(val.Type().Field(c.field)) - if gotName != c.wantName { - t.Errorf("Incorrect name. Got %q, want %q", gotName, c.wantName) - } - diff(t, gotOpts, c.wantOpts) - }) - } -} diff --git a/parquet/variants/primitive.go b/parquet/variants/primitive.go deleted file mode 100644 index 49449f17..00000000 --- a/parquet/variants/primitive.go +++ /dev/null @@ -1,756 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "encoding/binary" - "fmt" - "io" - "math" - "math/bits" - "reflect" - "strings" - "time" - "unsafe" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/decimal" - "github.com/google/uuid" -) - -// Variant primitive type IDs. -type primitiveType int - -const ( - primitiveInvalid primitiveType = -1 - primitiveNull primitiveType = 0 - primitiveTrue primitiveType = 1 - primitiveFalse primitiveType = 2 - primitiveInt8 primitiveType = 3 - primitiveInt16 primitiveType = 4 - primitiveInt32 primitiveType = 5 - primitiveInt64 primitiveType = 6 - primitiveDouble primitiveType = 7 - primitiveDecimal4 primitiveType = 8 // TODO - primitiveDecimal8 primitiveType = 9 // TODO - primitiveDecimal16 primitiveType = 10 // TODO - primitiveDate primitiveType = 11 - primitiveTimestampMicros primitiveType = 12 - primitiveTimestampNTZMicros primitiveType = 13 - primitiveFloat primitiveType = 14 - primitiveBinary primitiveType = 15 - primitiveString primitiveType = 16 - primitiveTimeNTZ primitiveType = 17 - primitiveTimestampNanos primitiveType = 18 - primitiveTimestampNTZNanos primitiveType = 19 - primitiveUUID primitiveType = 20 -) - -func (pt primitiveType) String() string { - switch pt { - case primitiveNull: - return "Null" - case primitiveFalse, primitiveTrue: - return "Boolean" - case primitiveInt8: - return "Int8" - case primitiveInt16: - return "Int16" - case primitiveInt32: - return "Int32" - case primitiveInt64: - return "Int64" - case primitiveDouble: - return "Double" - case primitiveDecimal4: - return "Decimal4" - case primitiveDecimal8: - return "Decimal8" - case primitiveDecimal16: - return "Decimal16" - case primitiveDate: - return "Date" - case primitiveTimestampMicros: - return "Timestamp(micros)" - case primitiveTimestampNTZMicros: - return "TimestampNTZ(micros)" - case primitiveFloat: - return "Float" - case primitiveBinary: - return "Binary" - case primitiveString: - return "String" - case primitiveTimeNTZ: - return "TimeNTZ" - case primitiveTimestampNanos: - return "Timestamp(nanos)" - case primitiveTimestampNTZNanos: - return "TimestampNTZ(nanos)" - case primitiveUUID: - return "UUID" - } - return "Invalid" -} - -func validPrimitiveValue(prim primitiveType) error { - if prim < primitiveNull || prim > primitiveUUID { - return fmt.Errorf("%w: primitive type: %d", arrow.ErrInvalid, prim) - } - return nil -} - -func primitiveFromHeader(hdr byte) (primitiveType, error) { - // Special case the basic type of Short String and call it a Primitive String. - bt := BasicTypeFromHeader(hdr) - if bt == BasicShortString { - return primitiveString, nil - } else if bt == BasicPrimitive { - prim := primitiveType(hdr >> 2) - if err := validPrimitiveValue(prim); err != nil { - return primitiveInvalid, err - } - return prim, nil - } - return primitiveInvalid, fmt.Errorf("header is not of a primitive or short string basic type: %s", bt) -} - -func primitiveHeader(prim primitiveType) (byte, error) { - if err := validPrimitiveValue(prim); err != nil { - return 0, err - } - hdr := byte(prim << 2) - hdr |= byte(BasicPrimitive) - return hdr, nil -} - -func marshalDecimal[T decimal.Decimal32 | decimal.Decimal64 | decimal.Decimal128](scale int8, val T, w io.Writer) (int, error) { - hdr := [2]byte{0, byte(scale)} - switch v := any(val).(type) { - case decimal.Decimal32: - hdr[0], _ = primitiveHeader(primitiveDecimal4) - if _, err := w.Write(hdr[:]); err != nil { - return 0, err - } - return 6, binary.Write(w, binary.LittleEndian, int32(v)) - case decimal.Decimal64: - hdr[0], _ = primitiveHeader(primitiveDecimal8) - if _, err := w.Write(hdr[:]); err != nil { - return 0, err - } - return 10, binary.Write(w, binary.LittleEndian, int64(v)) - case decimal.Decimal128: - hdr[0], _ = primitiveHeader(primitiveDecimal16) - if _, err := w.Write(hdr[:]); err != nil { - return 0, err - } - if err := binary.Write(w, binary.LittleEndian, v.LowBits()); err != nil { - return 2, err - } - return 18, binary.Write(w, binary.LittleEndian, v.HighBits()) - default: - panic("should never get here") - } -} - -// marshalPrimitive takes in a primitive value, asserts its type, then marshals the data according to the Variant spec -// into the provided writer, returning the number of bytes written. -// -// Time can be provided in various ways- either by a time.Time struct, or by an int64 when the EncodeAs{Date,Time,Timestamp} -// options are provided. By default, timestamps are written as microseconds- to use nanoseconds pass in EncodeTimeAsNanos. -// Timezone information can be determined from a time.Time struct. Otherwise, by default, timestamps will be written with -// local timezone set. -func marshalPrimitive(v any, w io.Writer, opts ...MarshalOpts) (int, error) { - var allOpts MarshalOpts - for _, o := range opts { - allOpts |= o - } - switch val := v.(type) { - case bool: - return marshalBoolean(val, w), nil - case int: - if bits.UintSize == 32 { - return marshalNumeric(int32(val), w) - } - return marshalNumeric(int64(val), w) - case int8: - return marshalNumeric(val, w) - case uint8: - return marshalNumeric(int16(val), w) - case int16: - return marshalNumeric(val, w) - case uint16: - return marshalNumeric(int32(val), w) - case int32: - return marshalNumeric(val, w) - case uint32: - return marshalNumeric(int64(val), w) - case int64: - return marshalNumeric(val, w) - case uint64: - return 0, fmt.Errorf("%w: cannot marshal uint64 values", arrow.ErrInvalid) - case float32: - return marshalNumeric(val, w) - case float64: - return marshalNumeric(val, w) - case arrow.Date32: - return marshalNumeric(val, w) - case arrow.Date64: - return marshalNumeric(arrow.Date32FromTime(val.ToTime()), w) - case arrow.Time64: - return marshalNumeric(val, w) - case uuid.UUID: - return marshalUUID(val, w) - case string: - return marshalString(val, w) - case []byte: - return marshalBinary(val, w) - case arrow.Timestamp: - return encodeTimestamp(int64(val), allOpts&MarshalTimeNanos != 0, false, w) - // TODO: add decimal.Decimal32/Decimal64/Decimal128 - case time.Time: - if allOpts&MarshalAsDate != 0 { - return marshalNumeric(arrow.Date32FromTime(val), w) - } - return marshalTimestamp(val, allOpts&MarshalTimeNanos != 0, w) - } - if v == nil { - return marshalNull(w), nil - } - return -1, fmt.Errorf("unsupported primitive type") -} - -// unmarshals a primitive (or a short string) into dest. dest must be a non-nil pointer to variable that is -// compatible with the Variant value to decode. Some conversions can take place: -// - Integer values: Higher widths can be decoded into smaller widths so long as they don't overflow. Also, -// integral values can be decoded into floats (also so long as they don't overflow). -// - Time/timestamps: Can be decoded into either int64 or time.Time, the latter of which will carry time zone information -// - Strings and binary: Can be decoded into either string or []byte -// -// If the Variant primitive is of the Null type, dest will be set to its zero value. -func unmarshalPrimitive(raw []byte, offset int, destPtr reflect.Value) error { - dest := destPtr.Elem() - kind := dest.Kind() - isEmptyInterface := kind == reflect.Interface && dest.NumMethod() == 0 - - if err := checkBounds(raw, offset, offset); err != nil { - return err - } - - prim, err := primitiveFromHeader(raw[offset]) - if err != nil { - return err - } - - switch prim { - case primitiveNull: - dest.Set(reflect.Zero(dest.Type())) - case primitiveTrue, primitiveFalse: - if kind != reflect.Bool && kind != reflect.Interface { - return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) - } - dest.Set(reflect.ValueOf(prim == primitiveTrue)) - case primitiveInt8, primitiveInt16, primitiveInt32, primitiveInt64: - iv, err := decodeIntPhysical(raw, offset) - if err != nil { - return err - } - if isEmptyInterface { - dest.Set(reflect.ValueOf(iv)) - } else if dest.CanInt() { - if dest.OverflowInt(iv) { - return fmt.Errorf("int value of %d will overflow dest", iv) - } - switch kind { - case reflect.Int: - dest.Set(reflect.ValueOf(int(iv))) - case reflect.Int8: - dest.Set(reflect.ValueOf(int8(iv))) - case reflect.Int16: - dest.Set(reflect.ValueOf(int16(iv))) - case reflect.Int32: - dest.Set(reflect.ValueOf(int32(iv))) - case reflect.Int64: - dest.Set(reflect.ValueOf(iv)) - default: - panic("unhandled int value") - } - } else if dest.CanFloat() { - // Converting from an int64 to a float64 can potentially lose precision, but it's still a valid - // conversion and can be supported here. - fv := float64(iv) - if dest.OverflowFloat(fv) { - return fmt.Errorf("value of %d will overflow dest", iv) - } - switch kind { - case reflect.Float32: - dest.Set(reflect.ValueOf(float32(fv))) - case reflect.Float64: - dest.Set(reflect.ValueOf(fv)) - default: - panic("unhandled float value") - } - } else { - return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) - } - case primitiveFloat: - fv, err := unmarshalFloat(raw, offset) - if err != nil { - return err - } - if !dest.CanFloat() && !isEmptyInterface { - return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) - } - switch kind { - case reflect.Float32, reflect.Interface: - dest.Set(reflect.ValueOf(fv)) - case reflect.Float64: - dest.Set(reflect.ValueOf(float64(fv))) - default: - panic("unhandled float value") - } - case primitiveDouble: - dv, err := unmarshalDouble(raw, offset) - if err != nil { - return err - } - if !dest.CanFloat() && !isEmptyInterface { - return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) - } - switch kind { - case reflect.Float32: - if dest.OverflowFloat(dv) { - return fmt.Errorf("value of %f will overflow dest", dv) - } - dest.Set(reflect.ValueOf(float32(dv))) - case reflect.Float64, reflect.Interface: - dest.Set(reflect.ValueOf(dv)) - default: - panic("unhandled float value") - } - case primitiveTimeNTZ, primitiveTimestampMicros, primitiveTimestampNTZMicros, - primitiveTimestampNanos, primitiveTimestampNTZNanos: - tsv, err := readUint(raw, offset+1, 8) - if err != nil { - return err - } - - // Time can be decoded into either an int64 (the physical time), or into a time.Time struct. - // Anything else is invalid. - if kind == reflect.Int64 || isEmptyInterface { - dest.Set(reflect.ValueOf(int64(tsv))) - } else if kind == reflect.Uint64 { - dest.Set(reflect.ValueOf(tsv)) - } else if dest.Type() == reflect.TypeOf(time.Time{}) { - var t time.Time - if prim == primitiveTimeNTZ { - // TimeNTZ for Variants is UTC=false (ie. local timezone) and in microseconds - t = time.Date(0, 0, 0, 0, 0, 0, 1000*int(tsv), time.Local) - } else { - if prim == primitiveTimestampMicros || prim == primitiveTimestampNTZMicros { - t = time.UnixMicro(int64(tsv)) - } else { - sec := int64(tsv / 1e9) - nsec := int64(tsv % 1e9) - t = time.Unix(sec, nsec) - } - if prim == primitiveTimestampMicros || prim == primitiveTimestampNanos { - t = t.In(time.Local) - } - } - dest.Set(reflect.ValueOf(t)) - } else { - return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) - } - case primitiveString: - str, err := unmarshalString(raw, offset) - if err != nil { - return err - } - if kind == reflect.String || isEmptyInterface { - dest.Set(reflect.ValueOf(str)) - } else if dest.Type() == reflect.TypeOf([]byte{}) { - dest.Set(reflect.ValueOf([]byte(str))) - } else { - return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) - } - case primitiveBinary: - bytes, err := unmarshalBinary(raw, offset) - if err != nil { - return err - } - if isEmptyInterface || dest.Type() == reflect.TypeOf([]byte{}) { - dest.Set(reflect.ValueOf(bytes)) - } else if kind == reflect.String { - dest.Set(reflect.ValueOf(string(bytes))) - } else { - return fmt.Errorf("cannot decode Variant value of %s into dest %s", prim, kind) - } - case primitiveUUID: - u, err := unmarshalUUID(raw, offset) - if err != nil { - return err - } - if dest.Type() == reflect.TypeOf(uuid.UUID{}) || isEmptyInterface { - dest.Set(reflect.ValueOf(u)) - } else if dest.Type() == reflect.TypeOf([]byte{}) { - bytes, _ := u.MarshalBinary() - dest.Set(reflect.ValueOf(bytes)) - } else if dest.Type() == reflect.TypeOf([16]byte{}) { - bytes, _ := u.MarshalBinary() - var fixed [16]byte - copy(fixed[:], bytes) - dest.Set(reflect.ValueOf(fixed)) - } else if kind == reflect.String { - dest.Set(reflect.ValueOf(u.String())) - } else { - return fmt.Errorf("cannot decode Variant UUID into dest %s", kind) - } - default: - return fmt.Errorf("unknown primitive: %s", prim) - } - - return nil -} - -func marshalNull(w io.Writer) int { - hdr, _ := primitiveHeader(primitiveNull) - w.Write([]byte{hdr}) - return 1 -} - -func marshalBoolean(b bool, w io.Writer) int { - var hdr byte - if b { - hdr, _ = primitiveHeader(primitiveTrue) - } else { - hdr, _ = primitiveHeader(primitiveFalse) - } - w.Write([]byte{hdr}) - return 1 -} - -func unmarshalBoolean(raw []byte, offset int) (bool, error) { - prim, err := primitiveFromHeader(raw[offset]) - if err != nil { - return false, err - } - return prim == primitiveTrue, nil -} - -func marshalNumeric[T float32 | float64 | int8 | int16 | int32 | int64 | arrow.Date32 | arrow.Time64](val T, w io.Writer) (int, error) { - var hdr byte - switch any(val).(type) { - case int8: - hdr, _ = primitiveHeader(primitiveInt8) - case int16: - hdr, _ = primitiveHeader(primitiveInt16) - case int32: - hdr, _ = primitiveHeader(primitiveInt32) - case int64: - hdr, _ = primitiveHeader(primitiveInt64) - case float32: - hdr, _ = primitiveHeader(primitiveFloat) - case float64: - hdr, _ = primitiveHeader(primitiveDouble) - case arrow.Date32: - hdr, _ = primitiveHeader(primitiveDate) - case arrow.Time64: - hdr, _ = primitiveHeader(primitiveTimeNTZ) - } - - if _, err := w.Write([]byte{hdr}); err != nil { - return 0, err - } - return binary.Size(val) + 1, binary.Write(w, binary.LittleEndian, val) -} - -func marshalBinary[T string | []byte](val T, w io.Writer) (int, error) { - var buf [5]byte - switch any(val).(type) { - case []byte: - buf[0], _ = primitiveHeader(primitiveBinary) - case string: - buf[0], _ = primitiveHeader(primitiveString) - } - - binary.Encode(buf[1:], binary.LittleEndian, int32(len(val))) - n, err := w.Write(buf[:]) - if err != nil { - return n, err - } - - if c, err := w.Write([]byte(val)); err != nil { - return n + c, err - } - - return n + len(val), nil -} - -// // Encodes an integer with the appropriate primitive header. This encodes the int -// // into the minimal space necessary regardless of the width that's passed in (eg. an -// // int64 of value 1 will be encoded into an int8) -// func marshalInt(val int64, w io.Writer) int { -// var hdr byte -// var size int -// if val < math.MaxInt8 && val > math.MinInt8 { -// hdr, _ = primitiveHeader(primitiveInt8) -// size = 1 -// } else if val < math.MaxInt16 && val > math.MinInt16 { -// hdr, _ = primitiveHeader(primitiveInt16) -// size = 2 -// } else if val < math.MaxInt32 && val > math.MinInt32 { -// hdr, _ = primitiveHeader(primitiveInt32) -// size = 4 -// } else { -// hdr, _ = primitiveHeader(primitiveInt64) -// size = 8 -// } -// w.Write([]byte{hdr}) -// encodeNumber(val, size, w) -// return size + 1 -// } - -func decodeIntPhysical(raw []byte, offset int) (int64, error) { - typ, _ := primitiveFromHeader(raw[offset]) - var size int - switch typ { - case primitiveInt8: - size = 1 - case primitiveInt16: - size = 2 - case primitiveInt32, primitiveDate: - size = 4 - case primitiveInt64: - size = 8 - default: - return -1, fmt.Errorf("not an integral type: %s", typ) - } - val, err := readInt(raw, offset+1, size) - if err != nil { - return -1, err - } - - // Do a conversion dance from the minimal width to int64 to catch - // negative numbers. - switch typ { - case primitiveInt8: - return int64(int8(val)), nil - case primitiveInt16: - return int64(int16(val)), nil - case primitiveInt32: - return int64(int32(val)), nil - default: - return int64(val), nil - } -} - -// func marshalFloat(val float32, w io.Writer) (int, error) { -// buf := make([]byte, 5) -// hdr, _ := primitiveHeader(primitiveFloat) -// buf[0] = hdr -// binary.Encode(buf[1:], binary.LittleEndian, val) -// return w.Write(buf) -// } - -// func marshalDouble(val float64, w io.Writer) int { -// buf := make([]byte, 9) -// hdr, _ := primitiveHeader(primitiveDouble) -// buf[0] = hdr -// bits := math.Float64bits(val) -// for i := range 8 { -// buf[i+1] = byte(bits) -// bits >>= 8 -// } -// w.Write(buf) -// return 9 -// } - -func unmarshalFloat(raw []byte, offset int) (float32, error) { - v, err := readUint(raw, offset+1, 4) - if err != nil { - return -1, err - } - return math.Float32frombits(uint32(v)), nil -} - -func unmarshalDouble(raw []byte, offset int) (float64, error) { - v, err := readUint(raw, offset+1, 8) - if err != nil { - return -1, err - } - return math.Float64frombits(v), nil -} - -// func encodePrimitiveBytes(b []byte, w io.Writer) int { -// encodeNumber(int64(len(b)), 4, w) -// w.Write(b) -// return len(b) + 4 -// } - -func marshalString(str string, w io.Writer) (int, error) { - str = strings.ToValidUTF8(str, "\uFFFD") - - // If the string is 63 characters or less, encode this as a short string to save space. - strlen := len(str) - if strlen < 0x3F { - hdr := byte(strlen << 2) - hdr |= byte(BasicShortString) - if _, err := w.Write([]byte{hdr}); err != nil { - return 0, err - } - n, err := w.Write([]byte(str)) - return 1 + n, err - } - - return marshalBinary(str, w) -} - -func marshalUUID(u uuid.UUID, w io.Writer) (int, error) { - hdr, _ := primitiveHeader(primitiveUUID) - if _, err := w.Write([]byte{hdr}); err != nil { - return 0, err - } - - m, _ := u.MarshalBinary() // MarshalBinary() can never return an error - n, err := w.Write(m) - return 1 + n, err -} - -func unmarshalUUID(raw []byte, offset int) (uuid.UUID, error) { - if err := checkBounds(raw, offset, offset+17); err != nil { - return uuid.UUID{}, err - } - return uuid.FromBytes(raw[offset+1 : offset+17]) -} - -func unmarshalString(raw []byte, offset int) (string, error) { - // Determine if the string is a short string, or a basic string. - maxPos := len(raw) - if offset >= maxPos { - return "", fmt.Errorf("offset is out of bounds: trying to access position %d, max position is %d", offset, maxPos) - } - bt := BasicTypeFromHeader(raw[offset]) - - if bt == BasicShortString { - l := int(raw[offset] >> 2) - endIdx := 1 + l + offset - if endIdx > maxPos { - return "", fmt.Errorf("end index is out of bounds: trying to access position %d, max position is %d", endIdx, maxPos) - } - strPtr := (*byte)(unsafe.Pointer(&raw[offset+1])) - return unsafe.String(strPtr, l), nil - } - - b, err := getBytes(raw, offset+1) - if err != nil { - return "", err - } - strPtr := (*byte)(unsafe.Pointer(&b[0])) - return unsafe.String(strPtr, len(b)), nil -} - -func getBytes(raw []byte, offset int) ([]byte, error) { - l, err := readUint(raw, offset, 4) - if err != nil { - return nil, fmt.Errorf("could not read length: %v", err) - } - maxIdx := offset + 4 + int(l) - if len(raw) < maxIdx { - return nil, fmt.Errorf("bytes are out of bounds") - } - return raw[offset+4 : maxIdx], nil -} - -// func marshalBinary(b []byte, w io.Writer) int { -// hdr, _ := primitiveHeader(primitiveBinary) -// w.Write([]byte{hdr}) -// return 1 + encodePrimitiveBytes(b, w) -// } - -func unmarshalBinary(raw []byte, offset int) ([]byte, error) { - return getBytes(raw, offset+1) -} - -func marshalTimestamp(t time.Time, nanos bool, w io.Writer) (int, error) { - var ts int64 - if nanos { - ts = t.UnixNano() - } else { - ts = t.UnixMicro() - } - return encodeTimestamp(ts, nanos, t.Location() == time.UTC, w) -} - -func encodeTimestamp(t int64, nanos, ntz bool, w io.Writer) (int, error) { - var typ primitiveType - if nanos { - if ntz { - typ = primitiveTimestampNTZNanos - } else { - typ = primitiveTimestampNanos - } - } else { - if ntz { - typ = primitiveTimestampNTZMicros - } else { - typ = primitiveTimestampMicros - } - } - hdr, _ := primitiveHeader(typ) - if _, err := w.Write([]byte{hdr}); err != nil { - return 0, err - } - return 9, binary.Write(w, binary.LittleEndian, t) -} - -func unmarshalTimestamp(raw []byte, offset int) (time.Time, error) { - typ, _ := primitiveFromHeader(raw[offset]) - ts, err := readUint(raw, offset+1, 8) - if err != nil { - return time.Time{}, err - } - var ret time.Time - if typ == primitiveTimestampMicros || typ == primitiveTimestampNTZMicros { - ret = time.UnixMicro(int64(ts)) - } else { - ret = time.Unix(0, int64(ts)) - } - if typ == primitiveTimestampNTZMicros || typ == primitiveTimestampNTZNanos { - ret = ret.UTC() - } else { - ret = ret.Local() - } - return ret, nil -} - -// func marshalDate(t time.Time, w io.Writer) int { -// epoch := time.Unix(0, 0) -// since := t.Sub(epoch) -// days := int64(since.Hours() / 24) -// hdr, _ := primitiveHeader(primitiveDate) -// w.Write([]byte{hdr}) -// encodeNumber(days, 4, w) -// return 5 -// } - -func unmarshalDate(raw []byte, offset int) (time.Time, error) { - days, err := readUint(raw, offset+1, 4) - if err != nil { - return time.Time{}, err - } - return time.Unix(0, 0).Add(time.Hour * 24 * time.Duration(days)), nil -} diff --git a/parquet/variants/primitive_test.go b/parquet/variants/primitive_test.go deleted file mode 100644 index ccd1db43..00000000 --- a/parquet/variants/primitive_test.go +++ /dev/null @@ -1,618 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "bytes" - "reflect" - "testing" - "time" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/google/go-cmp/cmp" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func diffByteArrays(t *testing.T, got, want []byte) { - t.Helper() - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("Incorrect encoding. Diff (-got, +want):\n%s", diff) - } -} - -func checkSize(t *testing.T, wantSize int, buf []byte) { - t.Helper() - if gotSize := len(buf); gotSize != wantSize { - t.Errorf("Incorrect reported size: got %d, want %d", gotSize, wantSize) - } -} - -func TestBoolean(t *testing.T) { - var b bytes.Buffer - size := marshalBoolean(true, &b) - encodedTrue := b.Bytes() - checkSize(t, size, encodedTrue) - diffByteArrays(t, encodedTrue, []byte{0b100}) - got, err := unmarshalBoolean(encodedTrue, 0) - if err != nil { - t.Fatalf("unmarshalBoolean(): %v", err) - } - if got != true { - t.Fatalf("Incorrect boolean returned. Got %t, want true", got) - } - - b.Reset() - marshalBoolean(false, &b) - encodedFalse := b.Bytes() - diffByteArrays(t, encodedFalse, []byte{0b1000}) - got, err = unmarshalBoolean(encodedFalse, 0) - if err != nil { - t.Fatalf("unmarshalBoolean(): %v", err) - } - if got != false { - t.Fatalf("Incorrect boolean returned. Got %t, want false", got) - } -} - -func TestInt(t *testing.T) { - cases := []struct { - name string - val int64 - wantHdr byte - wantHexVal []byte - }{ - { - name: "Positive Int8", - val: 8, - wantHdr: 0b1100, - wantHexVal: []byte{0x08}, - }, - { - name: "Negative Int8", - val: -8, - wantHdr: 0b1100, - wantHexVal: []byte{0xF8}, - }, - { - name: "Positive Int16", - val: 200, - wantHdr: 0b10000, - wantHexVal: []byte{0xC8, 0x00}, - }, - { - name: "NegativeInt16", - val: -200, - wantHdr: 0b10000, - wantHexVal: []byte{0x38, 0xFF}, - }, - { - name: "Positive Int32", - val: 32768, - wantHdr: 0b10100, - wantHexVal: []byte{0x00, 0x80, 0x00, 0x00}, - }, - { - name: "Negative Int32", - val: -32768, - wantHdr: 0b10100, - wantHexVal: []byte{0x00, 0x80, 0xFF, 0xFF}, - }, - { - name: "Positive Int64", - val: 9223372036854775807, - wantHdr: 0b11000, - wantHexVal: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, - }, - { - name: "Negative Int64", - val: -9223372036854775807, - wantHdr: 0b11000, - wantHexVal: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var b bytes.Buffer - size, err := marshalNumeric(c.val, &b) - require.NoError(t, err) - encoded := b.Bytes() - checkSize(t, size, encoded) - if gotHdr := encoded[0]; gotHdr != c.wantHdr { - t.Fatalf("Incorrect header: got %x, want %x", gotHdr, c.wantHdr) - } - diffByteArrays(t, encoded[1:], c.wantHexVal) - gotVal, err := decodeIntPhysical(encoded, 0) - if err != nil { - t.Fatalf("decodeIntPhysical(): %v", err) - } - if wantVal := c.val; gotVal != wantVal { - t.Fatalf("Incorrect decoded value: got %d, want %d", gotVal, wantVal) - } - }) - } -} - -func TestUUID(t *testing.T) { - cases := []struct { - name string - uuid uuid.UUID - want []byte - }{ - { - name: "UUID no padding", - uuid: func() uuid.UUID { - u, _ := uuid.FromBytes([]byte("sixteencharacter")) - return u - }(), - want: []byte{ - 0b1010000, // Basic primitive UUID - 's', 'i', 'x', 't', 'e', 'e', 'n', - 'c', 'h', 'a', 'r', 'a', 'c', 't', 'e', 'r', - }, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var b bytes.Buffer - size, err := marshalUUID(c.uuid, &b) - require.NoError(t, err) - if size != 17 { - t.Fatalf("Incorrect size. Got %d, want 17", size) - } - diff(t, b.Bytes(), c.want) - - gotUUID, err := unmarshalUUID(b.Bytes(), 0) - if err != nil { - t.Fatalf("unmarshalUUID(): %v", err) - } - gotUUIDBytes, _ := gotUUID.MarshalBinary() - diff(t, gotUUIDBytes, c.want[1:]) - }) - } -} - -func TestFloat(t *testing.T) { - var b bytes.Buffer - size, err := marshalNumeric(1.1, &b) - require.NoError(t, err) - encodedFloat := b.Bytes() - checkSize(t, size, encodedFloat) - diffByteArrays(t, encodedFloat, []byte{ - 0b111000, // Primitive type, float - 0xCD, - 0xCC, - 0x8C, - 0x3F, // 0x3F8C CCCD ~= 1.1 encoded - }) - got, err := unmarshalFloat(encodedFloat, 0) - if err != nil { - t.Fatalf("unmarshalFloat(): %v", err) - } - if want := float32(1.1); got != want { - t.Fatalf("Incorrect float returned. Got %.2f, want %.2f", got, want) - } -} - -func TestDouble(t *testing.T) { - var b bytes.Buffer - size, err := marshalNumeric(float64(1.1), &b) - require.NoError(t, err) - encodedDouble := b.Bytes() - checkSize(t, size, encodedDouble) - diffByteArrays(t, encodedDouble, []byte{ - 0b11100, // Primitive type, double - 0x9A, - 0x99, - 0x99, - 0x99, - 0x99, - 0x99, - 0xF1, - 0x3F, // 0x3FF1 9999 9999 999A ~= 1.1 encoded - }) - got, err := unmarshalDouble(encodedDouble, 0) - if err != nil { - t.Fatalf("unmarshalDouble(): %v", err) - } - if want := float64(1.1); got != want { - t.Fatalf("Incorrect double returned. Got %.2f, want %.2f", got, want) - } -} - -func mustMarshalPrimitive(t *testing.T, val any, opts ...MarshalOpts) []byte { - t.Helper() - var buf bytes.Buffer - if _, err := marshalPrimitive(val, &buf, opts...); err != nil { - t.Fatalf("marshalPrimitive(): %v", err) - } - return buf.Bytes() -} - -func TestUnmarshalPrimitive(t *testing.T) { - cases := []struct { - name string - encoded []byte - offset int - unmarshalType reflect.Type - want any - wantErr bool - }{ - { - name: "Unmarshal bool (with offset)", - encoded: append([]byte{0, 0}, mustMarshalPrimitive(t, true)...), - offset: 2, - want: true, - }, - { - name: "Unmarshal into int (with offset)", - encoded: append([]byte{0, 0}, mustMarshalPrimitive(t, 1)...), // Encodes to Int8 - offset: 2, - want: int(1), - }, - { - name: "Unmarshal into int8", - encoded: mustMarshalPrimitive(t, 1), - want: int8(1), - }, - { - name: "Unmarshal into int16", - encoded: mustMarshalPrimitive(t, 1), - want: int16(1), - }, - { - name: "Unmarshal into int32", - encoded: mustMarshalPrimitive(t, 1), - want: int32(1), - }, - { - name: "Unmarshal into int64", - encoded: mustMarshalPrimitive(t, 1), - want: int64(1), - }, - { - name: "Unmarshal negative", - encoded: mustMarshalPrimitive(t, -100), - want: -100, - }, - { - name: "Unmarshal int into float32", - encoded: mustMarshalPrimitive(t, 1), - want: float32(1), - }, - { - name: "unmarsUnmarshalhal float32", - encoded: mustMarshalPrimitive(t, float32(1.2)), - want: float32(1.2), - }, - { - name: "Unmarshal float64 (with offset)", - encoded: append([]byte{1, 1}, mustMarshalPrimitive(t, float64(1.2))...), - offset: 2, - want: float64(1.2), - }, - { - name: "Unmarshal timestamp into int64", - encoded: mustMarshalPrimitive(t, time.Unix(123, 0).Local(), MarshalTimeNanos), - want: time.Unix(123, 0).UnixNano(), - }, - { - name: "Unmarshal timestamp into time", - encoded: mustMarshalPrimitive(t, time.UnixMilli(1742967183000)), - want: time.UnixMilli(1742967183000), - }, - { - name: "Unmarshal timestamp into time (nanos)", - encoded: mustMarshalPrimitive(t, time.UnixMilli(1742967183000), MarshalTimeNanos), - want: time.Unix(0, 1742967183000000000), - }, - { - name: "Unmarshal short string with offset", - encoded: append([]byte{0, 0}, mustMarshalPrimitive(t, "hello")...), - offset: 2, - want: "hello", - }, - { - name: "Unmarshal basic string", - encoded: []byte{0b1000000, 0x03, 0x00, 0x00, 0x00, 'a', 'b', 'c'}, - want: "abc", - }, - { - name: "Unmarshal string into byte slice", - encoded: mustMarshalPrimitive(t, "hello"), - want: []byte("hello"), - }, - { - name: "Unmarshal binary into byte slice", - encoded: mustMarshalPrimitive(t, []byte{'b', 'y', 't', 'e'}), - want: []byte("byte"), - }, - { - name: "Unmarshal binary into string", - encoded: mustMarshalPrimitive(t, []byte{'b', 'y', 't', 'e'}), - want: "byte", - }, - { - name: "Unmarshal empty binary", - encoded: mustMarshalPrimitive(t, []byte{}), - want: []byte{}, - }, - { - name: "Unmarshal UUID", - encoded: func() []byte { - u, _ := uuid.FromBytes([]byte("sixteencharacter")) - return mustMarshalPrimitive(t, u) - }(), - want: func() uuid.UUID { - u, _ := uuid.FromBytes([]byte("sixteencharacter")) - return u - }(), - }, - { - name: "Unmarshal UUID to byte slice", - encoded: func() []byte { - u, _ := uuid.FromBytes([]byte("sixteencharacter")) - return mustMarshalPrimitive(t, u) - }(), - want: []byte("sixteencharacter"), - }, - { - name: "Unmarshal UUID to string", - encoded: func() []byte { - u, _ := uuid.FromBytes([]byte("sixteencharacter")) - return mustMarshalPrimitive(t, u) - }(), - want: "73697874-6565-6e63-6861-726163746572", - }, - { - name: "Unmarshal into int8 would overflow", - encoded: mustMarshalPrimitive(t, 12345), - unmarshalType: reflect.TypeOf(int8(0)), - wantErr: true, - }, - { - name: "Cannot unmarshal int into non-int type", - encoded: mustMarshalPrimitive(t, 1), - unmarshalType: reflect.TypeOf(string("")), - wantErr: true, - }, - { - name: "Cannot unmarshal string into non-string type", - encoded: mustMarshalPrimitive(t, "hello"), - unmarshalType: reflect.TypeOf(int(1)), - wantErr: true, - }, - { - name: "Cannot unmarshal binary into non-binary type", - encoded: mustMarshalPrimitive(t, []byte{0, 1, 2}), - unmarshalType: reflect.TypeOf(int(1)), - wantErr: true, - }, - { - name: "Malformed value", - encoded: mustMarshalPrimitive(t, 256)[:1], // int16 is usually 3 bytes - unmarshalType: reflect.TypeOf(int(1)), - wantErr: true, - }, - { - name: "Short string out of bounds", - encoded: []byte{0b1001, 'a'}, - unmarshalType: reflect.TypeOf(""), - wantErr: true, - }, - { - name: "Binary out of bounds", - encoded: []byte{0b111100, 0x09, 0x00, 0x00, 0x00, 'a', 'b', 'c'}, - unmarshalType: reflect.TypeOf([]byte{}), - wantErr: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - typ := reflect.TypeOf(c.want) - if c.unmarshalType != nil { - typ = c.unmarshalType - } - got := reflect.New(typ) - if err := unmarshalPrimitive(c.encoded, c.offset, got); err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatalf("Got no error when one was expected") - } - - diff(t, got.Elem().Interface(), c.want) - }) - } -} - -func TestString(t *testing.T) { - cases := []struct { - name string - str string - wantEncoded []byte - }{ - { - name: "Short string", - str: "short", - wantEncoded: []byte{ - 0b0010101, // Short string type, length=5 - 's', 'h', 'o', 'r', 't', - }, - }, - { - name: "Basic string", - str: "abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklmnopqrstuvwxyz1234567890", - wantEncoded: append([]byte{ - 0b1000000, - 0x48, 0x00, 0x00, 0x00, // Length of 72 - }, []byte("abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklmnopqrstuvwxyz1234567890")...), - }, - { - name: "Empty string", - str: "", - wantEncoded: []byte{0b01}, // Short string basic type, length = 0 - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var b bytes.Buffer - size, err := marshalString(c.str, &b) - require.NoError(t, err) - checkSize(t, size, c.wantEncoded) - - gotEncoded := b.Bytes() - diff(t, gotEncoded, c.wantEncoded) - }) - } -} - -func TestBinary(t *testing.T) { - cases := []struct { - name string - bin []byte - wantEncoded []byte - }{ - { - name: "Binary data", - bin: []byte("hello"), - wantEncoded: []byte{ - 0b111100, // Primitive type, binary - 0x05, 0x00, 0x00, 0x00, // Length of 5 - 'h', 'e', 'l', 'l', 'o', - }, - }, - { - name: "Empty data", - bin: []byte{}, - wantEncoded: []byte{0b111100, 0x00, 0x00, 0x00, 0x00}, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - var b bytes.Buffer - size, err := marshalBinary(c.bin, &b) - require.NoError(t, err) - checkSize(t, size, c.wantEncoded) - diff(t, b.Bytes(), c.wantEncoded) - }) - } -} - -func TestTimestamp(t *testing.T) { - cases := []struct { - name string - nanos bool - ntz bool - wantHdr byte - }{ - { - name: "Nanos NTZ", - nanos: true, - ntz: true, - wantHdr: 0b1001100, - }, - { - name: "Nanos", - nanos: true, - wantHdr: 0b1001000, - }, - { - name: "Micros NTZ", - ntz: true, - wantHdr: 0b110100, - }, - { - name: "Micros", - wantHdr: 0b110000, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - ref := time.UnixMicro(1000000000) - if c.ntz { - ref = ref.UTC() - } else { - ref = ref.Local() - } - var b bytes.Buffer - size, err := marshalTimestamp(ref, c.nanos, &b) - require.NoError(t, err) - wantEncoded := []byte{c.wantHdr} - if c.nanos { - wantEncoded = append(wantEncoded, []byte{ - 0x00, - 0x10, - 0xA5, - 0xD4, - 0xE8, - 0x00, - 0x00, - 0x00, // Binary encoding of 1,000,000,000,000 - }...) - } else { - wantEncoded = append(wantEncoded, []byte{ - 0x00, - 0xCA, - 0x9A, - 0x3B, - 0x00, - 0x00, - 0x00, - 0x00, // Binary encoding of 1,000,000,000 - }...) - } - encodedTimestamp := b.Bytes() - checkSize(t, size, encodedTimestamp) - diffByteArrays(t, encodedTimestamp, wantEncoded) - got, err := unmarshalTimestamp(b.Bytes(), 0) - if err != nil { - t.Fatalf("unmarshalTimestamp(): %v", err) - } - if want := ref; got != want { - t.Fatalf("Timestamps differ: got %s, want %s", got, want) - } - }) - } -} - -func TestDate(t *testing.T) { - day := time.Unix(0, 0).Add(10000 * 24 * time.Hour) - var b bytes.Buffer - size, err := marshalNumeric(arrow.Date32FromTime(day), &b) - require.NoError(t, err) - encodedDate := b.Bytes() - checkSize(t, size, encodedDate) - diffByteArrays(t, encodedDate, []byte{ - 0b101100, // Primitive type, date - 0x10, - 0x27, - 0x00, - 0x00, // 10000 = 0x0000 2710 - }) - got, err := unmarshalDate(encodedDate, 0) - require.NoError(t, err) - assert.Equal(t, got, day) -} diff --git a/parquet/variants/testutils.go b/parquet/variants/testutils.go deleted file mode 100644 index 937226b4..00000000 --- a/parquet/variants/testutils.go +++ /dev/null @@ -1,30 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func diff(t *testing.T, got, want any, cmpOpts ...cmp.Option) { - t.Helper() - if d := cmp.Diff(got, want, cmpOpts...); d != "" { - t.Fatalf("Incorrect returned value. Diff (-got, +want):\n%s", d) - } -} diff --git a/parquet/variants/util.go b/parquet/variants/util.go deleted file mode 100644 index bba2e8ce..00000000 --- a/parquet/variants/util.go +++ /dev/null @@ -1,161 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "fmt" - "io" - "reflect" - "time" -) - -// Reads a little-endian encoded uint (betwen 1 and 8 bytes wide) from a raw buffer at a specified -// offset and returns its value. If any part of the read would be out of bounds, this returns an error. -func readUint(raw []byte, offset, size int) (uint64, error) { - if size < 1 || size > 8 { - return 0, fmt.Errorf("invalid size, must be in range [1,8]: %d", size) - } - if maxPos := offset + size; maxPos > len(raw) { - return 0, fmt.Errorf("out of bounds: trying to access position %d, max position is %d", maxPos, len(raw)) - } - var ret uint64 - for i := range size { - ret |= uint64(raw[i+offset]) << (8 * i) - } - return ret, nil -} - -// Reads a little-endian encoded integer (between 1 and 8 bytes wide) from a raw buffer at a specified offset. -func readInt(raw []byte, offset, size int) (int64, error) { - u, err := readUint(raw, offset, size) - if err != nil { - return -1, err - } - return int64(u), nil -} - -func fieldOffsetSize(maxSize int32) int { - if maxSize < 0xFF { - return 1 - } else if maxSize < 0xFFFF { - return 2 - } else if maxSize < 0xFFFFFF { - return 3 - } - return 4 -} - -// Checks that a given range is in the provided raw buffer. -func checkBounds(raw []byte, low, high int) error { - maxPos := len(raw) - if low >= maxPos { - return fmt.Errorf("out of bounds: trying to access position %d, max is %d", low, maxPos) - } - if high > maxPos { - return fmt.Errorf("out of bounds: trying to access position %d, max is %d", high, maxPos) - } - if high < low { - return fmt.Errorf("incorrect bounds- high (%d) must higher than or equal to low (%d)", high, low) - } - if low < 0 { - return fmt.Errorf("bounds must be positive, have [%d, %d]", low, high) - } - return nil -} - -// Encodes a number of a specified width in little-endian format and writes to a writer. -func encodeNumber(val int64, size int, w io.Writer) { - buf := make([]byte, size) - for i := range size { - buf[i] = byte(val) - val >>= 8 - } - w.Write(buf) -} - -func isLarge(numItems int) bool { - return numItems > 0xFF -} - -// Returns the basic type the passed in value should be encoded as, or undefined if it cannot be handled. -func kindFromValue(val any) BasicType { - if val == nil { - return BasicPrimitive - } - v := reflect.ValueOf(val) - - if v.Kind() == reflect.Pointer { - v = v.Elem() - } - switch v.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Bool, - reflect.String, reflect.Float32, reflect.Float64: - return BasicPrimitive - case reflect.Struct: - // Time is considered a primitive. All other structs are objects. - if v.Type() == reflect.TypeOf(time.Time{}) { - return BasicPrimitive - } - return BasicObject - case reflect.Array, reflect.Slice: - typ := v.Type() - // Byte arrays are primitives. UUID happens to fall into this bucket too serindiptously. - if typ.Elem().Kind() == reflect.Uint8 { - return BasicPrimitive - } - return BasicArray - case reflect.Map: - // Only maps with string keys are supported. - typ := v.Type() - if typ.Key().Kind() == reflect.String { - return BasicObject - } - } - return BasicUndefined -} - -// Returns the nth item (zero indexed) in a serialized list (ie. a serialized Array, or serialized Metadata). -// The offset should be the index of the first offset listing. -func readNthItem(raw []byte, offset, item, offsetSize, numElements int) ([]byte, error) { - if err := checkBounds(raw, offset, offset); err != nil { - return nil, err - } - - if item > numElements { - return nil, fmt.Errorf("item number is greater than number of elements (%d vs %d)", item, numElements) - } - - // Calculate the range to return by getting the upper and lower bound of the item. - lowerBound, err := readUint(raw, offset+item*offsetSize, offsetSize) - if err != nil { - return nil, err - } - upperBound, err := readUint(raw, offset+(item+1)*offsetSize, offsetSize) - if err != nil { - return nil, err - } - firstElemIdx := offset + (numElements+1)*offsetSize - - lowIdx := firstElemIdx + int(lowerBound) - highIdx := firstElemIdx + int(upperBound) - - if err := checkBounds(raw, lowIdx, highIdx); err != nil { - return nil, err - } - - return raw[lowIdx:highIdx], nil -} diff --git a/parquet/variants/util_test.go b/parquet/variants/util_test.go deleted file mode 100644 index f52b9004..00000000 --- a/parquet/variants/util_test.go +++ /dev/null @@ -1,331 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -import ( - "testing" - "time" -) - -func TestReadUint(t *testing.T) { - cases := []struct { - name string - raw []byte - offset int - size int - want uint64 - wantErr bool - }{ - { - name: "Read uint8, offset=1", - raw: []byte{0x00, 0x05}, - offset: 1, - size: 1, - want: 5, - }, - { - name: "Read uint16, offset=1", - raw: []byte{0x00, 0x00, 0x01}, // 256 - offset: 1, - size: 2, - want: 256, - }, - { - name: "Read uint32, offset=1", - raw: []byte{0x00, 0x00, 0x00, 0x01, 0x00}, // 65536 - offset: 1, - size: 4, - want: 65536, - }, - { - name: "Read uint64, offset=1", - raw: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00}, // 4294967293 - offset: 1, - size: 8, - want: 4294967296, - }, - { - name: "Empty raw buffer", - offset: 0, - size: 1, - wantErr: true, - }, - { - name: "Not enough bytes for offset", - raw: []byte{0x00}, - offset: 1, - size: 1, - wantErr: true, - }, - { - name: "Not enough bytes to read", - raw: []byte{0x00}, - offset: 0, - size: 2, - wantErr: true, - }, - { - name: "Invalid size 0", - raw: []byte{0x00}, - offset: 0, - size: 0, - wantErr: true, - }, - { - name: "Invalid size too big", - raw: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - offset: 0, - size: 9, - wantErr: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - got, err := readUint(c.raw, c.offset, c.size) - if err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatalf("Got no error when one was expected") - } - if got != c.want { - t.Fatalf("Incorrect value returned. Got %d, want %d", got, c.want) - } - }) - } -} - -func TestKindFromValue(t *testing.T) { - cases := []struct { - name string - val any - want BasicType - }{ - { - name: "Int", - val: 123, - want: BasicPrimitive, - }, - { - name: "Int pointer", - val: func() *int { - a := 123 - return &a - }(), - want: BasicPrimitive, - }, - { - name: "Bool", - val: false, - want: BasicPrimitive, - }, - { - name: "Byte slice is primitive", - val: []byte{'a', 'b', 'c'}, - want: BasicPrimitive, - }, - { - name: "Time", - val: time.Unix(100, 100), - want: BasicPrimitive, - }, - { - name: "Struct", - val: struct{ a int }{1}, - want: BasicObject, - }, - { - name: "Struct pointer", - val: &struct{ a int }{1}, - want: BasicObject, - }, - { - name: "Slice is an array", - val: []int{1, 2, 3}, - want: BasicArray, - }, - { - name: "Map with string keys is an object", - val: map[string]bool{"a": true}, - want: BasicObject, - }, - { - name: "Map with non string keys is not supported", - val: map[int]string{1: "a"}, - want: BasicUndefined, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - if got := kindFromValue(c.val); got != c.want { - t.Fatalf("Incorrect kind. Got %s, want %s", got, c.want) - } - }) - } -} - -func TestReadNthItem(t *testing.T) { - cases := []struct { - name string - raw []byte - offset int - item int - offsetSize int - numElements int - want []byte - wantErr bool - }{ - { - name: "Third item, offset=1, offsetSize=1, width=1", - raw: []byte{0x00, 0x00, 0x01, 0x02, 0x03, 0xAA, 0xBB, 0xCC}, - offset: 1, - item: 2, - offsetSize: 1, - numElements: 3, - want: []byte{0xCC}, - }, - { - name: "Second item, offset=1, offsetSize=2, width=1", - raw: []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0xAA, 0xBB, 0xCC}, - offset: 1, - item: 1, - offsetSize: 2, - numElements: 3, - want: []byte{0xBB}, - }, - { - name: "First item, offset=1, offsetSize=2, width=2", - raw: []byte{ - 0x00, 0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x00, - 0xAA, 0xAA, 0xBB, 0xBB, 0xCC, 0xCC}, - offset: 1, - item: 0, - offsetSize: 2, - numElements: 3, - want: []byte{0xAA, 0xAA}, - }, - { - name: "Second item, offset=1, offsetSize=2, width=2", - raw: []byte{ - 0x00, 0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x00, - 0xAA, 0xAA, 0xBB, 0xBB, 0xCC, 0xCC}, - offset: 1, - item: 1, - offsetSize: 2, - numElements: 3, - want: []byte{0xBB, 0xBB}, - }, - { - name: "Offset out of bounds", - raw: []byte{0x00}, - offset: 1, - wantErr: true, - }, - { - name: "Item is greater than numElements", - raw: []byte{0x00, 0x00, 0x01, 0x02, 0x03, 0xAA, 0xBB, 0xCC}, - offset: 1, - item: 4, - offsetSize: 1, - numElements: 3, - wantErr: true, - }, - { - name: "Item is out of bounds", - raw: []byte{0x00, 0x01, 0x02, 0x04, 0xAA, 0xBB, 0xCC}, - offset: 0, - item: 2, - offsetSize: 1, - numElements: 3, - wantErr: true, - }, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - got, err := readNthItem(c.raw, c.offset, c.item, c.offsetSize, c.numElements) - if err != nil { - if c.wantErr { - return - } - t.Fatalf("readNthItem(): %v", err) - } else if c.wantErr { - t.Fatalf("readNthItem(): wanted error, got none") - } - diffByteArrays(t, got, c.want) - }) - } -} - -func TestCheckBounds(t *testing.T) { - cases := []struct { - name string - raw []byte - low, high int - wantErr bool - }{ - { - name: "In bounds", - raw: make([]byte, 10), - low: 1, - high: 9, - }, - { - name: "low == high", - raw: make([]byte, 10), - low: 1, - high: 1, - }, - { - name: "Out of bounds (idx == len(raw))", - raw: make([]byte, 10), - low: 10, - high: 10, - wantErr: true, - }, - { - name: "high < low", - raw: make([]byte, 10), - low: 5, - high: 1, - wantErr: true, - }, - { - name: "Negative index", - raw: make([]byte, 10), - low: -1, - high: 1, - wantErr: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - err := checkBounds(c.raw, c.low, c.high) - if err != nil { - if c.wantErr { - return - } - t.Fatalf("Unexpected error: %v", err) - } else if c.wantErr { - t.Fatalf("Got no error when one was expected") - } - }) - } -} diff --git a/parquet/variants/variant.go b/parquet/variants/variant.go deleted file mode 100644 index bd6fb83e..00000000 --- a/parquet/variants/variant.go +++ /dev/null @@ -1,53 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package variants - -// Basic types -type BasicType int - -const ( - BasicUndefined BasicType = -1 - BasicPrimitive BasicType = 0 - BasicShortString BasicType = 1 - BasicObject BasicType = 2 - BasicArray BasicType = 3 -) - -func (bt BasicType) String() string { - switch bt { - case BasicPrimitive: - return "Primitive" - case BasicShortString: - return "ShortString" - case BasicObject: - return "Object" - case BasicArray: - return "Array" - } - return "Unknown" -} - -// Function to get the Variant basic type from a provided value header -func BasicTypeFromHeader(hdr byte) BasicType { - return BasicType(hdr & 0x3) -} - -// Container to hold a marshaled Variant. -type MarshaledVariant struct { - Metadata []byte - Value []byte -} From 616e71e0de9a2ce0e3ed0d81878a36a161569c55 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 23 May 2025 15:42:07 -0400 Subject: [PATCH 05/10] go mod tidy --- go.mod | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/go.mod b/go.mod index e4d7cf08..5f1df00a 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,6 @@ require ( github.com/goccy/go-json v0.10.5 github.com/golang/snappy v1.0.0 github.com/google/flatbuffers v25.2.10+incompatible - github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/hamba/avro/v2 v2.28.0 github.com/klauspost/asmfmt v1.3.2 @@ -67,6 +66,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.15.0 // indirect github.com/goccy/go-yaml v1.11.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/gookit/color v1.5.4 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -102,4 +102,3 @@ require ( modernc.org/token v1.1.0 // indirect ) -tool golang.org/x/tools/cmd/stringer From 71ea4d6f0a88dc07ffc0d640e8f555e7babe21b8 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 23 May 2025 15:44:51 -0400 Subject: [PATCH 06/10] exclude generated files from RAT --- dev/release/rat_exclude_files.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index e254fd7a..d9356703 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -34,3 +34,5 @@ parquet/internal/gen-go/parquet/GoUnusedProtection__.go parquet/internal/gen-go/parquet/parquet-consts.go parquet/internal/gen-go/parquet/parquet.go parquet/version_string.go +parquet/variant/basic_type_string.go +parquet/variant/primitive_type_string.go From ee7951a776807505c5be37772b1b1ff80ad45617 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Sat, 24 May 2025 13:36:12 -0400 Subject: [PATCH 07/10] Update parquet/variant/variant.go Co-authored-by: David Li --- parquet/variant/variant.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/parquet/variant/variant.go b/parquet/variant/variant.go index a600c6c7..73a12e20 100644 --- a/parquet/variant/variant.go +++ b/parquet/variant/variant.go @@ -497,7 +497,8 @@ func (v Value) Bytes() []byte { return v.value } func (v Value) Clone() Value { return Value{ meta: v.meta.Clone(), - value: bytes.Clone(v.value)} + value: bytes.Clone(v.value), + } } // Metadata returns the metadata associated with the value. From 18fe8d5d9bfee5b937114e7a4470663ca81cf989 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Sat, 24 May 2025 14:59:11 -0400 Subject: [PATCH 08/10] updates and tests from feedback --- parquet/variant/builder.go | 15 +- parquet/variant/variant.go | 46 +++--- parquet/variant/variant_test.go | 275 ++++++++++++++++++++++++++++++++ 3 files changed, 306 insertions(+), 30 deletions(-) diff --git a/parquet/variant/builder.go b/parquet/variant/builder.go index 70f1d9a1..d240d4a3 100644 --- a/parquet/variant/builder.go +++ b/parquet/variant/builder.go @@ -151,8 +151,13 @@ func extractFieldInfo(f reflect.StructField) (name string, o AppendOpt) { // // timestamp(UTC=true,NANOS) // } // +// There is only one case where options can conflict currently: If both [OptTimeAsDate] and +// [OptTimeAsTime] are set, then [OptTimeAsDate] will take precedence. +// // Options specified in the struct tags will be OR'd with any options passed to the original call -// to Append. +// to Append. As a result, if a Struct field tag sets [OptTimeAsTime], but the call to Append +// passes [OptTimeAsDate], then the value will be appended as a date since that option takes +// precedence. func (b *Builder) Append(v any, opts ...AppendOpt) error { var o AppendOpt for _, opt := range opts { @@ -731,10 +736,9 @@ func (b *Builder) FinishObject(start int, fields []FieldEntry) error { } var ( - dataSize = b.buf.Len() - start - isLarge = sz > math.MaxUint8 - sizeBytes = 1 - idSize, offsetSize = intSize(int(maxID)), intSize(dataSize) + dataSize = b.buf.Len() - start + isLarge = sz > math.MaxUint8 + sizeBytes = 1 ) if isLarge { @@ -745,6 +749,7 @@ func (b *Builder) FinishObject(start int, fields []FieldEntry) error { return errors.New("invalid object size") } + idSize, offsetSize := intSize(int(maxID)), intSize(dataSize) headerSize := 1 + sizeBytes + sz*int(idSize) + (sz+1)*int(offsetSize) // shift the just written data to make room for the header section b.buf.Grow(headerSize) diff --git a/parquet/variant/variant.go b/parquet/variant/variant.go index 73a12e20..512cf3d0 100644 --- a/parquet/variant/variant.go +++ b/parquet/variant/variant.go @@ -51,6 +51,9 @@ const ( ) func basicTypeFromHeader(hdr byte) BasicType { + // because we're doing hdr & 0x3, it is impossible for the result + // to be outside of the range of BasicType. Therefore, we don't + // need to perform any checks. The value will always be [0,3] return BasicType(hdr & basicTypeMask) } @@ -136,6 +139,8 @@ const ( var ( // EmptyMetadataBytes contains a minimal valid metadata section with no dictionary entries. EmptyMetadataBytes = [3]byte{0x1, 0, 0} + + ErrInvalidMetadata = errors.New("invalid variant metadata") ) // Metadata represents the dictionary part of a variant value, which stores @@ -150,29 +155,15 @@ type Metadata struct { func NewMetadata(data []byte) (Metadata, error) { m := Metadata{data: data} if len(data) < hdrSizeBytes+minOffsetSizeBytes*2 { - return m, fmt.Errorf("invalid variant metadata: too short: size=%d", len(data)) + return m, fmt.Errorf("%w: too short: size=%d", ErrInvalidMetadata, len(data)) } if m.Version() != supportedVersion { - return m, fmt.Errorf("invalid variant metadata: unsupported version: %d", m.Version()) + return m, fmt.Errorf("%w: unsupported version: %d", ErrInvalidMetadata, m.Version()) } offsetSz := m.OffsetSize() - if offsetSz < minOffsetSizeBytes || offsetSz > maxOffsetSizeBytes { - return m, fmt.Errorf("invalid variant metadata: invalid offset size: %d", offsetSz) - } - - dictSize, err := m.loadDictionary(offsetSz) - if err != nil { - return m, err - } - - if hdrSizeBytes+int(dictSize+1)*int(offsetSz) > len(m.data) { - return m, fmt.Errorf("invalid variant metadata: offset out of range: %d > %d", - (dictSize+hdrSizeBytes)*uint32(offsetSz), len(m.data)) - } - - return m, nil + return m, m.loadDictionary(offsetSz) } // Clone creates a deep copy of the metadata. @@ -186,21 +177,26 @@ func (m *Metadata) Clone() Metadata { } } -func (m *Metadata) loadDictionary(offsetSz uint8) (uint32, error) { +func (m *Metadata) loadDictionary(offsetSz uint8) error { if int(offsetSz+hdrSizeBytes) > len(m.data) { - return 0, errors.New("invalid variant metadata: too short for dictionary size") + return fmt.Errorf("%w: too short for dictionary size", ErrInvalidMetadata) } dictSize := readLEU32(m.data[hdrSizeBytes : hdrSizeBytes+offsetSz]) m.keys = make([][]byte, dictSize) if dictSize == 0 { - return 0, nil + return nil } // first offset is always 0 offsetStart, offsetPos := uint32(0), hdrSizeBytes+offsetSz valuesStart := hdrSizeBytes + (dictSize+2)*uint32(offsetSz) + if hdrSizeBytes+int(dictSize+1)*int(offsetSz) > len(m.data) { + return fmt.Errorf("%w: offset out of range: %d > %d", + ErrInvalidMetadata, (dictSize+hdrSizeBytes)*uint32(offsetSz), len(m.data)) + } + for i := range dictSize { offsetPos += offsetSz end := readLEU32(m.data[offsetPos : offsetPos+offsetSz]) @@ -208,14 +204,14 @@ func (m *Metadata) loadDictionary(offsetSz uint8) (uint32, error) { keySize := end - offsetStart valStart := valuesStart + offsetStart if valStart+keySize > uint32(len(m.data)) { - return 0, fmt.Errorf("invalid variant metadata: string data out of range: %d + %d > %d", - valStart, keySize, len(m.data)) + return fmt.Errorf("%w: string data out of range: %d + %d > %d", + ErrInvalidMetadata, valStart, keySize, len(m.data)) } m.keys[i] = m.data[valStart : valStart+keySize] offsetStart += keySize } - return dictSize, nil + return nil } // Bytes returns the raw byte representation of the metadata. @@ -265,8 +261,8 @@ func (m Metadata) IdFor(key string) []uint32 { return ret } - for i, k := range m.keys { - if bytes.Equal(k, k) { + for i, kb := range m.keys { + if bytes.Equal(kb, k) { ret = append(ret, uint32(i)) } } diff --git a/parquet/variant/variant_test.go b/parquet/variant/variant_test.go index 44c582f7..2ef4da38 100644 --- a/parquet/variant/variant_test.go +++ b/parquet/variant/variant_test.go @@ -18,6 +18,7 @@ package variant_test import ( "encoding/json" + "math" "os" "path/filepath" "testing" @@ -114,6 +115,7 @@ func TestBasicRead(t *testing.T) { key, err := m.KeyAt(uint32(i)) require.NoError(t, err) assert.Equal(t, k, key) + assert.Equal(t, uint32(i), m.IdFor(k)[0]) } }) } @@ -400,6 +402,10 @@ func TestTimestampNanos(t *testing.T) { require.NoError(t, err) assert.Equal(t, variant.TimestampNanos, v.Type()) assert.Equal(t, arrow.Timestamp(-1), v.Value()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `"1969-12-31 23:59:59.999999999Z"`, string(out)) }) t.Run("ts nanos tz positive", func(t *testing.T) { @@ -409,6 +415,10 @@ func TestTimestampNanos(t *testing.T) { require.NoError(t, err) assert.Equal(t, variant.TimestampNanos, v.Type()) assert.Equal(t, arrow.Timestamp(1744877350123456789), v.Value()) + + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, `"2025-04-17 08:09:10.123456789Z"`, string(out)) }) t.Run("ts nanos ntz positive", func(t *testing.T) { @@ -418,6 +428,12 @@ func TestTimestampNanos(t *testing.T) { require.NoError(t, err) assert.Equal(t, variant.TimestampNanosNTZ, v.Type()) assert.Equal(t, arrow.Timestamp(1744877350123456789), v.Value()) + + tm := time.Unix(1744877350123456789/int64(time.Second), 1744877350123456789%int64(time.Second)) + tm = tm.In(time.Local) + out, err := json.Marshal(v) + require.NoError(t, err) + assert.JSONEq(t, tm.Format(`"2006-01-02 15:04:05.999999999Z0700"`), string(out)) }) } @@ -510,3 +526,262 @@ func TestArrayValues(t *testing.T) { assert.JSONEq(t, expected, string(out)) }) } + +func TestInvalidMetadata(t *testing.T) { + tests := []struct { + name string + metadata []byte + errMsg string + }{ + { + name: "empty metadata", + metadata: []byte{}, + errMsg: "too short", + }, + { + name: "unsupported version", + metadata: []byte{0x02, 0x00, 0x00}, // Version != 1 is unsupported + errMsg: "unsupported version", + }, + { + name: "truncated metadata", + metadata: []byte{0x01, 0x05}, // Metadata too short for its header + errMsg: "too short", + }, + { + name: "too short for dict size", + metadata: []byte{0x81, 0x01, 0x00}, // Offset size is 3, not enough bytes + errMsg: "too short for dictionary", + }, + { + name: "key count exceeds metadata size", + metadata: []byte{0x01, 0xFF, 0x00}, // Claims to have many keys but doesn't + errMsg: "out of range", + }, + { + name: "string data out of range", + metadata: []byte{0x01, 0x01, 0x00, 0x05}, + errMsg: "string data out of range", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := variant.NewMetadata(tt.metadata) + require.Error(t, err) + assert.ErrorIs(t, err, variant.ErrInvalidMetadata) + assert.Contains(t, err.Error(), tt.errMsg) + + _, err = variant.New(tt.metadata, []byte{}) + require.Error(t, err) + assert.ErrorIs(t, err, variant.ErrInvalidMetadata) + assert.Contains(t, err.Error(), tt.errMsg) + }) + } +} + +func TestInvalidValue(t *testing.T) { + tests := []struct { + name string + metadata []byte + value []byte + errMsg string + }{ + { + name: "empty value", + metadata: variant.EmptyMetadataBytes[:], + value: []byte{}, + errMsg: "invalid variant value: empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := variant.New(tt.metadata, tt.value) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + }) + } +} + +func TestInvalidObjectAccess(t *testing.T) { + v := loadVariant(t, "object_primitive") + obj := v.Value().(variant.ObjectValue) + + t.Run("field_at_out_of_bounds", func(t *testing.T) { + _, err := obj.FieldAt(obj.NumElements()) + require.Error(t, err) + assert.Contains(t, err.Error(), "out of range") + assert.ErrorIs(t, err, arrow.ErrIndex) + }) + + t.Run("corrupt_id", func(t *testing.T) { + // Create a corrupt variant with invalid field ID + objBytes := v.Bytes() + idPosition := 2 // Assumes field ID is at this position - adjust if needed + + // Make a copy so we don't modify the original + corruptBytes := make([]byte, len(objBytes)) + copy(corruptBytes, objBytes) + + // Set field ID to an invalid value + corruptBytes[idPosition] = 0xFF + + corrupt, err := variant.NewWithMetadata(v.Metadata(), corruptBytes) + require.NoError(t, err) + + corruptObj := corrupt.Value().(variant.ObjectValue) + _, err = corruptObj.FieldAt(0) + require.Error(t, err) + assert.Contains(t, err.Error(), "fieldID") + + _, err = corruptObj.ValueByKey("int_field") + require.Error(t, err) + assert.Contains(t, err.Error(), "fieldID") + }) +} + +func TestInvalidArrayAccess(t *testing.T) { + v := loadVariant(t, "array_primitive") + arr := v.Value().(variant.ArrayValue) + + t.Run("out_of_bounds", func(t *testing.T) { + _, err := arr.Value(arr.Len()) + require.Error(t, err) + assert.Contains(t, err.Error(), "out of range") + assert.ErrorIs(t, err, arrow.ErrIndex) + }) + + t.Run("negative_index", func(t *testing.T) { + _, err := arr.Value(uint32(math.MaxUint32)) + require.Error(t, err) + assert.Contains(t, err.Error(), "out of range") + }) +} + +func TestInvalidBuilderOperations(t *testing.T) { + t.Run("invalid_object_size", func(t *testing.T) { + var b variant.Builder + start := b.Offset() + + // Move offset to before start to create invalid size + b.AppendInt(123) + fields := []variant.FieldEntry{{Key: "test", ID: 0, Offset: -10}} + + err := b.FinishObject(start+10, fields) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid object size") + }) + + t.Run("invalid_array_size", func(t *testing.T) { + var b variant.Builder + start := b.Offset() + + // Move offset to before start to create invalid size + b.AppendInt(123) + offsets := []int{-10} + + err := b.FinishArray(start+10, offsets) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid array size") + }) + +} + +func TestUnsupportedTypes(t *testing.T) { + var b variant.Builder + + tests := []struct { + name string + value interface{} + }{ + { + name: "complex number", + value: complex(1, 2), + }, + { + name: "function", + value: func() {}, + }, + { + name: "channel", + value: make(chan int), + }, + { + name: "map with non-string keys", + value: map[int]string{ + 1: "test", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := b.Append(tt.value) + require.Error(t, err) + }) + } +} + +func TestDuplicateKeys(t *testing.T) { + t.Run("disallow_duplicates", func(t *testing.T) { + var b variant.Builder + b.SetAllowDuplicates(false) // default, but explicit for test clarity + + start := b.Offset() + fields := make([]variant.FieldEntry, 0) + + fields = append(fields, b.NextField(start, "key")) + require.NoError(t, b.AppendInt(1)) + + fields = append(fields, b.NextField(start, "key")) + require.NoError(t, b.AppendInt(2)) + + err := b.FinishObject(start, fields) + require.Error(t, err) + assert.Contains(t, err.Error(), "disallowed duplicate key") + }) + + t.Run("allow_duplicates", func(t *testing.T) { + var b variant.Builder + b.SetAllowDuplicates(true) + + start := b.Offset() + fields := make([]variant.FieldEntry, 0) + + fields = append(fields, b.NextField(start, "key")) + require.NoError(t, b.AppendInt(1)) + + fields = append(fields, b.NextField(start, "key")) + require.NoError(t, b.AppendInt(2)) + + require.NoError(t, b.FinishObject(start, fields)) + + v, err := b.Build() + require.NoError(t, err) + + obj := v.Value().(variant.ObjectValue) + field, err := obj.ValueByKey("key") + require.NoError(t, err) + assert.Equal(t, int8(2), field.Value.Value()) + }) +} + +func TestValueCloneConsistency(t *testing.T) { + var b variant.Builder + require.NoError(t, b.AppendString("test")) + + v, err := b.Build() + require.NoError(t, err) + + cloned := v.Clone() + + // Reset should invalidate the original value's buffer + b.Reset() + require.NoError(t, b.AppendInt(123)) + + // Original value's buffer is now used for something else + // But the cloned value should still be valid + assert.Equal(t, variant.String, cloned.Type()) + assert.Equal(t, "test", cloned.Value()) +} From f976e809a0b0615bb03734b2cd5a826e11bdcfcf Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 27 May 2025 16:02:55 -0400 Subject: [PATCH 09/10] rename generated string files --- dev/release/rat_exclude_files.txt | 4 ++-- .../variant/{basic_type_string.go => basic_type_stringer.go} | 2 +- .../{primitive_type_string.go => primitive_type_stringer.go} | 2 +- parquet/variant/variant.go | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) rename parquet/variant/{basic_type_string.go => basic_type_stringer.go} (94%) rename parquet/variant/{primitive_type_string.go => primitive_type_stringer.go} (96%) diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index d9356703..adad69e0 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -34,5 +34,5 @@ parquet/internal/gen-go/parquet/GoUnusedProtection__.go parquet/internal/gen-go/parquet/parquet-consts.go parquet/internal/gen-go/parquet/parquet.go parquet/version_string.go -parquet/variant/basic_type_string.go -parquet/variant/primitive_type_string.go +parquet/variant/basic_type_stringer.go +parquet/variant/primitive_type_stringer.go diff --git a/parquet/variant/basic_type_string.go b/parquet/variant/basic_type_stringer.go similarity index 94% rename from parquet/variant/basic_type_string.go rename to parquet/variant/basic_type_stringer.go index 31afdc4a..e8cf83e4 100644 --- a/parquet/variant/basic_type_string.go +++ b/parquet/variant/basic_type_stringer.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=BasicType -linecomment -output=basic_type_string.go"; DO NOT EDIT. +// Code generated by "stringer -type=BasicType -linecomment -output=basic_type_stringer.go"; DO NOT EDIT. package variant diff --git a/parquet/variant/primitive_type_string.go b/parquet/variant/primitive_type_stringer.go similarity index 96% rename from parquet/variant/primitive_type_string.go rename to parquet/variant/primitive_type_stringer.go index f24ed4d4..205724c3 100644 --- a/parquet/variant/primitive_type_string.go +++ b/parquet/variant/primitive_type_stringer.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=PrimitiveType -linecomment -output=primitive_type_string.go"; DO NOT EDIT. +// Code generated by "stringer -type=PrimitiveType -linecomment -output=primitive_type_stringer.go"; DO NOT EDIT. package variant diff --git a/parquet/variant/variant.go b/parquet/variant/variant.go index 512cf3d0..f4075b93 100644 --- a/parquet/variant/variant.go +++ b/parquet/variant/variant.go @@ -36,8 +36,8 @@ import ( "github.com/google/uuid" ) -//go:generate go tool stringer -type=BasicType -linecomment -output=basic_type_string.go -//go:generate go tool stringer -type=PrimitiveType -linecomment -output=primitive_type_string.go +//go:generate go tool stringer -type=BasicType -linecomment -output=basic_type_stringer.go +//go:generate go tool stringer -type=PrimitiveType -linecomment -output=primitive_type_stringer.go // BasicType represents the fundamental type category of a variant value. type BasicType int From 70b7f90ac882d25a68593b55097c0bd289ba4231 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 27 May 2025 16:05:41 -0400 Subject: [PATCH 10/10] updates from feedback --- parquet/variant/builder.go | 14 ++++++-------- parquet/variant/utils.go | 2 +- parquet/variant/variant.go | 10 +++++----- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/parquet/variant/builder.go b/parquet/variant/builder.go index d240d4a3..e57770b7 100644 --- a/parquet/variant/builder.go +++ b/parquet/variant/builder.go @@ -41,6 +41,7 @@ type Builder struct { buf bytes.Buffer dict map[string]uint32 dictKeys [][]byte + totalDictSize int allowDuplicates bool } @@ -67,6 +68,7 @@ func (b *Builder) AddKey(key string) (id uint32) { id = uint32(len(b.dictKeys)) b.dict[key] = id b.dictKeys = append(b.dictKeys, unsafe.Slice(unsafe.StringData(key), len(key))) + b.totalDictSize += len(key) return id } @@ -804,26 +806,22 @@ func (b *Builder) Reset() { // performing this copy. func (b *Builder) Build() (Value, error) { nkeys := len(b.dictKeys) - totalDictSize := 0 - for _, k := range b.dictKeys { - totalDictSize += len(k) - } // determine the number of bytes required per offset entry. // the largest offset is the one-past-the-end value, the total size. // It's very unlikely that the number of keys could be larger, but // incorporate that into the calculation in case of pathological data. - maxSize := max(totalDictSize, nkeys) - if maxSize > maxSizeLimit { + maxSize := max(b.totalDictSize, nkeys) + if maxSize > metadataMaxSizeLimit { return Value{}, fmt.Errorf("metadata size too large: %d", maxSize) } offsetSize := intSize(int(maxSize)) offsetStart := 1 + offsetSize stringStart := int(offsetStart) + (nkeys+1)*int(offsetSize) - metadataSize := stringStart + totalDictSize + metadataSize := stringStart + b.totalDictSize - if metadataSize > maxSizeLimit { + if metadataSize > metadataMaxSizeLimit { return Value{}, fmt.Errorf("metadata size too large: %d", metadataSize) } diff --git a/parquet/variant/utils.go b/parquet/variant/utils.go index 05dee06d..9b8ca24d 100644 --- a/parquet/variant/utils.go +++ b/parquet/variant/utils.go @@ -84,7 +84,7 @@ func objectHeader(large bool, idSize, offsetSize uint8) byte { } func intSize(v int) uint8 { - debug.Assert(v <= maxSizeLimit, "size too large") + debug.Assert(v <= metadataMaxSizeLimit, "size too large") debug.Assert(v >= 0, "size cannot be negative") switch { diff --git a/parquet/variant/variant.go b/parquet/variant/variant.go index f4075b93..9ba05787 100644 --- a/parquet/variant/variant.go +++ b/parquet/variant/variant.go @@ -129,11 +129,11 @@ const ( maxOffsetSizeBytes = 4 // mask is applied after shift - offsetSizeMask uint8 = 0b11 - offsetSizeBitShift uint8 = 6 - supportedVersion = 1 - maxShortStringSize = 0x3F - maxSizeLimit = 128 * 1024 * 1024 // 128MB + offsetSizeMask uint8 = 0b11 + offsetSizeBitShift uint8 = 6 + supportedVersion = 1 + maxShortStringSize = 0x3F + metadataMaxSizeLimit = 128 * 1024 * 1024 // 128MB ) var (