Black Lives Matter. Support the Equal Justice Initiative.

Source file src/compress/flate/deflate_test.go

Documentation: compress/flate

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package flate
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"internal/testenv"
    12  	"io"
    13  	"math/rand"
    14  	"os"
    15  	"reflect"
    16  	"runtime/debug"
    17  	"sync"
    18  	"testing"
    19  )
    20  
    21  type deflateTest struct {
    22  	in    []byte
    23  	level int
    24  	out   []byte
    25  }
    26  
    27  type deflateInflateTest struct {
    28  	in []byte
    29  }
    30  
    31  type reverseBitsTest struct {
    32  	in       uint16
    33  	bitCount uint8
    34  	out      uint16
    35  }
    36  
    37  var deflateTests = []*deflateTest{
    38  	{[]byte{}, 0, []byte{1, 0, 0, 255, 255}},
    39  	{[]byte{0x11}, -1, []byte{18, 4, 4, 0, 0, 255, 255}},
    40  	{[]byte{0x11}, DefaultCompression, []byte{18, 4, 4, 0, 0, 255, 255}},
    41  	{[]byte{0x11}, 4, []byte{18, 4, 4, 0, 0, 255, 255}},
    42  
    43  	{[]byte{0x11}, 0, []byte{0, 1, 0, 254, 255, 17, 1, 0, 0, 255, 255}},
    44  	{[]byte{0x11, 0x12}, 0, []byte{0, 2, 0, 253, 255, 17, 18, 1, 0, 0, 255, 255}},
    45  	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 0,
    46  		[]byte{0, 8, 0, 247, 255, 17, 17, 17, 17, 17, 17, 17, 17, 1, 0, 0, 255, 255},
    47  	},
    48  	{[]byte{}, 2, []byte{1, 0, 0, 255, 255}},
    49  	{[]byte{0x11}, 2, []byte{18, 4, 4, 0, 0, 255, 255}},
    50  	{[]byte{0x11, 0x12}, 2, []byte{18, 20, 2, 4, 0, 0, 255, 255}},
    51  	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 2, []byte{18, 132, 2, 64, 0, 0, 0, 255, 255}},
    52  	{[]byte{}, 9, []byte{1, 0, 0, 255, 255}},
    53  	{[]byte{0x11}, 9, []byte{18, 4, 4, 0, 0, 255, 255}},
    54  	{[]byte{0x11, 0x12}, 9, []byte{18, 20, 2, 4, 0, 0, 255, 255}},
    55  	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 9, []byte{18, 132, 2, 64, 0, 0, 0, 255, 255}},
    56  }
    57  
    58  var deflateInflateTests = []*deflateInflateTest{
    59  	{[]byte{}},
    60  	{[]byte{0x11}},
    61  	{[]byte{0x11, 0x12}},
    62  	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
    63  	{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
    64  	{largeDataChunk()},
    65  }
    66  
    67  var reverseBitsTests = []*reverseBitsTest{
    68  	{1, 1, 1},
    69  	{1, 2, 2},
    70  	{1, 3, 4},
    71  	{1, 4, 8},
    72  	{1, 5, 16},
    73  	{17, 5, 17},
    74  	{257, 9, 257},
    75  	{29, 5, 23},
    76  }
    77  
    78  func largeDataChunk() []byte {
    79  	result := make([]byte, 100000)
    80  	for i := range result {
    81  		result[i] = byte(i * i & 0xFF)
    82  	}
    83  	return result
    84  }
    85  
    86  func TestBulkHash4(t *testing.T) {
    87  	for _, x := range deflateTests {
    88  		y := x.out
    89  		if len(y) < minMatchLength {
    90  			continue
    91  		}
    92  		y = append(y, y...)
    93  		for j := 4; j < len(y); j++ {
    94  			y := y[:j]
    95  			dst := make([]uint32, len(y)-minMatchLength+1)
    96  			for i := range dst {
    97  				dst[i] = uint32(i + 100)
    98  			}
    99  			bulkHash4(y, dst)
   100  			for i, got := range dst {
   101  				want := hash4(y[i:])
   102  				if got != want && got == uint32(i)+100 {
   103  					t.Errorf("Len:%d Index:%d, want 0x%08x but not modified", len(y), i, want)
   104  				} else if got != want {
   105  					t.Errorf("Len:%d Index:%d, got 0x%08x want:0x%08x", len(y), i, got, want)
   106  				}
   107  			}
   108  		}
   109  	}
   110  }
   111  
   112  func TestDeflate(t *testing.T) {
   113  	for _, h := range deflateTests {
   114  		var buf bytes.Buffer
   115  		w, err := NewWriter(&buf, h.level)
   116  		if err != nil {
   117  			t.Errorf("NewWriter: %v", err)
   118  			continue
   119  		}
   120  		w.Write(h.in)
   121  		w.Close()
   122  		if !bytes.Equal(buf.Bytes(), h.out) {
   123  			t.Errorf("Deflate(%d, %x) = \n%#v, want \n%#v", h.level, h.in, buf.Bytes(), h.out)
   124  		}
   125  	}
   126  }
   127  
   128  // A sparseReader returns a stream consisting of 0s followed by 1<<16 1s.
   129  // This tests missing hash references in a very large input.
   130  type sparseReader struct {
   131  	l   int64
   132  	cur int64
   133  }
   134  
   135  func (r *sparseReader) Read(b []byte) (n int, err error) {
   136  	if r.cur >= r.l {
   137  		return 0, io.EOF
   138  	}
   139  	n = len(b)
   140  	cur := r.cur + int64(n)
   141  	if cur > r.l {
   142  		n -= int(cur - r.l)
   143  		cur = r.l
   144  	}
   145  	for i := range b[0:n] {
   146  		if r.cur+int64(i) >= r.l-1<<16 {
   147  			b[i] = 1
   148  		} else {
   149  			b[i] = 0
   150  		}
   151  	}
   152  	r.cur = cur
   153  	return
   154  }
   155  
   156  func TestVeryLongSparseChunk(t *testing.T) {
   157  	if testing.Short() {
   158  		t.Skip("skipping sparse chunk during short test")
   159  	}
   160  	w, err := NewWriter(io.Discard, 1)
   161  	if err != nil {
   162  		t.Errorf("NewWriter: %v", err)
   163  		return
   164  	}
   165  	if _, err = io.Copy(w, &sparseReader{l: 23e8}); err != nil {
   166  		t.Errorf("Compress failed: %v", err)
   167  		return
   168  	}
   169  }
   170  
   171  type syncBuffer struct {
   172  	buf    bytes.Buffer
   173  	mu     sync.RWMutex
   174  	closed bool
   175  	ready  chan bool
   176  }
   177  
   178  func newSyncBuffer() *syncBuffer {
   179  	return &syncBuffer{ready: make(chan bool, 1)}
   180  }
   181  
   182  func (b *syncBuffer) Read(p []byte) (n int, err error) {
   183  	for {
   184  		b.mu.RLock()
   185  		n, err = b.buf.Read(p)
   186  		b.mu.RUnlock()
   187  		if n > 0 || b.closed {
   188  			return
   189  		}
   190  		<-b.ready
   191  	}
   192  }
   193  
   194  func (b *syncBuffer) signal() {
   195  	select {
   196  	case b.ready <- true:
   197  	default:
   198  	}
   199  }
   200  
   201  func (b *syncBuffer) Write(p []byte) (n int, err error) {
   202  	n, err = b.buf.Write(p)
   203  	b.signal()
   204  	return
   205  }
   206  
   207  func (b *syncBuffer) WriteMode() {
   208  	b.mu.Lock()
   209  }
   210  
   211  func (b *syncBuffer) ReadMode() {
   212  	b.mu.Unlock()
   213  	b.signal()
   214  }
   215  
   216  func (b *syncBuffer) Close() error {
   217  	b.closed = true
   218  	b.signal()
   219  	return nil
   220  }
   221  
   222  func testSync(t *testing.T, level int, input []byte, name string) {
   223  	if len(input) == 0 {
   224  		return
   225  	}
   226  
   227  	t.Logf("--testSync %d, %d, %s", level, len(input), name)
   228  	buf := newSyncBuffer()
   229  	buf1 := new(bytes.Buffer)
   230  	buf.WriteMode()
   231  	w, err := NewWriter(io.MultiWriter(buf, buf1), level)
   232  	if err != nil {
   233  		t.Errorf("NewWriter: %v", err)
   234  		return
   235  	}
   236  	r := NewReader(buf)
   237  
   238  	// Write half the input and read back.
   239  	for i := 0; i < 2; i++ {
   240  		var lo, hi int
   241  		if i == 0 {
   242  			lo, hi = 0, (len(input)+1)/2
   243  		} else {
   244  			lo, hi = (len(input)+1)/2, len(input)
   245  		}
   246  		t.Logf("#%d: write %d-%d", i, lo, hi)
   247  		if _, err := w.Write(input[lo:hi]); err != nil {
   248  			t.Errorf("testSync: write: %v", err)
   249  			return
   250  		}
   251  		if i == 0 {
   252  			if err := w.Flush(); err != nil {
   253  				t.Errorf("testSync: flush: %v", err)
   254  				return
   255  			}
   256  		} else {
   257  			if err := w.Close(); err != nil {
   258  				t.Errorf("testSync: close: %v", err)
   259  			}
   260  		}
   261  		buf.ReadMode()
   262  		out := make([]byte, hi-lo+1)
   263  		m, err := io.ReadAtLeast(r, out, hi-lo)
   264  		t.Logf("#%d: read %d", i, m)
   265  		if m != hi-lo || err != nil {
   266  			t.Errorf("testSync/%d (%d, %d, %s): read %d: %d, %v (%d left)", i, level, len(input), name, hi-lo, m, err, buf.buf.Len())
   267  			return
   268  		}
   269  		if !bytes.Equal(input[lo:hi], out[:hi-lo]) {
   270  			t.Errorf("testSync/%d: read wrong bytes: %x vs %x", i, input[lo:hi], out[:hi-lo])
   271  			return
   272  		}
   273  		// This test originally checked that after reading
   274  		// the first half of the input, there was nothing left
   275  		// in the read buffer (buf.buf.Len() != 0) but that is
   276  		// not necessarily the case: the write Flush may emit
   277  		// some extra framing bits that are not necessary
   278  		// to process to obtain the first half of the uncompressed
   279  		// data. The test ran correctly most of the time, because
   280  		// the background goroutine had usually read even
   281  		// those extra bits by now, but it's not a useful thing to
   282  		// check.
   283  		buf.WriteMode()
   284  	}
   285  	buf.ReadMode()
   286  	out := make([]byte, 10)
   287  	if n, err := r.Read(out); n > 0 || err != io.EOF {
   288  		t.Errorf("testSync (%d, %d, %s): final Read: %d, %v (hex: %x)", level, len(input), name, n, err, out[0:n])
   289  	}
   290  	if buf.buf.Len() != 0 {
   291  		t.Errorf("testSync (%d, %d, %s): extra data at end", level, len(input), name)
   292  	}
   293  	r.Close()
   294  
   295  	// stream should work for ordinary reader too
   296  	r = NewReader(buf1)
   297  	out, err = io.ReadAll(r)
   298  	if err != nil {
   299  		t.Errorf("testSync: read: %s", err)
   300  		return
   301  	}
   302  	r.Close()
   303  	if !bytes.Equal(input, out) {
   304  		t.Errorf("testSync: decompress(compress(data)) != data: level=%d input=%s", level, name)
   305  	}
   306  }
   307  
   308  func testToFromWithLevelAndLimit(t *testing.T, level int, input []byte, name string, limit int) {
   309  	var buffer bytes.Buffer
   310  	w, err := NewWriter(&buffer, level)
   311  	if err != nil {
   312  		t.Errorf("NewWriter: %v", err)
   313  		return
   314  	}
   315  	w.Write(input)
   316  	w.Close()
   317  	if limit > 0 && buffer.Len() > limit {
   318  		t.Errorf("level: %d, len(compress(data)) = %d > limit = %d", level, buffer.Len(), limit)
   319  		return
   320  	}
   321  	if limit > 0 {
   322  		t.Logf("level: %d, size:%.2f%%, %d b\n", level, float64(buffer.Len()*100)/float64(limit), buffer.Len())
   323  	}
   324  	r := NewReader(&buffer)
   325  	out, err := io.ReadAll(r)
   326  	if err != nil {
   327  		t.Errorf("read: %s", err)
   328  		return
   329  	}
   330  	r.Close()
   331  	if !bytes.Equal(input, out) {
   332  		t.Errorf("decompress(compress(data)) != data: level=%d input=%s", level, name)
   333  		return
   334  	}
   335  	testSync(t, level, input, name)
   336  }
   337  
   338  func testToFromWithLimit(t *testing.T, input []byte, name string, limit [11]int) {
   339  	for i := 0; i < 10; i++ {
   340  		testToFromWithLevelAndLimit(t, i, input, name, limit[i])
   341  	}
   342  	// Test HuffmanCompression
   343  	testToFromWithLevelAndLimit(t, -2, input, name, limit[10])
   344  }
   345  
   346  func TestDeflateInflate(t *testing.T) {
   347  	t.Parallel()
   348  	for i, h := range deflateInflateTests {
   349  		if testing.Short() && len(h.in) > 10000 {
   350  			continue
   351  		}
   352  		testToFromWithLimit(t, h.in, fmt.Sprintf("#%d", i), [11]int{})
   353  	}
   354  }
   355  
   356  func TestReverseBits(t *testing.T) {
   357  	for _, h := range reverseBitsTests {
   358  		if v := reverseBits(h.in, h.bitCount); v != h.out {
   359  			t.Errorf("reverseBits(%v,%v) = %v, want %v",
   360  				h.in, h.bitCount, v, h.out)
   361  		}
   362  	}
   363  }
   364  
   365  type deflateInflateStringTest struct {
   366  	filename string
   367  	label    string
   368  	limit    [11]int
   369  }
   370  
   371  var deflateInflateStringTests = []deflateInflateStringTest{
   372  	{
   373  		"../testdata/e.txt",
   374  		"2.718281828...",
   375  		[...]int{100018, 50650, 50960, 51150, 50930, 50790, 50790, 50790, 50790, 50790, 43683},
   376  	},
   377  	{
   378  		"../../testdata/Isaac.Newton-Opticks.txt",
   379  		"Isaac.Newton-Opticks",
   380  		[...]int{567248, 218338, 198211, 193152, 181100, 175427, 175427, 173597, 173422, 173422, 325240},
   381  	},
   382  }
   383  
   384  func TestDeflateInflateString(t *testing.T) {
   385  	t.Parallel()
   386  	if testing.Short() && testenv.Builder() == "" {
   387  		t.Skip("skipping in short mode")
   388  	}
   389  	for _, test := range deflateInflateStringTests {
   390  		gold, err := os.ReadFile(test.filename)
   391  		if err != nil {
   392  			t.Error(err)
   393  		}
   394  		testToFromWithLimit(t, gold, test.label, test.limit)
   395  		if testing.Short() {
   396  			break
   397  		}
   398  	}
   399  }
   400  
   401  func TestReaderDict(t *testing.T) {
   402  	const (
   403  		dict = "hello world"
   404  		text = "hello again world"
   405  	)
   406  	var b bytes.Buffer
   407  	w, err := NewWriter(&b, 5)
   408  	if err != nil {
   409  		t.Fatalf("NewWriter: %v", err)
   410  	}
   411  	w.Write([]byte(dict))
   412  	w.Flush()
   413  	b.Reset()
   414  	w.Write([]byte(text))
   415  	w.Close()
   416  
   417  	r := NewReaderDict(&b, []byte(dict))
   418  	data, err := io.ReadAll(r)
   419  	if err != nil {
   420  		t.Fatal(err)
   421  	}
   422  	if string(data) != "hello again world" {
   423  		t.Fatalf("read returned %q want %q", string(data), text)
   424  	}
   425  }
   426  
   427  func TestWriterDict(t *testing.T) {
   428  	const (
   429  		dict = "hello world"
   430  		text = "hello again world"
   431  	)
   432  	var b bytes.Buffer
   433  	w, err := NewWriter(&b, 5)
   434  	if err != nil {
   435  		t.Fatalf("NewWriter: %v", err)
   436  	}
   437  	w.Write([]byte(dict))
   438  	w.Flush()
   439  	b.Reset()
   440  	w.Write([]byte(text))
   441  	w.Close()
   442  
   443  	var b1 bytes.Buffer
   444  	w, _ = NewWriterDict(&b1, 5, []byte(dict))
   445  	w.Write([]byte(text))
   446  	w.Close()
   447  
   448  	if !bytes.Equal(b1.Bytes(), b.Bytes()) {
   449  		t.Fatalf("writer wrote %q want %q", b1.Bytes(), b.Bytes())
   450  	}
   451  }
   452  
   453  // See https://golang.org/issue/2508
   454  func TestRegression2508(t *testing.T) {
   455  	if testing.Short() {
   456  		t.Logf("test disabled with -short")
   457  		return
   458  	}
   459  	w, err := NewWriter(io.Discard, 1)
   460  	if err != nil {
   461  		t.Fatalf("NewWriter: %v", err)
   462  	}
   463  	buf := make([]byte, 1024)
   464  	for i := 0; i < 131072; i++ {
   465  		if _, err := w.Write(buf); err != nil {
   466  			t.Fatalf("writer failed: %v", err)
   467  		}
   468  	}
   469  	w.Close()
   470  }
   471  
   472  func TestWriterReset(t *testing.T) {
   473  	t.Parallel()
   474  	for level := 0; level <= 9; level++ {
   475  		if testing.Short() && level > 1 {
   476  			break
   477  		}
   478  		w, err := NewWriter(io.Discard, level)
   479  		if err != nil {
   480  			t.Fatalf("NewWriter: %v", err)
   481  		}
   482  		buf := []byte("hello world")
   483  		n := 1024
   484  		if testing.Short() {
   485  			n = 10
   486  		}
   487  		for i := 0; i < n; i++ {
   488  			w.Write(buf)
   489  		}
   490  		w.Reset(io.Discard)
   491  
   492  		wref, err := NewWriter(io.Discard, level)
   493  		if err != nil {
   494  			t.Fatalf("NewWriter: %v", err)
   495  		}
   496  
   497  		// DeepEqual doesn't compare functions.
   498  		w.d.fill, wref.d.fill = nil, nil
   499  		w.d.step, wref.d.step = nil, nil
   500  		w.d.bulkHasher, wref.d.bulkHasher = nil, nil
   501  		w.d.bestSpeed, wref.d.bestSpeed = nil, nil
   502  		// hashMatch is always overwritten when used.
   503  		copy(w.d.hashMatch[:], wref.d.hashMatch[:])
   504  		if len(w.d.tokens) != 0 {
   505  			t.Errorf("level %d Writer not reset after Reset. %d tokens were present", level, len(w.d.tokens))
   506  		}
   507  		// As long as the length is 0, we don't care about the content.
   508  		w.d.tokens = wref.d.tokens
   509  
   510  		// We don't care if there are values in the window, as long as it is at d.index is 0
   511  		w.d.window = wref.d.window
   512  		if !reflect.DeepEqual(w, wref) {
   513  			t.Errorf("level %d Writer not reset after Reset", level)
   514  		}
   515  	}
   516  
   517  	levels := []int{0, 1, 2, 5, 9}
   518  	for _, level := range levels {
   519  		t.Run(fmt.Sprint(level), func(t *testing.T) {
   520  			testResetOutput(t, level, nil)
   521  		})
   522  	}
   523  
   524  	t.Run("dict", func(t *testing.T) {
   525  		for _, level := range levels {
   526  			t.Run(fmt.Sprint(level), func(t *testing.T) {
   527  				testResetOutput(t, level, nil)
   528  			})
   529  		}
   530  	})
   531  }
   532  
   533  func testResetOutput(t *testing.T, level int, dict []byte) {
   534  	writeData := func(w *Writer) {
   535  		msg := []byte("now is the time for all good gophers")
   536  		w.Write(msg)
   537  		w.Flush()
   538  
   539  		hello := []byte("hello world")
   540  		for i := 0; i < 1024; i++ {
   541  			w.Write(hello)
   542  		}
   543  
   544  		fill := bytes.Repeat([]byte("x"), 65000)
   545  		w.Write(fill)
   546  	}
   547  
   548  	buf := new(bytes.Buffer)
   549  	var w *Writer
   550  	var err error
   551  	if dict == nil {
   552  		w, err = NewWriter(buf, level)
   553  	} else {
   554  		w, err = NewWriterDict(buf, level, dict)
   555  	}
   556  	if err != nil {
   557  		t.Fatalf("NewWriter: %v", err)
   558  	}
   559  
   560  	writeData(w)
   561  	w.Close()
   562  	out1 := buf.Bytes()
   563  
   564  	buf2 := new(bytes.Buffer)
   565  	w.Reset(buf2)
   566  	writeData(w)
   567  	w.Close()
   568  	out2 := buf2.Bytes()
   569  
   570  	if len(out1) != len(out2) {
   571  		t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
   572  		return
   573  	}
   574  	if !bytes.Equal(out1, out2) {
   575  		mm := 0
   576  		for i, b := range out1[:len(out2)] {
   577  			if b != out2[i] {
   578  				t.Errorf("mismatch index %d: %#02x, expected %#02x", i, out2[i], b)
   579  			}
   580  			mm++
   581  			if mm == 10 {
   582  				t.Fatal("Stopping")
   583  			}
   584  		}
   585  	}
   586  	t.Logf("got %d bytes", len(out1))
   587  }
   588  
   589  // TestBestSpeed tests that round-tripping through deflate and then inflate
   590  // recovers the original input. The Write sizes are near the thresholds in the
   591  // compressor.encSpeed method (0, 16, 128), as well as near maxStoreBlockSize
   592  // (65535).
   593  func TestBestSpeed(t *testing.T) {
   594  	t.Parallel()
   595  	abc := make([]byte, 128)
   596  	for i := range abc {
   597  		abc[i] = byte(i)
   598  	}
   599  	abcabc := bytes.Repeat(abc, 131072/len(abc))
   600  	var want []byte
   601  
   602  	testCases := [][]int{
   603  		{65536, 0},
   604  		{65536, 1},
   605  		{65536, 1, 256},
   606  		{65536, 1, 65536},
   607  		{65536, 14},
   608  		{65536, 15},
   609  		{65536, 16},
   610  		{65536, 16, 256},
   611  		{65536, 16, 65536},
   612  		{65536, 127},
   613  		{65536, 128},
   614  		{65536, 128, 256},
   615  		{65536, 128, 65536},
   616  		{65536, 129},
   617  		{65536, 65536, 256},
   618  		{65536, 65536, 65536},
   619  	}
   620  
   621  	for i, tc := range testCases {
   622  		if i >= 3 && testing.Short() {
   623  			break
   624  		}
   625  		for _, firstN := range []int{1, 65534, 65535, 65536, 65537, 131072} {
   626  			tc[0] = firstN
   627  		outer:
   628  			for _, flush := range []bool{false, true} {
   629  				buf := new(bytes.Buffer)
   630  				want = want[:0]
   631  
   632  				w, err := NewWriter(buf, BestSpeed)
   633  				if err != nil {
   634  					t.Errorf("i=%d, firstN=%d, flush=%t: NewWriter: %v", i, firstN, flush, err)
   635  					continue
   636  				}
   637  				for _, n := range tc {
   638  					want = append(want, abcabc[:n]...)
   639  					if _, err := w.Write(abcabc[:n]); err != nil {
   640  						t.Errorf("i=%d, firstN=%d, flush=%t: Write: %v", i, firstN, flush, err)
   641  						continue outer
   642  					}
   643  					if !flush {
   644  						continue
   645  					}
   646  					if err := w.Flush(); err != nil {
   647  						t.Errorf("i=%d, firstN=%d, flush=%t: Flush: %v", i, firstN, flush, err)
   648  						continue outer
   649  					}
   650  				}
   651  				if err := w.Close(); err != nil {
   652  					t.Errorf("i=%d, firstN=%d, flush=%t: Close: %v", i, firstN, flush, err)
   653  					continue
   654  				}
   655  
   656  				r := NewReader(buf)
   657  				got, err := io.ReadAll(r)
   658  				if err != nil {
   659  					t.Errorf("i=%d, firstN=%d, flush=%t: ReadAll: %v", i, firstN, flush, err)
   660  					continue
   661  				}
   662  				r.Close()
   663  
   664  				if !bytes.Equal(got, want) {
   665  					t.Errorf("i=%d, firstN=%d, flush=%t: corruption during deflate-then-inflate", i, firstN, flush)
   666  					continue
   667  				}
   668  			}
   669  		}
   670  	}
   671  }
   672  
   673  var errIO = errors.New("IO error")
   674  
   675  // failWriter fails with errIO exactly at the nth call to Write.
   676  type failWriter struct{ n int }
   677  
   678  func (w *failWriter) Write(b []byte) (int, error) {
   679  	w.n--
   680  	if w.n == -1 {
   681  		return 0, errIO
   682  	}
   683  	return len(b), nil
   684  }
   685  
   686  func TestWriterPersistentError(t *testing.T) {
   687  	t.Parallel()
   688  	d, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
   689  	if err != nil {
   690  		t.Fatalf("ReadFile: %v", err)
   691  	}
   692  	d = d[:10000] // Keep this test short
   693  
   694  	zw, err := NewWriter(nil, DefaultCompression)
   695  	if err != nil {
   696  		t.Fatalf("NewWriter: %v", err)
   697  	}
   698  
   699  	// Sweep over the threshold at which an error is returned.
   700  	// The variable i makes it such that the ith call to failWriter.Write will
   701  	// return errIO. Since failWriter errors are not persistent, we must ensure
   702  	// that flate.Writer errors are persistent.
   703  	for i := 0; i < 1000; i++ {
   704  		fw := &failWriter{i}
   705  		zw.Reset(fw)
   706  
   707  		_, werr := zw.Write(d)
   708  		cerr := zw.Close()
   709  		if werr != errIO && werr != nil {
   710  			t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
   711  		}
   712  		if cerr != errIO && fw.n < 0 {
   713  			t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
   714  		}
   715  		if fw.n >= 0 {
   716  			// At this point, the failure threshold was sufficiently high enough
   717  			// that we wrote the whole stream without any errors.
   718  			return
   719  		}
   720  	}
   721  }
   722  
   723  func TestBestSpeedMatch(t *testing.T) {
   724  	t.Parallel()
   725  	cases := []struct {
   726  		previous, current []byte
   727  		t, s, want        int32
   728  	}{{
   729  		previous: []byte{0, 0, 0, 1, 2},
   730  		current:  []byte{3, 4, 5, 0, 1, 2, 3, 4, 5},
   731  		t:        -3,
   732  		s:        3,
   733  		want:     6,
   734  	}, {
   735  		previous: []byte{0, 0, 0, 1, 2},
   736  		current:  []byte{2, 4, 5, 0, 1, 2, 3, 4, 5},
   737  		t:        -3,
   738  		s:        3,
   739  		want:     3,
   740  	}, {
   741  		previous: []byte{0, 0, 0, 1, 1},
   742  		current:  []byte{3, 4, 5, 0, 1, 2, 3, 4, 5},
   743  		t:        -3,
   744  		s:        3,
   745  		want:     2,
   746  	}, {
   747  		previous: []byte{0, 0, 0, 1, 2},
   748  		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
   749  		t:        -1,
   750  		s:        0,
   751  		want:     4,
   752  	}, {
   753  		previous: []byte{0, 0, 0, 1, 2, 3, 4, 5, 2, 2},
   754  		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
   755  		t:        -7,
   756  		s:        4,
   757  		want:     5,
   758  	}, {
   759  		previous: []byte{9, 9, 9, 9, 9},
   760  		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
   761  		t:        -1,
   762  		s:        0,
   763  		want:     0,
   764  	}, {
   765  		previous: []byte{9, 9, 9, 9, 9},
   766  		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
   767  		t:        0,
   768  		s:        1,
   769  		want:     0,
   770  	}, {
   771  		previous: []byte{},
   772  		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
   773  		t:        -5,
   774  		s:        1,
   775  		want:     0,
   776  	}, {
   777  		previous: []byte{},
   778  		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
   779  		t:        -1,
   780  		s:        1,
   781  		want:     0,
   782  	}, {
   783  		previous: []byte{},
   784  		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
   785  		t:        0,
   786  		s:        1,
   787  		want:     3,
   788  	}, {
   789  		previous: []byte{3, 4, 5},
   790  		current:  []byte{3, 4, 5},
   791  		t:        -3,
   792  		s:        0,
   793  		want:     3,
   794  	}, {
   795  		previous: make([]byte, 1000),
   796  		current:  make([]byte, 1000),
   797  		t:        -1000,
   798  		s:        0,
   799  		want:     maxMatchLength - 4,
   800  	}, {
   801  		previous: make([]byte, 200),
   802  		current:  make([]byte, 500),
   803  		t:        -200,
   804  		s:        0,
   805  		want:     maxMatchLength - 4,
   806  	}, {
   807  		previous: make([]byte, 200),
   808  		current:  make([]byte, 500),
   809  		t:        0,
   810  		s:        1,
   811  		want:     maxMatchLength - 4,
   812  	}, {
   813  		previous: make([]byte, maxMatchLength-4),
   814  		current:  make([]byte, 500),
   815  		t:        -(maxMatchLength - 4),
   816  		s:        0,
   817  		want:     maxMatchLength - 4,
   818  	}, {
   819  		previous: make([]byte, 200),
   820  		current:  make([]byte, 500),
   821  		t:        -200,
   822  		s:        400,
   823  		want:     100,
   824  	}, {
   825  		previous: make([]byte, 10),
   826  		current:  make([]byte, 500),
   827  		t:        200,
   828  		s:        400,
   829  		want:     100,
   830  	}}
   831  	for i, c := range cases {
   832  		e := deflateFast{prev: c.previous}
   833  		got := e.matchLen(c.s, c.t, c.current)
   834  		if got != c.want {
   835  			t.Errorf("Test %d: match length, want %d, got %d", i, c.want, got)
   836  		}
   837  	}
   838  }
   839  
   840  func TestBestSpeedMaxMatchOffset(t *testing.T) {
   841  	t.Parallel()
   842  	const abc, xyz = "abcdefgh", "stuvwxyz"
   843  	for _, matchBefore := range []bool{false, true} {
   844  		for _, extra := range []int{0, inputMargin - 1, inputMargin, inputMargin + 1, 2 * inputMargin} {
   845  			for offsetAdj := -5; offsetAdj <= +5; offsetAdj++ {
   846  				report := func(desc string, err error) {
   847  					t.Errorf("matchBefore=%t, extra=%d, offsetAdj=%d: %s%v",
   848  						matchBefore, extra, offsetAdj, desc, err)
   849  				}
   850  
   851  				offset := maxMatchOffset + offsetAdj
   852  
   853  				// Make src to be a []byte of the form
   854  				//	"%s%s%s%s%s" % (abc, zeros0, xyzMaybe, abc, zeros1)
   855  				// where:
   856  				//	zeros0 is approximately maxMatchOffset zeros.
   857  				//	xyzMaybe is either xyz or the empty string.
   858  				//	zeros1 is between 0 and 30 zeros.
   859  				// The difference between the two abc's will be offset, which
   860  				// is maxMatchOffset plus or minus a small adjustment.
   861  				src := make([]byte, offset+len(abc)+extra)
   862  				copy(src, abc)
   863  				if !matchBefore {
   864  					copy(src[offset-len(xyz):], xyz)
   865  				}
   866  				copy(src[offset:], abc)
   867  
   868  				buf := new(bytes.Buffer)
   869  				w, err := NewWriter(buf, BestSpeed)
   870  				if err != nil {
   871  					report("NewWriter: ", err)
   872  					continue
   873  				}
   874  				if _, err := w.Write(src); err != nil {
   875  					report("Write: ", err)
   876  					continue
   877  				}
   878  				if err := w.Close(); err != nil {
   879  					report("Writer.Close: ", err)
   880  					continue
   881  				}
   882  
   883  				r := NewReader(buf)
   884  				dst, err := io.ReadAll(r)
   885  				r.Close()
   886  				if err != nil {
   887  					report("ReadAll: ", err)
   888  					continue
   889  				}
   890  
   891  				if !bytes.Equal(dst, src) {
   892  					report("", fmt.Errorf("bytes differ after round-tripping"))
   893  					continue
   894  				}
   895  			}
   896  		}
   897  	}
   898  }
   899  
   900  func TestBestSpeedShiftOffsets(t *testing.T) {
   901  	// Test if shiftoffsets properly preserves matches and resets out-of-range matches
   902  	// seen in https://github.com/golang/go/issues/4142
   903  	enc := newDeflateFast()
   904  
   905  	// testData may not generate internal matches.
   906  	testData := make([]byte, 32)
   907  	rng := rand.New(rand.NewSource(0))
   908  	for i := range testData {
   909  		testData[i] = byte(rng.Uint32())
   910  	}
   911  
   912  	// Encode the testdata with clean state.
   913  	// Second part should pick up matches from the first block.
   914  	wantFirstTokens := len(enc.encode(nil, testData))
   915  	wantSecondTokens := len(enc.encode(nil, testData))
   916  
   917  	if wantFirstTokens <= wantSecondTokens {
   918  		t.Fatalf("test needs matches between inputs to be generated")
   919  	}
   920  	// Forward the current indicator to before wraparound.
   921  	enc.cur = bufferReset - int32(len(testData))
   922  
   923  	// Part 1 before wrap, should match clean state.
   924  	got := len(enc.encode(nil, testData))
   925  	if wantFirstTokens != got {
   926  		t.Errorf("got %d, want %d tokens", got, wantFirstTokens)
   927  	}
   928  
   929  	// Verify we are about to wrap.
   930  	if enc.cur != bufferReset {
   931  		t.Errorf("got %d, want e.cur to be at bufferReset (%d)", enc.cur, bufferReset)
   932  	}
   933  
   934  	// Part 2 should match clean state as well even if wrapped.
   935  	got = len(enc.encode(nil, testData))
   936  	if wantSecondTokens != got {
   937  		t.Errorf("got %d, want %d token", got, wantSecondTokens)
   938  	}
   939  
   940  	// Verify that we wrapped.
   941  	if enc.cur >= bufferReset {
   942  		t.Errorf("want e.cur to be < bufferReset (%d), got %d", bufferReset, enc.cur)
   943  	}
   944  
   945  	// Forward the current buffer, leaving the matches at the bottom.
   946  	enc.cur = bufferReset
   947  	enc.shiftOffsets()
   948  
   949  	// Ensure that no matches were picked up.
   950  	got = len(enc.encode(nil, testData))
   951  	if wantFirstTokens != got {
   952  		t.Errorf("got %d, want %d tokens", got, wantFirstTokens)
   953  	}
   954  }
   955  
   956  func TestMaxStackSize(t *testing.T) {
   957  	// This test must not run in parallel with other tests as debug.SetMaxStack
   958  	// affects all goroutines.
   959  	n := debug.SetMaxStack(1 << 16)
   960  	defer debug.SetMaxStack(n)
   961  
   962  	var wg sync.WaitGroup
   963  	defer wg.Wait()
   964  
   965  	b := make([]byte, 1<<20)
   966  	for level := HuffmanOnly; level <= BestCompression; level++ {
   967  		// Run in separate goroutine to increase probability of stack regrowth.
   968  		wg.Add(1)
   969  		go func(level int) {
   970  			defer wg.Done()
   971  			zw, err := NewWriter(io.Discard, level)
   972  			if err != nil {
   973  				t.Errorf("level %d, NewWriter() = %v, want nil", level, err)
   974  			}
   975  			if n, err := zw.Write(b); n != len(b) || err != nil {
   976  				t.Errorf("level %d, Write() = (%d, %v), want (%d, nil)", level, n, err, len(b))
   977  			}
   978  			if err := zw.Close(); err != nil {
   979  				t.Errorf("level %d, Close() = %v, want nil", level, err)
   980  			}
   981  			zw.Reset(io.Discard)
   982  		}(level)
   983  	}
   984  }
   985  

View as plain text