pax_global_header00006660000000000000000000000064145047507640014526gustar00rootroot0000000000000052 comment=07821e22d8655dcce7d23d1038a8d325e5dc234b oxy-2.0.0/000077500000000000000000000000001450475076400123445ustar00rootroot00000000000000oxy-2.0.0/.github/000077500000000000000000000000001450475076400137045ustar00rootroot00000000000000oxy-2.0.0/.github/workflows/000077500000000000000000000000001450475076400157415ustar00rootroot00000000000000oxy-2.0.0/.github/workflows/go-cross.yml000066400000000000000000000014431450475076400202220ustar00rootroot00000000000000name: Go Matrix on: push: branches: - master pull_request: jobs: cross: name: Go runs-on: ${{ matrix.os }} env: CGO_ENABLED: 0 strategy: matrix: go-version: [ stable, oldstable ] os: [ubuntu-latest, macos-latest] # TODO ignore windows but need to be added in the future # os: [ubuntu-latest, macos-latest, windows-latest] steps: # https://github.com/marketplace/actions/checkout - name: Checkout code uses: actions/checkout@v4 # https://github.com/marketplace/actions/setup-go-environment - name: Set up Go ${{ matrix.go-version }} uses: actions/setup-go@v4 with: go-version: ${{ matrix.go-version }} - name: Test run: go test -v -cover ./... oxy-2.0.0/.github/workflows/pr.yml000066400000000000000000000020621450475076400171050ustar00rootroot00000000000000name: Main on: pull_request: jobs: main: name: Main Process runs-on: ubuntu-latest env: GO_VERSION: stable GOLANGCI_LINT_VERSION: v1.54.2 steps: # https://github.com/marketplace/actions/checkout - name: Check out code uses: actions/checkout@v4 with: fetch-depth: 0 # https://github.com/marketplace/actions/setup-go-environment - name: Set up Go ${{ env.GO_VERSION }} uses: actions/setup-go@v4 with: go-version: ${{ env.GO_VERSION }} - name: Check and get dependencies run: | go mod tidy git diff --exit-code go.mod git diff --exit-code go.sum # https://golangci-lint.run/usage/install#other-ci - name: Install golangci-lint ${{ env.GOLANGCI_LINT_VERSION }} run: | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin ${GOLANGCI_LINT_VERSION} golangci-lint --version - name: Make run: make oxy-2.0.0/.gitignore000066400000000000000000000004441450475076400143360ustar00rootroot00000000000000# Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe *.test *.prof .idea/ flymake_* vendor/oxy-2.0.0/.golangci.yml000066400000000000000000000072101450475076400147300ustar00rootroot00000000000000run: deadline: 5m skip-files: [ ] skip-dirs: ["internal/holsterv4"] linters-settings: govet: enable-all: true disable: - fieldalignment - shadow gocyclo: min-complexity: 15 maligned: suggest-new: true goconst: min-len: 5 min-occurrences: 3 misspell: locale: US funlen: lines: -1 statements: 50 godox: keywords: - FIXME gofumpt: extra-rules: false depguard: rules: main: deny: - pkg: "github.com/instana/testify" desc: not allowed - pkg: "github.com/pkg/errors" desc: Should be replaced by standard lib errors package gocritic: enabled-tags: - diagnostic - style - performance disabled-checks: - sloppyReassign - rangeValCopy - octalLiteral - paramTypeCombine # already handle by gofumpt.extra-rules - httpNoBody - unnamedResult - deferInLoop # TODO(ldez) should be use on the project settings: hugeParam: sizeThreshold: 100 linters: enable-all: true disable: - deadcode # deprecated - exhaustivestruct # deprecated - golint # deprecated - ifshort # deprecated - interfacer # deprecated - maligned # deprecated - nosnakecase # deprecated - scopelint # deprecated - structcheck # deprecated - varcheck # deprecated - sqlclosecheck # not relevant (SQL) - rowserrcheck # not relevant (SQL) - execinquery # not relevant (SQL) - cyclop # duplicate of gocyclo - lll - dupl - wsl - nlreturn - gomnd - goerr113 - wrapcheck - exhaustive - exhaustruct - testpackage - tparallel - paralleltest - prealloc - ifshort - forcetypeassert - bodyclose # Too many false positives: https://github.com/timakin/bodyclose/issues/30 - varnamelen - noctx - tagliatelle - nilnil - ireturn - nonamedreturns - gochecknoglobals # TODO(ldez) should be use on the project - nestif # TODO(ldez) should be use on the project issues: exclude-use-default: false max-per-linter: 0 max-same-issues: 0 exclude: - 'ST1000: at least one file in a package should have a package comment' # TODO(ldez) must be fixed - 'package-comments: should have a package comment' - 'Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*Flush|os\\.Remove(All)?|.*printf?|os\\.(Un)?Setenv). is not checked' - 'SA1019: http.CloseNotifier has been deprecated' - 'exported: func name will be used as roundrobin.RoundRobinRequestRewriteListener by other packages'# TODO(ldez) must be fixed - 'G101: Potential hardcoded credentials' # TODO(ldez) https://github.com/golangci/golangci-lint/issues/4037 exclude-rules: - path: .*_test.go linters: - funlen - gosec - path: testutils/.+ linters: - gosec - path: cbreaker/cbreaker_test.go text: "`statsNetErrors` - `threshold` always receives `0.6`" # TODO(ldez) must be fixed - path: buffer/buffer.go text: "(cognitive|cyclomatic) complexity \\d+ of func `\\(\\*Buffer\\)\\.ServeHTTP` is high" # TODO(ldez) must be fixed - path: buffer/buffer.go text: "Function 'ServeHTTP' has too many statements" # TODO(ldez) must be fixed - path: forward/fwd.go text: "(cognitive|cyclomatic) complexity \\d+ of func `\\(\\*httpForwarder\\)\\.serveWebSocket` is high" # TODO(ldez) must be fixed - path: forward/fwd.go text: "Function 'serveWebSocket' has too many statements" # TODO(ldez) must be fixed - path: utils/handler.go text: "ifElseChain: rewrite if-else to switch statement" # TODO(ldez) must be fixed oxy-2.0.0/LICENSE000066400000000000000000000260751450475076400133630ustar00rootroot00000000000000Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright {yyyy} {name of copyright owner} Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. oxy-2.0.0/Makefile000066400000000000000000000004351450475076400140060ustar00rootroot00000000000000.PHONY: default clean checks test test-verbose export GO111MODULE=on default: clean checks test test: clean go test -race -cover -count 1 ./... test-verbose: clean go test -v -race -cover ./... clean: find . -name flymake_* -delete rm -f cover.out checks: golangci-lint run oxy-2.0.0/README.md000066400000000000000000000065231450475076400136310ustar00rootroot00000000000000Oxy [![Build Status](https://travis-ci.org/vulcand/oxy.svg?branch=master)](https://travis-ci.org/vulcand/oxy) ===== Oxy is a Go library with HTTP handlers that enhance HTTP standard library: * [Buffer](https://pkg.go.dev/github.com/vulcand/oxy/buffer) retries and buffers requests and responses * [Stream](https://pkg.go.dev/github.com/vulcand/oxy/stream) passes-through requests, supports chunked encoding with configurable flush interval * [Forward](https://pkg.go.dev/github.com/vulcand/oxy/forward) forwards requests to remote location and rewrites headers * [Roundrobin](https://pkg.go.dev/github.com/vulcand/oxy/roundrobin) is a round-robin load balancer * [Circuit Breaker](https://pkg.go.dev/github.com/vulcand/oxy/cbreaker) Hystrix-style circuit breaker * [Connlimit](https://pkg.go.dev/github.com/vulcand/oxy/connlimit) Simultaneous connections limiter * [Ratelimit](https://pkg.go.dev/github.com/vulcand/oxy/ratelimit) Rate limiter (based on tokenbucket algo) * [Trace](https://pkg.go.dev/github.com/vulcand/oxy/trace) Structured request and response logger It is designed to be fully compatible with http standard library, easy to customize and reuse. Status ------ * Initial design is completed * Covered by tests * Used as a reverse proxy engine in [Vulcand](https://github.com/vulcand/vulcand) Quickstart ----------- Every handler is ``http.Handler``, so writing and plugging in a middleware is easy. Let us write a simple reverse proxy as an example: Simple reverse proxy ==================== ```go import ( "net/http" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/testutils" ) // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers fwd := forward.New(false) redirect := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // let us forward this request to another server req.URL = testutils.ParseURI("http://localhost:63450") fwd.ServeHTTP(w, req) }) // that's it! our reverse proxy is ready! s := &http.Server{ Addr: ":8080", Handler: redirect, } s.ListenAndServe() ``` As a next step, let us add a round robin load-balancer: ```go import ( "net/http" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/roundrobin" ) // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers fwd := forward.New(false) lb, _ := roundrobin.New(fwd) lb.UpsertServer(url1) lb.UpsertServer(url2) s := &http.Server{ Addr: ":8080", Handler: lb, } s.ListenAndServe() ``` What if we want to handle retries and replay the request in case of errors? `buffer` handler will help: ```go import ( "net/http" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/buffer" "github.com/vulcand/oxy/v2/roundrobin" ) // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers fwd := forward.New(false) lb, _ := roundrobin.New(fwd) // buffer will read the request body and will replay the request again in case if forward returned status // corresponding to nework error (e.g. Gateway Timeout) buffer, _ := buffer.New(lb, buffer.Retry(`IsNetworkError() && Attempts() < 2`)) lb.UpsertServer(url1) lb.UpsertServer(url2) // that's it! our reverse proxy is ready! s := &http.Server{ Addr: ":8080", Handler: buffer, } s.ListenAndServe() ``` oxy-2.0.0/buffer/000077500000000000000000000000001450475076400136155ustar00rootroot00000000000000oxy-2.0.0/buffer/buffer.go000066400000000000000000000246061450475076400154250ustar00rootroot00000000000000/* Package buffer provides http.Handler middleware that solves several problems when dealing with http requests: Reads the entire request and response into buffer, optionally buffering it to disk for large requests. Checks the limits for the requests and responses, rejecting in case if the limit was exceeded. Changes request content-transfer-encoding from chunked and provides total size to the handlers. Examples of a buffering middleware: // sample HTTP handler. handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) // Buffer will read the body in buffer before passing the request to the handler // calculate total size of the request and transform it from chunked encoding // before passing to the server buffer.New(handler) // This version will buffer up to 2MB in memory and will serialize any extra // to a temporary file, if the request size exceeds 10MB it will reject the request buffer.New(handler, buffer.MemRequestBodyBytes(2 * 1024 * 1024), buffer.MaxRequestBodyBytes(10 * 1024 * 1024)) // Will do the same as above, but with responses buffer.New(handler, buffer.MemResponseBodyBytes(2 * 1024 * 1024), buffer.MaxResponseBodyBytes(10 * 1024 * 1024)) // Buffer will replay the request if the handler returns error at least 3 times // before returning the response buffer.New(handler, buffer.Retry(`IsNetworkError() && Attempts() <= 2`)) */ package buffer import ( "bufio" "fmt" "io" "net" "net/http" "reflect" "github.com/mailgun/multibuf" "github.com/vulcand/oxy/v2/utils" ) const ( // DefaultMemBodyBytes Store up to 1MB in RAM. DefaultMemBodyBytes = 1048576 // DefaultMaxBodyBytes No limit by default. DefaultMaxBodyBytes = -1 // DefaultMaxRetryAttempts Maximum retry attempts. DefaultMaxRetryAttempts = 10 ) var errHandler utils.ErrorHandler = &SizeErrHandler{} // Buffer is responsible for buffering requests and responses // It buffers large requests and responses to disk,. type Buffer struct { maxRequestBodyBytes int64 memRequestBodyBytes int64 maxResponseBodyBytes int64 memResponseBodyBytes int64 retryPredicate hpredicate next http.Handler errHandler utils.ErrorHandler verbose bool log utils.Logger } // New returns a new buffer middleware. New() function supports optional functional arguments. func New(next http.Handler, setters ...Option) (*Buffer, error) { strm := &Buffer{ next: next, maxRequestBodyBytes: DefaultMaxBodyBytes, memRequestBodyBytes: DefaultMemBodyBytes, maxResponseBodyBytes: DefaultMaxBodyBytes, memResponseBodyBytes: DefaultMemBodyBytes, log: &utils.NoopLogger{}, } for _, s := range setters { if err := s(strm); err != nil { return nil, err } } if strm.errHandler == nil { strm.errHandler = errHandler } return strm, nil } // Wrap sets the next handler to be called by buffer handler. func (b *Buffer) Wrap(next http.Handler) error { b.next = next return nil } func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if b.verbose { dump := utils.DumpHTTPRequest(req) b.log.Debug("vulcand/oxy/buffer: begin ServeHttp on request: %s", dump) defer b.log.Debug("vulcand/oxy/buffer: completed ServeHttp on request: %s", dump) } if err := b.checkLimit(req); err != nil { b.log.Error("vulcand/oxy/buffer: request body over limit, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } // Read the body while keeping limits in mind. This reader controls the maximum bytes // to read into memory and disk. This reader returns an error if the total request size exceeds the // predefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 // and the reader would be unbounded bufio in the http.Server body, err := multibuf.New(req.Body, multibuf.MaxBytes(b.maxRequestBodyBytes), multibuf.MemBytes(b.memRequestBodyBytes)) if err != nil || body == nil { if req.Context().Err() != nil { b.log.Error("vulcand/oxy/buffer: error when reading request body, err: %v", req.Context().Err()) b.errHandler.ServeHTTP(w, req, req.Context().Err()) return } b.log.Error("vulcand/oxy/buffer: error when reading request body, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } // Set request body to buffered reader that can replay the read and execute Seek // Note that we don't change the original request body as it's handled by the http server // and we don't want to mess with standard library defer func() { if body != nil { errClose := body.Close() if errClose != nil { b.log.Error("vulcand/oxy/buffer: failed to close body, err: %v", errClose) } } }() // We need to set ContentLength based on known request size. The incoming request may have been // set without content length or using chunked TransferEncoding totalSize, err := body.Size() if err != nil { b.log.Error("vulcand/oxy/buffer: failed to get request size, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } if totalSize == 0 { body = nil } outReq := b.copyRequest(req, body, totalSize) attempt := 1 for { // We create a special writer that will limit the response size, buffer it to disk if necessary writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(b.maxResponseBodyBytes), multibuf.MemBytes(b.memResponseBodyBytes)) if err != nil { b.log.Error("vulcand/oxy/buffer: failed create response writer, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } // We are mimicking http.ResponseWriter to replace writer with our special writer bw := &bufferWriter{ header: make(http.Header), buffer: writer, responseWriter: w, log: b.log, } defer bw.Close() b.next.ServeHTTP(bw, outReq) if bw.hijacked { b.log.Debug("vulcand/oxy/buffer: connection was hijacked downstream. Not taking any action in buffer.") return } var reader multibuf.MultiReader if bw.expectBody(outReq) { rdr, err := writer.Reader() if err != nil { b.log.Error("vulcand/oxy/buffer: failed to read response, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } defer rdr.Close() reader = rdr } if (b.retryPredicate == nil || attempt > DefaultMaxRetryAttempts) || !b.retryPredicate(&context{r: req, attempt: attempt, responseCode: bw.code}) { utils.CopyHeaders(w.Header(), bw.Header()) w.WriteHeader(bw.code) if reader != nil { _, _ = io.Copy(w, reader) } return } attempt++ if body != nil { if _, err := body.Seek(0, 0); err != nil { b.log.Error("vulcand/oxy/buffer: failed to rewind response body, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } } outReq = b.copyRequest(req, body, totalSize) b.log.Debug("vulcand/oxy/buffer: retry Request(%v %v) attempt %v", req.Method, req.URL, attempt) } } func (b *Buffer) copyRequest(req *http.Request, body io.ReadCloser, bodySize int64) *http.Request { o := *req o.URL = utils.CopyURL(req.URL) o.Header = make(http.Header) utils.CopyHeaders(o.Header, req.Header) o.ContentLength = bodySize // remove TransferEncoding that could have been previously set because we have transformed the request from chunked encoding o.TransferEncoding = []string{} // http.Transport will close the request body on any error, we are controlling the close process ourselves, so we override the closer here if body == nil { o.Body = io.NopCloser(req.Body) } else { o.Body = io.NopCloser(body.(io.Reader)) } return &o } func (b *Buffer) checkLimit(req *http.Request) error { if b.maxRequestBodyBytes <= 0 { return nil } if req.ContentLength > b.maxRequestBodyBytes { return &multibuf.MaxSizeReachedError{MaxSize: b.maxRequestBodyBytes} } return nil } type bufferWriter struct { header http.Header code int buffer multibuf.WriterOnce responseWriter http.ResponseWriter hijacked bool log utils.Logger } // RFC2616 #4.4. func (b *bufferWriter) expectBody(r *http.Request) bool { if r.Method == http.MethodHead { return false } if (b.code >= 100 && b.code < 200) || b.code == 204 || b.code == 304 { return false } // refer to https://github.com/vulcand/oxy/issues/113 // if b.header.Get("Content-Length") == "" && b.header.Get("Transfer-Encoding") == "" { // return false // } if b.header.Get("Content-Length") == "0" { return false } // Support for gRPC, gRPC Web. if grpcStatus := b.header.Get("Grpc-Status"); grpcStatus != "" && grpcStatus != "0" { return false } return true } func (b *bufferWriter) Close() error { return b.buffer.Close() } func (b *bufferWriter) Header() http.Header { return b.header } func (b *bufferWriter) Write(buf []byte) (int, error) { length, err := b.buffer.Write(buf) if err != nil { // Since go1.11 (https://github.com/golang/go/commit/8f38f28222abccc505b9a1992deecfe3e2cb85de) // if the writer returns an error, the reverse proxy panics b.log.Error("write: %v", err) length = len(buf) } return length, nil } // WriteHeader sets rw.Code. func (b *bufferWriter) WriteHeader(code int) { b.code = code } // CloseNotify CloseNotifier interface - this allows downstream connections to be terminated when the client terminates. func (b *bufferWriter) CloseNotify() <-chan bool { if cn, ok := b.responseWriter.(http.CloseNotifier); ok { return cn.CloseNotify() } b.log.Warn("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) return make(<-chan bool) } // Hijack This allows connections to be hijacked for websockets for instance. func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hi, ok := b.responseWriter.(http.Hijacker); ok { conn, rw, err := hi.Hijack() if err == nil { b.hijacked = true } return conn, rw, err } b.log.Warn("Upstream ResponseWriter of type %v does not implement http.Hijacker.", reflect.TypeOf(b.responseWriter)) return nil, nil, fmt.Errorf("the response writer wrapped in this proxy does not implement http.Hijacker. Its type is: %v", reflect.TypeOf(b.responseWriter)) } // SizeErrHandler Size error handler. type SizeErrHandler struct{} func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { //nolint:errorlint // must be changed if _, ok := err.(*multibuf.MaxSizeReachedError); ok { w.WriteHeader(http.StatusRequestEntityTooLarge) _, _ = w.Write([]byte(http.StatusText(http.StatusRequestEntityTooLarge))) return } utils.DefaultHandler.ServeHTTP(w, req, err) } oxy-2.0.0/buffer/buffer_test.go000066400000000000000000000313101450475076400164520ustar00rootroot00000000000000package buffer import ( "bufio" "crypto/tls" "fmt" "io" "net" "net/http" "net/http/httptest" "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/testutils" "github.com/vulcand/oxy/v2/utils" ) func TestSimple(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestChunkedEncodingSuccess(t *testing.T) { var reqBody string var contentLength int64 srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { body, err := io.ReadAll(req.Body) require.NoError(t, err) reqBody = string(body) contentLength = req.ContentLength _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) conn, err := net.Dial("tcp", testutils.ParseURI(proxy.URL).Host) require.NoError(t, err) _, _ = fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: 127.0.0.1:8080\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") status, err := bufio.NewReader(conn).ReadString('\n') require.NoError(t, err) assert.Equal(t, "testtest1test2", reqBody) assert.Equal(t, "HTTP/1.1 200 OK\r\n", status) assert.EqualValues(t, len(reqBody), contentLength) } func TestChunkedEncodingLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MemRequestBodyBytes(4), MaxRequestBodyBytes(8)) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) conn, err := net.Dial("tcp", testutils.ParseURI(proxy.URL).Host) require.NoError(t, err) _, _ = fmt.Fprint(conn, "POST / HTTP/1.1\r\nHost: 127.0.0.1:8080\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") status, err := bufio.NewReader(conn).ReadString('\n') require.NoError(t, err) assert.Equal(t, "HTTP/1.1 413 Request Entity Too Large\r\n", status) } func TestChunkedResponse(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { h := w.(http.Hijacker) conn, _, _ := h.Hijack() _, _ = fmt.Fprintf(conn, "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") _ = conn.Close() }) t.Cleanup(srv.Close) fwd := forward.New(false) rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, "testtest1test2", string(body)) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, strconv.Itoa(len("testtest1test2")), re.Header.Get("Content-Length")) } func TestRequestLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MaxRequestBodyBytes(4)) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL, testutils.Body("this request is too long")) require.NoError(t, err) assert.Equal(t, http.StatusRequestEntityTooLarge, re.StatusCode) } func TestResponseLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello, this response is too large")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MaxResponseBodyBytes(4)) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } func TestFileStreamingResponse(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello, this response is too large to fit in memory")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MemResponseBodyBytes(4)) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello, this response is too large to fit in memory", string(body)) } func TestCustomErrorHandler(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello, this response is too large")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) _, _ = w.Write([]byte(http.StatusText(http.StatusTeapot))) }) st, err := New(rdr, MaxResponseBodyBytes(4), ErrorHandler(errHandler)) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) } func TestNotModified(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusNotModified) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusNotModified, re.StatusCode) } func TestNoBody(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // Make sure that stream handler preserves TLS settings. func TestPreservesTLS(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) var cs *tls.ConnectionState // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { cs = req.TLS req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewUnstartedServer(st) proxy.StartTLS() t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.NotNil(t, cs) } func TestNotNilBody(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) // During a request check if the request body is no nil before sending to the next middleware // Because we can send a POST request without body assert.NotNil(t, req.Body) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MaxRequestBodyBytes(10)) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) re, body, err = testutils.Post(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestGRPCErrorResponse(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Header().Set("Grpc-Status", "10" /* ABORTED */) w.WriteHeader(http.StatusOK) // To skip the "Content-Length" header. w.(http.Flusher).Flush() }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Empty(t, body) } func TestGRPCOKResponse(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Header().Set("Grpc-Status", "0" /* OK */) _, _ = w.Write([]byte("grpc-body")) w.WriteHeader(http.StatusOK) // To skip the "Content-Length" header. w.(http.Flusher).Flush() }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "grpc-body", string(body)) } oxy-2.0.0/buffer/options.go000066400000000000000000000051631450475076400156440ustar00rootroot00000000000000package buffer import ( "fmt" "github.com/vulcand/oxy/v2/utils" ) // Option represents an option you can pass to New. type Option func(b *Buffer) error // Logger defines the logger used by Buffer. func Logger(l utils.Logger) Option { return func(b *Buffer) error { b.log = l return nil } } // Verbose additional debug information. func Verbose(verbose bool) Option { return func(b *Buffer) error { b.verbose = verbose return nil } } // Cond Conditional setter. // ex: Cond(a > 4, MemRequestBodyBytes(a)) func Cond(condition bool, setter Option) Option { if !condition { // NoOp setter return func(*Buffer) error { return nil } } return setter } // Retry provides a predicate that allows buffer middleware to replay the request // if it matches certain condition, e.g. returns special error code. Available functions are: // // Attempts() - limits the amount of retry attempts // ResponseCode() - returns http response code // IsNetworkError() - tests if response code is related to networking error // // Example of the predicate: // // `Attempts() <= 2 && ResponseCode() == 502`. func Retry(predicate string) Option { return func(b *Buffer) error { p, err := parseExpression(predicate) if err != nil { return err } b.retryPredicate = p return nil } } // ErrorHandler sets error handler of the server. func ErrorHandler(h utils.ErrorHandler) Option { return func(b *Buffer) error { b.errHandler = h return nil } } // MaxRequestBodyBytes sets the maximum request body size in bytes. func MaxRequestBodyBytes(m int64) Option { return func(b *Buffer) error { if m < 0 { return fmt.Errorf("max bytes should be >= 0 got %d", m) } b.maxRequestBodyBytes = m return nil } } // MemRequestBodyBytes bytes sets the maximum request body to be stored in memory // buffer middleware will serialize the excess to disk. func MemRequestBodyBytes(m int64) Option { return func(b *Buffer) error { if m < 0 { return fmt.Errorf("mem bytes should be >= 0 got %d", m) } b.memRequestBodyBytes = m return nil } } // MaxResponseBodyBytes sets the maximum response body size in bytes. func MaxResponseBodyBytes(m int64) Option { return func(b *Buffer) error { if m < 0 { return fmt.Errorf("max bytes should be >= 0 got %d", m) } b.maxResponseBodyBytes = m return nil } } // MemResponseBodyBytes sets the maximum response body to be stored in memory // buffer middleware will serialize the excess to disk. func MemResponseBodyBytes(m int64) Option { return func(b *Buffer) error { if m < 0 { return fmt.Errorf("mem bytes should be >= 0 got %d", m) } b.memResponseBodyBytes = m return nil } } oxy-2.0.0/buffer/retry_test.go000066400000000000000000000051611450475076400163530ustar00rootroot00000000000000package buffer import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/roundrobin" "github.com/vulcand/oxy/v2/testutils" ) func TestSuccess(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) proxy := httptest.NewServer(rt) t.Cleanup(proxy.Close) require.NoError(t, lb.UpsertServer(testutils.ParseURI(srv.URL))) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestRetryOnError(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) proxy := httptest.NewServer(rt) t.Cleanup(proxy.Close) require.NoError(t, lb.UpsertServer(testutils.ParseURI("http://localhost:64321"))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(srv.URL))) re, body, err := testutils.Get(proxy.URL, testutils.Body("some request parameters")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestRetryExceedAttempts(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) proxy := httptest.NewServer(rt) t.Cleanup(proxy.Close) require.NoError(t, lb.UpsertServer(testutils.ParseURI("http://localhost:64321"))) require.NoError(t, lb.UpsertServer(testutils.ParseURI("http://localhost:64322"))) require.NoError(t, lb.UpsertServer(testutils.ParseURI("http://localhost:64323"))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(srv.URL))) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusBadGateway, re.StatusCode) } func newBufferMiddleware(t *testing.T, p string) (*roundrobin.RoundRobin, *Buffer) { t.Helper() // forwarder will proxy the request to whatever destination fwd := forward.New(false) // load balancer will round robin request lb, err := roundrobin.New(fwd) require.NoError(t, err) // stream handler will forward requests to redirect, make sure it uses files st, err := New(lb, Retry(p), MemRequestBodyBytes(1)) require.NoError(t, err) return lb, st } oxy-2.0.0/buffer/threshold.go000066400000000000000000000123161450475076400161430ustar00rootroot00000000000000package buffer import ( "fmt" "net/http" "github.com/vulcand/predicate" ) type hpredicate func(*context) bool // IsValidExpression check if it's a valid expression. func IsValidExpression(expr string) bool { _, err := parseExpression(expr) return err == nil } // Parses expression in the go language into Failover predicates. func parseExpression(in string) (hpredicate, error) { p, err := predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ AND: and, OR: or, EQ: eq, NEQ: neq, LT: lt, GT: gt, LE: le, GE: ge, }, Functions: map[string]interface{}{ "RequestMethod": requestMethod, "IsNetworkError": isNetworkError, "Attempts": attempts, "ResponseCode": responseCode, }, }) if err != nil { return nil, err } out, err := p.Parse(in) if err != nil { return nil, err } pr, ok := out.(hpredicate) if !ok { return nil, fmt.Errorf("expected predicate, got %T", out) } return pr, nil } // IsNetworkError returns a predicate that returns true if last attempt ended with network error. func isNetworkError() hpredicate { return func(c *context) bool { return c.responseCode == http.StatusBadGateway || c.responseCode == http.StatusGatewayTimeout } } // and returns predicate by joining the passed predicates with logical 'and'. func and(fns ...hpredicate) hpredicate { return func(c *context) bool { for _, fn := range fns { if !fn(c) { return false } } return true } } // or returns predicate by joining the passed predicates with logical 'or'. func or(fns ...hpredicate) hpredicate { return func(c *context) bool { for _, fn := range fns { if fn(c) { return true } } return false } } // not creates negation of the passed predicate. func not(p hpredicate) hpredicate { return func(c *context) bool { return !p(c) } } // eq returns predicate that tests for equality of the value of the mapper and the constant. func eq(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toString: return stringEQ(mapper, value) case toInt: return intEQ(mapper, value) } return nil, fmt.Errorf("unsupported argument: %T", m) } // neq returns predicate that tests for inequality of the value of the mapper and the constant. func neq(m interface{}, value interface{}) (hpredicate, error) { p, err := eq(m, value) if err != nil { return nil, err } return not(p), nil } // lt returns predicate that tests that value of the mapper function is less than the constant. func lt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intLT(mapper, value) default: return nil, fmt.Errorf("unsupported argument: %T", m) } } // le returns predicate that tests that value of the mapper function is less or equal than the constant. func le(m interface{}, value interface{}) (hpredicate, error) { l, err := lt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *context) bool { return l(c) || e(c) }, nil } // gt returns predicate that tests that value of the mapper function is greater than the constant. func gt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intGT(mapper, value) default: return nil, fmt.Errorf("unsupported argument: %T", m) } } // ge returns predicate that tests that value of the mapper function is less or equal than the constant. func ge(m interface{}, value interface{}) (hpredicate, error) { g, err := gt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *context) bool { return g(c) || e(c) }, nil } func stringEQ(m toString, val interface{}) (hpredicate, error) { value, ok := val.(string) if !ok { return nil, fmt.Errorf("expected string, got %T", val) } return func(c *context) bool { return m(c) == value }, nil } func intEQ(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) == value }, nil } func intLT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) < value }, nil } func intGT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) > value }, nil } type context struct { r *http.Request attempt int responseCode int } type toString func(c *context) string type toInt func(c *context) int // RequestMethod returns mapper of the request to its method e.g. POST. func requestMethod() toString { return func(c *context) string { return c.r.Method } } // Attempts returns mapper of the request to the number of proxy attempts. func attempts() toInt { return func(c *context) int { return c.attempt } } // ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code. func responseCode() toInt { return func(c *context) int { return c.responseCode } } oxy-2.0.0/cbreaker/000077500000000000000000000000001450475076400141225ustar00rootroot00000000000000oxy-2.0.0/cbreaker/cbreaker.go000066400000000000000000000172301450475076400162320ustar00rootroot00000000000000// Package cbreaker implements circuit breaker similar to https://github.com/Netflix/Hystrix/wiki/How-it-Works // // Vulcan circuit breaker watches the error condtion to match // after which it activates the fallback scenario, e.g. returns the response code // or redirects the request to another location // // Circuit breakers start in the Standby state first, observing responses and watching location metrics. // // Once the Circuit breaker condition is met, it enters the "Tripped" state, where it activates fallback scenario // for all requests during the FallbackDuration time period and reset the stats for the location. // // After FallbackDuration time period passes, Circuit breaker enters "Recovering" state, during that state it will // start passing some traffic back to the endpoints, increasing the amount of passed requests using linear function: // // allowedRequestsRatio = 0.5 * (Now() - StartRecovery())/RecoveryDuration // // Two scenarios are possible in the "Recovering" state: // 1. Condition matches again, this will reset the state to "Tripped" and reset the timer. // 2. Condition does not match, circuit breaker enters "Standby" state // // It is possible to define actions (e.g. webhooks) of transitions between states: // // * OnTripped action is called on transition (Standby -> Tripped) // * OnStandby action is called on transition (Recovering -> Standby) package cbreaker import ( "fmt" "net/http" "sync" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/memmetrics" "github.com/vulcand/oxy/v2/utils" ) // CircuitBreaker is http.Handler that implements circuit breaker pattern. type CircuitBreaker struct { m *sync.RWMutex metrics *memmetrics.RTMetrics condition hpredicate fallbackDuration time.Duration recoveryDuration time.Duration onTripped SideEffect onStandby SideEffect state cbState until clock.Time rc *ratioController checkPeriod time.Duration lastCheck clock.Time fallback http.Handler next http.Handler verbose bool log utils.Logger } // New creates a new CircuitBreaker middleware. func New(next http.Handler, expression string, options ...Option) (*CircuitBreaker, error) { cb := &CircuitBreaker{ m: &sync.RWMutex{}, next: next, // Default values. Might be overwritten by options below. checkPeriod: defaultCheckPeriod, fallbackDuration: defaultFallbackDuration, recoveryDuration: defaultRecoveryDuration, fallback: defaultFallback, log: &utils.NoopLogger{}, } for _, s := range options { if err := s(cb); err != nil { return nil, err } } condition, err := parseExpression(expression) if err != nil { return nil, err } cb.condition = condition mt, err := memmetrics.NewRTMetrics() if err != nil { return nil, err } cb.metrics = mt return cb, nil } func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) { if c.verbose { dump := utils.DumpHTTPRequest(req) c.log.Debug("vulcand/oxy/circuitbreaker: begin ServeHttp on request: %s", dump) defer c.log.Debug("vulcand/oxy/circuitbreaker: completed ServeHttp on request: %s", dump) } if c.activateFallback(w, req) { c.fallback.ServeHTTP(w, req) return } c.serve(w, req) } // Fallback sets the fallback handler to be called by circuit breaker handler. func (c *CircuitBreaker) Fallback(f http.Handler) { c.fallback = f } // Wrap sets the next handler to be called by circuit breaker handler. func (c *CircuitBreaker) Wrap(next http.Handler) { c.next = next } // updateState updates internal state and returns true if fallback should be used and false otherwise. func (c *CircuitBreaker) activateFallback(_ http.ResponseWriter, _ *http.Request) bool { // Quick check with read locks optimized for normal operation use-case if c.isStandby() { return false } // Circuit breaker is in tripped or recovering state c.m.Lock() defer c.m.Unlock() c.log.Warn("%v is in error state", c) switch c.state { case stateStandby: // someone else has set it to standby just now return false case stateTripped: if clock.Now().UTC().Before(c.until) { return true } // We have been in active state enough, enter recovering state c.setRecovering() fallthrough case stateRecovering: // We have been in recovering state enough, enter standby and allow request if clock.Now().UTC().After(c.until) { c.setState(stateStandby, clock.Now().UTC()) return false } // ratio controller allows this request if c.rc.allowRequest() { return false } return true } return false } func (c *CircuitBreaker) serve(w http.ResponseWriter, req *http.Request) { start := clock.Now().UTC() p := utils.NewProxyWriterWithLogger(w, c.log) c.next.ServeHTTP(p, req) latency := clock.Now().UTC().Sub(start) c.metrics.Record(p.StatusCode(), latency) // Note that this call is less expensive than it looks -- checkCondition only performs the real check // periodically. Because of that we can afford to call it here on every single response. c.checkAndSet() } func (c *CircuitBreaker) isStandby() bool { c.m.RLock() defer c.m.RUnlock() return c.state == stateStandby } // String returns log-friendly representation of the circuit breaker state. func (c *CircuitBreaker) String() string { switch c.state { case stateTripped, stateRecovering: return fmt.Sprintf("CircuitBreaker(state=%v, until=%v)", c.state, c.until) default: return fmt.Sprintf("CircuitBreaker(state=%v)", c.state) } } // exec executes side effect. func (c *CircuitBreaker) exec(s SideEffect) { if s == nil { return } go func() { if err := s.Exec(); err != nil { c.log.Error("%v side effect failure: %v", c, err) } }() } func (c *CircuitBreaker) setState(state cbState, until time.Time) { c.log.Debug("%v setting state to %v, until %v", c, state, until) c.state = state c.until = until switch state { case stateTripped: c.exec(c.onTripped) case stateStandby: c.exec(c.onStandby) } } func (c *CircuitBreaker) timeToCheck() bool { c.m.RLock() defer c.m.RUnlock() return clock.Now().UTC().After(c.lastCheck) } // Checks if tripping condition matches and sets circuit breaker to the tripped state. func (c *CircuitBreaker) checkAndSet() { if !c.timeToCheck() { return } c.m.Lock() defer c.m.Unlock() // Other goroutine could have updated the lastCheck variable before we grabbed mutex if clock.Now().UTC().Before(c.lastCheck) { return } c.lastCheck = clock.Now().UTC().Add(c.checkPeriod) if c.state == stateTripped { c.log.Debug("%v skip set tripped", c) return } if !c.condition(c) { return } c.setState(stateTripped, clock.Now().UTC().Add(c.fallbackDuration)) c.metrics.Reset() } func (c *CircuitBreaker) setRecovering() { c.setState(stateRecovering, clock.Now().UTC().Add(c.recoveryDuration)) c.rc = newRatioController(c.recoveryDuration, c.log) } // cbState is the state of the circuit breaker. type cbState int func (s cbState) String() string { switch s { case stateStandby: return "standby" case stateTripped: return "tripped" case stateRecovering: return "recovering" } return "undefined" } const ( // CircuitBreaker is passing all requests and watching stats. stateStandby = iota // CircuitBreaker activates fallback scenario for all requests. stateTripped // CircuitBreaker passes some requests to go through, rejecting others. stateRecovering ) const ( defaultFallbackDuration = 10 * clock.Second defaultRecoveryDuration = 10 * clock.Second defaultCheckPeriod = 100 * clock.Millisecond ) var defaultFallback = &fallback{} type fallback struct{} func (f *fallback) ServeHTTP(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))) } oxy-2.0.0/cbreaker/cbreaker_test.go000066400000000000000000000215051450475076400172710ustar00rootroot00000000000000package cbreaker import ( "fmt" "io" "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/memmetrics" "github.com/vulcand/oxy/v2/testutils" ) const triggerNetRatio = `NetworkErrorRatio() > 0.5` func TestStandbyCycle(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) cb, err := New(handler, triggerNetRatio) require.NoError(t, err) srv := httptest.NewServer(cb) t.Cleanup(srv.Close) re, body, err := testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestFullCycle(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) done := testutils.FreezeTime() defer done() cb, err := New(handler, triggerNetRatio) require.NoError(t, err) srv := httptest.NewServer(cb) t.Cleanup(srv.Close) re, _, err := testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) cb.metrics = statsNetErrors(0.6) clock.Advance(defaultCheckPeriod + clock.Millisecond) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateTripped), cb.state) // Some time has passed, but we are still in trapped state. clock.Advance(9 * clock.Second) re, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusServiceUnavailable, re.StatusCode) assert.Equal(t, cbState(stateTripped), cb.state) // We should be in recovering state by now clock.Advance(clock.Second*1 + clock.Millisecond) re, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusServiceUnavailable, re.StatusCode) assert.Equal(t, cbState(stateRecovering), cb.state) // 5 seconds after we should be allowing some requests to pass clock.Advance(5 * clock.Second) allowed := 0 for i := 0; i < 100; i++ { re, _, err = testutils.Get(srv.URL) if re.StatusCode == http.StatusOK && err == nil { allowed++ } } assert.NotEqual(t, 0, allowed) // After some time, all is good and we should be in stand by mode again clock.Advance(5*clock.Second + clock.Millisecond) re, _, err = testutils.Get(srv.URL) assert.Equal(t, cbState(stateStandby), cb.state) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestRedirectWithPath(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) fallbackRedirectPath, err := NewRedirectFallback(Redirect{ URL: "http://localhost:6000", PreservePath: true, }) require.NoError(t, err) cb, err := New(handler, triggerNetRatio, Fallback(fallbackRedirectPath)) require.NoError(t, err) srv := httptest.NewServer(cb) t.Cleanup(srv.Close) cb.metrics = statsNetErrors(0.6) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return fmt.Errorf("no redirects") }, } re, err := client.Get(srv.URL + "/somePath") require.Error(t, err) assert.Equal(t, http.StatusFound, re.StatusCode) assert.Equal(t, "http://localhost:6000/somePath", re.Header.Get("Location")) } func TestRedirect(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) fallbackRedirect, err := NewRedirectFallback(Redirect{URL: "http://localhost:5000"}) require.NoError(t, err) cb, err := New(handler, triggerNetRatio, Fallback(fallbackRedirect)) require.NoError(t, err) srv := httptest.NewServer(cb) t.Cleanup(srv.Close) cb.metrics = statsNetErrors(0.6) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return fmt.Errorf("no redirects") }, } re, err := client.Get(srv.URL + "/somePath") require.Error(t, err) assert.Equal(t, http.StatusFound, re.StatusCode) assert.Equal(t, "http://localhost:5000", re.Header.Get("Location")) } func TestTriggerDuringRecovery(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) done := testutils.FreezeTime() defer done() cb, err := New(handler, triggerNetRatio, CheckPeriod(clock.Microsecond)) require.NoError(t, err) srv := httptest.NewServer(cb) t.Cleanup(srv.Close) cb.metrics = statsNetErrors(0.6) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateTripped), cb.state) // We should be in recovering state by now clock.Advance(10*clock.Second + clock.Millisecond) re, _, err := testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusServiceUnavailable, re.StatusCode) assert.Equal(t, cbState(stateRecovering), cb.state) // We have matched error condition during recovery state and are going back to tripped state clock.Advance(5 * clock.Second) cb.metrics = statsNetErrors(0.6) allowed := 0 for i := 0; i < 100; i++ { re, _, err = testutils.Get(srv.URL) if re.StatusCode == http.StatusOK && err == nil { allowed++ } } assert.NotEqual(t, 0, allowed) assert.Equal(t, cbState(stateTripped), cb.state) } func TestSideEffects(t *testing.T) { srv1Chan := make(chan *http.Request, 1) var srv1Body []byte srv1 := testutils.NewHandler(func(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(r.Body) require.NoError(t, err) srv1Body = b _, _ = w.Write([]byte("srv1")) srv1Chan <- r }) defer srv1.Close() srv2Chan := make(chan *http.Request, 1) srv2 := testutils.NewHandler(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("srv2")) err := r.ParseForm() require.NoError(t, err) srv2Chan <- r }) defer srv2.Close() onTripped, err := NewWebhookSideEffect( Webhook{ URL: fmt.Sprintf("%s/post.json", srv1.URL), Method: http.MethodPost, Headers: map[string][]string{"Content-Type": {"application/json"}}, Body: []byte(`{"Key": ["val1", "val2"]}`), }) require.NoError(t, err) onStandby, err := NewWebhookSideEffect( Webhook{ URL: fmt.Sprintf("%s/post", srv2.URL), Method: http.MethodPost, Form: map[string][]string{"key": {"val1", "val2"}}, }) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) done := testutils.FreezeTime() defer done() cb, err := New(handler, triggerNetRatio, CheckPeriod(clock.Microsecond), OnTripped(onTripped), OnStandby(onStandby)) require.NoError(t, err) srv := httptest.NewServer(cb) t.Cleanup(srv.Close) cb.metrics = statsNetErrors(0.6) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateTripped), cb.state) select { case req := <-srv1Chan: assert.Equal(t, http.MethodPost, req.Method) assert.Equal(t, "/post.json", req.URL.Path) assert.Equal(t, `{"Key": ["val1", "val2"]}`, string(srv1Body)) assert.Equal(t, "application/json", req.Header.Get("Content-Type")) case <-clock.After(clock.Second): t.Error("timeout waiting for side effect to kick off") } // Transition to recovering state clock.Advance(10*clock.Second + clock.Millisecond) cb.metrics = statsOK() _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateRecovering), cb.state) // Going back to standby clock.Advance(10*clock.Second + clock.Millisecond) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateStandby), cb.state) select { case req := <-srv2Chan: assert.Equal(t, http.MethodPost, req.Method) assert.Equal(t, "/post", req.URL.Path) assert.Equal(t, url.Values{"key": []string{"val1", "val2"}}, req.Form) case <-clock.After(clock.Second): t.Error("timeout waiting for side effect to kick off") } } func statsOK() *memmetrics.RTMetrics { m, err := memmetrics.NewRTMetrics() if err != nil { panic(err) } return m } func statsNetErrors(threshold float64) *memmetrics.RTMetrics { m, err := memmetrics.NewRTMetrics() if err != nil { panic(err) } for i := 0; i < 100; i++ { if i < int(threshold*100) { m.Record(http.StatusGatewayTimeout, 0) } else { m.Record(http.StatusOK, 0) } } return m } func statsLatencyAtQuantile(_ float64, value time.Duration) *memmetrics.RTMetrics { m, err := memmetrics.NewRTMetrics() if err != nil { panic(err) } m.Record(http.StatusOK, value) return m } func statsResponseCodes(codes ...statusCode) *memmetrics.RTMetrics { m, err := memmetrics.NewRTMetrics() if err != nil { panic(err) } for _, c := range codes { for i := int64(0); i < c.Count; i++ { m.Record(c.Code, 0) } } return m } type statusCode struct { Code int Count int64 } oxy-2.0.0/cbreaker/effect.go000066400000000000000000000034141450475076400157070ustar00rootroot00000000000000package cbreaker import ( "bytes" "fmt" "io" "net/http" "net/url" "strings" "github.com/vulcand/oxy/v2/utils" ) // SideEffect a side effect. type SideEffect interface { Exec() error } // Webhook Web hook. type Webhook struct { URL string Method string Headers http.Header Form url.Values Body []byte } // WebhookSideEffect a web hook side effect. type WebhookSideEffect struct { w Webhook log utils.Logger } // NewWebhookSideEffectsWithLogger creates a new WebhookSideEffect. func NewWebhookSideEffectsWithLogger(w Webhook, l utils.Logger) (*WebhookSideEffect, error) { if w.Method == "" { return nil, fmt.Errorf("supply method") } _, err := url.Parse(w.URL) if err != nil { return nil, err } return &WebhookSideEffect{w: w, log: l}, nil } // NewWebhookSideEffect creates a new WebhookSideEffect. func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { return NewWebhookSideEffectsWithLogger(w, &utils.NoopLogger{}) } func (w *WebhookSideEffect) getBody() io.Reader { if len(w.w.Form) != 0 { return strings.NewReader(w.w.Form.Encode()) } if len(w.w.Body) != 0 { return bytes.NewBuffer(w.w.Body) } return nil } // Exec execute the side effect. func (w *WebhookSideEffect) Exec() error { r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody()) if err != nil { return err } if len(w.w.Headers) != 0 { utils.CopyHeaders(r.Header, w.w.Headers) } if len(w.w.Form) != 0 { r.Header.Set("Content-Type", "application/x-www-form-urlencoded") } re, err := http.DefaultClient.Do(r) if err != nil { return err } if re.Body != nil { defer func() { _ = re.Body.Close() }() } body, err := io.ReadAll(re.Body) if err != nil { return err } w.log.Debug("%v got response: (%s): %s", w, re.Status, string(body)) return nil } oxy-2.0.0/cbreaker/fallback.go000066400000000000000000000050031450475076400162060ustar00rootroot00000000000000package cbreaker import ( "net/http" "net/url" "strconv" "github.com/vulcand/oxy/v2/utils" ) // Response response model. type Response struct { StatusCode int ContentType string Body []byte } // ResponseFallback fallback response handler. type ResponseFallback struct { r Response debug bool log utils.Logger } // NewResponseFallback creates a new ResponseFallback. func NewResponseFallback(r Response, options ...ResponseFallbackOption) (*ResponseFallback, error) { rf := &ResponseFallback{r: r, log: &utils.NoopLogger{}} for _, s := range options { if err := s(rf); err != nil { return nil, err } } return rf, nil } func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { if f.debug { dump := utils.DumpHTTPRequest(req) f.log.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request: %s", dump) defer f.log.Debug("vulcand/oxy/fallback/response: completed ServeHttp on request: %s", dump) } if f.r.ContentType != "" { w.Header().Set("Content-Type", f.r.ContentType) } w.Header().Set("Content-Length", strconv.Itoa(len(f.r.Body))) w.WriteHeader(f.r.StatusCode) _, err := w.Write(f.r.Body) if err != nil { f.log.Error("vulcand/oxy/fallback/response: failed to write response, err: %v", err) } } // Redirect redirect model. type Redirect struct { URL string PreservePath bool } // RedirectFallback fallback redirect handler. type RedirectFallback struct { r Redirect u *url.URL debug bool log utils.Logger } // NewRedirectFallback creates a new RedirectFallback. func NewRedirectFallback(r Redirect, options ...RedirectFallbackOption) (*RedirectFallback, error) { u, err := url.ParseRequestURI(r.URL) if err != nil { return nil, err } rf := &RedirectFallback{r: r, u: u, log: &utils.NoopLogger{}} for _, s := range options { if err := s(rf); err != nil { return nil, err } } return rf, nil } func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { if f.debug { dump := utils.DumpHTTPRequest(req) f.log.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request: %s", dump) defer f.log.Debug("vulcand/oxy/fallback/redirect: completed ServeHttp on request: %s", dump) } location := f.u.String() if f.r.PreservePath { location += req.URL.Path } w.Header().Set("Location", location) w.WriteHeader(http.StatusFound) _, err := w.Write([]byte(http.StatusText(http.StatusFound))) if err != nil { f.log.Error("vulcand/oxy/fallback/redirect: failed to write response, err: %v", err) } } oxy-2.0.0/cbreaker/options.go000066400000000000000000000057621450475076400161560ustar00rootroot00000000000000package cbreaker import ( "net/http" "time" "github.com/vulcand/oxy/v2/utils" ) // Option represents an option you can pass to New. type Option func(*CircuitBreaker) error // Logger defines the logger used by CircuitBreaker. func Logger(l utils.Logger) Option { return func(c *CircuitBreaker) error { c.log = l return nil } } // Verbose additional debug information. func Verbose(verbose bool) Option { return func(c *CircuitBreaker) error { c.verbose = verbose return nil } } // FallbackDuration is how long the CircuitBreaker will remain in the Tripped // state before trying to recover. func FallbackDuration(d time.Duration) Option { return func(c *CircuitBreaker) error { c.fallbackDuration = d return nil } } // RecoveryDuration is how long the CircuitBreaker will take to ramp up // requests during the Recovering state. func RecoveryDuration(d time.Duration) Option { return func(c *CircuitBreaker) error { c.recoveryDuration = d return nil } } // CheckPeriod is how long the CircuitBreaker will wait between successive // checks of the breaker condition. func CheckPeriod(d time.Duration) Option { return func(c *CircuitBreaker) error { c.checkPeriod = d return nil } } // OnTripped sets a SideEffect to run when entering the Tripped state. // Only one SideEffect can be set for this hook. func OnTripped(s SideEffect) Option { return func(c *CircuitBreaker) error { c.onTripped = s return nil } } // OnStandby sets a SideEffect to run when entering the Standby state. // Only one SideEffect can be set for this hook. func OnStandby(s SideEffect) Option { return func(c *CircuitBreaker) error { c.onStandby = s return nil } } // Fallback defines the http.Handler that the CircuitBreaker should route // requests to when it prevents a request from taking its normal path. func Fallback(h http.Handler) Option { return func(c *CircuitBreaker) error { c.fallback = h return nil } } // ResponseFallbackOption represents an option you can pass to NewResponseFallback. type ResponseFallbackOption func(*ResponseFallback) error // ResponseFallbackLogger defines the logger used by ResponseFallback. func ResponseFallbackLogger(l utils.Logger) ResponseFallbackOption { return func(c *ResponseFallback) error { c.log = l return nil } } // ResponseFallbackDebug additional debug information. func ResponseFallbackDebug(debug bool) ResponseFallbackOption { return func(c *ResponseFallback) error { c.debug = debug return nil } } // RedirectFallbackOption represents an option you can pass to NewRedirectFallback. type RedirectFallbackOption func(*RedirectFallback) error // RedirectFallbackLogger defines the logger used by ResponseFallback. func RedirectFallbackLogger(l utils.Logger) RedirectFallbackOption { return func(c *RedirectFallback) error { c.log = l return nil } } // RedirectFallbackDebug additional debug information. func RedirectFallbackDebug(debug bool) RedirectFallbackOption { return func(c *RedirectFallback) error { c.debug = debug return nil } } oxy-2.0.0/cbreaker/predicates.go000066400000000000000000000127461450475076400166060ustar00rootroot00000000000000package cbreaker import ( "fmt" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/predicate" ) type hpredicate func(*CircuitBreaker) bool // parseExpression parses expression in the go language into predicates. func parseExpression(in string) (hpredicate, error) { p, err := predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ AND: and, OR: or, EQ: eq, NEQ: neq, LT: lt, LE: le, GT: gt, GE: ge, }, Functions: map[string]interface{}{ "LatencyAtQuantileMS": latencyAtQuantile, "NetworkErrorRatio": networkErrorRatio, "ResponseCodeRatio": responseCodeRatio, }, }) if err != nil { return nil, err } out, err := p.Parse(in) if err != nil { return nil, err } pr, ok := out.(hpredicate) if !ok { return nil, fmt.Errorf("expected predicate, got %T", out) } return pr, nil } type toInt func(c *CircuitBreaker) int type toFloat64 func(c *CircuitBreaker) float64 func latencyAtQuantile(quantile float64) toInt { return func(c *CircuitBreaker) int { h, err := c.metrics.LatencyHistogram() if err != nil { c.log.Error("Failed to get latency histogram, for %v error: %v", c, err) return 0 } return int(h.LatencyAtQuantile(quantile) / clock.Millisecond) } } func networkErrorRatio() toFloat64 { return func(c *CircuitBreaker) float64 { return c.metrics.NetworkErrorRatio() } } func responseCodeRatio(startA, endA, startB, endB int) toFloat64 { return func(c *CircuitBreaker) float64 { return c.metrics.ResponseCodeRatio(startA, endA, startB, endB) } } // or returns predicate by joining the passed predicates with logical 'or'. func or(fns ...hpredicate) hpredicate { return func(c *CircuitBreaker) bool { for _, fn := range fns { if fn(c) { return true } } return false } } // and returns predicate by joining the passed predicates with logical 'and'. func and(fns ...hpredicate) hpredicate { return func(c *CircuitBreaker) bool { for _, fn := range fns { if !fn(c) { return false } } return true } } // not creates negation of the passed predicate. func not(p hpredicate) hpredicate { return func(c *CircuitBreaker) bool { return !p(c) } } // eq returns predicate that tests for equality of the value of the mapper and the constant. func eq(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intEQ(mapper, value) case toFloat64: return float64EQ(mapper, value) } return nil, fmt.Errorf("eq: unsupported argument: %T", m) } // neq returns predicate that tests for inequality of the value of the mapper and the constant. func neq(m interface{}, value interface{}) (hpredicate, error) { p, err := eq(m, value) if err != nil { return nil, err } return not(p), nil } // lt returns predicate that tests that value of the mapper function is less than the constant. func lt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intLT(mapper, value) case toFloat64: return float64LT(mapper, value) } return nil, fmt.Errorf("lt: unsupported argument: %T", m) } // le returns predicate that tests that value of the mapper function is less or equal than the constant. func le(m interface{}, value interface{}) (hpredicate, error) { l, err := lt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *CircuitBreaker) bool { return l(c) || e(c) }, nil } // gt returns predicate that tests that value of the mapper function is greater than the constant. func gt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intGT(mapper, value) case toFloat64: return float64GT(mapper, value) } return nil, fmt.Errorf("gt: unsupported argument: %T", m) } // ge returns predicate that tests that value of the mapper function is less or equal than the constant. func ge(m interface{}, value interface{}) (hpredicate, error) { g, err := gt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *CircuitBreaker) bool { return g(c) || e(c) }, nil } func intEQ(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) == value }, nil } func float64EQ(m toFloat64, val interface{}) (hpredicate, error) { value, ok := val.(float64) if !ok { return nil, fmt.Errorf("expected float64, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) == value }, nil } func intLT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) < value }, nil } func intGT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) > value }, nil } func float64LT(m toFloat64, val interface{}) (hpredicate, error) { value, ok := val.(float64) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) < value }, nil } func float64GT(m toFloat64, val interface{}) (hpredicate, error) { value, ok := val.(float64) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) > value }, nil } oxy-2.0.0/cbreaker/predicates_test.go000066400000000000000000000032251450475076400176350ustar00rootroot00000000000000package cbreaker import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/memmetrics" ) func TestTripped(t *testing.T) { testCases := []struct { expression string metrics *memmetrics.RTMetrics expected bool }{ { expression: "NetworkErrorRatio() > 0.5", metrics: statsNetErrors(0.6), expected: true, }, { expression: "NetworkErrorRatio() < 0.5", metrics: statsNetErrors(0.6), expected: false, }, { expression: "LatencyAtQuantileMS(50.0) > 50", metrics: statsLatencyAtQuantile(50, clock.Millisecond*51), expected: true, }, { expression: "LatencyAtQuantileMS(50.0) < 50", metrics: statsLatencyAtQuantile(50, clock.Millisecond*51), expected: false, }, { expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", metrics: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 6}), expected: true, }, { expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", metrics: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 4}), expected: false, }, { // quantile not defined expression: "LatencyAtQuantileMS(40.0) > 50", metrics: statsNetErrors(0.6), expected: false, }, } for _, test := range testCases { test := test t.Run(test.expression, func(t *testing.T) { t.Parallel() p, err := parseExpression(test.expression) require.NoError(t, err) require.NotNil(t, p) assert.Equal(t, test.expected, p(&CircuitBreaker{metrics: test.metrics})) }) } } oxy-2.0.0/cbreaker/ratio.go000066400000000000000000000035611450475076400155740ustar00rootroot00000000000000package cbreaker import ( "fmt" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/utils" ) // ratioController allows passing portions traffic back to the endpoints, // increasing the amount of passed requests using linear function: // // allowedRequestsRatio = 0.5 * (Now() - Start())/Duration type ratioController struct { duration time.Duration start clock.Time allowed int denied int log utils.Logger } func newRatioController(rampUp time.Duration, log utils.Logger) *ratioController { return &ratioController{ duration: rampUp, start: clock.Now().UTC(), log: log, } } func (r *ratioController) String() string { return fmt.Sprintf("RatioController(target=%f, current=%f, allowed=%d, denied=%d)", r.targetRatio(), r.computeRatio(r.allowed, r.denied), r.allowed, r.denied) } func (r *ratioController) allowRequest() bool { r.log.Debug("%v", r) t := r.targetRatio() // This condition answers the question - would we satisfy the target ratio if we allow this request? e := r.computeRatio(r.allowed+1, r.denied) if e < t { r.allowed++ r.log.Debug("%v allowed", r) return true } r.denied++ r.log.Debug("%v denied", r) return false } func (r *ratioController) computeRatio(allowed, denied int) float64 { if denied+allowed == 0 { return 0 } return float64(allowed) / float64(denied+allowed) } func (r *ratioController) targetRatio() float64 { // Here's why it's 0.5: // We are watching the following ratio: // ratio = a / (a + d) // We can notice, that once we get to 0.5: // 0.5 = a / (a + d) // we can evaluate that a = d // that means equilibrium, where we would allow all the requests // after this point to achieve ratio of 1 (that can never be reached unless d is 0) // so we stop from there multiplier := 0.5 / float64(r.duration) return multiplier * float64(clock.Now().UTC().Sub(r.start)) } oxy-2.0.0/cbreaker/ratio_test.go000066400000000000000000000023451450475076400166320ustar00rootroot00000000000000package cbreaker import ( "math" "testing" "github.com/stretchr/testify/assert" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" "github.com/vulcand/oxy/v2/utils" ) func TestRampUp(t *testing.T) { done := testutils.FreezeTime() defer done() duration := 10 * clock.Second rc := newRatioController(duration, &utils.NoopLogger{}) allowed, denied := 0, 0 for i := 0; i < int(duration/clock.Millisecond); i++ { ratio := sendRequest(&allowed, &denied, rc) expected := rc.targetRatio() diff := math.Abs(expected - ratio) t.Log("Ratio", ratio) t.Log("Expected", expected) t.Log("Diff", diff) assert.EqualValues(t, 0, round(diff, 0.5, 1)) clock.Advance(clock.Millisecond) } } func sendRequest(allowed, denied *int, rc *ratioController) float64 { if rc.allowRequest() { *allowed++ } else { *denied++ } if *allowed+*denied == 0 { return 0 } return float64(*allowed) / float64(*allowed+*denied) } func round(val float64, roundOn float64, places int) float64 { pow := math.Pow(10, float64(places)) digit := pow * val _, div := math.Modf(digit) var round float64 if div >= roundOn { round = math.Ceil(digit) } else { round = math.Floor(digit) } return round / pow } oxy-2.0.0/connlimit/000077500000000000000000000000001450475076400143405ustar00rootroot00000000000000oxy-2.0.0/connlimit/connlimit.go000066400000000000000000000063671450475076400166770ustar00rootroot00000000000000// Package connlimit provides control over simultaneous connections coming from the same source package connlimit import ( "fmt" "net/http" "sync" "github.com/vulcand/oxy/v2/utils" ) // ConnLimiter tracks concurrent connection per token // and is capable of rejecting connections if they are failed. type ConnLimiter struct { mutex *sync.Mutex extract utils.SourceExtractor connections map[string]int64 maxConnections int64 totalConnections int64 next http.Handler errHandler utils.ErrorHandler verbose bool log utils.Logger } // New creates a new ConnLimiter. func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...Option) (*ConnLimiter, error) { if extract == nil { return nil, fmt.Errorf("extract function can not be nil") } cl := &ConnLimiter{ mutex: &sync.Mutex{}, extract: extract, maxConnections: maxConnections, connections: make(map[string]int64), next: next, log: &utils.NoopLogger{}, } for _, o := range options { if err := o(cl); err != nil { return nil, err } } if cl.errHandler == nil { cl.errHandler = &ConnErrHandler{ debug: cl.verbose, log: cl.log, } } return cl, nil } // Wrap sets the next handler to be called by connection limiter handler. func (cl *ConnLimiter) Wrap(h http.Handler) { cl.next = h } func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { token, amount, err := cl.extract.Extract(r) if err != nil { cl.log.Error("failed to extract source of the connection: %v", err) cl.errHandler.ServeHTTP(w, r, err) return } if err := cl.acquire(token, amount); err != nil { cl.log.Debug("limiting request source %s: %v", token, err) cl.errHandler.ServeHTTP(w, r, err) return } defer cl.release(token, amount) cl.next.ServeHTTP(w, r) } func (cl *ConnLimiter) acquire(token string, amount int64) error { cl.mutex.Lock() defer cl.mutex.Unlock() connections := cl.connections[token] if connections >= cl.maxConnections { return &MaxConnError{max: cl.maxConnections} } cl.connections[token] += amount cl.totalConnections += amount return nil } func (cl *ConnLimiter) release(token string, amount int64) { cl.mutex.Lock() defer cl.mutex.Unlock() cl.connections[token] -= amount cl.totalConnections -= amount // Otherwise it would grow forever if cl.connections[token] == 0 { delete(cl.connections, token) } } // MaxConnError maximum connections reached error. type MaxConnError struct { max int64 } func (m *MaxConnError) Error() string { return fmt.Sprintf("max connections reached: %d", m.max) } // ConnErrHandler connection limiter error handler. type ConnErrHandler struct { debug bool log utils.Logger } func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { if e.debug { dump := utils.DumpHTTPRequest(req) e.log.Debug("vulcand/oxy/connlimit: begin ServeHttp on request: %s", dump) defer e.log.Debug("vulcand/oxy/connlimit: completed ServeHttp on request: %s", dump) } //nolint:errorlint // must be changed if _, ok := err.(*MaxConnError); ok { w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(err.Error())) return } utils.DefaultHandler.ServeHTTP(w, req, err) } oxy-2.0.0/connlimit/connlimit_test.go000066400000000000000000000060361450475076400177270ustar00rootroot00000000000000package connlimit import ( "fmt" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/testutils" "github.com/vulcand/oxy/v2/utils" ) // We've hit the limit and were able to proceed once the request has completed. func TestHitLimitAndRelease(t *testing.T) { wait := make(chan bool) proceed := make(chan bool) finish := make(chan bool) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { t.Logf("%v", req.Header) if req.Header.Get("Wait") != "" { proceed <- true <-wait } _, _ = w.Write([]byte("hello")) }) cl, err := New(handler, headerLimit, 1) require.NoError(t, err) srv := httptest.NewServer(cl) t.Cleanup(srv.Close) go func() { re, _, errGet := testutils.Get(srv.URL, testutils.Header("Limit", "a"), testutils.Header("wait", "yes")) require.NoError(t, errGet) assert.Equal(t, http.StatusOK, re.StatusCode) finish <- true }() <-proceed re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a")) require.NoError(t, err) assert.Equal(t, http.StatusTooManyRequests, re.StatusCode) // request from another source succeeds re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "b")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) // Once the first request finished, next one succeeds close(wait) <-finish re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // We've hit the limit and were able to proceed once the request has completed. func TestCustomHandlers(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) _, _ = w.Write([]byte(http.StatusText(http.StatusTeapot))) }) l, err := New(handler, headerLimit, 0, ErrorHandler(errHandler)) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a")) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) } // We've hit the limit and were able to proceed once the request has completed. func TestFaultyExtract(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) l, err := New(handler, faultyExtract, 1) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) re, _, err := testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } func headerLimiter(req *http.Request) (string, int64, error) { return req.Header.Get("Limit"), 1, nil } func faultyExtractor(_ *http.Request) (string, int64, error) { return "", -1, fmt.Errorf("oops") } var headerLimit = utils.ExtractorFunc(headerLimiter) var faultyExtract = utils.ExtractorFunc(faultyExtractor) oxy-2.0.0/connlimit/options.go000066400000000000000000000012071450475076400163620ustar00rootroot00000000000000package connlimit import ( "github.com/vulcand/oxy/v2/utils" ) // Option represents an option you can pass to New. type Option func(l *ConnLimiter) error // Logger defines the logger used by ConnLimiter. func Logger(l utils.Logger) Option { return func(cl *ConnLimiter) error { cl.log = l return nil } } // Verbose additional debug information. func Verbose(verbose bool) Option { return func(cl *ConnLimiter) error { cl.verbose = verbose return nil } } // ErrorHandler sets error handler of the server. func ErrorHandler(h utils.ErrorHandler) Option { return func(cl *ConnLimiter) error { cl.errHandler = h return nil } } oxy-2.0.0/forward/000077500000000000000000000000001450475076400140105ustar00rootroot00000000000000oxy-2.0.0/forward/example_test.go000066400000000000000000000056771450475076400170500ustar00rootroot00000000000000package forward import ( "crypto/tls" "fmt" "io" "net/http" "net/http/httptest" "net/url" ) func ExampleNew_customErrHandler() { f := New(true) f.ErrorHandler = func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) _, _ = w.Write([]byte(http.StatusText(http.StatusTeapot))) } proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL, _ = url.ParseRequestURI("http://localhost:63450") f.ServeHTTP(w, req) })) defer proxy.Close() resp, err := http.Get(proxy.URL) if err != nil { fmt.Println(err) return } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) return } fmt.Println(resp.StatusCode) fmt.Println(string(body)) // output: // 418 // I'm a teapot } func ExampleNew_responseModifier() { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) })) defer srv.Close() f := New(true) f.ModifyResponse = func(resp *http.Response) error { resp.Header.Add("X-Test", "CUSTOM") return nil } proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL, _ = url.ParseRequestURI(srv.URL) f.ServeHTTP(w, req) })) defer proxy.Close() resp, err := http.Get(proxy.URL) if err != nil { fmt.Println(err) return } fmt.Println(resp.StatusCode) fmt.Println(resp.Header.Get("X-Test")) // Output: // 200 // CUSTOM } func ExampleNew_customTransport() { srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) })) defer srv.Close() f := New(true) f.Transport = &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL, _ = url.ParseRequestURI(srv.URL) f.ServeHTTP(w, req) })) defer proxy.Close() resp, err := http.Get(proxy.URL) if err != nil { fmt.Println(err) return } body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println(err) return } fmt.Println(resp.StatusCode) fmt.Println(string(body)) // Output: // 200 // hello } func ExampleNewStateListener() { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) })) defer srv.Close() f := New(true) f.ModifyResponse = func(resp *http.Response) error { resp.Header.Add("X-Test", "CUSTOM") return nil } stateLn := NewStateListener(f, func(u *url.URL, i int) { fmt.Println(u.Hostname(), i) }) proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL, _ = url.ParseRequestURI(srv.URL) stateLn.ServeHTTP(w, req) })) defer proxy.Close() resp, err := http.Get(proxy.URL) if err != nil { fmt.Println(err) return } fmt.Println(resp.StatusCode) // Output: // 127.0.0.1 0 // 127.0.0.1 1 // 200 } oxy-2.0.0/forward/fwd.go000066400000000000000000000024721450475076400151240ustar00rootroot00000000000000// Package forward creates a pre-configured httputil.ReverseProxy. package forward import ( "net/http" "net/http/httputil" "net/url" "github.com/vulcand/oxy/v2/utils" ) // New creates a new ReverseProxy. func New(passHostHeader bool) *httputil.ReverseProxy { h := NewHeaderRewriter() return &httputil.ReverseProxy{ Director: func(request *http.Request) { modifyRequest(request) h.Rewrite(request) if !passHostHeader { request.Host = request.URL.Host } }, ErrorHandler: utils.DefaultHandler.ServeHTTP, } } // Modify the request to handle the target URL. func modifyRequest(outReq *http.Request) { u := getURLFromRequest(outReq) outReq.URL.Path = u.Path outReq.URL.RawPath = u.RawPath outReq.URL.RawQuery = u.RawQuery outReq.RequestURI = "" // Outgoing request should not have RequestURI outReq.Proto = "HTTP/1.1" outReq.ProtoMajor = 1 outReq.ProtoMinor = 1 } func getURLFromRequest(req *http.Request) *url.URL { // If the Request was created by Go via a real HTTP request, // RequestURI will contain the original query string. // If the Request was created in code, // RequestURI will be empty, and we will use the URL object instead u := req.URL if req.RequestURI != "" { parsedURL, err := url.ParseRequestURI(req.RequestURI) if err == nil { return parsedURL } } return u } oxy-2.0.0/forward/fwd_test.go000066400000000000000000000045471450475076400161700ustar00rootroot00000000000000package forward import ( "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/testutils" ) func TestDefaultErrHandler(t *testing.T) { f := New(true) proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI("http://localhost:63450") f.ServeHTTP(w, req) })) t.Cleanup(proxy.Close) resp, err := http.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusBadGateway, resp.StatusCode) } func TestXForwardedHostHeader(t *testing.T) { tests := []struct { Description string PassHostHeader bool TargetURL string ProxyfiedURL string ExpectedXForwardedHost string }{ { Description: "XForwardedHost without PassHostHeader", PassHostHeader: false, TargetURL: "http://xforwardedhost.com", ProxyfiedURL: "http://backend.com", ExpectedXForwardedHost: "xforwardedhost.com", }, { Description: "XForwardedHost with PassHostHeader", PassHostHeader: true, TargetURL: "http://xforwardedhost.com", ProxyfiedURL: "http://backend.com", ExpectedXForwardedHost: "xforwardedhost.com", }, } for _, test := range tests { test := test t.Run(test.Description, func(t *testing.T) { t.Parallel() f := New(true) r, err := http.NewRequest(http.MethodGet, test.TargetURL, nil) require.NoError(t, err) backendURL, err := url.Parse(test.ProxyfiedURL) require.NoError(t, err) r.URL = backendURL f.Director(r) require.Equal(t, test.ExpectedXForwardedHost, r.Header.Get(XForwardedHost)) }) } } func TestForwardedProto(t *testing.T) { var proto string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { proto = req.Header.Get(XForwardedProto) _, _ = w.Write([]byte("hello")) })) t.Cleanup(srv.Close) f := New(true) proxy := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) })) proxy.StartTLS() t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "https", proto) } oxy-2.0.0/forward/fwd_websocket_test.go000066400000000000000000000357031450475076400202340ustar00rootroot00000000000000package forward import ( "bufio" "crypto/tls" "errors" "fmt" "net" "net/http" "net/http/httptest" "testing" gorillawebsocket "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" "golang.org/x/net/websocket" ) func TestWebSocketTCPClose(t *testing.T) { f := New(true) errChan := make(chan error, 1) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer func(c *gorillawebsocket.Conn) { _ = c.Close() }(c) for { _, _, err := c.ReadMessage() if err != nil { errChan <- err break } } })) t.Cleanup(srv.Close) proxy := createProxyWithForwarder(f, srv.URL) proxyAddr := proxy.Listener.Addr().String() _, conn, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), ).open() require.NoError(t, err) _ = conn.Close() serverErr := <-errChan wsErr := &gorillawebsocket.CloseError{} assert.ErrorAs(t, serverErr, &wsErr) assert.Equal(t, 1006, wsErr.Code) } func TestWebSocketPingPong(t *testing.T) { f := New(true) upgrader := gorillawebsocket.Upgrader{ HandshakeTimeout: 10 * clock.Second, CheckOrigin: func(*http.Request) bool { return true }, } mux := http.NewServeMux() mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) { ws, err := upgrader.Upgrade(writer, request, nil) require.NoError(t, err) ws.SetPingHandler(func(appData string) error { _ = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong")) return nil }) _, _, _ = ws.ReadMessage() }) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) t.Cleanup(proxy.Close) serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", resp) defer conn.Close() goodErr := fmt.Errorf("signal: %s", "Good data") badErr := fmt.Errorf("signal: %s", "Bad data") conn.SetPongHandler(func(data string) error { if data == "PingPong" { return goodErr } return badErr }) _ = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), clock.Now().Add(clock.Second)) _, _, err = conn.ReadMessage() if !errors.Is(err, goodErr) { require.NoError(t, err) } } func TestWebSocketEcho(t *testing.T) { f := New(true) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { msg := make([]byte, 4) _, _ = conn.Read(msg) t.Log(string(msg)) _, _ = conn.Write(msg) _ = conn.Close() })) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) t.Cleanup(proxy.Close) serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", resp) _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) t.Log(conn.ReadMessage()) _ = conn.Close() } func TestWebSocketPassHost(t *testing.T) { testCases := []struct { desc string passHost bool expected string }{ { desc: "PassHost false", passHost: false, }, { desc: "PassHost true", passHost: true, expected: "example.com", }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { f := New(test.passHost) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { req := conn.Request() if test.passHost { require.Equal(t, test.expected, req.Host) } else { require.NotEqual(t, test.expected, req.Host) } msg := make([]byte, 4) _, _ = conn.Read(msg) t.Log(string(msg)) _, _ = conn.Write(msg) _ = conn.Close() })) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) t.Cleanup(proxy.Close) serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) headers.Add("Host", "example.com") conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", resp) _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) t.Log(conn.ReadMessage()) _ = conn.Close() }) } } func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { f := New(true) upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) t.Cleanup(srv.Close) proxy := createProxyWithForwarder(f, srv.URL) t.Cleanup(proxy.Close) proxyAddr := proxy.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), withOrigin("http://127.0.0.2"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestWebSocketRequestWithOrigin(t *testing.T) { f := New(true) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) t.Cleanup(srv.Close) proxy := createProxyWithForwarder(f, srv.URL) t.Cleanup(proxy.Close) proxyAddr := proxy.Listener.Addr().String() _, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("echo"), withOrigin("http://127.0.0.2"), ).send() require.EqualError(t, err, "bad status") resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestWebSocketRequestWithQueryParams(t *testing.T) { f := New(true) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() assert.Equal(t, "test", r.URL.Query().Get("query")) for { mt, message, err := conn.ReadMessage() if err != nil { break } err = conn.WriteMessage(mt, message) if err != nil { break } } })) t.Cleanup(srv.Close) proxy := createProxyWithForwarder(f, srv.URL) t.Cleanup(proxy.Close) proxyAddr := proxy.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws?query=test"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { f := New(true) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { _ = conn.Close() })) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) w.Header().Set("HEADER-KEY", "HEADER-VALUE") f.ServeHTTP(w, req) }) t.Cleanup(proxy.Close) serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", err, resp) t.Cleanup(func() { _ = conn.Close() }) assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY")) } func TestWebSocketRequestWithEncodedChar(t *testing.T) { f := New(true) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath()) for { mt, message, err := conn.ReadMessage() if err != nil { break } err = conn.WriteMessage(mt, message) if err != nil { break } } })) t.Cleanup(srv.Close) proxy := createProxyWithForwarder(f, srv.URL) t.Cleanup(proxy.Close) proxyAddr := proxy.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/%3A%2F%2F"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestWebSocketUpgradeFailed(t *testing.T) { f := New(true) mux := http.NewServeMux() mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusBadRequest) }) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path if path == "/ws" { // Set new backend URL req.URL = testutils.ParseURI(srv.URL) req.URL.Path = path f.ServeHTTP(w, req) } }) t.Cleanup(proxy.Close) proxyAddr := proxy.Listener.Addr().String() conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout) require.NoError(t, err) defer conn.Close() req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil) require.NoError(t, err) req.Header.Add("upgrade", "websocket") req.Header.Add("Connection", "upgrade") _ = req.Write(conn) // First request works with 400 br := bufio.NewReader(conn) resp, err := http.ReadResponse(br, req) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) } func TestForwardsWebsocketTraffic(t *testing.T) { f := New(true) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { _, _ = conn.Write([]byte("ok")) _ = conn.Close() })) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) proxy := createProxyWithForwarder(f, srv.URL) t.Cleanup(proxy.Close) proxyAddr := proxy.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("echo"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func createTLSWebsocketServer() *httptest.Server { upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() for { mt, message, err := conn.ReadMessage() if err != nil { break } err = conn.WriteMessage(mt, message) if err != nil { break } } })) return srv } func createProxyWithForwarder(forwarder http.Handler, uri string) *httptest.Server { return testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path // Set new backend URL req.URL = testutils.ParseURI(uri) req.URL.Path = path forwarder.ServeHTTP(w, req) }) } func TestWebSocketTransferTLSConfig(t *testing.T) { srv := createTLSWebsocketServer() t.Cleanup(srv.Close) forwarderWithoutTLSConfig := New(true) proxyWithoutTLSConfig := createProxyWithForwarder(forwarderWithoutTLSConfig, srv.URL) defer proxyWithoutTLSConfig.Close() proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String() _, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), ).send() require.EqualError(t, err, "bad status") transport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } forwarderWithTLSConfig := New(true) forwarderWithTLSConfig.Transport = transport proxyWithTLSConfig := createProxyWithForwarder(forwarderWithTLSConfig, srv.URL) defer proxyWithTLSConfig.Close() proxyAddr = proxyWithTLSConfig.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} forwarderWithTLSConfigFromDefaultTransport := New(true) proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(forwarderWithTLSConfigFromDefaultTransport, srv.URL) defer proxyWithTLSConfig.Close() proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String() resp, err = newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } const dialTimeout = clock.Second type websocketRequestOpt func(w *websocketRequest) func withServer(server string) websocketRequestOpt { return func(w *websocketRequest) { w.ServerAddr = server } } func withPath(path string) websocketRequestOpt { return func(w *websocketRequest) { w.Path = path } } func withData(data string) websocketRequestOpt { return func(w *websocketRequest) { w.Data = data } } func withOrigin(origin string) websocketRequestOpt { return func(w *websocketRequest) { w.Origin = origin } } func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest { wsrequest := &websocketRequest{} for _, opt := range opts { opt(wsrequest) } if wsrequest.Origin == "" { wsrequest.Origin = "http://" + wsrequest.ServerAddr } if wsrequest.Config == nil { wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin) } return wsrequest } type websocketRequest struct { ServerAddr string Path string Data string Origin string Config *websocket.Config } func (w *websocketRequest) send() (string, error) { conn, _, err := w.open() if err != nil { return "", err } defer conn.Close() if _, err := conn.Write([]byte(w.Data)); err != nil { return "", err } msg := make([]byte, 512) var n int n, err = conn.Read(msg) if err != nil { return "", err } received := string(msg[:n]) return received, nil } func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) { client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout) if err != nil { return nil, nil, err } conn, err := websocket.NewClient(w.Config, client) if err != nil { return nil, nil, err } return conn, client, err } oxy-2.0.0/forward/headers.go000066400000000000000000000033511450475076400157540ustar00rootroot00000000000000package forward // X-* Header names. const ( XForwardedProto = "X-Forwarded-Proto" XForwardedFor = "X-Forwarded-For" XForwardedHost = "X-Forwarded-Host" XForwardedPort = "X-Forwarded-Port" XForwardedServer = "X-Forwarded-Server" XRealIP = "X-Real-Ip" ) // Headers names. const ( Connection = "Connection" KeepAlive = "Keep-Alive" ProxyAuthenticate = "Proxy-Authenticate" ProxyAuthorization = "Proxy-Authorization" Te = "Te" // canonicalized version of "TE" Trailers = "Trailers" TransferEncoding = "Transfer-Encoding" Upgrade = "Upgrade" ContentLength = "Content-Length" ) // WebSocket Header names. const ( SecWebsocketKey = "Sec-Websocket-Key" SecWebsocketVersion = "Sec-Websocket-Version" SecWebsocketExtensions = "Sec-Websocket-Extensions" SecWebsocketAccept = "Sec-Websocket-Accept" ) // XHeaders X-* headers. var XHeaders = []string{ XForwardedProto, XForwardedFor, XForwardedHost, XForwardedPort, XForwardedServer, XRealIP, } // HopHeaders Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html // Copied from reverseproxy.go, too bad. var HopHeaders = []string{ Connection, KeepAlive, ProxyAuthenticate, ProxyAuthorization, Te, // canonicalized version of "TE" Trailers, TransferEncoding, Upgrade, } // WebsocketDialHeaders Websocket dial headers. var WebsocketDialHeaders = []string{ Upgrade, Connection, SecWebsocketKey, SecWebsocketVersion, SecWebsocketExtensions, SecWebsocketAccept, } // WebsocketUpgradeHeaders Websocket upgrade headers. var WebsocketUpgradeHeaders = []string{ Upgrade, Connection, SecWebsocketAccept, SecWebsocketExtensions, } oxy-2.0.0/forward/middlewares.go000066400000000000000000000014601450475076400166400ustar00rootroot00000000000000package forward import ( "net/http" "net/url" ) // Connection states. const ( StateConnected = iota StateDisconnected ) // URLForwardingStateListener URL forwarding state listener. type URLForwardingStateListener func(*url.URL, int) // StateListener listens on state change for urls. type StateListener struct { next http.Handler stateListener URLForwardingStateListener } // NewStateListener creates a new StateListener middleware. func NewStateListener(next http.Handler, stateListener URLForwardingStateListener) *StateListener { return &StateListener{next: next, stateListener: stateListener} } func (s *StateListener) ServeHTTP(rw http.ResponseWriter, req *http.Request) { s.stateListener(req.URL, StateConnected) s.next.ServeHTTP(rw, req) s.stateListener(req.URL, StateDisconnected) } oxy-2.0.0/forward/rewrite.go000066400000000000000000000036741450475076400160320ustar00rootroot00000000000000package forward import ( "net" "net/http" "os" "strings" "github.com/vulcand/oxy/v2/utils" ) // NewHeaderRewriter creates a new HeaderRewriter middleware. func NewHeaderRewriter() *HeaderRewriter { h, err := os.Hostname() if err != nil { h = "localhost" } return &HeaderRewriter{TrustForwardHeader: true, Hostname: h} } // HeaderRewriter is responsible for removing hop-by-hop headers and setting forwarding headers. type HeaderRewriter struct { TrustForwardHeader bool Hostname string } // clean up IP in case if it is ipv6 address and it has {zone} information in it, like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692". func ipv6fix(clientIP string) string { return strings.Split(clientIP, "%")[0] } // Rewrite request headers. func (rw *HeaderRewriter) Rewrite(req *http.Request) { if !rw.TrustForwardHeader { utils.RemoveHeaders(req.Header, XHeaders...) } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { clientIP = ipv6fix(clientIP) if req.Header.Get(XRealIP) == "" { req.Header.Set(XRealIP, clientIP) } } xfProto := req.Header.Get(XForwardedProto) if xfProto == "" { if req.TLS != nil { req.Header.Set(XForwardedProto, "https") } else { req.Header.Set(XForwardedProto, "http") } } if xfPort := req.Header.Get(XForwardedPort); xfPort == "" { req.Header.Set(XForwardedPort, forwardedPort(req)) } if xfHost := req.Header.Get(XForwardedHost); xfHost == "" && req.Host != "" { req.Header.Set(XForwardedHost, req.Host) } if rw.Hostname != "" { req.Header.Set(XForwardedServer, rw.Hostname) } } func forwardedPort(req *http.Request) string { if req == nil { return "" } if _, port, err := net.SplitHostPort(req.Host); err == nil && port != "" { return port } if req.Header.Get(XForwardedProto) == "https" || req.Header.Get(XForwardedProto) == "wss" { return "443" } if req.TLS != nil { return "443" } return "80" } oxy-2.0.0/forward/rewrite_test.go000066400000000000000000000020721450475076400170600ustar00rootroot00000000000000package forward import ( "testing" "github.com/stretchr/testify/assert" ) func TestIPv6Fix(t *testing.T) { testCases := []struct { desc string clientIP string expected string }{ { desc: "empty", clientIP: "", expected: "", }, { desc: "ipv4 localhost", clientIP: "127.0.0.1", expected: "127.0.0.1", }, { desc: "ipv4", clientIP: "10.13.14.15", expected: "10.13.14.15", }, { desc: "ipv6 zone", clientIP: `fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)`, expected: "fe80::d806:a55d:eb1b:49cc", }, { desc: "ipv6 medium", clientIP: `fe80::1`, expected: "fe80::1", }, { desc: "ipv6 small", clientIP: `2000::`, expected: "2000::", }, { desc: "ipv6", clientIP: `2001:3452:4952:2837::`, expected: "2001:3452:4952:2837::", }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() actual := ipv6fix(test.clientIP) assert.Equal(t, test.expected, actual) }) } } oxy-2.0.0/go.mod000066400000000000000000000013661450475076400134600ustar00rootroot00000000000000module github.com/vulcand/oxy/v2 go 1.19 require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 github.com/gorilla/websocket v1.5.0 github.com/mailgun/multibuf v0.1.2 github.com/segmentio/fasthash v1.0.3 github.com/stretchr/testify v1.8.4 github.com/vulcand/predicate v1.2.0 golang.org/x/net v0.15.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/gravitational/trace v1.1.16-0.20220114165159-14a9a7dd6aaf // indirect github.com/jonboulle/clockwork v0.2.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect golang.org/x/crypto v0.13.0 // indirect golang.org/x/sys v0.12.0 // indirect golang.org/x/term v0.12.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) oxy-2.0.0/go.sum000066400000000000000000000317751450475076400135140ustar00rootroot00000000000000cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gravitational/trace v1.1.16-0.20220114165159-14a9a7dd6aaf h1:C1GPyPJrOlJlIrcaBBiBpDsqZena2Ks8spa5xZqr1XQ= github.com/gravitational/trace v1.1.16-0.20220114165159-14a9a7dd6aaf/go.mod h1:zXqxTI6jXDdKnlf8s+nT+3c8LrwUEy3yNpO4XJL90lA= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailgun/multibuf v0.1.2 h1:QE9kE27lK6LFZB4aYNVtUPlWVHVCT0zpgUr2uoq/+jk= github.com/mailgun/multibuf v0.1.2/go.mod h1:E+sUhIy69qgT6EM57kCPdUTlHnjTuxQBO/yf6af9Hes= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/vulcand/predicate v1.2.0 h1:uFsW1gcnnR7R+QTID+FVcs0sSYlIGntoGOTb3rQJt50= github.com/vulcand/predicate v1.2.0/go.mod h1:VipoNYXny6c8N381zGUWkjuuNHiRbeAZhE7Qm9c+2GA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136 h1:A1gGSx58LAGVHUUsOf7IiR0u8Xb6W51gRwfDBhkdcaw= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= oxy-2.0.0/internal/000077500000000000000000000000001450475076400141605ustar00rootroot00000000000000oxy-2.0.0/internal/holsterv4/000077500000000000000000000000001450475076400161125ustar00rootroot00000000000000oxy-2.0.0/internal/holsterv4/LICENSE000066400000000000000000000261351450475076400171260ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. oxy-2.0.0/internal/holsterv4/README.md000066400000000000000000000016171450475076400173760ustar00rootroot00000000000000# What is this? This is a vendored copy of 2 packages (`clock` and `collections`) from the github.com/mailgun/holster@v4.2.5 module. The `clock` package was completely copied over and the following modifications were made: * pkg/errors was replaced with the stdlib errors package / fmt.Errorf's %w; * import names changed in blackbox test packages; * a small race condition in the testing logic was fixed using the provided mutex. The `collections` package only contains the priority_queue and ttlmap and corresponding test files. The only changes made to those files were to adjust the package names to use the vendored packages. ## Why TL;DR: holster is a utility repo with many dependencies and even with graph pruning using it in oxy can transitively impact oxy users in negative ways by forcing version bumps (at the least). Full details can be found here: https://github.com/vulcand/oxy/pull/223 oxy-2.0.0/internal/holsterv4/clock/000077500000000000000000000000001450475076400172055ustar00rootroot00000000000000oxy-2.0.0/internal/holsterv4/clock/README.md000066400000000000000000000031071450475076400204650ustar00rootroot00000000000000# Clock A drop in (almost) replacement for the system `time` package. It provides a way to make scheduled calls, timers and tickers deterministic in tests. By default it forwards all calls to the system `time` package. In test, however, it is possible to enable the frozen clock mode, and advance time manually to make scheduled even trigger at certain moments. # Usage ```go package foo import ( "testing" "github.com/vulcand/oxy/internal/holsterv4/clock" "github.com/stretchr/testify/assert" ) func TestSleep(t *testing.T) { // Freeze switches the clock package to the frozen clock mode. You need to // advance time manually from now on. Note that all scheduled events, timers // and ticker created before this call keep operating in real time. // // The initial time is set to now here, but you can set any datetime. clock.Freeze(clock.Now()) // Do not forget to revert the effect of Freeze at the end of the test. defer clock.Unfreeze() var fired bool clock.AfterFunc(100*clock.Millisecond, func() { fired = true }) clock.Advance(93*clock.Millisecond) // Advance will make all fire all events, timers, tickers that are // scheduled for the passed period of time. Note that scheduled functions // are called from within Advanced unlike system time package that calls // them in their own goroutine. assert.Equal(t, 97*clock.Millisecond, clock.Advance(6*clock.Millisecond)) assert.True(t, fired) assert.Equal(t, 100*clock.Millisecond, clock.Advance(1*clock.Millisecond)) assert.True(t, fired) } ``` oxy-2.0.0/internal/holsterv4/clock/clock.go000066400000000000000000000061051450475076400206310ustar00rootroot00000000000000//go:build !holster_test_mode // Package clock provides the same functions as the system package time. In // production it forwards all calls to the system time package, but in tests // the time can be frozen by calling Freeze function and from that point it has // to be advanced manually with Advance function making all scheduled calls // deterministic. // // The functions provided by the package have the same parameters and return // values as their system counterparts with a few exceptions. Where either // *time.Timer or *time.Ticker is returned by a system function, the clock // package counterpart returns clock.Timer or clock.Ticker interface // respectively. The interfaces provide API as respective structs except C is // not a channel, but a function that returns <-chan time.Time. package clock import "time" var ( frozenAt time.Time realtime = &systemTime{} provider Clock = realtime ) // Freeze after this function is called all time related functions start // generate deterministic timers that are triggered by Advance function. It is // supposed to be used in tests only. Returns an Unfreezer so it can be a // one-liner in tests: defer clock.Freeze(clock.Now()).Unfreeze() func Freeze(now time.Time) Unfreezer { frozenAt = now.UTC() provider = &frozenTime{now: now} return Unfreezer{} } type Unfreezer struct{} func (u Unfreezer) Unfreeze() { Unfreeze() } // Unfreeze reverses effect of Freeze. func Unfreeze() { provider = realtime } // Realtime returns a clock provider wrapping the SDK's time package. It is // supposed to be used in tests when time is frozen to schedule test timeouts. func Realtime() Clock { return realtime } // Makes the deterministic time move forward by the specified duration, firing // timers along the way in the natural order. It returns how much time has // passed since it was frozen. So you can assert on the return value in tests // to make it explicit where you stand on the deterministic time scale. func Advance(d time.Duration) time.Duration { ft, ok := provider.(*frozenTime) if !ok { panic("Freeze time first!") } ft.advance(d) return Now().UTC().Sub(frozenAt) } // Wait4Scheduled blocks until either there are n or more scheduled events, or // the timeout elapses. It returns true if the wait condition has been met // before the timeout expired, false otherwise. func Wait4Scheduled(count int, timeout time.Duration) bool { return provider.Wait4Scheduled(count, timeout) } // Now see time.Now. func Now() time.Time { return provider.Now() } // Sleep see time.Sleep. func Sleep(d time.Duration) { provider.Sleep(d) } // After see time.After. func After(d time.Duration) <-chan time.Time { return provider.After(d) } // NewTimer see time.NewTimer. func NewTimer(d time.Duration) Timer { return provider.NewTimer(d) } // AfterFunc see time.AfterFunc. func AfterFunc(d time.Duration, f func()) Timer { return provider.AfterFunc(d, f) } // NewTicker see time.Ticker. func NewTicker(d time.Duration) Ticker { return provider.NewTicker(d) } // Tick see time.Tick. func Tick(d time.Duration) <-chan time.Time { return provider.Tick(d) } oxy-2.0.0/internal/holsterv4/clock/clock_mutex.go000066400000000000000000000070661450475076400220620ustar00rootroot00000000000000//go:build holster_test_mode // Package clock provides the same functions as the system package time. In // production it forwards all calls to the system time package, but in tests // the time can be frozen by calling Freeze function and from that point it has // to be advanced manually with Advance function making all scheduled calls // deterministic. // // The functions provided by the package have the same parameters and return // values as their system counterparts with a few exceptions. Where either // *time.Timer or *time.Ticker is returned by a system function, the clock // package counterpart returns clock.Timer or clock.Ticker interface // respectively. The interfaces provide API as respective structs except C is // not a channel, but a function that returns <-chan time.Time. package clock import ( "sync" "time" ) var ( frozenAt time.Time realtime = &systemTime{} provider Clock = realtime rwMutex = sync.RWMutex{} ) // Freeze after this function is called all time related functions start // generate deterministic timers that are triggered by Advance function. It is // supposed to be used in tests only. Returns an Unfreezer so it can be a // one-liner in tests: defer clock.Freeze(clock.Now()).Unfreeze() func Freeze(now time.Time) Unfreezer { frozenAt = now.UTC() rwMutex.Lock() defer rwMutex.Unlock() provider = &frozenTime{now: now} return Unfreezer{} } type Unfreezer struct{} func (u Unfreezer) Unfreeze() { Unfreeze() } // Unfreeze reverses effect of Freeze. func Unfreeze() { rwMutex.Lock() defer rwMutex.Unlock() provider = realtime } // Realtime returns a clock provider wrapping the SDK's time package. It is // supposed to be used in tests when time is frozen to schedule test timeouts. func Realtime() Clock { return realtime } // Makes the deterministic time move forward by the specified duration, firing // timers along the way in the natural order. It returns how much time has // passed since it was frozen. So you can assert on the return value in tests // to make it explicit where you stand on the deterministic time scale. func Advance(d time.Duration) time.Duration { rwMutex.RLock() ft, ok := provider.(*frozenTime) rwMutex.RUnlock() if !ok { panic("Freeze time first!") } ft.advance(d) return Now().UTC().Sub(frozenAt) } // Wait4Scheduled blocks until either there are n or more scheduled events, or // the timeout elapses. It returns true if the wait condition has been met // before the timeout expired, false otherwise. func Wait4Scheduled(count int, timeout time.Duration) bool { rwMutex.RLock() defer rwMutex.RUnlock() return provider.Wait4Scheduled(count, timeout) } // Now see time.Now. func Now() time.Time { rwMutex.RLock() defer rwMutex.RUnlock() return provider.Now() } // Sleep see time.Sleep. func Sleep(d time.Duration) { rwMutex.RLock() defer rwMutex.RUnlock() provider.Sleep(d) } // After see time.After. func After(d time.Duration) <-chan time.Time { rwMutex.RLock() defer rwMutex.RUnlock() return provider.After(d) } // NewTimer see time.NewTimer. func NewTimer(d time.Duration) Timer { rwMutex.RLock() defer rwMutex.RUnlock() return provider.NewTimer(d) } // AfterFunc see time.AfterFunc. func AfterFunc(d time.Duration, f func()) Timer { rwMutex.RLock() defer rwMutex.RUnlock() return provider.AfterFunc(d, f) } // NewTicker see time.Ticker. func NewTicker(d time.Duration) Ticker { rwMutex.RLock() defer rwMutex.RUnlock() return provider.NewTicker(d) } // Tick see time.Tick. func Tick(d time.Duration) <-chan time.Time { rwMutex.RLock() defer rwMutex.RUnlock() return provider.Tick(d) } oxy-2.0.0/internal/holsterv4/clock/duration.go000066400000000000000000000025521450475076400213650ustar00rootroot00000000000000package clock import ( "encoding/json" "fmt" ) type DurationJSON struct { Duration Duration } func NewDurationJSON(v interface{}) (DurationJSON, error) { switch v := v.(type) { case Duration: return DurationJSON{Duration: v}, nil case float64: return DurationJSON{Duration: Duration(v)}, nil case int64: return DurationJSON{Duration: Duration(v)}, nil case int: return DurationJSON{Duration: Duration(v)}, nil case []byte: duration, err := ParseDuration(string(v)) if err != nil { return DurationJSON{}, fmt.Errorf("while parsing []byte: %w", err) } return DurationJSON{Duration: duration}, nil case string: duration, err := ParseDuration(v) if err != nil { return DurationJSON{}, fmt.Errorf("while parsing string: %w", err) } return DurationJSON{Duration: duration}, nil default: return DurationJSON{}, fmt.Errorf("bad type %T", v) } } func NewDurationJSONOrPanic(v interface{}) DurationJSON { d, err := NewDurationJSON(v) if err != nil { panic(err) } return d } func (d DurationJSON) MarshalJSON() ([]byte, error) { return json.Marshal(d.Duration.String()) } func (d *DurationJSON) UnmarshalJSON(b []byte) error { var v interface{} var err error if err = json.Unmarshal(b, &v); err != nil { return err } *d, err = NewDurationJSON(v) return err } func (d DurationJSON) String() string { return d.Duration.String() } oxy-2.0.0/internal/holsterv4/clock/duration_test.go000066400000000000000000000031531450475076400224220ustar00rootroot00000000000000package clock_test import ( "encoding/json" "testing" "github.com/stretchr/testify/suite" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) type DurationSuite struct { suite.Suite } func TestDurationSuite(t *testing.T) { suite.Run(t, new(DurationSuite)) } func (s *DurationSuite) TestNewOk() { for _, v := range []interface{}{ 42 * clock.Second, int(42000000000), int64(42000000000), 42000000000., "42s", []byte("42s"), } { d, err := clock.NewDurationJSON(v) s.Nil(err) s.Equal(42*clock.Second, d.Duration) } } func (s *DurationSuite) TestNewError() { for _, tc := range []struct { v interface{} errMsg string }{{ v: "foo", errMsg: "while parsing string: time: invalid duration \"foo\"", }, { v: []byte("foo"), errMsg: "while parsing []byte: time: invalid duration \"foo\"", }, { v: true, errMsg: "bad type bool", }} { d, err := clock.NewDurationJSON(tc.v) s.Equal(tc.errMsg, err.Error()) s.Equal(clock.DurationJSON{}, d) } } func (s *DurationSuite) TestUnmarshal() { for _, v := range []string{ `{"foo": 42000000000}`, `{"foo": 0.42e11}`, `{"foo": "42s"}`, } { var withDuration struct { Foo clock.DurationJSON `json:"foo"` } err := json.Unmarshal([]byte(v), &withDuration) s.Nil(err) s.Equal(42*clock.Second, withDuration.Foo.Duration) } } func (s *DurationSuite) TestMarshalling() { d, err := clock.NewDurationJSON(42 * clock.Second) s.Nil(err) encoded, err := d.MarshalJSON() s.Nil(err) var decoded clock.DurationJSON err = decoded.UnmarshalJSON(encoded) s.Nil(err) s.Equal(d, decoded) s.Equal("42s", decoded.String()) } oxy-2.0.0/internal/holsterv4/clock/frozen.go000066400000000000000000000105041450475076400210370ustar00rootroot00000000000000package clock import ( "errors" "sync" "time" ) type frozenTime struct { mu sync.Mutex now time.Time timers []*frozenTimer waiter *waiter } type waiter struct { count int signalCh chan struct{} } func (ft *frozenTime) Now() time.Time { ft.mu.Lock() defer ft.mu.Unlock() return ft.now } func (ft *frozenTime) Sleep(d time.Duration) { <-ft.NewTimer(d).C() } func (ft *frozenTime) After(d time.Duration) <-chan time.Time { return ft.NewTimer(d).C() } func (ft *frozenTime) NewTimer(d time.Duration) Timer { return ft.AfterFunc(d, nil) } func (ft *frozenTime) AfterFunc(d time.Duration, f func()) Timer { t := &frozenTimer{ ft: ft, when: ft.Now().Add(d), f: f, } if f == nil { t.c = make(chan time.Time, 1) } ft.startTimer(t) return t } func (ft *frozenTime) advance(d time.Duration) { ft.mu.Lock() defer ft.mu.Unlock() ft.now = ft.now.Add(d) for t := ft.nextExpired(); t != nil; t = ft.nextExpired() { // Send the timer expiration time to the timer channel if it is // defined. But make sure not to block on the send if the channel is // full. This behavior will make a ticker skip beats if it readers are // not fast enough. if t.c != nil { select { case t.c <- t.when: default: } } // If it is a ticking timer then schedule next tick, otherwise mark it // as stopped. if t.interval != 0 { t.when = t.when.Add(t.interval) t.stopped = false ft.unlockedStartTimer(t) } // If a function is associated with the timer then call it, but make // sure to release the lock for the time of call it is necessary // because the lock is not re-entrant but the function may need to // start another timer or ticker. if t.f != nil { func() { ft.mu.Unlock() defer ft.mu.Lock() t.f() }() } } } func (ft *frozenTime) stopTimer(t *frozenTimer) bool { ft.mu.Lock() defer ft.mu.Unlock() if t.stopped { return false } for i, curr := range ft.timers { if curr == t { t.stopped = true copy(ft.timers[i:], ft.timers[i+1:]) lastIdx := len(ft.timers) - 1 ft.timers[lastIdx] = nil ft.timers = ft.timers[:lastIdx] return true } } return false } func (ft *frozenTime) nextExpired() *frozenTimer { if len(ft.timers) == 0 { return nil } t := ft.timers[0] if ft.now.Before(t.when) { return nil } copy(ft.timers, ft.timers[1:]) lastIdx := len(ft.timers) - 1 ft.timers[lastIdx] = nil ft.timers = ft.timers[:lastIdx] t.stopped = true return t } func (ft *frozenTime) startTimer(t *frozenTimer) { ft.mu.Lock() defer ft.mu.Unlock() ft.unlockedStartTimer(t) if ft.waiter == nil { return } if len(ft.timers) >= ft.waiter.count { close(ft.waiter.signalCh) } } func (ft *frozenTime) unlockedStartTimer(t *frozenTimer) { pos := 0 for _, curr := range ft.timers { if t.when.Before(curr.when) { break } pos++ } ft.timers = append(ft.timers, nil) copy(ft.timers[pos+1:], ft.timers[pos:]) ft.timers[pos] = t } type frozenTimer struct { ft *frozenTime when time.Time interval time.Duration stopped bool c chan time.Time f func() } func (t *frozenTimer) C() <-chan time.Time { return t.c } func (t *frozenTimer) Stop() bool { return t.ft.stopTimer(t) } func (t *frozenTimer) Reset(d time.Duration) bool { active := t.ft.stopTimer(t) t.when = t.ft.Now().Add(d) t.ft.startTimer(t) return active } type frozenTicker struct { t *frozenTimer } func (t *frozenTicker) C() <-chan time.Time { return t.t.C() } func (t *frozenTicker) Stop() { t.t.Stop() } func (ft *frozenTime) NewTicker(d time.Duration) Ticker { if d <= 0 { panic(errors.New("non-positive interval for NewTicker")) } t := &frozenTimer{ ft: ft, when: ft.Now().Add(d), interval: d, c: make(chan time.Time, 1), } ft.startTimer(t) return &frozenTicker{t} } func (ft *frozenTime) Tick(d time.Duration) <-chan time.Time { if d <= 0 { return nil } return ft.NewTicker(d).C() } func (ft *frozenTime) Wait4Scheduled(count int, timeout time.Duration) bool { ft.mu.Lock() if len(ft.timers) >= count { ft.mu.Unlock() return true } if ft.waiter != nil { panic("Concurrent call") } ft.waiter = &waiter{count, make(chan struct{})} ft.mu.Unlock() success := false select { case <-ft.waiter.signalCh: success = true case <-time.After(timeout): } ft.mu.Lock() ft.waiter = nil ft.mu.Unlock() return success } oxy-2.0.0/internal/holsterv4/clock/frozen_test.go000066400000000000000000000164431450475076400221060ustar00rootroot00000000000000package clock import ( "fmt" "testing" "time" "github.com/stretchr/testify/suite" ) func TestFreezeUnfreeze(t *testing.T) { defer Freeze(Now()).Unfreeze() } type FrozenSuite struct { suite.Suite epoch time.Time } func TestFrozenSuite(t *testing.T) { suite.Run(t, new(FrozenSuite)) } func (s *FrozenSuite) SetupSuite() { var err error s.epoch, err = time.Parse(time.RFC3339, "2009-02-19T00:00:00Z") s.Require().NoError(err) } func (s *FrozenSuite) SetupTest() { Freeze(s.epoch) } func (s *FrozenSuite) TearDownTest() { Unfreeze() } func (s *FrozenSuite) TestAdvanceNow() { s.Require().Equal(s.epoch, Now()) s.Require().Equal(42*time.Millisecond, Advance(42*time.Millisecond)) s.Require().Equal(s.epoch.Add(42*time.Millisecond), Now()) s.Require().Equal(55*time.Millisecond, Advance(13*time.Millisecond)) s.Require().Equal(74*time.Millisecond, Advance(19*time.Millisecond)) s.Require().Equal(s.epoch.Add(74*time.Millisecond), Now()) } func (s *FrozenSuite) TestSleep() { hits := make(chan int, 100) delays := []int{60, 100, 90, 131, 999, 5} for i, tc := range []struct { desc string fn func(delayMs int) }{{ desc: "Sleep", fn: func(delay int) { Sleep(time.Duration(delay) * time.Millisecond) hits <- delay }, }, { desc: "After", fn: func(delay int) { <-After(time.Duration(delay) * time.Millisecond) hits <- delay }, }, { desc: "AfterFunc", fn: func(delay int) { AfterFunc(time.Duration(delay)*time.Millisecond, func() { hits <- delay }) }, }, { desc: "NewTimer", fn: func(delay int) { t := NewTimer(time.Duration(delay) * time.Millisecond) <-t.C() hits <- delay }, }} { fmt.Printf("Test case #%d: %s", i, tc.desc) for _, delay := range delays { go tc.fn(delay) } // Spin-wait for all goroutines to fall asleep. ft := provider.(*frozenTime) for { var brk bool ft.mu.Lock() if len(ft.timers) == len(delays) { brk = true } ft.mu.Unlock() if brk { break } time.Sleep(10 * time.Millisecond) } runningMs := 0 for i, delayMs := range []int{5, 60, 90, 100, 131, 999} { fmt.Printf("Checking timer #%d, delay=%d\n", i, delayMs) delta := delayMs - runningMs - 1 Advance(time.Duration(delta) * time.Millisecond) // Check before each timer deadline that it is not triggered yet. s.assertHits(hits, []int{}) // When Advance(1 * time.Millisecond) // Then s.assertHits(hits, []int{delayMs}) runningMs += delta + 1 } Advance(1000 * time.Millisecond) s.assertHits(hits, []int{}) } } // Timers scheduled to trigger at the same time do that in the order they were // created. func (s *FrozenSuite) TestSameTime() { var hits []int AfterFunc(100, func() { hits = append(hits, 3) }) AfterFunc(100, func() { hits = append(hits, 1) }) AfterFunc(99, func() { hits = append(hits, 2) }) AfterFunc(100, func() { hits = append(hits, 5) }) AfterFunc(101, func() { hits = append(hits, 4) }) AfterFunc(101, func() { hits = append(hits, 6) }) // When Advance(100) // Then s.Require().Equal([]int{2, 3, 1, 5}, hits) } func (s *FrozenSuite) TestTimerStop() { hits := []int{} AfterFunc(100, func() { hits = append(hits, 1) }) t := AfterFunc(100, func() { hits = append(hits, 2) }) AfterFunc(100, func() { hits = append(hits, 3) }) Advance(99) s.Require().Equal([]int{}, hits) // When active1 := t.Stop() active2 := t.Stop() // Then s.Require().Equal(true, active1) s.Require().Equal(false, active2) Advance(1) s.Require().Equal([]int{1, 3}, hits) } func (s *FrozenSuite) TestReset() { hits := []int{} t1 := AfterFunc(100, func() { hits = append(hits, 1) }) t2 := AfterFunc(100, func() { hits = append(hits, 2) }) AfterFunc(100, func() { hits = append(hits, 3) }) Advance(99) s.Require().Equal([]int{}, hits) // When active1 := t1.Reset(1) // Reset to the same time active2 := t2.Reset(7) // Then s.Require().Equal(true, active1) s.Require().Equal(true, active2) Advance(1) s.Require().Equal([]int{3, 1}, hits) Advance(5) s.Require().Equal([]int{3, 1}, hits) Advance(1) s.Require().Equal([]int{3, 1, 2}, hits) } // Reset to the same time just puts the timer at the end of the trigger list // for the date. func (s *FrozenSuite) TestResetSame() { hits := []int{} t := AfterFunc(100, func() { hits = append(hits, 1) }) AfterFunc(100, func() { hits = append(hits, 2) }) AfterFunc(100, func() { hits = append(hits, 3) }) AfterFunc(101, func() { hits = append(hits, 4) }) Advance(9) // When active := t.Reset(91) // Then s.Require().Equal(true, active) Advance(90) s.Require().Equal([]int{}, hits) Advance(1) s.Require().Equal([]int{2, 3, 1}, hits) } func (s *FrozenSuite) TestTicker() { t := NewTicker(100) Advance(99) s.assertNotFired(t.C()) Advance(1) s.Require().Equal(<-t.C(), s.epoch.Add(100)) Advance(750) s.Require().Equal(<-t.C(), s.epoch.Add(200)) Advance(49) s.assertNotFired(t.C()) Advance(1) s.Require().Equal(<-t.C(), s.epoch.Add(900)) t.Stop() Advance(300) s.assertNotFired(t.C()) } func (s *FrozenSuite) TestTickerZero() { defer func() { recover() }() NewTicker(0) s.Fail("Should panic") } func (s *FrozenSuite) TestTick() { ch := Tick(100) Advance(99) s.assertNotFired(ch) Advance(1) s.Require().Equal(<-ch, s.epoch.Add(100)) Advance(750) s.Require().Equal(<-ch, s.epoch.Add(200)) Advance(49) s.assertNotFired(ch) Advance(1) s.Require().Equal(<-ch, s.epoch.Add(900)) } func (s *FrozenSuite) TestTickZero() { ch := Tick(0) s.Require().Nil(ch) } func (s *FrozenSuite) TestNewStoppedTimer() { t := NewStoppedTimer() // When/Then select { case <-t.C(): s.Fail("Timer should not have fired") default: } s.Require().Equal(false, t.Stop()) } func (s *FrozenSuite) TestWait4Scheduled() { After(100 * Millisecond) After(100 * Millisecond) s.Require().Equal(false, Wait4Scheduled(3, 0)) startedCh := make(chan struct{}) resultCh := make(chan bool) go func() { close(startedCh) resultCh <- Wait4Scheduled(3, 5*Second) }() // Allow some time for waiter to be set and start waiting for a signal. <-startedCh time.Sleep(50 * Millisecond) // When After(100 * Millisecond) // Then s.Require().Equal(true, <-resultCh) } // If there is enough timers scheduled already, then a shortcut execution path // is taken and Wait4Scheduled returns immediately. func (s *FrozenSuite) TestWait4ScheduledImmediate() { After(100 * Millisecond) After(100 * Millisecond) // When/Then s.Require().Equal(true, Wait4Scheduled(2, 0)) } func (s *FrozenSuite) TestSince() { s.Require().Equal(Duration(0), Since(Now())) s.Require().Equal(-Millisecond, Since(Now().Add(Millisecond))) s.Require().Equal(Millisecond, Since(Now().Add(-Millisecond))) } func (s *FrozenSuite) TestUntil() { s.Require().Equal(Duration(0), Until(Now())) s.Require().Equal(Millisecond, Until(Now().Add(Millisecond))) s.Require().Equal(-Millisecond, Until(Now().Add(-Millisecond))) } func (s *FrozenSuite) assertHits(got <-chan int, want []int) { for i, w := range want { var g int select { case g = <-got: case <-time.After(100 * time.Millisecond): s.Failf("Missing hit", "want=%v", w) return } s.Require().Equal(w, g, "Hit #%d", i) } for { select { case g := <-got: s.Failf("Unexpected hit", "got=%v", g) default: return } } } func (s *FrozenSuite) assertNotFired(ch <-chan time.Time) { select { case <-ch: s.Fail("Premature fire") default: } } oxy-2.0.0/internal/holsterv4/clock/go19.go000066400000000000000000000044031450475076400203140ustar00rootroot00000000000000// +build go1.9 // This file introduces aliases to allow using of the clock package as a // drop-in replacement of the standard time package. package clock import "time" type ( Time = time.Time Duration = time.Duration Location = time.Location Weekday = time.Weekday Month = time.Month ParseError = time.ParseError ) const ( Nanosecond = time.Nanosecond Microsecond = time.Microsecond Millisecond = time.Millisecond Second = time.Second Minute = time.Minute Hour = time.Hour Sunday = time.Sunday Monday = time.Monday Tuesday = time.Tuesday Wednesday = time.Wednesday Thursday = time.Thursday Friday = time.Friday Saturday = time.Saturday January = time.January February = time.February March = time.March April = time.April May = time.May June = time.June July = time.July August = time.August September = time.September October = time.October November = time.November December = time.December ANSIC = time.ANSIC UnixDate = time.UnixDate RubyDate = time.RubyDate RFC822 = time.RFC822 RFC822Z = time.RFC822Z RFC850 = time.RFC850 RFC1123 = time.RFC1123 RFC1123Z = time.RFC1123Z RFC3339 = time.RFC3339 RFC3339Nano = time.RFC3339Nano Kitchen = time.Kitchen Stamp = time.Stamp StampMilli = time.StampMilli StampMicro = time.StampMicro StampNano = time.StampNano ) var ( UTC = time.UTC Local = time.Local ) func Date(year int, month Month, day, hour, min, sec, nsec int, loc *Location) Time { return time.Date(year, month, day, hour, min, sec, nsec, loc) } func FixedZone(name string, offset int) *Location { return time.FixedZone(name, offset) } func LoadLocation(name string) (*Location, error) { return time.LoadLocation(name) } func Parse(layout, value string) (Time, error) { return time.Parse(layout, value) } func ParseDuration(s string) (Duration, error) { return time.ParseDuration(s) } func ParseInLocation(layout, value string, loc *Location) (Time, error) { return time.ParseInLocation(layout, value, loc) } func Unix(sec int64, nsec int64) Time { return time.Unix(sec, nsec) } func Since(t Time) Duration { return provider.Now().Sub(t) } func Until(t Time) Duration { return t.Sub(provider.Now()) } oxy-2.0.0/internal/holsterv4/clock/interface.go000066400000000000000000000014001450475076400214670ustar00rootroot00000000000000package clock import "time" // Timer see time.Timer. type Timer interface { C() <-chan time.Time Stop() bool Reset(d time.Duration) bool } // Ticker see time.Ticker. type Ticker interface { C() <-chan time.Time Stop() } // NewStoppedTimer returns a stopped timer. Call Reset to get it ticking. func NewStoppedTimer() Timer { t := NewTimer(42 * time.Hour) t.Stop() return t } // Clock is an interface that mimics the one of the SDK time package. type Clock interface { Now() time.Time Sleep(d time.Duration) After(d time.Duration) <-chan time.Time NewTimer(d time.Duration) Timer AfterFunc(d time.Duration, f func()) Timer NewTicker(d time.Duration) Ticker Tick(d time.Duration) <-chan time.Time Wait4Scheduled(n int, timeout time.Duration) bool } oxy-2.0.0/internal/holsterv4/clock/rfc822.go000066400000000000000000000062251450475076400205470ustar00rootroot00000000000000package clock import ( "strconv" "time" ) var datetimeLayouts = [48]string{ // Day first month 2nd abbreviated. "Mon, 2 Jan 2006 15:04:05 MST", "Mon, 2 Jan 2006 15:04:05 -0700", "Mon, 2 Jan 2006 15:04:05 -0700 (MST)", "2 Jan 2006 15:04:05 MST", "2 Jan 2006 15:04:05 -0700", "2 Jan 2006 15:04:05 -0700 (MST)", "Mon, 2 Jan 2006 15:04 MST", "Mon, 2 Jan 2006 15:04 -0700", "Mon, 2 Jan 2006 15:04 -0700 (MST)", "2 Jan 2006 15:04 MST", "2 Jan 2006 15:04 -0700", "2 Jan 2006 15:04 -0700 (MST)", // Month first day 2nd abbreviated. "Mon, Jan 2 2006 15:04:05 MST", "Mon, Jan 2 2006 15:04:05 -0700", "Mon, Jan 2 2006 15:04:05 -0700 (MST)", "Jan 2 2006 15:04:05 MST", "Jan 2 2006 15:04:05 -0700", "Jan 2 2006 15:04:05 -0700 (MST)", "Mon, Jan 2 2006 15:04 MST", "Mon, Jan 2 2006 15:04 -0700", "Mon, Jan 2 2006 15:04 -0700 (MST)", "Jan 2 2006 15:04 MST", "Jan 2 2006 15:04 -0700", "Jan 2 2006 15:04 -0700 (MST)", // Day first month 2nd not abbreviated. "Mon, 2 January 2006 15:04:05 MST", "Mon, 2 January 2006 15:04:05 -0700", "Mon, 2 January 2006 15:04:05 -0700 (MST)", "2 January 2006 15:04:05 MST", "2 January 2006 15:04:05 -0700", "2 January 2006 15:04:05 -0700 (MST)", "Mon, 2 January 2006 15:04 MST", "Mon, 2 January 2006 15:04 -0700", "Mon, 2 January 2006 15:04 -0700 (MST)", "2 January 2006 15:04 MST", "2 January 2006 15:04 -0700", "2 January 2006 15:04 -0700 (MST)", // Month first day 2nd not abbreviated. "Mon, January 2 2006 15:04:05 MST", "Mon, January 2 2006 15:04:05 -0700", "Mon, January 2 2006 15:04:05 -0700 (MST)", "January 2 2006 15:04:05 MST", "January 2 2006 15:04:05 -0700", "January 2 2006 15:04:05 -0700 (MST)", "Mon, January 2 2006 15:04 MST", "Mon, January 2 2006 15:04 -0700", "Mon, January 2 2006 15:04 -0700 (MST)", "January 2 2006 15:04 MST", "January 2 2006 15:04 -0700", "January 2 2006 15:04 -0700 (MST)", } // Allows seamless JSON encoding/decoding of rfc822 formatted timestamps. // https://www.ietf.org/rfc/rfc822.txt section 5. type RFC822Time struct { Time } // NewRFC822Time creates RFC822Time from a standard Time. The created value is // truncated down to second precision because RFC822 does not allow for better. func NewRFC822Time(t Time) RFC822Time { return RFC822Time{Time: t.Truncate(Second)} } // ParseRFC822Time parses an RFC822 time string. func ParseRFC822Time(s string) (Time, error) { var t time.Time var err error for _, layout := range datetimeLayouts { t, err = Parse(layout, s) if err == nil { return t, err } } return t, err } // NewRFC822Time creates RFC822Time from a Unix timestamp (seconds from Epoch). func NewRFC822TimeFromUnix(timestamp int64) RFC822Time { return RFC822Time{Time: Unix(timestamp, 0).UTC()} } func (t RFC822Time) MarshalJSON() ([]byte, error) { return []byte(strconv.Quote(t.Format(RFC1123))), nil } func (t *RFC822Time) UnmarshalJSON(s []byte) error { q, err := strconv.Unquote(string(s)) if err != nil { return err } parsed, err := ParseRFC822Time(q) if err != nil { return err } t.Time = parsed return nil } func (t RFC822Time) String() string { return t.Format(RFC1123) } func (t RFC822Time) StringWithOffset() string { return t.Format(RFC1123Z) } oxy-2.0.0/internal/holsterv4/clock/rfc822_test.go000066400000000000000000000151231450475076400216030ustar00rootroot00000000000000package clock import ( "encoding/json" "fmt" "testing" "time" "github.com/stretchr/testify/assert" ) type testStruct struct { Time RFC822Time `json:"ts"` } func TestRFC822New(t *testing.T) { stdTime, err := Parse(RFC3339, "2019-08-29T11:20:07.123456+03:00") assert.NoError(t, err) rfc822TimeFromTime := NewRFC822Time(stdTime) rfc822TimeFromUnix := NewRFC822TimeFromUnix(stdTime.Unix()) assert.True(t, rfc822TimeFromTime.Equal(rfc822TimeFromUnix.Time), "want=%s, got=%s", rfc822TimeFromTime.Time, rfc822TimeFromUnix.Time) // Parsing from numerical offset to abbreviated offset is not always reliable. In this // context Go will fallback to the known numerical offset. assert.Equal(t, "Thu, 29 Aug 2019 11:20:07 +0300", rfc822TimeFromTime.String()) assert.Equal(t, "Thu, 29 Aug 2019 08:20:07 UTC", rfc822TimeFromUnix.String()) } // NewRFC822Time truncates to second precision. func TestRFC822SecondPrecision(t *testing.T) { stdTime1, err := Parse(RFC3339, "2019-08-29T11:20:07.111111+03:00") assert.NoError(t, err) stdTime2, err := Parse(RFC3339, "2019-08-29T11:20:07.999999+03:00") assert.NoError(t, err) assert.False(t, stdTime1.Equal(stdTime2)) rfc822Time1 := NewRFC822Time(stdTime1) rfc822Time2 := NewRFC822Time(stdTime2) assert.True(t, rfc822Time1.Equal(rfc822Time2.Time), "want=%s, got=%s", rfc822Time1.Time, rfc822Time2.Time) } // Marshaled representation is truncated down to second precision. func TestRFC822Marshaling(t *testing.T) { stdTime, err := Parse(RFC3339Nano, "2019-08-29T11:20:07.123456789+03:30") assert.NoError(t, err) ts := testStruct{Time: NewRFC822Time(stdTime)} encoded, err := json.Marshal(&ts) assert.NoError(t, err) assert.Equal(t, `{"ts":"Thu, 29 Aug 2019 11:20:07 +0330"}`, string(encoded)) } func TestRFC822Unmarshaling(t *testing.T) { for i, tc := range []struct { inRFC822 string outRFC3339 string outRFC822 string }{{ inRFC822: "Thu, 29 Aug 2019 11:20:07 GMT", outRFC3339: "2019-08-29T11:20:07Z", outRFC822: "Thu, 29 Aug 2019 11:20:07 GMT", }, { inRFC822: "Thu, 29 Aug 2019 11:20:07 MSK", // Extrapolating the numerical offset from an abbreviated offset is unreliable. In // this test case the RFC3339 will have the incorrect result due to limitation in // Go's time parser. outRFC3339: "2019-08-29T11:20:07Z", outRFC822: "Thu, 29 Aug 2019 11:20:07 MSK", }, { inRFC822: "Thu, 29 Aug 2019 11:20:07 -0000", outRFC3339: "2019-08-29T11:20:07Z", outRFC822: "Thu, 29 Aug 2019 11:20:07 -0000", }, { inRFC822: "Thu, 29 Aug 2019 11:20:07 +0000", outRFC3339: "2019-08-29T11:20:07Z", outRFC822: "Thu, 29 Aug 2019 11:20:07 +0000", }, { inRFC822: "Thu, 29 Aug 2019 11:20:07 +0300", outRFC3339: "2019-08-29T11:20:07+03:00", outRFC822: "Thu, 29 Aug 2019 11:20:07 +0300", }, { inRFC822: "Thu, 29 Aug 2019 11:20:07 +0330", outRFC3339: "2019-08-29T11:20:07+03:30", outRFC822: "Thu, 29 Aug 2019 11:20:07 +0330", }, { inRFC822: "Sun, 01 Sep 2019 11:20:07 +0300", outRFC3339: "2019-09-01T11:20:07+03:00", outRFC822: "Sun, 01 Sep 2019 11:20:07 +0300", }, { inRFC822: "Sun, 1 Sep 2019 11:20:07 +0300", outRFC3339: "2019-09-01T11:20:07+03:00", outRFC822: "Sun, 01 Sep 2019 11:20:07 +0300", }, { inRFC822: "Sun, 1 Sep 2019 11:20:07 +0300", outRFC3339: "2019-09-01T11:20:07+03:00", outRFC822: "Sun, 01 Sep 2019 11:20:07 +0300", }, { inRFC822: "Sun, 1 Sep 2019 11:20:07 UTC", outRFC3339: "2019-09-01T11:20:07Z", outRFC822: "Sun, 01 Sep 2019 11:20:07 UTC", }, { inRFC822: "Sun, 1 Sep 2019 11:20:07 UTC", outRFC3339: "2019-09-01T11:20:07Z", outRFC822: "Sun, 01 Sep 2019 11:20:07 UTC", }, { inRFC822: "Sun, 1 Sep 2019 11:20:07 GMT", outRFC3339: "2019-09-01T11:20:07Z", outRFC822: "Sun, 01 Sep 2019 11:20:07 GMT", }, { inRFC822: "Fri, 21 Nov 1997 09:55:06 -0600 (MDT)", outRFC3339: "1997-11-21T09:55:06-06:00", outRFC822: "Fri, 21 Nov 1997 09:55:06 MDT", }} { t.Run(tc.inRFC822, func(t *testing.T) { tcDesc := fmt.Sprintf("Test case #%d: %v", i, tc) var ts testStruct inEncoded := []byte(fmt.Sprintf(`{"ts":"%s"}`, tc.inRFC822)) err := json.Unmarshal(inEncoded, &ts) assert.NoError(t, err, tcDesc) assert.Equal(t, tc.outRFC3339, ts.Time.Format(RFC3339), tcDesc) actualEncoded, err := json.Marshal(&ts) assert.NoError(t, err, tcDesc) outEncoded := fmt.Sprintf(`{"ts":"%s"}`, tc.outRFC822) assert.Equal(t, outEncoded, string(actualEncoded), tcDesc) }) } } func TestRFC822UnmarshalingError(t *testing.T) { for _, tc := range []struct { inEncoded string outError string }{{ inEncoded: `{"ts": "Thu, 29 Aug 2019 11:20:07"}`, outError: `parsing time "Thu, 29 Aug 2019 11:20:07" as "January 2 2006 15:04 -0700 (MST)": cannot parse "Thu, 29 Aug 2019 11:20:07" as "January"`, }, { inEncoded: `{"ts": "foo"}`, outError: `parsing time "foo" as "January 2 2006 15:04 -0700 (MST)": cannot parse "foo" as "January"`, }, { inEncoded: `{"ts": 42}`, outError: "invalid syntax", }} { t.Run(tc.inEncoded, func(t *testing.T) { var ts testStruct err := json.Unmarshal([]byte(tc.inEncoded), &ts) assert.EqualError(t, err, tc.outError) }) } } func TestParseRFC822Time(t *testing.T) { for _, tt := range []struct { rfc822Time string }{ {"Thu, 3 Jun 2021 12:01:05 MST"}, {"Thu, 3 Jun 2021 12:01:05 -0700"}, {"Thu, 3 Jun 2021 12:01:05 -0700 (MST)"}, {"2 Jun 2021 17:06:41 GMT"}, {"2 Jun 2021 17:06:41 -0700"}, {"2 Jun 2021 17:06:41 -0700 (MST)"}, {"Mon, 30 August 2021 11:05:00 -0400"}, {"Thu, 3 June 2021 12:01:05 MST"}, {"Thu, 3 June 2021 12:01:05 -0700"}, {"Thu, 3 June 2021 12:01:05 -0700 (MST)"}, {"2 June 2021 17:06:41 GMT"}, {"2 June 2021 17:06:41 -0700"}, {"2 June 2021 17:06:41 -0700 (MST)"}, {"Wed, Nov 03 2021 17:48:06 CST"}, {"Wed, November 03 2021 17:48:06 CST"}, // Timestamps without seconds. {"Sun, 31 Oct 2021 12:10 -5000"}, {"Thu, 3 Jun 2021 12:01 MST"}, {"Thu, 3 Jun 2021 12:01 -0700"}, {"Thu, 3 Jun 2021 12:01 -0700 (MST)"}, {"2 Jun 2021 17:06 GMT"}, {"2 Jun 2021 17:06 -0700"}, {"2 Jun 2021 17:06 -0700 (MST)"}, {"Mon, 30 August 2021 11:05 -0400"}, {"Thu, 3 June 2021 12:01 MST"}, {"Thu, 3 June 2021 12:01 -0700"}, {"Thu, 3 June 2021 12:01 -0700 (MST)"}, {"2 June 2021 17:06 GMT"}, {"2 June 2021 17:06 -0700"}, {"2 June 2021 17:06 -0700 (MST)"}, {"Wed, Nov 03 2021 17:48 CST"}, {"Wed, November 03 2021 17:48 CST"}, } { t.Run(tt.rfc822Time, func(t *testing.T) { _, err := ParseRFC822Time(tt.rfc822Time) assert.NoError(t, err) }) } } func TestStringWithOffset(t *testing.T) { now := time.Now().UTC() r := NewRFC822Time(now) assert.Equal(t, now.Format(time.RFC1123Z), r.StringWithOffset()) } oxy-2.0.0/internal/holsterv4/clock/system.go000066400000000000000000000022541450475076400210630ustar00rootroot00000000000000package clock import "time" type systemTime struct{} func (st *systemTime) Now() time.Time { return time.Now() } func (st *systemTime) Sleep(d time.Duration) { time.Sleep(d) } func (st *systemTime) After(d time.Duration) <-chan time.Time { return time.After(d) } type systemTimer struct { t *time.Timer } func (st *systemTime) NewTimer(d time.Duration) Timer { t := time.NewTimer(d) return &systemTimer{t} } func (st *systemTime) AfterFunc(d time.Duration, f func()) Timer { t := time.AfterFunc(d, f) return &systemTimer{t} } func (t *systemTimer) C() <-chan time.Time { return t.t.C } func (t *systemTimer) Stop() bool { return t.t.Stop() } func (t *systemTimer) Reset(d time.Duration) bool { return t.t.Reset(d) } type systemTicker struct { t *time.Ticker } func (t *systemTicker) C() <-chan time.Time { return t.t.C } func (t *systemTicker) Stop() { t.t.Stop() } func (st *systemTime) NewTicker(d time.Duration) Ticker { t := time.NewTicker(d) return &systemTicker{t} } func (st *systemTime) Tick(d time.Duration) <-chan time.Time { return time.Tick(d) } func (st *systemTime) Wait4Scheduled(count int, timeout time.Duration) bool { panic("Not supported") } oxy-2.0.0/internal/holsterv4/clock/system_test.go000066400000000000000000000050731450475076400221240ustar00rootroot00000000000000package clock import ( "testing" "time" "github.com/stretchr/testify/assert" ) func TestSleep(t *testing.T) { start := Now() // When Sleep(100 * time.Millisecond) // Then if Now().Sub(start) < 100*time.Millisecond { assert.Fail(t, "Sleep did not last long enough") } } func TestAfter(t *testing.T) { start := Now() // When end := <-After(100 * time.Millisecond) // Then if end.Sub(start) < 100*time.Millisecond { assert.Fail(t, "Sleep did not last long enough") } } func TestAfterFunc(t *testing.T) { start := Now() endCh := make(chan time.Time, 1) // When AfterFunc(100*time.Millisecond, func() { endCh <- time.Now() }) // Then end := <-endCh if end.Sub(start) < 100*time.Millisecond { assert.Fail(t, "Sleep did not last long enough") } } func TestNewTimer(t *testing.T) { start := Now() // When timer := NewTimer(100 * time.Millisecond) // Then end := <-timer.C() if end.Sub(start) < 100*time.Millisecond { assert.Fail(t, "Sleep did not last long enough") } } func TestTimerStop(t *testing.T) { timer := NewTimer(50 * time.Millisecond) // When active := timer.Stop() // Then assert.Equal(t, true, active) time.Sleep(100) select { case <-timer.C(): assert.Fail(t, "Timer should not have fired") default: } } func TestTimerReset(t *testing.T) { t.Skip("fail on the CI for darwin") start := time.Now() timer := NewTimer(300 * time.Millisecond) // When timer.Reset(100 * time.Millisecond) // Then end := <-timer.C() if end.Sub(start) >= 150*time.Millisecond { assert.Fail(t, "Waited too long", end.Sub(start).String()) } } func TestNewTicker(t *testing.T) { start := Now() // When timer := NewTicker(100 * time.Millisecond) // Then end := <-timer.C() if end.Sub(start) < 100*time.Millisecond { assert.Fail(t, "Sleep did not last long enough") } end = <-timer.C() if end.Sub(start) < 200*time.Millisecond { assert.Fail(t, "Sleep did not last long enough") } timer.Stop() time.Sleep(150) select { case <-timer.C(): assert.Fail(t, "Ticker should not have fired") default: } } func TestTick(t *testing.T) { start := Now() // When ch := Tick(100 * time.Millisecond) // Then end := <-ch if end.Sub(start) < 100*time.Millisecond { assert.Fail(t, "Sleep did not last long enough") } end = <-ch if end.Sub(start) < 200*time.Millisecond { assert.Fail(t, "Sleep did not last long enough") } } func TestNewStoppedTimer(t *testing.T) { timer := NewStoppedTimer() // When/Then select { case <-timer.C(): assert.Fail(t, "Timer should not have fired") default: } assert.Equal(t, false, timer.Stop()) } oxy-2.0.0/internal/holsterv4/collections/000077500000000000000000000000001450475076400204305ustar00rootroot00000000000000oxy-2.0.0/internal/holsterv4/collections/README.md000066400000000000000000000010721450475076400217070ustar00rootroot00000000000000## Priority Queue Provides a Priority Queue implementation as described [here](https://en.wikipedia.org/wiki/Priority_queue) ```go queue := collections.NewPriorityQueue() queue.Push(&collections.PQItem{ Value: "thing3", Priority: 3, }) queue.Push(&collections.PQItem{ Value: "thing1", Priority: 1, }) queue.Push(&collections.PQItem{ Value: "thing2", Priority: 2, }) // Pops item off the queue according to the priority instead of the Push() order item := queue.Pop() fmt.Printf("Item: %s", item.Value.(string)) // Output: Item: thing1 ``` oxy-2.0.0/internal/holsterv4/collections/priority_queue.go000066400000000000000000000042601450475076400240460ustar00rootroot00000000000000/* Copyright 2017 Mailgun Technologies Inc Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package collections import ( "container/heap" ) // An PQItem is something we manage in a priority queue. type PQItem struct { Value interface{} Priority int // The priority of the item in the queue. // The index is needed by update and is maintained by the heap.Interface methods. index int // The index of the item in the heap. } // Implements a PriorityQueue type PriorityQueue struct { impl *pqImpl } func NewPriorityQueue() *PriorityQueue { mh := &pqImpl{} heap.Init(mh) return &PriorityQueue{impl: mh} } func (p PriorityQueue) Len() int { return p.impl.Len() } func (p *PriorityQueue) Push(el *PQItem) { heap.Push(p.impl, el) } func (p *PriorityQueue) Pop() *PQItem { el := heap.Pop(p.impl) return el.(*PQItem) } func (p *PriorityQueue) Peek() *PQItem { return (*p.impl)[0] } // Modifies the priority and value of an Item in the queue. func (p *PriorityQueue) Update(el *PQItem, priority int) { heap.Remove(p.impl, el.index) el.Priority = priority heap.Push(p.impl, el) } func (p *PriorityQueue) Remove(el *PQItem) { heap.Remove(p.impl, el.index) } // Actual Implementation using heap.Interface type pqImpl []*PQItem func (mh pqImpl) Len() int { return len(mh) } func (mh pqImpl) Less(i, j int) bool { return mh[i].Priority < mh[j].Priority } func (mh pqImpl) Swap(i, j int) { mh[i], mh[j] = mh[j], mh[i] mh[i].index = i mh[j].index = j } func (mh *pqImpl) Push(x interface{}) { n := len(*mh) item := x.(*PQItem) item.index = n *mh = append(*mh, item) } func (mh *pqImpl) Pop() interface{} { old := *mh n := len(old) item := old[n-1] item.index = -1 // for safety *mh = old[0 : n-1] return item } oxy-2.0.0/internal/holsterv4/collections/priority_queue_test.go000066400000000000000000000045451450475076400251130ustar00rootroot00000000000000/* Copyright 2017 Mailgun Technologies Inc Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package collections_test import ( "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/vulcand/oxy/v2/internal/holsterv4/collections" ) func toPtr(i int) interface{} { return &i } func toInt(i interface{}) int { return *(i.(*int)) } func TestPeek(t *testing.T) { mh := collections.NewPriorityQueue() el := &collections.PQItem{ Value: toPtr(1), Priority: 5, } mh.Push(el) assert.Equal(t, 1, toInt(mh.Peek().Value)) assert.Equal(t, 1, mh.Len()) el = &collections.PQItem{ Value: toPtr(2), Priority: 1, } mh.Push(el) assert.Equal(t, 2, mh.Len()) assert.Equal(t, 2, toInt(mh.Peek().Value)) assert.Equal(t, 2, toInt(mh.Peek().Value)) assert.Equal(t, 2, mh.Len()) el = mh.Pop() assert.Equal(t, 2, toInt(el.Value)) assert.Equal(t, 1, mh.Len()) assert.Equal(t, 1, toInt(mh.Peek().Value)) mh.Pop() assert.Equal(t, 0, mh.Len()) } func TestUpdate(t *testing.T) { mh := collections.NewPriorityQueue() x := &collections.PQItem{ Value: toPtr(1), Priority: 4, } y := &collections.PQItem{ Value: toPtr(2), Priority: 3, } z := &collections.PQItem{ Value: toPtr(3), Priority: 8, } mh.Push(x) mh.Push(y) mh.Push(z) assert.Equal(t, 2, toInt(mh.Peek().Value)) mh.Update(z, 1) assert.Equal(t, 3, toInt(mh.Peek().Value)) mh.Update(x, 0) assert.Equal(t, 1, toInt(mh.Peek().Value)) } func ExampleNewPriorityQueue() { queue := collections.NewPriorityQueue() queue.Push(&collections.PQItem{ Value: "thing3", Priority: 3, }) queue.Push(&collections.PQItem{ Value: "thing1", Priority: 1, }) queue.Push(&collections.PQItem{ Value: "thing2", Priority: 2, }) // Pops item off the queue according to the priority instead of the Push() order item := queue.Pop() fmt.Printf("Item: %s", item.Value.(string)) // Output: Item: thing1 } oxy-2.0.0/internal/holsterv4/collections/ttlmap.go000066400000000000000000000117541450475076400222700ustar00rootroot00000000000000/* Copyright 2017 Mailgun Technologies Inc Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package collections import ( "fmt" "sync" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) type TTLMap struct { // Optionally specifies a callback function to be // executed when an entry has expired OnExpire func(key string, i interface{}) capacity int elements map[string]*mapElement expiryTimes *PriorityQueue mutex *sync.RWMutex } type mapElement struct { key string value interface{} heapEl *PQItem } func NewTTLMap(capacity int) *TTLMap { if capacity <= 0 { capacity = 0 } return &TTLMap{ capacity: capacity, elements: make(map[string]*mapElement), expiryTimes: NewPriorityQueue(), mutex: &sync.RWMutex{}, } } func (m *TTLMap) Set(key string, value interface{}, ttlSeconds int) error { expiryTime, err := m.toEpochSeconds(ttlSeconds) if err != nil { return err } m.mutex.Lock() defer m.mutex.Unlock() return m.set(key, value, expiryTime) } func (m *TTLMap) Len() int { m.mutex.RLock() defer m.mutex.RUnlock() return len(m.elements) } func (m *TTLMap) Get(key string) (interface{}, bool) { value, mapEl, expired := m.lockNGet(key) if mapEl == nil { return nil, false } if expired { m.lockNDel(mapEl) return nil, false } return value, true } func (m *TTLMap) Increment(key string, value int, ttlSeconds int) (int, error) { expiryTime, err := m.toEpochSeconds(ttlSeconds) if err != nil { return 0, err } m.mutex.Lock() defer m.mutex.Unlock() mapEl, expired := m.get(key) if mapEl == nil || expired { m.set(key, value, expiryTime) return value, nil } currentValue, ok := mapEl.value.(int) if !ok { return 0, fmt.Errorf("Expected existing value to be integer, got %T", mapEl.value) } currentValue += value m.set(key, currentValue, expiryTime) return currentValue, nil } func (m *TTLMap) GetInt(key string) (int, bool, error) { valueI, exists := m.Get(key) if !exists { return 0, false, nil } value, ok := valueI.(int) if !ok { return 0, false, fmt.Errorf("Expected existing value to be integer, got %T", valueI) } return value, true, nil } func (m *TTLMap) set(key string, value interface{}, expiryTime int) error { if mapEl, ok := m.elements[key]; ok { mapEl.value = value m.expiryTimes.Update(mapEl.heapEl, expiryTime) return nil } if len(m.elements) >= m.capacity { m.freeSpace(1) } heapEl := &PQItem{ Priority: expiryTime, } mapEl := &mapElement{ key: key, value: value, heapEl: heapEl, } heapEl.Value = mapEl m.elements[key] = mapEl m.expiryTimes.Push(heapEl) return nil } func (m *TTLMap) lockNGet(key string) (value interface{}, mapEl *mapElement, expired bool) { m.mutex.RLock() defer m.mutex.RUnlock() mapEl, expired = m.get(key) value = nil if mapEl != nil { value = mapEl.value } return value, mapEl, expired } func (m *TTLMap) get(key string) (*mapElement, bool) { mapEl, ok := m.elements[key] if !ok { return nil, false } now := int(clock.Now().Unix()) expired := mapEl.heapEl.Priority <= now return mapEl, expired } func (m *TTLMap) lockNDel(mapEl *mapElement) { m.mutex.Lock() defer m.mutex.Unlock() // Map element could have been updated. Now that we have a lock // retrieve it again and check if it is still expired. var ok bool if mapEl, ok = m.elements[mapEl.key]; !ok { return } now := int(clock.Now().Unix()) if mapEl.heapEl.Priority > now { return } if m.OnExpire != nil { m.OnExpire(mapEl.key, mapEl.value) } delete(m.elements, mapEl.key) m.expiryTimes.Remove(mapEl.heapEl) } func (m *TTLMap) freeSpace(count int) { removed := m.RemoveExpired(count) if removed >= count { return } m.RemoveLastUsed(count - removed) } func (m *TTLMap) RemoveExpired(iterations int) int { removed := 0 now := int(clock.Now().Unix()) for i := 0; i < iterations; i += 1 { if len(m.elements) == 0 { break } heapEl := m.expiryTimes.Peek() if heapEl.Priority > now { break } m.expiryTimes.Pop() mapEl := heapEl.Value.(*mapElement) delete(m.elements, mapEl.key) removed += 1 } return removed } func (m *TTLMap) RemoveLastUsed(iterations int) { for i := 0; i < iterations; i += 1 { if len(m.elements) == 0 { return } heapEl := m.expiryTimes.Pop() mapEl := heapEl.Value.(*mapElement) delete(m.elements, mapEl.key) } } func (m *TTLMap) toEpochSeconds(ttlSeconds int) (int, error) { if ttlSeconds <= 0 { return 0, fmt.Errorf("ttlSeconds should be >= 0, got %d", ttlSeconds) } return int(clock.Now().Add(time.Second * time.Duration(ttlSeconds)).Unix()), nil } oxy-2.0.0/internal/holsterv4/collections/ttlmap_test.go000066400000000000000000000160131450475076400233200ustar00rootroot00000000000000/* Copyright 2017 Mailgun Technologies Inc Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package collections import ( "testing" "github.com/stretchr/testify/suite" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) type TTLMapSuite struct { suite.Suite } func TestTTLMapSuite(t *testing.T) { suite.Run(t, new(TTLMapSuite)) } func (s *TTLMapSuite) SetupTest() { clock.Freeze(clock.Date(2012, 3, 4, 5, 6, 7, 0, clock.UTC)) } func (s *TTLMapSuite) TearDownSuite() { clock.Unfreeze() } func (s *TTLMapSuite) TestSetWrong() { m := NewTTLMap(1) err := m.Set("a", 1, -1) s.Require().EqualError(err, "ttlSeconds should be >= 0, got -1") err = m.Set("a", 1, 0) s.Require().EqualError(err, "ttlSeconds should be >= 0, got 0") _, err = m.Increment("a", 1, 0) s.Require().EqualError(err, "ttlSeconds should be >= 0, got 0") _, err = m.Increment("a", 1, -1) s.Require().EqualError(err, "ttlSeconds should be >= 0, got -1") } func (s *TTLMapSuite) TestRemoveExpiredEmpty() { m := NewTTLMap(1) m.RemoveExpired(100) } func (s *TTLMapSuite) TestRemoveLastUsedEmpty() { m := NewTTLMap(1) m.RemoveLastUsed(100) } func (s *TTLMapSuite) TestGetSetExpire() { m := NewTTLMap(1) err := m.Set("a", 1, 1) s.Require().Equal(nil, err) valI, exists := m.Get("a") s.Require().Equal(true, exists) s.Require().Equal(1, valI) clock.Advance(1 * clock.Second) _, exists = m.Get("a") s.Require().Equal(false, exists) } func (s *TTLMapSuite) TestSetOverwrite() { m := NewTTLMap(1) err := m.Set("o", 1, 1) s.Require().Equal(nil, err) valI, exists := m.Get("o") s.Require().Equal(true, exists) s.Require().Equal(1, valI) err = m.Set("o", 2, 1) s.Require().Equal(nil, err) valI, exists = m.Get("o") s.Require().Equal(true, exists) s.Require().Equal(2, valI) } func (s *TTLMapSuite) TestRemoveExpiredEdgeCase() { m := NewTTLMap(1) err := m.Set("a", 1, 1) s.Require().Equal(nil, err) clock.Advance(1 * clock.Second) err = m.Set("b", 2, 1) s.Require().Equal(nil, err) valI, exists := m.Get("a") s.Require().Equal(false, exists) valI, exists = m.Get("b") s.Require().Equal(true, exists) s.Require().Equal(2, valI) s.Require().Equal(1, m.Len()) } func (s *TTLMapSuite) TestRemoveOutOfCapacity() { m := NewTTLMap(2) err := m.Set("a", 1, 5) s.Require().Equal(nil, err) clock.Advance(1 * clock.Second) err = m.Set("b", 2, 6) s.Require().Equal(nil, err) err = m.Set("c", 3, 10) s.Require().Equal(nil, err) valI, exists := m.Get("a") s.Require().Equal(false, exists) valI, exists = m.Get("b") s.Require().Equal(true, exists) s.Require().Equal(2, valI) valI, exists = m.Get("c") s.Require().Equal(true, exists) s.Require().Equal(3, valI) s.Require().Equal(2, m.Len()) } func (s *TTLMapSuite) TestGetNotExists() { m := NewTTLMap(1) _, exists := m.Get("a") s.Require().Equal(false, exists) } func (s *TTLMapSuite) TestGetIntNotExists() { m := NewTTLMap(1) _, exists, err := m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(false, exists) } func (s *TTLMapSuite) TestGetInvalidType() { m := NewTTLMap(1) m.Set("a", "banana", 5) _, _, err := m.GetInt("a") s.Require().EqualError(err, "Expected existing value to be integer, got string") _, err = m.Increment("a", 4, 1) s.Require().EqualError(err, "Expected existing value to be integer, got string") } func (s *TTLMapSuite) TestIncrementGetExpire() { m := NewTTLMap(1) m.Increment("a", 5, 1) val, exists, err := m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(5, val) clock.Advance(1 * clock.Second) m.Increment("a", 4, 1) val, exists, err = m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(4, val) } func (s *TTLMapSuite) TestIncrementOverwrite() { m := NewTTLMap(1) m.Increment("a", 5, 1) val, exists, err := m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(5, val) m.Increment("a", 4, 1) val, exists, err = m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(9, val) } func (s *TTLMapSuite) TestIncrementOutOfCapacity() { m := NewTTLMap(1) m.Increment("a", 5, 1) val, exists, err := m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(5, val) m.Increment("b", 4, 1) val, exists, err = m.GetInt("b") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(4, val) val, exists, err = m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(false, exists) } func (s *TTLMapSuite) TestIncrementRemovesExpired() { m := NewTTLMap(2) m.Increment("a", 1, 1) m.Increment("b", 2, 2) clock.Advance(1 * clock.Second) m.Increment("c", 3, 3) val, exists, err := m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(false, exists) val, exists, err = m.GetInt("b") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(2, val) val, exists, err = m.GetInt("c") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(3, val) } func (s *TTLMapSuite) TestIncrementRemovesLastUsed() { m := NewTTLMap(2) m.Increment("a", 1, 10) m.Increment("b", 2, 11) m.Increment("c", 3, 12) val, exists, err := m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(false, exists) val, exists, err = m.GetInt("b") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(2, val) val, exists, err = m.GetInt("c") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(3, val) } func (s *TTLMapSuite) TestIncrementUpdatesTtl() { m := NewTTLMap(1) m.Increment("a", 1, 1) m.Increment("a", 1, 10) clock.Advance(1 * clock.Second) val, exists, err := m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(2, val) } func (s *TTLMapSuite) TestUpdate() { m := NewTTLMap(1) m.Increment("a", 1, 1) m.Increment("a", 1, 10) clock.Advance(1 * clock.Second) val, exists, err := m.GetInt("a") s.Require().Equal(nil, err) s.Require().Equal(true, exists) s.Require().Equal(2, val) } func (s *TTLMapSuite) TestCallOnExpire() { var called bool var key string var val interface{} m := NewTTLMap(1) m.OnExpire = func(k string, el interface{}) { called = true key = k val = el } err := m.Set("a", 1, 1) s.Require().Equal(nil, err) valI, exists := m.Get("a") s.Require().Equal(true, exists) s.Require().Equal(1, valI) clock.Advance(1 * clock.Second) _, exists = m.Get("a") s.Require().Equal(false, exists) s.Require().Equal(true, called) s.Require().Equal("a", key) s.Require().Equal(1, val) } oxy-2.0.0/memmetrics/000077500000000000000000000000001450475076400145115ustar00rootroot00000000000000oxy-2.0.0/memmetrics/anomaly.go000066400000000000000000000064041450475076400165040ustar00rootroot00000000000000package memmetrics import ( "math" "sort" "time" ) // SplitLatencies provides simple anomaly detection for requests latencies. // it splits values into good or bad category based on the threshold and the median value. // If all values are not far from the median, it will return all values in 'good' set. // Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored. func SplitLatencies(values []time.Duration, precision time.Duration) (good map[time.Duration]bool, bad map[time.Duration]bool) { // Find the max latency M and then map each latency L to the ratio L/M and then call SplitFloat64 v2r := map[float64]time.Duration{} ratios := make([]float64, len(values)) m := maxTime(values) for i, v := range values { ratio := float64(v/precision+1) / float64(m/precision+1) // +1 is to avoid division by 0 v2r[ratio] = v ratios[i] = ratio } good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool) // Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise. vgood, vbad := SplitFloat64(2, 0, ratios) for r := range vgood { good[v2r[r]] = true } for r := range vbad { bad[v2r[r]] = true } return good, bad } // SplitRatios provides simple anomaly detection for ratio values, that are all in the range [0, 1] // it splits values into good or bad category based on the threshold and the median value. // If all values are not far from the median, it will return all values in 'good' set. func SplitRatios(values []float64) (good map[float64]bool, bad map[float64]bool) { return SplitFloat64(1.5, 0, values) } // SplitFloat64 provides simple anomaly detection for skewed data sets with no particular distribution. // In essence it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly // There's a corner case where there are just 2 values, so by definition there's no value that exceeds the threshold. // This case is solved by introducing additional value that we know is good, e.g. 0. That helps to improve the detection results // on such data sets. func SplitFloat64(threshold, sentinel float64, values []float64) (good map[float64]bool, bad map[float64]bool) { good, bad = make(map[float64]bool), make(map[float64]bool) var newValues []float64 if len(values)%2 == 0 { newValues = make([]float64, len(values)+1) copy(newValues, values) // Add a sentinel endpoint so we can distinguish outliers better newValues[len(newValues)-1] = sentinel } else { newValues = values } m := median(newValues) mAbs := medianAbsoluteDeviation(newValues) for _, v := range values { if v > (m+mAbs)*threshold { bad[v] = true } else { good[v] = true } } return good, bad } func median(values []float64) float64 { vals := make([]float64, len(values)) copy(vals, values) sort.Float64s(vals) l := len(vals) if l%2 != 0 { return vals[l/2] } return (vals[l/2-1] + vals[l/2]) / 2.0 } func medianAbsoluteDeviation(values []float64) float64 { m := median(values) distances := make([]float64, len(values)) for i, v := range values { distances[i] = math.Abs(v - m) } return median(distances) } func maxTime(vals []time.Duration) time.Duration { val := vals[0] for _, v := range vals { if v > val { val = v } } return val } oxy-2.0.0/memmetrics/anomaly_test.go000066400000000000000000000073411450475076400175440ustar00rootroot00000000000000package memmetrics import ( "strconv" "testing" "time" "github.com/stretchr/testify/assert" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) func TestMedian(t *testing.T) { testCases := []struct { desc string values []float64 expected float64 }{ { desc: "2 values", values: []float64{0.1, 0.2}, expected: (float64(0.1) + float64(0.2)) / 2.0, }, { desc: "3 values", values: []float64{0.3, 0.2, 0.5}, expected: 0.3, }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() actual := median(test.values) assert.Equal(t, test.expected, actual) }) } } func TestSplitRatios(t *testing.T) { testCases := []struct { values []float64 good []float64 bad []float64 }{ { values: []float64{0, 0}, good: []float64{0}, bad: []float64{}, }, { values: []float64{0, 1}, good: []float64{0}, bad: []float64{1}, }, { values: []float64{0.1, 0.1}, good: []float64{0.1}, bad: []float64{}, }, { values: []float64{0.15, 0.1}, good: []float64{0.15, 0.1}, bad: []float64{}, }, { values: []float64{0.01, 0.01}, good: []float64{0.01}, bad: []float64{}, }, { values: []float64{0.012, 0.01, 1}, good: []float64{0.012, 0.01}, bad: []float64{1}, }, { values: []float64{0, 0, 1, 1}, good: []float64{0}, bad: []float64{1}, }, { values: []float64{0, 0.1, 0.1, 0}, good: []float64{0}, bad: []float64{0.1}, }, { values: []float64{0, 0.01, 0.1, 0}, good: []float64{0}, bad: []float64{0.01, 0.1}, }, { values: []float64{0, 0.01, 0.02, 1}, good: []float64{0, 0.01, 0.02}, bad: []float64{1}, }, { values: []float64{0, 0, 0, 0, 0, 0.01, 0.02, 1}, good: []float64{0}, bad: []float64{0.01, 0.02, 1}, }, } for ind, test := range testCases { test := test t.Run(strconv.Itoa(ind), func(t *testing.T) { t.Parallel() good, bad := SplitRatios(test.values) vgood := make(map[float64]bool, len(test.good)) for _, v := range test.good { vgood[v] = true } vbad := make(map[float64]bool, len(test.bad)) for _, v := range test.bad { vbad[v] = true } assert.Equal(t, vgood, good) assert.Equal(t, vbad, bad) }) } } func TestSplitLatencies(t *testing.T) { testCases := []struct { values []int good []int bad []int }{ { values: []int{0, 0}, good: []int{0}, bad: []int{}, }, { values: []int{1, 2}, good: []int{1, 2}, bad: []int{}, }, { values: []int{1, 2, 4}, good: []int{1, 2, 4}, bad: []int{}, }, { values: []int{8, 8, 18}, good: []int{8}, bad: []int{18}, }, { values: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8, 97}, good: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8}, bad: []int{97}, }, { values: []int{1, 2, 4, 40}, good: []int{1, 2, 4}, bad: []int{40}, }, { values: []int{40, 60, 1000}, good: []int{40, 60}, bad: []int{1000}, }, } for ind, test := range testCases { test := test t.Run(strconv.Itoa(ind), func(t *testing.T) { t.Parallel() values := make([]time.Duration, len(test.values)) for i, d := range test.values { values[i] = clock.Millisecond * time.Duration(d) } good, bad := SplitLatencies(values, clock.Millisecond) vgood := make(map[time.Duration]bool, len(test.good)) for _, v := range test.good { vgood[time.Duration(v)*clock.Millisecond] = true } assert.Equal(t, vgood, good) vbad := make(map[time.Duration]bool, len(test.bad)) for _, v := range test.bad { vbad[time.Duration(v)*clock.Millisecond] = true } assert.Equal(t, vbad, bad) }) } } oxy-2.0.0/memmetrics/counter.go000066400000000000000000000071101450475076400165160ustar00rootroot00000000000000package memmetrics import ( "fmt" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) type rcOption func(*RollingCounter) error // RollingCounter Calculates in memory failure rate of an endpoint using rolling window of a predefined size. type RollingCounter struct { resolution time.Duration values []int countedBuckets int // how many samples in different buckets have we collected so far lastBucket int // last recorded bucket lastUpdated clock.Time } // NewCounter creates a counter with fixed amount of buckets that are rotated every resolution period. // E.g. 10 buckets with 1 second means that every new second the bucket is refreshed, so it maintains 10 seconds rolling window. // By default, creates a bucket with 10 buckets and 1 second resolution. func NewCounter(buckets int, resolution time.Duration, options ...rcOption) (*RollingCounter, error) { if buckets <= 0 { return nil, fmt.Errorf("buckets should be >= 0") } if resolution < clock.Second { return nil, fmt.Errorf("resolution should be larger than a second") } rc := &RollingCounter{ lastBucket: -1, resolution: resolution, values: make([]int, buckets), } for _, o := range options { if err := o(rc); err != nil { return nil, err } } return rc, nil } // Append appends a counter. func (c *RollingCounter) Append(o *RollingCounter) error { c.Inc(int(o.Count())) return nil } // Clone clones a counter. func (c *RollingCounter) Clone() *RollingCounter { c.cleanup() other := &RollingCounter{ resolution: c.resolution, values: make([]int, len(c.values)), lastBucket: c.lastBucket, lastUpdated: c.lastUpdated, } copy(other.values, c.values) return other } // Reset resets a counter. func (c *RollingCounter) Reset() { c.lastBucket = -1 c.countedBuckets = 0 c.lastUpdated = clock.Time{} for i := range c.values { c.values[i] = 0 } } // CountedBuckets gets counted buckets. func (c *RollingCounter) CountedBuckets() int { return c.countedBuckets } // Count counts. func (c *RollingCounter) Count() int64 { c.cleanup() return c.sum() } // Resolution gets resolution. func (c *RollingCounter) Resolution() time.Duration { return c.resolution } // Buckets gets buckets. func (c *RollingCounter) Buckets() int { return len(c.values) } // WindowSize gets windows size. func (c *RollingCounter) WindowSize() time.Duration { return time.Duration(len(c.values)) * c.resolution } // Inc increments counter. func (c *RollingCounter) Inc(v int) { c.cleanup() c.incBucketValue(v) } func (c *RollingCounter) incBucketValue(v int) { now := clock.Now().UTC() bucket := c.getBucket(now) c.values[bucket] += v c.lastUpdated = now // Update usage stats if we haven't collected enough data if c.countedBuckets < len(c.values) { // Only update if we have advanced to the next bucket and not incremented the value // in the current bucket. if c.lastBucket != bucket { c.lastBucket = bucket c.countedBuckets++ } } } // Returns the number in the moving window bucket that this slot occupies. func (c *RollingCounter) getBucket(t time.Time) int { return int(t.Truncate(c.resolution).Unix() % int64(len(c.values))) } // Reset buckets that were not updated. func (c *RollingCounter) cleanup() { now := clock.Now().UTC() for i := 0; i < len(c.values); i++ { now = now.Add(time.Duration(-1*i) * c.resolution) if now.Truncate(c.resolution).After(c.lastUpdated.Truncate(c.resolution)) { c.values[c.getBucket(now)] = 0 } else { break } } } func (c *RollingCounter) sum() int64 { out := int64(0) for _, v := range c.values { out += int64(v) } return out } oxy-2.0.0/memmetrics/counter_test.go000066400000000000000000000010221450475076400175510ustar00rootroot00000000000000package memmetrics import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) func TestCloneExpired(t *testing.T) { clock.Freeze(clock.Date(2012, 3, 4, 5, 6, 7, 0, clock.UTC)) cnt, err := NewCounter(3, clock.Second) require.NoError(t, err) cnt.Inc(1) clock.Advance(clock.Second) cnt.Inc(1) clock.Advance(clock.Second) cnt.Inc(1) clock.Advance(clock.Second) out := cnt.Clone() assert.EqualValues(t, 2, out.Count()) } oxy-2.0.0/memmetrics/histogram.go000066400000000000000000000117521450475076400170430ustar00rootroot00000000000000package memmetrics import ( "fmt" "time" "github.com/HdrHistogram/hdrhistogram-go" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) // HDRHistogram is a tiny wrapper around github.com/HdrHistogram/hdrhistogram-go that provides convenience functions for measuring http latencies. type HDRHistogram struct { // lowest trackable value low int64 // highest trackable value high int64 // significant figures sigfigs int h *hdrhistogram.Histogram } // NewHDRHistogram creates a new HDRHistogram. func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) { defer func() { if msg := recover(); msg != nil { err = fmt.Errorf("%s", msg) } }() return &HDRHistogram{ low: low, high: high, sigfigs: sigfigs, h: hdrhistogram.New(low, high, sigfigs), }, nil } // Export exports a HDRHistogram. func (h *HDRHistogram) Export() *HDRHistogram { var hist *hdrhistogram.Histogram if h.h != nil { snapshot := h.h.Export() hist = hdrhistogram.Import(snapshot) } return &HDRHistogram{low: h.low, high: h.high, sigfigs: h.sigfigs, h: hist} } // LatencyAtQuantile sets latency at quantile with microsecond precision. func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration { return time.Duration(h.ValueAtQuantile(q)) * clock.Microsecond } // RecordLatencies Records latencies with microsecond precision. func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error { return h.RecordValues(int64(d/clock.Microsecond), n) } // Reset resets a HDRHistogram. func (h *HDRHistogram) Reset() { h.h.Reset() } // ValueAtQuantile sets value at quantile. func (h *HDRHistogram) ValueAtQuantile(q float64) int64 { return h.h.ValueAtQuantile(q) } // RecordValues sets record values. func (h *HDRHistogram) RecordValues(v, n int64) error { return h.h.RecordValues(v, n) } // Merge merges a HDRHistogram. func (h *HDRHistogram) Merge(other *HDRHistogram) error { if other == nil { return fmt.Errorf("other is nil") } h.h.Merge(other.h) return nil } type rhOption func(r *RollingHDRHistogram) error // RollingHDRHistogram holds multiple histograms and rotates every period. // It provides resulting histogram as a result of a call of 'Merged' function. type RollingHDRHistogram struct { idx int lastRoll clock.Time period time.Duration bucketCount int low int64 high int64 sigfigs int buckets []*HDRHistogram } // NewRollingHDRHistogram created a new RollingHDRHistogram. func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, bucketCount int, options ...rhOption) (*RollingHDRHistogram, error) { rh := &RollingHDRHistogram{ bucketCount: bucketCount, period: period, low: low, high: high, sigfigs: sigfigs, } for _, o := range options { if err := o(rh); err != nil { return nil, err } } buckets := make([]*HDRHistogram, rh.bucketCount) for i := range buckets { h, err := NewHDRHistogram(low, high, sigfigs) if err != nil { return nil, err } buckets[i] = h } rh.buckets = buckets return rh, nil } // Export exports a RollingHDRHistogram. func (r *RollingHDRHistogram) Export() *RollingHDRHistogram { export := &RollingHDRHistogram{} export.idx = r.idx export.lastRoll = r.lastRoll export.period = r.period export.bucketCount = r.bucketCount export.low = r.low export.high = r.high export.sigfigs = r.sigfigs exportBuckets := make([]*HDRHistogram, len(r.buckets)) for i, hist := range r.buckets { exportBuckets[i] = hist.Export() } export.buckets = exportBuckets return export } // Append appends a RollingHDRHistogram. func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error { if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs { return fmt.Errorf("can't merge") } for i := range r.buckets { if err := r.buckets[i].Merge(o.buckets[i]); err != nil { return err } } return nil } // Reset resets a RollingHDRHistogram. func (r *RollingHDRHistogram) Reset() { r.idx = 0 r.lastRoll = clock.Now().UTC() for _, b := range r.buckets { b.Reset() } } func (r *RollingHDRHistogram) rotate() { r.idx = (r.idx + 1) % len(r.buckets) r.buckets[r.idx].Reset() } // Merged gets merged histogram. func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) { m, err := NewHDRHistogram(r.low, r.high, r.sigfigs) if err != nil { return m, err } for _, h := range r.buckets { if errMerge := m.Merge(h); errMerge != nil { return nil, errMerge } } return m, nil } func (r *RollingHDRHistogram) getHist() *HDRHistogram { if clock.Now().UTC().Sub(r.lastRoll) >= r.period { r.rotate() r.lastRoll = clock.Now().UTC() } return r.buckets[r.idx] } // RecordLatencies sets records latencies. func (r *RollingHDRHistogram) RecordLatencies(v time.Duration, n int64) error { return r.getHist().RecordLatencies(v, n) } // RecordValues sets record values. func (r *RollingHDRHistogram) RecordValues(v, n int64) error { return r.getHist().RecordValues(v, n) } oxy-2.0.0/memmetrics/histogram_test.go000066400000000000000000000075521450475076400201050ustar00rootroot00000000000000package memmetrics import ( "testing" "github.com/HdrHistogram/hdrhistogram-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" ) func TestMerge(t *testing.T) { a, err := NewHDRHistogram(1, 3600000, 2) require.NoError(t, err) err = a.RecordValues(1, 2) require.NoError(t, err) b, err := NewHDRHistogram(1, 3600000, 2) require.NoError(t, err) require.NoError(t, b.RecordValues(2, 1)) require.NoError(t, a.Merge(b)) assert.EqualValues(t, 1, a.ValueAtQuantile(50)) assert.EqualValues(t, 2, a.ValueAtQuantile(100)) } func TestMergeNil(t *testing.T) { a, err := NewHDRHistogram(1, 3600000, 1) require.NoError(t, err) require.Error(t, a.Merge(nil)) } func TestRotation(t *testing.T) { done := testutils.FreezeTime() defer done() h, err := NewRollingHDRHistogram( 1, // min value 3600000, // max value 3, // significant figures clock.Second, 2, // 2 histograms in a window ) require.NoError(t, err) require.NotNil(t, h) err = h.RecordValues(5, 1) require.NoError(t, err) m, err := h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) clock.Advance(clock.Second) require.NoError(t, h.RecordValues(2, 1)) require.NoError(t, h.RecordValues(1, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) // rotate, this means that the old value would evaporate clock.Advance(clock.Second) require.NoError(t, h.RecordValues(1, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 2, m.ValueAtQuantile(100)) } func TestReset(t *testing.T) { done := testutils.FreezeTime() defer done() h, err := NewRollingHDRHistogram( 1, // min value 3600000, // max value 3, // significant figures clock.Second, 2, // 2 histograms in a window ) require.NoError(t, err) require.NotNil(t, h) require.NoError(t, h.RecordValues(5, 1)) m, err := h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) clock.Advance(clock.Second) require.NoError(t, h.RecordValues(2, 1)) require.NoError(t, h.RecordValues(1, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) h.Reset() require.NoError(t, h.RecordValues(5, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) clock.Advance(clock.Second) require.NoError(t, h.RecordValues(2, 1)) require.NoError(t, h.RecordValues(1, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) } func TestHDRHistogramExportReturnsNewCopy(t *testing.T) { // Create HDRHistogram instance a := HDRHistogram{ low: 1, high: 2, sigfigs: 3, h: hdrhistogram.New(0, 1, 2), } // Get a copy and modify the original b := a.Export() a.low = 11 a.high = 12 a.sigfigs = 4 a.h = nil // Assert the copy has not been modified assert.EqualValues(t, 1, b.low) assert.EqualValues(t, 2, b.high) assert.Equal(t, 3, b.sigfigs) require.NotNil(t, b.h) } func TestRollingHDRHistogramExportReturnsNewCopy(t *testing.T) { origTime := clock.Now() done := testutils.FreezeTime() defer done() a := RollingHDRHistogram{ idx: 1, lastRoll: origTime, period: 2 * clock.Second, bucketCount: 3, low: 4, high: 5, sigfigs: 1, buckets: []*HDRHistogram{}, } b := a.Export() a.idx = 11 a.lastRoll = clock.Now().Add(1 * clock.Minute) a.period = 12 * clock.Second a.bucketCount = 13 a.low = 14 a.high = 15 a.sigfigs = 1 a.buckets = nil assert.Equal(t, 1, b.idx) assert.Equal(t, origTime, b.lastRoll) assert.Equal(t, 2*clock.Second, b.period) assert.Equal(t, 3, b.bucketCount) assert.Equal(t, int64(4), b.low) assert.EqualValues(t, 5, b.high) assert.NotNil(t, b.buckets) } oxy-2.0.0/memmetrics/options.go000066400000000000000000000011171450475076400165330ustar00rootroot00000000000000package memmetrics // RTOption represents an option you can pass to NewRTMetrics. type RTOption func(r *RTMetrics) error // RTCounter set a builder function for Counter. func RTCounter(fn NewCounterFn) RTOption { return func(r *RTMetrics) error { r.newCounter = fn return nil } } // RTHistogram set a builder function for RollingHDRHistogram. func RTHistogram(fn NewRollingHistogramFn) RTOption { return func(r *RTMetrics) error { r.newHist = fn return nil } } // RatioOption represents an option you can pass to NewRatioCounter. type RatioOption func(r *RatioCounter) error oxy-2.0.0/memmetrics/ratio.go000066400000000000000000000043651450475076400161660ustar00rootroot00000000000000package memmetrics import "time" // RatioCounter calculates a ratio of a/a+b over a rolling window of predefined buckets. type RatioCounter struct { a *RollingCounter b *RollingCounter } // NewRatioCounter creates a new RatioCounter. func NewRatioCounter(buckets int, resolution time.Duration, options ...RatioOption) (*RatioCounter, error) { rc := &RatioCounter{} for _, o := range options { if err := o(rc); err != nil { return nil, err } } a, err := NewCounter(buckets, resolution) if err != nil { return nil, err } b, err := NewCounter(buckets, resolution) if err != nil { return nil, err } rc.a = a rc.b = b return rc, nil } // Reset resets the counter. func (r *RatioCounter) Reset() { r.a.Reset() r.b.Reset() } // IsReady returns true if the counter is ready. func (r *RatioCounter) IsReady() bool { return r.a.countedBuckets+r.b.countedBuckets >= len(r.a.values) } // CountA gets count A. func (r *RatioCounter) CountA() int64 { return r.a.Count() } // CountB gets count B. func (r *RatioCounter) CountB() int64 { return r.b.Count() } // Resolution gets resolution. func (r *RatioCounter) Resolution() time.Duration { return r.a.Resolution() } // Buckets gets buckets. func (r *RatioCounter) Buckets() int { return r.a.Buckets() } // WindowSize gets windows size. func (r *RatioCounter) WindowSize() time.Duration { return r.a.WindowSize() } // ProcessedCount gets processed count. func (r *RatioCounter) ProcessedCount() int64 { return r.CountA() + r.CountB() } // Ratio gets ratio. func (r *RatioCounter) Ratio() float64 { a := r.a.Count() b := r.b.Count() // No data, return ok if a+b == 0 { return 0 } return float64(a) / float64(a+b) } // IncA increments counter A. func (r *RatioCounter) IncA(v int) { r.a.Inc(v) } // IncB increments counter B. func (r *RatioCounter) IncB(v int) { r.b.Inc(v) } // TestMeter a test meter. type TestMeter struct { Rate float64 NotReady bool WindowSize time.Duration } // GetWindowSize gets windows size. func (tm *TestMeter) GetWindowSize() time.Duration { return tm.WindowSize } // IsReady returns true if the meter is ready. func (tm *TestMeter) IsReady() bool { return !tm.NotReady } // GetRate gets rate. func (tm *TestMeter) GetRate() float64 { return tm.Rate } oxy-2.0.0/memmetrics/ratio_test.go000066400000000000000000000075001450475076400172170ustar00rootroot00000000000000package memmetrics import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" ) func TestNewRatioCounterInvalidParams(t *testing.T) { done := testutils.FreezeTime() defer done() // Bad buckets count _, err := NewRatioCounter(0, clock.Second) require.Error(t, err) // Too precise resolution _, err = NewRatioCounter(10, clock.Millisecond) require.Error(t, err) } func TestNotReady(t *testing.T) { done := testutils.FreezeTime() defer done() // No data fr, err := NewRatioCounter(10, clock.Second) require.NoError(t, err) assert.Equal(t, false, fr.IsReady()) assert.Equal(t, 0.0, fr.Ratio()) // Not enough data fr, err = NewRatioCounter(10, clock.Second) require.NoError(t, err) fr.CountA() assert.Equal(t, false, fr.IsReady()) } func TestNoB(t *testing.T) { done := testutils.FreezeTime() defer done() fr, err := NewRatioCounter(1, clock.Second) require.NoError(t, err) fr.IncA(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 1.0, fr.Ratio()) } func TestNoA(t *testing.T) { done := testutils.FreezeTime() defer done() fr, err := NewRatioCounter(1, clock.Second) require.NoError(t, err) fr.IncB(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 0.0, fr.Ratio()) } // Make sure that data is properly calculated over several buckets. func TestMultipleBuckets(t *testing.T) { done := testutils.FreezeTime() defer done() fr, err := NewRatioCounter(3, clock.Second) require.NoError(t, err) fr.IncB(1) clock.Advance(clock.Second) fr.IncA(1) clock.Advance(clock.Second) fr.IncA(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, float64(2)/float64(3), fr.Ratio()) } // Make sure that data is properly calculated over several buckets // When we overwrite old data when the window is rolling. func TestOverwriteBuckets(t *testing.T) { done := testutils.FreezeTime() defer done() fr, err := NewRatioCounter(3, clock.Second) require.NoError(t, err) fr.IncB(1) clock.Advance(clock.Second) fr.IncA(1) clock.Advance(clock.Second) fr.IncA(1) // This time we should overwrite the old data points clock.Advance(clock.Second) fr.IncA(1) fr.IncB(2) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, float64(3)/float64(5), fr.Ratio()) } // Make sure we cleanup the data after periods of inactivity // So it does not mess up the stats. func TestInactiveBuckets(t *testing.T) { done := testutils.FreezeTime() defer done() fr, err := NewRatioCounter(3, clock.Second) require.NoError(t, err) fr.IncB(1) clock.Advance(clock.Second) fr.IncA(1) clock.Advance(clock.Second) fr.IncA(1) // This time we should overwrite the old data points with new data clock.Advance(clock.Second) fr.IncA(1) fr.IncB(2) // Jump to the last bucket and change the data clock.Advance(clock.Second * 2) fr.IncB(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, float64(1)/float64(4), fr.Ratio()) } func TestLongPeriodsOfInactivity(t *testing.T) { done := testutils.FreezeTime() defer done() fr, err := NewRatioCounter(2, clock.Second) require.NoError(t, err) fr.IncB(1) clock.Advance(clock.Second) fr.IncA(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 0.5, fr.Ratio()) // This time we should overwrite all data points clock.Advance(100 * clock.Second) fr.IncA(1) assert.Equal(t, 1.0, fr.Ratio()) } func TestNewRatioCounterReset(t *testing.T) { done := testutils.FreezeTime() defer done() fr, err := NewRatioCounter(1, clock.Second) require.NoError(t, err) fr.IncB(1) fr.IncA(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 0.5, fr.Ratio()) // Reset the counter fr.Reset() assert.Equal(t, false, fr.IsReady()) // Now add some stats fr.IncA(2) // We are game again! assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 1.0, fr.Ratio()) } oxy-2.0.0/memmetrics/roundtrip.go000066400000000000000000000153521450475076400170740ustar00rootroot00000000000000package memmetrics import ( "errors" "net/http" "sync" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) // NewRTMetricsFn builder function type. type NewRTMetricsFn func() (*RTMetrics, error) // NewCounterFn builder function type. type NewCounterFn func() (*RollingCounter, error) // NewRollingHistogramFn builder function type. type NewRollingHistogramFn func() (*RollingHDRHistogram, error) // RTMetrics provides aggregated performance metrics for HTTP requests processing // such as round trip latency, response codes counters network error and total requests. // all counters are collected as rolling window counters with defined precision, histograms // are a rolling window histograms with defined precision as well. // See RTOptions for more detail on parameters. type RTMetrics struct { total *RollingCounter netErrors *RollingCounter statusCodes map[int]*RollingCounter statusCodesLock sync.RWMutex histogram *RollingHDRHistogram histogramLock sync.RWMutex newCounter NewCounterFn newHist NewRollingHistogramFn } // NewRTMetrics returns new instance of metrics collector. func NewRTMetrics(settings ...RTOption) (*RTMetrics, error) { m := &RTMetrics{ statusCodes: make(map[int]*RollingCounter), statusCodesLock: sync.RWMutex{}, } for _, s := range settings { if err := s(m); err != nil { return nil, err } } if m.newCounter == nil { m.newCounter = func() (*RollingCounter, error) { return NewCounter(counterBuckets, counterResolution) } } if m.newHist == nil { m.newHist = func() (*RollingHDRHistogram, error) { return NewRollingHDRHistogram(histMin, histMax, histSignificantFigures, histPeriod, histBuckets) } } h, err := m.newHist() if err != nil { return nil, err } netErrors, err := m.newCounter() if err != nil { return nil, err } total, err := m.newCounter() if err != nil { return nil, err } m.histogram = h m.netErrors = netErrors m.total = total return m, nil } // Export Returns a new RTMetrics which is a copy of the current one. func (m *RTMetrics) Export() *RTMetrics { m.statusCodesLock.RLock() defer m.statusCodesLock.RUnlock() m.histogramLock.RLock() defer m.histogramLock.RUnlock() export := &RTMetrics{} export.statusCodesLock = sync.RWMutex{} export.histogramLock = sync.RWMutex{} export.total = m.total.Clone() export.netErrors = m.netErrors.Clone() exportStatusCodes := map[int]*RollingCounter{} for code, rollingCounter := range m.statusCodes { exportStatusCodes[code] = rollingCounter.Clone() } export.statusCodes = exportStatusCodes if m.histogram != nil { export.histogram = m.histogram.Export() } export.newCounter = m.newCounter export.newHist = m.newHist return export } // CounterWindowSize gets total windows size. func (m *RTMetrics) CounterWindowSize() time.Duration { return m.total.WindowSize() } // NetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection // that occurred in the given time window compared to the total requests count. func (m *RTMetrics) NetworkErrorRatio() float64 { if m.total.Count() == 0 { return 0 } return float64(m.netErrors.Count()) / float64(m.total.Count()) } // ResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB). func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 { a := int64(0) b := int64(0) m.statusCodesLock.RLock() defer m.statusCodesLock.RUnlock() for code, v := range m.statusCodes { if code < endA && code >= startA { a += v.Count() } if code < endB && code >= startB { b += v.Count() } } if b != 0 { return float64(a) / float64(b) } return 0 } // Append append a metric. func (m *RTMetrics) Append(other *RTMetrics) error { if m == other { return errors.New("RTMetrics cannot append to self") } if err := m.total.Append(other.total); err != nil { return err } if err := m.netErrors.Append(other.netErrors); err != nil { return err } copied := other.Export() m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() m.histogramLock.Lock() defer m.histogramLock.Unlock() for code, c := range copied.statusCodes { o, ok := m.statusCodes[code] if ok { if err := o.Append(c); err != nil { return err } } else { m.statusCodes[code] = c.Clone() } } return m.histogram.Append(copied.histogram) } // Record records a metric. func (m *RTMetrics) Record(code int, duration time.Duration) { m.total.Inc(1) if code == http.StatusGatewayTimeout || code == http.StatusBadGateway { m.netErrors.Inc(1) } _ = m.recordStatusCode(code) _ = m.recordLatency(duration) } // TotalCount returns total count of processed requests collected. func (m *RTMetrics) TotalCount() int64 { return m.total.Count() } // NetworkErrorCount returns total count of processed requests observed. func (m *RTMetrics) NetworkErrorCount() int64 { return m.netErrors.Count() } // StatusCodesCounts returns map with counts of the response codes. func (m *RTMetrics) StatusCodesCounts() map[int]int64 { sc := make(map[int]int64) m.statusCodesLock.RLock() defer m.statusCodesLock.RUnlock() for k, v := range m.statusCodes { if v.Count() != 0 { sc[k] = v.Count() } } return sc } // LatencyHistogram computes and returns resulting histogram with latencies observed. func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) { m.histogramLock.Lock() defer m.histogramLock.Unlock() return m.histogram.Merged() } // Reset reset metrics. func (m *RTMetrics) Reset() { m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() m.histogramLock.Lock() defer m.histogramLock.Unlock() m.histogram.Reset() m.total.Reset() m.netErrors.Reset() m.statusCodes = make(map[int]*RollingCounter) } func (m *RTMetrics) recordLatency(d time.Duration) error { m.histogramLock.Lock() defer m.histogramLock.Unlock() return m.histogram.RecordLatencies(d, 1) } func (m *RTMetrics) recordStatusCode(statusCode int) error { m.statusCodesLock.Lock() if c, ok := m.statusCodes[statusCode]; ok { c.Inc(1) m.statusCodesLock.Unlock() return nil } m.statusCodesLock.Unlock() m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() // Check if another goroutine has written our counter already if c, ok := m.statusCodes[statusCode]; ok { c.Inc(1) return nil } c, err := m.newCounter() if err != nil { return err } c.Inc(1) m.statusCodes[statusCode] = c return nil } const ( counterBuckets = 10 counterResolution = clock.Second histMin = 1 histMax = 3600000000 // 1 hour in microseconds histSignificantFigures = 2 // significant figures (1% precision) histBuckets = 6 // number of sub-histograms in a rolling histogram histPeriod = 10 * clock.Second // roll time ) oxy-2.0.0/memmetrics/roundtrip_test.go000066400000000000000000000073421450475076400201330ustar00rootroot00000000000000package memmetrics import ( "runtime" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" ) func TestDefaults(t *testing.T) { done := testutils.FreezeTime() defer done() rr, err := NewRTMetrics() require.NoError(t, err) require.NotNil(t, rr) rr.Record(200, clock.Second) rr.Record(502, 2*clock.Second) rr.Record(200, clock.Second) rr.Record(200, clock.Second) assert.EqualValues(t, 1, rr.NetworkErrorCount()) assert.EqualValues(t, 4, rr.TotalCount()) assert.Equal(t, map[int]int64{502: 1, 200: 3}, rr.StatusCodesCounts()) assert.Equal(t, float64(1)/float64(4), rr.NetworkErrorRatio()) assert.Equal(t, 1.0/3.0, rr.ResponseCodeRatio(500, 503, 200, 300)) h, err := rr.LatencyHistogram() require.NoError(t, err) assert.Equal(t, 2, int(h.LatencyAtQuantile(100)/clock.Second)) rr.Reset() assert.EqualValues(t, 0, rr.NetworkErrorCount()) assert.EqualValues(t, 0, rr.TotalCount()) assert.Equal(t, map[int]int64{}, rr.StatusCodesCounts()) assert.Equal(t, float64(0), rr.NetworkErrorRatio()) assert.Equal(t, float64(0), rr.ResponseCodeRatio(500, 503, 200, 300)) h, err = rr.LatencyHistogram() require.NoError(t, err) assert.Equal(t, time.Duration(0), h.LatencyAtQuantile(100)) } func TestAppend(t *testing.T) { done := testutils.FreezeTime() defer done() rr, err := NewRTMetrics() require.NoError(t, err) require.NotNil(t, rr) rr.Record(200, clock.Second) rr.Record(502, 2*clock.Second) rr.Record(200, clock.Second) rr.Record(200, clock.Second) rr2, err := NewRTMetrics() require.NoError(t, err) require.NotNil(t, rr2) rr2.Record(200, 3*clock.Second) rr2.Record(501, 3*clock.Second) rr2.Record(200, 3*clock.Second) rr2.Record(200, 3*clock.Second) require.NoError(t, rr2.Append(rr)) assert.Equal(t, map[int]int64{501: 1, 502: 1, 200: 6}, rr2.StatusCodesCounts()) assert.EqualValues(t, 1, rr2.NetworkErrorCount()) h, err := rr2.LatencyHistogram() require.NoError(t, err) assert.EqualValues(t, 3, h.LatencyAtQuantile(100)/clock.Second) } func TestConcurrentRecords(t *testing.T) { // This test asserts a race condition which requires parallelism runtime.GOMAXPROCS(100) rr, err := NewRTMetrics() require.NoError(t, err) for code := 0; code < 100; code++ { for numRecords := 0; numRecords < 10; numRecords++ { go func(statusCode int) { _ = rr.recordStatusCode(statusCode) }(code) } } } func TestRTMetricExportReturnsNewCopy(t *testing.T) { a := RTMetrics{ statusCodes: map[int]*RollingCounter{}, statusCodesLock: sync.RWMutex{}, histogram: &RollingHDRHistogram{}, histogramLock: sync.RWMutex{}, } var err error a.total, err = NewCounter(1, clock.Second) require.NoError(t, err) a.netErrors, err = NewCounter(1, clock.Second) require.NoError(t, err) a.newCounter = func() (*RollingCounter, error) { return NewCounter(counterBuckets, counterResolution) } a.newHist = func() (*RollingHDRHistogram, error) { return NewRollingHDRHistogram(histMin, histMax, histSignificantFigures, histPeriod, histBuckets) } b := a.Export() a.total = nil a.netErrors = nil a.statusCodes = nil a.histogram = nil a.newCounter = nil a.newHist = nil assert.NotNil(t, b.total) assert.NotNil(t, b.netErrors) assert.NotNil(t, b.statusCodes) assert.NotNil(t, b.histogram) assert.NotNil(t, b.newCounter) assert.NotNil(t, b.newHist) // a and b should have different locks locksSucceed := make(chan bool) go func() { a.statusCodesLock.Lock() b.statusCodesLock.Lock() a.histogramLock.Lock() b.histogramLock.Lock() locksSucceed <- true }() for { select { case <-locksSucceed: return case <-clock.After(10 * clock.Second): t.FailNow() } } } oxy-2.0.0/ratelimit/000077500000000000000000000000001450475076400143365ustar00rootroot00000000000000oxy-2.0.0/ratelimit/bucket.go000066400000000000000000000105401450475076400161420ustar00rootroot00000000000000package ratelimit import ( "fmt" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) // UndefinedDelay default delay. const UndefinedDelay = -1 // rate defines token bucket parameters. type rate struct { period time.Duration average int64 burst int64 } func (r *rate) String() string { return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst) } // tokenBucket Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket) type tokenBucket struct { // The time period controlled by the bucket in nanoseconds. period time.Duration // The number of nanoseconds that takes to add one more token to the total // number of available tokens. It effectively caches the value that could // have been otherwise deduced from refillRate. timePerToken time.Duration // The maximum number of tokens that can be accumulate in the bucket. burst int64 // The number of tokens available for consumption at the moment. It can // nether be larger then capacity. availableTokens int64 // Tells when tokensAvailable was updated the last time. lastRefresh clock.Time // The number of tokens consumed the last time. lastConsumed int64 } // newTokenBucket crates a `tokenBucket` instance for the specified `Rate`. func newTokenBucket(rate *rate) *tokenBucket { period := rate.period if period == 0 { period = clock.Nanosecond } return &tokenBucket{ period: period, timePerToken: time.Duration(int64(period) / rate.average), burst: rate.burst, lastRefresh: clock.Now().UTC(), availableTokens: rate.burst, } } // consume makes an attempt to consume the specified number of tokens from the // bucket. If there are enough tokens available then `0, nil` is returned; if // tokens to consume is larger than the burst size, then an error is returned // and the delay is not defined; otherwise returned a none zero delay that tells // how much time the caller needs to wait until the desired number of tokens // will become available for consumption. func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) { tb.updateAvailableTokens() tb.lastConsumed = 0 if tokens > tb.burst { return UndefinedDelay, fmt.Errorf("requested tokens larger than max tokens") } if tb.availableTokens < tokens { return tb.timeTillAvailable(tokens), nil } tb.availableTokens -= tokens tb.lastConsumed = tokens return 0, nil } // rollback reverts effect of the most recent consumption. If the most recent // `consume` resulted in an error or a burst overflow, and therefore did not // modify the number of available tokens, then `rollback` won't do that either. // It is safe to call this method multiple times, for the second and all // following calls have no effect. func (tb *tokenBucket) rollback() { tb.availableTokens += tb.lastConsumed tb.lastConsumed = 0 } // update modifies `average` and `burst` fields of the token bucket according // to the provided `Rate`. func (tb *tokenBucket) update(rate *rate) error { if rate.period != tb.period { return fmt.Errorf("period mismatch: %v != %v", tb.period, rate.period) } tb.timePerToken = time.Duration(int64(tb.period) / rate.average) tb.burst = rate.burst if tb.availableTokens > rate.burst { tb.availableTokens = rate.burst } return nil } // timeTillAvailable returns the number of nanoseconds that we need to // wait until the specified number of tokens becomes available for consumption. func (tb *tokenBucket) timeTillAvailable(tokens int64) time.Duration { missingTokens := tokens - tb.availableTokens return time.Duration(missingTokens) * tb.timePerToken } // updateAvailableTokens updates the number of tokens available for consumption. // It is calculated based on the refill rate, the time passed since last refresh, // and is limited by the bucket capacity. func (tb *tokenBucket) updateAvailableTokens() { now := clock.Now().UTC() timePassed := now.Sub(tb.lastRefresh) if tb.timePerToken == 0 { return } tokens := tb.availableTokens + int64(timePassed/tb.timePerToken) // If we haven't added any tokens that means that not enough time has passed, // in this case do not adjust last refill checkpoint, otherwise it will be // always moving in time in case of frequent requests that exceed the rate if tokens != tb.availableTokens { tb.lastRefresh = now tb.availableTokens = tokens } if tb.availableTokens > tb.burst { tb.availableTokens = tb.burst } } oxy-2.0.0/ratelimit/bucket_test.go000066400000000000000000000241341450475076400172050ustar00rootroot00000000000000package ratelimit import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" ) func TestConsumeSingleToken(t *testing.T) { done := testutils.FreezeTime() defer done() tb := newTokenBucket(&rate{period: clock.Second, average: 1, burst: 1}) // First request passes delay, err := tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // Next request does not pass the same second delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, clock.Second, delay) // Second later, the request passes clock.Advance(clock.Second) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // Five seconds later, still only one request is allowed // because maxBurst is 1 clock.Advance(5 * clock.Second) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // The next one is forbidden delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, clock.Second, delay) } func TestFastConsumption(t *testing.T) { done := testutils.FreezeTime() defer done() tb := newTokenBucket(&rate{period: clock.Second, average: 1, burst: 1}) // First request passes delay, err := tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // Try 200 ms later clock.Advance(clock.Millisecond * 200) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, clock.Second, delay) // Try 700 ms later clock.Advance(clock.Millisecond * 700) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, clock.Second, delay) // Try 100 ms later, success! clock.Advance(clock.Millisecond * 100) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) } func TestConsumeMultipleTokens(t *testing.T) { done := testutils.FreezeTime() defer done() tb := newTokenBucket(&rate{period: clock.Second, average: 3, burst: 5}) delay, err := tb.consume(3) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(2) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.NotEqual(t, time.Duration(0), delay) } func TestDelayIsCorrect(t *testing.T) { done := testutils.FreezeTime() defer done() tb := newTokenBucket(&rate{period: clock.Second, average: 3, burst: 5}) // Exhaust initial capacity delay, err := tb.consume(5) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(3) require.NoError(t, err) assert.NotEqual(t, time.Duration(0), delay) // Now wait provided delay and make sure we can consume now clock.Advance(delay) delay, err = tb.consume(3) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) } // Make sure requests that exceed burst size are not allowed. func TestExceedsBurst(t *testing.T) { done := testutils.FreezeTime() defer done() tb := newTokenBucket(&rate{period: clock.Second, average: 1, burst: 10}) _, err := tb.consume(11) require.Error(t, err) } func TestConsumeBurst(t *testing.T) { done := testutils.FreezeTime() defer done() tb := newTokenBucket(&rate{period: clock.Second, average: 2, burst: 5}) // In two seconds we would have 5 tokens clock.Advance(2 * clock.Second) // Lets consume 5 at once delay, err := tb.consume(5) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) } func TestConsumeEstimate(t *testing.T) { done := testutils.FreezeTime() defer done() tb := newTokenBucket(&rate{period: clock.Second, average: 2, burst: 4}) // Consume all burst at once delay, err := tb.consume(4) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // Now try to consume it and face delay delay, err = tb.consume(4) require.NoError(t, err) assert.Equal(t, time.Duration(2)*clock.Second, delay) } // If a rate with different period is passed to the `update` method, then an // error is returned but the state of the bucket remains valid and unchanged. func TestUpdateInvalidPeriod(t *testing.T) { done := testutils.FreezeTime() defer done() // Given tb := newTokenBucket(&rate{period: clock.Second, average: 10, burst: 20}) _, err := tb.consume(15) // 5 tokens available require.NoError(t, err) // When err = tb.update(&rate{period: clock.Second + 1, average: 30, burst: 40}) // still 5 tokens available require.Error(t, err) // Then // ...check that rate did not change clock.Advance(500 * clock.Millisecond) delay, err := tb.consume(11) require.NoError(t, err) assert.Equal(t, 100*clock.Millisecond, delay) delay, err = tb.consume(10) require.NoError(t, err) // 0 available assert.Equal(t, time.Duration(0), delay) // ...check that burst did not change clock.Advance(40 * clock.Second) _, err = tb.consume(21) require.Error(t, err) delay, err = tb.consume(20) require.NoError(t, err) // 0 available assert.Equal(t, time.Duration(0), delay) } // If the capacity of the bucket is increased by the update then it takes some // time to fill the bucket with tokens up to the new capacity. func TestUpdateBurstIncreased(t *testing.T) { done := testutils.FreezeTime() defer done() // Given tb := newTokenBucket(&rate{period: clock.Second, average: 10, burst: 20}) _, err := tb.consume(15) // 5 tokens available require.NoError(t, err) // When err = tb.update(&rate{period: clock.Second, average: 10, burst: 50}) // still 5 tokens available require.NoError(t, err) // Then delay, err := tb.consume(50) require.NoError(t, err) assert.Equal(t, clock.Second/10*45, delay) } // If the capacity of the bucket is increased by the update then it takes some // time to fill the bucket with tokens up to the new capacity. func TestUpdateBurstDecreased(t *testing.T) { done := testutils.FreezeTime() defer done() // Given tb := newTokenBucket(&rate{period: clock.Second, average: 10, burst: 50}) _, err := tb.consume(15) // 35 tokens available require.NoError(t, err) // When err = tb.update(&rate{period: clock.Second, average: 10, burst: 20}) // the number of available tokens reduced to 20. require.NoError(t, err) // Then delay, err := tb.consume(21) require.Error(t, err) assert.Equal(t, time.Duration(-1), delay) } // If rate is updated then it affects the bucket refill speed. func TestUpdateRateChanged(t *testing.T) { done := testutils.FreezeTime() defer done() // Given tb := newTokenBucket(&rate{period: clock.Second, average: 10, burst: 20}) _, err := tb.consume(15) // 5 tokens available require.NoError(t, err) // When err = tb.update(&rate{period: clock.Second, average: 20, burst: 20}) // still 5 tokens available require.NoError(t, err) // Then delay, err := tb.consume(20) require.NoError(t, err) assert.Equal(t, clock.Second/20*15, delay) } // Only the most recent consumption is reverted by `Rollback`. func TestRollback(t *testing.T) { done := testutils.FreezeTime() defer done() // Given tb := newTokenBucket(&rate{period: clock.Second, average: 10, burst: 20}) _, err := tb.consume(8) // 12 tokens available require.NoError(t, err) _, err = tb.consume(7) // 5 tokens available require.NoError(t, err) // When tb.rollback() // 12 tokens available // Then delay, err := tb.consume(12) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, 100*clock.Millisecond, delay) } // It is safe to call `Rollback` several times. The second and all subsequent // calls just do nothing. func TestRollbackSeveralTimes(t *testing.T) { done := testutils.FreezeTime() defer done() // Given tb := newTokenBucket(&rate{period: clock.Second, average: 10, burst: 20}) _, err := tb.consume(8) // 12 tokens available require.NoError(t, err) tb.rollback() // 20 tokens available // When tb.rollback() // still 20 tokens available tb.rollback() // still 20 tokens available tb.rollback() // still 20 tokens available // Then: all 20 tokens can be consumed delay, err := tb.consume(20) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, 100*clock.Millisecond, delay) } // If previous consumption returned a delay due to an attempt to consume more // tokens then there are available, then `Rollback` has no effect. func TestRollbackAfterAvailableExceeded(t *testing.T) { done := testutils.FreezeTime() defer done() // Given tb := newTokenBucket(&rate{period: clock.Second, average: 10, burst: 20}) _, err := tb.consume(8) // 12 tokens available require.NoError(t, err) delay, err := tb.consume(15) // still 12 tokens available require.NoError(t, err) assert.Equal(t, 300*clock.Millisecond, delay) // When tb.rollback() // Previous operation consumed 0 tokens, so rollback has no effect. // Then delay, err = tb.consume(12) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, 100*clock.Millisecond, delay) } // If previous consumption returned a error due to an attempt to consume more // tokens then the bucket's burst size, then `Rollback` has no effect. func TestRollbackAfterError(t *testing.T) { done := testutils.FreezeTime() defer done() // Given tb := newTokenBucket(&rate{period: clock.Second, average: 10, burst: 20}) _, err := tb.consume(8) // 12 tokens available require.NoError(t, err) delay, err := tb.consume(21) // still 12 tokens available require.Error(t, err) assert.Equal(t, time.Duration(-1), delay) // When tb.rollback() // Previous operation consumed 0 tokens, so rollback has no effect. // Then delay, err = tb.consume(12) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, 100*clock.Millisecond, delay) } func TestDivisionByZeroOnPeriod(t *testing.T) { var emptyPeriod int64 tb := newTokenBucket(&rate{period: time.Duration(emptyPeriod), average: 2, burst: 2}) _, err := tb.consume(1) assert.NoError(t, err) err = tb.update(&rate{period: clock.Nanosecond, average: 1, burst: 1}) assert.NoError(t, err) } oxy-2.0.0/ratelimit/bucketset.go000066400000000000000000000061621450475076400166630ustar00rootroot00000000000000package ratelimit import ( "fmt" "sort" "strings" "time" ) // TokenBucketSet represents a set of TokenBucket covering different time periods. type TokenBucketSet struct { buckets map[time.Duration]*tokenBucket maxPeriod time.Duration } // NewTokenBucketSet creates a `TokenBucketSet` from the specified `rates`. func NewTokenBucketSet(rates *RateSet) *TokenBucketSet { tbs := new(TokenBucketSet) // In the majority of cases we will have only one bucket. tbs.buckets = make(map[time.Duration]*tokenBucket, len(rates.m)) for _, rate := range rates.m { newBucket := newTokenBucket(rate) tbs.buckets[rate.period] = newBucket tbs.maxPeriod = maxDuration(tbs.maxPeriod, rate.period) } return tbs } // Update brings the buckets in the set in accordance with the provided `rates`. func (tbs *TokenBucketSet) Update(rates *RateSet) { // Update existing buckets and delete those that have no corresponding spec. for _, bucket := range tbs.buckets { if rate, ok := rates.m[bucket.period]; ok { _ = bucket.update(rate) } else { delete(tbs.buckets, bucket.period) } } // Add missing buckets. for _, rate := range rates.m { if _, ok := tbs.buckets[rate.period]; !ok { newBucket := newTokenBucket(rate) tbs.buckets[rate.period] = newBucket } } // Identify the maximum period in the set tbs.maxPeriod = 0 for _, bucket := range tbs.buckets { tbs.maxPeriod = maxDuration(tbs.maxPeriod, bucket.period) } } // Consume consume tokens. func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) { var maxDelay time.Duration = UndefinedDelay var firstErr error for _, tokenBucket := range tbs.buckets { // We keep calling `Consume` even after a error is returned for one of // buckets because that allows us to simplify the rollback procedure, // that is to just call `Rollback` for all buckets. delay, err := tokenBucket.consume(tokens) if firstErr == nil { if err != nil { firstErr = err } else { maxDelay = maxDuration(maxDelay, delay) } } } // If we could not make ALL buckets consume tokens for whatever reason, // then rollback consumption for all of them. if firstErr != nil || maxDelay > 0 { for _, tokenBucket := range tbs.buckets { tokenBucket.rollback() } } return maxDelay, firstErr } // GetMaxPeriod returns the max period. func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration { return tbs.maxPeriod } // debugState returns string that reflects the current state of all buckets in // this set. It is intended to be used for debugging and testing only. func (tbs *TokenBucketSet) debugState() string { periods := make([]int64, 0, len(tbs.buckets)) for period := range tbs.buckets { periods = append(periods, int64(period)) } sort.Slice(periods, func(i, j int) bool { return periods[i] < periods[j] }) bucketRepr := make([]string, 0, len(tbs.buckets)) for _, period := range periods { bucket := tbs.buckets[time.Duration(period)] bucketRepr = append(bucketRepr, fmt.Sprintf("{%v: %v}", bucket.period, bucket.availableTokens)) } return strings.Join(bucketRepr, ", ") } func maxDuration(x, y time.Duration) time.Duration { if x > y { return x } return y } oxy-2.0.0/ratelimit/bucketset_test.go000066400000000000000000000145631450475076400177260ustar00rootroot00000000000000package ratelimit import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" ) // A value returned by `MaxPeriod` corresponds to the longest bucket time period. func TestLongestPeriod(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*clock.Second, 10, 20)) require.NoError(t, rates.Add(7*clock.Second, 10, 20)) require.NoError(t, rates.Add(5*clock.Second, 11, 21)) done := testutils.FreezeTime() defer done() // When tbs := NewTokenBucketSet(rates) // Then assert.Equal(t, 7*clock.Second, tbs.maxPeriod) } // Successful token consumption updates state of all buckets in the set. func TestConsume(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*clock.Second, 10, 20)) require.NoError(t, rates.Add(10*clock.Second, 20, 50)) done := testutils.FreezeTime() defer done() tbs := NewTokenBucketSet(rates) // When delay, err := tbs.Consume(15) require.NoError(t, err) // Then assert.Equal(t, time.Duration(0), delay) assert.Equal(t, "{1s: 5}, {10s: 35}", tbs.debugState()) } // As time goes by all set buckets are refilled with appropriate rates. func TestConsumeRefill(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(10*clock.Second, 10, 20)) require.NoError(t, rates.Add(100*clock.Second, 20, 50)) done := testutils.FreezeTime() defer done() tbs := NewTokenBucketSet(rates) _, err := tbs.Consume(15) require.NoError(t, err) assert.Equal(t, "{10s: 5}, {1m40s: 35}", tbs.debugState()) // When clock.Advance(10 * clock.Second) delay, err := tbs.Consume(0) // Consumes nothing but forces an internal state update. require.NoError(t, err) // Then assert.Equal(t, time.Duration(0), delay) assert.Equal(t, "{10s: 15}, {1m40s: 37}", tbs.debugState()) } // If the first bucket in the set has no enough tokens to allow desired // consumption then an appropriate delay is returned. func TestConsumeLimitedBy1st(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(10*clock.Second, 10, 10)) require.NoError(t, rates.Add(100*clock.Second, 20, 20)) done := testutils.FreezeTime() defer done() tbs := NewTokenBucketSet(rates) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{10s: 5}, {1m40s: 15}", tbs.debugState()) // When delay, err := tbs.Consume(10) require.NoError(t, err) // Then assert.Equal(t, 5*clock.Second, delay) assert.Equal(t, "{10s: 5}, {1m40s: 15}", tbs.debugState()) } // If the second bucket in the set has no enough tokens to allow desired // consumption then an appropriate delay is returned. func TestConsumeLimitedBy2st(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(10*clock.Second, 10, 10)) require.NoError(t, rates.Add(100*clock.Second, 20, 20)) done := testutils.FreezeTime() defer done() tbs := NewTokenBucketSet(rates) _, err := tbs.Consume(10) require.NoError(t, err) clock.Advance(10 * clock.Second) _, err = tbs.Consume(10) require.NoError(t, err) clock.Advance(5 * clock.Second) _, err = tbs.Consume(0) require.NoError(t, err) assert.Equal(t, "{10s: 5}, {1m40s: 3}", tbs.debugState()) // When delay, err := tbs.Consume(10) require.NoError(t, err) // Then assert.Equal(t, 7*(5*clock.Second), delay) assert.Equal(t, "{10s: 5}, {1m40s: 3}", tbs.debugState()) } // An attempt to consume more tokens then the smallest bucket capacity results // in error. func TestConsumeMoreThenBurst(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*clock.Second, 10, 20)) require.NoError(t, rates.Add(10*clock.Second, 50, 100)) done := testutils.FreezeTime() defer done() tbs := NewTokenBucketSet(rates) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{1s: 15}, {10s: 95}", tbs.debugState()) // When _, err = tbs.Consume(21) require.Error(t, err) // Then assert.Equal(t, "{1s: 15}, {10s: 95}", tbs.debugState()) } // Update operation can add buckets. func TestUpdateMore(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*clock.Second, 10, 20)) require.NoError(t, rates.Add(10*clock.Second, 20, 50)) require.NoError(t, rates.Add(20*clock.Second, 45, 90)) done := testutils.FreezeTime() defer done() tbs := NewTokenBucketSet(rates) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{1s: 15}, {10s: 45}, {20s: 85}", tbs.debugState()) rates = NewRateSet() require.NoError(t, rates.Add(10*clock.Second, 30, 40)) require.NoError(t, rates.Add(11*clock.Second, 30, 40)) require.NoError(t, rates.Add(12*clock.Second, 30, 40)) require.NoError(t, rates.Add(13*clock.Second, 30, 40)) // When tbs.Update(rates) // Then assert.Equal(t, "{10s: 40}, {11s: 40}, {12s: 40}, {13s: 40}", tbs.debugState()) assert.Equal(t, 13*clock.Second, tbs.maxPeriod) } // Update operation can remove buckets. func TestUpdateLess(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*clock.Second, 10, 20)) require.NoError(t, rates.Add(10*clock.Second, 20, 50)) require.NoError(t, rates.Add(20*clock.Second, 45, 90)) require.NoError(t, rates.Add(30*clock.Second, 50, 100)) done := testutils.FreezeTime() defer done() tbs := NewTokenBucketSet(rates) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{1s: 15}, {10s: 45}, {20s: 85}, {30s: 95}", tbs.debugState()) rates = NewRateSet() require.NoError(t, rates.Add(10*clock.Second, 25, 20)) require.NoError(t, rates.Add(20*clock.Second, 30, 21)) // When tbs.Update(rates) // Then assert.Equal(t, "{10s: 20}, {20s: 21}", tbs.debugState()) assert.Equal(t, 20*clock.Second, tbs.maxPeriod) } // Update operation can remove buckets. func TestUpdateAllDifferent(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(10*clock.Second, 20, 50)) require.NoError(t, rates.Add(30*clock.Second, 50, 100)) done := testutils.FreezeTime() defer done() tbs := NewTokenBucketSet(rates) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{10s: 45}, {30s: 95}", tbs.debugState()) rates = NewRateSet() require.NoError(t, rates.Add(1*clock.Second, 10, 40)) require.NoError(t, rates.Add(60*clock.Second, 100, 150)) // When tbs.Update(rates) // Then assert.Equal(t, "{1s: 40}, {1m0s: 150}", tbs.debugState()) assert.Equal(t, 60*clock.Second, tbs.maxPeriod) } oxy-2.0.0/ratelimit/options.go000066400000000000000000000016741450475076400163700ustar00rootroot00000000000000package ratelimit import ( "fmt" "github.com/vulcand/oxy/v2/utils" ) // TokenLimiterOption token limiter option type. type TokenLimiterOption func(l *TokenLimiter) error // ErrorHandler sets error handler of the server. func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption { return func(cl *TokenLimiter) error { cl.errHandler = h return nil } } // ExtractRates sets the rate extractor. func ExtractRates(e RateExtractor) TokenLimiterOption { return func(cl *TokenLimiter) error { cl.extractRates = e return nil } } // Capacity sets the capacity. func Capacity(capacity int) TokenLimiterOption { return func(cl *TokenLimiter) error { if capacity <= 0 { return fmt.Errorf("bad capacity: %v", capacity) } cl.capacity = capacity return nil } } // Logger defines the logger the TokenLimiter will use. func Logger(l utils.Logger) TokenLimiterOption { return func(tl *TokenLimiter) error { tl.log = l return nil } } oxy-2.0.0/ratelimit/tokenlimiter.go000066400000000000000000000124601450475076400173760ustar00rootroot00000000000000// Package ratelimit Tokenbucket based request rate limiter package ratelimit import ( "fmt" "net/http" "sync" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/internal/holsterv4/collections" "github.com/vulcand/oxy/v2/utils" ) // DefaultCapacity default capacity. const DefaultCapacity = 65536 // RateSet maintains a set of rates. It can contain only one rate per period at a time. type RateSet struct { m map[time.Duration]*rate } // NewRateSet crates an empty `RateSet` instance. func NewRateSet() *RateSet { rs := new(RateSet) rs.m = make(map[time.Duration]*rate) return rs } // Add adds a rate to the set. If there is a rate with the same period in the // set then the new rate overrides the old one. func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error { if period <= 0 { return fmt.Errorf("invalid period: %v", period) } if average <= 0 { return fmt.Errorf("invalid average: %v", average) } if burst <= 0 { return fmt.Errorf("invalid burst: %v", burst) } rs.m[period] = &rate{period: period, average: average, burst: burst} return nil } func (rs *RateSet) String() string { return fmt.Sprint(rs.m) } // RateExtractor rate extractor. type RateExtractor interface { Extract(r *http.Request) (*RateSet, error) } // RateExtractorFunc rate extractor function type. type RateExtractorFunc func(r *http.Request) (*RateSet, error) // Extract extract from request. func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) { return e(r) } // TokenLimiter implements rate limiting middleware. type TokenLimiter struct { defaultRates *RateSet extract utils.SourceExtractor extractRates RateExtractor mutex sync.Mutex bucketSets *collections.TTLMap errHandler utils.ErrorHandler capacity int next http.Handler log utils.Logger } // New constructs a `TokenLimiter` middleware instance. func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) { if defaultRates == nil || len(defaultRates.m) == 0 { return nil, fmt.Errorf("provide default rates") } if extract == nil { return nil, fmt.Errorf("provide extract function") } tl := &TokenLimiter{ next: next, defaultRates: defaultRates, extract: extract, log: &utils.NoopLogger{}, } for _, o := range opts { if err := o(tl); err != nil { return nil, err } } setDefaults(tl) tl.bucketSets = collections.NewTTLMap(tl.capacity) return tl, nil } // Wrap sets the next handler to be called by token limiter handler. func (tl *TokenLimiter) Wrap(next http.Handler) { tl.next = next } func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) { source, amount, err := tl.extract.Extract(req) if err != nil { tl.errHandler.ServeHTTP(w, req, err) return } if err := tl.consumeRates(req, source, amount); err != nil { tl.log.Warn("limiting request %v %v, limit: %v", req.Method, req.URL, err) tl.errHandler.ServeHTTP(w, req, err) return } tl.next.ServeHTTP(w, req) } func (tl *TokenLimiter) consumeRates(req *http.Request, source string, amount int64) error { tl.mutex.Lock() defer tl.mutex.Unlock() effectiveRates := tl.resolveRates(req) bucketSetI, exists := tl.bucketSets.Get(source) var bucketSet *TokenBucketSet if exists { bucketSet = bucketSetI.(*TokenBucketSet) bucketSet.Update(effectiveRates) } else { bucketSet = NewTokenBucketSet(effectiveRates) // We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip // the counters for this ip will expire after 10 seconds of inactivity err := tl.bucketSets.Set(source, bucketSet, int(bucketSet.maxPeriod/clock.Second)*10+1) if err != nil { return err } } delay, err := bucketSet.Consume(amount) if err != nil { return err } if delay > 0 { return &MaxRateError{Delay: delay} } return nil } // effectiveRates retrieves rates to be applied to the request. func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet { // If configuration mapper is not specified for this instance, then return // the default bucket specs. if tl.extractRates == nil { return tl.defaultRates } rates, err := tl.extractRates.Extract(req) if err != nil { tl.log.Error("Failed to retrieve rates: %v", err) return tl.defaultRates } // If the returned rate set is empty then used the default one. if len(rates.m) == 0 { return tl.defaultRates } return rates } // MaxRateError max rate error. type MaxRateError struct { Delay time.Duration } func (m *MaxRateError) Error() string { return fmt.Sprintf("max rate reached: retry-in %v", m.Delay) } // RateErrHandler error handler. type RateErrHandler struct{} func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { //nolint:errorlint // must be changed if rerr, ok := err.(*MaxRateError); ok { w.Header().Set("Retry-After", fmt.Sprintf("%.0f", rerr.Delay.Seconds())) w.Header().Set("X-Retry-In", rerr.Delay.String()) w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(err.Error())) return } utils.DefaultHandler.ServeHTTP(w, req, err) } var defaultErrHandler = &RateErrHandler{} func setDefaults(tl *TokenLimiter) { if tl.capacity <= 0 { tl.capacity = DefaultCapacity } if tl.errHandler == nil { tl.errHandler = defaultErrHandler } } oxy-2.0.0/ratelimit/tokenlimiter_test.go000066400000000000000000000231771450475076400204440ustar00rootroot00000000000000package ratelimit import ( "fmt" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" "github.com/vulcand/oxy/v2/utils" ) func TestRateSetAdd(t *testing.T) { rs := NewRateSet() // Invalid period err := rs.Add(0, 1, 1) require.Error(t, err) // Invalid Average err = rs.Add(clock.Second, 0, 1) require.Error(t, err) // Invalid Burst err = rs.Add(clock.Second, 1, 0) require.Error(t, err) err = rs.Add(clock.Second, 1, 1) require.NoError(t, err) assert.Equal(t, rs.String(), "map[1s:rate(1/1s, burst=1)]") } // We've hit the limit and were able to proceed on the next time run. func TestHitLimit(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(clock.Second, 1, 1) require.NoError(t, err) done := testutils.FreezeTime() defer done() l, err := New(handler, headerLimit, rates) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) // Next request from the same source hits rate limit re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) // Second later, the request from this ip will succeed clock.Advance(clock.Second) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // We've failed to extract client ip. func TestFailure(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(clock.Second, 1, 1) require.NoError(t, err) done := testutils.FreezeTime() defer done() l, err := New(handler, faultyExtract, rates) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } // Make sure rates from different ips are controlled separately. func TestIsolation(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(clock.Second, 1, 1) require.NoError(t, err) done := testutils.FreezeTime() defer done() l, err := New(handler, headerLimit, rates) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) // Next request from the same source hits rate limit re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) // The request from other source can proceed re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "b")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // Make sure that expiration works (Expiration is triggered after significant amount of time passes). func TestExpiration(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(clock.Second, 1, 1) require.NoError(t, err) done := testutils.FreezeTime() defer done() l, err := New(handler, headerLimit, rates) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) // Next request from the same source hits rate limit re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) // 24 hours later, the request from this ip will succeed clock.Advance(24 * clock.Hour) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // If rate limiting configuration is valid, then it is applied. func TestExtractRates(t *testing.T) { // Given extractRates := func(*http.Request) (*RateSet, error) { rates := NewRateSet() err := rates.Add(clock.Second, 2, 2) if err != nil { return nil, err } err = rates.Add(60*clock.Second, 10, 10) if err != nil { return nil, err } return rates, nil } rates := NewRateSet() err := rates.Add(clock.Second, 1, 1) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) done := testutils.FreezeTime() defer done() tl, err := New(handler, headerLimit, rates, ExtractRates(RateExtractorFunc(extractRates))) require.NoError(t, err) srv := httptest.NewServer(tl) t.Cleanup(srv.Close) // When/Then: The configured rate is applied, which 2 req/second re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) clock.Advance(clock.Second) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // If configMapper returns error, then the default rate is applied. func TestBadRateExtractor(t *testing.T) { // Given extractor := func(*http.Request) (*RateSet, error) { return nil, fmt.Errorf("boom") } rates := NewRateSet() err := rates.Add(clock.Second, 1, 1) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) done := testutils.FreezeTime() defer done() l, err := New(handler, headerLimit, rates, ExtractRates(RateExtractorFunc(extractor))) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) // When/Then: The default rate is applied, which 1 req/second re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) clock.Advance(clock.Second) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // If configMapper returns empty rates, then the default rate is applied. func TestExtractorEmpty(t *testing.T) { // Given extractor := func(*http.Request) (*RateSet, error) { return NewRateSet(), nil } rates := NewRateSet() err := rates.Add(clock.Second, 1, 1) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) done := testutils.FreezeTime() defer done() l, err := New(handler, headerLimit, rates, ExtractRates(RateExtractorFunc(extractor))) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) // When/Then: The default rate is applied, which 1 req/second re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) clock.Advance(clock.Second) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestInvalidParams(t *testing.T) { // Rates are missing rs := NewRateSet() err := rs.Add(clock.Second, 1, 1) require.NoError(t, err) // Empty _, err = New(nil, nil, rs) require.Error(t, err) // Rates are empty _, err = New(nil, nil, NewRateSet()) require.Error(t, err) // Bad capacity _, err = New(nil, headerLimit, rs, Capacity(-1)) require.Error(t, err) } // We've hit the limit and were able to proceed on the next time run. func TestOptions(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(clock.Second, 1, 1) require.NoError(t, err) errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) _, _ = w.Write([]byte(http.StatusText(http.StatusTeapot))) }) done := testutils.FreezeTime() defer done() l, err := New(handler, headerLimit, rates, ErrorHandler(errHandler)) require.NoError(t, err) srv := httptest.NewServer(l) t.Cleanup(srv.Close) re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) } func headerLimiter(req *http.Request) (string, int64, error) { return req.Header.Get("Source"), 1, nil } func faultyExtractor(_ *http.Request) (string, int64, error) { return "", -1, fmt.Errorf("oops") } var headerLimit = utils.ExtractorFunc(headerLimiter) var faultyExtract = utils.ExtractorFunc(faultyExtractor) oxy-2.0.0/roundrobin/000077500000000000000000000000001450475076400145255ustar00rootroot00000000000000oxy-2.0.0/roundrobin/RequestRewriteListener.go000066400000000000000000000002531450475076400215540ustar00rootroot00000000000000package roundrobin import "net/http" // RequestRewriteListener function to rewrite request. type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request) oxy-2.0.0/roundrobin/options.go000066400000000000000000000060071450475076400165520ustar00rootroot00000000000000package roundrobin import ( "fmt" "time" "github.com/vulcand/oxy/v2/utils" ) // RebalancerOption represents an option you can pass to NewRebalancer. type RebalancerOption func(*Rebalancer) error // RebalancerBackoff sets a beck off duration. func RebalancerBackoff(d time.Duration) RebalancerOption { return func(r *Rebalancer) error { r.backoffDuration = d return nil } } // RebalancerMeter sets a Meter builder function. func RebalancerMeter(newMeter NewMeterFn) RebalancerOption { return func(r *Rebalancer) error { r.newMeter = newMeter return nil } } // RebalancerErrorHandler is a functional argument that sets error handler of the server. func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption { return func(r *Rebalancer) error { r.errHandler = h return nil } } // RebalancerStickySession sets a sticky session. func RebalancerStickySession(stickySession *StickySession) RebalancerOption { return func(r *Rebalancer) error { r.stickySession = stickySession return nil } } // RebalancerRequestRewriteListener is a functional argument that sets error handler of the server. func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption { return func(r *Rebalancer) error { r.requestRewriteListener = rrl return nil } } // RebalancerLogger defines the logger used by Rebalancer. func RebalancerLogger(l utils.Logger) RebalancerOption { return func(rb *Rebalancer) error { rb.log = l return nil } } // RebalancerDebug additional debug information. func RebalancerDebug(debug bool) RebalancerOption { return func(rb *Rebalancer) error { rb.debug = debug return nil } } // ServerOption provides various options for server, e.g. weight. type ServerOption func(*server) error // Weight is an optional functional argument that sets weight of the server. func Weight(w int) ServerOption { return func(s *server) error { if w < 0 { return fmt.Errorf("Weight should be >= 0") } s.weight = w return nil } } // LBOption provides options for load balancer. type LBOption func(*RoundRobin) error // ErrorHandler is a functional argument that sets error handler of the server. func ErrorHandler(h utils.ErrorHandler) LBOption { return func(s *RoundRobin) error { s.errHandler = h return nil } } // EnableStickySession enable sticky session. func EnableStickySession(stickySession *StickySession) LBOption { return func(s *RoundRobin) error { s.stickySession = stickySession return nil } } // RoundRobinRequestRewriteListener is a functional argument that sets error handler of the server. func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { return func(s *RoundRobin) error { s.requestRewriteListener = rrl return nil } } // Logger defines the logger the RoundRobin will use. func Logger(l utils.Logger) LBOption { return func(r *RoundRobin) error { r.log = l return nil } } // Verbose additional debug information. func Verbose(verbose bool) LBOption { return func(r *RoundRobin) error { r.verbose = verbose return nil } } oxy-2.0.0/roundrobin/rebalancer.go000066400000000000000000000247441450475076400171650ustar00rootroot00000000000000package roundrobin import ( "fmt" "net/http" "net/url" "sync" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/memmetrics" "github.com/vulcand/oxy/v2/utils" ) const ( // FSMMaxWeight is the maximum weight that handler will set for the server. FSMMaxWeight = 4096 // FSMGrowFactor Multiplier for the server weight. FSMGrowFactor = 4 ) // splitThreshold tells how far the value should go from the median + median absolute deviation before it is considered an outlier. const splitThreshold = 1.5 // BalancerHandler the balancer/handler interface. type BalancerHandler interface { Servers() []*url.URL ServeHTTP(w http.ResponseWriter, req *http.Request) ServerWeight(u *url.URL) (int, bool) RemoveServer(u *url.URL) error UpsertServer(u *url.URL, options ...ServerOption) error NextServer() (*url.URL, error) Next() http.Handler } // Meter measures server performance and returns its relative value via rating. type Meter interface { Rating() float64 Record(int, time.Duration) IsReady() bool } // NewMeterFn type of functions to create new Meter. type NewMeterFn func() (Meter, error) // Rebalancer increases weights on servers that perform better than others. It also rolls back to original weights // if the servers have changed. It is designed as a wrapper on top of the round-robin. type Rebalancer struct { // mutex mtx *sync.Mutex // Time that freezes state machine to accumulate stats after updating the weights backoffDuration time.Duration // Timer is set to give probing some time to take place timer clock.Time // server records that remember original weights servers []*rbServer // next is internal load balancer next in chain next BalancerHandler // errHandler is HTTP handler called in case of errors errHandler utils.ErrorHandler ratings []float64 // creates new meters newMeter NewMeterFn // sticky session object stickySession *StickySession requestRewriteListener RequestRewriteListener debug bool log utils.Logger } // NewRebalancer creates a new Rebalancer. func NewRebalancer(handler BalancerHandler, opts ...RebalancerOption) (*Rebalancer, error) { rb := &Rebalancer{ mtx: &sync.Mutex{}, next: handler, stickySession: nil, log: &utils.NoopLogger{}, } for _, o := range opts { if err := o(rb); err != nil { return nil, err } } if rb.backoffDuration == 0 { rb.backoffDuration = 10 * clock.Second } if rb.newMeter == nil { rb.newMeter = func() (Meter, error) { rc, err := memmetrics.NewRatioCounter(10, clock.Second) if err != nil { return nil, err } return &codeMeter{ r: rc, codeS: http.StatusInternalServerError, codeE: http.StatusGatewayTimeout + 1, }, nil } } if rb.errHandler == nil { rb.errHandler = utils.DefaultHandler } return rb, nil } // Servers gets all servers. func (rb *Rebalancer) Servers() []*url.URL { rb.mtx.Lock() defer rb.mtx.Unlock() return rb.next.Servers() } func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if rb.debug { dump := utils.DumpHTTPRequest(req) rb.log.Debug("vulcand/oxy/roundrobin/rebalancer: begin ServeHttp on request: %s", dump) defer rb.log.Debug("vulcand/oxy/roundrobin/rebalancer: completed ServeHttp on request: %s", dump) } pw := utils.NewProxyWriter(w) start := clock.Now().UTC() // make shallow copy of request before changing anything to avoid side effects newReq := *req stuck := false if rb.stickySession != nil { cookieURL, present, err := rb.stickySession.GetBackend(&newReq, rb.Servers()) if err != nil { rb.log.Warn("vulcand/oxy/roundrobin/rebalancer: error using server from cookie: %v", err) } if present { newReq.URL = cookieURL stuck = true } } if !stuck { fwdURL, err := rb.next.NextServer() if err != nil { rb.errHandler.ServeHTTP(w, req, err) return } if rb.debug { // log which backend URL we're sending this request to rb.log.Debug("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL (%s) :%s", fwdURL, utils.DumpHTTPRequest(req)) } if rb.stickySession != nil { rb.stickySession.StickBackend(fwdURL, w) } newReq.URL = fwdURL } // Emit event to a listener if one exists if rb.requestRewriteListener != nil { rb.requestRewriteListener(req, &newReq) } rb.next.Next().ServeHTTP(pw, &newReq) rb.recordMetrics(newReq.URL, pw.StatusCode(), clock.Now().UTC().Sub(start)) rb.adjustWeights() } func (rb *Rebalancer) recordMetrics(u *url.URL, code int, latency time.Duration) { rb.mtx.Lock() defer rb.mtx.Unlock() if srv, i := rb.findServer(u); i != -1 { srv.meter.Record(code, latency) } } func (rb *Rebalancer) reset() { for _, s := range rb.servers { s.curWeight = s.origWeight _ = rb.next.UpsertServer(s.url, Weight(s.origWeight)) } rb.timer = clock.Now().UTC().Add(-1 * clock.Second) rb.ratings = make([]float64, len(rb.servers)) } // Wrap sets the next handler to be called by rebalancer handler. func (rb *Rebalancer) Wrap(next BalancerHandler) error { if rb.next != nil { return fmt.Errorf("already bound to %T", rb.next) } rb.next = next return nil } // UpsertServer upsert a server. func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error { rb.mtx.Lock() defer rb.mtx.Unlock() if err := rb.next.UpsertServer(u, options...); err != nil { return err } weight, _ := rb.next.ServerWeight(u) if err := rb.upsertServer(u, weight); err != nil { _ = rb.next.RemoveServer(u) return err } rb.reset() return nil } // RemoveServer remove a server. func (rb *Rebalancer) RemoveServer(u *url.URL) error { rb.mtx.Lock() defer rb.mtx.Unlock() return rb.removeServer(u) } func (rb *Rebalancer) removeServer(u *url.URL) error { _, i := rb.findServer(u) if i == -1 { return fmt.Errorf("%v not found", u) } if err := rb.next.RemoveServer(u); err != nil { return err } rb.servers = append(rb.servers[:i], rb.servers[i+1:]...) rb.reset() return nil } func (rb *Rebalancer) upsertServer(u *url.URL, weight int) error { if s, i := rb.findServer(u); i != -1 { s.origWeight = weight } meter, err := rb.newMeter() if err != nil { return err } rbSrv := &rbServer{ url: utils.CopyURL(u), origWeight: weight, curWeight: weight, meter: meter, } rb.servers = append(rb.servers, rbSrv) return nil } func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) { if len(rb.servers) == 0 { return nil, -1 } for i, s := range rb.servers { if sameURL(u, s.url) { return s, i } } return nil, -1 } // adjustWeights Called on every load balancer ServeHTTP call, returns the suggested weights // on every call, can adjust weights if needed. func (rb *Rebalancer) adjustWeights() { rb.mtx.Lock() defer rb.mtx.Unlock() // In this case adjusting weights would have no effect, so do nothing if len(rb.servers) < 2 { return } // Metrics are not ready if !rb.metricsReady() { return } if !rb.timerExpired() { return } if rb.markServers() { if rb.setMarkedWeights() { rb.setTimer() } } else { // No servers that are different by their quality, so converge weights if rb.convergeWeights() { rb.setTimer() } } } func (rb *Rebalancer) applyWeights() { for _, srv := range rb.servers { rb.log.Debug("upsert server %v, weight %v", srv.url, srv.curWeight) _ = rb.next.UpsertServer(srv.url, Weight(srv.curWeight)) } } func (rb *Rebalancer) setMarkedWeights() bool { changed := false // Increase weights on servers marked as good for _, srv := range rb.servers { if srv.good { weight := increase(srv.curWeight) if weight <= FSMMaxWeight { rb.log.Debug("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight) srv.curWeight = weight changed = true } } } if changed { rb.normalizeWeights() rb.applyWeights() return true } return false } func (rb *Rebalancer) setTimer() { rb.timer = clock.Now().UTC().Add(rb.backoffDuration) } func (rb *Rebalancer) timerExpired() bool { return rb.timer.Before(clock.Now().UTC()) } func (rb *Rebalancer) metricsReady() bool { for _, s := range rb.servers { if !s.meter.IsReady() { return false } } return true } // markServers splits servers into two groups of servers with bad and good failure rate. // It does compare relative performances of the servers though, so if all servers have approximately the same error rate // this function returns the result as if all servers are equally good. func (rb *Rebalancer) markServers() bool { for i, srv := range rb.servers { rb.ratings[i] = srv.meter.Rating() } g, b := memmetrics.SplitFloat64(splitThreshold, 0, rb.ratings) for i, srv := range rb.servers { if g[rb.ratings[i]] { srv.good = true } else { srv.good = false } } if len(g) != 0 && len(b) != 0 { rb.log.Debug("bad: %v good: %v, ratings: %v", b, g, rb.ratings) } return len(g) != 0 && len(b) != 0 } func (rb *Rebalancer) convergeWeights() bool { // If we have previously changed servers try to restore weights to the original state changed := false for _, s := range rb.servers { if s.origWeight == s.curWeight { continue } changed = true newWeight := decrease(s.origWeight, s.curWeight) rb.log.Debug("decreasing weight of %v from %v to %v", s.url, s.curWeight, newWeight) s.curWeight = newWeight } if !changed { return false } rb.normalizeWeights() rb.applyWeights() return true } func (rb *Rebalancer) weightsGcd() int { divisor := -1 for _, w := range rb.servers { if divisor == -1 { divisor = w.curWeight } else { divisor = gcd(divisor, w.curWeight) } } return divisor } func (rb *Rebalancer) normalizeWeights() { gcd := rb.weightsGcd() if gcd <= 1 { return } for _, s := range rb.servers { s.curWeight /= gcd } } func increase(weight int) int { return weight * FSMGrowFactor } func decrease(target, current int) int { adjusted := current / FSMGrowFactor if adjusted < target { return target } return adjusted } // rebalancer server record that keeps track of the original weight supplied by user. type rbServer struct { url *url.URL origWeight int // original weight supplied by user curWeight int // current weight good bool meter Meter } type codeMeter struct { r *memmetrics.RatioCounter codeS int codeE int } // Rating gets ratio. func (n *codeMeter) Rating() float64 { return n.r.Ratio() } // Record records a meter. func (n *codeMeter) Record(code int, _ time.Duration) { if code >= n.codeS && code < n.codeE { n.r.IncA(1) } else { n.r.IncB(1) } } // IsReady returns true if the counter is ready. func (n *codeMeter) IsReady() bool { return n.r.IsReady() } oxy-2.0.0/roundrobin/rebalancer_test.go000066400000000000000000000263401450475076400202160ustar00rootroot00000000000000package roundrobin import ( "io" "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" ) func TestRebalancerNormalOperation(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) assert.Equal(t, a.URL, rb.Servers()[0].String()) proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) assert.Equal(t, []string{"a", "a", "a"}, seq(t, proxy.URL, 3)) } func TestRebalancerNoServers(t *testing.T) { fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb) require.NoError(t, err) proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } func TestRebalancerRemoveServer(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) require.NoError(t, rb.RemoveServer(testutils.ParseURI(a.URL))) assert.Equal(t, []string{"b", "b", "b"}, seq(t, proxy.URL, 3)) } // Test scenario when one server goes down after what it recovers. func TestRebalancerRecovery(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) newMeter := func() (Meter, error) { return &testMeter{}, nil } done := testutils.FreezeTime() defer done() rb, err := NewRebalancer(lb, RebalancerMeter(newMeter)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) rb.servers[0].meter.(*testMeter).rating = 0.3 proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.Advance(rb.backoffDuration + clock.Second) } assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) assert.Equal(t, 1, lb.servers[0].weight) assert.Equal(t, FSMMaxWeight, lb.servers[1].weight) // server a is now recovering, the weights should go back to the original state rb.servers[0].meter.(*testMeter).rating = 0 for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.Advance(rb.backoffDuration + clock.Second) } assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, 1, rb.servers[1].curWeight) // Make sure we have applied the weights to the inner load balancer assert.Equal(t, 1, lb.servers[0].weight) assert.Equal(t, 1, lb.servers[1].weight) } // Test scenario when increaing the weight on good endpoints made it worse. func TestRebalancerCascading(t *testing.T) { a, b, d := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("d") defer a.Close() defer b.Close() defer d.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) newMeter := func() (Meter, error) { return &testMeter{}, nil } done := testutils.FreezeTime() defer done() rb, err := NewRebalancer(lb, RebalancerMeter(newMeter)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(d.URL)) require.NoError(t, err) rb.servers[0].meter.(*testMeter).rating = 0.3 proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.Advance(rb.backoffDuration + clock.Second) } // We have increased the load, and the situation became worse as the other servers started failing assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[2].curWeight) // server a is now recovering, the weights should go back to the original state rb.servers[0].meter.(*testMeter).rating = 0.3 rb.servers[1].meter.(*testMeter).rating = 0.2 rb.servers[2].meter.(*testMeter).rating = 0.2 for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.Advance(rb.backoffDuration + clock.Second) } // the algo reverted it back assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, 1, rb.servers[1].curWeight) assert.Equal(t, 1, rb.servers[2].curWeight) } // Test scenario when all servers started failing. func TestRebalancerAllBad(t *testing.T) { a, b, d := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("d") defer a.Close() defer b.Close() defer d.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) newMeter := func() (Meter, error) { return &testMeter{}, nil } done := testutils.FreezeTime() defer done() rb, err := NewRebalancer(lb, RebalancerMeter(newMeter)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(d.URL)) require.NoError(t, err) rb.servers[0].meter.(*testMeter).rating = 0.12 rb.servers[1].meter.(*testMeter).rating = 0.13 rb.servers[2].meter.(*testMeter).rating = 0.11 proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.Advance(rb.backoffDuration + clock.Second) } // load balancer does nothing assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, 1, rb.servers[1].curWeight) assert.Equal(t, 1, rb.servers[2].curWeight) } // Removing the server resets the state. func TestRebalancerReset(t *testing.T) { a, b, d := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("d") defer a.Close() defer b.Close() defer d.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) newMeter := func() (Meter, error) { return &testMeter{}, nil } done := testutils.FreezeTime() defer done() rb, err := NewRebalancer(lb, RebalancerMeter(newMeter)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(d.URL)) require.NoError(t, err) rb.servers[0].meter.(*testMeter).rating = 0.3 rb.servers[1].meter.(*testMeter).rating = 0 rb.servers[2].meter.(*testMeter).rating = 0 proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.Advance(rb.backoffDuration + clock.Second) } // load balancer changed weights assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[2].curWeight) // Removing servers has reset the state err = rb.RemoveServer(testutils.ParseURI(d.URL)) require.NoError(t, err) assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, 1, rb.servers[1].curWeight) } func TestRebalancerRequestRewriteListenerLive(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) done := testutils.FreezeTime() defer done() rb, err := NewRebalancer(lb, RebalancerBackoff(clock.Millisecond)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI("http://localhost:62345")) require.NoError(t, err) proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) for i := 0; i < 1000; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) if i%10 == 0 { clock.Advance(rb.backoffDuration + clock.Second) } } // load balancer changed weights assert.Equal(t, FSMMaxWeight, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) assert.Equal(t, 1, rb.servers[2].curWeight) } func TestRebalancerRequestRewriteListener(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb, RebalancerRequestRewriteListener(func(oldReq *http.Request, newReq *http.Request) { })) require.NoError(t, err) assert.NotNil(t, rb.requestRewriteListener) } func TestRebalancerStickySession(t *testing.T) { a, b, x := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("x") defer a.Close() defer b.Close() defer x.Close() sticky := NewStickySession("test") require.NotNil(t, sticky) fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb, RebalancerStickySession(sticky)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(x.URL)) require.NoError(t, err) proxy := httptest.NewServer(rb) t.Cleanup(proxy.Close) for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "a", string(body)) } require.NoError(t, rb.RemoveServer(testutils.ParseURI(a.URL))) assert.Equal(t, []string{"b", "x", "b"}, seq(t, proxy.URL, 3)) require.NoError(t, rb.RemoveServer(testutils.ParseURI(b.URL))) assert.Equal(t, []string{"x", "x", "x"}, seq(t, proxy.URL, 3)) } type testMeter struct { rating float64 notReady bool } func (tm *testMeter) Rating() float64 { return tm.rating } func (tm *testMeter) Record(int, time.Duration) { } func (tm *testMeter) IsReady() bool { return !tm.notReady } oxy-2.0.0/roundrobin/rr.go000066400000000000000000000144231450475076400155030ustar00rootroot00000000000000// Package roundrobin implements dynamic weighted round robin load balancer http handler package roundrobin import ( "errors" "fmt" "net/http" "net/url" "sync" "github.com/vulcand/oxy/v2/utils" ) // ErrNoServers indicates that there are no servers registered for the given Backend. var ErrNoServers = errors.New("no servers in the pool") // RoundRobin implements dynamic weighted round-robin load balancer http handler. type RoundRobin struct { mutex *sync.Mutex next http.Handler errHandler utils.ErrorHandler // Current index (starts from -1) index int servers []*server currentWeight int stickySession *StickySession requestRewriteListener RequestRewriteListener verbose bool log utils.Logger } // New created a new RoundRobin. func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { rr := &RoundRobin{ next: next, index: -1, mutex: &sync.Mutex{}, servers: []*server{}, stickySession: nil, log: &utils.NoopLogger{}, } for _, o := range opts { if err := o(rr); err != nil { return nil, err } } if rr.errHandler == nil { rr.errHandler = utils.DefaultHandler } return rr, nil } // Next returns the next handler. func (r *RoundRobin) Next() http.Handler { return r.next } func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { if r.verbose { dump := utils.DumpHTTPRequest(req) r.log.Debug("vulcand/oxy/roundrobin/rr: begin ServeHttp on request: %s", dump) defer r.log.Debug("vulcand/oxy/roundrobin/rr: completed ServeHttp on request: %s", dump) } // make shallow copy of request before chaning anything to avoid side effects newReq := *req stuck := false if r.stickySession != nil { cookieURL, present, err := r.stickySession.GetBackend(&newReq, r.Servers()) if err != nil { r.log.Warn("vulcand/oxy/roundrobin/rr: error using server from cookie: %v", err) } if present { newReq.URL = cookieURL stuck = true } } if !stuck { uri, err := r.NextServer() if err != nil { r.errHandler.ServeHTTP(w, req, err) return } if r.stickySession != nil { r.stickySession.StickBackend(uri, w) } newReq.URL = uri } if r.verbose { // log which backend URL we're sending this request to dump := utils.DumpHTTPRequest(req) r.log.Debug("vulcand/oxy/roundrobin/rr: Forwarding this request to URL (%s): %s", newReq.URL, dump) } // Emit event to a listener if one exists if r.requestRewriteListener != nil { r.requestRewriteListener(req, &newReq) } r.next.ServeHTTP(w, &newReq) } // NextServer gets the next server. func (r *RoundRobin) NextServer() (*url.URL, error) { srv, err := r.nextServer() if err != nil { return nil, err } return utils.CopyURL(srv.url), nil } func (r *RoundRobin) nextServer() (*server, error) { r.mutex.Lock() defer r.mutex.Unlock() if len(r.servers) == 0 { return nil, ErrNoServers } // The algo below may look messy, but is actually very simple // it calculates the GCD and subtracts it on every iteration, what interleaves servers // and allows us not to build an iterator every time we readjust weights // GCD across all enabled servers gcd := r.weightGcd() // Maximum weight across all enabled servers max := r.maxWeight() for { r.index = (r.index + 1) % len(r.servers) if r.index == 0 { r.currentWeight -= gcd if r.currentWeight <= 0 { r.currentWeight = max if r.currentWeight == 0 { return nil, fmt.Errorf("all servers have 0 weight") } } } srv := r.servers[r.index] if srv.weight >= r.currentWeight { return srv, nil } } } // RemoveServer remove a server. func (r *RoundRobin) RemoveServer(u *url.URL) error { r.mutex.Lock() defer r.mutex.Unlock() e, index := r.findServerByURL(u) if e == nil { return fmt.Errorf("server not found") } r.servers = append(r.servers[:index], r.servers[index+1:]...) r.resetState() return nil } // Servers gets servers URL. func (r *RoundRobin) Servers() []*url.URL { r.mutex.Lock() defer r.mutex.Unlock() out := make([]*url.URL, len(r.servers)) for i, srv := range r.servers { out[i] = srv.url } return out } // ServerWeight gets the server weight. func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) { r.mutex.Lock() defer r.mutex.Unlock() if s, _ := r.findServerByURL(u); s != nil { return s.weight, true } return -1, false } // UpsertServer In case if server is already present in the load balancer, returns error. func (r *RoundRobin) UpsertServer(u *url.URL, options ...ServerOption) error { r.mutex.Lock() defer r.mutex.Unlock() if u == nil { return fmt.Errorf("server URL can't be nil") } if s, _ := r.findServerByURL(u); s != nil { for _, o := range options { if err := o(s); err != nil { return err } } r.resetState() return nil } srv := &server{url: utils.CopyURL(u)} for _, o := range options { if err := o(srv); err != nil { return err } } if srv.weight == 0 { srv.weight = defaultWeight } r.servers = append(r.servers, srv) r.resetState() return nil } func (r *RoundRobin) resetIterator() { r.index = -1 r.currentWeight = 0 } func (r *RoundRobin) resetState() { r.resetIterator() } func (r *RoundRobin) findServerByURL(u *url.URL) (*server, int) { if len(r.servers) == 0 { return nil, -1 } for i, s := range r.servers { if sameURL(u, s.url) { return s, i } } return nil, -1 } func (r *RoundRobin) maxWeight() int { max := -1 for _, s := range r.servers { if s.weight > max { max = s.weight } } return max } func (r *RoundRobin) weightGcd() int { divisor := -1 for _, s := range r.servers { if divisor == -1 { divisor = s.weight } else { divisor = gcd(divisor, s.weight) } } return divisor } func gcd(a, b int) int { for b != 0 { a, b = b, a%b } return a } // Set additional parameters for the server can be supplied when adding server. type server struct { url *url.URL // Relative weight for the enpoint to other enpoints in the load balancer weight int } var defaultWeight = 1 // SetDefaultWeight sets the default server weight. func SetDefaultWeight(weight int) error { if weight < 0 { return fmt.Errorf("default weight should be >= 0") } defaultWeight = weight return nil } func sameURL(a, b *url.URL) bool { return a.Path == b.Path && a.Host == b.Host && a.Scheme == b.Scheme } oxy-2.0.0/roundrobin/rr_test.go000066400000000000000000000125221450475076400165400ustar00rootroot00000000000000package roundrobin import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/testutils" "github.com/vulcand/oxy/v2/utils" ) func TestNoServers(t *testing.T) { fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } func TestRemoveBadServer(t *testing.T) { lb, err := New(nil) require.NoError(t, err) assert.Error(t, lb.RemoveServer(testutils.ParseURI("http://google.com"))) } func TestCustomErrHandler(t *testing.T) { errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) _, _ = w.Write([]byte(http.StatusText(http.StatusTeapot))) }) fwd := forward.New(false) lb, err := New(fwd, ErrorHandler(errHandler)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) } func TestOneServer(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) assert.Equal(t, []string{"a", "a", "a"}, seq(t, proxy.URL, 3)) } func TestSimple(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL))) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) } func TestRemoveServer(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL))) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) err = lb.RemoveServer(testutils.ParseURI(a.URL)) require.NoError(t, err) assert.Equal(t, []string{"b", "b", "b"}, seq(t, proxy.URL, 3)) } func TestUpsertSame(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) assert.Equal(t, []string{"a", "a", "a"}, seq(t, proxy.URL, 3)) } func TestUpsertWeight(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL))) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) assert.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL), Weight(3))) assert.Equal(t, []string{"b", "b", "a", "b"}, seq(t, proxy.URL, 4)) } func TestWeighted(t *testing.T) { require.NoError(t, SetDefaultWeight(0)) defer func() { _ = SetDefaultWeight(1) }() a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() z := testutils.NewResponder("z") defer z.Close() fwd := forward.New(false) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL), Weight(3))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL), Weight(2))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(z.URL), Weight(0))) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) assert.Equal(t, []string{"a", "a", "b", "a", "b", "a"}, seq(t, proxy.URL, 6)) w, ok := lb.ServerWeight(testutils.ParseURI(a.URL)) assert.Equal(t, 3, w) assert.Equal(t, true, ok) w, ok = lb.ServerWeight(testutils.ParseURI(b.URL)) assert.Equal(t, 2, w) assert.Equal(t, true, ok) w, ok = lb.ServerWeight(testutils.ParseURI(z.URL)) assert.Equal(t, 0, w) assert.Equal(t, true, ok) w, ok = lb.ServerWeight(testutils.ParseURI("http://caramba:4000")) assert.Equal(t, -1, w) assert.Equal(t, false, ok) } func TestRequestRewriteListener(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() fwd := forward.New(false) lb, err := New(fwd, RoundRobinRequestRewriteListener(func(oldReq *http.Request, newReq *http.Request) {})) require.NoError(t, err) assert.NotNil(t, lb.requestRewriteListener) } func seq(t *testing.T, url string, repeat int) []string { t.Helper() var out []string for i := 0; i < repeat; i++ { _, body, err := testutils.Get(url) require.NoError(t, err) out = append(out, string(body)) } return out } oxy-2.0.0/roundrobin/stickycookie/000077500000000000000000000000001450475076400172255ustar00rootroot00000000000000oxy-2.0.0/roundrobin/stickycookie/aes_value.go000066400000000000000000000066431450475076400215310ustar00rootroot00000000000000package stickycookie import ( "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "encoding/binary" "errors" "fmt" "io" "net/url" "strconv" "strings" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) // AESValue manages hashed sticky value. type AESValue struct { block cipher.AEAD ttl time.Duration } // NewAESValue takes a fixed-size key and returns an CookieValue or an error. // Key size must be exactly one of 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256. func NewAESValue(key []byte, ttl time.Duration) (*AESValue, error) { block, err := aes.NewCipher(key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } return &AESValue{block: gcm, ttl: ttl}, nil } // Get hashes the sticky value. func (v *AESValue) Get(raw *url.URL) string { base := raw.String() if v.ttl > 0 { base = fmt.Sprintf("%s|%d", base, clock.Now().UTC().Add(v.ttl).Unix()) } // Nonce is the 64bit nanosecond-resolution time, plus 32bits of crypto/rand, for 96bits (12Bytes). // Theoretically, if 2^32 calls were made in 1 nanoseconds, there might be a repeat. // Adds ~765ns, and 4B heap in 1 alloc nonce := make([]byte, 12) binary.PutVarint(nonce, clock.Now().UnixNano()) rpend := make([]byte, 4) if _, err := io.ReadFull(rand.Reader, rpend); err != nil { // This is a near-impossible error condition on Linux systems. // An error here means rand.Reader (and thus getrandom(2), and thus /dev/urandom) returned // less than 4 bytes of data. /dev/urandom is guaranteed to always return the number of // bytes requested up to 512 bytes on modern kernels. Behavior on non-Linux systems // varies, of course. panic(err) } for i := 0; i < 4; i++ { nonce[i+8] = rpend[i] } obfuscated := v.block.Seal(nil, nonce, []byte(base), nil) // We append the 12byte nonce onto the end of the message obfuscated = append(obfuscated, nonce...) obfuscatedStr := base64.RawURLEncoding.EncodeToString(obfuscated) return obfuscatedStr } // FindURL gets url from array that match the value. func (v *AESValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { rawURL, err := v.fromValue(raw) if err != nil { return nil, err } for _, u := range urls { ok, err := areURLEqual(rawURL, u) if err != nil { return nil, err } if ok { return u, nil } } return nil, nil } func (v *AESValue) fromValue(obfuscatedStr string) (string, error) { obfuscated, err := base64.RawURLEncoding.DecodeString(obfuscatedStr) if err != nil { return "", err } // The first len-12 bytes is the ciphertext, the last 12 bytes is the nonce n := len(obfuscated) - 12 if n <= 0 { // Protect against range errors causing panics return "", errors.New("post-base64-decoded string is too short") } nonce := obfuscated[n:] obfuscated = obfuscated[:n] raw, err := v.block.Open(nil, nonce, obfuscated, nil) if err != nil { return "", err } if v.ttl > 0 { rawParts := strings.Split(string(raw), "|") if len(rawParts) < 2 { return "", fmt.Errorf("TTL set but cookie doesn't contain an expiration: '%s'", raw) } // validate the ttl i, err := strconv.ParseInt(rawParts[1], 10, 64) if err != nil { return "", err } if clock.Now().UTC().After(clock.Unix(i, 0).UTC()) { strTime := clock.Unix(i, 0).UTC().String() return "", fmt.Errorf("TTL expired: '%s' (%s)", raw, strTime) } raw = []byte(rawParts[0]) } return string(raw), nil } oxy-2.0.0/roundrobin/stickycookie/cookie_value.go000066400000000000000000000013471450475076400222260ustar00rootroot00000000000000package stickycookie import "net/url" // CookieValue interface to manage the sticky cookie value format. // It will be used by the load balancer to generate the sticky cookie value and to retrieve the matching url. type CookieValue interface { // Get converts raw value to an expected sticky format. Get(*url.URL) string // FindURL gets url from array that match the value. FindURL(string, []*url.URL) (*url.URL, error) } // areURLEqual compare a string to a url and check if the string is the same as the url value. func areURLEqual(normalized string, u *url.URL) (bool, error) { u1, err := url.Parse(normalized) if err != nil { return false, err } return u1.Scheme == u.Scheme && u1.Host == u.Host && u1.Path == u.Path, nil } oxy-2.0.0/roundrobin/stickycookie/fallback_value.go000066400000000000000000000016071450475076400225130ustar00rootroot00000000000000package stickycookie import ( "errors" "net/url" ) // FallbackValue manages hashed sticky value. type FallbackValue struct { from CookieValue to CookieValue } // NewFallbackValue creates a new FallbackValue. func NewFallbackValue(from CookieValue, to CookieValue) (*FallbackValue, error) { if from == nil || to == nil { return nil, errors.New("from and to are mandatory") } return &FallbackValue{from: from, to: to}, nil } // Get hashes the sticky value. func (v *FallbackValue) Get(raw *url.URL) string { return v.to.Get(raw) } // FindURL gets url from array that match the value. // If it is a symmetric algorithm, it decodes the URL, otherwise it compares the ciphered values. func (v *FallbackValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { findURL, err := v.from.FindURL(raw, urls) if findURL != nil { return findURL, err } return v.to.FindURL(raw, urls) } oxy-2.0.0/roundrobin/stickycookie/fallback_value_test.go000066400000000000000000000133571450475076400235570ustar00rootroot00000000000000package stickycookie import ( "fmt" "net/url" "path" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" ) func TestFallbackValue_FindURL(t *testing.T) { servers := []*url.URL{ {Scheme: "https", Host: "10.10.10.42", Path: "/"}, {Scheme: "http", Host: "10.10.10.10", Path: "/foo"}, {Scheme: "http", Host: "10.10.10.11", Path: "/", User: url.User("John Doe")}, {Scheme: "http", Host: "10.10.10.10", Path: "/"}, } aesValue, err := NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 5*clock.Second) require.NoError(t, err) values := []struct { Name string CookieValue CookieValue }{ {Name: "rawValue", CookieValue: &RawValue{}}, {Name: "hashValue", CookieValue: &HashValue{Salt: "foo"}}, {Name: "aesValue", CookieValue: aesValue}, } for _, from := range values { from := from for _, to := range values { to := to t.Run(fmt.Sprintf("From: %s, To %s", from.Name, to.Name), func(t *testing.T) { t.Parallel() value, err := NewFallbackValue(from.CookieValue, to.CookieValue) if from.CookieValue == nil && to.CookieValue == nil { assert.Error(t, err) return } require.NoError(t, err) if from.CookieValue != nil { // URL found From value findURL, err := value.FindURL(from.CookieValue.Get(servers[0]), servers) require.NoError(t, err) assert.Equal(t, servers[0], findURL) // URL not found From value findURL, _ = value.FindURL(from.CookieValue.Get(mustJoin(t, servers[0], "bar")), servers) assert.Nil(t, findURL) } if to.CookieValue != nil { // URL found To Value findURL, err := value.FindURL(to.CookieValue.Get(servers[0]), servers) require.NoError(t, err) assert.Equal(t, servers[0], findURL) // URL not found To value findURL, _ = value.FindURL(to.CookieValue.Get(mustJoin(t, servers[0], "bar")), servers) assert.Nil(t, findURL) } }) } } } func TestFallbackValue_FindURL_error(t *testing.T) { servers := []*url.URL{ {Scheme: "http", Host: "10.10.10.10", Path: "/"}, {Scheme: "https", Host: "10.10.10.42", Path: "/"}, {Scheme: "http", Host: "10.10.10.10", Path: "/foo"}, {Scheme: "http", Host: "10.10.10.11", Path: "/", User: url.User("John Doe")}, } hashValue := &HashValue{Salt: "foo"} rawValue := &RawValue{} aesValue, err := NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 5*clock.Second) require.NoError(t, err) tests := []struct { name string From CookieValue To CookieValue rawValue string want *url.URL expectError bool expectErrorOnNew bool }{ { name: "From RawValue To HashValue with RawValue value", From: rawValue, To: hashValue, rawValue: "http://10.10.10.10/", want: servers[0], }, { name: "From RawValue To HashValue with RawValue non matching value", From: rawValue, To: hashValue, rawValue: "http://24.10.10.10/", }, { name: "From RawValue To HashValue with HashValue value", From: rawValue, To: hashValue, rawValue: hashValue.Get(mustParse(t, "http://10.10.10.10/")), want: servers[0], }, { name: "From RawValue To HashValue with HashValue non matching value", From: rawValue, To: hashValue, rawValue: hashValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From HashValue To AESValue with AESValue value", From: hashValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://10.10.10.10/")), want: servers[0], }, { name: "From HashValue To AESValue with AESValue non matching value", From: hashValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From HashValue To AESValue with HashValue value", From: hashValue, To: aesValue, rawValue: hashValue.Get(mustParse(t, "http://10.10.10.10/")), want: servers[0], }, { name: "From HashValue To AESValue with AESValue non matching value", From: hashValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From AESValue To AESValue with AESValue value", From: aesValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://10.10.10.10/")), want: servers[0], }, { name: "From AESValue To AESValue with AESValue non matching value", From: aesValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From AESValue To HashValue with HashValue non matching value", From: aesValue, To: hashValue, rawValue: hashValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From nil To RawValue", To: hashValue, rawValue: "http://24.10.10.10/", expectErrorOnNew: true, }, { name: "From RawValue To nil", From: rawValue, rawValue: "http://24.10.10.10/", expectErrorOnNew: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { value, err := NewFallbackValue(tt.From, tt.To) if tt.expectErrorOnNew { assert.Error(t, err) return } require.NoError(t, err) findURL, err := value.FindURL(tt.rawValue, servers) if tt.expectError { assert.Error(t, err) return } require.NoError(t, err) assert.Equal(t, tt.want, findURL) }) } } func mustJoin(t *testing.T, u *url.URL, part string) *url.URL { t.Helper() nu, err := u.Parse(path.Join(u.Path, part)) require.NoError(t, err) return nu } func mustParse(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) require.NoError(t, err) return u } oxy-2.0.0/roundrobin/stickycookie/hash_value.go000066400000000000000000000014601450475076400216740ustar00rootroot00000000000000package stickycookie import ( "fmt" "net/url" "github.com/segmentio/fasthash/fnv1a" ) // HashValue manages hashed sticky value. type HashValue struct { // Salt secret to anonymize the hashed cookie Salt string } // Get hashes the sticky value. func (v *HashValue) Get(raw *url.URL) string { return v.hash(raw.String()) } // FindURL gets url from array that match the value. func (v *HashValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { for _, u := range urls { if raw == v.hash(normalized(u)) { return u, nil } } return nil, nil } func (v *HashValue) hash(input string) string { return fmt.Sprintf("%x", fnv1a.HashString64(v.Salt+input)) } func normalized(u *url.URL) string { normalized := url.URL{Scheme: u.Scheme, Host: u.Host, Path: u.Path} return normalized.String() } oxy-2.0.0/roundrobin/stickycookie/raw_value.go000066400000000000000000000010001450475076400215300ustar00rootroot00000000000000package stickycookie import ( "net/url" ) // RawValue is a no-op that returns the raw strings as-is. type RawValue struct{} // Get returns the raw value. func (v *RawValue) Get(raw *url.URL) string { return raw.String() } // FindURL gets url from array that match the value. func (v *RawValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { for _, u := range urls { ok, err := areURLEqual(raw, u) if err != nil { return nil, err } if ok { return u, nil } } return nil, nil } oxy-2.0.0/roundrobin/stickysessions.go000066400000000000000000000043441450475076400201560ustar00rootroot00000000000000package roundrobin import ( "errors" "net/http" "net/url" "time" "github.com/vulcand/oxy/v2/roundrobin/stickycookie" ) // CookieOptions has all the options one would like to set on the affinity cookie. type CookieOptions struct { HTTPOnly bool Secure bool Path string Domain string Expires time.Time MaxAge int SameSite http.SameSite } // StickySession is a mixin for load balancers that implements layer 7 (http cookie) session affinity. type StickySession struct { cookieName string cookieValue stickycookie.CookieValue options CookieOptions } // NewStickySession creates a new StickySession. func NewStickySession(cookieName string) *StickySession { return &StickySession{cookieName: cookieName, cookieValue: &stickycookie.RawValue{}} } // NewStickySessionWithOptions creates a new StickySession whilst allowing for options to // shape its affinity cookie such as "httpOnly" or "secure". func NewStickySessionWithOptions(cookieName string, options CookieOptions) *StickySession { return &StickySession{cookieName: cookieName, options: options, cookieValue: &stickycookie.RawValue{}} } // SetCookieValue set the CookieValue for the StickySession. func (s *StickySession) SetCookieValue(value stickycookie.CookieValue) *StickySession { s.cookieValue = value return s } // GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers. func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) { cookie, err := req.Cookie(s.cookieName) if err != nil { if errors.Is(err, http.ErrNoCookie) { return nil, false, nil } return nil, false, err } server, err := s.cookieValue.FindURL(cookie.Value, servers) return server, server != nil, err } // StickBackend creates and sets the cookie. func (s *StickySession) StickBackend(backend *url.URL, w http.ResponseWriter) { opt := s.options cp := "/" if opt.Path != "" { cp = opt.Path } cookie := &http.Cookie{ Name: s.cookieName, Value: s.cookieValue.Get(backend), Path: cp, Domain: opt.Domain, Expires: opt.Expires, MaxAge: opt.MaxAge, Secure: opt.Secure, HttpOnly: opt.HTTPOnly, SameSite: opt.SameSite, } http.SetCookie(w, cookie) } oxy-2.0.0/roundrobin/stickysessions_test.go000066400000000000000000000412301450475076400212100ustar00rootroot00000000000000package roundrobin import ( "fmt" "io" "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/roundrobin/stickycookie" "github.com/vulcand/oxy/v2/testutils" ) func TestBasic(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) client := http.DefaultClient for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "a", string(body)) } } func TestBasicWithHashValue(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) sticky := NewStickySession("test") require.NotNil(t, sticky) sticky.SetCookieValue(&stickycookie.HashValue{Salt: "foo"}) require.NotNil(t, sticky.cookieValue) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) client := http.DefaultClient var cookie *http.Cookie for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) if cookie != nil { req.AddCookie(cookie) } resp, err := client.Do(req) require.NoError(t, err) body, err := io.ReadAll(resp.Body) defer resp.Body.Close() require.NoError(t, err) assert.Equal(t, "a", string(body)) if cookie == nil { // The first request will set the cookie value cookie = resp.Cookies()[0] } assert.Equal(t, "test", cookie.Name) assert.Equal(t, sticky.cookieValue.Get(mustParse(t, a.URL)), cookie.Value) } } func TestBasicWithAESValue(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) sticky := NewStickySession("test") require.NotNil(t, sticky) aesValue, err := stickycookie.NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 5*clock.Second) require.NoError(t, err) sticky.SetCookieValue(aesValue) require.NotNil(t, sticky.cookieValue) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) client := http.DefaultClient var cookie *http.Cookie for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) if cookie != nil { req.AddCookie(cookie) } resp, err := client.Do(req) require.NoError(t, err) body, err := io.ReadAll(resp.Body) defer resp.Body.Close() require.NoError(t, err) assert.Equal(t, "a", string(body)) if cookie == nil { // The first request will set the cookie value cookie = resp.Cookies()[0] } assert.Equal(t, "test", cookie.Name) } } func TestStickyCookie(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) resp, err := http.Get(proxy.URL) require.NoError(t, err) cookie := resp.Cookies()[0] assert.Equal(t, "test", cookie.Name) assert.Equal(t, a.URL, cookie.Value) } func TestStickyCookieWithOptions(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() testCases := []struct { desc string name string options CookieOptions expected *http.Cookie }{ { desc: "no options", name: "test", options: CookieOptions{}, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", Raw: fmt.Sprintf("test=%s; Path=/", a.URL), }, }, { desc: "HTTPOnly", name: "test", options: CookieOptions{ HTTPOnly: true, }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", HttpOnly: true, Raw: fmt.Sprintf("test=%s; Path=/; HttpOnly", a.URL), Unparsed: nil, }, }, { desc: "Secure", name: "test", options: CookieOptions{ Secure: true, }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", Secure: true, Raw: fmt.Sprintf("test=%s; Path=/; Secure", a.URL), }, }, { desc: "Path", name: "test", options: CookieOptions{ Path: "/foo", }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/foo", Raw: fmt.Sprintf("test=%s; Path=/foo", a.URL), }, }, { desc: "Domain", name: "test", options: CookieOptions{ Domain: "example.org", }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", Domain: "example.org", Raw: fmt.Sprintf("test=%s; Path=/; Domain=example.org", a.URL), }, }, { desc: "Expires", name: "test", options: CookieOptions{ Expires: clock.Date(1955, 11, 12, 1, 22, 0, 0, clock.UTC), }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", Expires: clock.Date(1955, 11, 12, 1, 22, 0, 0, clock.UTC), RawExpires: "Sat, 12 Nov 1955 01:22:00 GMT", Raw: fmt.Sprintf("test=%s; Path=/; Expires=Sat, 12 Nov 1955 01:22:00 GMT", a.URL), }, }, { desc: "MaxAge", name: "test", options: CookieOptions{ MaxAge: -20, }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", MaxAge: -1, Raw: fmt.Sprintf("test=%s; Path=/; Max-Age=0", a.URL), }, }, { desc: "SameSite", name: "test", options: CookieOptions{ SameSite: http.SameSiteNoneMode, }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", SameSite: http.SameSiteNoneMode, Raw: fmt.Sprintf("test=%s; Path=/; SameSite=None", a.URL), }, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { fwd := forward.New(false) sticky := NewStickySessionWithOptions(test.name, test.options) require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) resp, err := http.Get(proxy.URL) require.NoError(t, err) require.Len(t, resp.Cookies(), 1) assert.Equal(t, test.expected, resp.Cookies()[0]) }) } } func TestRemoveRespondingServer(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) client := http.DefaultClient for i := 0; i < 10; i++ { req, errReq := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, errReq) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, errReq := client.Do(req) require.NoError(t, errReq) defer resp.Body.Close() body, errReq := io.ReadAll(resp.Body) require.NoError(t, errReq) assert.Equal(t, "a", string(body)) } err = lb.RemoveServer(testutils.ParseURI(a.URL)) require.NoError(t, err) // Now, use the organic cookie response in our next requests. req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, err := client.Do(req) require.NoError(t, err) assert.Equal(t, "test", resp.Cookies()[0].Name) assert.Equal(t, b.URL, resp.Cookies()[0].Value) for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "b", string(body)) } } func TestRemoveAllServers(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd := forward.New(false) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) client := http.DefaultClient for i := 0; i < 10; i++ { req, errReq := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, errReq) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, errReq := client.Do(req) require.NoError(t, errReq) defer resp.Body.Close() body, errReq := io.ReadAll(resp.Body) require.NoError(t, errReq) assert.Equal(t, "a", string(body)) } err = lb.RemoveServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.RemoveServer(testutils.ParseURI(b.URL)) require.NoError(t, err) // Now, use the organic cookie response in our next requests. req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, err := client.Do(req) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) } func TestBadCookieVal(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() fwd := forward.New(false) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) t.Cleanup(proxy.Close) client := http.DefaultClient req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: "This is a patently invalid url! You can't parse it! :-)"}) resp, err := client.Do(req) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "a", string(body)) // Now, cycle off the good server to cause an error err = lb.RemoveServer(testutils.ParseURI(a.URL)) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) _, err = io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) } func TestStickySession_GetBackend(t *testing.T) { cookieName := "Test-Cookie" servers := []*url.URL{ {Scheme: "http", Host: "10.10.10.10", Path: "/"}, {Scheme: "https", Host: "10.10.10.42", Path: "/"}, {Scheme: "http", Host: "10.10.10.10", Path: "/foo"}, {Scheme: "http", Host: "10.10.10.11", Path: "/", User: url.User("John Doe")}, } rawValue := &stickycookie.RawValue{} hashValue := &stickycookie.HashValue{} saltyHashValue := &stickycookie.HashValue{Salt: "test salt"} aesValue, err := stickycookie.NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 5*clock.Second) require.NoError(t, err) aesValueInfinite, err := stickycookie.NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 0) require.NoError(t, err) aesValueExpired, err := stickycookie.NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 1*clock.Nanosecond) require.NoError(t, err) tests := []struct { name string CookieValue stickycookie.CookieValue cookie *http.Cookie want *url.URL expectError bool }{ { name: "NoCookies", }, { name: "Cookie no matched", cookie: &http.Cookie{Name: "not" + cookieName, Value: "http://10.10.10.10/"}, }, { name: "Cookie not a URL", cookie: &http.Cookie{Name: cookieName, Value: "foo://foo bar"}, expectError: true, }, { name: "Simple", cookie: &http.Cookie{Name: cookieName, Value: "http://10.10.10.10/"}, want: servers[0], }, { name: "Host no match for needle", cookie: &http.Cookie{Name: cookieName, Value: "http://10.10.10.255/"}, }, { name: "Scheme no match for needle", cookie: &http.Cookie{Name: cookieName, Value: "https://10.10.10.10/"}, }, { name: "Path no match for needle", cookie: &http.Cookie{Name: cookieName, Value: "http://10.10.10.10/foo/bar"}, }, { name: "With user in haystack but not in needle", cookie: &http.Cookie{Name: cookieName, Value: "http://10.10.10.11/"}, want: servers[3], }, { name: "With user in haystack and in needle", cookie: &http.Cookie{Name: cookieName, Value: "http://John%20Doe@10.10.10.11/"}, want: servers[3], }, { name: "Cookie no matched with RawValue", CookieValue: rawValue, cookie: &http.Cookie{Name: "not" + cookieName, Value: rawValue.Get(mustParse(t, "http://10.10.10.10/"))}, }, { name: "Cookie no matched with HashValue", CookieValue: hashValue, cookie: &http.Cookie{Name: "not" + cookieName, Value: hashValue.Get(mustParse(t, "http://10.10.10.10/"))}, }, { name: "Cookie value not matched with HashValue", CookieValue: hashValue, cookie: &http.Cookie{Name: cookieName, Value: hashValue.Get(mustParse(t, "http://10.10.10.255/"))}, }, { name: "simple with HashValue", CookieValue: hashValue, cookie: &http.Cookie{Name: cookieName, Value: hashValue.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "simple with HashValue and salt", CookieValue: saltyHashValue, cookie: &http.Cookie{Name: cookieName, Value: saltyHashValue.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "Cookie value not matched with AESValue", CookieValue: aesValue, cookie: &http.Cookie{Name: cookieName, Value: aesValue.Get(mustParse(t, "http://10.10.10.255/"))}, }, { name: "simple with AESValue", CookieValue: aesValue, cookie: &http.Cookie{Name: cookieName, Value: aesValue.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "Cookie value not matched with AESValue with ttl 0s", CookieValue: aesValueInfinite, cookie: &http.Cookie{Name: cookieName, Value: aesValueInfinite.Get(mustParse(t, "http://10.10.10.255/"))}, }, { name: "simple with AESValue with ttl 0s", CookieValue: aesValueInfinite, cookie: &http.Cookie{Name: cookieName, Value: aesValueInfinite.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "simple with AESValue with ttl 0s", CookieValue: aesValueInfinite, cookie: &http.Cookie{Name: cookieName, Value: aesValueInfinite.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "simple with AESValue with expired ttl", CookieValue: aesValueExpired, cookie: &http.Cookie{Name: cookieName, Value: aesValueExpired.Get(mustParse(t, "http://10.10.10.10/"))}, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewStickySession(cookieName) if tt.CookieValue != nil { s.SetCookieValue(tt.CookieValue) } req := httptest.NewRequest(http.MethodGet, "http://foo", nil) if tt.cookie != nil { req.AddCookie(tt.cookie) } got, _, err := s.GetBackend(req, servers) if tt.expectError { require.Error(t, err) return } require.NoError(t, err) assert.Equal(t, tt.want, got) }) } } func mustParse(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) require.NoError(t, err) return u } oxy-2.0.0/stream/000077500000000000000000000000001450475076400136375ustar00rootroot00000000000000oxy-2.0.0/stream/options.go000066400000000000000000000006741450475076400156700ustar00rootroot00000000000000package stream import ( "github.com/vulcand/oxy/v2/utils" ) // Option represents an option you can pass to New. type Option func(s *Stream) error // Logger defines the logger used by Stream. func Logger(l utils.Logger) Option { return func(s *Stream) error { s.log = l return nil } } // Verbose additional debug information. func Verbose(verbose bool) Option { return func(s *Stream) error { s.verbose = verbose return nil } } oxy-2.0.0/stream/stream.go000066400000000000000000000043151450475076400154640ustar00rootroot00000000000000/* Package stream provides http.Handler middleware that passes-through the entire request Stream works around several limitations caused by buffering implementations, but also introduces certain risks. Workarounds for buffering limitations: 1. Streaming really large chunks of data (large file transfers, or streaming videos, etc.) 2. Streaming (chunking) sparse data. For example, an implementation might send a health check or a heart beat over a long-lived connection. This does not play well with buffering. Risks: 1. Connections could survive for very long periods of time. 2. There is no easy way to enforce limits on size/time of a connection. Examples of a streaming middleware: // sample HTTP handler. handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) // Stream will literally pass through to the next handler without ANY buffering // or validation of the data. stream.New(handler) */ package stream import ( "net/http" "github.com/vulcand/oxy/v2/utils" ) // DefaultMaxBodyBytes No limit by default. const DefaultMaxBodyBytes = -1 // Stream is responsible for buffering requests and responses // It buffers large requests and responses to disk,. type Stream struct { maxRequestBodyBytes int64 maxResponseBodyBytes int64 next http.Handler verbose bool log utils.Logger } // New returns a new streamer middleware. New() function supports optional functional arguments. func New(next http.Handler, setters ...Option) (*Stream, error) { strm := &Stream{ next: next, maxRequestBodyBytes: DefaultMaxBodyBytes, maxResponseBodyBytes: DefaultMaxBodyBytes, log: &utils.NoopLogger{}, } for _, s := range setters { if err := s(strm); err != nil { return nil, err } } return strm, nil } // Wrap sets the next handler to be called by stream handler. func (s *Stream) Wrap(next http.Handler) error { s.next = next return nil } func (s *Stream) ServeHTTP(w http.ResponseWriter, req *http.Request) { if s.verbose { dump := utils.DumpHTTPRequest(req) s.log.Debug("vulcand/oxy/stream: begin ServeHttp on request: %s", dump) defer s.log.Debug("vulcand/oxy/stream: completed ServeHttp on request: %s", dump) } s.next.ServeHTTP(w, req) } oxy-2.0.0/stream/stream_test.go000066400000000000000000000202671450475076400165270ustar00rootroot00000000000000package stream import ( "bufio" "crypto/tls" "fmt" "io" "net" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/forward" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/testutils" ) func TestSimple(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestChunkedEncodingSuccess(t *testing.T) { var reqBody string var contentLength int64 srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { body, err := io.ReadAll(req.Body) require.NoError(t, err) reqBody = string(body) contentLength = req.ContentLength w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) if !ok { panic("expected http.ResponseWriter to be an http.Flusher") } _, _ = fmt.Fprint(w, "Response") flusher.Flush() clock.Sleep(500 * clock.Millisecond) _, _ = fmt.Fprint(w, "in") flusher.Flush() clock.Sleep(500 * clock.Millisecond) _, _ = fmt.Fprint(w, "Chunks") flusher.Flush() }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) conn, err := net.Dial("tcp", testutils.ParseURI(proxy.URL).Host) require.NoError(t, err) _, _ = fmt.Fprint(conn, "POST / HTTP/1.1\r\nHost: 127.0.0.1\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") reader := bufio.NewReader(conn) status, err := reader.ReadString('\n') require.NoError(t, err) _, err = reader.ReadString('\n') // content type require.NoError(t, err) _, err = reader.ReadString('\n') // Date require.NoError(t, err) transferEncoding, err := reader.ReadString('\n') require.NoError(t, err) assert.Equal(t, "Transfer-Encoding: chunked\r\n", transferEncoding) assert.Equal(t, int64(-1), contentLength) assert.Equal(t, "testtest1test2", reqBody) assert.Equal(t, "HTTP/1.1 200 OK\r\n", status) } func TestRequestLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL, testutils.Body("this request is too long")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestResponseLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello, this response is too large")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestFileStreamingResponse(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello, this response is too large to fit in memory")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello, this response is too large to fit in memory", string(body)) } func TestCustomErrorHandler(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello, this response is too large")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestNotModified(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusNotModified) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusNotModified, re.StatusCode) } func TestNoBody(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // Make sure that stream handler preserves TLS settings. func TestPreservesTLS(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) t.Cleanup(srv.Close) // forwarder will proxy the request to whatever destination fwd := forward.New(false) var cs *tls.ConnectionState // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { cs = req.TLS req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewUnstartedServer(st) proxy.StartTLS() t.Cleanup(proxy.Close) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.NotNil(t, cs) } oxy-2.0.0/stream/threshold.go000066400000000000000000000123161450475076400161650ustar00rootroot00000000000000package stream import ( "fmt" "net/http" "github.com/vulcand/predicate" ) type hpredicate func(*context) bool // IsValidExpression check if it's a valid expression. func IsValidExpression(expr string) bool { _, err := parseExpression(expr) return err == nil } // Parses expression in the go language into Failover predicates. func parseExpression(in string) (hpredicate, error) { p, err := predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ AND: and, OR: or, EQ: eq, NEQ: neq, LT: lt, GT: gt, LE: le, GE: ge, }, Functions: map[string]interface{}{ "RequestMethod": requestMethod, "IsNetworkError": isNetworkError, "Attempts": attempts, "ResponseCode": responseCode, }, }) if err != nil { return nil, err } out, err := p.Parse(in) if err != nil { return nil, err } pr, ok := out.(hpredicate) if !ok { return nil, fmt.Errorf("expected predicate, got %T", out) } return pr, nil } // IsNetworkError returns a predicate that returns true if last attempt ended with network error. func isNetworkError() hpredicate { return func(c *context) bool { return c.responseCode == http.StatusBadGateway || c.responseCode == http.StatusGatewayTimeout } } // and returns predicate by joining the passed predicates with logical 'and'. func and(fns ...hpredicate) hpredicate { return func(c *context) bool { for _, fn := range fns { if !fn(c) { return false } } return true } } // or returns predicate by joining the passed predicates with logical 'or'. func or(fns ...hpredicate) hpredicate { return func(c *context) bool { for _, fn := range fns { if fn(c) { return true } } return false } } // not creates negation of the passed predicate. func not(p hpredicate) hpredicate { return func(c *context) bool { return !p(c) } } // eq returns predicate that tests for equality of the value of the mapper and the constant. func eq(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toString: return stringEQ(mapper, value) case toInt: return intEQ(mapper, value) } return nil, fmt.Errorf("unsupported argument: %T", m) } // neq returns predicate that tests for inequality of the value of the mapper and the constant. func neq(m interface{}, value interface{}) (hpredicate, error) { p, err := eq(m, value) if err != nil { return nil, err } return not(p), nil } // lt returns predicate that tests that value of the mapper function is less than the constant. func lt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intLT(mapper, value) default: return nil, fmt.Errorf("unsupported argument: %T", m) } } // le returns predicate that tests that value of the mapper function is less or equal than the constant. func le(m interface{}, value interface{}) (hpredicate, error) { l, err := lt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *context) bool { return l(c) || e(c) }, nil } // gt returns predicate that tests that value of the mapper function is greater than the constant. func gt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intGT(mapper, value) default: return nil, fmt.Errorf("unsupported argument: %T", m) } } // ge returns predicate that tests that value of the mapper function is less or equal than the constant. func ge(m interface{}, value interface{}) (hpredicate, error) { g, err := gt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *context) bool { return g(c) || e(c) }, nil } func stringEQ(m toString, val interface{}) (hpredicate, error) { value, ok := val.(string) if !ok { return nil, fmt.Errorf("expected string, got %T", val) } return func(c *context) bool { return m(c) == value }, nil } func intEQ(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) == value }, nil } func intLT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) < value }, nil } func intGT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) > value }, nil } type context struct { r *http.Request attempt int responseCode int } type toString func(c *context) string type toInt func(c *context) int // RequestMethod returns mapper of the request to its method e.g. POST. func requestMethod() toString { return func(c *context) string { return c.r.Method } } // Attempts returns mapper of the request to the number of proxy attempts. func attempts() toInt { return func(c *context) int { return c.attempt } } // ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code. func responseCode() toInt { return func(c *context) int { return c.responseCode } } oxy-2.0.0/testutils/000077500000000000000000000000001450475076400144045ustar00rootroot00000000000000oxy-2.0.0/testutils/utils.go000066400000000000000000000074721450475076400161050ustar00rootroot00000000000000package testutils import ( "crypto/tls" "errors" "io" "net/http" "net/http/httptest" "net/url" "strings" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/utils" ) // NewHandler creates a new Server. func NewHandler(handler http.HandlerFunc) *httptest.Server { return httptest.NewServer(handler) } // NewResponder creates a new Server with response. func NewResponder(response string) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(response)) })) } // ParseURI is the version of url.ParseRequestURI that panics if incorrect, helpful to shorten the tests. func ParseURI(uri string) *url.URL { out, err := url.ParseRequestURI(uri) if err != nil { panic(err) } return out } // ReqOpts request options. type ReqOpts struct { Host string Method string Body string Headers http.Header Auth *utils.BasicAuth } // ReqOption request option type. type ReqOption func(o *ReqOpts) error // Method sets request method. func Method(m string) ReqOption { return func(o *ReqOpts) error { o.Method = m return nil } } // Host sets request host. func Host(h string) ReqOption { return func(o *ReqOpts) error { o.Host = h return nil } } // Body sets request body. func Body(b string) ReqOption { return func(o *ReqOpts) error { o.Body = b return nil } } // Header sets request header. func Header(name, val string) ReqOption { return func(o *ReqOpts) error { if o.Headers == nil { o.Headers = make(http.Header) } o.Headers.Add(name, val) return nil } } // Headers sets request headers. func Headers(h http.Header) ReqOption { return func(o *ReqOpts) error { if o.Headers == nil { o.Headers = make(http.Header) } utils.CopyHeaders(o.Headers, h) return nil } } // BasicAuth sets request basic auth. func BasicAuth(username, password string) ReqOption { return func(o *ReqOpts) error { o.Auth = &utils.BasicAuth{ Username: username, Password: password, } return nil } } // MakeRequest create and do a request. func MakeRequest(uri string, opts ...ReqOption) (*http.Response, []byte, error) { o := &ReqOpts{} for _, s := range opts { if err := s(o); err != nil { return nil, nil, err } } if o.Method == "" { o.Method = http.MethodGet } request, err := http.NewRequest(o.Method, uri, strings.NewReader(o.Body)) if err != nil { return nil, nil, err } if o.Headers != nil { utils.CopyHeaders(request.Header, o.Headers) } if o.Auth != nil { request.Header.Set("Authorization", o.Auth.String()) } if o.Host != "" { request.Host = o.Host } var tr *http.Transport if strings.HasPrefix(uri, "https") { tr = &http.Transport{ DisableKeepAlives: true, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: request.Host, }, } } else { tr = &http.Transport{ DisableKeepAlives: true, } } client := &http.Client{ Transport: tr, CheckRedirect: func(req *http.Request, via []*http.Request) error { return errors.New("no redirects") }, } response, err := client.Do(request) if err == nil { bodyBytes, errRead := io.ReadAll(response.Body) return response, bodyBytes, errRead } return response, nil, err } // Get do a GET request. func Get(uri string, opts ...ReqOption) (*http.Response, []byte, error) { opts = append(opts, Method(http.MethodGet)) return MakeRequest(uri, opts...) } // Post do a POST request. func Post(uri string, opts ...ReqOption) (*http.Response, []byte, error) { opts = append(opts, Method(http.MethodPost)) return MakeRequest(uri, opts...) } // FreezeTime to the predetermined time. Returns a function that should be // deferred to unfreeze time. Meant for testing. func FreezeTime() func() { clock.Freeze(clock.Date(2012, 3, 4, 5, 6, 7, 0, clock.UTC)) return clock.Unfreeze } oxy-2.0.0/trace/000077500000000000000000000000001450475076400134425ustar00rootroot00000000000000oxy-2.0.0/trace/options.go000066400000000000000000000015771450475076400154760ustar00rootroot00000000000000package trace import "github.com/vulcand/oxy/v2/utils" // Option is a functional option setter for Tracer. type Option func(*Tracer) error // ErrorHandler is a functional argument that sets error handler of the server. func ErrorHandler(h utils.ErrorHandler) Option { return func(t *Tracer) error { t.errHandler = h return nil } } // RequestHeaders adds request headers to capture. func RequestHeaders(headers ...string) Option { return func(t *Tracer) error { t.reqHeaders = append(t.reqHeaders, headers...) return nil } } // ResponseHeaders adds response headers to capture. func ResponseHeaders(headers ...string) Option { return func(t *Tracer) error { t.respHeaders = append(t.respHeaders, headers...) return nil } } // Logger defines the logger the tracer will use. func Logger(l utils.Logger) Option { return func(t *Tracer) error { t.log = l return nil } } oxy-2.0.0/trace/trace.go000066400000000000000000000133631450475076400150750ustar00rootroot00000000000000// Package trace implement structured logging of requests package trace import ( "crypto/tls" "encoding/json" "fmt" "io" "net/http" "strconv" "time" "github.com/vulcand/oxy/v2/internal/holsterv4/clock" "github.com/vulcand/oxy/v2/utils" ) // Tracer records request and response emitting JSON structured data to the output. type Tracer struct { errHandler utils.ErrorHandler next http.Handler reqHeaders []string respHeaders []string writer io.Writer log utils.Logger } // New creates a new Tracer middleware that emits all the request/response information in structured format // to writer and passes the request to the next handler. It can optionally capture request and response headers, // see RequestHeaders and ResponseHeaders options for details. func New(next http.Handler, writer io.Writer, opts ...Option) (*Tracer, error) { t := &Tracer{ writer: writer, next: next, log: &utils.NoopLogger{}, } for _, o := range opts { if err := o(t); err != nil { return nil, err } } if t.errHandler == nil { t.errHandler = utils.DefaultHandler } return t, nil } func (t *Tracer) ServeHTTP(w http.ResponseWriter, req *http.Request) { start := clock.Now() pw := utils.NewProxyWriterWithLogger(w, t.log) t.next.ServeHTTP(pw, req) l := t.newRecord(req, pw, clock.Since(start)) if err := json.NewEncoder(t.writer).Encode(l); err != nil { t.log.Error("Failed to marshal request: %v", err) } } func (t *Tracer) newRecord(req *http.Request, pw *utils.ProxyWriter, diff time.Duration) *Record { return &Record{ Request: Request{ Method: req.Method, URL: req.URL.String(), TLS: newTLS(req), BodyBytes: bodyBytes(req.Header), Headers: captureHeaders(req.Header, t.reqHeaders), }, Response: Response{ Code: pw.StatusCode(), BodyBytes: bodyBytes(pw.Header()), Roundtrip: float64(diff) / float64(clock.Millisecond), Headers: captureHeaders(pw.Header(), t.respHeaders), }, } } func captureHeaders(in http.Header, headers []string) http.Header { if len(headers) == 0 || in == nil { return nil } out := make(http.Header, len(headers)) for _, h := range headers { vals, ok := in[h] if !ok || len(out[h]) != 0 { continue } for i := range vals { out.Add(h, vals[i]) } } return out } // Record represents a structured request and response record. type Record struct { Request Request `json:"request"` Response Response `json:"response"` } // Request contains information about an HTTP request. type Request struct { Method string `json:"method"` // Method - request method BodyBytes int64 `json:"body_bytes"` // BodyBytes - size of request body in bytes URL string `json:"url"` // URL - Request URL Headers http.Header `json:"headers,omitempty"` // Headers - optional request headers, will be recorded if configured TLS *TLS `json:"tls,omitempty"` // TLS - optional TLS record, will be recorded if it's a TLS connection } // Response contains information about HTTP response. type Response struct { Code int `json:"code"` // Code - response status code Roundtrip float64 `json:"roundtrip"` // Roundtrip - round trip time in milliseconds Headers http.Header `json:"headers,omitempty"` // Headers - optional headers, will be recorded if configured BodyBytes int64 `json:"body_bytes"` // BodyBytes - size of response body in bytes } // TLS contains information about this TLS connection. type TLS struct { Version string `json:"version"` // Version - TLS version Resume bool `json:"resume"` // Resume tells if the session has been re-used (session tickets) CipherSuite string `json:"cipher_suite"` // CipherSuite contains cipher suite used for this connection Server string `json:"server"` // Server contains server name used in SNI } func newTLS(req *http.Request) *TLS { if req.TLS == nil { return nil } return &TLS{ Version: versionToString(req.TLS.Version), Resume: req.TLS.DidResume, CipherSuite: csToString(req.TLS.CipherSuite), Server: req.TLS.ServerName, } } func versionToString(v uint16) string { switch v { case tls.VersionTLS10: return "TLS10" case tls.VersionTLS11: return "TLS11" case tls.VersionTLS12: return "TLS12" } return fmt.Sprintf("unknown: %x", v) } func csToString(cs uint16) string { switch cs { case tls.TLS_RSA_WITH_RC4_128_SHA: return "TLS_RSA_WITH_RC4_128_SHA" case tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: return "TLS_RSA_WITH_3DES_EDE_CBC_SHA" case tls.TLS_RSA_WITH_AES_128_CBC_SHA: return "TLS_RSA_WITH_AES_128_CBC_SHA" case tls.TLS_RSA_WITH_AES_256_CBC_SHA: return "TLS_RSA_WITH_AES_256_CBC_SHA" case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA" case tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA" case tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" case tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: return "TLS_ECDHE_RSA_WITH_RC4_128_SHA" case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA" case tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA" case tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA" case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" case tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" } return fmt.Sprintf("unknown: %x", cs) } func bodyBytes(h http.Header) int64 { length := h.Get("Content-Length") if length == "" { return 0 } bytes, err := strconv.ParseInt(length, 10, 0) if err == nil { return bytes } return 0 } oxy-2.0.0/trace/trace_test.go000066400000000000000000000056431450475076400161360ustar00rootroot00000000000000package trace import ( "bufio" "bytes" "crypto/tls" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/v2/testutils" "github.com/vulcand/oxy/v2/utils" ) func TestTraceSimple(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Header().Set("Content-Length", "5") _, _ = w.Write([]byte("hello")) }) trace := &bytes.Buffer{} tr, err := New(handler, trace) require.NoError(t, err) srv := httptest.NewServer(tr) t.Cleanup(srv.Close) re, _, err := testutils.MakeRequest(srv.URL+"/hello", testutils.Method(http.MethodPost), testutils.Body("123456")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) var r *Record require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) assert.Equal(t, http.MethodPost, r.Request.Method) assert.Equal(t, "/hello", r.Request.URL) assert.Equal(t, http.StatusOK, r.Response.Code) assert.EqualValues(t, 6, r.Request.BodyBytes) assert.NotEqual(t, float64(0), r.Response.Roundtrip) assert.EqualValues(t, 5, r.Response.BodyBytes) } func TestTraceCaptureHeaders(t *testing.T) { respHeaders := http.Header{ "X-Re-1": []string{"6", "7"}, "X-Re-2": []string{"2", "3"}, } handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { utils.CopyHeaders(w.Header(), respHeaders) _, _ = w.Write([]byte("hello")) }) trace := &bytes.Buffer{} tr, err := New(handler, trace, RequestHeaders("X-Req-B", "X-Req-A"), ResponseHeaders("X-Re-1", "X-Re-2")) require.NoError(t, err) srv := httptest.NewServer(tr) t.Cleanup(srv.Close) reqHeaders := http.Header{"X-Req-A": []string{"1", "2"}, "X-Req-B": []string{"3", "4"}} re, _, err := testutils.Get(srv.URL+"/hello", testutils.Headers(reqHeaders)) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) var r *Record require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) assert.Equal(t, reqHeaders, r.Request.Headers) assert.Equal(t, respHeaders, r.Response.Headers) } func TestTraceTLS(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { _, _ = w.Write([]byte("hello")) }) trace := &bytes.Buffer{} tr, err := New(handler, trace) require.NoError(t, err) srv := httptest.NewUnstartedServer(tr) srv.StartTLS() t.Cleanup(srv.Close) config := &tls.Config{ InsecureSkipVerify: true, } u, err := url.Parse(srv.URL) require.NoError(t, err) conn, err := tls.Dial("tcp", u.Host, config) require.NoError(t, err) _, _ = fmt.Fprint(conn, "GET / HTTP/1.0\r\n\r\n") status, err := bufio.NewReader(conn).ReadString('\n') require.NoError(t, err) assert.Equal(t, "HTTP/1.0 200 OK\r\n", status) state := conn.ConnectionState() _ = conn.Close() var r *Record require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) assert.Equal(t, versionToString(state.Version), r.Request.TLS.Version) } oxy-2.0.0/utils/000077500000000000000000000000001450475076400135045ustar00rootroot00000000000000oxy-2.0.0/utils/auth.go000066400000000000000000000022351450475076400147760ustar00rootroot00000000000000package utils import ( "encoding/base64" "fmt" "strings" ) // BasicAuth basic auth information. type BasicAuth struct { Username string Password string } func (ba *BasicAuth) String() string { encoded := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", ba.Username, ba.Password))) return fmt.Sprintf("Basic %s", encoded) } // ParseAuthHeader creates a new BasicAuth from header values. func ParseAuthHeader(header string) (*BasicAuth, error) { values := strings.Fields(header) if len(values) != 2 { return nil, fmt.Errorf("failed to parse header '%s'", header) } authType := strings.ToLower(values[0]) if authType != "basic" { return nil, fmt.Errorf("expected basic auth type, got '%s'", authType) } encodedString := values[1] decodedString, err := base64.StdEncoding.DecodeString(encodedString) if err != nil { return nil, fmt.Errorf("failed to parse header '%s', base64 failed: %w", header, err) } values = strings.SplitN(string(decodedString), ":", 2) if len(values) != 2 { return nil, fmt.Errorf("failed to parse header '%s', expected separator ':'", header) } return &BasicAuth{Username: values[0], Password: values[1]}, nil } oxy-2.0.0/utils/auth_test.go000066400000000000000000000026131450475076400160350ustar00rootroot00000000000000package utils import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Just to make sure we don't panic, return err and not // username and pass and cover the function. func TestParseBadHeaders(t *testing.T) { headers := []string{ // just empty string "", // missing auth type "justplainstring", // unknown auth type "Whut justplainstring", // invalid base64 "Basic Shmasic", // random encoded string "Basic YW55IGNhcm5hbCBwbGVhcw==", } for _, h := range headers { _, err := ParseAuthHeader(h) require.Error(t, err) } } // Just to make sure we don't panic, return err and not // username and pass and cover the function. func TestParseSuccess(t *testing.T) { headers := []struct { Header string Expected BasicAuth }{ { "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", BasicAuth{Username: "Aladdin", Password: "open sesame"}, }, // Make sure that String() produces valid header { (&BasicAuth{Username: "Alice", Password: "Here's bob"}).String(), BasicAuth{Username: "Alice", Password: "Here's bob"}, }, // empty pass { "Basic QWxhZGRpbjo=", BasicAuth{Username: "Aladdin", Password: ""}, }, } for _, h := range headers { request, err := ParseAuthHeader(h.Header) require.NoError(t, err) assert.Equal(t, h.Expected.Username, request.Username) assert.Equal(t, h.Expected.Password, request.Password) } } oxy-2.0.0/utils/dumpreq.go000066400000000000000000000027661450475076400155230ustar00rootroot00000000000000package utils import ( "crypto/tls" "encoding/json" "fmt" "mime/multipart" "net/http" "net/url" ) // SerializableHTTPRequest serializable HTTP request. // //nolint:musttag // Cannot be changed more now. type SerializableHTTPRequest struct { Method string URL *url.URL Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 Header http.Header ContentLength int64 TransferEncoding []string Host string Form url.Values PostForm url.Values MultipartForm *multipart.Form Trailer http.Header RemoteAddr string RequestURI string TLS *tls.ConnectionState } // Clone clone a request. func Clone(r *http.Request) *SerializableHTTPRequest { if r == nil { return nil } rc := new(SerializableHTTPRequest) rc.Method = r.Method rc.URL = r.URL rc.Proto = r.Proto rc.ProtoMajor = r.ProtoMajor rc.ProtoMinor = r.ProtoMinor rc.Header = r.Header rc.ContentLength = r.ContentLength rc.Host = r.Host rc.RemoteAddr = r.RemoteAddr rc.RequestURI = r.RequestURI return rc } // ToJSON serializes to JSON. func (s *SerializableHTTPRequest) ToJSON() string { jsonVal, err := json.Marshal(s) if err != nil || jsonVal == nil { return fmt.Sprintf("error marshaling SerializableHTTPRequest to json: %s", err) } return string(jsonVal) } // DumpHTTPRequest dump a HTTP request to JSON. func DumpHTTPRequest(req *http.Request) string { return Clone(req).ToJSON() } oxy-2.0.0/utils/dumpreq_test.go000066400000000000000000000012131450475076400165440ustar00rootroot00000000000000package utils import ( "net/http" "net/url" "testing" "github.com/stretchr/testify/assert" ) type readCloserTestImpl struct{} func (r *readCloserTestImpl) Read(_ []byte) (n int, err error) { return 0, nil } func (r *readCloserTestImpl) Close() error { return nil } // Just to make sure we don't panic, return err and not // username and pass and cover the function. func TestHttpReqToString(t *testing.T) { req := &http.Request{ URL: &url.URL{Host: "localhost:2374", Path: "/unittest"}, Method: http.MethodDelete, Cancel: make(chan struct{}), Body: &readCloserTestImpl{}, } assert.True(t, len(DumpHTTPRequest(req)) > 0) } oxy-2.0.0/utils/handler.go000066400000000000000000000032701450475076400154520ustar00rootroot00000000000000package utils import ( "context" "errors" "io" "net" "net/http" ) // StatusClientClosedRequest non-standard HTTP status code for client disconnection. const StatusClientClosedRequest = 499 // StatusClientClosedRequestText non-standard HTTP status for client disconnection. const StatusClientClosedRequestText = "Client Closed Request" // ErrorHandler error handler. type ErrorHandler interface { ServeHTTP(w http.ResponseWriter, req *http.Request, err error) } // DefaultHandler default error handler. var DefaultHandler ErrorHandler = &StdHandler{log: &NoopLogger{}} // StdHandler Standard error handler. type StdHandler struct { log Logger } func (e *StdHandler) ServeHTTP(w http.ResponseWriter, _ *http.Request, err error) { statusCode := http.StatusInternalServerError //nolint:errorlint // must be changed if e, ok := err.(net.Error); ok { if e.Timeout() { statusCode = http.StatusGatewayTimeout } else { statusCode = http.StatusBadGateway } } else if errors.Is(err, io.EOF) { statusCode = http.StatusBadGateway } else if errors.Is(err, context.Canceled) { statusCode = StatusClientClosedRequest } w.WriteHeader(statusCode) _, _ = w.Write([]byte(statusText(statusCode))) e.log.Debug("'%d %s' caused by: %v", statusCode, statusText(statusCode), err) } func statusText(statusCode int) string { if statusCode == StatusClientClosedRequest { return StatusClientClosedRequestText } return http.StatusText(statusCode) } // ErrorHandlerFunc error handler function type. type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, error) // ServeHTTP calls f(w, r). func (f ErrorHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request, err error) { f(w, r, err) } oxy-2.0.0/utils/handler_test.go000066400000000000000000000013211450475076400165040ustar00rootroot00000000000000package utils import ( "bytes" "net/http" "net/http/httptest" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDefaultHandlerErrors(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := w.(http.Hijacker) conn, _, _ := h.Hijack() conn.Close() })) t.Cleanup(srv.Close) request, err := http.NewRequest(http.MethodGet, srv.URL, strings.NewReader("")) require.NoError(t, err) _, err = http.DefaultTransport.RoundTrip(request) w := NewBufferWriter(NopWriteCloser(&bytes.Buffer{}), &NoopLogger{}) DefaultHandler.ServeHTTP(w, nil, err) assert.Equal(t, http.StatusBadGateway, w.Code) } oxy-2.0.0/utils/log.go000066400000000000000000000010121450475076400146060ustar00rootroot00000000000000package utils // Logger the logger interface. type Logger interface { Debug(msg string, args ...any) Info(msg string, args ...any) Warn(msg string, args ...any) Error(msg string, args ...any) } // NoopLogger a noop logger. type NoopLogger struct{} // Debug noop. func (*NoopLogger) Debug(string, ...interface{}) {} // Info noop. func (*NoopLogger) Info(string, ...interface{}) {} // Warn noop. func (*NoopLogger) Warn(string, ...interface{}) {} // Error noop. func (*NoopLogger) Error(string, ...interface{}) {} oxy-2.0.0/utils/netutils.go000066400000000000000000000115741450475076400157120ustar00rootroot00000000000000package utils import ( "bufio" "fmt" "io" "net" "net/http" "net/url" "reflect" ) // ProxyWriter calls recorder, used to debug logs. type ProxyWriter struct { w http.ResponseWriter code int length int64 log Logger } // NewProxyWriter creates a new ProxyWriter. func NewProxyWriter(w http.ResponseWriter) *ProxyWriter { return NewProxyWriterWithLogger(w, &NoopLogger{}) } // NewProxyWriterWithLogger creates a new ProxyWriter. func NewProxyWriterWithLogger(w http.ResponseWriter, l Logger) *ProxyWriter { return &ProxyWriter{ w: w, log: l, } } // StatusCode gets status code. func (p *ProxyWriter) StatusCode() int { if p.code == 0 { // per contract standard lib will set this to http.StatusOK if not set // by user, here we avoid the confusion by mirroring this logic return http.StatusOK } return p.code } // GetLength gets content length. func (p *ProxyWriter) GetLength() int64 { return p.length } // Header gets response header. func (p *ProxyWriter) Header() http.Header { return p.w.Header() } func (p *ProxyWriter) Write(buf []byte) (int, error) { p.length += int64(len(buf)) return p.w.Write(buf) } // WriteHeader writes status code. func (p *ProxyWriter) WriteHeader(code int) { p.code = code p.w.WriteHeader(code) } // Flush flush the writer. func (p *ProxyWriter) Flush() { if f, ok := p.w.(http.Flusher); ok { f.Flush() } } // CloseNotify returns a channel that receives at most a single value (true) // when the client connection has gone away. func (p *ProxyWriter) CloseNotify() <-chan bool { if cn, ok := p.w.(http.CloseNotifier); ok { return cn.CloseNotify() } p.log.Debug("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.w)) return make(<-chan bool) } // Hijack lets the caller take over the connection. func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hi, ok := p.w.(http.Hijacker); ok { return hi.Hijack() } p.log.Debug("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.w)) return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.w)) } // NewBufferWriter creates a new BufferWriter. func NewBufferWriter(w io.WriteCloser, l Logger) *BufferWriter { return &BufferWriter{ W: w, H: make(http.Header), log: l, } } // BufferWriter buffer writer. type BufferWriter struct { H http.Header Code int W io.WriteCloser log Logger } // Close close the writer. func (b *BufferWriter) Close() error { return b.W.Close() } // Header gets response header. func (b *BufferWriter) Header() http.Header { return b.H } func (b *BufferWriter) Write(buf []byte) (int, error) { return b.W.Write(buf) } // WriteHeader writes status code. func (b *BufferWriter) WriteHeader(code int) { b.Code = code } // CloseNotify returns a channel that receives at most a single value (true) // when the client connection has gone away. func (b *BufferWriter) CloseNotify() <-chan bool { if cn, ok := b.W.(http.CloseNotifier); ok { return cn.CloseNotify() } b.log.Warn("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.W)) return make(<-chan bool) } // Hijack lets the caller take over the connection. func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hi, ok := b.W.(http.Hijacker); ok { return hi.Hijack() } b.log.Debug("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.W)) return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.W)) } type nopWriteCloser struct { io.Writer } func (*nopWriteCloser) Close() error { return nil } // NopWriteCloser returns a WriteCloser with a no-op Close method wrapping // the provided Writer w. func NopWriteCloser(w io.Writer) io.WriteCloser { return &nopWriteCloser{Writer: w} } // CopyURL provides update safe copy by avoiding shallow copying User field. func CopyURL(i *url.URL) *url.URL { out := *i if i.User != nil { u := *i.User out.User = &u } return &out } // CopyHeaders copies http headers from source to destination, it // does not override, but adds multiple headers. func CopyHeaders(dst http.Header, src http.Header) { for k, vv := range src { dst[k] = append(dst[k], vv...) } } // HasHeaders determines whether any of the header names is present in the http headers. func HasHeaders(names []string, headers http.Header) bool { for _, h := range names { if headers.Get(h) != "" { return true } } return false } // RemoveHeaders removes the header with the given names from the headers map. func RemoveHeaders(headers http.Header, names ...string) { for _, h := range names { headers.Del(h) } } oxy-2.0.0/utils/netutils_test.go000066400000000000000000000046061450475076400167470ustar00rootroot00000000000000package utils import ( "net/http" "net/url" "testing" "github.com/stretchr/testify/assert" ) // Make sure copy does it right, so the copied url is safe to alter without modifying the other. func TestCopyUrl(t *testing.T) { userinfo := url.UserPassword("foo", "secret") urlA := &url.URL{ Scheme: "http", Host: "localhost:5000", Path: "/upstream", Opaque: "opaque", RawQuery: "a=1&b=2", Fragment: "#hello", User: userinfo, } urlB := CopyURL(urlA) assert.Equal(t, urlA, urlB) *userinfo = *url.User("bar") assert.Equal(t, urlA.User, userinfo) assert.NotEqual(t, urlA.User, urlB.User) urlB.Scheme = "https" assert.NotEqual(t, urlA, urlB) } // Make sure copy headers is not shallow and copies all headers. func TestCopyHeaders(t *testing.T) { source, destination := make(http.Header), make(http.Header) source.Add("a", "b") source.Add("c", "d") CopyHeaders(destination, source) assert.Equal(t, "b", destination.Get("a")) assert.Equal(t, "d", destination.Get("c")) // make sure that altering source does not affect the destination source.Del("a") assert.Equal(t, "", source.Get("a")) assert.Equal(t, "b", destination.Get("a")) } func TestHasHeaders(t *testing.T) { source := make(http.Header) source.Add("a", "b") source.Add("c", "d") assert.True(t, HasHeaders([]string{"a", "f"}, source)) assert.False(t, HasHeaders([]string{"i", "j"}, source)) } func TestRemoveHeaders(t *testing.T) { source := make(http.Header) source.Add("a", "b") source.Add("a", "m") source.Add("c", "d") RemoveHeaders(source, "a") assert.Equal(t, "", source.Get("a")) assert.Equal(t, "d", source.Get("c")) } func BenchmarkCopyHeaders(b *testing.B) { dstHeaders := make([]http.Header, 0, b.N) sourceHeaders := make([]http.Header, 0, b.N) for n := 0; n < b.N; n++ { // example from a reverse proxy merging headers d := http.Header{} d.Add("Request-Id", "1bd36bcc-a0d1-4fc7-aedc-20bbdefa27c5") dstHeaders = append(dstHeaders, d) s := http.Header{} s.Add("Content-Length", "374") s.Add("Context-Type", "text/html; charset=utf-8") s.Add("Etag", `"op14g6ae"`) s.Add("Last-Modified", "Wed, 26 Apr 2017 18:24:06 GMT") s.Add("Server", "Caddy") s.Add("Date", "Fri, 28 Apr 2017 15:54:01 GMT") s.Add("Accept-Ranges", "bytes") sourceHeaders = append(sourceHeaders, s) } b.ResetTimer() for n := 0; n < b.N; n++ { CopyHeaders(dstHeaders[n], sourceHeaders[n]) } } oxy-2.0.0/utils/source.go000066400000000000000000000035261450475076400153410ustar00rootroot00000000000000package utils import ( "fmt" "net/http" "strings" ) // SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that // identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters // error should be returned when source can not be identified. type SourceExtractor interface { Extract(req *http.Request) (token string, amount int64, err error) } // ExtractorFunc extractor function type. type ExtractorFunc func(req *http.Request) (token string, amount int64, err error) // Extract extract from request. func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) { return f(req) } // ExtractSource extract source function type. type ExtractSource func(req *http.Request) // NewExtractor creates a new SourceExtractor. func NewExtractor(variable string) (SourceExtractor, error) { if variable == "client.ip" { return ExtractorFunc(extractClientIP), nil } if variable == "request.host" { return ExtractorFunc(extractHost), nil } if strings.HasPrefix(variable, "request.header.") { header := strings.TrimPrefix(variable, "request.header.") if header == "" { return nil, fmt.Errorf("wrong header: %s", header) } return makeHeaderExtractor(header), nil } return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable) } func extractClientIP(req *http.Request) (string, int64, error) { vals := strings.SplitN(req.RemoteAddr, ":", 2) if vals[0] == "" { return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr) } return vals[0], 1, nil } func extractHost(req *http.Request) (string, int64, error) { return req.Host, 1, nil } func makeHeaderExtractor(header string) SourceExtractor { return ExtractorFunc(func(req *http.Request) (string, int64, error) { return req.Header.Get(header), 1, nil }) }