// Copyright 2012 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package httptest import ( "fmt" "io" "net/http" "testing" ) func TestRecorder(t *testing.T) { type checkFunc func(*ResponseRecorder) error check := func(fns ...checkFunc) []checkFunc { return fns } hasStatus := func(wantCode int) checkFunc { return func(rec *ResponseRecorder) error { if rec.Code != wantCode { return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) } return nil } } hasResultStatus := func(want string) checkFunc { return func(rec *ResponseRecorder) error { if rec.Result().Status != want { return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want) } return nil } } hasResultStatusCode := func(wantCode int) checkFunc { return func(rec *ResponseRecorder) error { if rec.Result().StatusCode != wantCode { return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode) } return nil } } hasResultContents := func(want string) checkFunc { return func(rec *ResponseRecorder) error { contentBytes, err := io.ReadAll(rec.Result().Body) if err != nil { return err } contents := string(contentBytes) if contents != want { return fmt.Errorf("Result().Body = %s; want %s", contents, want) } return nil } } hasContents := func(want string) checkFunc { return func(rec *ResponseRecorder) error { if rec.Body.String() != want { return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) } return nil } } hasFlush := func(want bool) checkFunc { return func(rec *ResponseRecorder) error { if rec.Flushed != want { return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) } return nil } } hasOldHeader := func(key, want string) checkFunc { return func(rec *ResponseRecorder) error { if got := rec.HeaderMap.Get(key); got != want { return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want) } return nil } } hasHeader := func(key, want string) checkFunc { return func(rec *ResponseRecorder) error { if got := rec.Result().Header.Get(key); got != want { return fmt.Errorf("final header %s = %q; want %q", key, got, want) } return nil } } hasNotHeaders := func(keys ...string) checkFunc { return func(rec *ResponseRecorder) error { for _, k := range keys { v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)] if ok { return fmt.Errorf("unexpected header %s with value %q", k, v) } } return nil } } hasTrailer := func(key, want string) checkFunc { return func(rec *ResponseRecorder) error { if got := rec.Result().Trailer.Get(key); got != want { return fmt.Errorf("trailer %s = %q; want %q", key, got, want) } return nil } } hasNotTrailers := func(keys ...string) checkFunc { return func(rec *ResponseRecorder) error { trailers := rec.Result().Trailer for _, k := range keys { _, ok := trailers[http.CanonicalHeaderKey(k)] if ok { return fmt.Errorf("unexpected trailer %s", k) } } return nil } } hasContentLength := func(length int64) checkFunc { return func(rec *ResponseRecorder) error { if got := rec.Result().ContentLength; got != length { return fmt.Errorf("ContentLength = %d; want %d", got, length) } return nil } } for _, tt := range [...]struct { name string h func(w http.ResponseWriter, r *http.Request) checks []checkFunc }{ { "200 default", func(w http.ResponseWriter, r *http.Request) {}, check(hasStatus(200), hasContents("")), }, { "first code only", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(201) w.WriteHeader(202) w.Write([]byte("hi")) }, check(hasStatus(201), hasContents("hi")), }, { "write sends 200", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hi first")) w.WriteHeader(201) w.WriteHeader(202) }, check(hasStatus(200), hasContents("hi first"), hasFlush(false)), }, { "write string", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "hi first") }, check( hasStatus(200), hasContents("hi first"), hasFlush(false), hasHeader("Content-Type", "text/plain; charset=utf-8"), ), }, { "flush", func(w http.ResponseWriter, r *http.Request) { w.(http.Flusher).Flush() // also sends a 200 w.WriteHeader(201) }, check(hasStatus(200), hasFlush(true), hasContentLength(-1)), }, { "Content-Type detection", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "") }, check(hasHeader("Content-Type", "text/html; charset=utf-8")), }, { "no Content-Type detection with Transfer-Encoding", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Transfer-Encoding", "some encoding") io.WriteString(w, "") }, check(hasHeader("Content-Type", "")), // no header }, { "no Content-Type detection if set explicitly", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "some/type") io.WriteString(w, "") }, check(hasHeader("Content-Type", "some/type")), }, { "Content-Type detection doesn't crash if HeaderMap is nil", func(w http.ResponseWriter, r *http.Request) { // Act as if the user wrote new(httptest.ResponseRecorder) // rather than using NewRecorder (which initializes // HeaderMap) w.(*ResponseRecorder).HeaderMap = nil io.WriteString(w, "") }, check(hasHeader("Content-Type", "text/html; charset=utf-8")), }, { "Header is not changed after write", func(w http.ResponseWriter, r *http.Request) { hdr := w.Header() hdr.Set("Key", "correct") w.WriteHeader(200) hdr.Set("Key", "incorrect") }, check(hasHeader("Key", "correct")), }, { "Trailer headers are correctly recorded", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Non-Trailer", "correct") w.Header().Set("Trailer", "Trailer-A") w.Header().Add("Trailer", "Trailer-B") w.Header().Add("Trailer", "Trailer-C") io.WriteString(w, "") w.Header().Set("Non-Trailer", "incorrect") w.Header().Set("Trailer-A", "valuea") w.Header().Set("Trailer-C", "valuec") w.Header().Set("Trailer-NotDeclared", "should be omitted") w.Header().Set("Trailer:Trailer-D", "with prefix") }, check( hasStatus(200), hasHeader("Content-Type", "text/html; charset=utf-8"), hasHeader("Non-Trailer", "correct"), hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"), hasTrailer("Trailer-A", "valuea"), hasTrailer("Trailer-C", "valuec"), hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), hasTrailer("Trailer-D", "with prefix"), ), }, { "Header set without any write", // Issue 15560 func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Foo", "1") // Simulate somebody using // new(ResponseRecorder) instead of // using the constructor which sets // this to 200 w.(*ResponseRecorder).Code = 0 }, check( hasOldHeader("X-Foo", "1"), hasStatus(0), hasHeader("X-Foo", "1"), hasResultStatus("200 OK"), hasResultStatusCode(200), ), }, { "HeaderMap vs FinalHeaders", // more for Issue 15560 func(w http.ResponseWriter, r *http.Request) { h := w.Header() h.Set("X-Foo", "1") w.Write([]byte("hi")) h.Set("X-Foo", "2") h.Set("X-Bar", "2") }, check( hasOldHeader("X-Foo", "2"), hasOldHeader("X-Bar", "2"), hasHeader("X-Foo", "1"), hasNotHeaders("X-Bar"), ), }, { "setting Content-Length header", func(w http.ResponseWriter, r *http.Request) { body := "Some body" contentLength := fmt.Sprintf("%d", len(body)) w.Header().Set("Content-Length", contentLength) io.WriteString(w, body) }, check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), }, { "nil ResponseRecorder.Body", // Issue 26642 func(w http.ResponseWriter, r *http.Request) { w.(*ResponseRecorder).Body = nil io.WriteString(w, "hi") }, check(hasResultContents("")), // check we don't crash reading the body }, } { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest("GET", "http://foo.com/", nil) h := http.HandlerFunc(tt.h) rec := NewRecorder() h.ServeHTTP(rec, r) for _, check := range tt.checks { if err := check(rec); err != nil { t.Error(err) } } }) } } // issue 39017 - disallow Content-Length values such as "+3" func TestParseContentLength(t *testing.T) { tests := []struct { cl string want int64 }{ { cl: "3", want: 3, }, { cl: "+3", want: -1, }, { cl: "-3", want: -1, }, { // max int64, for safe conversion before returning cl: "9223372036854775807", want: 9223372036854775807, }, { cl: "9223372036854775808", want: -1, }, } for _, tt := range tests { if got := parseContentLength(tt.cl); got != tt.want { t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want) } } } // Ensure that httptest.Recorder panics when given a non-3 digit (XXX) // status HTTP code. See https://golang.org/issues/45353 func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) { badCodes := []int{ -100, 0, 99, 1000, 20000, } for _, badCode := range badCodes { badCode := badCode t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) { defer func() { if r := recover(); r == nil { t.Fatal("Expected a panic") } }() handler := func(rw http.ResponseWriter, _ *http.Request) { rw.WriteHeader(badCode) } r, _ := http.NewRequest("GET", "http://example.org/", nil) rw := NewRecorder() handler(rw, r) }) } }