pax_global_header00006660000000000000000000000064145137402530014516gustar00rootroot0000000000000052 comment=02b3ca3432ded5ecc34ba10c90017113835a1ec0 handlers-1.5.2/000077500000000000000000000000001451374025300133235ustar00rootroot00000000000000handlers-1.5.2/.editorconfig000066400000000000000000000005041451374025300157770ustar00rootroot00000000000000; https://editorconfig.org/ root = true [*] insert_final_newline = true charset = utf-8 trim_trailing_whitespace = true indent_style = space indent_size = 2 [{Makefile,go.mod,go.sum,*.go,.gitmodules}] indent_style = tab indent_size = 4 [*.md] indent_size = 4 trim_trailing_whitespace = false eclint_indent_style = unsethandlers-1.5.2/.github/000077500000000000000000000000001451374025300146635ustar00rootroot00000000000000handlers-1.5.2/.github/workflows/000077500000000000000000000000001451374025300167205ustar00rootroot00000000000000handlers-1.5.2/.github/workflows/issues.yml000066400000000000000000000007351451374025300207630ustar00rootroot00000000000000# Add all the issues created to the project. name: Add issue or pull request to Project on: issues: types: - opened pull_request_target: types: - opened - reopened jobs: add-to-project: runs-on: ubuntu-latest steps: - name: Add issue to project uses: actions/add-to-project@v0.5.0 with: project-url: https://github.com/orgs/gorilla/projects/4 github-token: ${{ secrets.ADD_TO_PROJECT_TOKEN }} handlers-1.5.2/.github/workflows/security.yml000066400000000000000000000013561451374025300213170ustar00rootroot00000000000000name: Security on: push: branches: - main pull_request: branches: - main permissions: contents: read jobs: scan: strategy: matrix: go: ['1.20','1.21'] fail-fast: true runs-on: ubuntu-latest steps: - name: Checkout Code uses: actions/checkout@v3 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} cache: false - name: Run GoSec uses: securego/gosec@master with: args: -exclude-dir examples ./... - name: Run GoVulnCheck uses: golang/govulncheck-action@v1 with: go-version-input: ${{ matrix.go }} go-package: ./... handlers-1.5.2/.github/workflows/test.yml000066400000000000000000000013701451374025300204230ustar00rootroot00000000000000name: Test on: push: branches: - main pull_request: branches: - main permissions: contents: read jobs: unit: strategy: matrix: go: ['1.20','1.21'] os: [ubuntu-latest, macos-latest, windows-latest] fail-fast: true runs-on: ${{ matrix.os }} steps: - name: Checkout Code uses: actions/checkout@v3 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} cache: false - name: Run Tests run: go test -race -cover -coverprofile=coverage -covermode=atomic -v ./... - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: files: ./coverage handlers-1.5.2/.github/workflows/verify.yml000066400000000000000000000011501451374025300207440ustar00rootroot00000000000000name: Verify on: push: branches: - main pull_request: branches: - main permissions: contents: read jobs: lint: strategy: matrix: go: ['1.20','1.21'] fail-fast: true runs-on: ubuntu-latest steps: - name: Checkout Code uses: actions/checkout@v3 - name: Setup Go ${{ matrix.go }} uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} cache: false - name: Run GolangCI-Lint uses: golangci/golangci-lint-action@v3 with: version: v1.53 args: --timeout=5m handlers-1.5.2/.gitignore000066400000000000000000000000741451374025300153140ustar00rootroot00000000000000# Output of the go test coverage tool coverage.coverprofile handlers-1.5.2/LICENSE000066400000000000000000000027111451374025300143310ustar00rootroot00000000000000Copyright (c) 2023 The Gorilla Authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. handlers-1.5.2/Makefile000066400000000000000000000016621451374025300147700ustar00rootroot00000000000000GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '') GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest GO_SEC=$(shell which gosec 2> /dev/null || echo '') GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '') GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest .PHONY: verify verify: sec govulncheck lint test .PHONY: lint lint: $(if $(GO_LINT), ,go install $(GO_LINT_URI)) @echo "##### Running golangci-lint #####" golangci-lint run -v .PHONY: sec sec: $(if $(GO_SEC), ,go install $(GO_SEC_URI)) @echo "##### Running gosec #####" gosec ./... .PHONY: govulncheck govulncheck: $(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI)) @echo "##### Running govulncheck #####" govulncheck ./... .PHONY: test test: @echo "##### Running tests #####" go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./... handlers-1.5.2/README.md000066400000000000000000000054501451374025300146060ustar00rootroot00000000000000# gorilla/handlers ![Testing](https://github.com/gorilla/handlers/actions/workflows/test.yml/badge.svg) [![Codecov](https://codecov.io/github/gorilla/handlers/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/handlers) [![GoDoc](https://godoc.org/github.com/gorilla/handlers?status.svg)](https://godoc.org/github.com/gorilla/handlers) [![Sourcegraph](https://sourcegraph.com/github.com/gorilla/handlers/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/handlers?badge) Package handlers is a collection of handlers (aka "HTTP middleware") for use with Go's `net/http` package (or any framework supporting `http.Handler`), including: * [**LoggingHandler**](https://godoc.org/github.com/gorilla/handlers#LoggingHandler) for logging HTTP requests in the Apache [Common Log Format](http://httpd.apache.org/docs/2.2/logs.html#common). * [**CombinedLoggingHandler**](https://godoc.org/github.com/gorilla/handlers#CombinedLoggingHandler) for logging HTTP requests in the Apache [Combined Log Format](http://httpd.apache.org/docs/2.2/logs.html#combined) commonly used by both Apache and nginx. * [**CompressHandler**](https://godoc.org/github.com/gorilla/handlers#CompressHandler) for gzipping responses. * [**ContentTypeHandler**](https://godoc.org/github.com/gorilla/handlers#ContentTypeHandler) for validating requests against a list of accepted content types. * [**MethodHandler**](https://godoc.org/github.com/gorilla/handlers#MethodHandler) for matching HTTP methods against handlers in a `map[string]http.Handler` * [**ProxyHeaders**](https://godoc.org/github.com/gorilla/handlers#ProxyHeaders) for populating `r.RemoteAddr` and `r.URL.Scheme` based on the `X-Forwarded-For`, `X-Real-IP`, `X-Forwarded-Proto` and RFC7239 `Forwarded` headers when running a Go server behind a HTTP reverse proxy. * [**CanonicalHost**](https://godoc.org/github.com/gorilla/handlers#CanonicalHost) for re-directing to the preferred host when handling multiple domains (i.e. multiple CNAME aliases). * [**RecoveryHandler**](https://godoc.org/github.com/gorilla/handlers#RecoveryHandler) for recovering from unexpected panics. Other handlers are documented [on the Gorilla website](https://www.gorillatoolkit.org/pkg/handlers). ## Example A simple example using `handlers.LoggingHandler` and `handlers.CompressHandler`: ```go import ( "net/http" "github.com/gorilla/handlers" ) func main() { r := http.NewServeMux() // Only log requests to our admin dashboard to stdout r.Handle("/admin", handlers.LoggingHandler(os.Stdout, http.HandlerFunc(ShowAdminDashboard))) r.HandleFunc("/", ShowIndex) // Wrap our server with our gzip handler to gzip compress all responses. http.ListenAndServe(":8000", handlers.CompressHandler(r)) } ``` ## License BSD licensed. See the included LICENSE file for details. handlers-1.5.2/canonical.go000066400000000000000000000036441451374025300156100ustar00rootroot00000000000000package handlers import ( "net/http" "net/url" "strings" ) type canonical struct { h http.Handler domain string code int } // CanonicalHost is HTTP middleware that re-directs requests to the canonical // domain. It accepts a domain and a status code (e.g. 301 or 302) and // re-directs clients to this domain. The existing request path is maintained. // // Note: If the provided domain is considered invalid by url.Parse or otherwise // returns an empty scheme or host, clients are not re-directed. // // Example: // // r := mux.NewRouter() // canonical := handlers.CanonicalHost("http://www.gorillatoolkit.org", 302) // r.HandleFunc("/route", YourHandler) // // log.Fatal(http.ListenAndServe(":7000", canonical(r))) func CanonicalHost(domain string, code int) func(h http.Handler) http.Handler { fn := func(h http.Handler) http.Handler { return canonical{h, domain, code} } return fn } func (c canonical) ServeHTTP(w http.ResponseWriter, r *http.Request) { dest, err := url.Parse(c.domain) if err != nil { // Call the next handler if the provided domain fails to parse. c.h.ServeHTTP(w, r) return } if dest.Scheme == "" || dest.Host == "" { // Call the next handler if the scheme or host are empty. // Note that url.Parse won't fail on in this case. c.h.ServeHTTP(w, r) return } if !strings.EqualFold(cleanHost(r.Host), dest.Host) { // Re-build the destination URL dest := dest.Scheme + "://" + dest.Host + r.URL.Path if r.URL.RawQuery != "" { dest += "?" + r.URL.RawQuery } http.Redirect(w, r, dest, c.code) return } c.h.ServeHTTP(w, r) } // cleanHost cleans invalid Host headers by stripping anything after '/' or ' '. // This is backported from Go 1.5 (in response to issue #11206) and attempts to // mitigate malformed Host headers that do not match the format in RFC7230. func cleanHost(in string) string { if i := strings.IndexAny(in, " /"); i != -1 { return in[:i] } return in } handlers-1.5.2/canonical_test.go000066400000000000000000000066041451374025300166460ustar00rootroot00000000000000package handlers import ( "bufio" "bytes" "log" "net/http" "net/http/httptest" "net/url" "strings" "testing" ) func TestCleanHost(t *testing.T) { tests := []struct { in, want string }{ {"www.google.com", "www.google.com"}, {"www.google.com foo", "www.google.com"}, {"www.google.com/foo", "www.google.com"}, {" first character is a space", ""}, } for _, tt := range tests { got := cleanHost(tt.in) if tt.want != got { t.Errorf("cleanHost(%q) = %q, want %q", tt.in, got, tt.want) } } } func TestCanonicalHost(t *testing.T) { gorilla := "http://www.gorillatoolkit.org" rr := httptest.NewRecorder() r := newRequest(http.MethodGet, "http://www.example.com/") testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) // Test a re-direct: should return a 302 Found. CanonicalHost(gorilla, http.StatusFound)(testHandler).ServeHTTP(rr, r) if rr.Code != http.StatusFound { t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusFound) } if rr.Header().Get("Location") != gorilla+r.URL.Path { t.Fatalf("bad re-direct: got %q want %q", rr.Header().Get("Location"), gorilla+r.URL.Path) } } func TestKeepsQueryString(t *testing.T) { google := "https://www.google.com" rr := httptest.NewRecorder() querystring := url.Values{"q": {"golang"}, "format": {"json"}}.Encode() r := newRequest(http.MethodGet, "http://www.example.com/search?"+querystring) testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CanonicalHost(google, http.StatusFound)(testHandler).ServeHTTP(rr, r) want := google + r.URL.Path + "?" + querystring if rr.Header().Get("Location") != want { t.Fatalf("bad re-direct: got %q want %q", rr.Header().Get("Location"), want) } } func TestBadDomain(t *testing.T) { rr := httptest.NewRecorder() r := newRequest(http.MethodGet, "http://www.example.com/") testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) // Test a bad domain - should return 200 OK. CanonicalHost("%", http.StatusFound)(testHandler).ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusOK) } } func TestEmptyHost(t *testing.T) { rr := httptest.NewRecorder() r := newRequest(http.MethodGet, "http://www.example.com/") testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) // Test a domain that returns an empty url.Host from url.Parse. CanonicalHost("hello.com", http.StatusFound)(testHandler).ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusOK) } } func TestHeaderWrites(t *testing.T) { gorilla := "http://www.gorillatoolkit.org" testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // Catch the log output to ensure we don't write multiple headers. var b bytes.Buffer buf := bufio.NewWriter(&b) tl := log.New(buf, "test: ", log.Lshortfile) srv := httptest.NewServer( CanonicalHost(gorilla, http.StatusFound)(testHandler)) defer srv.Close() srv.Config.ErrorLog = tl _, err := http.Get(srv.URL) if err != nil { t.Fatal(err) } err = buf.Flush() if err != nil { t.Fatal(err) } // We rely on the error not changing: net/http does not export it. if strings.Contains(b.String(), "multiple response.WriteHeader calls") { t.Fatalf("re-direct did not return early: multiple header writes") } } handlers-1.5.2/compress.go000066400000000000000000000073511451374025300155130ustar00rootroot00000000000000// Copyright 2013 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package handlers import ( "compress/flate" "compress/gzip" "io" "net/http" "strings" "github.com/felixge/httpsnoop" ) const acceptEncoding string = "Accept-Encoding" type compressResponseWriter struct { compressor io.Writer w http.ResponseWriter } func (cw *compressResponseWriter) WriteHeader(c int) { cw.w.Header().Del("Content-Length") cw.w.WriteHeader(c) } func (cw *compressResponseWriter) Write(b []byte) (int, error) { h := cw.w.Header() if h.Get("Content-Type") == "" { h.Set("Content-Type", http.DetectContentType(b)) } h.Del("Content-Length") return cw.compressor.Write(b) } func (cw *compressResponseWriter) ReadFrom(r io.Reader) (int64, error) { return io.Copy(cw.compressor, r) } type flusher interface { Flush() error } func (cw *compressResponseWriter) Flush() { // Flush compressed data if compressor supports it. if f, ok := cw.compressor.(flusher); ok { _ = f.Flush() } // Flush HTTP response. if f, ok := cw.w.(http.Flusher); ok { f.Flush() } } // CompressHandler gzip compresses HTTP responses for clients that support it // via the 'Accept-Encoding' header. // // Compressing TLS traffic may leak the page contents to an attacker if the // page contains user input: http://security.stackexchange.com/a/102015/12208 func CompressHandler(h http.Handler) http.Handler { return CompressHandlerLevel(h, gzip.DefaultCompression) } // CompressHandlerLevel gzip compresses HTTP responses with specified compression level // for clients that support it via the 'Accept-Encoding' header. // // The compression level should be gzip.DefaultCompression, gzip.NoCompression, // or any integer value between gzip.BestSpeed and gzip.BestCompression inclusive. // gzip.DefaultCompression is used in case of invalid compression level. func CompressHandlerLevel(h http.Handler, level int) http.Handler { if level < gzip.DefaultCompression || level > gzip.BestCompression { level = gzip.DefaultCompression } const ( gzipEncoding = "gzip" flateEncoding = "deflate" ) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // detect what encoding to use var encoding string for _, curEnc := range strings.Split(r.Header.Get(acceptEncoding), ",") { curEnc = strings.TrimSpace(curEnc) if curEnc == gzipEncoding || curEnc == flateEncoding { encoding = curEnc break } } // always add Accept-Encoding to Vary to prevent intermediate caches corruption w.Header().Add("Vary", acceptEncoding) // if we weren't able to identify an encoding we're familiar with, pass on the // request to the handler and return if encoding == "" { h.ServeHTTP(w, r) return } if r.Header.Get("Upgrade") != "" { h.ServeHTTP(w, r) return } // wrap the ResponseWriter with the writer for the chosen encoding var encWriter io.WriteCloser if encoding == gzipEncoding { encWriter, _ = gzip.NewWriterLevel(w, level) } else if encoding == flateEncoding { encWriter, _ = flate.NewWriter(w, level) } defer encWriter.Close() w.Header().Set("Content-Encoding", encoding) r.Header.Del(acceptEncoding) cw := &compressResponseWriter{ w: w, compressor: encWriter, } w = httpsnoop.Wrap(w, httpsnoop.Hooks{ Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc { return cw.Write }, WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { return cw.WriteHeader }, Flush: func(httpsnoop.FlushFunc) httpsnoop.FlushFunc { return cw.Flush }, ReadFrom: func(rff httpsnoop.ReadFromFunc) httpsnoop.ReadFromFunc { return cw.ReadFrom }, }) h.ServeHTTP(w, r) }) } handlers-1.5.2/compress_test.go000066400000000000000000000177641451374025300165630ustar00rootroot00000000000000// Copyright 2013 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package handlers import ( "bufio" "bytes" "compress/gzip" "io" "log" "net" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "strconv" "testing" ) var contentType = "text/plain; charset=utf-8" func compressedRequest(w *httptest.ResponseRecorder, compression string) { CompressHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", strconv.Itoa(9*1024)) w.Header().Set("Content-Type", contentType) for i := 0; i < 1024; i++ { _, err := io.WriteString(w, "Gorilla!\n") if err != nil { log.Printf("error on writting to http.ResponseWriter: %v\n", err) } } })).ServeHTTP(w, &http.Request{ Method: http.MethodGet, Header: http.Header{ acceptEncoding: []string{compression}, }, }) } func TestCompressHandlerNoCompression(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "") resp := w.Result() if enc := resp.Header.Get("Content-Encoding"); enc != "" { t.Errorf("wrong content encoding, got %q want %q", enc, "") } if ct := resp.Header.Get("Content-Type"); ct != contentType { t.Errorf("wrong content type, got %q want %q", ct, contentType) } if w.Body.Len() != 1024*9 { t.Errorf("wrong len, got %d want %d", w.Body.Len(), 1024*9) } if l := resp.Header.Get("Content-Length"); l != "9216" { t.Errorf("wrong content-length. got %q expected %d", l, 1024*9) } if v := resp.Header.Get("Vary"); v != acceptEncoding { t.Errorf("wrong vary. got %s expected %s", v, acceptEncoding) } } func TestAcceptEncodingIsDropped(t *testing.T) { tCases := []struct { name, compression, expect string isPresent bool }{ { "accept-encoding-gzip", "gzip", "", false, }, { "accept-encoding-deflate", "deflate", "", false, }, { "accept-encoding-gzip,deflate", "gzip,deflate", "", false, }, { "accept-encoding-gzip,deflate,something", "gzip,deflate,something", "", false, }, { "accept-encoding-unknown", "unknown", "unknown", true, }, } for _, tCase := range tCases { ch := CompressHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { acceptEnc := r.Header.Get(acceptEncoding) if acceptEnc == "" && tCase.isPresent { t.Fatalf("%s: expected 'Accept-Encoding' header to be present but was not", tCase.name) } if acceptEnc != "" { if !tCase.isPresent { t.Fatalf("%s: expected 'Accept-Encoding' header to be dropped but was still present having value %q", tCase.name, acceptEnc) } if acceptEnc != tCase.expect { t.Fatalf("%s: expected 'Accept-Encoding' to be %q but was %q", tCase.name, tCase.expect, acceptEnc) } } })) w := httptest.NewRecorder() ch.ServeHTTP(w, &http.Request{ Method: http.MethodGet, Header: http.Header{ acceptEncoding: []string{tCase.compression}, }, }) } } func TestCompressHandlerGzip(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "gzip") resp := w.Result() if resp.Header.Get("Content-Encoding") != "gzip" { t.Errorf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "gzip") } if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Errorf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") } if w.Body.Len() != 72 { t.Errorf("wrong len, got %d want %d", w.Body.Len(), 72) } if l := resp.Header.Get("Content-Length"); l != "" { t.Errorf("wrong content-length. got %q expected %q", l, "") } } func TestCompressHandlerDeflate(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "deflate") resp := w.Result() if resp.Header.Get("Content-Encoding") != "deflate" { t.Fatalf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "deflate") } if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Fatalf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") } if w.Body.Len() != 54 { t.Fatalf("wrong len, got %d want %d", w.Body.Len(), 54) } } func TestCompressHandlerGzipDeflate(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "gzip, deflate ") resp := w.Result() if resp.Header.Get("Content-Encoding") != "gzip" { t.Fatalf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "gzip") } if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Fatalf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") } } // Make sure we can compress and serve an *os.File properly. We need // to use a real http server to trigger the net/http sendfile special // case. func TestCompressFile(t *testing.T) { dir, err := os.MkdirTemp("", "gorilla_compress") if err != nil { t.Fatal(err) } defer os.RemoveAll(dir) err = os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello"), 0o644) if err != nil { t.Fatal(err) } s := httptest.NewServer(CompressHandler(http.FileServer(http.Dir(dir)))) defer s.Close() url := &url.URL{Scheme: "http", Host: s.Listener.Addr().String(), Path: "/hello.txt"} req, err := http.NewRequest(http.MethodGet, url.String(), nil) if err != nil { t.Fatal(err) } req.Header.Set(acceptEncoding, "gzip") res, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } if res.StatusCode != http.StatusOK { t.Fatalf("expected OK, got %q", res.Status) } var got bytes.Buffer gr, err := gzip.NewReader(res.Body) if err != nil { t.Fatal(err) } _, err = io.Copy(&got, gr) if err != nil { t.Fatal(err) } if got.String() != "hello" { t.Errorf("expected hello, got %q", got.String()) } } type fullyFeaturedResponseWriter struct{} // Header/Write/WriteHeader implement the http.ResponseWriter interface. func (fullyFeaturedResponseWriter) Header() http.Header { return http.Header{} } func (fullyFeaturedResponseWriter) Write([]byte) (int, error) { return 0, nil } func (fullyFeaturedResponseWriter) WriteHeader(int) {} // Flush implements the http.Flusher interface. func (fullyFeaturedResponseWriter) Flush() {} // Hijack implements the http.Hijacker interface. func (fullyFeaturedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, nil } func TestCompressHandlerPreserveInterfaces(t *testing.T) { // Compile time validation fullyFeaturedResponseWriter implements all the // interfaces we're asserting in the test case below. var ( _ http.Flusher = fullyFeaturedResponseWriter{} _ http.Hijacker = fullyFeaturedResponseWriter{} ) var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { comp := r.Header.Get(acceptEncoding) if _, ok := rw.(http.Flusher); !ok { t.Errorf("ResponseWriter lost http.Flusher interface for %q", comp) } if _, ok := rw.(http.Hijacker); !ok { t.Errorf("ResponseWriter lost http.Hijacker interface for %q", comp) } }) h = CompressHandler(h) var rw fullyFeaturedResponseWriter r, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { t.Fatalf("Failed to create test request: %v", err) } r.Header.Set(acceptEncoding, "gzip") h.ServeHTTP(rw, r) r.Header.Set(acceptEncoding, "deflate") h.ServeHTTP(rw, r) } type paltryResponseWriter struct{} func (paltryResponseWriter) Header() http.Header { return http.Header{} } func (paltryResponseWriter) Write([]byte) (int, error) { return 0, nil } func (paltryResponseWriter) WriteHeader(int) {} func TestCompressHandlerDoesntInventInterfaces(t *testing.T) { var h http.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { if _, ok := rw.(http.Hijacker); ok { t.Error("ResponseWriter shouldn't implement http.Hijacker") } }) h = CompressHandler(h) var rw paltryResponseWriter r, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { t.Fatalf("Failed to create test request: %v", err) } r.Header.Set(acceptEncoding, "gzip") h.ServeHTTP(rw, r) } handlers-1.5.2/cors.go000066400000000000000000000227141451374025300146260ustar00rootroot00000000000000package handlers import ( "net/http" "strconv" "strings" ) // CORSOption represents a functional option for configuring the CORS middleware. type CORSOption func(*cors) error type cors struct { h http.Handler allowedHeaders []string allowedMethods []string allowedOrigins []string allowedOriginValidator OriginValidator exposedHeaders []string maxAge int ignoreOptions bool allowCredentials bool optionStatusCode int } // OriginValidator takes an origin string and returns whether or not that origin is allowed. type OriginValidator func(string) bool var ( defaultCorsOptionStatusCode = http.StatusOK defaultCorsMethods = []string{http.MethodGet, http.MethodHead, http.MethodPost} defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} // (WebKit/Safari v9 sends the Origin header by default in AJAX requests). ) const ( corsOptionMethod string = http.MethodOptions corsAllowOriginHeader string = "Access-Control-Allow-Origin" corsExposeHeadersHeader string = "Access-Control-Expose-Headers" corsMaxAgeHeader string = "Access-Control-Max-Age" corsAllowMethodsHeader string = "Access-Control-Allow-Methods" corsAllowHeadersHeader string = "Access-Control-Allow-Headers" corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" corsRequestMethodHeader string = "Access-Control-Request-Method" corsRequestHeadersHeader string = "Access-Control-Request-Headers" corsOriginHeader string = "Origin" corsVaryHeader string = "Vary" corsOriginMatchAll string = "*" ) func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get(corsOriginHeader) if !ch.isOriginAllowed(origin) { if r.Method != corsOptionMethod || ch.ignoreOptions { ch.h.ServeHTTP(w, r) } return } if r.Method == corsOptionMethod { if ch.ignoreOptions { ch.h.ServeHTTP(w, r) return } if _, ok := r.Header[corsRequestMethodHeader]; !ok { w.WriteHeader(http.StatusBadRequest) return } method := r.Header.Get(corsRequestMethodHeader) if !ch.isMatch(method, ch.allowedMethods) { w.WriteHeader(http.StatusMethodNotAllowed) return } requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") allowedHeaders := []string{} for _, v := range requestHeaders { canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { continue } if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { w.WriteHeader(http.StatusForbidden) return } allowedHeaders = append(allowedHeaders, canonicalHeader) } if len(allowedHeaders) > 0 { w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) } if ch.maxAge > 0 { w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) } if !ch.isMatch(method, defaultCorsMethods) { w.Header().Set(corsAllowMethodsHeader, method) } } else if len(ch.exposedHeaders) > 0 { w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) } if ch.allowCredentials { w.Header().Set(corsAllowCredentialsHeader, "true") } if len(ch.allowedOrigins) > 1 { w.Header().Set(corsVaryHeader, corsOriginHeader) } returnOrigin := origin if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { returnOrigin = "*" } else { for _, o := range ch.allowedOrigins { // A configuration of * is different than explicitly setting an allowed // origin. Returning arbitrary origin headers in an access control allow // origin header is unsafe and is not required by any use case. if o == corsOriginMatchAll { returnOrigin = "*" break } } } w.Header().Set(corsAllowOriginHeader, returnOrigin) if r.Method == corsOptionMethod { w.WriteHeader(ch.optionStatusCode) return } ch.h.ServeHTTP(w, r) } // CORS provides Cross-Origin Resource Sharing middleware. // Example: // // import ( // "net/http" // // "github.com/gorilla/handlers" // "github.com/gorilla/mux" // ) // // func main() { // r := mux.NewRouter() // r.HandleFunc("/users", UserEndpoint) // r.HandleFunc("/projects", ProjectEndpoint) // // // Apply the CORS middleware to our top-level router, with the defaults. // http.ListenAndServe(":8000", handlers.CORS()(r)) // } func CORS(opts ...CORSOption) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { ch := parseCORSOptions(opts...) ch.h = h return ch } } func parseCORSOptions(opts ...CORSOption) *cors { ch := &cors{ allowedMethods: defaultCorsMethods, allowedHeaders: defaultCorsHeaders, allowedOrigins: []string{}, optionStatusCode: defaultCorsOptionStatusCode, } for _, option := range opts { _ = option(ch) //TODO: @bharat-rajani, return error to caller if not nil? } return ch } // // Functional options for configuring CORS. // // AllowedHeaders adds the provided headers to the list of allowed headers in a // CORS request. // This is an append operation so the headers Accept, Accept-Language, // and Content-Language are always allowed. // Content-Type must be explicitly declared if accepting Content-Types other than // application/x-www-form-urlencoded, multipart/form-data, or text/plain. func AllowedHeaders(headers []string) CORSOption { return func(ch *cors) error { for _, v := range headers { normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) if normalizedHeader == "" { continue } if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) } } return nil } } // AllowedMethods can be used to explicitly allow methods in the // Access-Control-Allow-Methods header. // This is a replacement operation so you must also // pass GET, HEAD, and POST if you wish to support those methods. func AllowedMethods(methods []string) CORSOption { return func(ch *cors) error { ch.allowedMethods = []string{} for _, v := range methods { normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) if normalizedMethod == "" { continue } if !ch.isMatch(normalizedMethod, ch.allowedMethods) { ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) } } return nil } } // AllowedOrigins sets the allowed origins for CORS requests, as used in the // 'Allow-Access-Control-Origin' HTTP header. // Note: Passing in a []string{"*"} will allow any domain. func AllowedOrigins(origins []string) CORSOption { return func(ch *cors) error { for _, v := range origins { if v == corsOriginMatchAll { ch.allowedOrigins = []string{corsOriginMatchAll} return nil } } ch.allowedOrigins = origins return nil } } // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the // 'Allow-Access-Control-Origin' HTTP header. func AllowedOriginValidator(fn OriginValidator) CORSOption { return func(ch *cors) error { ch.allowedOriginValidator = fn return nil } } // OptionStatusCode sets a custom status code on the OPTIONS requests. // Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory // and can be used if you need a custom status code (i.e 204). // // More informations on the spec: // https://fetch.spec.whatwg.org/#cors-preflight-fetch func OptionStatusCode(code int) CORSOption { return func(ch *cors) error { ch.optionStatusCode = code return nil } } // ExposedHeaders can be used to specify headers that are available // and will not be stripped out by the user-agent. func ExposedHeaders(headers []string) CORSOption { return func(ch *cors) error { ch.exposedHeaders = []string{} for _, v := range headers { normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) if normalizedHeader == "" { continue } if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) } } return nil } } // MaxAge determines the maximum age (in seconds) between preflight requests. A // maximum of 10 minutes is allowed. An age above this value will default to 10 // minutes. func MaxAge(age int) CORSOption { return func(ch *cors) error { // Maximum of 10 minutes. if age > 600 { age = 600 } ch.maxAge = age return nil } } // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead // passing them through to the next handler. This is useful when your application // or framework has a pre-existing mechanism for responding to OPTIONS requests. func IgnoreOptions() CORSOption { return func(ch *cors) error { ch.ignoreOptions = true return nil } } // AllowCredentials can be used to specify that the user agent may pass // authentication details along with the request. func AllowCredentials() CORSOption { return func(ch *cors) error { ch.allowCredentials = true return nil } } func (ch *cors) isOriginAllowed(origin string) bool { if origin == "" { return false } if ch.allowedOriginValidator != nil { return ch.allowedOriginValidator(origin) } if len(ch.allowedOrigins) == 0 { return true } for _, allowedOrigin := range ch.allowedOrigins { if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { return true } } return false } func (ch *cors) isMatch(needle string, haystack []string) bool { for _, v := range haystack { if v == needle { return true } } return false } handlers-1.5.2/cors_test.go000066400000000000000000000323211451374025300156600ustar00rootroot00000000000000package handlers import ( "net/http" "net/http/httptest" "strings" "testing" ) func TestDefaultCORSHandlerReturnsOk(t *testing.T) { r := newRequest(http.MethodGet, "http://www.example.com/") rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) { r := newRequest(http.MethodGet, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) }) CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusTeapot; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerSetsExposedHeaders(t *testing.T) { // Test default configuration. r := newRequest(http.MethodGet, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } header := resp.Header.Get(corsExposeHeadersHeader) if got, want := header, "X-Cors-Test"; got != want { t.Fatalf("bad header: expected %q header, got empty header for method.", want) } } func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusBadRequest; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodDelete) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodGet) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("Options request must not be passed to next handler") }) CORS()(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) { statusCode := http.StatusNoContent r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodGet) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("Options request must not be passed to next handler") }) CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, statusCode; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodGet) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("Options request must not be passed to next handler") }) CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodDelete) rc := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rc, r) resp := rc.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } header := resp.Header.Get(corsAllowMethodsHeader) if header != http.MethodDelete { t.Fatalf("bad header: expected %q method header, got %q header.", http.MethodDelete, header) } } func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) { for _, method := range defaultCorsMethods { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, method) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } header := resp.Header.Get(corsAllowMethodsHeader) if got, want := header, ""; got != want { t.Fatalf("bad header: expected %q method header, got %q.", want, got) } } } func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) { for _, simpleHeader := range defaultCorsHeaders { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodGet) r.Header.Set(corsRequestHeadersHeader, simpleHeader) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } header := resp.Header.Get(corsAllowHeadersHeader) if got, want := header, ""; got != want { t.Fatalf("bad header: expected %q header, got %q.", want, got) } } } func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodPost) r.Header.Set(corsRequestHeadersHeader, "Content-Type") rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } header := resp.Header.Get(corsAllowHeadersHeader) if got, want := header, "Content-Type"; got != want { t.Fatalf("bad header: expected %q header, got %q header.", want, got) } } func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodPost) r.Header.Set(corsRequestHeadersHeader, "Content-Type") rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusForbidden; got != want { t.Fatalf("bad status: got %v want %v", got, want) } } func TestCORSHandlerMaxAgeForPreflight(t *testing.T) { r := newRequest(http.MethodOptions, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) r.Header.Set(corsRequestMethodHeader, http.MethodPost) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r) resp := rr.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { t.Fatalf("bad status: got %v want %v", got, want) } header := resp.Header.Get(corsMaxAgeHeader) if got, want := header, "600"; got != want { t.Fatalf("bad header: expected %q to be %q, got %q.", corsMaxAgeHeader, want, got) } } func TestCORSHandlerAllowedCredentials(t *testing.T) { r := newRequest(http.MethodGet, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r) resp := rr.Result() if status := resp.StatusCode; status != http.StatusOK { t.Fatalf("bad status: got %v want %v", status, http.StatusOK) } header := resp.Header.Get(corsAllowCredentialsHeader) if got, want := header, "true"; got != want { t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowCredentialsHeader, want, got) } } func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) { r := newRequest(http.MethodGet, "http://www.example.com/") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r) resp := rr.Result() if status := resp.StatusCode; status != http.StatusOK { t.Fatalf("bad status: got %v want %v", status, http.StatusOK) } header := resp.Header.Get(corsVaryHeader) if got, want := header, corsOriginHeader; got != want { t.Fatalf("bad header: expected %s to be %q, got %q.", corsVaryHeader, want, got) } } func TestCORSWithMultipleHandlers(t *testing.T) { var lastHandledBy string corsMiddleware := CORS() testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { lastHandledBy = "testHandler1" }) testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { lastHandledBy = "testHandler2" }) r1 := newRequest(http.MethodGet, "http://www.example.com/") rr1 := httptest.NewRecorder() handler1 := corsMiddleware(testHandler1) corsMiddleware(testHandler2) handler1.ServeHTTP(rr1, r1) if lastHandledBy != "testHandler1" { t.Fatalf("bad CORS() registration: Handler served should be Handler registered") } } func TestCORSOriginValidatorWithImplicitStar(t *testing.T) { r := newRequest(http.MethodGet, "http://a.example.com") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) originValidator := func(origin string) bool { return strings.HasSuffix(origin, ".example.com") } CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) resp := rr.Result() header := resp.Header.Get(corsAllowOriginHeader) if got, want := header, r.URL.String(); got != want { t.Fatalf("bad header: expected %s to be %q, got %q.", corsAllowOriginHeader, want, got) } } func TestCORSOriginValidatorWithExplicitStar(t *testing.T) { r := newRequest(http.MethodGet, "http://a.example.com") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) originValidator := func(origin string) bool { return strings.HasSuffix(origin, ".example.com") } CORS( AllowedOriginValidator(originValidator), AllowedOrigins([]string{"*"}), )(testHandler).ServeHTTP(rr, r) resp := rr.Result() header := resp.Header.Get(corsAllowOriginHeader) if got, want := header, "*"; got != want { t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) } } func TestCORSAllowStar(t *testing.T) { r := newRequest(http.MethodGet, "http://a.example.com") r.Header.Set("Origin", r.URL.String()) rr := httptest.NewRecorder() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) CORS()(testHandler).ServeHTTP(rr, r) resp := rr.Result() header := resp.Header.Get(corsAllowOriginHeader) if got, want := header, "*"; got != want { t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) } } handlers-1.5.2/doc.go000066400000000000000000000005441451374025300144220ustar00rootroot00000000000000/* Package handlers is a collection of handlers (aka "HTTP middleware") for use with Go's net/http package (or any framework supporting http.Handler). The package includes handlers for logging in standardised formats, compressing HTTP responses, validating content types and other useful tools for manipulating requests and responses. */ package handlers handlers-1.5.2/go.mod000066400000000000000000000001311451374025300144240ustar00rootroot00000000000000module github.com/gorilla/handlers go 1.20 require github.com/felixge/httpsnoop v1.0.3 handlers-1.5.2/go.sum000066400000000000000000000002571451374025300144620ustar00rootroot00000000000000github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= handlers-1.5.2/handlers.go000066400000000000000000000106761451374025300154640ustar00rootroot00000000000000// Copyright 2013 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package handlers import ( "bufio" "fmt" "net" "net/http" "sort" "strings" ) // MethodHandler is an http.Handler that dispatches to a handler whose key in the // MethodHandler's map matches the name of the HTTP request's method, eg: GET // // If the request's method is OPTIONS and OPTIONS is not a key in the map then // the handler responds with a status of 200 and sets the Allow header to a // comma-separated list of available methods. // // If the request's method doesn't match any of its keys the handler responds // with a status of HTTP 405 "Method Not Allowed" and sets the Allow header to a // comma-separated list of available methods. type MethodHandler map[string]http.Handler func (h MethodHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { if handler, ok := h[req.Method]; ok { handler.ServeHTTP(w, req) } else { allow := []string{} for k := range h { allow = append(allow, k) } sort.Strings(allow) w.Header().Set("Allow", strings.Join(allow, ", ")) if req.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) } else { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } } // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP // status code and body size. type responseLogger struct { w http.ResponseWriter status int size int } func (l *responseLogger) Write(b []byte) (int, error) { size, err := l.w.Write(b) l.size += size return size, err } func (l *responseLogger) WriteHeader(s int) { l.w.WriteHeader(s) l.status = s } func (l *responseLogger) Status() int { return l.status } func (l *responseLogger) Size() int { return l.size } func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { conn, rw, err := l.w.(http.Hijacker).Hijack() if err == nil && l.status == 0 { // The status will be StatusSwitchingProtocols if there was no error and // WriteHeader has not been called yet l.status = http.StatusSwitchingProtocols } return conn, rw, err } // isContentType validates the Content-Type header matches the supplied // contentType. That is, its type and subtype match. func isContentType(h http.Header, contentType string) bool { ct := h.Get("Content-Type") if i := strings.IndexRune(ct, ';'); i != -1 { ct = ct[0:i] } return ct == contentType } // ContentTypeHandler wraps and returns a http.Handler, validating the request // content type is compatible with the contentTypes list. It writes a HTTP 415 // error if that fails. // // Only PUT, POST, and PATCH requests are considered. func ContentTypeHandler(h http.Handler, contentTypes ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !(r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodPatch) { h.ServeHTTP(w, r) return } for _, ct := range contentTypes { if isContentType(r.Header, ct) { h.ServeHTTP(w, r) return } } http.Error(w, fmt.Sprintf("Unsupported content type %q; expected one of %q", r.Header.Get("Content-Type"), contentTypes), http.StatusUnsupportedMediaType) }) } const ( // HTTPMethodOverrideHeader is a commonly used // http header to override a request method. HTTPMethodOverrideHeader = "X-HTTP-Method-Override" // HTTPMethodOverrideFormKey is a commonly used // HTML form key to override a request method. HTTPMethodOverrideFormKey = "_method" ) // HTTPMethodOverrideHandler wraps and returns a http.Handler which checks for // the X-HTTP-Method-Override header or the _method form key, and overrides (if // valid) request.Method with its value. // // This is especially useful for HTTP clients that don't support many http verbs. // It isn't secure to override e.g a GET to a POST, so only POST requests are // considered. Likewise, the override method can only be a "write" method: PUT, // PATCH or DELETE. // // Form method takes precedence over header method. func HTTPMethodOverrideHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { om := r.FormValue(HTTPMethodOverrideFormKey) if om == "" { om = r.Header.Get(HTTPMethodOverrideHeader) } if om == http.MethodPut || om == http.MethodPatch || om == http.MethodDelete { r.Method = om } } h.ServeHTTP(w, r) }) } handlers-1.5.2/handlers_test.go000066400000000000000000000122751451374025300165200ustar00rootroot00000000000000// Copyright 2013 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package handlers import ( "io" "log" "net/http" "net/http/httptest" "net/url" "strings" "testing" ) const ( ok = "ok\n" notAllowed = "Method not allowed\n" ) var okHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, err := w.Write([]byte(ok)) if err != nil { log.Fatalf("error on writing to http.ResponseWriter: %v", err) } }) func newRequest(method, url string) *http.Request { req, err := http.NewRequest(method, url, nil) if err != nil { panic(err) } return req } func TestMethodHandler(t *testing.T) { tests := []struct { req *http.Request handler http.Handler code int allow string // Contents of the Allow header body string }{ // No handlers {newRequest(http.MethodGet, "/foo"), MethodHandler{}, http.StatusMethodNotAllowed, "", notAllowed}, {newRequest(http.MethodOptions, "/foo"), MethodHandler{}, http.StatusOK, "", ""}, // A single handler {newRequest(http.MethodGet, "/foo"), MethodHandler{http.MethodGet: okHandler}, http.StatusOK, "", ok}, {newRequest(http.MethodPost, "/foo"), MethodHandler{http.MethodGet: okHandler}, http.StatusMethodNotAllowed, http.MethodGet, notAllowed}, // Multiple handlers {newRequest(http.MethodGet, "/foo"), MethodHandler{http.MethodGet: okHandler, http.MethodPost: okHandler}, http.StatusOK, "", ok}, {newRequest(http.MethodPost, "/foo"), MethodHandler{http.MethodGet: okHandler, http.MethodPost: okHandler}, http.StatusOK, "", ok}, {newRequest(http.MethodDelete, "/foo"), MethodHandler{http.MethodGet: okHandler, http.MethodPost: okHandler}, http.StatusMethodNotAllowed, "GET, POST", notAllowed}, {newRequest(http.MethodOptions, "/foo"), MethodHandler{http.MethodGet: okHandler, http.MethodPost: okHandler}, http.StatusOK, "GET, POST", ""}, // Override OPTIONS {newRequest(http.MethodOptions, "/foo"), MethodHandler{http.MethodOptions: okHandler}, http.StatusOK, "", ok}, } for i, test := range tests { rec := httptest.NewRecorder() test.handler.ServeHTTP(rec, test.req) resp := rec.Result() if resp.StatusCode != test.code { t.Fatalf("%d: wrong code, got %d want %d", i, resp.StatusCode, test.code) } if allow := resp.Header.Get("Allow"); allow != test.allow { t.Fatalf("%d: wrong Allow, got %s want %s", i, allow, test.allow) } respBodyBytes, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("io error while reading response body %v", err) } if body := string(respBodyBytes); body != test.body { t.Fatalf("%d: wrong body, got %q want %q", i, body, test.body) } } } func TestContentTypeHandler(t *testing.T) { tests := []struct { Method string AllowContentTypes []string ContentType string Code int }{ {http.MethodPost, []string{"application/json"}, "application/json", http.StatusOK}, {http.MethodPost, []string{"application/json", "application/xml"}, "application/json", http.StatusOK}, {http.MethodPost, []string{"application/json"}, "application/json; charset=utf-8", http.StatusOK}, {http.MethodPost, []string{"application/json"}, "application/json+xxx", http.StatusUnsupportedMediaType}, {http.MethodPost, []string{"application/json"}, "text/plain", http.StatusUnsupportedMediaType}, {http.MethodGet, []string{"application/json"}, "", http.StatusOK}, {http.MethodGet, []string{}, "", http.StatusOK}, } for _, test := range tests { r, err := http.NewRequest(test.Method, "/", nil) if err != nil { t.Error(err) continue } h := ContentTypeHandler(okHandler, test.AllowContentTypes...) r.Header.Set("Content-Type", test.ContentType) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != test.Code { t.Errorf("expected %d, got %d", test.Code, w.Code) } } } func TestHTTPMethodOverride(t *testing.T) { tests := []struct { Method string OverrideMethod string ExpectedMethod string }{ {http.MethodPost, http.MethodPut, http.MethodPut}, {http.MethodPost, http.MethodPatch, http.MethodPatch}, {http.MethodPost, http.MethodDelete, http.MethodDelete}, {http.MethodPut, http.MethodDelete, http.MethodPut}, {http.MethodGet, http.MethodGet, http.MethodGet}, {http.MethodHead, http.MethodHead, http.MethodHead}, {http.MethodGet, http.MethodPut, http.MethodGet}, {http.MethodHead, http.MethodDelete, http.MethodHead}, } for _, test := range tests { h := HTTPMethodOverrideHandler(okHandler) reqs := make([]*http.Request, 0, 2) rHeader, err := http.NewRequest(test.Method, "/", nil) if err != nil { t.Error(err) } rHeader.Header.Set(HTTPMethodOverrideHeader, test.OverrideMethod) reqs = append(reqs, rHeader) f := url.Values{HTTPMethodOverrideFormKey: []string{test.OverrideMethod}} rForm, err := http.NewRequest(test.Method, "/", strings.NewReader(f.Encode())) if err != nil { t.Error(err) } rForm.Header.Set("Content-Type", "application/x-www-form-urlencoded") reqs = append(reqs, rForm) for _, r := range reqs { w := httptest.NewRecorder() h.ServeHTTP(w, r) if r.Method != test.ExpectedMethod { t.Errorf("Expected %s, got %s", test.ExpectedMethod, r.Method) } } } } handlers-1.5.2/logging.go000066400000000000000000000157541451374025300153140ustar00rootroot00000000000000// Copyright 2013 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package handlers import ( "io" "net" "net/http" "net/url" "strconv" "time" "unicode/utf8" "github.com/felixge/httpsnoop" ) // Logging // LogFormatterParams is the structure any formatter will be handed when time to log comes. type LogFormatterParams struct { Request *http.Request URL url.URL TimeStamp time.Time StatusCode int Size int } // LogFormatter gives the signature of the formatter function passed to CustomLoggingHandler. type LogFormatter func(writer io.Writer, params LogFormatterParams) // loggingHandler is the http.Handler implementation for LoggingHandlerTo and its // friends type loggingHandler struct { writer io.Writer handler http.Handler formatter LogFormatter } func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { t := time.Now() logger, w := makeLogger(w) url := *req.URL h.handler.ServeHTTP(w, req) if req.MultipartForm != nil { err := req.MultipartForm.RemoveAll() if err != nil { return } } params := LogFormatterParams{ Request: req, URL: url, TimeStamp: t, StatusCode: logger.Status(), Size: logger.Size(), } h.formatter(h.writer, params) } func makeLogger(w http.ResponseWriter) (*responseLogger, http.ResponseWriter) { logger := &responseLogger{w: w, status: http.StatusOK} return logger, httpsnoop.Wrap(w, httpsnoop.Hooks{ Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc { return logger.Write }, WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { return logger.WriteHeader }, }) } const lowerhex = "0123456789abcdef" func appendQuoted(buf []byte, s string) []byte { var runeTmp [utf8.UTFMax]byte for width := 0; len(s) > 0; s = s[width:] { //nolint: wastedassign //TODO: why width starts from 0and reassigned as 1 r := rune(s[0]) width = 1 if r >= utf8.RuneSelf { r, width = utf8.DecodeRuneInString(s) } if width == 1 && r == utf8.RuneError { buf = append(buf, `\x`...) buf = append(buf, lowerhex[s[0]>>4]) buf = append(buf, lowerhex[s[0]&0xF]) continue } if r == rune('"') || r == '\\' { // always backslashed buf = append(buf, '\\') buf = append(buf, byte(r)) continue } if strconv.IsPrint(r) { n := utf8.EncodeRune(runeTmp[:], r) buf = append(buf, runeTmp[:n]...) continue } switch r { case '\a': buf = append(buf, `\a`...) case '\b': buf = append(buf, `\b`...) case '\f': buf = append(buf, `\f`...) case '\n': buf = append(buf, `\n`...) case '\r': buf = append(buf, `\r`...) case '\t': buf = append(buf, `\t`...) case '\v': buf = append(buf, `\v`...) default: switch { case r < ' ': buf = append(buf, `\x`...) buf = append(buf, lowerhex[s[0]>>4]) buf = append(buf, lowerhex[s[0]&0xF]) case r > utf8.MaxRune: r = 0xFFFD fallthrough case r < 0x10000: buf = append(buf, `\u`...) for s := 12; s >= 0; s -= 4 { buf = append(buf, lowerhex[r>>uint(s)&0xF]) } default: buf = append(buf, `\U`...) for s := 28; s >= 0; s -= 4 { buf = append(buf, lowerhex[r>>uint(s)&0xF]) } } } } return buf } // buildCommonLogLine builds a log entry for req in Apache Common Log Format. // ts is the timestamp with which the entry should be logged. // status and size are used to provide the response HTTP status and size. func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int, size int) []byte { username := "-" if url.User != nil { if name := url.User.Username(); name != "" { username = name } } host, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { host = req.RemoteAddr } uri := req.RequestURI // Requests using the CONNECT method over HTTP/2.0 must use // the authority field (aka r.Host) to identify the target. // Refer: https://httpwg.github.io/specs/rfc7540.html#CONNECT if req.ProtoMajor == 2 && req.Method == "CONNECT" { uri = req.Host } if uri == "" { uri = url.RequestURI() } buf := make([]byte, 0, 3*(len(host)+len(username)+len(req.Method)+len(uri)+len(req.Proto)+50)/2) buf = append(buf, host...) buf = append(buf, " - "...) buf = append(buf, username...) buf = append(buf, " ["...) buf = append(buf, ts.Format("02/Jan/2006:15:04:05 -0700")...) buf = append(buf, `] "`...) buf = append(buf, req.Method...) buf = append(buf, " "...) buf = appendQuoted(buf, uri) buf = append(buf, " "...) buf = append(buf, req.Proto...) buf = append(buf, `" `...) buf = append(buf, strconv.Itoa(status)...) buf = append(buf, " "...) buf = append(buf, strconv.Itoa(size)...) return buf } // writeLog writes a log entry for req to w in Apache Common Log Format. // ts is the timestamp with which the entry should be logged. // status and size are used to provide the response HTTP status and size. func writeLog(writer io.Writer, params LogFormatterParams) { buf := buildCommonLogLine(params.Request, params.URL, params.TimeStamp, params.StatusCode, params.Size) buf = append(buf, '\n') _, _ = writer.Write(buf) } // writeCombinedLog writes a log entry for req to w in Apache Combined Log Format. // ts is the timestamp with which the entry should be logged. // status and size are used to provide the response HTTP status and size. func writeCombinedLog(writer io.Writer, params LogFormatterParams) { buf := buildCommonLogLine(params.Request, params.URL, params.TimeStamp, params.StatusCode, params.Size) buf = append(buf, ` "`...) buf = appendQuoted(buf, params.Request.Referer()) buf = append(buf, `" "`...) buf = appendQuoted(buf, params.Request.UserAgent()) buf = append(buf, '"', '\n') _, _ = writer.Write(buf) } // CombinedLoggingHandler return a http.Handler that wraps h and logs requests to out in // Apache Combined Log Format. // // See http://httpd.apache.org/docs/2.2/logs.html#combined for a description of this format. // // LoggingHandler always sets the ident field of the log to -. func CombinedLoggingHandler(out io.Writer, h http.Handler) http.Handler { return loggingHandler{out, h, writeCombinedLog} } // LoggingHandler return a http.Handler that wraps h and logs requests to out in // Apache Common Log Format (CLF). // // See http://httpd.apache.org/docs/2.2/logs.html#common for a description of this format. // // LoggingHandler always sets the ident field of the log to - // // Example: // // r := mux.NewRouter() // r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // w.Write([]byte("This is a catch-all route")) // }) // loggedRouter := handlers.LoggingHandler(os.Stdout, r) // http.ListenAndServe(":1123", loggedRouter) func LoggingHandler(out io.Writer, h http.Handler) http.Handler { return loggingHandler{out, h, writeLog} } // CustomLoggingHandler provides a way to supply a custom log formatter // while taking advantage of the mechanisms in this package. func CustomLoggingHandler(out io.Writer, h http.Handler, f LogFormatter) http.Handler { return loggingHandler{out, h, f} } handlers-1.5.2/logging_test.go000066400000000000000000000245371451374025300163520ustar00rootroot00000000000000// Copyright 2013 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package handlers import ( "bytes" "crypto/rand" "encoding/base64" "errors" "fmt" "io/fs" "mime/multipart" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "strings" "testing" "time" ) func TestMakeLogger(t *testing.T) { rec := httptest.NewRecorder() logger, w := makeLogger(rec) // initial status if logger.Status() != http.StatusOK { t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusOK) } // WriteHeader w.WriteHeader(http.StatusInternalServerError) if logger.Status() != http.StatusInternalServerError { t.Fatalf("wrong status, got %d want %d", logger.Status(), http.StatusInternalServerError) } // Write _, err := w.Write([]byte(ok)) if err != nil { t.Fatalf("error while writing to http.ResponseWriter %v", err) return } if logger.Size() != len(ok) { t.Fatalf("wrong size, got %d want %d", logger.Size(), len(ok)) } // Header w.Header().Set("key", "value") if val := w.Header().Get("key"); val != "value" { t.Fatalf("wrong header, got %s want %s", val, "value") } } func TestLoggerCleanup(t *testing.T) { rbuf := make([]byte, 128) if _, err := rand.Read(rbuf); err != nil { t.Fatalf("Failed to generate random content: %v", err) } contents := base64.StdEncoding.EncodeToString(rbuf) var body bytes.Buffer body.WriteString(fmt.Sprintf(` --boundary Content-Disposition: form-data; name="buzz"; filename="example.txt" %s --boundary-- `, contents)) r := multipart.NewReader(&body, "boundary") form, err := r.ReadForm(0) // small max memory to force flush to disk if err != nil { t.Fatalf("Failed to read multipart form: %v", err) } tmpFiles, err := os.ReadDir(os.TempDir()) if err != nil { t.Fatalf("Failed to list %s: %v", os.TempDir(), err) } var tmpFile string for _, f := range tmpFiles { if !strings.HasPrefix(f.Name(), "multipart-") { continue } path := filepath.Join(os.TempDir(), f.Name()) switch b, fileError := os.ReadFile(path); { case fileError != nil: t.Fatalf("Failed to read %s: %v", path, err) case string(b) != contents: continue default: tmpFile = path } } if tmpFile == "" { t.Fatal("Could not find multipart form tmp file") } req := newRequest(http.MethodGet, "/subdir/asdf") req.MultipartForm = form var buf bytes.Buffer handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL.Path = "/" // simulate http.StripPrefix and friends w.WriteHeader(http.StatusOK) }) logger := LoggingHandler(&buf, handler) logger.ServeHTTP(httptest.NewRecorder(), req) if _, osStatErr := os.Stat(tmpFile); osStatErr == nil || !errors.Is(osStatErr, fs.ErrNotExist) { t.Fatalf("Expected %s to not exist, got %v", tmpFile, osStatErr) } } func TestLogPathRewrites(t *testing.T) { var buf bytes.Buffer handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL.Path = "/" // simulate http.StripPrefix and friends w.WriteHeader(http.StatusOK) }) logger := LoggingHandler(&buf, handler) logger.ServeHTTP(httptest.NewRecorder(), newRequest(http.MethodGet, "/subdir/asdf")) if !strings.Contains(buf.String(), "GET /subdir/asdf HTTP") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "GET /subdir/asdf HTTP") } } func BenchmarkWriteLog(b *testing.B) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { b.Fatalf(err.Error()) } ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) req := newRequest(http.MethodGet, "http://example.com") req.RemoteAddr = "192.168.100.5" b.ResetTimer() params := LogFormatterParams{ Request: req, URL: *req.URL, TimeStamp: ts, StatusCode: http.StatusUnauthorized, Size: 500, } buf := &bytes.Buffer{} for i := 0; i < b.N; i++ { buf.Reset() writeLog(buf, params) } } func TestLogFormatterWriteLog_Scenario1(t *testing.T) { formatter := writeLog expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100\n" LoggingScenario1(t, formatter, expected) } func TestLogFormatterCombinedLog_Scenario1(t *testing.T) { formatter := writeCombinedLog expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" LoggingScenario1(t, formatter, expected) } func TestLogFormatterWriteLog_Scenario2(t *testing.T) { formatter := writeLog expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"CONNECT www.example.com:443 HTTP/2.0\" 200 100\n" LoggingScenario2(t, formatter, expected) } func TestLogFormatterCombinedLog_Scenario2(t *testing.T) { formatter := writeCombinedLog expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"CONNECT www.example.com:443 HTTP/2.0\" 200 100 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" LoggingScenario2(t, formatter, expected) } func TestLogFormatterWriteLog_Scenario3(t *testing.T) { formatter := writeLog expected := "192.168.100.5 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 401 500\n" LoggingScenario3(t, formatter, expected) } func TestLogFormatterCombinedLog_Scenario3(t *testing.T) { formatter := writeCombinedLog expected := "192.168.100.5 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 401 500 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" LoggingScenario3(t, formatter, expected) } func TestLogFormatterWriteLog_Scenario4(t *testing.T) { formatter := writeLog expected := "192.168.100.5 - - [26/May/1983:03:30:45 +0200] \"GET /test?abc=hello%20world&a=b%3F HTTP/1.1\" 200 100\n" LoggingScenario4(t, formatter, expected) } func TestLogFormatterCombinedLog_Scenario5(t *testing.T) { formatter := writeCombinedLog expected := "::1 - kamil [26/May/1983:03:30:45 +0200] \"GET / HTTP/1.1\" 200 100 \"http://example.com\" " + "\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) " + "AppleWebKit/537.33 (KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33\"\n" LoggingScenario5(t, formatter, expected) } func LoggingScenario1(t *testing.T, formatter LogFormatter, expected string) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { panic(err) } ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) // A typical request with an OK response req := constructTypicalRequestOk() buf := new(bytes.Buffer) params := LogFormatterParams{ Request: req, URL: *req.URL, TimeStamp: ts, StatusCode: http.StatusOK, Size: 100, } formatter(buf, params) log := buf.String() if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } } func LoggingScenario2(t *testing.T, formatter LogFormatter, expected string) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { panic(err) } ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) // CONNECT request over http/2.0 req := constructConnectRequest() buf := new(bytes.Buffer) params := LogFormatterParams{ Request: req, URL: *req.URL, TimeStamp: ts, StatusCode: http.StatusOK, Size: 100, } formatter(buf, params) log := buf.String() if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } } func LoggingScenario3(t *testing.T, formatter LogFormatter, expected string) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { panic(err) } ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) // Request with an unauthorized user req := constructTypicalRequestOk() req.URL.User = url.User("kamil") buf := new(bytes.Buffer) params := LogFormatterParams{ Request: req, URL: *req.URL, TimeStamp: ts, StatusCode: http.StatusUnauthorized, Size: 500, } formatter(buf, params) log := buf.String() if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } } func LoggingScenario4(t *testing.T, formatter LogFormatter, expected string) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { panic(err) } ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) // Request with url encoded parameters req := constructEncodedRequest() buf := new(bytes.Buffer) params := LogFormatterParams{ Request: req, URL: *req.URL, TimeStamp: ts, StatusCode: http.StatusOK, Size: 100, } formatter(buf, params) log := buf.String() if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } } func LoggingScenario5(t *testing.T, formatter LogFormatter, expected string) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { panic(err) } ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc) req := constructTypicalRequestOk() req.URL.User = url.User("kamil") req.RemoteAddr = "::1" buf := new(bytes.Buffer) params := LogFormatterParams{ Request: req, URL: *req.URL, TimeStamp: ts, StatusCode: http.StatusOK, Size: 100, } formatter(buf, params) log := buf.String() if log != expected { t.Fatalf("wrong log, got %q want %q", log, expected) } } // A typical request with an OK response. func constructTypicalRequestOk() *http.Request { req := newRequest(http.MethodGet, "http://example.com") req.RemoteAddr = "192.168.100.5" req.Header.Set("Referer", "http://example.com") req.Header.Set( "User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.33 "+ "(KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33", ) return req } // CONNECT request over http/2.0. func constructConnectRequest() *http.Request { req := &http.Request{ Method: http.MethodConnect, Host: "www.example.com:443", Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, RemoteAddr: "192.168.100.5", Header: http.Header{}, URL: &url.URL{Host: "www.example.com:443"}, } req.Header.Set("Referer", "http://example.com") req.Header.Set( "User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.33 "+ "(KHTML, like Gecko) Chrome/27.0.1430.0 Safari/537.33", ) return req } func constructEncodedRequest() *http.Request { req := constructTypicalRequestOk() req.URL, _ = url.Parse("http://example.com/test?abc=hello%20world&a=b%3F") return req } handlers-1.5.2/proxy_headers.go000066400000000000000000000107241451374025300165320ustar00rootroot00000000000000package handlers import ( "net/http" "regexp" "strings" ) var ( // De-facto standard header keys. xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host") xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme") xRealIP = http.CanonicalHeaderKey("X-Real-IP") ) var ( // RFC7239 defines a new "Forwarded: " header designed to replace the // existing use of X-Forwarded-* headers. // e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43. forwarded = http.CanonicalHeaderKey("Forwarded") // Allows for a sub-match of the first value after 'for=' to the next // comma, semi-colon or space. The match is case-insensitive. forRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`) // Allows for a sub-match for the first instance of scheme (http|https) // prefixed by 'proto='. The match is case-insensitive. protoRegex = regexp.MustCompile(`(?i)(?:proto=)(https|http)`) ) // ProxyHeaders inspects common reverse proxy headers and sets the corresponding // fields in the HTTP request struct. These are X-Forwarded-For and X-Real-IP // for the remote (client) IP address, X-Forwarded-Proto or X-Forwarded-Scheme // for the scheme (http|https), X-Forwarded-Host for the host and the RFC7239 // Forwarded header, which may include both client IPs and schemes. // // NOTE: This middleware should only be used when behind a reverse // proxy like nginx, HAProxy or Apache. Reverse proxies that don't (or are // configured not to) strip these headers from client requests, or where these // headers are accepted "as is" from a remote client (e.g. when Go is not behind // a proxy), can manifest as a vulnerability if your application uses these // headers for validating the 'trustworthiness' of a request. func ProxyHeaders(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { // Set the remote IP with the value passed from the proxy. if fwd := getIP(r); fwd != "" { r.RemoteAddr = fwd } // Set the scheme (proto) with the value passed from the proxy. if scheme := getScheme(r); scheme != "" { r.URL.Scheme = scheme } // Set the host with the value passed by the proxy if r.Header.Get(xForwardedHost) != "" { r.Host = r.Header.Get(xForwardedHost) } // Call the next handler in the chain. h.ServeHTTP(w, r) } return http.HandlerFunc(fn) } // getIP retrieves the IP from the X-Forwarded-For, X-Real-IP and RFC7239 // Forwarded headers (in that order). func getIP(r *http.Request) string { var addr string switch { case r.Header.Get(xForwardedFor) != "": fwd := r.Header.Get(xForwardedFor) // Only grab the first (client) address. Note that '192.168.0.1, // 10.1.1.1' is a valid key for X-Forwarded-For where addresses after // the first may represent forwarding proxies earlier in the chain. s := strings.Index(fwd, ", ") if s == -1 { s = len(fwd) } addr = fwd[:s] case r.Header.Get(xRealIP) != "": addr = r.Header.Get(xRealIP) case r.Header.Get(forwarded) != "": // match should contain at least two elements if the protocol was // specified in the Forwarded header. The first element will always be // the 'for=' capture, which we ignore. In the case of multiple IP // addresses (for=8.8.8.8, 8.8.4.4,172.16.1.20 is valid) we only // extract the first, which should be the client IP. if match := forRegex.FindStringSubmatch(r.Header.Get(forwarded)); len(match) > 1 { // IPv6 addresses in Forwarded headers are quoted-strings. We strip // these quotes. addr = strings.Trim(match[1], `"`) } } return addr } // getScheme retrieves the scheme from the X-Forwarded-Proto and RFC7239 // Forwarded headers (in that order). func getScheme(r *http.Request) string { var scheme string // Retrieve the scheme from X-Forwarded-Proto. if proto := r.Header.Get(xForwardedProto); proto != "" { scheme = strings.ToLower(proto) } else if proto = r.Header.Get(xForwardedScheme); proto != "" { scheme = strings.ToLower(proto) } else if proto = r.Header.Get(forwarded); proto != "" { // match should contain at least two elements if the protocol was // specified in the Forwarded header. The first element will always be // the 'proto=' capture, which we ignore. In the case of multiple proto // parameters (invalid) we only extract the first. if match := protoRegex.FindStringSubmatch(proto); len(match) > 1 { scheme = strings.ToLower(match[1]) } } return scheme } handlers-1.5.2/proxy_headers_test.go000066400000000000000000000067241451374025300175760ustar00rootroot00000000000000package handlers import ( "net/http" "net/http/httptest" "testing" ) type headerTable struct { key string // header key val string // header val expected string // expected result } func TestGetIP(t *testing.T) { headers := []headerTable{ {xForwardedFor, "8.8.8.8", "8.8.8.8"}, // Single address {xForwardedFor, "8.8.8.8, 8.8.4.4", "8.8.8.8"}, // Multiple {xForwardedFor, "[2001:db8:cafe::17]:4711", "[2001:db8:cafe::17]:4711"}, // IPv6 address {xForwardedFor, "", ""}, // None {xRealIP, "8.8.8.8", "8.8.8.8"}, // Single address {xRealIP, "8.8.8.8, 8.8.4.4", "8.8.8.8, 8.8.4.4"}, // Multiple {xRealIP, "[2001:db8:cafe::17]:4711", "[2001:db8:cafe::17]:4711"}, // IPv6 address {xRealIP, "", ""}, // None {forwarded, `for="_gazonk"`, "_gazonk"}, // Hostname {forwarded, `For="[2001:db8:cafe::17]:4711`, `[2001:db8:cafe::17]:4711`}, // IPv6 address {forwarded, `for=192.0.2.60;proto=http;by=203.0.113.43`, `192.0.2.60`}, // Multiple params {forwarded, `for=192.0.2.43, for=198.51.100.17`, "192.0.2.43"}, // Multiple params {forwarded, `for="workstation.local",for=198.51.100.17`, "workstation.local"}, // Hostname } for _, v := range headers { req := &http.Request{ Header: http.Header{ v.key: []string{v.val}, }, } res := getIP(req) if res != v.expected { t.Fatalf("wrong header for %s: got %s want %s", v.key, res, v.expected) } } } func TestGetScheme(t *testing.T) { headers := []headerTable{ {xForwardedProto, "https", "https"}, {xForwardedProto, "http", "http"}, {xForwardedProto, "HTTP", "http"}, {xForwardedScheme, "https", "https"}, {xForwardedScheme, "http", "http"}, {xForwardedScheme, "HTTP", "http"}, {forwarded, `For="[2001:db8:cafe::17]:4711`, ""}, // No proto {forwarded, `for=192.0.2.43, for=198.51.100.17;proto=https`, "https"}, // Multiple params before proto {forwarded, `for=172.32.10.15; proto=https;by=127.0.0.1`, "https"}, // Space before proto {forwarded, `for=192.0.2.60;proto=http;by=203.0.113.43`, "http"}, // Multiple params } for _, v := range headers { req := &http.Request{ Header: http.Header{ v.key: []string{v.val}, }, } res := getScheme(req) if res != v.expected { t.Fatalf("wrong header for %s: got %s want %s", v.key, res, v.expected) } } } // Test the middleware end-to-end. func TestProxyHeaders(t *testing.T) { rr := httptest.NewRecorder() r := newRequest(http.MethodGet, "/") r.Header.Set(xForwardedFor, "8.8.8.8") r.Header.Set(xForwardedProto, "https") r.Header.Set(xForwardedHost, "google.com") var ( addr string proto string host string ) ProxyHeaders(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { addr = r.RemoteAddr proto = r.URL.Scheme host = r.Host })).ServeHTTP(rr, r) if rr.Code != http.StatusOK { t.Fatalf("bad status: got %d want %d", rr.Code, http.StatusOK) } if addr != r.Header.Get(xForwardedFor) { t.Fatalf("wrong address: got %s want %s", addr, r.Header.Get(xForwardedFor)) } if proto != r.Header.Get(xForwardedProto) { t.Fatalf("wrong address: got %s want %s", proto, r.Header.Get(xForwardedProto)) } if host != r.Header.Get(xForwardedHost) { t.Fatalf("wrong address: got %s want %s", host, r.Header.Get(xForwardedHost)) } } handlers-1.5.2/recovery.go000066400000000000000000000046431451374025300155170ustar00rootroot00000000000000package handlers import ( "log" "net/http" "runtime/debug" ) // RecoveryHandlerLogger is an interface used by the recovering handler to print logs. type RecoveryHandlerLogger interface { Println(...interface{}) } type recoveryHandler struct { handler http.Handler logger RecoveryHandlerLogger printStack bool } // RecoveryOption provides a functional approach to define // configuration for a handler; such as setting the logging // whether or not to print stack traces on panic. type RecoveryOption func(http.Handler) func parseRecoveryOptions(h http.Handler, opts ...RecoveryOption) http.Handler { for _, option := range opts { option(h) } return h } // RecoveryHandler is HTTP middleware that recovers from a panic, // logs the panic, writes http.StatusInternalServerError, and // continues to the next handler. // // Example: // // r := mux.NewRouter() // r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // panic("Unexpected error!") // }) // // http.ListenAndServe(":1123", handlers.RecoveryHandler()(r)) func RecoveryHandler(opts ...RecoveryOption) func(h http.Handler) http.Handler { return func(h http.Handler) http.Handler { r := &recoveryHandler{handler: h} return parseRecoveryOptions(r, opts...) } } // RecoveryLogger is a functional option to override // the default logger. func RecoveryLogger(logger RecoveryHandlerLogger) RecoveryOption { return func(h http.Handler) { r := h.(*recoveryHandler) //nolint:errcheck //TODO: // @bharat-rajani should return type-assertion error but would break the API? r.logger = logger } } // PrintRecoveryStack is a functional option to enable // or disable printing stack traces on panic. func PrintRecoveryStack(shouldPrint bool) RecoveryOption { return func(h http.Handler) { r := h.(*recoveryHandler) //nolint:errcheck //TODO: // @bharat-rajani should return type-assertion error but would break the API? r.printStack = shouldPrint } } func (h recoveryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { defer func() { if err := recover(); err != nil { w.WriteHeader(http.StatusInternalServerError) h.log(err) } }() h.handler.ServeHTTP(w, req) } func (h recoveryHandler) log(v ...interface{}) { if h.logger != nil { h.logger.Println(v...) } else { log.Println(v...) } if h.printStack { stack := string(debug.Stack()) if h.logger != nil { h.logger.Println(stack) } else { log.Println(stack) } } } handlers-1.5.2/recovery_test.go000066400000000000000000000032301451374025300165450ustar00rootroot00000000000000package handlers import ( "bytes" "log" "net/http" "net/http/httptest" "strings" "testing" ) func TestRecoveryLoggerWithDefaultOptions(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) handler := RecoveryHandler() handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { panic("Unexpected error!") }) recovery := handler(handlerFunc) recovery.ServeHTTP(httptest.NewRecorder(), newRequest(http.MethodGet, "/subdir/asdf")) if !strings.Contains(buf.String(), "Unexpected error!") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "Unexpected error!") } } func TestRecoveryLoggerWithCustomLogger(t *testing.T) { var buf bytes.Buffer logger := log.New(&buf, "", log.LstdFlags) handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { panic("Unexpected error!") }) t.Run("Without print stack", func(t *testing.T) { handler := RecoveryHandler(RecoveryLogger(logger), PrintRecoveryStack(false)) recovery := handler(handlerFunc) recovery.ServeHTTP(httptest.NewRecorder(), newRequest(http.MethodGet, "/subdir/asdf")) if !strings.Contains(buf.String(), "Unexpected error!") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "Unexpected error!") } }) t.Run("With print stack enabled", func(t *testing.T) { handler := RecoveryHandler(RecoveryLogger(logger), PrintRecoveryStack(true)) recovery := handler(handlerFunc) recovery.ServeHTTP(httptest.NewRecorder(), newRequest(http.MethodGet, "/subdir/asdf")) if !strings.Contains(buf.String(), "runtime/debug.Stack") { t.Fatalf("Got log %#v, wanted substring %#v", buf.String(), "runtime/debug.Stack") } }) }