Skip to content

Commit 6f95269

Browse files
authored
Add MaxEncodedSize to encoder (#691)
Adds function that will return the expected maximum size of a given input size with current settings. See #688
1 parent cbc850f commit 6f95269

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

zstd/encoder.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"crypto/rand"
99
"fmt"
1010
"io"
11+
"math"
1112
rdebug "runtime/debug"
1213
"sync"
1314

@@ -639,3 +640,37 @@ func (e *Encoder) EncodeAll(src, dst []byte) []byte {
639640
}
640641
return dst
641642
}
643+
644+
// MaxEncodedSize returns the expected maximum
645+
// size of an encoded block or stream.
646+
func (e *Encoder) MaxEncodedSize(size int) int {
647+
frameHeader := 4 + 2 // magic + frame header & window descriptor
648+
if e.o.dict != nil {
649+
frameHeader += 4
650+
}
651+
// Frame content size:
652+
if size < 256 {
653+
frameHeader++
654+
} else if size < 65536+256 {
655+
frameHeader += 2
656+
} else if size < math.MaxInt32 {
657+
frameHeader += 4
658+
} else {
659+
frameHeader += 8
660+
}
661+
// Final crc
662+
if e.o.crc {
663+
frameHeader += 4
664+
}
665+
666+
// Max overhead is 3 bytes/block.
667+
// There cannot be 0 blocks.
668+
blocks := (size + e.o.blockSize) / e.o.blockSize
669+
670+
// Combine, add padding.
671+
maxSz := frameHeader + 3*blocks + size
672+
if e.o.pad > 1 {
673+
maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
674+
}
675+
return maxSz
676+
}

zstd/encoder_test.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func TestEncoder_EncodeAllSimple(t *testing.T) {
8585
defer e.Close()
8686
start := time.Now()
8787
dst := e.EncodeAll(in, nil)
88-
t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
88+
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
8989
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
9090
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
9191

@@ -98,7 +98,7 @@ func TestEncoder_EncodeAllSimple(t *testing.T) {
9898
os.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
9999
t.Fatal("Decoded does not match")
100100
}
101-
t.Log("Encoded content matched")
101+
//t.Log("Encoded content matched")
102102
})
103103
}
104104
}
@@ -136,6 +136,9 @@ func TestEncoder_EncodeAllConcurrent(t *testing.T) {
136136
go func() {
137137
defer wg.Done()
138138
dst := e.EncodeAll(in, nil)
139+
if len(dst) > e.MaxEncodedSize(len(in)) {
140+
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(dst), e.MaxEncodedSize(len(in)))
141+
}
139142
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
140143
decoded, err := dec.DecodeAll(dst, nil)
141144
if err != nil {
@@ -150,7 +153,7 @@ func TestEncoder_EncodeAllConcurrent(t *testing.T) {
150153
}()
151154
}
152155
wg.Wait()
153-
t.Log("Encoded content matched.", n, "goroutines")
156+
//t.Log("Encoded content matched.", n, "goroutines")
154157
})
155158
}
156159
}
@@ -185,7 +188,10 @@ func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
185188
defer e.Close()
186189
start := time.Now()
187190
dst := e.EncodeAll(in, nil)
188-
t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
191+
if len(dst) > e.MaxEncodedSize(len(in)) {
192+
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(dst), e.MaxEncodedSize(len(in)))
193+
}
194+
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
189195
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
190196
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
191197

@@ -198,7 +204,7 @@ func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
198204
t.Error("Decoded does not match")
199205
return
200206
}
201-
t.Log("Encoded content matched")
207+
//t.Log("Encoded content matched")
202208
})
203209
}
204210
}
@@ -250,6 +256,9 @@ func TestEncoderRegression(t *testing.T) {
250256
t.Error(err)
251257
}
252258
encoded := enc.EncodeAll(in, nil)
259+
if len(encoded) > enc.MaxEncodedSize(len(in)) {
260+
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
261+
}
253262
// Usually too small...
254263
got, err := dec.DecodeAll(encoded, make([]byte, 0, len(in)))
255264
if err != nil {
@@ -268,6 +277,9 @@ func TestEncoderRegression(t *testing.T) {
268277
t.Error(err)
269278
}
270279
encoded = dst.Bytes()
280+
if len(encoded) > enc.MaxEncodedSize(len(in)) {
281+
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
282+
}
271283
got, err = dec.DecodeAll(encoded, make([]byte, 0, len(in)/2))
272284
if err != nil {
273285
t.Logf("error: %v\nwant: %v\ngot: %v", err, in, got)

zstd/fuzz_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ func FuzzEncoding(f *testing.F) {
210210
}
211211

212212
encoded := enc.EncodeAll(data, make([]byte, 0, bufSize))
213+
if len(encoded) > enc.MaxEncodedSize(len(data)) {
214+
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
215+
}
216+
213217
got, err := dec.DecodeAll(encoded, make([]byte, 0, bufSize))
214218
if err != nil {
215219
t.Fatal(fmt.Sprintln("Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
@@ -223,6 +227,9 @@ func FuzzEncoding(f *testing.F) {
223227
t.Fatal(fmt.Sprintln("Level", level, "Close (buffer) error:", err))
224228
}
225229
encoded2 := dst.Bytes()
230+
if len(encoded2) > enc.MaxEncodedSize(len(data)) {
231+
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
232+
}
226233
if !bytes.Equal(encoded, encoded2) {
227234
got, err = dec.DecodeAll(encoded2, got[:0])
228235
if err != nil {
@@ -247,6 +254,9 @@ func FuzzEncoding(f *testing.F) {
247254
}
248255

249256
encoded = enc.EncodeAll(data, encoded[:0])
257+
if len(encoded) > enc.MaxEncodedSize(len(data)) {
258+
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
259+
}
250260
got, err = dec.DecodeAll(encoded, got[:0])
251261
if err != nil {
252262
t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
@@ -260,6 +270,9 @@ func FuzzEncoding(f *testing.F) {
260270
t.Fatal(fmt.Sprintln("Dict Level", level, "Close (buffer) error:", err))
261271
}
262272
encoded2 = dst.Bytes()
273+
if len(encoded2) > enc.MaxEncodedSize(len(data)) {
274+
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
275+
}
263276
if !bytes.Equal(encoded, encoded2) {
264277
got, err = dec.DecodeAll(encoded2, got[:0])
265278
if err != nil {

0 commit comments

Comments
 (0)