pax_global_header00006660000000000000000000000064140424666430014523gustar00rootroot0000000000000052 comment=99d4542ee235bdab0fe9f314c2938c0bceaefb7b oxy-1.3.0/000077500000000000000000000000001404246664300123435ustar00rootroot00000000000000oxy-1.3.0/.gitignore000066400000000000000000000004441404246664300143350ustar00rootroot00000000000000# Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe *.test *.prof .idea/ flymake_* vendor/oxy-1.3.0/.travis.yml000066400000000000000000000005741404246664300144620ustar00rootroot00000000000000language: go go: - 1.15.x - 1.x go_import_path: github.com/vulcand/oxy notifications: email: on_success: never on_failure: change env: - GO111MODULE=on before_install: - GO111MODULE=off go get -u golang.org/x/lint/golint - GO111MODULE=off go get -u github.com/client9/misspell/cmd/misspell install: - go mod tidy - git diff --exit-code go.mod go.sum oxy-1.3.0/LICENSE000066400000000000000000000260751404246664300133620ustar00rootroot00000000000000Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright {yyyy} {name of copyright owner} Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. oxy-1.3.0/Makefile000066400000000000000000000020521404246664300140020ustar00rootroot00000000000000.PHONY: all export GO111MODULE=on PKGS := $(shell go list ./... | grep -v '/vendor/') GOFILES := $(shell go list -f '{{range $$index, $$element := .GoFiles}}{{$$.Dir}}/{{$$element}}{{"\n"}}{{end}}' ./... | grep -v '/vendor/') TXT_FILES := $(shell find * -type f -not -path 'vendor/**') default: clean misspell vet check-fmt test test: clean go test -race -cover $(PKGS) test-verbose: clean go test -v -race -cover $(PKGS) clean: find . -name flymake_* -delete rm -f cover.out lint: echo "golint:" golint -set_exit_status $(PKGS) vet: go vet $(PKGS) checks: vet lint check-fmt staticcheck $(PKGS) gosimple $(PKGS) check-fmt: SHELL := /bin/bash check-fmt: diff -u <(echo -n) <(gofmt -d $(GOFILES)) misspell: misspell -source=text -error $(TXT_FILES) test-package: clean go test -v ./$(p) test-grep-package: clean go test -v ./$(p) -check.f=$(e) cover-package: clean go test -v ./$(p) -coverprofile=/tmp/coverage.out go tool cover -html=/tmp/coverage.out sloccount: find . -path ./vendor -prune -o -name "*.go" -print0 | xargs -0 wc -l oxy-1.3.0/README.md000066400000000000000000000064701404246664300136310ustar00rootroot00000000000000Oxy [![Build Status](https://travis-ci.org/vulcand/oxy.svg?branch=master)](https://travis-ci.org/vulcand/oxy) ===== Oxy is a Go library with HTTP handlers that enhance HTTP standard library: * [Buffer](https://pkg.go.dev/github.com/vulcand/oxy/buffer) retries and buffers requests and responses * [Stream](https://pkg.go.dev/github.com/vulcand/oxy/stream) passes-through requests, supports chunked encoding with configurable flush interval * [Forward](https://pkg.go.dev/github.com/vulcand/oxy/forward) forwards requests to remote location and rewrites headers * [Roundrobin](https://pkg.go.dev/github.com/vulcand/oxy/roundrobin) is a round-robin load balancer * [Circuit Breaker](https://pkg.go.dev/github.com/vulcand/oxy/cbreaker) Hystrix-style circuit breaker * [Connlimit](https://pkg.go.dev/github.com/vulcand/oxy/connlimit) Simultaneous connections limiter * [Ratelimit](https://pkg.go.dev/github.com/vulcand/oxy/ratelimit) Rate limiter (based on tokenbucket algo) * [Trace](https://pkg.go.dev/github.com/vulcand/oxy/trace) Structured request and response logger It is designed to be fully compatible with http standard library, easy to customize and reuse. Status ------ * Initial design is completed * Covered by tests * Used as a reverse proxy engine in [Vulcand](https://github.com/vulcand/vulcand) Quickstart ----------- Every handler is ``http.Handler``, so writing and plugging in a middleware is easy. Let us write a simple reverse proxy as an example: Simple reverse proxy ==================== ```go import ( "net/http" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/testutils" ) // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers fwd, _ := forward.New() redirect := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // let us forward this request to another server req.URL = testutils.ParseURI("http://localhost:63450") fwd.ServeHTTP(w, req) }) // that's it! our reverse proxy is ready! s := &http.Server{ Addr: ":8080", Handler: redirect, } s.ListenAndServe() ``` As a next step, let us add a round robin load-balancer: ```go import ( "net/http" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/roundrobin" ) // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers fwd, _ := forward.New() lb, _ := roundrobin.New(fwd) lb.UpsertServer(url1) lb.UpsertServer(url2) s := &http.Server{ Addr: ":8080", Handler: lb, } s.ListenAndServe() ``` What if we want to handle retries and replay the request in case of errors? `buffer` handler will help: ```go import ( "net/http" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/buffer" "github.com/vulcand/oxy/roundrobin" ) // Forwards incoming requests to whatever location URL points to, adds proper forwarding headers fwd, _ := forward.New() lb, _ := roundrobin.New(fwd) // buffer will read the request body and will replay the request again in case if forward returned status // corresponding to nework error (e.g. Gateway Timeout) buffer, _ := buffer.New(lb, buffer.Retry(`IsNetworkError() && Attempts() < 2`)) lb.UpsertServer(url1) lb.UpsertServer(url2) // that's it! our reverse proxy is ready! s := &http.Server{ Addr: ":8080", Handler: buffer, } s.ListenAndServe() ``` oxy-1.3.0/buffer/000077500000000000000000000000001404246664300136145ustar00rootroot00000000000000oxy-1.3.0/buffer/buffer.go000066400000000000000000000310721404246664300154170ustar00rootroot00000000000000/* Package buffer provides http.Handler middleware that solves several problems when dealing with http requests: Reads the entire request and response into buffer, optionally buffering it to disk for large requests. Checks the limits for the requests and responses, rejecting in case if the limit was exceeded. Changes request content-transfer-encoding from chunked and provides total size to the handlers. Examples of a buffering middleware: // sample HTTP handler handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) // Buffer will read the body in buffer before passing the request to the handler // calculate total size of the request and transform it from chunked encoding // before passing to the server buffer.New(handler) // This version will buffer up to 2MB in memory and will serialize any extra // to a temporary file, if the request size exceeds 10MB it will reject the request buffer.New(handler, buffer.MemRequestBodyBytes(2 * 1024 * 1024), buffer.MaxRequestBodyBytes(10 * 1024 * 1024)) // Will do the same as above, but with responses buffer.New(handler, buffer.MemResponseBodyBytes(2 * 1024 * 1024), buffer.MaxResponseBodyBytes(10 * 1024 * 1024)) // Buffer will replay the request if the handler returns error at least 3 times // before returning the response buffer.New(handler, buffer.Retry(`IsNetworkError() && Attempts() <= 2`)) */ package buffer import ( "bufio" "fmt" "io" "io/ioutil" "net" "net/http" "reflect" "github.com/mailgun/multibuf" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) const ( // DefaultMemBodyBytes Store up to 1MB in RAM DefaultMemBodyBytes = 1048576 // DefaultMaxBodyBytes No limit by default DefaultMaxBodyBytes = -1 // DefaultMaxRetryAttempts Maximum retry attempts DefaultMaxRetryAttempts = 10 ) var errHandler utils.ErrorHandler = &SizeErrHandler{} // Buffer is responsible for buffering requests and responses // It buffers large requests and responses to disk, type Buffer struct { maxRequestBodyBytes int64 memRequestBodyBytes int64 maxResponseBodyBytes int64 memResponseBodyBytes int64 retryPredicate hpredicate next http.Handler errHandler utils.ErrorHandler log *log.Logger } // New returns a new buffer middleware. New() function supports optional functional arguments func New(next http.Handler, setters ...optSetter) (*Buffer, error) { strm := &Buffer{ next: next, maxRequestBodyBytes: DefaultMaxBodyBytes, memRequestBodyBytes: DefaultMemBodyBytes, maxResponseBodyBytes: DefaultMaxBodyBytes, memResponseBodyBytes: DefaultMemBodyBytes, log: log.StandardLogger(), } for _, s := range setters { if err := s(strm); err != nil { return nil, err } } if strm.errHandler == nil { strm.errHandler = errHandler } return strm, nil } // Logger defines the logger the buffer will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func Logger(l *log.Logger) optSetter { return func(b *Buffer) error { b.log = l return nil } } type optSetter func(b *Buffer) error // CondSetter Conditional setter. // ex: Cond(a > 4, MemRequestBodyBytes(a)) func CondSetter(condition bool, setter optSetter) optSetter { if !condition { // NoOp setter return func(*Buffer) error { return nil } } return setter } // Retry provides a predicate that allows buffer middleware to replay the request // if it matches certain condition, e.g. returns special error code. Available functions are: // // Attempts() - limits the amount of retry attempts // ResponseCode() - returns http response code // IsNetworkError() - tests if response code is related to networking error // // Example of the predicate: // // `Attempts() <= 2 && ResponseCode() == 502` // func Retry(predicate string) optSetter { return func(b *Buffer) error { p, err := parseExpression(predicate) if err != nil { return err } b.retryPredicate = p return nil } } // ErrorHandler sets error handler of the server func ErrorHandler(h utils.ErrorHandler) optSetter { return func(b *Buffer) error { b.errHandler = h return nil } } // MaxRequestBodyBytes sets the maximum request body size in bytes func MaxRequestBodyBytes(m int64) optSetter { return func(b *Buffer) error { if m < 0 { return fmt.Errorf("max bytes should be >= 0 got %d", m) } b.maxRequestBodyBytes = m return nil } } // MemRequestBodyBytes bytes sets the maximum request body to be stored in memory // buffer middleware will serialize the excess to disk. func MemRequestBodyBytes(m int64) optSetter { return func(b *Buffer) error { if m < 0 { return fmt.Errorf("mem bytes should be >= 0 got %d", m) } b.memRequestBodyBytes = m return nil } } // MaxResponseBodyBytes sets the maximum response body size in bytes func MaxResponseBodyBytes(m int64) optSetter { return func(b *Buffer) error { if m < 0 { return fmt.Errorf("max bytes should be >= 0 got %d", m) } b.maxResponseBodyBytes = m return nil } } // MemResponseBodyBytes sets the maximum response body to be stored in memory // buffer middleware will serialize the excess to disk. func MemResponseBodyBytes(m int64) optSetter { return func(b *Buffer) error { if m < 0 { return fmt.Errorf("mem bytes should be >= 0 got %d", m) } b.memResponseBodyBytes = m return nil } } // Wrap sets the next handler to be called by buffer handler. func (b *Buffer) Wrap(next http.Handler) error { b.next = next return nil } func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if b.log.Level >= log.DebugLevel { logEntry := b.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/buffer: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/buffer: completed ServeHttp on request") } if err := b.checkLimit(req); err != nil { b.log.Errorf("vulcand/oxy/buffer: request body over limit, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } // Read the body while keeping limits in mind. This reader controls the maximum bytes // to read into memory and disk. This reader returns an error if the total request size exceeds the // predefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 // and the reader would be unbounded bufio in the http.Server body, err := multibuf.New(req.Body, multibuf.MaxBytes(b.maxRequestBodyBytes), multibuf.MemBytes(b.memRequestBodyBytes)) if err != nil || body == nil { b.log.Errorf("vulcand/oxy/buffer: error when reading request body, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } // Set request body to buffered reader that can replay the read and execute Seek // Note that we don't change the original request body as it's handled by the http server // and we don't want to mess with standard library defer func() { if body != nil { errClose := body.Close() if errClose != nil { b.log.Errorf("vulcand/oxy/buffer: failed to close body, err: %v", errClose) } } }() // We need to set ContentLength based on known request size. The incoming request may have been // set without content length or using chunked TransferEncoding totalSize, err := body.Size() if err != nil { b.log.Errorf("vulcand/oxy/buffer: failed to get request size, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } if totalSize == 0 { body = nil } outreq := b.copyRequest(req, body, totalSize) attempt := 1 for { // We create a special writer that will limit the response size, buffer it to disk if necessary writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(b.maxResponseBodyBytes), multibuf.MemBytes(b.memResponseBodyBytes)) if err != nil { b.log.Errorf("vulcand/oxy/buffer: failed create response writer, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } // We are mimicking http.ResponseWriter to replace writer with our special writer bw := &bufferWriter{ header: make(http.Header), buffer: writer, responseWriter: w, log: b.log, } defer bw.Close() b.next.ServeHTTP(bw, outreq) if bw.hijacked { b.log.Debugf("vulcand/oxy/buffer: connection was hijacked downstream. Not taking any action in buffer.") return } var reader multibuf.MultiReader if bw.expectBody(outreq) { rdr, err := writer.Reader() if err != nil { b.log.Errorf("vulcand/oxy/buffer: failed to read response, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } defer rdr.Close() reader = rdr } if (b.retryPredicate == nil || attempt > DefaultMaxRetryAttempts) || !b.retryPredicate(&context{r: req, attempt: attempt, responseCode: bw.code}) { utils.CopyHeaders(w.Header(), bw.Header()) w.WriteHeader(bw.code) if reader != nil { io.Copy(w, reader) } return } attempt++ if body != nil { if _, err := body.Seek(0, 0); err != nil { b.log.Errorf("vulcand/oxy/buffer: failed to rewind response body, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } } outreq = b.copyRequest(req, body, totalSize) b.log.Debugf("vulcand/oxy/buffer: retry Request(%v %v) attempt %v", req.Method, req.URL, attempt) } } func (b *Buffer) copyRequest(req *http.Request, body io.ReadCloser, bodySize int64) *http.Request { o := *req o.URL = utils.CopyURL(req.URL) o.Header = make(http.Header) utils.CopyHeaders(o.Header, req.Header) o.ContentLength = bodySize // remove TransferEncoding that could have been previously set because we have transformed the request from chunked encoding o.TransferEncoding = []string{} // http.Transport will close the request body on any error, we are controlling the close process ourselves, so we override the closer here if body == nil { o.Body = ioutil.NopCloser(req.Body) } else { o.Body = ioutil.NopCloser(body.(io.Reader)) } return &o } func (b *Buffer) checkLimit(req *http.Request) error { if b.maxRequestBodyBytes <= 0 { return nil } if req.ContentLength > b.maxRequestBodyBytes { return &multibuf.MaxSizeReachedError{MaxSize: b.maxRequestBodyBytes} } return nil } type bufferWriter struct { header http.Header code int buffer multibuf.WriterOnce responseWriter http.ResponseWriter hijacked bool log *log.Logger } // RFC2616 #4.4 func (b *bufferWriter) expectBody(r *http.Request) bool { if r.Method == "HEAD" { return false } if (b.code >= 100 && b.code < 200) || b.code == 204 || b.code == 304 { return false } // refer to https://github.com/vulcand/oxy/issues/113 // if b.header.Get("Content-Length") == "" && b.header.Get("Transfer-Encoding") == "" { // return false // } if b.header.Get("Content-Length") == "0" { return false } return true } func (b *bufferWriter) Close() error { return b.buffer.Close() } func (b *bufferWriter) Header() http.Header { return b.header } func (b *bufferWriter) Write(buf []byte) (int, error) { length, err := b.buffer.Write(buf) if err != nil { // Since go1.11 (https://github.com/golang/go/commit/8f38f28222abccc505b9a1992deecfe3e2cb85de) // if the writer returns an error, the reverse proxy panics b.log.Error(err) length = len(buf) } return length, nil } // WriteHeader sets rw.Code. func (b *bufferWriter) WriteHeader(code int) { b.code = code } // CloseNotifier interface - this allows downstream connections to be terminated when the client terminates. func (b *bufferWriter) CloseNotify() <-chan bool { if cn, ok := b.responseWriter.(http.CloseNotifier); ok { return cn.CloseNotify() } b.log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) return make(<-chan bool) } // Hijack This allows connections to be hijacked for websockets for instance. func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hi, ok := b.responseWriter.(http.Hijacker); ok { conn, rw, err := hi.Hijack() if err != nil { b.hijacked = true } return conn, rw, err } b.log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) return nil, nil, fmt.Errorf("the response writer wrapped in this proxy does not implement http.Hijacker. Its type is: %v", reflect.TypeOf(b.responseWriter)) } // SizeErrHandler Size error handler type SizeErrHandler struct{} func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { if _, ok := err.(*multibuf.MaxSizeReachedError); ok { w.WriteHeader(http.StatusRequestEntityTooLarge) w.Write([]byte(http.StatusText(http.StatusRequestEntityTooLarge))) return } utils.DefaultHandler.ServeHTTP(w, req, err) } oxy-1.3.0/buffer/buffer_test.go000066400000000000000000000261261404246664300164620ustar00rootroot00000000000000package buffer import ( "bufio" "crypto/tls" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/testutils" "github.com/vulcand/oxy/utils" ) func TestSimple(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestChunkedEncodingSuccess(t *testing.T) { var reqBody string var contentLength int64 srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { body, err := ioutil.ReadAll(req.Body) require.NoError(t, err) reqBody = string(body) contentLength = req.ContentLength w.Write([]byte("hello")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() conn, err := net.Dial("tcp", testutils.ParseURI(proxy.URL).Host) require.NoError(t, err) fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: 127.0.0.1:8080\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") status, err := bufio.NewReader(conn).ReadString('\n') require.NoError(t, err) assert.Equal(t, "testtest1test2", reqBody) assert.Equal(t, "HTTP/1.1 200 OK\r\n", status) assert.EqualValues(t, len(reqBody), contentLength) } func TestChunkedEncodingLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MemRequestBodyBytes(4), MaxRequestBodyBytes(8)) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() conn, err := net.Dial("tcp", testutils.ParseURI(proxy.URL).Host) require.NoError(t, err) fmt.Fprint(conn, "POST / HTTP/1.1\r\nHost: 127.0.0.1:8080\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") status, err := bufio.NewReader(conn).ReadString('\n') require.NoError(t, err) assert.Equal(t, "HTTP/1.1 413 Request Entity Too Large\r\n", status) } func TestChunkedResponse(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { h := w.(http.Hijacker) conn, _, _ := h.Hijack() fmt.Fprintf(conn, "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") conn.Close() }) defer srv.Close() fwd, err := forward.New() require.NoError(t, err) rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, "testtest1test2", string(body)) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, strconv.Itoa(len("testtest1test2")), re.Header.Get("Content-Length")) } func TestRequestLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MaxRequestBodyBytes(4)) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL, testutils.Body("this request is too long")) require.NoError(t, err) assert.Equal(t, http.StatusRequestEntityTooLarge, re.StatusCode) } func TestResponseLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello, this response is too large")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MaxResponseBodyBytes(4)) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } func TestFileStreamingResponse(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello, this response is too large to fit in memory")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MemResponseBodyBytes(4)) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello, this response is too large to fit in memory", string(body)) } func TestCustomErrorHandler(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello, this response is too large")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) w.Write([]byte(http.StatusText(http.StatusTeapot))) }) st, err := New(rdr, MaxResponseBodyBytes(4), ErrorHandler(errHandler)) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) } func TestNotModified(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusNotModified) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusNotModified, re.StatusCode) } func TestNoBody(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // Make sure that stream handler preserves TLS settings func TestPreservesTLS(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) var cs *tls.ConnectionState // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { cs = req.TLS req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewUnstartedServer(st) proxy.StartTLS() defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.NotNil(t, cs) } func TestNotNilBody(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) // During a request check if the request body is no nil before sending to the next middleware // Because we can send a POST request without body assert.NotNil(t, req.Body) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr, MaxRequestBodyBytes(10)) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) re, body, err = testutils.Post(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } oxy-1.3.0/buffer/retry_test.go000066400000000000000000000051151404246664300163510ustar00rootroot00000000000000package buffer import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/roundrobin" "github.com/vulcand/oxy/testutils" ) func TestSuccess(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) proxy := httptest.NewServer(rt) defer proxy.Close() require.NoError(t, lb.UpsertServer(testutils.ParseURI(srv.URL))) re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestRetryOnError(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) proxy := httptest.NewServer(rt) defer proxy.Close() require.NoError(t, lb.UpsertServer(testutils.ParseURI("http://localhost:64321"))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(srv.URL))) re, body, err := testutils.Get(proxy.URL, testutils.Body("some request parameters")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestRetryExceedAttempts(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() lb, rt := newBufferMiddleware(t, `IsNetworkError() && Attempts() <= 2`) proxy := httptest.NewServer(rt) defer proxy.Close() require.NoError(t, lb.UpsertServer(testutils.ParseURI("http://localhost:64321"))) require.NoError(t, lb.UpsertServer(testutils.ParseURI("http://localhost:64322"))) require.NoError(t, lb.UpsertServer(testutils.ParseURI("http://localhost:64323"))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(srv.URL))) re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusBadGateway, re.StatusCode) } func newBufferMiddleware(t *testing.T, p string) (*roundrobin.RoundRobin, *Buffer) { // forwarder will proxy the request to whatever destination fwd, err := forward.New() require.NoError(t, err) // load balancer will round robin request lb, err := roundrobin.New(fwd) require.NoError(t, err) // stream handler will forward requests to redirect, make sure it uses files st, err := New(lb, Retry(p), MemRequestBodyBytes(1)) require.NoError(t, err) return lb, st } oxy-1.3.0/buffer/threshold.go000066400000000000000000000122521404246664300161410ustar00rootroot00000000000000package buffer import ( "fmt" "net/http" "github.com/vulcand/predicate" ) // IsValidExpression check if it's a valid expression func IsValidExpression(expr string) bool { _, err := parseExpression(expr) return err == nil } type context struct { r *http.Request attempt int responseCode int } type hpredicate func(*context) bool // Parses expression in the go language into Failover predicates func parseExpression(in string) (hpredicate, error) { p, err := predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ AND: and, OR: or, EQ: eq, NEQ: neq, LT: lt, GT: gt, LE: le, GE: ge, }, Functions: map[string]interface{}{ "RequestMethod": requestMethod, "IsNetworkError": isNetworkError, "Attempts": attempts, "ResponseCode": responseCode, }, }) if err != nil { return nil, err } out, err := p.Parse(in) if err != nil { return nil, err } pr, ok := out.(hpredicate) if !ok { return nil, fmt.Errorf("expected predicate, got %T", out) } return pr, nil } type toString func(c *context) string type toInt func(c *context) int // RequestMethod returns mapper of the request to its method e.g. POST func requestMethod() toString { return func(c *context) string { return c.r.Method } } // Attempts returns mapper of the request to the number of proxy attempts func attempts() toInt { return func(c *context) int { return c.attempt } } // ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code. func responseCode() toInt { return func(c *context) int { return c.responseCode } } // IsNetworkError returns a predicate that returns true if last attempt ended with network error. func isNetworkError() hpredicate { return func(c *context) bool { return c.responseCode == http.StatusBadGateway || c.responseCode == http.StatusGatewayTimeout } } // and returns predicate by joining the passed predicates with logical 'and' func and(fns ...hpredicate) hpredicate { return func(c *context) bool { for _, fn := range fns { if !fn(c) { return false } } return true } } // or returns predicate by joining the passed predicates with logical 'or' func or(fns ...hpredicate) hpredicate { return func(c *context) bool { for _, fn := range fns { if fn(c) { return true } } return false } } // not creates negation of the passed predicate func not(p hpredicate) hpredicate { return func(c *context) bool { return !p(c) } } // eq returns predicate that tests for equality of the value of the mapper and the constant func eq(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toString: return stringEQ(mapper, value) case toInt: return intEQ(mapper, value) } return nil, fmt.Errorf("unsupported argument: %T", m) } // neq returns predicate that tests for inequality of the value of the mapper and the constant func neq(m interface{}, value interface{}) (hpredicate, error) { p, err := eq(m, value) if err != nil { return nil, err } return not(p), nil } // lt returns predicate that tests that value of the mapper function is less than the constant func lt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intLT(mapper, value) } return nil, fmt.Errorf("unsupported argument: %T", m) } // le returns predicate that tests that value of the mapper function is less or equal than the constant func le(m interface{}, value interface{}) (hpredicate, error) { l, err := lt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *context) bool { return l(c) || e(c) }, nil } // gt returns predicate that tests that value of the mapper function is greater than the constant func gt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intGT(mapper, value) } return nil, fmt.Errorf("unsupported argument: %T", m) } // ge returns predicate that tests that value of the mapper function is less or equal than the constant func ge(m interface{}, value interface{}) (hpredicate, error) { g, err := gt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *context) bool { return g(c) || e(c) }, nil } func stringEQ(m toString, val interface{}) (hpredicate, error) { value, ok := val.(string) if !ok { return nil, fmt.Errorf("expected string, got %T", val) } return func(c *context) bool { return m(c) == value }, nil } func intEQ(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) == value }, nil } func intLT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) < value }, nil } func intGT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) > value }, nil } oxy-1.3.0/cbreaker/000077500000000000000000000000001404246664300141215ustar00rootroot00000000000000oxy-1.3.0/cbreaker/cbreaker.go000066400000000000000000000236441404246664300162370ustar00rootroot00000000000000// Package cbreaker implements circuit breaker similar to https://github.com/Netflix/Hystrix/wiki/How-it-Works // // Vulcan circuit breaker watches the error condtion to match // after which it activates the fallback scenario, e.g. returns the response code // or redirects the request to another location // // Circuit breakers start in the Standby state first, observing responses and watching location metrics. // // Once the Circuit breaker condition is met, it enters the "Tripped" state, where it activates fallback scenario // for all requests during the FallbackDuration time period and reset the stats for the location. // // After FallbackDuration time period passes, Circuit breaker enters "Recovering" state, during that state it will // start passing some traffic back to the endpoints, increasing the amount of passed requests using linear function: // // allowedRequestsRatio = 0.5 * (Now() - StartRecovery())/RecoveryDuration // // Two scenarios are possible in the "Recovering" state: // 1. Condition matches again, this will reset the state to "Tripped" and reset the timer. // 2. Condition does not match, circuit breaker enters "Standby" state // // It is possible to define actions (e.g. webhooks) of transitions between states: // // * OnTripped action is called on transition (Standby -> Tripped) // * OnStandby action is called on transition (Recovering -> Standby) // package cbreaker import ( "fmt" "net/http" "sync" "time" "github.com/mailgun/timetools" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/memmetrics" "github.com/vulcand/oxy/utils" ) // CircuitBreaker is http.Handler that implements circuit breaker pattern type CircuitBreaker struct { m *sync.RWMutex metrics *memmetrics.RTMetrics condition hpredicate fallbackDuration time.Duration recoveryDuration time.Duration onTripped SideEffect onStandby SideEffect state cbState until time.Time rc *ratioController checkPeriod time.Duration lastCheck time.Time fallback http.Handler next http.Handler clock timetools.TimeProvider log *log.Logger } // New creates a new CircuitBreaker middleware func New(next http.Handler, expression string, options ...CircuitBreakerOption) (*CircuitBreaker, error) { cb := &CircuitBreaker{ m: &sync.RWMutex{}, next: next, // Default values. Might be overwritten by options below. clock: &timetools.RealTime{}, checkPeriod: defaultCheckPeriod, fallbackDuration: defaultFallbackDuration, recoveryDuration: defaultRecoveryDuration, fallback: defaultFallback, log: log.StandardLogger(), } for _, s := range options { if err := s(cb); err != nil { return nil, err } } condition, err := parseExpression(expression) if err != nil { return nil, err } cb.condition = condition mt, err := memmetrics.NewRTMetrics() if err != nil { return nil, err } cb.metrics = mt return cb, nil } // Logger defines the logger the circuit breaker will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func Logger(l *log.Logger) CircuitBreakerOption { return func(c *CircuitBreaker) error { c.log = l return nil } } func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) { if c.log.Level >= log.DebugLevel { logEntry := c.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/circuitbreaker: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/circuitbreaker: completed ServeHttp on request") } if c.activateFallback(w, req) { c.fallback.ServeHTTP(w, req) return } c.serve(w, req) } // Fallback sets the fallback handler to be called by circuit breaker handler. func (c *CircuitBreaker) Fallback(f http.Handler) { c.fallback = f } // Wrap sets the next handler to be called by circuit breaker handler. func (c *CircuitBreaker) Wrap(next http.Handler) { c.next = next } // updateState updates internal state and returns true if fallback should be used and false otherwise func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Request) bool { // Quick check with read locks optimized for normal operation use-case if c.isStandby() { return false } // Circuit breaker is in tripped or recovering state c.m.Lock() defer c.m.Unlock() c.log.Warnf("%v is in error state", c) switch c.state { case stateStandby: // someone else has set it to standby just now return false case stateTripped: if c.clock.UtcNow().Before(c.until) { return true } // We have been in active state enough, enter recovering state c.setRecovering() fallthrough case stateRecovering: // We have been in recovering state enough, enter standby and allow request if c.clock.UtcNow().After(c.until) { c.setState(stateStandby, c.clock.UtcNow()) return false } // ratio controller allows this request if c.rc.allowRequest() { return false } return true } return false } func (c *CircuitBreaker) serve(w http.ResponseWriter, req *http.Request) { start := c.clock.UtcNow() p := utils.NewProxyWriterWithLogger(w, c.log) c.next.ServeHTTP(p, req) latency := c.clock.UtcNow().Sub(start) c.metrics.Record(p.StatusCode(), latency) // Note that this call is less expensive than it looks -- checkCondition only performs the real check // periodically. Because of that we can afford to call it here on every single response. c.checkAndSet() } func (c *CircuitBreaker) isStandby() bool { c.m.RLock() defer c.m.RUnlock() return c.state == stateStandby } // String returns log-friendly representation of the circuit breaker state func (c *CircuitBreaker) String() string { switch c.state { case stateTripped, stateRecovering: return fmt.Sprintf("CircuitBreaker(state=%v, until=%v)", c.state, c.until) default: return fmt.Sprintf("CircuitBreaker(state=%v)", c.state) } } // exec executes side effect func (c *CircuitBreaker) exec(s SideEffect) { if s == nil { return } go func() { if err := s.Exec(); err != nil { c.log.Errorf("%v side effect failure: %v", c, err) } }() } func (c *CircuitBreaker) setState(new cbState, until time.Time) { c.log.Debugf("%v setting state to %v, until %v", c, new, until) c.state = new c.until = until switch new { case stateTripped: c.exec(c.onTripped) case stateStandby: c.exec(c.onStandby) } } func (c *CircuitBreaker) timeToCheck() bool { c.m.RLock() defer c.m.RUnlock() return c.clock.UtcNow().After(c.lastCheck) } // Checks if tripping condition matches and sets circuit breaker to the tripped state func (c *CircuitBreaker) checkAndSet() { if !c.timeToCheck() { return } c.m.Lock() defer c.m.Unlock() // Other goroutine could have updated the lastCheck variable before we grabbed mutex if !c.clock.UtcNow().After(c.lastCheck) { return } c.lastCheck = c.clock.UtcNow().Add(c.checkPeriod) if c.state == stateTripped { c.log.Debugf("%v skip set tripped", c) return } if !c.condition(c) { return } c.setState(stateTripped, c.clock.UtcNow().Add(c.fallbackDuration)) c.metrics.Reset() } func (c *CircuitBreaker) setRecovering() { c.setState(stateRecovering, c.clock.UtcNow().Add(c.recoveryDuration)) c.rc = newRatioController(c.clock, c.recoveryDuration, c.log) } // CircuitBreakerOption represents an option you can pass to New. // See the documentation for the individual options below. type CircuitBreakerOption func(*CircuitBreaker) error // Clock allows you to fake che CircuitBreaker's view of the current time. // Intended for unit tests. func Clock(clock timetools.TimeProvider) CircuitBreakerOption { return func(c *CircuitBreaker) error { c.clock = clock return nil } } // FallbackDuration is how long the CircuitBreaker will remain in the Tripped // state before trying to recover. func FallbackDuration(d time.Duration) CircuitBreakerOption { return func(c *CircuitBreaker) error { c.fallbackDuration = d return nil } } // RecoveryDuration is how long the CircuitBreaker will take to ramp up // requests during the Recovering state. func RecoveryDuration(d time.Duration) CircuitBreakerOption { return func(c *CircuitBreaker) error { c.recoveryDuration = d return nil } } // CheckPeriod is how long the CircuitBreaker will wait between successive // checks of the breaker condition. func CheckPeriod(d time.Duration) CircuitBreakerOption { return func(c *CircuitBreaker) error { c.checkPeriod = d return nil } } // OnTripped sets a SideEffect to run when entering the Tripped state. // Only one SideEffect can be set for this hook. func OnTripped(s SideEffect) CircuitBreakerOption { return func(c *CircuitBreaker) error { c.onTripped = s return nil } } // OnStandby sets a SideEffect to run when entering the Standby state. // Only one SideEffect can be set for this hook. func OnStandby(s SideEffect) CircuitBreakerOption { return func(c *CircuitBreaker) error { c.onStandby = s return nil } } // Fallback defines the http.Handler that the CircuitBreaker should route // requests to when it prevents a request from taking its normal path. func Fallback(h http.Handler) CircuitBreakerOption { return func(c *CircuitBreaker) error { c.fallback = h return nil } } // cbState is the state of the circuit breaker type cbState int func (s cbState) String() string { switch s { case stateStandby: return "standby" case stateTripped: return "tripped" case stateRecovering: return "recovering" } return "undefined" } const ( // CircuitBreaker is passing all requests and watching stats stateStandby = iota // CircuitBreaker activates fallback scenario for all requests stateTripped // CircuitBreaker passes some requests to go through, rejecting others stateRecovering ) const ( defaultFallbackDuration = 10 * time.Second defaultRecoveryDuration = 10 * time.Second defaultCheckPeriod = 100 * time.Millisecond ) var defaultFallback = &fallback{} type fallback struct{} func (f *fallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))) } oxy-1.3.0/cbreaker/cbreaker_test.go000066400000000000000000000217511404246664300172730ustar00rootroot00000000000000package cbreaker import ( "fmt" "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/memmetrics" "github.com/vulcand/oxy/testutils" ) const triggerNetRatio = `NetworkErrorRatio() > 0.5` func TestStandbyCycle(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) cb, err := New(handler, triggerNetRatio) require.NoError(t, err) srv := httptest.NewServer(cb) defer srv.Close() re, body, err := testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestFullCycle(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) clock := testutils.GetClock() cb, err := New(handler, triggerNetRatio, Clock(clock)) require.NoError(t, err) srv := httptest.NewServer(cb) defer srv.Close() re, _, err := testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) cb.metrics = statsNetErrors(0.6) clock.CurrentTime = clock.CurrentTime.Add(defaultCheckPeriod + time.Millisecond) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateTripped), cb.state) // Some time has passed, but we are still in trapped state. clock.CurrentTime = clock.CurrentTime.Add(9 * time.Second) re, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusServiceUnavailable, re.StatusCode) assert.Equal(t, cbState(stateTripped), cb.state) // We should be in recovering state by now clock.CurrentTime = clock.CurrentTime.Add(time.Second*1 + time.Millisecond) re, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusServiceUnavailable, re.StatusCode) assert.Equal(t, cbState(stateRecovering), cb.state) // 5 seconds after we should be allowing some requests to pass clock.CurrentTime = clock.CurrentTime.Add(5 * time.Second) allowed := 0 for i := 0; i < 100; i++ { re, _, err = testutils.Get(srv.URL) if re.StatusCode == http.StatusOK && err == nil { allowed++ } } assert.NotEqual(t, 0, allowed) // After some time, all is good and we should be in stand by mode again clock.CurrentTime = clock.CurrentTime.Add(5*time.Second + time.Millisecond) re, _, err = testutils.Get(srv.URL) assert.Equal(t, cbState(stateStandby), cb.state) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestRedirectWithPath(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) fallbackRedirectPath, err := NewRedirectFallback(Redirect{ URL: "http://localhost:6000", PreservePath: true, }) require.NoError(t, err) cb, err := New(handler, triggerNetRatio, Clock(testutils.GetClock()), Fallback(fallbackRedirectPath)) require.NoError(t, err) srv := httptest.NewServer(cb) defer srv.Close() cb.metrics = statsNetErrors(0.6) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return fmt.Errorf("no redirects") }, } re, err := client.Get(srv.URL + "/somePath") require.Error(t, err) assert.Equal(t, http.StatusFound, re.StatusCode) assert.Equal(t, "http://localhost:6000/somePath", re.Header.Get("Location")) } func TestRedirect(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) fallbackRedirect, err := NewRedirectFallback(Redirect{URL: "http://localhost:5000"}) require.NoError(t, err) cb, err := New(handler, triggerNetRatio, Clock(testutils.GetClock()), Fallback(fallbackRedirect)) require.NoError(t, err) srv := httptest.NewServer(cb) defer srv.Close() cb.metrics = statsNetErrors(0.6) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return fmt.Errorf("no redirects") }, } re, err := client.Get(srv.URL + "/somePath") require.Error(t, err) assert.Equal(t, http.StatusFound, re.StatusCode) assert.Equal(t, "http://localhost:5000", re.Header.Get("Location")) } func TestTriggerDuringRecovery(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) clock := testutils.GetClock() cb, err := New(handler, triggerNetRatio, Clock(clock), CheckPeriod(time.Microsecond)) require.NoError(t, err) srv := httptest.NewServer(cb) defer srv.Close() cb.metrics = statsNetErrors(0.6) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateTripped), cb.state) // We should be in recovering state by now clock.CurrentTime = clock.CurrentTime.Add(10*time.Second + time.Millisecond) re, _, err := testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusServiceUnavailable, re.StatusCode) assert.Equal(t, cbState(stateRecovering), cb.state) // We have matched error condition during recovery state and are going back to tripped state clock.CurrentTime = clock.CurrentTime.Add(5 * time.Second) cb.metrics = statsNetErrors(0.6) allowed := 0 for i := 0; i < 100; i++ { re, _, err = testutils.Get(srv.URL) if re.StatusCode == http.StatusOK && err == nil { allowed++ } } assert.NotEqual(t, 0, allowed) assert.Equal(t, cbState(stateTripped), cb.state) } func TestSideEffects(t *testing.T) { srv1Chan := make(chan *http.Request, 1) var srv1Body []byte srv1 := testutils.NewHandler(func(w http.ResponseWriter, r *http.Request) { b, err := ioutil.ReadAll(r.Body) require.NoError(t, err) srv1Body = b w.Write([]byte("srv1")) srv1Chan <- r }) defer srv1.Close() srv2Chan := make(chan *http.Request, 1) srv2 := testutils.NewHandler(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("srv2")) err := r.ParseForm() require.NoError(t, err) srv2Chan <- r }) defer srv2.Close() onTripped, err := NewWebhookSideEffect( Webhook{ URL: fmt.Sprintf("%s/post.json", srv1.URL), Method: http.MethodPost, Headers: map[string][]string{"Content-Type": {"application/json"}}, Body: []byte(`{"Key": ["val1", "val2"]}`), }) require.NoError(t, err) onStandby, err := NewWebhookSideEffect( Webhook{ URL: fmt.Sprintf("%s/post", srv2.URL), Method: http.MethodPost, Form: map[string][]string{"key": {"val1", "val2"}}, }) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) clock := testutils.GetClock() cb, err := New(handler, triggerNetRatio, Clock(clock), CheckPeriod(time.Microsecond), OnTripped(onTripped), OnStandby(onStandby)) require.NoError(t, err) srv := httptest.NewServer(cb) defer srv.Close() cb.metrics = statsNetErrors(0.6) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateTripped), cb.state) select { case req := <-srv1Chan: assert.Equal(t, http.MethodPost, req.Method) assert.Equal(t, "/post.json", req.URL.Path) assert.Equal(t, `{"Key": ["val1", "val2"]}`, string(srv1Body)) assert.Equal(t, "application/json", req.Header.Get("Content-Type")) case <-time.After(time.Second): t.Error("timeout waiting for side effect to kick off") } // Transition to recovering state clock.CurrentTime = clock.CurrentTime.Add(10*time.Second + time.Millisecond) cb.metrics = statsOK() _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateRecovering), cb.state) // Going back to standby clock.CurrentTime = clock.CurrentTime.Add(10*time.Second + time.Millisecond) _, _, err = testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, cbState(stateStandby), cb.state) select { case req := <-srv2Chan: assert.Equal(t, http.MethodPost, req.Method) assert.Equal(t, "/post", req.URL.Path) assert.Equal(t, url.Values{"key": []string{"val1", "val2"}}, req.Form) case <-time.After(time.Second): t.Error("timeout waiting for side effect to kick off") } } func statsOK() *memmetrics.RTMetrics { m, err := memmetrics.NewRTMetrics() if err != nil { panic(err) } return m } func statsNetErrors(threshold float64) *memmetrics.RTMetrics { m, err := memmetrics.NewRTMetrics() if err != nil { panic(err) } for i := 0; i < 100; i++ { if i < int(threshold*100) { m.Record(http.StatusGatewayTimeout, 0) } else { m.Record(http.StatusOK, 0) } } return m } func statsLatencyAtQuantile(_ float64, value time.Duration) *memmetrics.RTMetrics { m, err := memmetrics.NewRTMetrics() if err != nil { panic(err) } m.Record(http.StatusOK, value) return m } func statsResponseCodes(codes ...statusCode) *memmetrics.RTMetrics { m, err := memmetrics.NewRTMetrics() if err != nil { panic(err) } for _, c := range codes { for i := int64(0); i < c.Count; i++ { m.Record(c.Code, 0) } } return m } type statusCode struct { Code int Count int64 } oxy-1.3.0/cbreaker/effect.go000066400000000000000000000034451404246664300157120ustar00rootroot00000000000000package cbreaker import ( "bytes" "fmt" "io" "io/ioutil" "net/http" "net/url" "strings" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) // SideEffect a side effect type SideEffect interface { Exec() error } // Webhook Web hook type Webhook struct { URL string Method string Headers http.Header Form url.Values Body []byte } // WebhookSideEffect a web hook side effect type WebhookSideEffect struct { w Webhook log *log.Logger } // NewWebhookSideEffectsWithLogger creates a new WebhookSideEffect func NewWebhookSideEffectsWithLogger(w Webhook, l *log.Logger) (*WebhookSideEffect, error) { if w.Method == "" { return nil, fmt.Errorf("Supply method") } _, err := url.Parse(w.URL) if err != nil { return nil, err } return &WebhookSideEffect{w: w, log: l}, nil } // NewWebhookSideEffect creates a new WebhookSideEffect func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { return NewWebhookSideEffectsWithLogger(w, log.StandardLogger()) } func (w *WebhookSideEffect) getBody() io.Reader { if len(w.w.Form) != 0 { return strings.NewReader(w.w.Form.Encode()) } if len(w.w.Body) != 0 { return bytes.NewBuffer(w.w.Body) } return nil } // Exec execute the side effect func (w *WebhookSideEffect) Exec() error { r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody()) if err != nil { return err } if len(w.w.Headers) != 0 { utils.CopyHeaders(r.Header, w.w.Headers) } if len(w.w.Form) != 0 { r.Header.Set("Content-Type", "application/x-www-form-urlencoded") } re, err := http.DefaultClient.Do(r) if err != nil { return err } if re.Body != nil { defer re.Body.Close() } body, err := ioutil.ReadAll(re.Body) if err != nil { return err } w.log.Debugf("%v got response: (%s): %s", w, re.Status, string(body)) return nil } oxy-1.3.0/cbreaker/fallback.go000066400000000000000000000054651404246664300162210ustar00rootroot00000000000000package cbreaker import ( "fmt" "net/http" "net/url" "strconv" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) // Response response model type Response struct { StatusCode int ContentType string Body []byte } // ResponseFallback fallback response handler type ResponseFallback struct { r Response log *log.Logger } // NewResponseFallbackWithLogger creates a new ResponseFallback func NewResponseFallbackWithLogger(r Response, l *log.Logger) (*ResponseFallback, error) { if r.StatusCode == 0 { return nil, fmt.Errorf("response code should not be 0") } return &ResponseFallback{r: r, log: l}, nil } // NewResponseFallback creates a new ResponseFallback func NewResponseFallback(r Response) (*ResponseFallback, error) { return NewResponseFallbackWithLogger(r, log.StandardLogger()) } func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { if f.log.Level >= log.DebugLevel { logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/fallback/response: completed ServeHttp on request") } if f.r.ContentType != "" { w.Header().Set("Content-Type", f.r.ContentType) } w.Header().Set("Content-Length", strconv.Itoa(len(f.r.Body))) w.WriteHeader(f.r.StatusCode) _, err := w.Write(f.r.Body) if err != nil { f.log.Errorf("vulcand/oxy/fallback/response: failed to write response, err: %v", err) } } // Redirect redirect model type Redirect struct { URL string PreservePath bool } // RedirectFallback fallback redirect handler type RedirectFallback struct { r Redirect u *url.URL log *log.Logger } // NewRedirectFallbackWithLogger creates a new RedirectFallback func NewRedirectFallbackWithLogger(r Redirect, l *log.Logger) (*RedirectFallback, error) { u, err := url.ParseRequestURI(r.URL) if err != nil { return nil, err } return &RedirectFallback{r: r, u: u, log: l}, nil } // NewRedirectFallback creates a new RedirectFallback func NewRedirectFallback(r Redirect) (*RedirectFallback, error) { return NewRedirectFallbackWithLogger(r, log.StandardLogger()) } func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { if f.log.Level >= log.DebugLevel { logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/fallback/redirect: completed ServeHttp on request") } location := f.u.String() if f.r.PreservePath { location += req.URL.Path } w.Header().Set("Location", location) w.WriteHeader(http.StatusFound) _, err := w.Write([]byte(http.StatusText(http.StatusFound))) if err != nil { f.log.Errorf("vulcand/oxy/fallback/redirect: failed to write response, err: %v", err) } } oxy-1.3.0/cbreaker/predicates.go000066400000000000000000000126561404246664300166050ustar00rootroot00000000000000package cbreaker import ( "fmt" "time" "github.com/vulcand/predicate" ) type hpredicate func(*CircuitBreaker) bool // parseExpression parses expression in the go language into predicates. func parseExpression(in string) (hpredicate, error) { p, err := predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ AND: and, OR: or, EQ: eq, NEQ: neq, LT: lt, LE: le, GT: gt, GE: ge, }, Functions: map[string]interface{}{ "LatencyAtQuantileMS": latencyAtQuantile, "NetworkErrorRatio": networkErrorRatio, "ResponseCodeRatio": responseCodeRatio, }, }) if err != nil { return nil, err } out, err := p.Parse(in) if err != nil { return nil, err } pr, ok := out.(hpredicate) if !ok { return nil, fmt.Errorf("expected predicate, got %T", out) } return pr, nil } type toInt func(c *CircuitBreaker) int type toFloat64 func(c *CircuitBreaker) float64 func latencyAtQuantile(quantile float64) toInt { return func(c *CircuitBreaker) int { h, err := c.metrics.LatencyHistogram() if err != nil { c.log.Errorf("Failed to get latency histogram, for %v error: %v", c, err) return 0 } return int(h.LatencyAtQuantile(quantile) / time.Millisecond) } } func networkErrorRatio() toFloat64 { return func(c *CircuitBreaker) float64 { return c.metrics.NetworkErrorRatio() } } func responseCodeRatio(startA, endA, startB, endB int) toFloat64 { return func(c *CircuitBreaker) float64 { return c.metrics.ResponseCodeRatio(startA, endA, startB, endB) } } // or returns predicate by joining the passed predicates with logical 'or' func or(fns ...hpredicate) hpredicate { return func(c *CircuitBreaker) bool { for _, fn := range fns { if fn(c) { return true } } return false } } // and returns predicate by joining the passed predicates with logical 'and' func and(fns ...hpredicate) hpredicate { return func(c *CircuitBreaker) bool { for _, fn := range fns { if !fn(c) { return false } } return true } } // not creates negation of the passed predicate func not(p hpredicate) hpredicate { return func(c *CircuitBreaker) bool { return !p(c) } } // eq returns predicate that tests for equality of the value of the mapper and the constant func eq(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intEQ(mapper, value) case toFloat64: return float64EQ(mapper, value) } return nil, fmt.Errorf("eq: unsupported argument: %T", m) } // neq returns predicate that tests for inequality of the value of the mapper and the constant func neq(m interface{}, value interface{}) (hpredicate, error) { p, err := eq(m, value) if err != nil { return nil, err } return not(p), nil } // lt returns predicate that tests that value of the mapper function is less than the constant func lt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intLT(mapper, value) case toFloat64: return float64LT(mapper, value) } return nil, fmt.Errorf("lt: unsupported argument: %T", m) } // le returns predicate that tests that value of the mapper function is less or equal than the constant func le(m interface{}, value interface{}) (hpredicate, error) { l, err := lt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *CircuitBreaker) bool { return l(c) || e(c) }, nil } // gt returns predicate that tests that value of the mapper function is greater than the constant func gt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intGT(mapper, value) case toFloat64: return float64GT(mapper, value) } return nil, fmt.Errorf("gt: unsupported argument: %T", m) } // ge returns predicate that tests that value of the mapper function is less or equal than the constant func ge(m interface{}, value interface{}) (hpredicate, error) { g, err := gt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *CircuitBreaker) bool { return g(c) || e(c) }, nil } func intEQ(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) == value }, nil } func float64EQ(m toFloat64, val interface{}) (hpredicate, error) { value, ok := val.(float64) if !ok { return nil, fmt.Errorf("expected float64, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) == value }, nil } func intLT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) < value }, nil } func intGT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) > value }, nil } func float64LT(m toFloat64, val interface{}) (hpredicate, error) { value, ok := val.(float64) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) < value }, nil } func float64GT(m toFloat64, val interface{}) (hpredicate, error) { value, ok := val.(float64) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *CircuitBreaker) bool { return m(c) > value }, nil } oxy-1.3.0/cbreaker/predicates_test.go000066400000000000000000000031421404246664300176320ustar00rootroot00000000000000package cbreaker import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/memmetrics" ) func TestTripped(t *testing.T) { testCases := []struct { expression string metrics *memmetrics.RTMetrics expected bool }{ { expression: "NetworkErrorRatio() > 0.5", metrics: statsNetErrors(0.6), expected: true, }, { expression: "NetworkErrorRatio() < 0.5", metrics: statsNetErrors(0.6), expected: false, }, { expression: "LatencyAtQuantileMS(50.0) > 50", metrics: statsLatencyAtQuantile(50, time.Millisecond*51), expected: true, }, { expression: "LatencyAtQuantileMS(50.0) < 50", metrics: statsLatencyAtQuantile(50, time.Millisecond*51), expected: false, }, { expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", metrics: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 6}), expected: true, }, { expression: "ResponseCodeRatio(500, 600, 0, 600) > 0.5", metrics: statsResponseCodes(statusCode{Code: 200, Count: 5}, statusCode{Code: 500, Count: 4}), expected: false, }, { // quantile not defined expression: "LatencyAtQuantileMS(40.0) > 50", metrics: statsNetErrors(0.6), expected: false, }, } for _, test := range testCases { test := test t.Run(test.expression, func(t *testing.T) { t.Parallel() p, err := parseExpression(test.expression) require.NoError(t, err) require.NotNil(t, p) assert.Equal(t, test.expected, p(&CircuitBreaker{metrics: test.metrics})) }) } } oxy-1.3.0/cbreaker/ratio.go000066400000000000000000000036331404246664300155730ustar00rootroot00000000000000package cbreaker import ( "fmt" "time" "github.com/mailgun/timetools" log "github.com/sirupsen/logrus" ) // ratioController allows passing portions traffic back to the endpoints, // increasing the amount of passed requests using linear function: // // allowedRequestsRatio = 0.5 * (Now() - Start())/Duration // type ratioController struct { duration time.Duration start time.Time tm timetools.TimeProvider allowed int denied int log *log.Logger } func newRatioController(tm timetools.TimeProvider, rampUp time.Duration, log *log.Logger) *ratioController { return &ratioController{ duration: rampUp, tm: tm, start: tm.UtcNow(), log: log, } } func (r *ratioController) String() string { return fmt.Sprintf("RatioController(target=%f, current=%f, allowed=%d, denied=%d)", r.targetRatio(), r.computeRatio(r.allowed, r.denied), r.allowed, r.denied) } func (r *ratioController) allowRequest() bool { r.log.Debugf("%v", r) t := r.targetRatio() // This condition answers the question - would we satisfy the target ratio if we allow this request? e := r.computeRatio(r.allowed+1, r.denied) if e < t { r.allowed++ r.log.Debugf("%v allowed", r) return true } r.denied++ r.log.Debugf("%v denied", r) return false } func (r *ratioController) computeRatio(allowed, denied int) float64 { if denied+allowed == 0 { return 0 } return float64(allowed) / float64(denied+allowed) } func (r *ratioController) targetRatio() float64 { // Here's why it's 0.5: // We are watching the following ratio // ratio = a / (a + d) // We can notice, that once we get to 0.5 // 0.5 = a / (a + d) // we can evaluate that a = d // that means equilibrium, where we would allow all the requests // after this point to achieve ratio of 1 (that can never be reached unless d is 0) // so we stop from there multiplier := 0.5 / float64(r.duration) return multiplier * float64(r.tm.UtcNow().Sub(r.start)) } oxy-1.3.0/cbreaker/ratio_test.go000066400000000000000000000021711404246664300166260ustar00rootroot00000000000000package cbreaker import ( "math" "testing" "time" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/vulcand/oxy/testutils" ) func TestRampUp(t *testing.T) { clock := testutils.GetClock() duration := 10 * time.Second rc := newRatioController(clock, duration, log.StandardLogger()) allowed, denied := 0, 0 for i := 0; i < int(duration/time.Millisecond); i++ { ratio := sendRequest(&allowed, &denied, rc) expected := rc.targetRatio() diff := math.Abs(expected - ratio) assert.EqualValues(t, 0, round(diff, 0.5, 1)) clock.CurrentTime = clock.CurrentTime.Add(time.Millisecond) } } func sendRequest(allowed, denied *int, rc *ratioController) float64 { if rc.allowRequest() { *allowed++ } else { *denied++ } if *allowed+*denied == 0 { return 0 } return float64(*allowed) / float64(*allowed+*denied) } func round(val float64, roundOn float64, places int) float64 { pow := math.Pow(10, float64(places)) digit := pow * val _, div := math.Modf(digit) var round float64 if div >= roundOn { round = math.Ceil(digit) } else { round = math.Floor(digit) } return round / pow } oxy-1.3.0/connlimit/000077500000000000000000000000001404246664300143375ustar00rootroot00000000000000oxy-1.3.0/connlimit/connlimit.go000066400000000000000000000073431404246664300166710ustar00rootroot00000000000000// Package connlimit provides control over simultaneous connections coming from the same source package connlimit import ( "fmt" "net/http" "sync" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) // ConnLimiter tracks concurrent connection per token // and is capable of rejecting connections if they are failed type ConnLimiter struct { mutex *sync.Mutex extract utils.SourceExtractor connections map[string]int64 maxConnections int64 totalConnections int64 next http.Handler errHandler utils.ErrorHandler log *log.Logger } // New creates a new ConnLimiter func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) { if extract == nil { return nil, fmt.Errorf("Extract function can not be nil") } cl := &ConnLimiter{ mutex: &sync.Mutex{}, extract: extract, maxConnections: maxConnections, connections: make(map[string]int64), next: next, log: log.StandardLogger(), } for _, o := range options { if err := o(cl); err != nil { return nil, err } } if cl.errHandler == nil { cl.errHandler = &ConnErrHandler{ log: cl.log, } } return cl, nil } // Logger defines the logger the connection limiter will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func Logger(l *log.Logger) ConnLimitOption { return func(cl *ConnLimiter) error { cl.log = l return nil } } // Wrap sets the next handler to be called by connexion limiter handler. func (cl *ConnLimiter) Wrap(h http.Handler) { cl.next = h } func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { token, amount, err := cl.extract.Extract(r) if err != nil { cl.log.Errorf("failed to extract source of the connection: %v", err) cl.errHandler.ServeHTTP(w, r, err) return } if err := cl.acquire(token, amount); err != nil { cl.log.Debugf("limiting request source %s: %v", token, err) cl.errHandler.ServeHTTP(w, r, err) return } defer cl.release(token, amount) cl.next.ServeHTTP(w, r) } func (cl *ConnLimiter) acquire(token string, amount int64) error { cl.mutex.Lock() defer cl.mutex.Unlock() connections := cl.connections[token] if connections >= cl.maxConnections { return &MaxConnError{max: cl.maxConnections} } cl.connections[token] += amount cl.totalConnections += amount return nil } func (cl *ConnLimiter) release(token string, amount int64) { cl.mutex.Lock() defer cl.mutex.Unlock() cl.connections[token] -= amount cl.totalConnections -= amount // Otherwise it would grow forever if cl.connections[token] == 0 { delete(cl.connections, token) } } // MaxConnError maximum connections reached error type MaxConnError struct { max int64 } func (m *MaxConnError) Error() string { return fmt.Sprintf("max connections reached: %d", m.max) } // ConnErrHandler connection limiter error handler type ConnErrHandler struct { log *log.Logger } func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { if e.log.Level >= log.DebugLevel { logEntry := e.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/connlimit: completed ServeHttp on request") } if _, ok := err.(*MaxConnError); ok { w.WriteHeader(429) w.Write([]byte(err.Error())) return } utils.DefaultHandler.ServeHTTP(w, req, err) } // ConnLimitOption connection limit option type type ConnLimitOption func(l *ConnLimiter) error // ErrorHandler sets error handler of the server func ErrorHandler(h utils.ErrorHandler) ConnLimitOption { return func(cl *ConnLimiter) error { cl.errHandler = h return nil } } oxy-1.3.0/connlimit/connlimit_test.go000066400000000000000000000057561404246664300177360ustar00rootroot00000000000000package connlimit import ( "fmt" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" "github.com/vulcand/oxy/utils" ) // We've hit the limit and were able to proceed once the request has completed func TestHitLimitAndRelease(t *testing.T) { wait := make(chan bool) proceed := make(chan bool) finish := make(chan bool) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { fmt.Println(req.Header) if req.Header.Get("Wait") != "" { proceed <- true <-wait } w.Write([]byte("hello")) }) cl, err := New(handler, headerLimit, 1) require.NoError(t, err) srv := httptest.NewServer(cl) defer srv.Close() go func() { re, _, errGet := testutils.Get(srv.URL, testutils.Header("Limit", "a"), testutils.Header("wait", "yes")) require.NoError(t, errGet) assert.Equal(t, http.StatusOK, re.StatusCode) finish <- true }() <-proceed re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a")) require.NoError(t, err) assert.Equal(t, http.StatusTooManyRequests, re.StatusCode) // request from another source succeeds re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "b")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) // Once the first request finished, next one succeeds close(wait) <-finish re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // We've hit the limit and were able to proceed once the request has completed func TestCustomHandlers(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) w.Write([]byte(http.StatusText(http.StatusTeapot))) }) l, err := New(handler, headerLimit, 0, ErrorHandler(errHandler)) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a")) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) } // We've hit the limit and were able to proceed once the request has completed func TestFaultyExtract(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) l, err := New(handler, faultyExtract, 1) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() re, _, err := testutils.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } func headerLimiter(req *http.Request) (string, int64, error) { return req.Header.Get("Limit"), 1, nil } func faultyExtractor(_ *http.Request) (string, int64, error) { return "", -1, fmt.Errorf("oops") } var headerLimit = utils.ExtractorFunc(headerLimiter) var faultyExtract = utils.ExtractorFunc(faultyExtractor) oxy-1.3.0/forward/000077500000000000000000000000001404246664300140075ustar00rootroot00000000000000oxy-1.3.0/forward/fwd.go000066400000000000000000000364671404246664300151360ustar00rootroot00000000000000// Package forward implements http handler that forwards requests to remote server // and serves back the response // websocket proxying support based on https://github.com/yhat/wsutil package forward import ( "bytes" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "net/http/httputil" "net/url" "os" "reflect" "strings" "time" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) // OxyLogger interface of the internal type OxyLogger interface { log.FieldLogger GetLevel() log.Level } type internalLogger struct { *log.Logger } func (i *internalLogger) GetLevel() log.Level { return i.Level } // ReqRewriter can alter request headers and body type ReqRewriter interface { Rewrite(r *http.Request) } type optSetter func(f *Forwarder) error // PassHostHeader specifies if a client's Host header field should be delegated func PassHostHeader(b bool) optSetter { return func(f *Forwarder) error { f.httpForwarder.passHost = b return nil } } // RoundTripper sets a new http.RoundTripper // Forwarder will use http.DefaultTransport as a default round tripper func RoundTripper(r http.RoundTripper) optSetter { return func(f *Forwarder) error { f.httpForwarder.roundTripper = r return nil } } // Rewriter defines a request rewriter for the HTTP forwarder func Rewriter(r ReqRewriter) optSetter { return func(f *Forwarder) error { f.httpForwarder.rewriter = r return nil } } // WebsocketTLSClientConfig define the websocker client TLS configuration func WebsocketTLSClientConfig(tcc *tls.Config) optSetter { return func(f *Forwarder) error { f.httpForwarder.tlsClientConfig = tcc return nil } } // ErrorHandler is a functional argument that sets error handler of the server func ErrorHandler(h utils.ErrorHandler) optSetter { return func(f *Forwarder) error { f.errHandler = h return nil } } // BufferPool specifies a buffer pool for httputil.ReverseProxy. func BufferPool(pool httputil.BufferPool) optSetter { return func(f *Forwarder) error { f.bufferPool = pool return nil } } // Stream specifies if HTTP responses should be streamed. func Stream(stream bool) optSetter { return func(f *Forwarder) error { f.stream = stream return nil } } // Logger defines the logger the forwarder will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func Logger(l log.FieldLogger) optSetter { return func(f *Forwarder) error { if logger, ok := l.(OxyLogger); ok { f.log = logger return nil } if logger, ok := l.(*log.Logger); ok { f.log = &internalLogger{Logger: logger} return nil } return errors.New("the type of the logger must be OxyLogger or logrus.Logger") } } // StateListener defines a state listener for the HTTP forwarder func StateListener(stateListener UrlForwardingStateListener) optSetter { return func(f *Forwarder) error { f.stateListener = stateListener return nil } } // WebsocketConnectionClosedHook defines a hook called when websocket connection is closed func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter { return func(f *Forwarder) error { f.httpForwarder.websocketConnectionClosedHook = hook return nil } } // ResponseModifier defines a response modifier for the HTTP forwarder func ResponseModifier(responseModifier func(*http.Response) error) optSetter { return func(f *Forwarder) error { f.httpForwarder.modifyResponse = responseModifier return nil } } // StreamingFlushInterval defines a streaming flush interval for the HTTP forwarder func StreamingFlushInterval(flushInterval time.Duration) optSetter { return func(f *Forwarder) error { f.httpForwarder.flushInterval = flushInterval return nil } } // Forwarder wraps two traffic forwarding implementations: HTTP and websockets. // It decides based on the specified request which implementation to use type Forwarder struct { *httpForwarder *handlerContext stateListener UrlForwardingStateListener stream bool } // handlerContext defines a handler context for error reporting and logging type handlerContext struct { errHandler utils.ErrorHandler } // httpForwarder is a handler that can reverse proxy // HTTP traffic type httpForwarder struct { roundTripper http.RoundTripper rewriter ReqRewriter passHost bool flushInterval time.Duration modifyResponse func(*http.Response) error tlsClientConfig *tls.Config log OxyLogger bufferPool httputil.BufferPool websocketConnectionClosedHook func(req *http.Request, conn net.Conn) } const defaultFlushInterval = time.Duration(100) * time.Millisecond // Connection states const ( StateConnected = iota StateDisconnected ) // UrlForwardingStateListener URL forwarding state listener type UrlForwardingStateListener func(*url.URL, int) // New creates an instance of Forwarder based on the provided list of configuration options func New(setters ...optSetter) (*Forwarder, error) { f := &Forwarder{ httpForwarder: &httpForwarder{log: &internalLogger{Logger: log.StandardLogger()}}, handlerContext: &handlerContext{}, } for _, s := range setters { if err := s(f); err != nil { return nil, err } } if !f.stream { f.flushInterval = 0 } else if f.flushInterval == 0 { f.flushInterval = defaultFlushInterval } if f.httpForwarder.rewriter == nil { h, err := os.Hostname() if err != nil { h = "localhost" } f.httpForwarder.rewriter = &HeaderRewriter{TrustForwardHeader: true, Hostname: h} } if f.httpForwarder.roundTripper == nil { f.httpForwarder.roundTripper = http.DefaultTransport } if f.errHandler == nil { f.errHandler = utils.DefaultHandler } if f.tlsClientConfig == nil { if ht, ok := f.httpForwarder.roundTripper.(*http.Transport); ok { f.tlsClientConfig = ht.TLSClientConfig } } f.postConfig() return f, nil } // ServeHTTP decides which forwarder to use based on the specified // request and delegates to the proper implementation func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) { if f.log.GetLevel() >= log.DebugLevel { logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/forward: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/forward: completed ServeHttp on request") } if f.stateListener != nil { f.stateListener(req.URL, StateConnected) defer f.stateListener(req.URL, StateDisconnected) } if IsWebsocketRequest(req) { f.httpForwarder.serveWebSocket(w, req, f.handlerContext) } else { f.httpForwarder.serveHTTP(w, req, f.handlerContext) } } func (f *httpForwarder) getUrlFromRequest(req *http.Request) *url.URL { // If the Request was created by Go via a real HTTP request, RequestURI will // contain the original query string. If the Request was created in code, RequestURI // will be empty, and we will use the URL object instead u := req.URL if req.RequestURI != "" { parsedURL, err := url.ParseRequestURI(req.RequestURI) if err == nil { u = parsedURL } else { f.log.Warnf("vulcand/oxy/forward: error when parsing RequestURI: %s", err) } } return u } // Modify the request to handle the target URL func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) { outReq.URL = utils.CopyURL(outReq.URL) outReq.URL.Scheme = target.Scheme outReq.URL.Host = target.Host u := f.getUrlFromRequest(outReq) outReq.URL.Path = u.Path outReq.URL.RawPath = u.RawPath outReq.URL.RawQuery = u.RawQuery outReq.RequestURI = "" // Outgoing request should not have RequestURI outReq.Proto = "HTTP/1.1" outReq.ProtoMajor = 1 outReq.ProtoMinor = 1 if f.rewriter != nil { f.rewriter.Rewrite(outReq) } // Do not pass client Host header unless optsetter PassHostHeader is set. if !f.passHost { outReq.Host = target.Host } } // serveWebSocket forwards websocket traffic func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, ctx *handlerContext) { if f.log.GetLevel() >= log.DebugLevel { logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/forward/websocket: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/forward/websocket: completed ServeHttp on request") } outReq := f.copyWebSocketRequest(req) dialer := websocket.DefaultDialer if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil { dialer.TLSClientConfig = f.tlsClientConfig.Clone() // WebSocket is only in http/1.1 dialer.TLSClientConfig.NextProtos = []string{"http/1.1"} } targetConn, resp, err := dialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header) if err != nil { if resp == nil { ctx.errHandler.ServeHTTP(w, req, err) } else { f.log.Errorf("vulcand/oxy/forward/websocket: Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status) hijacker, ok := w.(http.Hijacker) if !ok { f.log.Errorf("vulcand/oxy/forward/websocket: %s can not be hijack", reflect.TypeOf(w)) ctx.errHandler.ServeHTTP(w, req, err) return } conn, _, errHijack := hijacker.Hijack() if errHijack != nil { f.log.Errorf("vulcand/oxy/forward/websocket: Failed to hijack responseWriter") ctx.errHandler.ServeHTTP(w, req, errHijack) return } defer func() { conn.Close() if f.websocketConnectionClosedHook != nil { f.websocketConnectionClosedHook(req, conn) } }() errWrite := resp.Write(conn) if errWrite != nil { f.log.Errorf("vulcand/oxy/forward/websocket: Failed to forward response") ctx.errHandler.ServeHTTP(w, req, errWrite) return } } return } // Only the targetConn choose to CheckOrigin or not upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} utils.RemoveHeaders(resp.Header, WebsocketUpgradeHeaders...) utils.CopyHeaders(resp.Header, w.Header()) underlyingConn, err := upgrader.Upgrade(w, req, resp.Header) if err != nil { f.log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err) return } defer func() { underlyingConn.Close() targetConn.Close() if f.websocketConnectionClosedHook != nil { f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn()) } }() errClient := make(chan error, 1) errBackend := make(chan error, 1) replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { forward := func(messageType int, reader io.Reader) error { writer, err := dst.NextWriter(messageType) if err != nil { return err } _, err = io.Copy(writer, reader) if err != nil { return err } return writer.Close() } src.SetPingHandler(func(data string) error { return forward(websocket.PingMessage, bytes.NewReader([]byte(data))) }) src.SetPongHandler(func(data string) error { return forward(websocket.PongMessage, bytes.NewReader([]byte(data))) }) for { msgType, reader, err := src.NextReader() if err != nil { m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) if e, ok := err.(*websocket.CloseError); ok { if e.Code != websocket.CloseNoStatusReceived { m = nil // Following codes are not valid on the wire so just close the // underlying TCP connection without sending a close frame. if e.Code != websocket.CloseAbnormalClosure && e.Code != websocket.CloseTLSHandshake { m = websocket.FormatCloseMessage(e.Code, e.Text) } } } errc <- err if m != nil { forward(websocket.CloseMessage, bytes.NewReader([]byte(m))) } break } err = forward(msgType, reader) if err != nil { errc <- err break } } } go replicateWebsocketConn(underlyingConn, targetConn, errClient) go replicateWebsocketConn(targetConn, underlyingConn, errBackend) var message string select { case err = <-errClient: message = "vulcand/oxy/forward/websocket: Error when copying from backend to client: %v" case err = <-errBackend: message = "vulcand/oxy/forward/websocket: Error when copying from client to backend: %v" } if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { f.log.Errorf(message, err) } } // copyWebsocketRequest makes a copy of the specified request. func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Request) { outReq = new(http.Request) *outReq = *req // includes shallow copies of maps, but we handle this below outReq.URL = utils.CopyURL(req.URL) outReq.URL.Scheme = req.URL.Scheme // sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs. switch req.URL.Scheme { case "https": outReq.URL.Scheme = "wss" case "http": outReq.URL.Scheme = "ws" } u := f.getUrlFromRequest(outReq) outReq.URL.Path = u.Path outReq.URL.RawPath = u.RawPath outReq.URL.RawQuery = u.RawQuery outReq.RequestURI = "" // Outgoing request should not have RequestURI outReq.URL.Host = req.URL.Host if !f.passHost { outReq.Host = req.URL.Host } outReq.Header = make(http.Header) // gorilla websocket use this header to set the request.Host tested in checkSameOrigin outReq.Header.Set("Host", outReq.Host) utils.CopyHeaders(outReq.Header, req.Header) utils.RemoveHeaders(outReq.Header, WebsocketDialHeaders...) if f.rewriter != nil { f.rewriter.Rewrite(outReq) } return outReq } // serveHTTP forwards HTTP traffic using the configured transport func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ctx *handlerContext) { if f.log.GetLevel() >= log.DebugLevel { logEntry := f.log.WithField("Request", utils.DumpHttpRequest(inReq)) logEntry.Debug("vulcand/oxy/forward/http: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/forward/http: completed ServeHttp on request") } start := time.Now().UTC() outReq := new(http.Request) *outReq = *inReq // includes shallow copies of maps, but we handle this in Director revproxy := httputil.ReverseProxy{ Director: func(req *http.Request) { f.modifyRequest(req, inReq.URL) }, Transport: f.roundTripper, FlushInterval: f.flushInterval, ModifyResponse: f.modifyResponse, BufferPool: f.bufferPool, ErrorHandler: ctx.errHandler.ServeHTTP, } if f.log.GetLevel() >= log.DebugLevel { pw := utils.NewProxyWriter(w) revproxy.ServeHTTP(pw, outReq) if inReq.TLS != nil { f.log.Debugf("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v", inReq.URL, pw.StatusCode(), pw.GetLength(), time.Now().UTC().Sub(start), inReq.TLS.Version, inReq.TLS.DidResume, inReq.TLS.CipherSuite, inReq.TLS.ServerName) } else { f.log.Debugf("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v", inReq.URL, pw.StatusCode(), pw.GetLength(), time.Now().UTC().Sub(start)) } } else { revproxy.ServeHTTP(w, outReq) } for key := range w.Header() { if strings.HasPrefix(key, http.TrailerPrefix) { if fl, ok := w.(http.Flusher); ok { fl.Flush() } break } } } // IsWebsocketRequest determines if the specified HTTP request is a // websocket handshake request func IsWebsocketRequest(req *http.Request) bool { containsHeader := func(name, value string) bool { items := strings.Split(req.Header.Get(name), ",") for _, item := range items { if value == strings.ToLower(strings.TrimSpace(item)) { return true } } return false } return containsHeader(Connection, "upgrade") && containsHeader(Upgrade, "websocket") } oxy-1.3.0/forward/fwd_chunked_go1_15_test.go000066400000000000000000000016241404246664300207340ustar00rootroot00000000000000// +build !go1.16 package forward import ( "fmt" "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" ) func TestChunkedResponseConversion(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { h := w.(http.Hijacker) conn, _, _ := h.Hijack() fmt.Fprintf(conn, "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") conn.Close() }) defer srv.Close() f, err := New() require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, "testtest1test2", string(body)) assert.Equal(t, http.StatusOK, re.StatusCode) } oxy-1.3.0/forward/fwd_chunked_test.go000066400000000000000000000023441404246664300176610ustar00rootroot00000000000000// +build go1.16 package forward import ( "fmt" "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" ) func TestChunkedResponseConversion(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { h := w.(http.Hijacker) conn, _, _ := h.Hijack() fmt.Fprintf(conn, "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") conn.Close() }) defer srv.Close() f, err := New() require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, "testtest1test2", string(body)) assert.Equal(t, http.StatusOK, re.StatusCode) // Also as per RFC2616 (https://tools.ietf.org/html/rfc2616#section-4.4) // "Messages MUST NOT include both a Content-Length header field and a non-identity transfer-coding. // If the message does include a non-identity transfer-coding, the Content-Length MUST be ignored." _, ok := re.Header["Content-Length"] assert.False(t, ok) } oxy-1.3.0/forward/fwd_test.go000066400000000000000000000261731404246664300161660ustar00rootroot00000000000000package forward import ( "context" "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" "github.com/vulcand/oxy/utils" ) // Makes sure hop-by-hop headers are removed func TestForwardHopHeaders(t *testing.T) { called := false var outHeaders http.Header var outHost, expectedHost string srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { called = true outHeaders = req.Header outHost = req.Host w.Write([]byte("hello")) }) defer srv.Close() f, err := New() require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) expectedHost = req.URL.Host f.ServeHTTP(w, req) }) defer proxy.Close() headers := http.Header{ Connection: []string{"close"}, KeepAlive: []string{"timeout=600"}, } re, body, err := testutils.Get(proxy.URL, testutils.Headers(headers)) require.NoError(t, err) assert.Equal(t, "hello", string(body)) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, true, called) assert.Equal(t, "", outHeaders.Get(Connection)) assert.Equal(t, "", outHeaders.Get(KeepAlive)) assert.Equal(t, expectedHost, outHost) } func TestDefaultErrHandler(t *testing.T) { f, err := New() require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI("http://localhost:63450") f.ServeHTTP(w, req) }) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusBadGateway, re.StatusCode) } func TestCustomErrHandler(t *testing.T) { f, err := New(ErrorHandler(utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) w.Write([]byte(http.StatusText(http.StatusTeapot))) }))) require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI("http://localhost:63450") f.ServeHTTP(w, req) }) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) assert.Equal(t, http.StatusText(http.StatusTeapot), string(body)) } func TestResponseModifier(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() f, err := New(ResponseModifier(func(resp *http.Response) error { resp.Header.Add("X-Test", "CUSTOM") return nil })) require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "CUSTOM", re.Header.Get("X-Test")) } func TestXForwardedHostHeader(t *testing.T) { tests := []struct { Description string PassHostHeader bool TargetUrl string ProxyfiedUrl string ExpectedXForwardedHost string }{ { Description: "XForwardedHost without PassHostHeader", PassHostHeader: false, TargetUrl: "http://xforwardedhost.com", ProxyfiedUrl: "http://backend.com", ExpectedXForwardedHost: "xforwardedhost.com", }, { Description: "XForwardedHost with PassHostHeader", PassHostHeader: true, TargetUrl: "http://xforwardedhost.com", ProxyfiedUrl: "http://backend.com", ExpectedXForwardedHost: "xforwardedhost.com", }, } for _, test := range tests { test := test t.Run(test.Description, func(t *testing.T) { t.Parallel() f, err := New(PassHostHeader(test.PassHostHeader)) require.NoError(t, err) r, err := http.NewRequest(http.MethodGet, test.TargetUrl, nil) require.NoError(t, err) backendUrl, err := url.Parse(test.ProxyfiedUrl) require.NoError(t, err) f.modifyRequest(r, backendUrl) require.Equal(t, test.ExpectedXForwardedHost, r.Header.Get(XForwardedHost)) }) } } // Makes sure hop-by-hop headers are removed func TestForwardedHeaders(t *testing.T) { var outHeaders http.Header srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { outHeaders = req.Header w.Write([]byte("hello")) }) defer srv.Close() f, err := New(Rewriter(&HeaderRewriter{TrustForwardHeader: true, Hostname: "hello"})) require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() headers := http.Header{ XForwardedProto: []string{"httpx"}, XForwardedFor: []string{"192.168.1.1"}, XForwardedServer: []string{"foobar"}, XForwardedHost: []string{"upstream-foobar"}, } re, _, err := testutils.Get(proxy.URL, testutils.Headers(headers)) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "httpx", outHeaders.Get(XForwardedProto)) assert.Contains(t, outHeaders.Get(XForwardedFor), "192.168.1.1") assert.Contains(t, "upstream-foobar", outHeaders.Get(XForwardedHost)) assert.Equal(t, "hello", outHeaders.Get(XForwardedServer)) } func TestCustomRewriter(t *testing.T) { var outHeaders http.Header srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { outHeaders = req.Header w.Write([]byte("hello")) }) defer srv.Close() f, err := New(Rewriter(&HeaderRewriter{TrustForwardHeader: false, Hostname: "hello"})) require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() headers := http.Header{ XForwardedProto: []string{"httpx"}, XForwardedFor: []string{"192.168.1.1"}, } re, _, err := testutils.Get(proxy.URL, testutils.Headers(headers)) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "http", outHeaders.Get(XForwardedProto)) assert.NotContains(t, outHeaders.Get(XForwardedFor), "192.168.1.1") } func TestCustomTransportTimeout(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { time.Sleep(20 * time.Millisecond) w.Write([]byte("hello")) }) defer srv.Close() f, err := New(RoundTripper( &http.Transport{ ResponseHeaderTimeout: 5 * time.Millisecond, })) require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusGatewayTimeout, re.StatusCode) } func TestCustomLogger(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() f, err := New() require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestRouteForwarding(t *testing.T) { var outPath string srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { outPath = req.RequestURI w.Write([]byte("hello")) }) defer srv.Close() f, err := New() require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() tests := []struct { Path string Query string ExpectedPath string }{ {"/hello", "", "/hello"}, {"//hello", "", "//hello"}, {"///hello", "", "///hello"}, {"/hello", "abc=def&def=123", "/hello?abc=def&def=123"}, {"/log/http%3A%2F%2Fwww.site.com%2Fsomething?a=b", "", "/log/http%3A%2F%2Fwww.site.com%2Fsomething?a=b"}, } for _, test := range tests { proxyURL := proxy.URL + test.Path if test.Query != "" { proxyURL = proxyURL + "?" + test.Query } request, err := http.NewRequest("GET", proxyURL, nil) require.NoError(t, err) re, err := http.DefaultClient.Do(request) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, test.ExpectedPath, outPath) } } func TestForwardedProto(t *testing.T) { var proto string srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { proto = req.Header.Get(XForwardedProto) w.Write([]byte("hello")) }) defer srv.Close() f, err := New() require.NoError(t, err) proxy := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) tproxy := httptest.NewUnstartedServer(proxy) tproxy.StartTLS() defer tproxy.Close() re, _, err := testutils.Get(tproxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "https", proto) } func TestContextWithValueInErrHandler(t *testing.T) { var originalPBool *bool originalBool := false originalPBool = &originalBool f, err := New(ErrorHandler(utils.ErrorHandlerFunc(func(rw http.ResponseWriter, req *http.Request, err error) { test, isBool := req.Context().Value("test").(*bool) if isBool { *test = true } if err != nil { rw.WriteHeader(http.StatusBadGateway) } }))) require.NoError(t, err) proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { // We need a network error req.URL = testutils.ParseURI("http://localhost:63450") newReq := req.WithContext(context.WithValue(req.Context(), "test", originalPBool)) f.ServeHTTP(w, newReq) }) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusBadGateway, re.StatusCode) assert.True(t, *originalPBool) } func TestTeTrailer(t *testing.T) { var teHeader string srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { teHeader = req.Header.Get(Te) w.Write([]byte("hello")) }) defer srv.Close() f, err := New() require.NoError(t, err) proxy := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) tproxy := httptest.NewUnstartedServer(proxy) tproxy.StartTLS() defer tproxy.Close() re, _, err := testutils.Get(tproxy.URL, testutils.Header("Te", "trailers")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "trailers", teHeader) } func TestUnannouncedTrailer(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(200) rw.(http.Flusher).Flush() rw.Header().Add(http.TrailerPrefix+"X-Trailer", "foo") })) proxy, err := New() require.Nil(t, err) proxySrv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) proxy.ServeHTTP(rw, req) })) resp, _ := http.Get(proxySrv.URL) ioutil.ReadAll(resp.Body) require.Equal(t, resp.Trailer.Get("X-Trailer"), "foo") } oxy-1.3.0/forward/fwd_websocket_test.go000066400000000000000000000452011404246664300202250ustar00rootroot00000000000000package forward import ( "bufio" "crypto/tls" "fmt" "net" "net/http" "net/http/httptest" "runtime" "testing" "time" gorillawebsocket "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" "golang.org/x/net/websocket" ) func TestWebSocketTCPClose(t *testing.T) { f, err := New(PassHostHeader(true)) require.NoError(t, err) errChan := make(chan error, 1) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { _, _, err := c.ReadMessage() if err != nil { errChan <- err break } } })) defer srv.Close() proxy := createProxyWithForwarder(f, srv.URL) proxyAddr := proxy.Listener.Addr().String() _, conn, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), ).open() require.NoError(t, err) conn.Close() serverErr := <-errChan wsErr, ok := serverErr.(*gorillawebsocket.CloseError) assert.Equal(t, true, ok) assert.Equal(t, 1006, wsErr.Code) } func TestWebsocketConnectionClosedHook(t *testing.T) { closed := make(chan struct{}) f, err := New(WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) { close(closed) })) require.NoError(t, err) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { msg := make([]byte, 4) conn.Read(msg) conn.Write(msg) conn.Close() })) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { mux.ServeHTTP(w, req) }) defer srv.Close() proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", resp) conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) fmt.Println(conn.ReadMessage()) conn.Close() select { case <-time.After(time.Second): t.Errorf("Websocket Hook not called") case <-closed: } } func TestWebSocketPingPong(t *testing.T) { f, err := New() require.NoError(t, err) var upgrader = gorillawebsocket.Upgrader{ HandshakeTimeout: 10 * time.Second, CheckOrigin: func(*http.Request) bool { return true }, } mux := http.NewServeMux() mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) { ws, err := upgrader.Upgrade(writer, request, nil) require.NoError(t, err) ws.SetPingHandler(func(appData string) error { ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong")) return nil }) ws.ReadMessage() }) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { mux.ServeHTTP(w, req) }) defer srv.Close() proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", resp) defer conn.Close() goodErr := fmt.Errorf("signal: %s", "Good data") badErr := fmt.Errorf("signal: %s", "Bad data") conn.SetPongHandler(func(data string) error { if data == "PingPong" { return goodErr } return badErr }) conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second)) _, _, err = conn.ReadMessage() if err != goodErr { require.NoError(t, err) } } func TestWebSocketEcho(t *testing.T) { f, err := New() require.NoError(t, err) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { msg := make([]byte, 4) conn.Read(msg) fmt.Println(string(msg)) conn.Write(msg) conn.Close() })) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { mux.ServeHTTP(w, req) }) defer srv.Close() proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", resp) conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) fmt.Println(conn.ReadMessage()) conn.Close() } func TestWebSocketPassHost(t *testing.T) { testCases := []struct { desc string passHost bool expected string }{ { desc: "PassHost false", passHost: false, }, { desc: "PassHost true", passHost: true, expected: "example.com", }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { f, err := New() f.passHost = test.passHost require.NoError(t, err) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { req := conn.Request() if test.passHost { require.Equal(t, test.expected, req.Host) } else { require.NotEqual(t, test.expected, req.Host) } msg := make([]byte, 4) conn.Read(msg) fmt.Println(string(msg)) conn.Write(msg) conn.Close() })) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { mux.ServeHTTP(w, req) }) defer srv.Close() proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) headers.Add("Host", "example.com") conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", resp) conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) fmt.Println(conn.ReadMessage()) conn.Close() }) } } func TestWebSocketNumGoRoutine(t *testing.T) { t.Skip("Flaky on goroutine") f, err := New() require.NoError(t, err) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { msg := make([]byte, 4) conn.Read(msg) fmt.Println(string(msg)) conn.Write(msg) conn.Close() })) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { mux.ServeHTTP(w, req) }) defer srv.Close() proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) f.ServeHTTP(w, req) }) defer proxy.Close() serverAddr := proxy.Listener.Addr().String() num := runtime.NumGoroutine() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", resp) conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) fmt.Println(conn.ReadMessage()) conn.Close() time.Sleep(time.Second) assert.Equal(t, num, runtime.NumGoroutine()) } func TestWebSocketServerWithoutCheckOrigin(t *testing.T) { f, err := New() require.NoError(t, err) upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) defer srv.Close() proxy := createProxyWithForwarder(f, srv.URL) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), withOrigin("http://127.0.0.2"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestWebSocketRequestWithOrigin(t *testing.T) { f, err := New(PassHostHeader(true)) require.NoError(t, err) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) defer srv.Close() proxy := createProxyWithForwarder(f, srv.URL) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() _, err = newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("echo"), withOrigin("http://127.0.0.2"), ).send() require.EqualError(t, err, "bad status") resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestWebSocketRequestWithQueryParams(t *testing.T) { f, err := New(PassHostHeader(true)) require.NoError(t, err) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() assert.Equal(t, "test", r.URL.Query().Get("query")) for { mt, message, err := conn.ReadMessage() if err != nil { break } err = conn.WriteMessage(mt, message) if err != nil { break } } })) defer srv.Close() proxy := createProxyWithForwarder(f, srv.URL) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws?query=test"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) { f, err := New() require.NoError(t, err) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { conn.Close() })) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { mux.ServeHTTP(w, req) }) defer srv.Close() proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) w.Header().Set("HEADER-KEY", "HEADER-VALUE") f.ServeHTTP(w, req) }) defer proxy.Close() serverAddr := proxy.Listener.Addr().String() headers := http.Header{} webSocketURL := "ws://" + serverAddr + "/ws" headers.Add("Origin", webSocketURL) conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) require.NoError(t, err, "Error during Dial with response: %+v", err, resp) defer conn.Close() assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY")) } func TestWebSocketRequestWithEncodedChar(t *testing.T) { f, err := New(PassHostHeader(true)) require.NoError(t, err) upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath()) for { mt, message, err := conn.ReadMessage() if err != nil { break } err = conn.WriteMessage(mt, message) if err != nil { break } } })) defer srv.Close() proxy := createProxyWithForwarder(f, srv.URL) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/%3A%2F%2F"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestDetectsWebSocketRequest(t *testing.T) { mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { conn.Write([]byte("ok")) conn.Close() })) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { websocketRequest := IsWebsocketRequest(req) assert.Equal(t, true, websocketRequest) mux.ServeHTTP(w, req) }) defer srv.Close() serverAddr := srv.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(serverAddr), withPath("/ws"), withData("echo"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func TestWebSocketUpgradeFailed(t *testing.T) { f, err := New() require.NoError(t, err) mux := http.NewServeMux() mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(400) }) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { mux.ServeHTTP(w, req) }) defer srv.Close() proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path if path == "/ws" { // Set new backend URL req.URL = testutils.ParseURI(srv.URL) req.URL.Path = path websocketRequest := IsWebsocketRequest(req) assert.Equal(t, true, websocketRequest) f.ServeHTTP(w, req) } else { w.WriteHeader(200) } }) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout) require.NoError(t, err) defer conn.Close() req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil) require.NoError(t, err) req.Header.Add("upgrade", "websocket") req.Header.Add("Connection", "upgrade") req.Write(conn) // First request works with 400 br := bufio.NewReader(conn) resp, err := http.ReadResponse(br, req) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) } func TestForwardsWebsocketTraffic(t *testing.T) { f, err := New() require.NoError(t, err) mux := http.NewServeMux() mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { conn.Write([]byte("ok")) conn.Close() })) srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { mux.ServeHTTP(w, req) }) defer srv.Close() proxy := createProxyWithForwarder(f, srv.URL) defer proxy.Close() proxyAddr := proxy.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("echo"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } func createTLSWebsocketServer() *httptest.Server { upgrader := gorillawebsocket.Upgrader{} srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() for { mt, message, err := conn.ReadMessage() if err != nil { break } err = conn.WriteMessage(mt, message) if err != nil { break } } })) return srv } func createProxyWithForwarder(forwarder *Forwarder, URL string) *httptest.Server { return testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { path := req.URL.Path // keep the original path // Set new backend URL req.URL = testutils.ParseURI(URL) req.URL.Path = path forwarder.ServeHTTP(w, req) }) } func TestWebSocketTransferTLSConfig(t *testing.T) { srv := createTLSWebsocketServer() defer srv.Close() forwarderWithoutTLSConfig, err := New(PassHostHeader(true)) require.NoError(t, err) proxyWithoutTLSConfig := createProxyWithForwarder(forwarderWithoutTLSConfig, srv.URL) defer proxyWithoutTLSConfig.Close() proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String() _, err = newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), ).send() require.EqualError(t, err, "bad status") transport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } forwarderWithTLSConfig, err := New(PassHostHeader(true), RoundTripper(transport)) require.NoError(t, err) proxyWithTLSConfig := createProxyWithForwarder(forwarderWithTLSConfig, srv.URL) defer proxyWithTLSConfig.Close() proxyAddr = proxyWithTLSConfig.Listener.Addr().String() resp, err := newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} forwarderWithTLSConfigFromDefaultTransport, err := New(PassHostHeader(true)) require.NoError(t, err) proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(forwarderWithTLSConfigFromDefaultTransport, srv.URL) defer proxyWithTLSConfig.Close() proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String() resp, err = newWebsocketRequest( withServer(proxyAddr), withPath("/ws"), withData("ok"), ).send() require.NoError(t, err) assert.Equal(t, "ok", resp) } const dialTimeout = time.Second type websocketRequestOpt func(w *websocketRequest) func withServer(server string) websocketRequestOpt { return func(w *websocketRequest) { w.ServerAddr = server } } func withPath(path string) websocketRequestOpt { return func(w *websocketRequest) { w.Path = path } } func withData(data string) websocketRequestOpt { return func(w *websocketRequest) { w.Data = data } } func withOrigin(origin string) websocketRequestOpt { return func(w *websocketRequest) { w.Origin = origin } } func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest { wsrequest := &websocketRequest{} for _, opt := range opts { opt(wsrequest) } if wsrequest.Origin == "" { wsrequest.Origin = "http://" + wsrequest.ServerAddr } if wsrequest.Config == nil { wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin) } return wsrequest } type websocketRequest struct { ServerAddr string Path string Data string Origin string Config *websocket.Config } func (w *websocketRequest) send() (string, error) { conn, _, err := w.open() if err != nil { return "", err } defer conn.Close() if _, err := conn.Write([]byte(w.Data)); err != nil { return "", err } var msg = make([]byte, 512) var n int n, err = conn.Read(msg) if err != nil { return "", err } received := string(msg[:n]) return received, nil } func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) { client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout) if err != nil { return nil, nil, err } conn, err := websocket.NewClient(w.Config, client) if err != nil { return nil, nil, err } return conn, client, err } oxy-1.3.0/forward/headers.go000066400000000000000000000033401404246664300157510ustar00rootroot00000000000000package forward // Headers const ( XForwardedProto = "X-Forwarded-Proto" XForwardedFor = "X-Forwarded-For" XForwardedHost = "X-Forwarded-Host" XForwardedPort = "X-Forwarded-Port" XForwardedServer = "X-Forwarded-Server" XRealIp = "X-Real-Ip" Connection = "Connection" KeepAlive = "Keep-Alive" ProxyAuthenticate = "Proxy-Authenticate" ProxyAuthorization = "Proxy-Authorization" Te = "Te" // canonicalized version of "TE" Trailers = "Trailers" TransferEncoding = "Transfer-Encoding" Upgrade = "Upgrade" ContentLength = "Content-Length" SecWebsocketKey = "Sec-Websocket-Key" SecWebsocketVersion = "Sec-Websocket-Version" SecWebsocketExtensions = "Sec-Websocket-Extensions" SecWebsocketAccept = "Sec-Websocket-Accept" ) // HopHeaders Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html // Copied from reverseproxy.go, too bad var HopHeaders = []string{ Connection, KeepAlive, ProxyAuthenticate, ProxyAuthorization, Te, // canonicalized version of "TE" Trailers, TransferEncoding, Upgrade, } // WebsocketDialHeaders Websocket dial headers var WebsocketDialHeaders = []string{ Upgrade, Connection, SecWebsocketKey, SecWebsocketVersion, SecWebsocketExtensions, SecWebsocketAccept, } // WebsocketUpgradeHeaders Websocket upgrade headers var WebsocketUpgradeHeaders = []string{ Upgrade, Connection, SecWebsocketAccept, SecWebsocketExtensions, } // XHeaders X-* headers var XHeaders = []string{ XForwardedProto, XForwardedFor, XForwardedHost, XForwardedPort, XForwardedServer, XRealIp, } oxy-1.3.0/forward/post_config.go000066400000000000000000000001071404246664300166460ustar00rootroot00000000000000// +build go1.11 package forward func (f *Forwarder) postConfig() {} oxy-1.3.0/forward/post_config_18.go000066400000000000000000000015341404246664300171630ustar00rootroot00000000000000// +build !go1.11 package forward import ( "context" "net/http" ) type key string const ( teHeader key = "TeHeader" ) type TeTrailerRoundTripper struct { http.RoundTripper } func (t *TeTrailerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { teHeader := req.Context().Value(teHeader) if teHeader != nil { req.Header.Set("Te", teHeader.(string)) } return t.RoundTripper.RoundTrip(req) } type TeTrailerRewriter struct { ReqRewriter } func (t *TeTrailerRewriter) Rewrite(req *http.Request) { if req.Header.Get("Te") == "trailers" { *req = *req.WithContext(context.WithValue(req.Context(), teHeader, req.Header.Get("Te"))) } t.ReqRewriter.Rewrite(req) } func (f *Forwarder) postConfig() { f.roundTripper = &TeTrailerRoundTripper{RoundTripper: f.roundTripper} f.rewriter = &TeTrailerRewriter{ReqRewriter: f.rewriter} } oxy-1.3.0/forward/rewrite.go000066400000000000000000000042231404246664300160200ustar00rootroot00000000000000package forward import ( "net" "net/http" "strings" "github.com/vulcand/oxy/utils" ) // HeaderRewriter is responsible for removing hop-by-hop headers and setting forwarding headers type HeaderRewriter struct { TrustForwardHeader bool Hostname string } // clean up IP in case if it is ipv6 address and it has {zone} information in it, like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692" func ipv6fix(clientIP string) string { return strings.Split(clientIP, "%")[0] } // Rewrite rewrite request headers func (rw *HeaderRewriter) Rewrite(req *http.Request) { if !rw.TrustForwardHeader { utils.RemoveHeaders(req.Header, XHeaders...) } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { clientIP = ipv6fix(clientIP) // If not websocket, done in http.ReverseProxy if IsWebsocketRequest(req) { if prior, ok := req.Header[XForwardedFor]; ok { req.Header.Set(XForwardedFor, strings.Join(prior, ", ")+", "+clientIP) } else { req.Header.Set(XForwardedFor, clientIP) } } if req.Header.Get(XRealIp) == "" { req.Header.Set(XRealIp, clientIP) } } xfProto := req.Header.Get(XForwardedProto) if xfProto == "" { if req.TLS != nil { req.Header.Set(XForwardedProto, "https") } else { req.Header.Set(XForwardedProto, "http") } } if IsWebsocketRequest(req) { if req.Header.Get(XForwardedProto) == "https" { req.Header.Set(XForwardedProto, "wss") } else { req.Header.Set(XForwardedProto, "ws") } } if xfPort := req.Header.Get(XForwardedPort); xfPort == "" { req.Header.Set(XForwardedPort, forwardedPort(req)) } if xfHost := req.Header.Get(XForwardedHost); xfHost == "" && req.Host != "" { req.Header.Set(XForwardedHost, req.Host) } if rw.Hostname != "" { req.Header.Set(XForwardedServer, rw.Hostname) } } func forwardedPort(req *http.Request) string { if req == nil { return "" } if _, port, err := net.SplitHostPort(req.Host); err == nil && port != "" { return port } if req.Header.Get(XForwardedProto) == "https" || req.Header.Get(XForwardedProto) == "wss" { return "443" } if req.TLS != nil { return "443" } return "80" } oxy-1.3.0/forward/rewrite_test.go000066400000000000000000000020721404246664300170570ustar00rootroot00000000000000package forward import ( "testing" "github.com/stretchr/testify/assert" ) func TestIPv6Fix(t *testing.T) { testCases := []struct { desc string clientIP string expected string }{ { desc: "empty", clientIP: "", expected: "", }, { desc: "ipv4 localhost", clientIP: "127.0.0.1", expected: "127.0.0.1", }, { desc: "ipv4", clientIP: "10.13.14.15", expected: "10.13.14.15", }, { desc: "ipv6 zone", clientIP: `fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)`, expected: "fe80::d806:a55d:eb1b:49cc", }, { desc: "ipv6 medium", clientIP: `fe80::1`, expected: "fe80::1", }, { desc: "ipv6 small", clientIP: `2000::`, expected: "2000::", }, { desc: "ipv6", clientIP: `2001:3452:4952:2837::`, expected: "2001:3452:4952:2837::", }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() actual := ipv6fix(test.clientIP) assert.Equal(t, test.expected, actual) }) } } oxy-1.3.0/go.mod000066400000000000000000000016421404246664300134540ustar00rootroot00000000000000module github.com/vulcand/oxy go 1.14 require ( github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd github.com/gorilla/websocket v1.4.2 github.com/gravitational/trace v0.0.0-20190726142706-a535a178675f // indirect github.com/jonboulle/clockwork v0.1.0 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/mailgun/minheap v0.0.0-20170619185613-3dbe6c6bf55f // indirect github.com/mailgun/multibuf v0.0.0-20150714184110-565402cd71fb github.com/mailgun/timetools v0.0.0-20141028012446-7e6055773c51 github.com/mailgun/ttlmap v0.0.0-20170619185759-c1c17f74874f github.com/segmentio/fasthash v1.0.3 github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.5.1 github.com/vulcand/predicate v1.1.0 golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect launchpad.net/gocheck v0.0.0-20140225173054-000000000087 // indirect ) oxy-1.3.0/go.sum000066400000000000000000000122011404246664300134720ustar00rootroot00000000000000github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd h1:qMd81Ts1T2OTKmB4acZcyKaMtRnY5Y44NuXGX2GFJ1w= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gravitational/trace v0.0.0-20190726142706-a535a178675f h1:68WxnfBzJRYktZ30fmIjGQ74RsXYLoeH2/NITPktTMY= github.com/gravitational/trace v0.0.0-20190726142706-a535a178675f/go.mod h1:RvdOUHE4SHqR3oXlFFKnGzms8a5dugHygGw1bqDstYI= github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mailgun/minheap v0.0.0-20170619185613-3dbe6c6bf55f h1:aOqSQstfwSx9+tcM/xiKTio3IVjs7ZL2vU8kI9bI6bM= github.com/mailgun/minheap v0.0.0-20170619185613-3dbe6c6bf55f/go.mod h1:V3EvCedtJTvUYzJF2GZMRB0JMlai+6cBu3VCTQz33GQ= github.com/mailgun/multibuf v0.0.0-20150714184110-565402cd71fb h1:m2FGM8K2LC9Zyt/7zbQNn5Uvf/YV7vFWKtoMcC7hHU8= github.com/mailgun/multibuf v0.0.0-20150714184110-565402cd71fb/go.mod h1:E0vRBBIQUHcRtmL/oR6w/jehh4FJqJFxe86gBnw9gXc= github.com/mailgun/timetools v0.0.0-20141028012446-7e6055773c51 h1:Kg/NPZLLC3aAFr1YToMs98dbCdhootQ1hZIvZU28hAQ= github.com/mailgun/timetools v0.0.0-20141028012446-7e6055773c51/go.mod h1:RYmqHbhWwIz3z9eVmQ2rx82rulEMG0t+Q1bzfc9DYN4= github.com/mailgun/ttlmap v0.0.0-20170619185759-c1c17f74874f h1:ZZYhg16XocqSKPGNQAe0aeweNtFxuedbwwb4fSlg7h4= github.com/mailgun/ttlmap v0.0.0-20170619185759-c1c17f74874f/go.mod h1:8heskWJ5c0v5J9WH89ADhyal1DOZcayll8fSbhB+/9A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/vulcand/predicate v1.1.0 h1:Gq/uWopa4rx/tnZu2opOSBqHK63Yqlou/SzrbwdJiNg= github.com/vulcand/predicate v1.1.0/go.mod h1:mlccC5IRBoc2cIFmCB8ZM62I3VDb6p2GXESMHa3CnZg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 h1:Ao/3l156eZf2AW5wK8a7/smtodRU+gha3+BeqJ69lRk= golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= launchpad.net/gocheck v0.0.0-20140225173054-000000000087 h1:Izowp2XBH6Ya6rv+hqbceQyw/gSGoXfH/UPoTGduL54= launchpad.net/gocheck v0.0.0-20140225173054-000000000087/go.mod h1:hj7XX3B/0A+80Vse0e+BUHsHMTEhd0O4cpUHr/e/BUM= oxy-1.3.0/memmetrics/000077500000000000000000000000001404246664300145105ustar00rootroot00000000000000oxy-1.3.0/memmetrics/anomaly.go000066400000000000000000000064041404246664300165030ustar00rootroot00000000000000package memmetrics import ( "math" "sort" "time" ) // SplitLatencies provides simple anomaly detection for requests latencies. // it splits values into good or bad category based on the threshold and the median value. // If all values are not far from the median, it will return all values in 'good' set. // Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored. func SplitLatencies(values []time.Duration, precision time.Duration) (good map[time.Duration]bool, bad map[time.Duration]bool) { // Find the max latency M and then map each latency L to the ratio L/M and then call SplitFloat64 v2r := map[float64]time.Duration{} ratios := make([]float64, len(values)) m := maxTime(values) for i, v := range values { ratio := float64(v/precision+1) / float64(m/precision+1) // +1 is to avoid division by 0 v2r[ratio] = v ratios[i] = ratio } good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool) // Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise. vgood, vbad := SplitFloat64(2, 0, ratios) for r := range vgood { good[v2r[r]] = true } for r := range vbad { bad[v2r[r]] = true } return good, bad } // SplitRatios provides simple anomaly detection for ratio values, that are all in the range [0, 1] // it splits values into good or bad category based on the threshold and the median value. // If all values are not far from the median, it will return all values in 'good' set. func SplitRatios(values []float64) (good map[float64]bool, bad map[float64]bool) { return SplitFloat64(1.5, 0, values) } // SplitFloat64 provides simple anomaly detection for skewed data sets with no particular distribution. // In essence it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly // There's a corner case where there are just 2 values, so by definition there's no value that exceeds the threshold. // This case is solved by introducing additional value that we know is good, e.g. 0. That helps to improve the detection results // on such data sets. func SplitFloat64(threshold, sentinel float64, values []float64) (good map[float64]bool, bad map[float64]bool) { good, bad = make(map[float64]bool), make(map[float64]bool) var newValues []float64 if len(values)%2 == 0 { newValues = make([]float64, len(values)+1) copy(newValues, values) // Add a sentinel endpoint so we can distinguish outliers better newValues[len(newValues)-1] = sentinel } else { newValues = values } m := median(newValues) mAbs := medianAbsoluteDeviation(newValues) for _, v := range values { if v > (m+mAbs)*threshold { bad[v] = true } else { good[v] = true } } return good, bad } func median(values []float64) float64 { vals := make([]float64, len(values)) copy(vals, values) sort.Float64s(vals) l := len(vals) if l%2 != 0 { return vals[l/2] } return (vals[l/2-1] + vals[l/2]) / 2.0 } func medianAbsoluteDeviation(values []float64) float64 { m := median(values) distances := make([]float64, len(values)) for i, v := range values { distances[i] = math.Abs(v - m) } return median(distances) } func maxTime(vals []time.Duration) time.Duration { val := vals[0] for _, v := range vals { if v > val { val = v } } return val } oxy-1.3.0/memmetrics/anomaly_test.go000066400000000000000000000072471404246664300175500ustar00rootroot00000000000000package memmetrics import ( "strconv" "testing" "time" "github.com/stretchr/testify/assert" ) func TestMedian(t *testing.T) { testCases := []struct { desc string values []float64 expected float64 }{ { desc: "2 values", values: []float64{0.1, 0.2}, expected: (float64(0.1) + float64(0.2)) / 2.0, }, { desc: "3 values", values: []float64{0.3, 0.2, 0.5}, expected: 0.3, }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() actual := median(test.values) assert.Equal(t, test.expected, actual) }) } } func TestSplitRatios(t *testing.T) { testCases := []struct { values []float64 good []float64 bad []float64 }{ { values: []float64{0, 0}, good: []float64{0}, bad: []float64{}, }, { values: []float64{0, 1}, good: []float64{0}, bad: []float64{1}, }, { values: []float64{0.1, 0.1}, good: []float64{0.1}, bad: []float64{}, }, { values: []float64{0.15, 0.1}, good: []float64{0.15, 0.1}, bad: []float64{}, }, { values: []float64{0.01, 0.01}, good: []float64{0.01}, bad: []float64{}, }, { values: []float64{0.012, 0.01, 1}, good: []float64{0.012, 0.01}, bad: []float64{1}, }, { values: []float64{0, 0, 1, 1}, good: []float64{0}, bad: []float64{1}, }, { values: []float64{0, 0.1, 0.1, 0}, good: []float64{0}, bad: []float64{0.1}, }, { values: []float64{0, 0.01, 0.1, 0}, good: []float64{0}, bad: []float64{0.01, 0.1}, }, { values: []float64{0, 0.01, 0.02, 1}, good: []float64{0, 0.01, 0.02}, bad: []float64{1}, }, { values: []float64{0, 0, 0, 0, 0, 0.01, 0.02, 1}, good: []float64{0}, bad: []float64{0.01, 0.02, 1}, }, } for ind, test := range testCases { test := test t.Run(strconv.Itoa(ind), func(t *testing.T) { t.Parallel() good, bad := SplitRatios(test.values) vgood := make(map[float64]bool, len(test.good)) for _, v := range test.good { vgood[v] = true } vbad := make(map[float64]bool, len(test.bad)) for _, v := range test.bad { vbad[v] = true } assert.Equal(t, vgood, good) assert.Equal(t, vbad, bad) }) } } func TestSplitLatencies(t *testing.T) { testCases := []struct { values []int good []int bad []int }{ { values: []int{0, 0}, good: []int{0}, bad: []int{}, }, { values: []int{1, 2}, good: []int{1, 2}, bad: []int{}, }, { values: []int{1, 2, 4}, good: []int{1, 2, 4}, bad: []int{}, }, { values: []int{8, 8, 18}, good: []int{8}, bad: []int{18}, }, { values: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8, 97}, good: []int{32, 28, 11, 26, 19, 51, 25, 39, 28, 26, 8}, bad: []int{97}, }, { values: []int{1, 2, 4, 40}, good: []int{1, 2, 4}, bad: []int{40}, }, { values: []int{40, 60, 1000}, good: []int{40, 60}, bad: []int{1000}, }, } for ind, test := range testCases { test := test t.Run(strconv.Itoa(ind), func(t *testing.T) { t.Parallel() values := make([]time.Duration, len(test.values)) for i, d := range test.values { values[i] = time.Millisecond * time.Duration(d) } good, bad := SplitLatencies(values, time.Millisecond) vgood := make(map[time.Duration]bool, len(test.good)) for _, v := range test.good { vgood[time.Duration(v)*time.Millisecond] = true } assert.Equal(t, vgood, good) vbad := make(map[time.Duration]bool, len(test.bad)) for _, v := range test.bad { vbad[time.Duration(v)*time.Millisecond] = true } assert.Equal(t, vbad, bad) }) } } oxy-1.3.0/memmetrics/counter.go000066400000000000000000000075071404246664300165270ustar00rootroot00000000000000package memmetrics import ( "fmt" "time" "github.com/mailgun/timetools" ) type rcOptSetter func(*RollingCounter) error // CounterClock defines a counter clock func CounterClock(c timetools.TimeProvider) rcOptSetter { return func(r *RollingCounter) error { r.clock = c return nil } } // RollingCounter Calculates in memory failure rate of an endpoint using rolling window of a predefined size type RollingCounter struct { clock timetools.TimeProvider resolution time.Duration values []int countedBuckets int // how many samples in different buckets have we collected so far lastBucket int // last recorded bucket lastUpdated time.Time } // NewCounter creates a counter with fixed amount of buckets that are rotated every resolution period. // E.g. 10 buckets with 1 second means that every new second the bucket is refreshed, so it maintains 10 second rolling window. // By default creates a bucket with 10 buckets and 1 second resolution func NewCounter(buckets int, resolution time.Duration, options ...rcOptSetter) (*RollingCounter, error) { if buckets <= 0 { return nil, fmt.Errorf("Buckets should be >= 0") } if resolution < time.Second { return nil, fmt.Errorf("Resolution should be larger than a second") } rc := &RollingCounter{ lastBucket: -1, resolution: resolution, values: make([]int, buckets), } for _, o := range options { if err := o(rc); err != nil { return nil, err } } if rc.clock == nil { rc.clock = &timetools.RealTime{} } return rc, nil } // Append append a counter func (c *RollingCounter) Append(o *RollingCounter) error { c.Inc(int(o.Count())) return nil } // Clone clone a counter func (c *RollingCounter) Clone() *RollingCounter { c.cleanup() other := &RollingCounter{ resolution: c.resolution, values: make([]int, len(c.values)), clock: c.clock, lastBucket: c.lastBucket, lastUpdated: c.lastUpdated, } copy(other.values, c.values) return other } // Reset reset a counter func (c *RollingCounter) Reset() { c.lastBucket = -1 c.countedBuckets = 0 c.lastUpdated = time.Time{} for i := range c.values { c.values[i] = 0 } } // CountedBuckets gets counted buckets func (c *RollingCounter) CountedBuckets() int { return c.countedBuckets } // Count counts func (c *RollingCounter) Count() int64 { c.cleanup() return c.sum() } // Resolution gets resolution func (c *RollingCounter) Resolution() time.Duration { return c.resolution } // Buckets gets buckets func (c *RollingCounter) Buckets() int { return len(c.values) } // WindowSize gets windows size func (c *RollingCounter) WindowSize() time.Duration { return time.Duration(len(c.values)) * c.resolution } // Inc increment counter func (c *RollingCounter) Inc(v int) { c.cleanup() c.incBucketValue(v) } func (c *RollingCounter) incBucketValue(v int) { now := c.clock.UtcNow() bucket := c.getBucket(now) c.values[bucket] += v c.lastUpdated = now // Update usage stats if we haven't collected enough data if c.countedBuckets < len(c.values) { // Only update if we have advanced to the next bucket and not incremented the value // in the current bucket. if c.lastBucket != bucket { c.lastBucket = bucket c.countedBuckets++ } } } // Returns the number in the moving window bucket that this slot occupies func (c *RollingCounter) getBucket(t time.Time) int { return int(t.Truncate(c.resolution).Unix() % int64(len(c.values))) } // Reset buckets that were not updated func (c *RollingCounter) cleanup() { now := c.clock.UtcNow() for i := 0; i < len(c.values); i++ { now = now.Add(time.Duration(-1*i) * c.resolution) if now.Truncate(c.resolution).After(c.lastUpdated.Truncate(c.resolution)) { c.values[c.getBucket(now)] = 0 } else { break } } } func (c *RollingCounter) sum() int64 { out := int64(0) for _, v := range c.values { out += int64(v) } return out } oxy-1.3.0/memmetrics/counter_test.go000066400000000000000000000011071404246664300175540ustar00rootroot00000000000000package memmetrics import ( "testing" "time" "github.com/mailgun/timetools" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCloneExpired(t *testing.T) { clockTest := &timetools.FreezedTime{ CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), } cnt, err := NewCounter(3, time.Second, CounterClock(clockTest)) require.NoError(t, err) cnt.Inc(1) clockTest.Sleep(time.Second) cnt.Inc(1) clockTest.Sleep(time.Second) cnt.Inc(1) clockTest.Sleep(time.Second) out := cnt.Clone() assert.EqualValues(t, 2, out.Count()) } oxy-1.3.0/memmetrics/histogram.go000066400000000000000000000123251404246664300170370ustar00rootroot00000000000000package memmetrics import ( "fmt" "time" "github.com/codahale/hdrhistogram" "github.com/mailgun/timetools" ) // HDRHistogram is a tiny wrapper around github.com/codahale/hdrhistogram that provides convenience functions for measuring http latencies type HDRHistogram struct { // lowest trackable value low int64 // highest trackable value high int64 // significant figures sigfigs int h *hdrhistogram.Histogram } // NewHDRHistogram creates a new HDRHistogram func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) { defer func() { if msg := recover(); msg != nil { err = fmt.Errorf("%s", msg) } }() return &HDRHistogram{ low: low, high: high, sigfigs: sigfigs, h: hdrhistogram.New(low, high, sigfigs), }, nil } // Export export a HDRHistogram func (h *HDRHistogram) Export() *HDRHistogram { var hist *hdrhistogram.Histogram if h.h != nil { snapshot := h.h.Export() hist = hdrhistogram.Import(snapshot) } return &HDRHistogram{low: h.low, high: h.high, sigfigs: h.sigfigs, h: hist} } // LatencyAtQuantile sets latency at quantile with microsecond precision func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration { return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond } // RecordLatencies Records latencies with microsecond precision func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error { return h.RecordValues(int64(d/time.Microsecond), n) } // Reset reset a HDRHistogram func (h *HDRHistogram) Reset() { h.h.Reset() } // ValueAtQuantile sets value at quantile func (h *HDRHistogram) ValueAtQuantile(q float64) int64 { return h.h.ValueAtQuantile(q) } // RecordValues sets record values func (h *HDRHistogram) RecordValues(v, n int64) error { return h.h.RecordValues(v, n) } // Merge merge a HDRHistogram func (h *HDRHistogram) Merge(other *HDRHistogram) error { if other == nil { return fmt.Errorf("other is nil") } h.h.Merge(other.h) return nil } type rhOptSetter func(r *RollingHDRHistogram) error // RollingClock sets a clock func RollingClock(clock timetools.TimeProvider) rhOptSetter { return func(r *RollingHDRHistogram) error { r.clock = clock return nil } } // RollingHDRHistogram holds multiple histograms and rotates every period. // It provides resulting histogram as a result of a call of 'Merged' function. type RollingHDRHistogram struct { idx int lastRoll time.Time period time.Duration bucketCount int low int64 high int64 sigfigs int buckets []*HDRHistogram clock timetools.TimeProvider } // NewRollingHDRHistogram created a new RollingHDRHistogram func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, bucketCount int, options ...rhOptSetter) (*RollingHDRHistogram, error) { rh := &RollingHDRHistogram{ bucketCount: bucketCount, period: period, low: low, high: high, sigfigs: sigfigs, } for _, o := range options { if err := o(rh); err != nil { return nil, err } } if rh.clock == nil { rh.clock = &timetools.RealTime{} } buckets := make([]*HDRHistogram, rh.bucketCount) for i := range buckets { h, err := NewHDRHistogram(low, high, sigfigs) if err != nil { return nil, err } buckets[i] = h } rh.buckets = buckets return rh, nil } // Export export a RollingHDRHistogram func (r *RollingHDRHistogram) Export() *RollingHDRHistogram { export := &RollingHDRHistogram{} export.idx = r.idx export.lastRoll = r.lastRoll export.period = r.period export.bucketCount = r.bucketCount export.low = r.low export.high = r.high export.sigfigs = r.sigfigs export.clock = r.clock exportBuckets := make([]*HDRHistogram, len(r.buckets)) for i, hist := range r.buckets { exportBuckets[i] = hist.Export() } export.buckets = exportBuckets return export } // Append append a RollingHDRHistogram func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error { if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs { return fmt.Errorf("can't merge") } for i := range r.buckets { if err := r.buckets[i].Merge(o.buckets[i]); err != nil { return err } } return nil } // Reset reset a RollingHDRHistogram func (r *RollingHDRHistogram) Reset() { r.idx = 0 r.lastRoll = r.clock.UtcNow() for _, b := range r.buckets { b.Reset() } } func (r *RollingHDRHistogram) rotate() { r.idx = (r.idx + 1) % len(r.buckets) r.buckets[r.idx].Reset() } // Merged gets merged histogram func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) { m, err := NewHDRHistogram(r.low, r.high, r.sigfigs) if err != nil { return m, err } for _, h := range r.buckets { if errMerge := m.Merge(h); errMerge != nil { return nil, errMerge } } return m, nil } func (r *RollingHDRHistogram) getHist() *HDRHistogram { if r.clock.UtcNow().Sub(r.lastRoll) >= r.period { r.rotate() r.lastRoll = r.clock.UtcNow() } return r.buckets[r.idx] } // RecordLatencies sets records latencies func (r *RollingHDRHistogram) RecordLatencies(v time.Duration, n int64) error { return r.getHist().RecordLatencies(v, n) } // RecordValues set record values func (r *RollingHDRHistogram) RecordValues(v, n int64) error { return r.getHist().RecordValues(v, n) } oxy-1.3.0/memmetrics/histogram_test.go000066400000000000000000000102301404246664300200670ustar00rootroot00000000000000package memmetrics import ( "testing" "time" "github.com/codahale/hdrhistogram" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" ) func TestMerge(t *testing.T) { a, err := NewHDRHistogram(1, 3600000, 2) require.NoError(t, err) err = a.RecordValues(1, 2) require.NoError(t, err) b, err := NewHDRHistogram(1, 3600000, 2) require.NoError(t, err) require.NoError(t, b.RecordValues(2, 1)) require.NoError(t, a.Merge(b)) assert.EqualValues(t, 1, a.ValueAtQuantile(50)) assert.EqualValues(t, 2, a.ValueAtQuantile(100)) } func TestInvalidParams(t *testing.T) { _, err := NewHDRHistogram(1, 3600000, 0) require.Error(t, err) } func TestMergeNil(t *testing.T) { a, err := NewHDRHistogram(1, 3600000, 1) require.NoError(t, err) require.Error(t, a.Merge(nil)) } func TestRotation(t *testing.T) { clock := testutils.GetClock() h, err := NewRollingHDRHistogram( 1, // min value 3600000, // max value 3, // significant figures time.Second, // 1 second is a rolling period 2, // 2 histograms in a window RollingClock(clock)) require.NoError(t, err) require.NotNil(t, h) err = h.RecordValues(5, 1) require.NoError(t, err) m, err := h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) clock.CurrentTime = clock.CurrentTime.Add(time.Second) require.NoError(t, h.RecordValues(2, 1)) require.NoError(t, h.RecordValues(1, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) // rotate, this means that the old value would evaporate clock.CurrentTime = clock.CurrentTime.Add(time.Second) require.NoError(t, h.RecordValues(1, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 2, m.ValueAtQuantile(100)) } func TestReset(t *testing.T) { clock := testutils.GetClock() h, err := NewRollingHDRHistogram( 1, // min value 3600000, // max value 3, // significant figures time.Second, // 1 second is a rolling period 2, // 2 histograms in a window RollingClock(clock)) require.NoError(t, err) require.NotNil(t, h) require.NoError(t, h.RecordValues(5, 1)) m, err := h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) clock.CurrentTime = clock.CurrentTime.Add(time.Second) require.NoError(t, h.RecordValues(2, 1)) require.NoError(t, h.RecordValues(1, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) h.Reset() require.NoError(t, h.RecordValues(5, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) clock.CurrentTime = clock.CurrentTime.Add(time.Second) require.NoError(t, h.RecordValues(2, 1)) require.NoError(t, h.RecordValues(1, 1)) m, err = h.Merged() require.NoError(t, err) assert.EqualValues(t, 5, m.ValueAtQuantile(100)) } func TestHDRHistogramExportReturnsNewCopy(t *testing.T) { // Create HDRHistogram instance a := HDRHistogram{ low: 1, high: 2, sigfigs: 3, h: hdrhistogram.New(0, 1, 2), } // Get a copy and modify the original b := a.Export() a.low = 11 a.high = 12 a.sigfigs = 4 a.h = nil // Assert the copy has not been modified assert.EqualValues(t, 1, b.low) assert.EqualValues(t, 2, b.high) assert.Equal(t, 3, b.sigfigs) require.NotNil(t, b.h) } func TestRollingHDRHistogramExportReturnsNewCopy(t *testing.T) { origTime := time.Now() a := RollingHDRHistogram{ idx: 1, lastRoll: origTime, period: 2 * time.Second, bucketCount: 3, low: 4, high: 5, sigfigs: 1, buckets: []*HDRHistogram{}, clock: testutils.GetClock(), } b := a.Export() a.idx = 11 a.lastRoll = time.Now().Add(1 * time.Minute) a.period = 12 * time.Second a.bucketCount = 13 a.low = 14 a.high = 15 a.sigfigs = 1 a.buckets = nil a.clock = nil assert.Equal(t, 1, b.idx) assert.Equal(t, origTime, b.lastRoll) assert.Equal(t, 2*time.Second, b.period) assert.Equal(t, 3, b.bucketCount) assert.Equal(t, int64(4), b.low) assert.EqualValues(t, 5, b.high) assert.NotNil(t, b.buckets) assert.NotNil(t, b.clock) } oxy-1.3.0/memmetrics/ratio.go000066400000000000000000000051631404246664300161620ustar00rootroot00000000000000package memmetrics import ( "time" "github.com/mailgun/timetools" ) type ratioOptSetter func(r *RatioCounter) error // RatioClock sets a clock func RatioClock(clock timetools.TimeProvider) ratioOptSetter { return func(r *RatioCounter) error { r.clock = clock return nil } } // RatioCounter calculates a ratio of a/a+b over a rolling window of predefined buckets type RatioCounter struct { clock timetools.TimeProvider a *RollingCounter b *RollingCounter } // NewRatioCounter creates a new RatioCounter func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptSetter) (*RatioCounter, error) { rc := &RatioCounter{} for _, o := range options { if err := o(rc); err != nil { return nil, err } } if rc.clock == nil { rc.clock = &timetools.RealTime{} } a, err := NewCounter(buckets, resolution, CounterClock(rc.clock)) if err != nil { return nil, err } b, err := NewCounter(buckets, resolution, CounterClock(rc.clock)) if err != nil { return nil, err } rc.a = a rc.b = b return rc, nil } // Reset reset the counter func (r *RatioCounter) Reset() { r.a.Reset() r.b.Reset() } // IsReady returns true if the counter is ready func (r *RatioCounter) IsReady() bool { return r.a.countedBuckets+r.b.countedBuckets >= len(r.a.values) } // CountA gets count A func (r *RatioCounter) CountA() int64 { return r.a.Count() } // CountB gets count B func (r *RatioCounter) CountB() int64 { return r.b.Count() } // Resolution gets resolution func (r *RatioCounter) Resolution() time.Duration { return r.a.Resolution() } // Buckets gets buckets func (r *RatioCounter) Buckets() int { return r.a.Buckets() } // WindowSize gets windows size func (r *RatioCounter) WindowSize() time.Duration { return r.a.WindowSize() } // ProcessedCount gets processed count func (r *RatioCounter) ProcessedCount() int64 { return r.CountA() + r.CountB() } // Ratio gets ratio func (r *RatioCounter) Ratio() float64 { a := r.a.Count() b := r.b.Count() // No data, return ok if a+b == 0 { return 0 } return float64(a) / float64(a+b) } // IncA increment counter A func (r *RatioCounter) IncA(v int) { r.a.Inc(v) } // IncB increment counter B func (r *RatioCounter) IncB(v int) { r.b.Inc(v) } // TestMeter a test meter type TestMeter struct { Rate float64 NotReady bool WindowSize time.Duration } // GetWindowSize gets windows size func (tm *TestMeter) GetWindowSize() time.Duration { return tm.WindowSize } // IsReady returns true if the meter is ready func (tm *TestMeter) IsReady() bool { return !tm.NotReady } // GetRate gets rate func (tm *TestMeter) GetRate() float64 { return tm.Rate } oxy-1.3.0/memmetrics/ratio_test.go000066400000000000000000000101021404246664300172060ustar00rootroot00000000000000package memmetrics import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" ) func TestNewRatioCounterInvalidParams(t *testing.T) { clock := testutils.GetClock() // Bad buckets count _, err := NewRatioCounter(0, time.Second, RatioClock(clock)) require.Error(t, err) // Too precise resolution _, err = NewRatioCounter(10, time.Millisecond, RatioClock(clock)) require.Error(t, err) } func TestNotReady(t *testing.T) { clock := testutils.GetClock() // No data fr, err := NewRatioCounter(10, time.Second, RatioClock(clock)) require.NoError(t, err) assert.Equal(t, false, fr.IsReady()) assert.Equal(t, 0.0, fr.Ratio()) // Not enough data fr, err = NewRatioCounter(10, time.Second, RatioClock(clock)) require.NoError(t, err) fr.CountA() assert.Equal(t, false, fr.IsReady()) } func TestNoB(t *testing.T) { fr, err := NewRatioCounter(1, time.Second, RatioClock(testutils.GetClock())) require.NoError(t, err) fr.IncA(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 1.0, fr.Ratio()) } func TestNoA(t *testing.T) { fr, err := NewRatioCounter(1, time.Second, RatioClock(testutils.GetClock())) require.NoError(t, err) fr.IncB(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 0.0, fr.Ratio()) } // Make sure that data is properly calculated over several buckets func TestMultipleBuckets(t *testing.T) { clock := testutils.GetClock() fr, err := NewRatioCounter(3, time.Second, RatioClock(clock)) require.NoError(t, err) fr.IncB(1) clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, float64(2)/float64(3), fr.Ratio()) } // Make sure that data is properly calculated over several buckets // When we overwrite old data when the window is rolling func TestOverwriteBuckets(t *testing.T) { clock := testutils.GetClock() fr, err := NewRatioCounter(3, time.Second, RatioClock(clock)) require.NoError(t, err) fr.IncB(1) clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) // This time we should overwrite the old data points clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) fr.IncB(2) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, float64(3)/float64(5), fr.Ratio()) } // Make sure we cleanup the data after periods of inactivity // So it does not mess up the stats func TestInactiveBuckets(t *testing.T) { clock := testutils.GetClock() fr, err := NewRatioCounter(3, time.Second, RatioClock(clock)) require.NoError(t, err) fr.IncB(1) clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) // This time we should overwrite the old data points with new data clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) fr.IncB(2) // Jump to the last bucket and change the data clock.CurrentTime = clock.CurrentTime.Add(time.Second * 2) fr.IncB(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, float64(1)/float64(4), fr.Ratio()) } func TestLongPeriodsOfInactivity(t *testing.T) { clock := testutils.GetClock() fr, err := NewRatioCounter(2, time.Second, RatioClock(clock)) require.NoError(t, err) fr.IncB(1) clock.CurrentTime = clock.CurrentTime.Add(time.Second) fr.IncA(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 0.5, fr.Ratio()) // This time we should overwrite all data points clock.CurrentTime = clock.CurrentTime.Add(100 * time.Second) fr.IncA(1) assert.Equal(t, 1.0, fr.Ratio()) } func TestNewRatioCounterReset(t *testing.T) { fr, err := NewRatioCounter(1, time.Second, RatioClock(testutils.GetClock())) require.NoError(t, err) fr.IncB(1) fr.IncA(1) assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 0.5, fr.Ratio()) // Reset the counter fr.Reset() assert.Equal(t, false, fr.IsReady()) // Now add some stats fr.IncA(2) // We are game again! assert.Equal(t, true, fr.IsReady()) assert.Equal(t, 1.0, fr.Ratio()) } oxy-1.3.0/memmetrics/roundtrip.go000066400000000000000000000166141404246664300170750ustar00rootroot00000000000000package memmetrics import ( "errors" "net/http" "sync" "time" "github.com/mailgun/timetools" ) // RTMetrics provides aggregated performance metrics for HTTP requests processing // such as round trip latency, response codes counters network error and total requests. // all counters are collected as rolling window counters with defined precision, histograms // are a rolling window histograms with defined precision as well. // See RTOptions for more detail on parameters. type RTMetrics struct { total *RollingCounter netErrors *RollingCounter statusCodes map[int]*RollingCounter statusCodesLock sync.RWMutex histogram *RollingHDRHistogram histogramLock sync.RWMutex newCounter NewCounterFn newHist NewRollingHistogramFn clock timetools.TimeProvider } type rrOptSetter func(r *RTMetrics) error // NewRTMetricsFn builder function type type NewRTMetricsFn func() (*RTMetrics, error) // NewCounterFn builder function type type NewCounterFn func() (*RollingCounter, error) // NewRollingHistogramFn builder function type type NewRollingHistogramFn func() (*RollingHDRHistogram, error) // RTCounter set a builder function for Counter func RTCounter(new NewCounterFn) rrOptSetter { return func(r *RTMetrics) error { r.newCounter = new return nil } } // RTHistogram set a builder function for RollingHistogram func RTHistogram(fn NewRollingHistogramFn) rrOptSetter { return func(r *RTMetrics) error { r.newHist = fn return nil } } // RTClock sets a clock func RTClock(clock timetools.TimeProvider) rrOptSetter { return func(r *RTMetrics) error { r.clock = clock return nil } } // NewRTMetrics returns new instance of metrics collector. func NewRTMetrics(settings ...rrOptSetter) (*RTMetrics, error) { m := &RTMetrics{ statusCodes: make(map[int]*RollingCounter), statusCodesLock: sync.RWMutex{}, } for _, s := range settings { if err := s(m); err != nil { return nil, err } } if m.clock == nil { m.clock = &timetools.RealTime{} } if m.newCounter == nil { m.newCounter = func() (*RollingCounter, error) { return NewCounter(counterBuckets, counterResolution, CounterClock(m.clock)) } } if m.newHist == nil { m.newHist = func() (*RollingHDRHistogram, error) { return NewRollingHDRHistogram(histMin, histMax, histSignificantFigures, histPeriod, histBuckets, RollingClock(m.clock)) } } h, err := m.newHist() if err != nil { return nil, err } netErrors, err := m.newCounter() if err != nil { return nil, err } total, err := m.newCounter() if err != nil { return nil, err } m.histogram = h m.netErrors = netErrors m.total = total return m, nil } // Export Returns a new RTMetrics which is a copy of the current one func (m *RTMetrics) Export() *RTMetrics { m.statusCodesLock.RLock() defer m.statusCodesLock.RUnlock() m.histogramLock.RLock() defer m.histogramLock.RUnlock() export := &RTMetrics{} export.statusCodesLock = sync.RWMutex{} export.histogramLock = sync.RWMutex{} export.total = m.total.Clone() export.netErrors = m.netErrors.Clone() exportStatusCodes := map[int]*RollingCounter{} for code, rollingCounter := range m.statusCodes { exportStatusCodes[code] = rollingCounter.Clone() } export.statusCodes = exportStatusCodes if m.histogram != nil { export.histogram = m.histogram.Export() } export.newCounter = m.newCounter export.newHist = m.newHist export.clock = m.clock return export } // CounterWindowSize gets total windows size func (m *RTMetrics) CounterWindowSize() time.Duration { return m.total.WindowSize() } // NetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection // that occurred in the given time window compared to the total requests count. func (m *RTMetrics) NetworkErrorRatio() float64 { if m.total.Count() == 0 { return 0 } return float64(m.netErrors.Count()) / float64(m.total.Count()) } // ResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB) func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 { a := int64(0) b := int64(0) m.statusCodesLock.RLock() defer m.statusCodesLock.RUnlock() for code, v := range m.statusCodes { if code < endA && code >= startA { a += v.Count() } if code < endB && code >= startB { b += v.Count() } } if b != 0 { return float64(a) / float64(b) } return 0 } // Append append a metric func (m *RTMetrics) Append(other *RTMetrics) error { if m == other { return errors.New("RTMetrics cannot append to self") } if err := m.total.Append(other.total); err != nil { return err } if err := m.netErrors.Append(other.netErrors); err != nil { return err } copied := other.Export() m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() m.histogramLock.Lock() defer m.histogramLock.Unlock() for code, c := range copied.statusCodes { o, ok := m.statusCodes[code] if ok { if err := o.Append(c); err != nil { return err } } else { m.statusCodes[code] = c.Clone() } } return m.histogram.Append(copied.histogram) } // Record records a metric func (m *RTMetrics) Record(code int, duration time.Duration) { m.total.Inc(1) if code == http.StatusGatewayTimeout || code == http.StatusBadGateway { m.netErrors.Inc(1) } m.recordStatusCode(code) m.recordLatency(duration) } // TotalCount returns total count of processed requests collected. func (m *RTMetrics) TotalCount() int64 { return m.total.Count() } // NetworkErrorCount returns total count of processed requests observed func (m *RTMetrics) NetworkErrorCount() int64 { return m.netErrors.Count() } // StatusCodesCounts returns map with counts of the response codes func (m *RTMetrics) StatusCodesCounts() map[int]int64 { sc := make(map[int]int64) m.statusCodesLock.RLock() defer m.statusCodesLock.RUnlock() for k, v := range m.statusCodes { if v.Count() != 0 { sc[k] = v.Count() } } return sc } // LatencyHistogram computes and returns resulting histogram with latencies observed. func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) { m.histogramLock.Lock() defer m.histogramLock.Unlock() return m.histogram.Merged() } // Reset reset metrics func (m *RTMetrics) Reset() { m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() m.histogramLock.Lock() defer m.histogramLock.Unlock() m.histogram.Reset() m.total.Reset() m.netErrors.Reset() m.statusCodes = make(map[int]*RollingCounter) } func (m *RTMetrics) recordLatency(d time.Duration) error { m.histogramLock.Lock() defer m.histogramLock.Unlock() return m.histogram.RecordLatencies(d, 1) } func (m *RTMetrics) recordStatusCode(statusCode int) error { m.statusCodesLock.Lock() if c, ok := m.statusCodes[statusCode]; ok { c.Inc(1) m.statusCodesLock.Unlock() return nil } m.statusCodesLock.Unlock() m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() // Check if another goroutine has written our counter already if c, ok := m.statusCodes[statusCode]; ok { c.Inc(1) return nil } c, err := m.newCounter() if err != nil { return err } c.Inc(1) m.statusCodes[statusCode] = c return nil } const ( counterBuckets = 10 counterResolution = time.Second histMin = 1 histMax = 3600000000 // 1 hour in microseconds histSignificantFigures = 2 // significant figures (1% precision) histBuckets = 6 // number of sub-histograms in a rolling histogram histPeriod = 10 * time.Second // roll time ) oxy-1.3.0/memmetrics/roundtrip_test.go000066400000000000000000000075771404246664300201440ustar00rootroot00000000000000package memmetrics import ( "runtime" "sync" "testing" "time" "github.com/mailgun/timetools" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" ) func TestDefaults(t *testing.T) { rr, err := NewRTMetrics(RTClock(testutils.GetClock())) require.NoError(t, err) require.NotNil(t, rr) rr.Record(200, time.Second) rr.Record(502, 2*time.Second) rr.Record(200, time.Second) rr.Record(200, time.Second) assert.EqualValues(t, 1, rr.NetworkErrorCount()) assert.EqualValues(t, 4, rr.TotalCount()) assert.Equal(t, map[int]int64{502: 1, 200: 3}, rr.StatusCodesCounts()) assert.Equal(t, float64(1)/float64(4), rr.NetworkErrorRatio()) assert.Equal(t, 1.0/3.0, rr.ResponseCodeRatio(500, 503, 200, 300)) h, err := rr.LatencyHistogram() require.NoError(t, err) assert.Equal(t, 2, int(h.LatencyAtQuantile(100)/time.Second)) rr.Reset() assert.EqualValues(t, 0, rr.NetworkErrorCount()) assert.EqualValues(t, 0, rr.TotalCount()) assert.Equal(t, map[int]int64{}, rr.StatusCodesCounts()) assert.Equal(t, float64(0), rr.NetworkErrorRatio()) assert.Equal(t, float64(0), rr.ResponseCodeRatio(500, 503, 200, 300)) h, err = rr.LatencyHistogram() require.NoError(t, err) assert.Equal(t, time.Duration(0), h.LatencyAtQuantile(100)) } func TestAppend(t *testing.T) { clock := testutils.GetClock() rr, err := NewRTMetrics(RTClock(clock)) require.NoError(t, err) require.NotNil(t, rr) rr.Record(200, time.Second) rr.Record(502, 2*time.Second) rr.Record(200, time.Second) rr.Record(200, time.Second) rr2, err := NewRTMetrics(RTClock(clock)) require.NoError(t, err) require.NotNil(t, rr2) rr2.Record(200, 3*time.Second) rr2.Record(501, 3*time.Second) rr2.Record(200, 3*time.Second) rr2.Record(200, 3*time.Second) require.NoError(t, rr2.Append(rr)) assert.Equal(t, map[int]int64{501: 1, 502: 1, 200: 6}, rr2.StatusCodesCounts()) assert.EqualValues(t, 1, rr2.NetworkErrorCount()) h, err := rr2.LatencyHistogram() require.NoError(t, err) assert.EqualValues(t, 3, h.LatencyAtQuantile(100)/time.Second) } func TestConcurrentRecords(t *testing.T) { // This test asserts a race condition which requires parallelism runtime.GOMAXPROCS(100) rr, err := NewRTMetrics(RTClock(testutils.GetClock())) require.NoError(t, err) for code := 0; code < 100; code++ { for numRecords := 0; numRecords < 10; numRecords++ { go func(statusCode int) { _ = rr.recordStatusCode(statusCode) }(code) } } } func TestRTMetricExportReturnsNewCopy(t *testing.T) { a := RTMetrics{ clock: &timetools.RealTime{}, statusCodes: map[int]*RollingCounter{}, statusCodesLock: sync.RWMutex{}, histogram: &RollingHDRHistogram{}, histogramLock: sync.RWMutex{}, } var err error a.total, err = NewCounter(1, time.Second, CounterClock(a.clock)) require.NoError(t, err) a.netErrors, err = NewCounter(1, time.Second, CounterClock(a.clock)) require.NoError(t, err) a.newCounter = func() (*RollingCounter, error) { return NewCounter(counterBuckets, counterResolution, CounterClock(a.clock)) } a.newHist = func() (*RollingHDRHistogram, error) { return NewRollingHDRHistogram(histMin, histMax, histSignificantFigures, histPeriod, histBuckets, RollingClock(a.clock)) } b := a.Export() a.total = nil a.netErrors = nil a.statusCodes = nil a.histogram = nil a.newCounter = nil a.newHist = nil a.clock = nil assert.NotNil(t, b.total) assert.NotNil(t, b.netErrors) assert.NotNil(t, b.statusCodes) assert.NotNil(t, b.histogram) assert.NotNil(t, b.newCounter) assert.NotNil(t, b.newHist) assert.NotNil(t, b.clock) // a and b should have different locks locksSucceed := make(chan bool) go func() { a.statusCodesLock.Lock() b.statusCodesLock.Lock() a.histogramLock.Lock() b.histogramLock.Lock() locksSucceed <- true }() for { select { case <-locksSucceed: return case <-time.After(10 * time.Second): t.FailNow() } } } oxy-1.3.0/ratelimit/000077500000000000000000000000001404246664300143355ustar00rootroot00000000000000oxy-1.3.0/ratelimit/bucket.go000066400000000000000000000107271404246664300161500ustar00rootroot00000000000000package ratelimit import ( "fmt" "time" "github.com/mailgun/timetools" ) // UndefinedDelay default delay const UndefinedDelay = -1 // rate defines token bucket parameters. type rate struct { period time.Duration average int64 burst int64 } func (r *rate) String() string { return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst) } // tokenBucket Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket) type tokenBucket struct { // The time period controlled by the bucket in nanoseconds. period time.Duration // The number of nanoseconds that takes to add one more token to the total // number of available tokens. It effectively caches the value that could // have been otherwise deduced from refillRate. timePerToken time.Duration // The maximum number of tokens that can be accumulate in the bucket. burst int64 // The number of tokens available for consumption at the moment. It can // nether be larger then capacity. availableTokens int64 // Interface that gives current time (so tests can override) clock timetools.TimeProvider // Tells when tokensAvailable was updated the last time. lastRefresh time.Time // The number of tokens consumed the last time. lastConsumed int64 } // newTokenBucket crates a `tokenBucket` instance for the specified `Rate`. func newTokenBucket(rate *rate, clock timetools.TimeProvider) *tokenBucket { period := rate.period if period == 0 { period = time.Nanosecond } return &tokenBucket{ period: period, timePerToken: time.Duration(int64(period) / rate.average), burst: rate.burst, clock: clock, lastRefresh: clock.UtcNow(), availableTokens: rate.burst, } } // consume makes an attempt to consume the specified number of tokens from the // bucket. If there are enough tokens available then `0, nil` is returned; if // tokens to consume is larger than the burst size, then an error is returned // and the delay is not defined; otherwise returned a none zero delay that tells // how much time the caller needs to wait until the desired number of tokens // will become available for consumption. func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) { tb.updateAvailableTokens() tb.lastConsumed = 0 if tokens > tb.burst { return UndefinedDelay, fmt.Errorf("requested tokens larger than max tokens") } if tb.availableTokens < tokens { return tb.timeTillAvailable(tokens), nil } tb.availableTokens -= tokens tb.lastConsumed = tokens return 0, nil } // rollback reverts effect of the most recent consumption. If the most recent // `consume` resulted in an error or a burst overflow, and therefore did not // modify the number of available tokens, then `rollback` won't do that either. // It is safe to call this method multiple times, for the second and all // following calls have no effect. func (tb *tokenBucket) rollback() { tb.availableTokens += tb.lastConsumed tb.lastConsumed = 0 } // update modifies `average` and `burst` fields of the token bucket according // to the provided `Rate` func (tb *tokenBucket) update(rate *rate) error { if rate.period != tb.period { return fmt.Errorf("period mismatch: %v != %v", tb.period, rate.period) } tb.timePerToken = time.Duration(int64(tb.period) / rate.average) tb.burst = rate.burst if tb.availableTokens > rate.burst { tb.availableTokens = rate.burst } return nil } // timeTillAvailable returns the number of nanoseconds that we need to // wait until the specified number of tokens becomes available for consumption. func (tb *tokenBucket) timeTillAvailable(tokens int64) time.Duration { missingTokens := tokens - tb.availableTokens return time.Duration(missingTokens) * tb.timePerToken } // updateAvailableTokens updates the number of tokens available for consumption. // It is calculated based on the refill rate, the time passed since last refresh, // and is limited by the bucket capacity. func (tb *tokenBucket) updateAvailableTokens() { now := tb.clock.UtcNow() timePassed := now.Sub(tb.lastRefresh) if tb.timePerToken == 0 { return } tokens := tb.availableTokens + int64(timePassed/tb.timePerToken) // If we haven't added any tokens that means that not enough time has passed, // in this case do not adjust last refill checkpoint, otherwise it will be // always moving in time in case of frequent requests that exceed the rate if tokens != tb.availableTokens { tb.lastRefresh = now tb.availableTokens = tokens } if tb.availableTokens > tb.burst { tb.availableTokens = tb.burst } } oxy-1.3.0/ratelimit/bucket_test.go000066400000000000000000000235761404246664300172150ustar00rootroot00000000000000package ratelimit import ( "testing" "time" "github.com/mailgun/timetools" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" ) func TestConsumeSingleToken(t *testing.T) { clock := testutils.GetClock() tb := newTokenBucket(&rate{period: time.Second, average: 1, burst: 1}, clock) // First request passes delay, err := tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // Next request does not pass the same second delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Second, delay) // Second later, the request passes clock.Sleep(time.Second) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // Five seconds later, still only one request is allowed // because maxBurst is 1 clock.Sleep(5 * time.Second) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // The next one is forbidden delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Second, delay) } func TestFastConsumption(t *testing.T) { clock := testutils.GetClock() tb := newTokenBucket(&rate{period: time.Second, average: 1, burst: 1}, clock) // First request passes delay, err := tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // Try 200 ms later clock.Sleep(time.Millisecond * 200) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Second, delay) // Try 700 ms later clock.Sleep(time.Millisecond * 700) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Second, delay) // Try 100 ms later, success! clock.Sleep(time.Millisecond * 100) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) } func TestConsumeMultipleTokens(t *testing.T) { tb := newTokenBucket(&rate{period: time.Second, average: 3, burst: 5}, testutils.GetClock()) delay, err := tb.consume(3) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(2) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.NotEqual(t, time.Duration(0), delay) } func TestDelayIsCorrect(t *testing.T) { clock := testutils.GetClock() tb := newTokenBucket(&rate{period: time.Second, average: 3, burst: 5}, clock) // Exhaust initial capacity delay, err := tb.consume(5) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(3) require.NoError(t, err) assert.NotEqual(t, time.Duration(0), delay) // Now wait provided delay and make sure we can consume now clock.Sleep(delay) delay, err = tb.consume(3) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) } // Make sure requests that exceed burst size are not allowed func TestExceedsBurst(t *testing.T) { tb := newTokenBucket(&rate{period: time.Second, average: 1, burst: 10}, testutils.GetClock()) _, err := tb.consume(11) require.Error(t, err) } func TestConsumeBurst(t *testing.T) { tb := newTokenBucket(&rate{period: time.Second, average: 2, burst: 5}, testutils.GetClock()) // In two seconds we would have 5 tokens testutils.GetClock().Sleep(2 * time.Second) // Lets consume 5 at once delay, err := tb.consume(5) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) } func TestConsumeEstimate(t *testing.T) { tb := newTokenBucket(&rate{period: time.Second, average: 2, burst: 4}, testutils.GetClock()) // Consume all burst at once delay, err := tb.consume(4) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) // Now try to consume it and face delay delay, err = tb.consume(4) require.NoError(t, err) assert.Equal(t, time.Duration(2)*time.Second, delay) } // If a rate with different period is passed to the `update` method, then an // error is returned but the state of the bucket remains valid and unchanged. func TestUpdateInvalidPeriod(t *testing.T) { clock := testutils.GetClock() // Given tb := newTokenBucket(&rate{period: time.Second, average: 10, burst: 20}, clock) _, err := tb.consume(15) // 5 tokens available require.NoError(t, err) // When err = tb.update(&rate{period: time.Second + 1, average: 30, burst: 40}) // still 5 tokens available require.Error(t, err) // Then // ...check that rate did not change clock.Sleep(500 * time.Millisecond) delay, err := tb.consume(11) require.NoError(t, err) assert.Equal(t, 100*time.Millisecond, delay) delay, err = tb.consume(10) require.NoError(t, err) // 0 available assert.Equal(t, time.Duration(0), delay) // ...check that burst did not change clock.Sleep(40 * time.Second) _, err = tb.consume(21) require.Error(t, err) delay, err = tb.consume(20) require.NoError(t, err) // 0 available assert.Equal(t, time.Duration(0), delay) } // If the capacity of the bucket is increased by the update then it takes some // time to fill the bucket with tokens up to the new capacity. func TestUpdateBurstIncreased(t *testing.T) { clock := testutils.GetClock() // Given tb := newTokenBucket(&rate{period: time.Second, average: 10, burst: 20}, clock) _, err := tb.consume(15) // 5 tokens available require.NoError(t, err) // When err = tb.update(&rate{period: time.Second, average: 10, burst: 50}) // still 5 tokens available require.NoError(t, err) // Then delay, err := tb.consume(50) require.NoError(t, err) assert.Equal(t, time.Duration(time.Second/10*45), delay) } // If the capacity of the bucket is increased by the update then it takes some // time to fill the bucket with tokens up to the new capacity. func TestUpdateBurstDecreased(t *testing.T) { clock := testutils.GetClock() // Given tb := newTokenBucket(&rate{period: time.Second, average: 10, burst: 50}, clock) _, err := tb.consume(15) // 35 tokens available require.NoError(t, err) // When err = tb.update(&rate{period: time.Second, average: 10, burst: 20}) // the number of available tokens reduced to 20. require.NoError(t, err) // Then delay, err := tb.consume(21) require.Error(t, err) assert.Equal(t, time.Duration(-1), delay) } // If rate is updated then it affects the bucket refill speed. func TestUpdateRateChanged(t *testing.T) { clock := testutils.GetClock() // Given tb := newTokenBucket(&rate{period: time.Second, average: 10, burst: 20}, clock) _, err := tb.consume(15) // 5 tokens available require.NoError(t, err) // When err = tb.update(&rate{period: time.Second, average: 20, burst: 20}) // still 5 tokens available require.NoError(t, err) // Then delay, err := tb.consume(20) require.NoError(t, err) assert.Equal(t, time.Duration(time.Second/20*15), delay) } // Only the most recent consumption is reverted by `Rollback`. func TestRollback(t *testing.T) { clock := testutils.GetClock() // Given tb := newTokenBucket(&rate{period: time.Second, average: 10, burst: 20}, clock) _, err := tb.consume(8) // 12 tokens available require.NoError(t, err) _, err = tb.consume(7) // 5 tokens available require.NoError(t, err) // When tb.rollback() // 12 tokens available // Then delay, err := tb.consume(12) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, 100*time.Millisecond, delay) } // It is safe to call `Rollback` several times. The second and all subsequent // calls just do nothing. func TestRollbackSeveralTimes(t *testing.T) { // Given tb := newTokenBucket(&rate{period: time.Second, average: 10, burst: 20}, testutils.GetClock()) _, err := tb.consume(8) // 12 tokens available require.NoError(t, err) tb.rollback() // 20 tokens available // When tb.rollback() // still 20 tokens available tb.rollback() // still 20 tokens available tb.rollback() // still 20 tokens available // Then: all 20 tokens can be consumed delay, err := tb.consume(20) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, 100*time.Millisecond, delay) } // If previous consumption returned a delay due to an attempt to consume more // tokens then there are available, then `Rollback` has no effect. func TestRollbackAfterAvailableExceeded(t *testing.T) { // Given tb := newTokenBucket(&rate{period: time.Second, average: 10, burst: 20}, testutils.GetClock()) _, err := tb.consume(8) // 12 tokens available require.NoError(t, err) delay, err := tb.consume(15) // still 12 tokens available require.NoError(t, err) assert.Equal(t, 300*time.Millisecond, delay) // When tb.rollback() // Previous operation consumed 0 tokens, so rollback has no effect. // Then delay, err = tb.consume(12) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, 100*time.Millisecond, delay) } // If previous consumption returned a error due to an attempt to consume more // tokens then the bucket's burst size, then `Rollback` has no effect. func TestRollbackAfterError(t *testing.T) { clock := testutils.GetClock() // Given tb := newTokenBucket(&rate{period: time.Second, average: 10, burst: 20}, clock) _, err := tb.consume(8) // 12 tokens available require.NoError(t, err) delay, err := tb.consume(21) // still 12 tokens available require.Error(t, err) assert.Equal(t, time.Duration(-1), delay) // When tb.rollback() // Previous operation consumed 0 tokens, so rollback has no effect. // Then delay, err = tb.consume(12) require.NoError(t, err) assert.Equal(t, time.Duration(0), delay) delay, err = tb.consume(1) require.NoError(t, err) assert.Equal(t, 100*time.Millisecond, delay) } func TestDivisionByZeroOnPeriod(t *testing.T) { clock := &timetools.RealTime{} var emptyPeriod int64 tb := newTokenBucket(&rate{period: time.Duration(emptyPeriod), average: 2, burst: 2}, clock) _, err := tb.consume(1) assert.NoError(t, err) err = tb.update(&rate{period: time.Nanosecond, average: 1, burst: 1}) assert.NoError(t, err) } oxy-1.3.0/ratelimit/bucketset.go000066400000000000000000000063221404246664300166600ustar00rootroot00000000000000package ratelimit import ( "fmt" "sort" "strings" "time" "github.com/mailgun/timetools" ) // TokenBucketSet represents a set of TokenBucket covering different time periods. type TokenBucketSet struct { buckets map[time.Duration]*tokenBucket maxPeriod time.Duration clock timetools.TimeProvider } // NewTokenBucketSet creates a `TokenBucketSet` from the specified `rates`. func NewTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *TokenBucketSet { tbs := new(TokenBucketSet) tbs.clock = clock // In the majority of cases we will have only one bucket. tbs.buckets = make(map[time.Duration]*tokenBucket, len(rates.m)) for _, rate := range rates.m { newBucket := newTokenBucket(rate, clock) tbs.buckets[rate.period] = newBucket tbs.maxPeriod = maxDuration(tbs.maxPeriod, rate.period) } return tbs } // Update brings the buckets in the set in accordance with the provided `rates`. func (tbs *TokenBucketSet) Update(rates *RateSet) { // Update existing buckets and delete those that have no corresponding spec. for _, bucket := range tbs.buckets { if rate, ok := rates.m[bucket.period]; ok { bucket.update(rate) } else { delete(tbs.buckets, bucket.period) } } // Add missing buckets. for _, rate := range rates.m { if _, ok := tbs.buckets[rate.period]; !ok { newBucket := newTokenBucket(rate, tbs.clock) tbs.buckets[rate.period] = newBucket } } // Identify the maximum period in the set tbs.maxPeriod = 0 for _, bucket := range tbs.buckets { tbs.maxPeriod = maxDuration(tbs.maxPeriod, bucket.period) } } // Consume consume tokens func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) { var maxDelay time.Duration = UndefinedDelay var firstErr error for _, tokenBucket := range tbs.buckets { // We keep calling `Consume` even after a error is returned for one of // buckets because that allows us to simplify the rollback procedure, // that is to just call `Rollback` for all buckets. delay, err := tokenBucket.consume(tokens) if firstErr == nil { if err != nil { firstErr = err } else { maxDelay = maxDuration(maxDelay, delay) } } } // If we could not make ALL buckets consume tokens for whatever reason, // then rollback consumption for all of them. if firstErr != nil || maxDelay > 0 { for _, tokenBucket := range tbs.buckets { tokenBucket.rollback() } } return maxDelay, firstErr } // GetMaxPeriod returns the max period func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration { return tbs.maxPeriod } // debugState returns string that reflects the current state of all buckets in // this set. It is intended to be used for debugging and testing only. func (tbs *TokenBucketSet) debugState() string { periods := sort.IntSlice(make([]int, 0, len(tbs.buckets))) for period := range tbs.buckets { periods = append(periods, int(period)) } sort.Sort(periods) bucketRepr := make([]string, 0, len(tbs.buckets)) for _, period := range periods { bucket := tbs.buckets[time.Duration(period)] bucketRepr = append(bucketRepr, fmt.Sprintf("{%v: %v}", bucket.period, bucket.availableTokens)) } return strings.Join(bucketRepr, ", ") } func maxDuration(x time.Duration, y time.Duration) time.Duration { if x > y { return x } return y } oxy-1.3.0/ratelimit/bucketset_test.go000066400000000000000000000143051404246664300177170ustar00rootroot00000000000000package ratelimit import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" ) // A value returned by `MaxPeriod` corresponds to the longest bucket time period. func TestLongestPeriod(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*time.Second, 10, 20)) require.NoError(t, rates.Add(7*time.Second, 10, 20)) require.NoError(t, rates.Add(5*time.Second, 11, 21)) clock := testutils.GetClock() // When tbs := NewTokenBucketSet(rates, clock) // Then assert.Equal(t, 7*time.Second, tbs.maxPeriod) } // Successful token consumption updates state of all buckets in the set. func TestConsume(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*time.Second, 10, 20)) require.NoError(t, rates.Add(10*time.Second, 20, 50)) clock := testutils.GetClock() tbs := NewTokenBucketSet(rates, clock) // When delay, err := tbs.Consume(15) require.NoError(t, err) // Then assert.Equal(t, time.Duration(0), delay) assert.Equal(t, "{1s: 5}, {10s: 35}", tbs.debugState()) } // As time goes by all set buckets are refilled with appropriate rates. func TestConsumeRefill(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(10*time.Second, 10, 20)) require.NoError(t, rates.Add(100*time.Second, 20, 50)) clock := testutils.GetClock() tbs := NewTokenBucketSet(rates, clock) _, err := tbs.Consume(15) require.NoError(t, err) assert.Equal(t, "{10s: 5}, {1m40s: 35}", tbs.debugState()) // When clock.Sleep(10 * time.Second) delay, err := tbs.Consume(0) // Consumes nothing but forces an internal state update. require.NoError(t, err) // Then assert.Equal(t, time.Duration(0), delay) assert.Equal(t, "{10s: 15}, {1m40s: 37}", tbs.debugState()) } // If the first bucket in the set has no enough tokens to allow desired // consumption then an appropriate delay is returned. func TestConsumeLimitedBy1st(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(10*time.Second, 10, 10)) require.NoError(t, rates.Add(100*time.Second, 20, 20)) clock := testutils.GetClock() tbs := NewTokenBucketSet(rates, clock) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{10s: 5}, {1m40s: 15}", tbs.debugState()) // When delay, err := tbs.Consume(10) require.NoError(t, err) // Then assert.Equal(t, 5*time.Second, delay) assert.Equal(t, "{10s: 5}, {1m40s: 15}", tbs.debugState()) } // If the second bucket in the set has no enough tokens to allow desired // consumption then an appropriate delay is returned. func TestConsumeLimitedBy2st(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(10*time.Second, 10, 10)) require.NoError(t, rates.Add(100*time.Second, 20, 20)) clock := testutils.GetClock() tbs := NewTokenBucketSet(rates, clock) _, err := tbs.Consume(10) require.NoError(t, err) clock.Sleep(10 * time.Second) _, err = tbs.Consume(10) require.NoError(t, err) clock.Sleep(5 * time.Second) _, err = tbs.Consume(0) require.NoError(t, err) assert.Equal(t, "{10s: 5}, {1m40s: 3}", tbs.debugState()) // When delay, err := tbs.Consume(10) require.NoError(t, err) // Then assert.Equal(t, 7*(5*time.Second), delay) assert.Equal(t, "{10s: 5}, {1m40s: 3}", tbs.debugState()) } // An attempt to consume more tokens then the smallest bucket capacity results // in error. func TestConsumeMoreThenBurst(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*time.Second, 10, 20)) require.NoError(t, rates.Add(10*time.Second, 50, 100)) clock := testutils.GetClock() tbs := NewTokenBucketSet(rates, clock) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{1s: 15}, {10s: 95}", tbs.debugState()) // When _, err = tbs.Consume(21) require.Error(t, err) // Then assert.Equal(t, "{1s: 15}, {10s: 95}", tbs.debugState()) } // Update operation can add buckets. func TestUpdateMore(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*time.Second, 10, 20)) require.NoError(t, rates.Add(10*time.Second, 20, 50)) require.NoError(t, rates.Add(20*time.Second, 45, 90)) clock := testutils.GetClock() tbs := NewTokenBucketSet(rates, clock) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{1s: 15}, {10s: 45}, {20s: 85}", tbs.debugState()) rates = NewRateSet() require.NoError(t, rates.Add(10*time.Second, 30, 40)) require.NoError(t, rates.Add(11*time.Second, 30, 40)) require.NoError(t, rates.Add(12*time.Second, 30, 40)) require.NoError(t, rates.Add(13*time.Second, 30, 40)) // When tbs.Update(rates) // Then assert.Equal(t, "{10s: 40}, {11s: 40}, {12s: 40}, {13s: 40}", tbs.debugState()) assert.Equal(t, 13*time.Second, tbs.maxPeriod) } // Update operation can remove buckets. func TestUpdateLess(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(1*time.Second, 10, 20)) require.NoError(t, rates.Add(10*time.Second, 20, 50)) require.NoError(t, rates.Add(20*time.Second, 45, 90)) require.NoError(t, rates.Add(30*time.Second, 50, 100)) clock := testutils.GetClock() tbs := NewTokenBucketSet(rates, clock) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{1s: 15}, {10s: 45}, {20s: 85}, {30s: 95}", tbs.debugState()) rates = NewRateSet() require.NoError(t, rates.Add(10*time.Second, 25, 20)) require.NoError(t, rates.Add(20*time.Second, 30, 21)) // When tbs.Update(rates) // Then assert.Equal(t, "{10s: 20}, {20s: 21}", tbs.debugState()) assert.Equal(t, 20*time.Second, tbs.maxPeriod) } // Update operation can remove buckets. func TestUpdateAllDifferent(t *testing.T) { // Given rates := NewRateSet() require.NoError(t, rates.Add(10*time.Second, 20, 50)) require.NoError(t, rates.Add(30*time.Second, 50, 100)) clock := testutils.GetClock() tbs := NewTokenBucketSet(rates, clock) _, err := tbs.Consume(5) require.NoError(t, err) assert.Equal(t, "{10s: 45}, {30s: 95}", tbs.debugState()) rates = NewRateSet() require.NoError(t, rates.Add(1*time.Second, 10, 40)) require.NoError(t, rates.Add(60*time.Second, 100, 150)) // When tbs.Update(rates) // Then assert.Equal(t, "{1s: 40}, {1m0s: 150}", tbs.debugState()) assert.Equal(t, 60*time.Second, tbs.maxPeriod) } oxy-1.3.0/ratelimit/tokenlimiter.go000066400000000000000000000147031404246664300173770ustar00rootroot00000000000000// Package ratelimit Tokenbucket based request rate limiter package ratelimit import ( "fmt" "net/http" "sync" "time" "github.com/mailgun/timetools" "github.com/mailgun/ttlmap" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) // DefaultCapacity default capacity const DefaultCapacity = 65536 // RateSet maintains a set of rates. It can contain only one rate per period at a time. type RateSet struct { m map[time.Duration]*rate } // NewRateSet crates an empty `RateSet` instance. func NewRateSet() *RateSet { rs := new(RateSet) rs.m = make(map[time.Duration]*rate) return rs } // Add adds a rate to the set. If there is a rate with the same period in the // set then the new rate overrides the old one. func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error { if period <= 0 { return fmt.Errorf("invalid period: %v", period) } if average <= 0 { return fmt.Errorf("invalid average: %v", average) } if burst <= 0 { return fmt.Errorf("invalid burst: %v", burst) } rs.m[period] = &rate{period: period, average: average, burst: burst} return nil } func (rs *RateSet) String() string { return fmt.Sprint(rs.m) } // RateExtractor rate extractor type RateExtractor interface { Extract(r *http.Request) (*RateSet, error) } // RateExtractorFunc rate extractor function type type RateExtractorFunc func(r *http.Request) (*RateSet, error) // Extract extract from request func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) { return e(r) } // TokenLimiter implements rate limiting middleware. type TokenLimiter struct { defaultRates *RateSet extract utils.SourceExtractor extractRates RateExtractor clock timetools.TimeProvider mutex sync.Mutex bucketSets *ttlmap.TtlMap errHandler utils.ErrorHandler capacity int next http.Handler log *log.Logger } // New constructs a `TokenLimiter` middleware instance. func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) { if defaultRates == nil || len(defaultRates.m) == 0 { return nil, fmt.Errorf("provide default rates") } if extract == nil { return nil, fmt.Errorf("provide extract function") } tl := &TokenLimiter{ next: next, defaultRates: defaultRates, extract: extract, log: log.StandardLogger(), } for _, o := range opts { if err := o(tl); err != nil { return nil, err } } setDefaults(tl) bucketSets, err := ttlmap.NewMapWithProvider(tl.capacity, tl.clock) if err != nil { return nil, err } tl.bucketSets = bucketSets return tl, nil } // Logger defines the logger the token limiter will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func Logger(l *log.Logger) TokenLimiterOption { return func(tl *TokenLimiter) error { tl.log = l return nil } } // Wrap sets the next handler to be called by token limiter handler. func (tl *TokenLimiter) Wrap(next http.Handler) { tl.next = next } func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) { source, amount, err := tl.extract.Extract(req) if err != nil { tl.errHandler.ServeHTTP(w, req, err) return } if err := tl.consumeRates(req, source, amount); err != nil { tl.log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err) tl.errHandler.ServeHTTP(w, req, err) return } tl.next.ServeHTTP(w, req) } func (tl *TokenLimiter) consumeRates(req *http.Request, source string, amount int64) error { tl.mutex.Lock() defer tl.mutex.Unlock() effectiveRates := tl.resolveRates(req) bucketSetI, exists := tl.bucketSets.Get(source) var bucketSet *TokenBucketSet if exists { bucketSet = bucketSetI.(*TokenBucketSet) bucketSet.Update(effectiveRates) } else { bucketSet = NewTokenBucketSet(effectiveRates, tl.clock) // We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip // the counters for this ip will expire after 10 seconds of inactivity tl.bucketSets.Set(source, bucketSet, int(bucketSet.maxPeriod/time.Second)*10+1) } delay, err := bucketSet.Consume(amount) if err != nil { return err } if delay > 0 { return &MaxRateError{Delay: delay} } return nil } // effectiveRates retrieves rates to be applied to the request. func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet { // If configuration mapper is not specified for this instance, then return // the default bucket specs. if tl.extractRates == nil { return tl.defaultRates } rates, err := tl.extractRates.Extract(req) if err != nil { tl.log.Errorf("Failed to retrieve rates: %v", err) return tl.defaultRates } // If the returned rate set is empty then used the default one. if len(rates.m) == 0 { return tl.defaultRates } return rates } // MaxRateError max rate error type MaxRateError struct { Delay time.Duration } func (m *MaxRateError) Error() string { return fmt.Sprintf("max rate reached: retry-in %v", m.Delay) } // RateErrHandler error handler type RateErrHandler struct{} func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { if rerr, ok := err.(*MaxRateError); ok { w.Header().Set("Retry-After", fmt.Sprintf("%.0f", rerr.Delay.Seconds())) w.Header().Set("X-Retry-In", rerr.Delay.String()) w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(err.Error())) return } utils.DefaultHandler.ServeHTTP(w, req, err) } // TokenLimiterOption token limiter option type type TokenLimiterOption func(l *TokenLimiter) error // ErrorHandler sets error handler of the server func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption { return func(cl *TokenLimiter) error { cl.errHandler = h return nil } } // ExtractRates sets the rate extractor func ExtractRates(e RateExtractor) TokenLimiterOption { return func(cl *TokenLimiter) error { cl.extractRates = e return nil } } // Clock sets the clock func Clock(clock timetools.TimeProvider) TokenLimiterOption { return func(cl *TokenLimiter) error { cl.clock = clock return nil } } // Capacity sets the capacity func Capacity(cap int) TokenLimiterOption { return func(cl *TokenLimiter) error { if cap <= 0 { return fmt.Errorf("bad capacity: %v", cap) } cl.capacity = cap return nil } } var defaultErrHandler = &RateErrHandler{} func setDefaults(tl *TokenLimiter) { if tl.capacity <= 0 { tl.capacity = DefaultCapacity } if tl.clock == nil { tl.clock = &timetools.RealTime{} } if tl.errHandler == nil { tl.errHandler = defaultErrHandler } } oxy-1.3.0/ratelimit/tokenlimiter_test.go000066400000000000000000000227141404246664300204370ustar00rootroot00000000000000package ratelimit import ( "fmt" "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" "github.com/vulcand/oxy/utils" ) func TestRateSetAdd(t *testing.T) { rs := NewRateSet() // Invalid period err := rs.Add(0, 1, 1) require.Error(t, err) // Invalid Average err = rs.Add(time.Second, 0, 1) require.Error(t, err) // Invalid Burst err = rs.Add(time.Second, 1, 0) require.Error(t, err) err = rs.Add(time.Second, 1, 1) require.NoError(t, err) assert.Equal(t, fmt.Sprint(rs), "map[1s:rate(1/1s, burst=1)]") } // We've hit the limit and were able to proceed on the next time run func TestHitLimit(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(time.Second, 1, 1) require.NoError(t, err) clock := testutils.GetClock() l, err := New(handler, headerLimit, rates, Clock(clock)) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) // Next request from the same source hits rate limit re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) // Second later, the request from this ip will succeed clock.Sleep(time.Second) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // We've failed to extract client ip func TestFailure(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(time.Second, 1, 1) require.NoError(t, err) clock := testutils.GetClock() l, err := New(handler, faultyExtract, rates, Clock(clock)) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } // Make sure rates from different ips are controlled separately func TestIsolation(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(time.Second, 1, 1) require.NoError(t, err) clock := testutils.GetClock() l, err := New(handler, headerLimit, rates, Clock(clock)) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) // Next request from the same source hits rate limit re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) // The request from other source can proceed re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "b")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // Make sure that expiration works (Expiration is triggered after significant amount of time passes) func TestExpiration(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(time.Second, 1, 1) require.NoError(t, err) clock := testutils.GetClock() l, err := New(handler, headerLimit, rates, Clock(clock)) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) // Next request from the same source hits rate limit re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) // 24 hours later, the request from this ip will succeed clock.Sleep(24 * time.Hour) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // If rate limiting configuration is valid, then it is applied. func TestExtractRates(t *testing.T) { // Given extractRates := func(*http.Request) (*RateSet, error) { rates := NewRateSet() err := rates.Add(time.Second, 2, 2) if err != nil { return nil, err } err = rates.Add(60*time.Second, 10, 10) if err != nil { return nil, err } return rates, nil } rates := NewRateSet() err := rates.Add(time.Second, 1, 1) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) clock := testutils.GetClock() tl, err := New(handler, headerLimit, rates, Clock(clock), ExtractRates(RateExtractorFunc(extractRates))) require.NoError(t, err) srv := httptest.NewServer(tl) defer srv.Close() // When/Then: The configured rate is applied, which 2 req/second re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) clock.Sleep(time.Second) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // If configMapper returns error, then the default rate is applied. func TestBadRateExtractor(t *testing.T) { // Given extractor := func(*http.Request) (*RateSet, error) { return nil, fmt.Errorf("boom") } rates := NewRateSet() err := rates.Add(time.Second, 1, 1) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) clock := testutils.GetClock() l, err := New(handler, headerLimit, rates, Clock(clock), ExtractRates(RateExtractorFunc(extractor))) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() // When/Then: The default rate is applied, which 1 req/second re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) clock.Sleep(time.Second) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // If configMapper returns empty rates, then the default rate is applied. func TestExtractorEmpty(t *testing.T) { // Given extractor := func(*http.Request) (*RateSet, error) { return NewRateSet(), nil } rates := NewRateSet() err := rates.Add(time.Second, 1, 1) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) clock := testutils.GetClock() l, err := New(handler, headerLimit, rates, Clock(clock), ExtractRates(RateExtractorFunc(extractor))) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() // When/Then: The default rate is applied, which 1 req/second re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, 429, re.StatusCode) clock.Sleep(time.Second) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestInvalidParams(t *testing.T) { // Rates are missing rs := NewRateSet() err := rs.Add(time.Second, 1, 1) require.NoError(t, err) // Empty _, err = New(nil, nil, rs) require.Error(t, err) // Rates are empty _, err = New(nil, nil, NewRateSet()) require.Error(t, err) // Bad capacity _, err = New(nil, headerLimit, rs, Capacity(-1)) require.Error(t, err) } // We've hit the limit and were able to proceed on the next time run func TestOptions(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) rates := NewRateSet() err := rates.Add(time.Second, 1, 1) require.NoError(t, err) errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) w.Write([]byte(http.StatusText(http.StatusTeapot))) }) clock := testutils.GetClock() l, err := New(handler, headerLimit, rates, ErrorHandler(errHandler), Clock(clock)) require.NoError(t, err) srv := httptest.NewServer(l) defer srv.Close() re, _, err := testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) re, _, err = testutils.Get(srv.URL, testutils.Header("Source", "a")) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) } func headerLimiter(req *http.Request) (string, int64, error) { return req.Header.Get("Source"), 1, nil } func faultyExtractor(_ *http.Request) (string, int64, error) { return "", -1, fmt.Errorf("oops") } var headerLimit = utils.ExtractorFunc(headerLimiter) var faultyExtract = utils.ExtractorFunc(faultyExtractor) oxy-1.3.0/roundrobin/000077500000000000000000000000001404246664300145245ustar00rootroot00000000000000oxy-1.3.0/roundrobin/RequestRewriteListener.go000066400000000000000000000002521404246664300215520ustar00rootroot00000000000000package roundrobin import "net/http" // RequestRewriteListener function to rewrite request type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request) oxy-1.3.0/roundrobin/rebalancer.go000066400000000000000000000277501404246664300171640ustar00rootroot00000000000000package roundrobin import ( "fmt" "net/http" "net/url" "sync" "time" "github.com/mailgun/timetools" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/memmetrics" "github.com/vulcand/oxy/utils" ) // RebalancerOption - functional option setter for rebalancer type RebalancerOption func(*Rebalancer) error // Meter measures server performance and returns it's relative value via rating type Meter interface { Rating() float64 Record(int, time.Duration) IsReady() bool } // NewMeterFn type of functions to create new Meter type NewMeterFn func() (Meter, error) // Rebalancer increases weights on servers that perform better than others. It also rolls back to original weights // if the servers have changed. It is designed as a wrapper on top of the roundrobin. type Rebalancer struct { // mutex mtx *sync.Mutex // As usual, control time in tests clock timetools.TimeProvider // Time that freezes state machine to accumulate stats after updating the weights backoffDuration time.Duration // Timer is set to give probing some time to take place timer time.Time // server records that remember original weights servers []*rbServer // next is internal load balancer next in chain next balancerHandler // errHandler is HTTP handler called in case of errors errHandler utils.ErrorHandler ratings []float64 // creates new meters newMeter NewMeterFn // sticky session object stickySession *StickySession requestRewriteListener RequestRewriteListener log *log.Logger } // RebalancerClock sets a clock func RebalancerClock(clock timetools.TimeProvider) RebalancerOption { return func(r *Rebalancer) error { r.clock = clock return nil } } // RebalancerBackoff sets a beck off duration func RebalancerBackoff(d time.Duration) RebalancerOption { return func(r *Rebalancer) error { r.backoffDuration = d return nil } } // RebalancerMeter sets a Meter builder function func RebalancerMeter(newMeter NewMeterFn) RebalancerOption { return func(r *Rebalancer) error { r.newMeter = newMeter return nil } } // RebalancerErrorHandler is a functional argument that sets error handler of the server func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption { return func(r *Rebalancer) error { r.errHandler = h return nil } } // RebalancerStickySession sets a sticky session func RebalancerStickySession(stickySession *StickySession) RebalancerOption { return func(r *Rebalancer) error { r.stickySession = stickySession return nil } } // RebalancerRequestRewriteListener is a functional argument that sets error handler of the server func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption { return func(r *Rebalancer) error { r.requestRewriteListener = rrl return nil } } // NewRebalancer creates a new Rebalancer func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) { rb := &Rebalancer{ mtx: &sync.Mutex{}, next: handler, stickySession: nil, log: log.StandardLogger(), } for _, o := range opts { if err := o(rb); err != nil { return nil, err } } if rb.clock == nil { rb.clock = &timetools.RealTime{} } if rb.backoffDuration == 0 { rb.backoffDuration = 10 * time.Second } if rb.newMeter == nil { rb.newMeter = func() (Meter, error) { rc, err := memmetrics.NewRatioCounter(10, time.Second, memmetrics.RatioClock(rb.clock)) if err != nil { return nil, err } return &codeMeter{ r: rc, codeS: http.StatusInternalServerError, codeE: http.StatusGatewayTimeout + 1, }, nil } } if rb.errHandler == nil { rb.errHandler = utils.DefaultHandler } return rb, nil } // RebalancerLogger defines the logger the rebalancer will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func RebalancerLogger(l *log.Logger) RebalancerOption { return func(rb *Rebalancer) error { rb.log = l return nil } } // Servers gets all servers func (rb *Rebalancer) Servers() []*url.URL { rb.mtx.Lock() defer rb.mtx.Unlock() return rb.next.Servers() } func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if rb.log.Level >= log.DebugLevel { logEntry := rb.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: completed ServeHttp on request") } pw := utils.NewProxyWriter(w) start := rb.clock.UtcNow() // make shallow copy of request before changing anything to avoid side effects newReq := *req stuck := false if rb.stickySession != nil { cookieUrl, present, err := rb.stickySession.GetBackend(&newReq, rb.Servers()) if err != nil { log.Warnf("vulcand/oxy/roundrobin/rebalancer: error using server from cookie: %v", err) } if present { newReq.URL = cookieUrl stuck = true } } if !stuck { fwdURL, err := rb.next.NextServer() if err != nil { rb.errHandler.ServeHTTP(w, req, err) return } if log.GetLevel() >= log.DebugLevel { // log which backend URL we're sending this request to log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": fwdURL}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL") } if rb.stickySession != nil { rb.stickySession.StickBackend(fwdURL, &w) } newReq.URL = fwdURL } // Emit event to a listener if one exists if rb.requestRewriteListener != nil { rb.requestRewriteListener(req, &newReq) } rb.next.Next().ServeHTTP(pw, &newReq) rb.recordMetrics(newReq.URL, pw.StatusCode(), rb.clock.UtcNow().Sub(start)) rb.adjustWeights() } func (rb *Rebalancer) recordMetrics(u *url.URL, code int, latency time.Duration) { rb.mtx.Lock() defer rb.mtx.Unlock() if srv, i := rb.findServer(u); i != -1 { srv.meter.Record(code, latency) } } func (rb *Rebalancer) reset() { for _, s := range rb.servers { s.curWeight = s.origWeight rb.next.UpsertServer(s.url, Weight(s.origWeight)) } rb.timer = rb.clock.UtcNow().Add(-1 * time.Second) rb.ratings = make([]float64, len(rb.servers)) } // Wrap sets the next handler to be called by rebalancer handler. func (rb *Rebalancer) Wrap(next balancerHandler) error { if rb.next != nil { return fmt.Errorf("already bound to %T", rb.next) } rb.next = next return nil } // UpsertServer upsert a server func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error { rb.mtx.Lock() defer rb.mtx.Unlock() if err := rb.next.UpsertServer(u, options...); err != nil { return err } weight, _ := rb.next.ServerWeight(u) if err := rb.upsertServer(u, weight); err != nil { rb.next.RemoveServer(u) return err } rb.reset() return nil } // RemoveServer remove a server func (rb *Rebalancer) RemoveServer(u *url.URL) error { rb.mtx.Lock() defer rb.mtx.Unlock() return rb.removeServer(u) } func (rb *Rebalancer) removeServer(u *url.URL) error { _, i := rb.findServer(u) if i == -1 { return fmt.Errorf("%v not found", u) } if err := rb.next.RemoveServer(u); err != nil { return err } rb.servers = append(rb.servers[:i], rb.servers[i+1:]...) rb.reset() return nil } func (rb *Rebalancer) upsertServer(u *url.URL, weight int) error { if s, i := rb.findServer(u); i != -1 { s.origWeight = weight } meter, err := rb.newMeter() if err != nil { return err } rbSrv := &rbServer{ url: utils.CopyURL(u), origWeight: weight, curWeight: weight, meter: meter, } rb.servers = append(rb.servers, rbSrv) return nil } func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) { if len(rb.servers) == 0 { return nil, -1 } for i, s := range rb.servers { if sameURL(u, s.url) { return s, i } } return nil, -1 } // adjustWeights Called on every load balancer ServeHTTP call, returns the suggested weights // on every call, can adjust weights if needed. func (rb *Rebalancer) adjustWeights() { rb.mtx.Lock() defer rb.mtx.Unlock() // In this case adjusting weights would have no effect, so do nothing if len(rb.servers) < 2 { return } // Metrics are not ready if !rb.metricsReady() { return } if !rb.timerExpired() { return } if rb.markServers() { if rb.setMarkedWeights() { rb.setTimer() } } else { // No servers that are different by their quality, so converge weights if rb.convergeWeights() { rb.setTimer() } } } func (rb *Rebalancer) applyWeights() { for _, srv := range rb.servers { rb.log.Debugf("upsert server %v, weight %v", srv.url, srv.curWeight) rb.next.UpsertServer(srv.url, Weight(srv.curWeight)) } } func (rb *Rebalancer) setMarkedWeights() bool { changed := false // Increase weights on servers marked as good for _, srv := range rb.servers { if srv.good { weight := increase(srv.curWeight) if weight <= FSMMaxWeight { rb.log.Debugf("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight) srv.curWeight = weight changed = true } } } if changed { rb.normalizeWeights() rb.applyWeights() return true } return false } func (rb *Rebalancer) setTimer() { rb.timer = rb.clock.UtcNow().Add(rb.backoffDuration) } func (rb *Rebalancer) timerExpired() bool { return rb.timer.Before(rb.clock.UtcNow()) } func (rb *Rebalancer) metricsReady() bool { for _, s := range rb.servers { if !s.meter.IsReady() { return false } } return true } // markServers splits servers into two groups of servers with bad and good failure rate. // It does compare relative performances of the servers though, so if all servers have approximately the same error rate // this function returns the result as if all servers are equally good. func (rb *Rebalancer) markServers() bool { for i, srv := range rb.servers { rb.ratings[i] = srv.meter.Rating() } g, b := memmetrics.SplitFloat64(splitThreshold, 0, rb.ratings) for i, srv := range rb.servers { if g[rb.ratings[i]] { srv.good = true } else { srv.good = false } } if len(g) != 0 && len(b) != 0 { rb.log.Debugf("bad: %v good: %v, ratings: %v", b, g, rb.ratings) } return len(g) != 0 && len(b) != 0 } func (rb *Rebalancer) convergeWeights() bool { // If we have previously changed servers try to restore weights to the original state changed := false for _, s := range rb.servers { if s.origWeight == s.curWeight { continue } changed = true newWeight := decrease(s.origWeight, s.curWeight) log.Debugf("decreasing weight of %v from %v to %v", s.url, s.curWeight, newWeight) s.curWeight = newWeight } if !changed { return false } rb.normalizeWeights() rb.applyWeights() return true } func (rb *Rebalancer) weightsGcd() int { divisor := -1 for _, w := range rb.servers { if divisor == -1 { divisor = w.curWeight } else { divisor = gcd(divisor, w.curWeight) } } return divisor } func (rb *Rebalancer) normalizeWeights() { gcd := rb.weightsGcd() if gcd <= 1 { return } for _, s := range rb.servers { s.curWeight = s.curWeight / gcd } } func increase(weight int) int { return weight * FSMGrowFactor } func decrease(target, current int) int { adjusted := current / FSMGrowFactor if adjusted < target { return target } return adjusted } // rebalancer server record that keeps track of the original weight supplied by user type rbServer struct { url *url.URL origWeight int // original weight supplied by user curWeight int // current weight good bool meter Meter } const ( // FSMMaxWeight is the maximum weight that handler will set for the server FSMMaxWeight = 4096 // FSMGrowFactor Multiplier for the server weight FSMGrowFactor = 4 ) type codeMeter struct { r *memmetrics.RatioCounter codeS int codeE int } // Rating gets ratio func (n *codeMeter) Rating() float64 { return n.r.Ratio() } // Record records a meter func (n *codeMeter) Record(code int, d time.Duration) { if code >= n.codeS && code < n.codeE { n.r.IncA(1) } else { n.r.IncB(1) } } // IsReady returns true if the counter is ready func (n *codeMeter) IsReady() bool { return n.r.IsReady() } // splitThreshold tells how far the value should go from the median + median absolute deviation before it is considered an outlier const splitThreshold = 1.5 oxy-1.3.0/roundrobin/rebalancer_test.go000066400000000000000000000271631404246664300202210ustar00rootroot00000000000000package roundrobin import ( "io/ioutil" "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/testutils" ) func TestRebalancerNormalOperation(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) assert.Equal(t, a.URL, rb.Servers()[0].String()) proxy := httptest.NewServer(rb) defer proxy.Close() assert.Equal(t, []string{"a", "a", "a"}, seq(t, proxy.URL, 3)) } func TestRebalancerNoServers(t *testing.T) { fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb) require.NoError(t, err) proxy := httptest.NewServer(rb) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } func TestRebalancerRemoveServer(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(rb) defer proxy.Close() assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) require.NoError(t, rb.RemoveServer(testutils.ParseURI(a.URL))) assert.Equal(t, []string{"b", "b", "b"}, seq(t, proxy.URL, 3)) } // Test scenario when one server goes down after what it recovers func TestRebalancerRecovery(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) newMeter := func() (Meter, error) { return &testMeter{}, nil } clock := testutils.GetClock() rb, err := NewRebalancer(lb, RebalancerMeter(newMeter), RebalancerClock(clock)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) rb.servers[0].meter.(*testMeter).rating = 0.3 proxy := httptest.NewServer(rb) defer proxy.Close() for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.CurrentTime = clock.CurrentTime.Add(rb.backoffDuration + time.Second) } assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) assert.Equal(t, 1, lb.servers[0].weight) assert.Equal(t, FSMMaxWeight, lb.servers[1].weight) // server a is now recovering, the weights should go back to the original state rb.servers[0].meter.(*testMeter).rating = 0 for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.CurrentTime = clock.CurrentTime.Add(rb.backoffDuration + time.Second) } assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, 1, rb.servers[1].curWeight) // Make sure we have applied the weights to the inner load balancer assert.Equal(t, 1, lb.servers[0].weight) assert.Equal(t, 1, lb.servers[1].weight) } // Test scenario when increaing the weight on good endpoints made it worse func TestRebalancerCascading(t *testing.T) { a, b, d := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("d") defer a.Close() defer b.Close() defer d.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) newMeter := func() (Meter, error) { return &testMeter{}, nil } clock := testutils.GetClock() rb, err := NewRebalancer(lb, RebalancerMeter(newMeter), RebalancerClock(clock)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(d.URL)) require.NoError(t, err) rb.servers[0].meter.(*testMeter).rating = 0.3 proxy := httptest.NewServer(rb) defer proxy.Close() for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.CurrentTime = clock.CurrentTime.Add(rb.backoffDuration + time.Second) } // We have increased the load, and the situation became worse as the other servers started failing assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[2].curWeight) // server a is now recovering, the weights should go back to the original state rb.servers[0].meter.(*testMeter).rating = 0.3 rb.servers[1].meter.(*testMeter).rating = 0.2 rb.servers[2].meter.(*testMeter).rating = 0.2 for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.CurrentTime = clock.CurrentTime.Add(rb.backoffDuration + time.Second) } // the algo reverted it back assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, 1, rb.servers[1].curWeight) assert.Equal(t, 1, rb.servers[2].curWeight) } // Test scenario when all servers started failing func TestRebalancerAllBad(t *testing.T) { a, b, d := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("d") defer a.Close() defer b.Close() defer d.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) newMeter := func() (Meter, error) { return &testMeter{}, nil } clock := testutils.GetClock() rb, err := NewRebalancer(lb, RebalancerMeter(newMeter), RebalancerClock(clock)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(d.URL)) require.NoError(t, err) rb.servers[0].meter.(*testMeter).rating = 0.12 rb.servers[1].meter.(*testMeter).rating = 0.13 rb.servers[2].meter.(*testMeter).rating = 0.11 proxy := httptest.NewServer(rb) defer proxy.Close() for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.CurrentTime = clock.CurrentTime.Add(rb.backoffDuration + time.Second) } // load balancer does nothing assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, 1, rb.servers[1].curWeight) assert.Equal(t, 1, rb.servers[2].curWeight) } // Removing the server resets the state func TestRebalancerReset(t *testing.T) { a, b, d := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("d") defer a.Close() defer b.Close() defer d.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) newMeter := func() (Meter, error) { return &testMeter{}, nil } clock := testutils.GetClock() rb, err := NewRebalancer(lb, RebalancerMeter(newMeter), RebalancerClock(clock)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(d.URL)) require.NoError(t, err) rb.servers[0].meter.(*testMeter).rating = 0.3 rb.servers[1].meter.(*testMeter).rating = 0 rb.servers[2].meter.(*testMeter).rating = 0 proxy := httptest.NewServer(rb) defer proxy.Close() for i := 0; i < 6; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) clock.CurrentTime = clock.CurrentTime.Add(rb.backoffDuration + time.Second) } // load balancer changed weights assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[2].curWeight) // Removing servers has reset the state err = rb.RemoveServer(testutils.ParseURI(d.URL)) require.NoError(t, err) assert.Equal(t, 1, rb.servers[0].curWeight) assert.Equal(t, 1, rb.servers[1].curWeight) } func TestRebalancerRequestRewriteListenerLive(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) clock := testutils.GetClock() rb, err := NewRebalancer(lb, RebalancerBackoff(time.Millisecond), RebalancerClock(clock)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI("http://localhost:62345")) require.NoError(t, err) proxy := httptest.NewServer(rb) defer proxy.Close() for i := 0; i < 1000; i++ { _, _, err = testutils.Get(proxy.URL) require.NoError(t, err) if i%10 == 0 { clock.CurrentTime = clock.CurrentTime.Add(rb.backoffDuration + time.Second) } } // load balancer changed weights assert.Equal(t, FSMMaxWeight, rb.servers[0].curWeight) assert.Equal(t, FSMMaxWeight, rb.servers[1].curWeight) assert.Equal(t, 1, rb.servers[2].curWeight) } func TestRebalancerRequestRewriteListener(t *testing.T) { a, b := testutils.NewResponder("a"), testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb, RebalancerRequestRewriteListener(func(oldReq *http.Request, newReq *http.Request) { })) require.NoError(t, err) assert.NotNil(t, rb.requestRewriteListener) } func TestRebalancerStickySession(t *testing.T) { a, b, x := testutils.NewResponder("a"), testutils.NewResponder("b"), testutils.NewResponder("x") defer a.Close() defer b.Close() defer x.Close() sticky := NewStickySession("test") require.NotNil(t, sticky) fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) rb, err := NewRebalancer(lb, RebalancerStickySession(sticky)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) err = rb.UpsertServer(testutils.ParseURI(x.URL)) require.NoError(t, err) proxy := httptest.NewServer(rb) defer proxy.Close() for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "a", string(body)) } require.NoError(t, rb.RemoveServer(testutils.ParseURI(a.URL))) assert.Equal(t, []string{"b", "x", "b"}, seq(t, proxy.URL, 3)) require.NoError(t, rb.RemoveServer(testutils.ParseURI(b.URL))) assert.Equal(t, []string{"x", "x", "x"}, seq(t, proxy.URL, 3)) } type testMeter struct { rating float64 notReady bool } func (tm *testMeter) Rating() float64 { return tm.rating } func (tm *testMeter) Record(int, time.Duration) { } func (tm *testMeter) IsReady() bool { return !tm.notReady } oxy-1.3.0/roundrobin/rr.go000066400000000000000000000175701404246664300155100ustar00rootroot00000000000000// Package roundrobin implements dynamic weighted round robin load balancer http handler package roundrobin import ( "fmt" "net/http" "net/url" "sync" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) // Weight is an optional functional argument that sets weight of the server func Weight(w int) ServerOption { return func(s *server) error { if w < 0 { return fmt.Errorf("Weight should be >= 0") } s.weight = w return nil } } // ErrorHandler is a functional argument that sets error handler of the server func ErrorHandler(h utils.ErrorHandler) LBOption { return func(s *RoundRobin) error { s.errHandler = h return nil } } // EnableStickySession enable sticky session func EnableStickySession(stickySession *StickySession) LBOption { return func(s *RoundRobin) error { s.stickySession = stickySession return nil } } // RoundRobinRequestRewriteListener is a functional argument that sets error handler of the server func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { return func(s *RoundRobin) error { s.requestRewriteListener = rrl return nil } } // RoundRobin implements dynamic weighted round robin load balancer http handler type RoundRobin struct { mutex *sync.Mutex next http.Handler errHandler utils.ErrorHandler // Current index (starts from -1) index int servers []*server currentWeight int stickySession *StickySession requestRewriteListener RequestRewriteListener log *log.Logger } // New created a new RoundRobin func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { rr := &RoundRobin{ next: next, index: -1, mutex: &sync.Mutex{}, servers: []*server{}, stickySession: nil, log: log.StandardLogger(), } for _, o := range opts { if err := o(rr); err != nil { return nil, err } } if rr.errHandler == nil { rr.errHandler = utils.DefaultHandler } return rr, nil } // RoundRobinLogger defines the logger the round robin load balancer will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func RoundRobinLogger(l *log.Logger) LBOption { return func(r *RoundRobin) error { r.log = l return nil } } // Next returns the next handler func (r *RoundRobin) Next() http.Handler { return r.next } func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { if r.log.Level >= log.DebugLevel { logEntry := r.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/roundrobin/rr: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/roundrobin/rr: completed ServeHttp on request") } // make shallow copy of request before chaning anything to avoid side effects newReq := *req stuck := false if r.stickySession != nil { cookieURL, present, err := r.stickySession.GetBackend(&newReq, r.Servers()) if err != nil { log.Warnf("vulcand/oxy/roundrobin/rr: error using server from cookie: %v", err) } if present { newReq.URL = cookieURL stuck = true } } if !stuck { url, err := r.NextServer() if err != nil { r.errHandler.ServeHTTP(w, req, err) return } if r.stickySession != nil { r.stickySession.StickBackend(url, &w) } newReq.URL = url } if r.log.Level >= log.DebugLevel { // log which backend URL we're sending this request to r.log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL") } // Emit event to a listener if one exists if r.requestRewriteListener != nil { r.requestRewriteListener(req, &newReq) } r.next.ServeHTTP(w, &newReq) } // NextServer gets the next server func (r *RoundRobin) NextServer() (*url.URL, error) { srv, err := r.nextServer() if err != nil { return nil, err } return utils.CopyURL(srv.url), nil } func (r *RoundRobin) nextServer() (*server, error) { r.mutex.Lock() defer r.mutex.Unlock() if len(r.servers) == 0 { return nil, fmt.Errorf("no servers in the pool") } // The algo below may look messy, but is actually very simple // it calculates the GCD and subtracts it on every iteration, what interleaves servers // and allows us not to build an iterator every time we readjust weights // GCD across all enabled servers gcd := r.weightGcd() // Maximum weight across all enabled servers max := r.maxWeight() for { r.index = (r.index + 1) % len(r.servers) if r.index == 0 { r.currentWeight = r.currentWeight - gcd if r.currentWeight <= 0 { r.currentWeight = max if r.currentWeight == 0 { return nil, fmt.Errorf("all servers have 0 weight") } } } srv := r.servers[r.index] if srv.weight >= r.currentWeight { return srv, nil } } } // RemoveServer remove a server func (r *RoundRobin) RemoveServer(u *url.URL) error { r.mutex.Lock() defer r.mutex.Unlock() e, index := r.findServerByURL(u) if e == nil { return fmt.Errorf("server not found") } r.servers = append(r.servers[:index], r.servers[index+1:]...) r.resetState() return nil } // Servers gets servers URL func (r *RoundRobin) Servers() []*url.URL { r.mutex.Lock() defer r.mutex.Unlock() out := make([]*url.URL, len(r.servers)) for i, srv := range r.servers { out[i] = srv.url } return out } // ServerWeight gets the server weight func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) { r.mutex.Lock() defer r.mutex.Unlock() if s, _ := r.findServerByURL(u); s != nil { return s.weight, true } return -1, false } // UpsertServer In case if server is already present in the load balancer, returns error func (r *RoundRobin) UpsertServer(u *url.URL, options ...ServerOption) error { r.mutex.Lock() defer r.mutex.Unlock() if u == nil { return fmt.Errorf("server URL can't be nil") } if s, _ := r.findServerByURL(u); s != nil { for _, o := range options { if err := o(s); err != nil { return err } } r.resetState() return nil } srv := &server{url: utils.CopyURL(u)} for _, o := range options { if err := o(srv); err != nil { return err } } if srv.weight == 0 { srv.weight = defaultWeight } r.servers = append(r.servers, srv) r.resetState() return nil } func (r *RoundRobin) resetIterator() { r.index = -1 r.currentWeight = 0 } func (r *RoundRobin) resetState() { r.resetIterator() } func (r *RoundRobin) findServerByURL(u *url.URL) (*server, int) { if len(r.servers) == 0 { return nil, -1 } for i, s := range r.servers { if sameURL(u, s.url) { return s, i } } return nil, -1 } func (r *RoundRobin) maxWeight() int { max := -1 for _, s := range r.servers { if s.weight > max { max = s.weight } } return max } func (r *RoundRobin) weightGcd() int { divisor := -1 for _, s := range r.servers { if divisor == -1 { divisor = s.weight } else { divisor = gcd(divisor, s.weight) } } return divisor } func gcd(a, b int) int { for b != 0 { a, b = b, a%b } return a } // ServerOption provides various options for server, e.g. weight type ServerOption func(*server) error // LBOption provides options for load balancer type LBOption func(*RoundRobin) error // Set additional parameters for the server can be supplied when adding server type server struct { url *url.URL // Relative weight for the enpoint to other enpoints in the load balancer weight int } var defaultWeight = 1 // SetDefaultWeight sets the default server weight func SetDefaultWeight(weight int) error { if weight < 0 { return fmt.Errorf("default weight should be >= 0") } defaultWeight = weight return nil } func sameURL(a, b *url.URL) bool { return a.Path == b.Path && a.Host == b.Host && a.Scheme == b.Scheme } type balancerHandler interface { Servers() []*url.URL ServeHTTP(w http.ResponseWriter, req *http.Request) ServerWeight(u *url.URL) (int, bool) RemoveServer(u *url.URL) error UpsertServer(u *url.URL, options ...ServerOption) error NextServer() (*url.URL, error) Next() http.Handler } oxy-1.3.0/roundrobin/rr_test.go000066400000000000000000000127551404246664300165470ustar00rootroot00000000000000package roundrobin import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/testutils" "github.com/vulcand/oxy/utils" ) func TestNoServers(t *testing.T) { fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, re.StatusCode) } func TestRemoveBadServer(t *testing.T) { lb, err := New(nil) require.NoError(t, err) assert.Error(t, lb.RemoveServer(testutils.ParseURI("http://google.com"))) } func TestCustomErrHandler(t *testing.T) { errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) { w.WriteHeader(http.StatusTeapot) w.Write([]byte(http.StatusText(http.StatusTeapot))) }) fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd, ErrorHandler(errHandler)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusTeapot, re.StatusCode) } func TestOneServer(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) proxy := httptest.NewServer(lb) defer proxy.Close() assert.Equal(t, []string{"a", "a", "a"}, seq(t, proxy.URL, 3)) } func TestSimple(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL))) proxy := httptest.NewServer(lb) defer proxy.Close() assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) } func TestRemoveServer(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL))) proxy := httptest.NewServer(lb) defer proxy.Close() assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) err = lb.RemoveServer(testutils.ParseURI(a.URL)) require.NoError(t, err) assert.Equal(t, []string{"b", "b", "b"}, seq(t, proxy.URL, 3)) } func TestUpsertSame(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) proxy := httptest.NewServer(lb) defer proxy.Close() assert.Equal(t, []string{"a", "a", "a"}, seq(t, proxy.URL, 3)) } func TestUpsertWeight(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL))) proxy := httptest.NewServer(lb) defer proxy.Close() assert.Equal(t, []string{"a", "b", "a"}, seq(t, proxy.URL, 3)) assert.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL), Weight(3))) assert.Equal(t, []string{"b", "b", "a", "b"}, seq(t, proxy.URL, 4)) } func TestWeighted(t *testing.T) { require.NoError(t, SetDefaultWeight(0)) defer SetDefaultWeight(1) a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() z := testutils.NewResponder("z") defer z.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd) require.NoError(t, err) require.NoError(t, lb.UpsertServer(testutils.ParseURI(a.URL), Weight(3))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(b.URL), Weight(2))) require.NoError(t, lb.UpsertServer(testutils.ParseURI(z.URL), Weight(0))) proxy := httptest.NewServer(lb) defer proxy.Close() assert.Equal(t, []string{"a", "a", "b", "a", "b", "a"}, seq(t, proxy.URL, 6)) w, ok := lb.ServerWeight(testutils.ParseURI(a.URL)) assert.Equal(t, 3, w) assert.Equal(t, true, ok) w, ok = lb.ServerWeight(testutils.ParseURI(b.URL)) assert.Equal(t, 2, w) assert.Equal(t, true, ok) w, ok = lb.ServerWeight(testutils.ParseURI(z.URL)) assert.Equal(t, 0, w) assert.Equal(t, true, ok) w, ok = lb.ServerWeight(testutils.ParseURI("http://caramba:4000")) assert.Equal(t, -1, w) assert.Equal(t, false, ok) } func TestRequestRewriteListener(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() b := testutils.NewResponder("b") defer b.Close() fwd, err := forward.New() require.NoError(t, err) lb, err := New(fwd, RoundRobinRequestRewriteListener(func(oldReq *http.Request, newReq *http.Request) {})) require.NoError(t, err) assert.NotNil(t, lb.requestRewriteListener) } func seq(t *testing.T, url string, repeat int) []string { var out []string for i := 0; i < repeat; i++ { _, body, err := testutils.Get(url) require.NoError(t, err) out = append(out, string(body)) } return out } oxy-1.3.0/roundrobin/stickycookie/000077500000000000000000000000001404246664300172245ustar00rootroot00000000000000oxy-1.3.0/roundrobin/stickycookie/aes_value.go000066400000000000000000000065621404246664300215300ustar00rootroot00000000000000package stickycookie import ( "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "encoding/binary" "errors" "fmt" "io" "net/url" "strconv" "strings" "time" ) // AESValue manages hashed sticky value. type AESValue struct { block cipher.AEAD ttl time.Duration } // NewAESValue takes a fixed-size key and returns an CookieValue or an error. // Key size must be exactly one of 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256. func NewAESValue(key []byte, ttl time.Duration) (*AESValue, error) { block, err := aes.NewCipher(key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } return &AESValue{block: gcm, ttl: ttl}, nil } // Get hashes the sticky value. func (v *AESValue) Get(raw *url.URL) string { base := raw.String() if v.ttl > 0 { base = fmt.Sprintf("%s|%d", base, time.Now().UTC().Add(v.ttl).Unix()) } // Nonce is the 64bit nanosecond-resolution time, plus 32bits of crypto/rand, for 96bits (12Bytes). // Theoretically, if 2^32 calls were made in 1 nanoseconds, there might be a repeat. // Adds ~765ns, and 4B heap in 1 alloc nonce := make([]byte, 12) binary.PutVarint(nonce, time.Now().UnixNano()) rpend := make([]byte, 4) if _, err := io.ReadFull(rand.Reader, rpend); err != nil { // This is a near-impossible error condition on Linux systems. // An error here means rand.Reader (and thus getrandom(2), and thus /dev/urandom) returned // less than 4 bytes of data. /dev/urandom is guaranteed to always return the number of // bytes requested up to 512 bytes on modern kernels. Behaviour on non-Linux systems // varies, of course. panic(err) } for i := 0; i < 4; i++ { nonce[i+8] = rpend[i] } obfuscated := v.block.Seal(nil, nonce, []byte(base), nil) // We append the 12byte nonce onto the end of the message obfuscated = append(obfuscated, nonce...) obfuscatedStr := base64.RawURLEncoding.EncodeToString(obfuscated) return obfuscatedStr } // FindURL gets url from array that match the value. func (v *AESValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { rawURL, err := v.fromValue(raw) if err != nil { return nil, err } for _, u := range urls { ok, err := areURLEqual(rawURL, u) if err != nil { return nil, err } if ok { return u, nil } } return nil, nil } func (v *AESValue) fromValue(obfuscatedStr string) (string, error) { obfuscated, err := base64.RawURLEncoding.DecodeString(obfuscatedStr) if err != nil { return "", err } // The first len-12 bytes is the ciphertext, the last 12 bytes is the nonce n := len(obfuscated) - 12 if n <= 0 { // Protect against range errors causing panics return "", errors.New("post-base64-decoded string is too short") } nonce := obfuscated[n:] obfuscated = obfuscated[:n] raw, err := v.block.Open(nil, nonce, []byte(obfuscated), nil) if err != nil { return "", err } if v.ttl > 0 { rawParts := strings.Split(string(raw), "|") if len(rawParts) < 2 { return "", fmt.Errorf("TTL set but cookie doesn't contain an expiration: '%s'", raw) } // validate the ttl i, err := strconv.ParseInt(rawParts[1], 10, 64) if err != nil { return "", err } if time.Now().UTC().After(time.Unix(i, 0).UTC()) { strTime := time.Unix(i, 0).UTC().String() return "", fmt.Errorf("TTL expired: '%s' (%s)\n", raw, strTime) } raw = []byte(rawParts[0]) } return string(raw), nil } oxy-1.3.0/roundrobin/stickycookie/cookie_value.go000066400000000000000000000013471404246664300222250ustar00rootroot00000000000000package stickycookie import "net/url" // CookieValue interface to manage the sticky cookie value format. // It will be used by the load balancer to generate the sticky cookie value and to retrieve the matching url. type CookieValue interface { // Get converts raw value to an expected sticky format. Get(*url.URL) string // FindURL gets url from array that match the value. FindURL(string, []*url.URL) (*url.URL, error) } // areURLEqual compare a string to a url and check if the string is the same as the url value. func areURLEqual(normalized string, u *url.URL) (bool, error) { u1, err := url.Parse(normalized) if err != nil { return false, err } return u1.Scheme == u.Scheme && u1.Host == u.Host && u1.Path == u.Path, nil } oxy-1.3.0/roundrobin/stickycookie/fallback_value.go000066400000000000000000000016061404246664300225110ustar00rootroot00000000000000package stickycookie import ( "errors" "net/url" ) // FallbackValue manages hashed sticky value. type FallbackValue struct { from CookieValue to CookieValue } // NewFallbackValue creates a new FallbackValue func NewFallbackValue(from CookieValue, to CookieValue) (*FallbackValue, error) { if from == nil || to == nil { return nil, errors.New("from and to are mandatory") } return &FallbackValue{from: from, to: to}, nil } // Get hashes the sticky value. func (v *FallbackValue) Get(raw *url.URL) string { return v.to.Get(raw) } // FindURL gets url from array that match the value. // If it is a symmetric algorithm, it decodes the URL, otherwise it compares the ciphered values. func (v *FallbackValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { findURL, err := v.from.FindURL(raw, urls) if findURL != nil { return findURL, err } return v.to.FindURL(raw, urls) } oxy-1.3.0/roundrobin/stickycookie/fallback_value_test.go000066400000000000000000000133001404246664300235420ustar00rootroot00000000000000package stickycookie import ( "fmt" "net/url" "path" "testing" "time" "github.com/stretchr/testify/require" "github.com/stretchr/testify/assert" ) func TestFallbackValue_FindURL(t *testing.T) { servers := []*url.URL{ {Scheme: "https", Host: "10.10.10.42", Path: "/"}, {Scheme: "http", Host: "10.10.10.10", Path: "/foo"}, {Scheme: "http", Host: "10.10.10.11", Path: "/", User: url.User("John Doe")}, {Scheme: "http", Host: "10.10.10.10", Path: "/"}, } aesValue, err := NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 5*time.Second) require.NoError(t, err) values := []struct { Name string CookieValue CookieValue }{ {Name: "rawValue", CookieValue: &RawValue{}}, {Name: "hashValue", CookieValue: &HashValue{Salt: "foo"}}, {Name: "aesValue", CookieValue: aesValue}, } for _, from := range values { from := from for _, to := range values { to := to t.Run(fmt.Sprintf("From: %s, To %s", from.Name, to.Name), func(t *testing.T) { t.Parallel() value, err := NewFallbackValue(from.CookieValue, to.CookieValue) if from.CookieValue == nil && to.CookieValue == nil { assert.Error(t, err) return } require.NoError(t, err) if from.CookieValue != nil { // URL found From value findURL, err := value.FindURL(from.CookieValue.Get(servers[0]), servers) require.NoError(t, err) assert.Equal(t, servers[0], findURL) // URL not found From value findURL, _ = value.FindURL(from.CookieValue.Get(mustJoin(t, servers[0], "bar")), servers) assert.Nil(t, findURL) } if to.CookieValue != nil { // URL found To Value findURL, err := value.FindURL(to.CookieValue.Get(servers[0]), servers) require.NoError(t, err) assert.Equal(t, servers[0], findURL) // URL not found To value findURL, _ = value.FindURL(to.CookieValue.Get(mustJoin(t, servers[0], "bar")), servers) assert.Nil(t, findURL) } }) } } } func TestFallbackValue_FindURL_error(t *testing.T) { servers := []*url.URL{ {Scheme: "http", Host: "10.10.10.10", Path: "/"}, {Scheme: "https", Host: "10.10.10.42", Path: "/"}, {Scheme: "http", Host: "10.10.10.10", Path: "/foo"}, {Scheme: "http", Host: "10.10.10.11", Path: "/", User: url.User("John Doe")}, } hashValue := &HashValue{Salt: "foo"} rawValue := &RawValue{} aesValue, err := NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 5*time.Second) require.NoError(t, err) tests := []struct { name string From CookieValue To CookieValue rawValue string want *url.URL expectError bool expectErrorOnNew bool }{ { name: "From RawValue To HashValue with RawValue value", From: rawValue, To: hashValue, rawValue: "http://10.10.10.10/", want: servers[0], }, { name: "From RawValue To HashValue with RawValue non matching value", From: rawValue, To: hashValue, rawValue: "http://24.10.10.10/", }, { name: "From RawValue To HashValue with HashValue value", From: rawValue, To: hashValue, rawValue: hashValue.Get(mustParse(t, "http://10.10.10.10/")), want: servers[0], }, { name: "From RawValue To HashValue with HashValue non matching value", From: rawValue, To: hashValue, rawValue: hashValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From HashValue To AESValue with AESValue value", From: hashValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://10.10.10.10/")), want: servers[0], }, { name: "From HashValue To AESValue with AESValue non matching value", From: hashValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From HashValue To AESValue with HashValue value", From: hashValue, To: aesValue, rawValue: hashValue.Get(mustParse(t, "http://10.10.10.10/")), want: servers[0], }, { name: "From HashValue To AESValue with AESValue non matching value", From: hashValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From AESValue To AESValue with AESValue value", From: aesValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://10.10.10.10/")), want: servers[0], }, { name: "From AESValue To AESValue with AESValue non matching value", From: aesValue, To: aesValue, rawValue: aesValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From AESValue To HashValue with HashValue non matching value", From: aesValue, To: hashValue, rawValue: hashValue.Get(mustParse(t, "http://24.10.10.10/")), }, { name: "From nil To RawValue", To: hashValue, rawValue: "http://24.10.10.10/", expectErrorOnNew: true, }, { name: "From RawValue To nil", From: rawValue, rawValue: "http://24.10.10.10/", expectErrorOnNew: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { value, err := NewFallbackValue(tt.From, tt.To) if tt.expectErrorOnNew { assert.Error(t, err) return } require.NoError(t, err) findURL, err := value.FindURL(tt.rawValue, servers) if tt.expectError { assert.Error(t, err) return } require.NoError(t, err) assert.Equal(t, tt.want, findURL) }) } } func mustJoin(t *testing.T, u *url.URL, part string) *url.URL { t.Helper() nu, err := u.Parse(path.Join(u.Path, part)) require.NoError(t, err) return nu } func mustParse(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) require.NoError(t, err) return u } oxy-1.3.0/roundrobin/stickycookie/hash_value.go000066400000000000000000000014601404246664300216730ustar00rootroot00000000000000package stickycookie import ( "fmt" "net/url" "github.com/segmentio/fasthash/fnv1a" ) // HashValue manages hashed sticky value. type HashValue struct { // Salt secret to anonymize the hashed cookie Salt string } // Get hashes the sticky value. func (v *HashValue) Get(raw *url.URL) string { return v.hash(raw.String()) } // FindURL gets url from array that match the value. func (v *HashValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { for _, u := range urls { if raw == v.hash(normalized(u)) { return u, nil } } return nil, nil } func (v *HashValue) hash(input string) string { return fmt.Sprintf("%x", fnv1a.HashString64(v.Salt+input)) } func normalized(u *url.URL) string { normalized := url.URL{Scheme: u.Scheme, Host: u.Host, Path: u.Path} return normalized.String() } oxy-1.3.0/roundrobin/stickycookie/raw_value.go000066400000000000000000000010001404246664300215270ustar00rootroot00000000000000package stickycookie import ( "net/url" ) // RawValue is a no-op that returns the raw strings as-is. type RawValue struct{} // Get returns the raw value. func (v *RawValue) Get(raw *url.URL) string { return raw.String() } // FindURL gets url from array that match the value. func (v *RawValue) FindURL(raw string, urls []*url.URL) (*url.URL, error) { for _, u := range urls { ok, err := areURLEqual(raw, u) if err != nil { return nil, err } if ok { return u, nil } } return nil, nil } oxy-1.3.0/roundrobin/stickysessions.go000066400000000000000000000043201404246664300201470ustar00rootroot00000000000000package roundrobin import ( "net/http" "net/url" "time" "github.com/vulcand/oxy/roundrobin/stickycookie" ) // CookieOptions has all the options one would like to set on the affinity cookie type CookieOptions struct { HTTPOnly bool Secure bool Path string Domain string Expires time.Time MaxAge int SameSite http.SameSite } // StickySession is a mixin for load balancers that implements layer 7 (http cookie) session affinity type StickySession struct { cookieName string cookieValue stickycookie.CookieValue options CookieOptions } // NewStickySession creates a new StickySession func NewStickySession(cookieName string) *StickySession { return &StickySession{cookieName: cookieName, cookieValue: &stickycookie.RawValue{}} } // NewStickySessionWithOptions creates a new StickySession whilst allowing for options to // shape its affinity cookie such as "httpOnly" or "secure" func NewStickySessionWithOptions(cookieName string, options CookieOptions) *StickySession { return &StickySession{cookieName: cookieName, options: options, cookieValue: &stickycookie.RawValue{}} } // SetCookieValue set the CookieValue for the StickySession. func (s *StickySession) SetCookieValue(value stickycookie.CookieValue) *StickySession { s.cookieValue = value return s } // GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers. func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) { cookie, err := req.Cookie(s.cookieName) switch err { case nil: case http.ErrNoCookie: return nil, false, nil default: return nil, false, err } server, err := s.cookieValue.FindURL(cookie.Value, servers) return server, server != nil, err } // StickBackend creates and sets the cookie func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) { opt := s.options cp := "/" if opt.Path != "" { cp = opt.Path } cookie := &http.Cookie{ Name: s.cookieName, Value: s.cookieValue.Get(backend), Path: cp, Domain: opt.Domain, Expires: opt.Expires, MaxAge: opt.MaxAge, Secure: opt.Secure, HttpOnly: opt.HTTPOnly, SameSite: opt.SameSite, } http.SetCookie(*w, cookie) } oxy-1.3.0/roundrobin/stickysessions_test.go000066400000000000000000000414331404246664300212140ustar00rootroot00000000000000package roundrobin import ( "fmt" "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/roundrobin/stickycookie" "github.com/vulcand/oxy/testutils" ) func TestBasic(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() client := http.DefaultClient for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "a", string(body)) } } func TestBasicWithHashValue(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) sticky := NewStickySession("test") require.NotNil(t, sticky) sticky.SetCookieValue(&stickycookie.HashValue{Salt: "foo"}) require.NotNil(t, sticky.cookieValue) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() client := http.DefaultClient var cookie *http.Cookie for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) if cookie != nil { req.AddCookie(cookie) } resp, err := client.Do(req) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) defer resp.Body.Close() require.NoError(t, err) assert.Equal(t, "a", string(body)) if cookie == nil { // The first request will set the cookie value cookie = resp.Cookies()[0] } assert.Equal(t, "test", cookie.Name) assert.Equal(t, sticky.cookieValue.Get(mustParse(t, a.URL)), cookie.Value) } } func TestBasicWithAESValue(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) sticky := NewStickySession("test") require.NotNil(t, sticky) aesValue, err := stickycookie.NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 5*time.Second) require.NoError(t, err) sticky.SetCookieValue(aesValue) require.NotNil(t, sticky.cookieValue) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() client := http.DefaultClient var cookie *http.Cookie for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) if cookie != nil { req.AddCookie(cookie) } resp, err := client.Do(req) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) defer resp.Body.Close() require.NoError(t, err) assert.Equal(t, "a", string(body)) if cookie == nil { // The first request will set the cookie value cookie = resp.Cookies()[0] } assert.Equal(t, "test", cookie.Name) } } func TestStickyCookie(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() resp, err := http.Get(proxy.URL) require.NoError(t, err) cookie := resp.Cookies()[0] assert.Equal(t, "test", cookie.Name) assert.Equal(t, a.URL, cookie.Value) } func TestStickyCookieWithOptions(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() testCases := []struct { desc string name string options CookieOptions expected *http.Cookie }{ { desc: "no options", name: "test", options: CookieOptions{}, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", Raw: fmt.Sprintf("test=%s; Path=/", a.URL), }, }, { desc: "HTTPOnly", name: "test", options: CookieOptions{ HTTPOnly: true, }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", HttpOnly: true, Raw: fmt.Sprintf("test=%s; Path=/; HttpOnly", a.URL), Unparsed: nil, }, }, { desc: "Secure", name: "test", options: CookieOptions{ Secure: true, }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", Secure: true, Raw: fmt.Sprintf("test=%s; Path=/; Secure", a.URL), }, }, { desc: "Path", name: "test", options: CookieOptions{ Path: "/foo", }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/foo", Raw: fmt.Sprintf("test=%s; Path=/foo", a.URL), }, }, { desc: "Domain", name: "test", options: CookieOptions{ Domain: "example.org", }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", Domain: "example.org", Raw: fmt.Sprintf("test=%s; Path=/; Domain=example.org", a.URL), }, }, { desc: "Expires", name: "test", options: CookieOptions{ Expires: time.Date(1955, 11, 12, 1, 22, 0, 0, time.UTC), }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", Expires: time.Date(1955, 11, 12, 1, 22, 0, 0, time.UTC), RawExpires: "Sat, 12 Nov 1955 01:22:00 GMT", Raw: fmt.Sprintf("test=%s; Path=/; Expires=Sat, 12 Nov 1955 01:22:00 GMT", a.URL), }, }, { desc: "MaxAge", name: "test", options: CookieOptions{ MaxAge: -20, }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", MaxAge: -1, Raw: fmt.Sprintf("test=%s; Path=/; Max-Age=0", a.URL), }, }, { desc: "SameSite", name: "test", options: CookieOptions{ SameSite: http.SameSiteNoneMode, }, expected: &http.Cookie{ Name: "test", Value: a.URL, Path: "/", SameSite: http.SameSiteNoneMode, Raw: fmt.Sprintf("test=%s; Path=/; SameSite=None", a.URL), }, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { fwd, err := forward.New() require.NoError(t, err) sticky := NewStickySessionWithOptions(test.name, test.options) require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() resp, err := http.Get(proxy.URL) require.NoError(t, err) require.Len(t, resp.Cookies(), 1) assert.Equal(t, test.expected, resp.Cookies()[0]) }) } } func TestRemoveRespondingServer(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() client := http.DefaultClient for i := 0; i < 10; i++ { req, errReq := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, errReq) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, errReq := client.Do(req) require.NoError(t, errReq) defer resp.Body.Close() body, errReq := ioutil.ReadAll(resp.Body) require.NoError(t, errReq) assert.Equal(t, "a", string(body)) } err = lb.RemoveServer(testutils.ParseURI(a.URL)) require.NoError(t, err) // Now, use the organic cookie response in our next requests. req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, err := client.Do(req) require.NoError(t, err) assert.Equal(t, "test", resp.Cookies()[0].Name) assert.Equal(t, b.URL, resp.Cookies()[0].Value) for i := 0; i < 10; i++ { req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "b", string(body)) } } func TestRemoveAllServers(t *testing.T) { a := testutils.NewResponder("a") b := testutils.NewResponder("b") defer a.Close() defer b.Close() fwd, err := forward.New() require.NoError(t, err) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(b.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() client := http.DefaultClient for i := 0; i < 10; i++ { req, errReq := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, errReq) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, errReq := client.Do(req) require.NoError(t, errReq) defer resp.Body.Close() body, errReq := ioutil.ReadAll(resp.Body) require.NoError(t, errReq) assert.Equal(t, "a", string(body)) } err = lb.RemoveServer(testutils.ParseURI(a.URL)) require.NoError(t, err) err = lb.RemoveServer(testutils.ParseURI(b.URL)) require.NoError(t, err) // Now, use the organic cookie response in our next requests. req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: a.URL}) resp, err := client.Do(req) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) } func TestBadCookieVal(t *testing.T) { a := testutils.NewResponder("a") defer a.Close() fwd, err := forward.New() require.NoError(t, err) sticky := NewStickySession("test") require.NotNil(t, sticky) lb, err := New(fwd, EnableStickySession(sticky)) require.NoError(t, err) err = lb.UpsertServer(testutils.ParseURI(a.URL)) require.NoError(t, err) proxy := httptest.NewServer(lb) defer proxy.Close() client := http.DefaultClient req, err := http.NewRequest(http.MethodGet, proxy.URL, nil) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: "test", Value: "This is a patently invalid url! You can't parse it! :-)"}) resp, err := client.Do(req) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "a", string(body)) // Now, cycle off the good server to cause an error err = lb.RemoveServer(testutils.ParseURI(a.URL)) require.NoError(t, err) resp, err = client.Do(req) require.NoError(t, err) _, err = ioutil.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) } func TestStickySession_GetBackend(t *testing.T) { cookieName := "Test-Cookie" servers := []*url.URL{ {Scheme: "http", Host: "10.10.10.10", Path: "/"}, {Scheme: "https", Host: "10.10.10.42", Path: "/"}, {Scheme: "http", Host: "10.10.10.10", Path: "/foo"}, {Scheme: "http", Host: "10.10.10.11", Path: "/", User: url.User("John Doe")}, } rawValue := &stickycookie.RawValue{} hashValue := &stickycookie.HashValue{} saltyHashValue := &stickycookie.HashValue{Salt: "test salt"} aesValue, err := stickycookie.NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 5*time.Second) aesValueInfinite, err := stickycookie.NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 0) require.NoError(t, err) aesValueExpired, err := stickycookie.NewAESValue([]byte("95Bx9JkKX3xbd7z3"), 1*time.Nanosecond) require.NoError(t, err) tests := []struct { name string CookieValue stickycookie.CookieValue cookie *http.Cookie want *url.URL expectError bool }{ { name: "NoCookies", }, { name: "Cookie no matched", cookie: &http.Cookie{Name: "not" + cookieName, Value: "http://10.10.10.10/"}, }, { name: "Cookie not a URL", cookie: &http.Cookie{Name: cookieName, Value: "foo://foo bar"}, expectError: true, }, { name: "Simple", cookie: &http.Cookie{Name: cookieName, Value: "http://10.10.10.10/"}, want: servers[0], }, { name: "Host no match for needle", cookie: &http.Cookie{Name: cookieName, Value: "http://10.10.10.255/"}, }, { name: "Scheme no match for needle", cookie: &http.Cookie{Name: cookieName, Value: "https://10.10.10.10/"}, }, { name: "Path no match for needle", cookie: &http.Cookie{Name: cookieName, Value: "http://10.10.10.10/foo/bar"}, }, { name: "With user in haystack but not in needle", cookie: &http.Cookie{Name: cookieName, Value: "http://10.10.10.11/"}, want: servers[3], }, { name: "With user in haystack and in needle", cookie: &http.Cookie{Name: cookieName, Value: "http://John%20Doe@10.10.10.11/"}, want: servers[3], }, { name: "Cookie no matched with RawValue", CookieValue: rawValue, cookie: &http.Cookie{Name: "not" + cookieName, Value: rawValue.Get(mustParse(t, "http://10.10.10.10/"))}, }, { name: "Cookie no matched with HashValue", CookieValue: hashValue, cookie: &http.Cookie{Name: "not" + cookieName, Value: hashValue.Get(mustParse(t, "http://10.10.10.10/"))}, }, { name: "Cookie value not matched with HashValue", CookieValue: hashValue, cookie: &http.Cookie{Name: cookieName, Value: hashValue.Get(mustParse(t, "http://10.10.10.255/"))}, }, { name: "simple with HashValue", CookieValue: hashValue, cookie: &http.Cookie{Name: cookieName, Value: hashValue.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "simple with HashValue and salt", CookieValue: saltyHashValue, cookie: &http.Cookie{Name: cookieName, Value: saltyHashValue.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "Cookie value not matched with AESValue", CookieValue: aesValue, cookie: &http.Cookie{Name: cookieName, Value: aesValue.Get(mustParse(t, "http://10.10.10.255/"))}, }, { name: "simple with AESValue", CookieValue: aesValue, cookie: &http.Cookie{Name: cookieName, Value: aesValue.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "Cookie value not matched with AESValue with ttl 0s", CookieValue: aesValueInfinite, cookie: &http.Cookie{Name: cookieName, Value: aesValueInfinite.Get(mustParse(t, "http://10.10.10.255/"))}, }, { name: "simple with AESValue with ttl 0s", CookieValue: aesValueInfinite, cookie: &http.Cookie{Name: cookieName, Value: aesValueInfinite.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "simple with AESValue with ttl 0s", CookieValue: aesValueInfinite, cookie: &http.Cookie{Name: cookieName, Value: aesValueInfinite.Get(mustParse(t, "http://10.10.10.10/"))}, want: servers[0], }, { name: "simple with AESValue with expired ttl", CookieValue: aesValueExpired, cookie: &http.Cookie{Name: cookieName, Value: aesValueExpired.Get(mustParse(t, "http://10.10.10.10/"))}, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewStickySession(cookieName) if tt.CookieValue != nil { s.SetCookieValue(tt.CookieValue) } req := httptest.NewRequest(http.MethodGet, "http://foo", nil) if tt.cookie != nil { req.AddCookie(tt.cookie) } got, _, err := s.GetBackend(req, servers) if tt.expectError { require.Error(t, err) return } require.NoError(t, err) assert.Equal(t, tt.want, got) }) } } func mustParse(t *testing.T, raw string) *url.URL { t.Helper() u, err := url.Parse(raw) require.NoError(t, err) return u } oxy-1.3.0/stream/000077500000000000000000000000001404246664300136365ustar00rootroot00000000000000oxy-1.3.0/stream/stream.go000066400000000000000000000051361404246664300154650ustar00rootroot00000000000000/* Package stream provides http.Handler middleware that passes-through the entire request Stream works around several limitations caused by buffering implementations, but also introduces certain risks. Workarounds for buffering limitations: 1. Streaming really large chunks of data (large file transfers, or streaming videos, etc.) 2. Streaming (chunking) sparse data. For example, an implementation might send a health check or a heart beat over a long-lived connection. This does not play well with buffering. Risks: 1. Connections could survive for very long periods of time. 2. There is no easy way to enforce limits on size/time of a connection. Examples of a streaming middleware: // sample HTTP handler handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) // Stream will literally pass through to the next handler without ANY buffering // or validation of the data. stream.New(handler) */ package stream import ( "net/http" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) const ( // DefaultMaxBodyBytes No limit by default DefaultMaxBodyBytes = -1 ) // Stream is responsible for buffering requests and responses // It buffers large requests and responses to disk, type Stream struct { maxRequestBodyBytes int64 maxResponseBodyBytes int64 retryPredicate hpredicate next http.Handler errHandler utils.ErrorHandler log *log.Logger } // New returns a new streamer middleware. New() function supports optional functional arguments func New(next http.Handler, setters ...optSetter) (*Stream, error) { strm := &Stream{ next: next, maxRequestBodyBytes: DefaultMaxBodyBytes, maxResponseBodyBytes: DefaultMaxBodyBytes, log: log.StandardLogger(), } for _, s := range setters { if err := s(strm); err != nil { return nil, err } } return strm, nil } // Logger defines the logger the streamer will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func Logger(l *log.Logger) optSetter { return func(s *Stream) error { s.log = l return nil } } type optSetter func(s *Stream) error // Wrap sets the next handler to be called by stream handler. func (s *Stream) Wrap(next http.Handler) error { s.next = next return nil } func (s *Stream) ServeHTTP(w http.ResponseWriter, req *http.Request) { if s.log.Level >= log.DebugLevel { logEntry := s.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/stream: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/stream: completed ServeHttp on request") } s.next.ServeHTTP(w, req) } oxy-1.3.0/stream/stream_test.go000066400000000000000000000225371404246664300165300ustar00rootroot00000000000000package stream import ( "bufio" "crypto/tls" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" "testing" "time" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/testutils" ) type noOpNextHTTPHandler struct{} func (n noOpNextHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {} type noOpIoWriter struct{} func (n noOpIoWriter) Write(bytes []byte) (int, error) { return len(bytes), nil } func TestSimple(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello", string(body)) } func TestChunkedEncodingSuccess(t *testing.T) { var reqBody string var contentLength int64 srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { body, err := ioutil.ReadAll(req.Body) require.NoError(t, err) reqBody = string(body) contentLength = req.ContentLength w.WriteHeader(200) flusher, ok := w.(http.Flusher) if !ok { panic("expected http.ResponseWriter to be an http.Flusher") } fmt.Fprint(w, "Response") flusher.Flush() time.Sleep(time.Duration(500) * time.Millisecond) fmt.Fprint(w, "in") flusher.Flush() time.Sleep(time.Duration(500) * time.Millisecond) fmt.Fprint(w, "Chunks") flusher.Flush() }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() conn, err := net.Dial("tcp", testutils.ParseURI(proxy.URL).Host) require.NoError(t, err) fmt.Fprint(conn, "POST / HTTP/1.1\r\nHost: 127.0.0.1\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n5\r\ntest1\r\n5\r\ntest2\r\n0\r\n\r\n") reader := bufio.NewReader(conn) status, err := reader.ReadString('\n') require.NoError(t, err) _, err = reader.ReadString('\n') // content type require.NoError(t, err) _, err = reader.ReadString('\n') // Date require.NoError(t, err) transferEncoding, err := reader.ReadString('\n') require.NoError(t, err) assert.Equal(t, "Transfer-Encoding: chunked\r\n", transferEncoding) assert.Equal(t, int64(-1), contentLength) assert.Equal(t, "testtest1test2", reqBody) assert.Equal(t, "HTTP/1.1 200 OK\r\n", status) } func TestRequestLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL, testutils.Body("this request is too long")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestResponseLimitReached(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello, this response is too large")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestFileStreamingResponse(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello, this response is too large to fit in memory")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, body, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.Equal(t, "hello, this response is too large to fit in memory", string(body)) } func TestCustomErrorHandler(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello, this response is too large")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } func TestNotModified(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusNotModified) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusNotModified, re.StatusCode) } func TestNoBody(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewServer(st) defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) } // Make sure that stream handler preserves TLS settings func TestPreservesTLS(t *testing.T) { srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) }) defer srv.Close() // forwarder will proxy the request to whatever destination fwd, err := forward.New(forward.Stream(true)) require.NoError(t, err) var cs *tls.ConnectionState // this is our redirect to server rdr := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { cs = req.TLS req.URL = testutils.ParseURI(srv.URL) fwd.ServeHTTP(w, req) }) // stream handler will forward requests to redirect st, err := New(rdr) require.NoError(t, err) proxy := httptest.NewUnstartedServer(st) proxy.StartTLS() defer proxy.Close() re, _, err := testutils.Get(proxy.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) assert.NotNil(t, cs) } func BenchmarkLoggingDebugLevel(b *testing.B) { streamer, _ := New(noOpNextHTTPHandler{}) log.SetLevel(log.DebugLevel) log.SetOutput(&noOpIoWriter{}) // Make sure we don't emit a bunch of stuff on screen for i := 0; i < b.N; i++ { heavyServeHTTPLoad(streamer) } } func BenchmarkLoggingInfoLevel(b *testing.B) { streamer, _ := New(noOpNextHTTPHandler{}) log.SetLevel(log.InfoLevel) log.SetOutput(&noOpIoWriter{}) // Make sure we don't emit a bunch of stuff on screen for i := 0; i < b.N; i++ { heavyServeHTTPLoad(streamer) } } func heavyServeHTTPLoad(handler http.Handler) { w := httptest.NewRecorder() r := &http.Request{} handler.ServeHTTP(w, r) } oxy-1.3.0/stream/threshold.go000066400000000000000000000122521404246664300161630ustar00rootroot00000000000000package stream import ( "fmt" "net/http" "github.com/vulcand/predicate" ) // IsValidExpression check if it's a valid expression func IsValidExpression(expr string) bool { _, err := parseExpression(expr) return err == nil } type context struct { r *http.Request attempt int responseCode int } type hpredicate func(*context) bool // Parses expression in the go language into Failover predicates func parseExpression(in string) (hpredicate, error) { p, err := predicate.NewParser(predicate.Def{ Operators: predicate.Operators{ AND: and, OR: or, EQ: eq, NEQ: neq, LT: lt, GT: gt, LE: le, GE: ge, }, Functions: map[string]interface{}{ "RequestMethod": requestMethod, "IsNetworkError": isNetworkError, "Attempts": attempts, "ResponseCode": responseCode, }, }) if err != nil { return nil, err } out, err := p.Parse(in) if err != nil { return nil, err } pr, ok := out.(hpredicate) if !ok { return nil, fmt.Errorf("expected predicate, got %T", out) } return pr, nil } type toString func(c *context) string type toInt func(c *context) int // RequestMethod returns mapper of the request to its method e.g. POST func requestMethod() toString { return func(c *context) string { return c.r.Method } } // Attempts returns mapper of the request to the number of proxy attempts func attempts() toInt { return func(c *context) int { return c.attempt } } // ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code. func responseCode() toInt { return func(c *context) int { return c.responseCode } } // IsNetworkError returns a predicate that returns true if last attempt ended with network error. func isNetworkError() hpredicate { return func(c *context) bool { return c.responseCode == http.StatusBadGateway || c.responseCode == http.StatusGatewayTimeout } } // and returns predicate by joining the passed predicates with logical 'and' func and(fns ...hpredicate) hpredicate { return func(c *context) bool { for _, fn := range fns { if !fn(c) { return false } } return true } } // or returns predicate by joining the passed predicates with logical 'or' func or(fns ...hpredicate) hpredicate { return func(c *context) bool { for _, fn := range fns { if fn(c) { return true } } return false } } // not creates negation of the passed predicate func not(p hpredicate) hpredicate { return func(c *context) bool { return !p(c) } } // eq returns predicate that tests for equality of the value of the mapper and the constant func eq(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toString: return stringEQ(mapper, value) case toInt: return intEQ(mapper, value) } return nil, fmt.Errorf("unsupported argument: %T", m) } // neq returns predicate that tests for inequality of the value of the mapper and the constant func neq(m interface{}, value interface{}) (hpredicate, error) { p, err := eq(m, value) if err != nil { return nil, err } return not(p), nil } // lt returns predicate that tests that value of the mapper function is less than the constant func lt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intLT(mapper, value) } return nil, fmt.Errorf("unsupported argument: %T", m) } // le returns predicate that tests that value of the mapper function is less or equal than the constant func le(m interface{}, value interface{}) (hpredicate, error) { l, err := lt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *context) bool { return l(c) || e(c) }, nil } // gt returns predicate that tests that value of the mapper function is greater than the constant func gt(m interface{}, value interface{}) (hpredicate, error) { switch mapper := m.(type) { case toInt: return intGT(mapper, value) } return nil, fmt.Errorf("unsupported argument: %T", m) } // ge returns predicate that tests that value of the mapper function is less or equal than the constant func ge(m interface{}, value interface{}) (hpredicate, error) { g, err := gt(m, value) if err != nil { return nil, err } e, err := eq(m, value) if err != nil { return nil, err } return func(c *context) bool { return g(c) || e(c) }, nil } func stringEQ(m toString, val interface{}) (hpredicate, error) { value, ok := val.(string) if !ok { return nil, fmt.Errorf("expected string, got %T", val) } return func(c *context) bool { return m(c) == value }, nil } func intEQ(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) == value }, nil } func intLT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) < value }, nil } func intGT(m toInt, val interface{}) (hpredicate, error) { value, ok := val.(int) if !ok { return nil, fmt.Errorf("expected int, got %T", val) } return func(c *context) bool { return m(c) > value }, nil } oxy-1.3.0/testutils/000077500000000000000000000000001404246664300144035ustar00rootroot00000000000000oxy-1.3.0/testutils/utils.go000066400000000000000000000073371404246664300161040ustar00rootroot00000000000000package testutils import ( "crypto/tls" "errors" "io/ioutil" "net/http" "net/http/httptest" "net/url" "strings" "time" "github.com/mailgun/timetools" "github.com/vulcand/oxy/utils" ) // NewHandler creates a new Server func NewHandler(handler http.HandlerFunc) *httptest.Server { return httptest.NewServer(handler) } // NewResponder creates a new Server with response func NewResponder(response string) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(response)) })) } // ParseURI is the version of url.ParseRequestURI that panics if incorrect, helpful to shorten the tests func ParseURI(uri string) *url.URL { out, err := url.ParseRequestURI(uri) if err != nil { panic(err) } return out } // ReqOpts request options type ReqOpts struct { Host string Method string Body string Headers http.Header Auth *utils.BasicAuth } // ReqOption request option type type ReqOption func(o *ReqOpts) error // Method sets request method func Method(m string) ReqOption { return func(o *ReqOpts) error { o.Method = m return nil } } // Host sets request host func Host(h string) ReqOption { return func(o *ReqOpts) error { o.Host = h return nil } } // Body sets request body func Body(b string) ReqOption { return func(o *ReqOpts) error { o.Body = b return nil } } // Header sets request header func Header(name, val string) ReqOption { return func(o *ReqOpts) error { if o.Headers == nil { o.Headers = make(http.Header) } o.Headers.Add(name, val) return nil } } // Headers sets request headers func Headers(h http.Header) ReqOption { return func(o *ReqOpts) error { if o.Headers == nil { o.Headers = make(http.Header) } utils.CopyHeaders(o.Headers, h) return nil } } // BasicAuth sets request basic auth func BasicAuth(username, password string) ReqOption { return func(o *ReqOpts) error { o.Auth = &utils.BasicAuth{ Username: username, Password: password, } return nil } } // MakeRequest create and do a request func MakeRequest(url string, opts ...ReqOption) (*http.Response, []byte, error) { o := &ReqOpts{} for _, s := range opts { if err := s(o); err != nil { return nil, nil, err } } if o.Method == "" { o.Method = http.MethodGet } request, err := http.NewRequest(o.Method, url, strings.NewReader(o.Body)) if err != nil { return nil, nil, err } if o.Headers != nil { utils.CopyHeaders(request.Header, o.Headers) } if o.Auth != nil { request.Header.Set("Authorization", o.Auth.String()) } if len(o.Host) != 0 { request.Host = o.Host } var tr *http.Transport if strings.HasPrefix(url, "https") { tr = &http.Transport{ DisableKeepAlives: true, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: request.Host, }, } } else { tr = &http.Transport{ DisableKeepAlives: true, } } client := &http.Client{ Transport: tr, CheckRedirect: func(req *http.Request, via []*http.Request) error { return errors.New("no redirects") }, } response, err := client.Do(request) if err == nil { bodyBytes, errRead := ioutil.ReadAll(response.Body) return response, bodyBytes, errRead } return response, nil, err } // Get do a GET request func Get(url string, opts ...ReqOption) (*http.Response, []byte, error) { opts = append(opts, Method(http.MethodGet)) return MakeRequest(url, opts...) } // Post do a POST request func Post(url string, opts ...ReqOption) (*http.Response, []byte, error) { opts = append(opts, Method(http.MethodPost)) return MakeRequest(url, opts...) } // GetClock gets a FreezedTime func GetClock() *timetools.FreezedTime { return &timetools.FreezedTime{ CurrentTime: time.Date(2012, 3, 4, 5, 6, 7, 0, time.UTC), } } oxy-1.3.0/trace/000077500000000000000000000000001404246664300134415ustar00rootroot00000000000000oxy-1.3.0/trace/trace.go000066400000000000000000000152201404246664300150660ustar00rootroot00000000000000// Package trace implement structured logging of requests package trace import ( "crypto/tls" "encoding/json" "fmt" "io" "net/http" "strconv" "time" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) // Option is a functional option setter for Tracer type Option func(*Tracer) error // ErrorHandler is a functional argument that sets error handler of the server func ErrorHandler(h utils.ErrorHandler) Option { return func(t *Tracer) error { t.errHandler = h return nil } } // RequestHeaders adds request headers to capture func RequestHeaders(headers ...string) Option { return func(t *Tracer) error { t.reqHeaders = append(t.reqHeaders, headers...) return nil } } // ResponseHeaders adds response headers to capture func ResponseHeaders(headers ...string) Option { return func(t *Tracer) error { t.respHeaders = append(t.respHeaders, headers...) return nil } } // Tracer records request and response emitting JSON structured data to the output type Tracer struct { errHandler utils.ErrorHandler next http.Handler reqHeaders []string respHeaders []string writer io.Writer log *log.Logger } // New creates a new Tracer middleware that emits all the request/response information in structured format // to writer and passes the request to the next handler. It can optionally capture request and response headers, // see RequestHeaders and ResponseHeaders options for details. func New(next http.Handler, writer io.Writer, opts ...Option) (*Tracer, error) { t := &Tracer{ writer: writer, next: next, log: log.StandardLogger(), } for _, o := range opts { if err := o(t); err != nil { return nil, err } } if t.errHandler == nil { t.errHandler = utils.DefaultHandler } return t, nil } // Logger defines the logger the tracer will use. // // It defaults to logrus.StandardLogger(), the global logger used by logrus. func Logger(l *log.Logger) Option { return func(t *Tracer) error { t.log = l return nil } } func (t *Tracer) ServeHTTP(w http.ResponseWriter, req *http.Request) { start := time.Now() pw := utils.NewProxyWriterWithLogger(w, t.log) t.next.ServeHTTP(pw, req) l := t.newRecord(req, pw, time.Since(start)) if err := json.NewEncoder(t.writer).Encode(l); err != nil { t.log.Errorf("Failed to marshal request: %v", err) } } func (t *Tracer) newRecord(req *http.Request, pw *utils.ProxyWriter, diff time.Duration) *Record { return &Record{ Request: Request{ Method: req.Method, URL: req.URL.String(), TLS: newTLS(req), BodyBytes: bodyBytes(req.Header), Headers: captureHeaders(req.Header, t.reqHeaders), }, Response: Response{ Code: pw.StatusCode(), BodyBytes: bodyBytes(pw.Header()), Roundtrip: float64(diff) / float64(time.Millisecond), Headers: captureHeaders(pw.Header(), t.respHeaders), }, } } func newTLS(req *http.Request) *TLS { if req.TLS == nil { return nil } return &TLS{ Version: versionToString(req.TLS.Version), Resume: req.TLS.DidResume, CipherSuite: csToString(req.TLS.CipherSuite), Server: req.TLS.ServerName, } } func captureHeaders(in http.Header, headers []string) http.Header { if len(headers) == 0 || in == nil { return nil } out := make(http.Header, len(headers)) for _, h := range headers { vals, ok := in[h] if !ok || len(out[h]) != 0 { continue } for i := range vals { out.Add(h, vals[i]) } } return out } // Record represents a structured request and response record type Record struct { Request Request `json:"request"` Response Response `json:"response"` } // Request contains information about an HTTP request type Request struct { Method string `json:"method"` // Method - request method BodyBytes int64 `json:"body_bytes"` // BodyBytes - size of request body in bytes URL string `json:"url"` // URL - Request URL Headers http.Header `json:"headers,omitempty"` // Headers - optional request headers, will be recorded if configured TLS *TLS `json:"tls,omitempty"` // TLS - optional TLS record, will be recorded if it's a TLS connection } // Response contains information about HTTP response type Response struct { Code int `json:"code"` // Code - response status code Roundtrip float64 `json:"roundtrip"` // Roundtrip - round trip time in milliseconds Headers http.Header `json:"headers,omitempty"` // Headers - optional headers, will be recorded if configured BodyBytes int64 `json:"body_bytes"` // BodyBytes - size of response body in bytes } // TLS contains information about this TLS connection type TLS struct { Version string `json:"version"` // Version - TLS version Resume bool `json:"resume"` // Resume tells if the session has been re-used (session tickets) CipherSuite string `json:"cipher_suite"` // CipherSuite contains cipher suite used for this connection Server string `json:"server"` // Server contains server name used in SNI } func versionToString(v uint16) string { switch v { case tls.VersionSSL30: return "SSL30" case tls.VersionTLS10: return "TLS10" case tls.VersionTLS11: return "TLS11" case tls.VersionTLS12: return "TLS12" } return fmt.Sprintf("unknown: %x", v) } func csToString(cs uint16) string { switch cs { case tls.TLS_RSA_WITH_RC4_128_SHA: return "TLS_RSA_WITH_RC4_128_SHA" case tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: return "TLS_RSA_WITH_3DES_EDE_CBC_SHA" case tls.TLS_RSA_WITH_AES_128_CBC_SHA: return "TLS_RSA_WITH_AES_128_CBC_SHA" case tls.TLS_RSA_WITH_AES_256_CBC_SHA: return "TLS_RSA_WITH_AES_256_CBC_SHA" case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA" case tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA" case tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" case tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: return "TLS_ECDHE_RSA_WITH_RC4_128_SHA" case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA" case tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA" case tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA" case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" case tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" } return fmt.Sprintf("unknown: %x", cs) } func bodyBytes(h http.Header) int64 { length := h.Get("Content-Length") if length == "" { return 0 } bytes, err := strconv.ParseInt(length, 10, 0) if err == nil { return bytes } return 0 } oxy-1.3.0/trace/trace_test.go000066400000000000000000000055641404246664300161370ustar00rootroot00000000000000package trace import ( "bufio" "bytes" "crypto/tls" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/testutils" "github.com/vulcand/oxy/utils" ) func TestTraceSimple(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Header().Set("Content-Length", "5") w.Write([]byte("hello")) }) trace := &bytes.Buffer{} tr, err := New(handler, trace) require.NoError(t, err) srv := httptest.NewServer(tr) defer srv.Close() re, _, err := testutils.MakeRequest(srv.URL+"/hello", testutils.Method(http.MethodPost), testutils.Body("123456")) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) var r *Record require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) assert.Equal(t, http.MethodPost, r.Request.Method) assert.Equal(t, "/hello", r.Request.URL) assert.Equal(t, http.StatusOK, r.Response.Code) assert.EqualValues(t, 6, r.Request.BodyBytes) assert.NotEqual(t, float64(0), r.Response.Roundtrip) assert.EqualValues(t, 5, r.Response.BodyBytes) } func TestTraceCaptureHeaders(t *testing.T) { respHeaders := http.Header{ "X-Re-1": []string{"6", "7"}, "X-Re-2": []string{"2", "3"}, } handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { utils.CopyHeaders(w.Header(), respHeaders) w.Write([]byte("hello")) }) trace := &bytes.Buffer{} tr, err := New(handler, trace, RequestHeaders("X-Req-B", "X-Req-A"), ResponseHeaders("X-Re-1", "X-Re-2")) require.NoError(t, err) srv := httptest.NewServer(tr) defer srv.Close() reqHeaders := http.Header{"X-Req-A": []string{"1", "2"}, "X-Req-B": []string{"3", "4"}} re, _, err := testutils.Get(srv.URL+"/hello", testutils.Headers(reqHeaders)) require.NoError(t, err) assert.Equal(t, http.StatusOK, re.StatusCode) var r *Record require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) assert.Equal(t, reqHeaders, r.Request.Headers) assert.Equal(t, respHeaders, r.Response.Headers) } func TestTraceTLS(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello")) }) trace := &bytes.Buffer{} tr, err := New(handler, trace) require.NoError(t, err) srv := httptest.NewUnstartedServer(tr) srv.StartTLS() defer srv.Close() config := &tls.Config{ InsecureSkipVerify: true, } u, err := url.Parse(srv.URL) require.NoError(t, err) conn, err := tls.Dial("tcp", u.Host, config) require.NoError(t, err) fmt.Fprint(conn, "GET / HTTP/1.0\r\n\r\n") status, err := bufio.NewReader(conn).ReadString('\n') require.NoError(t, err) assert.Equal(t, "HTTP/1.0 200 OK\r\n", status) state := conn.ConnectionState() conn.Close() var r *Record require.NoError(t, json.Unmarshal(trace.Bytes(), &r)) assert.Equal(t, versionToString(state.Version), r.Request.TLS.Version) } oxy-1.3.0/utils/000077500000000000000000000000001404246664300135035ustar00rootroot00000000000000oxy-1.3.0/utils/auth.go000066400000000000000000000022501404246664300147720ustar00rootroot00000000000000package utils import ( "encoding/base64" "fmt" "strings" ) // BasicAuth basic auth information type BasicAuth struct { Username string Password string } func (ba *BasicAuth) String() string { encoded := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", ba.Username, ba.Password))) return fmt.Sprintf("Basic %s", encoded) } // ParseAuthHeader creates a new BasicAuth from header values func ParseAuthHeader(header string) (*BasicAuth, error) { values := strings.Fields(header) if len(values) != 2 { return nil, fmt.Errorf(fmt.Sprintf("Failed to parse header '%s'", header)) } authType := strings.ToLower(values[0]) if authType != "basic" { return nil, fmt.Errorf("Expected basic auth type, got '%s'", authType) } encodedString := values[1] decodedString, err := base64.StdEncoding.DecodeString(encodedString) if err != nil { return nil, fmt.Errorf("Failed to parse header '%s', base64 failed: %s", header, err) } values = strings.SplitN(string(decodedString), ":", 2) if len(values) != 2 { return nil, fmt.Errorf("Failed to parse header '%s', expected separator ':'", header) } return &BasicAuth{Username: values[0], Password: values[1]}, nil } oxy-1.3.0/utils/auth_test.go000066400000000000000000000026111404246664300160320ustar00rootroot00000000000000package utils import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Just to make sure we don't panic, return err and not // username and pass and cover the function func TestParseBadHeaders(t *testing.T) { headers := []string{ // just empty string "", // missing auth type "justplainstring", // unknown auth type "Whut justplainstring", // invalid base64 "Basic Shmasic", // random encoded string "Basic YW55IGNhcm5hbCBwbGVhcw==", } for _, h := range headers { _, err := ParseAuthHeader(h) require.Error(t, err) } } // Just to make sure we don't panic, return err and not // username and pass and cover the function func TestParseSuccess(t *testing.T) { headers := []struct { Header string Expected BasicAuth }{ { "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", BasicAuth{Username: "Aladdin", Password: "open sesame"}, }, // Make sure that String() produces valid header { (&BasicAuth{Username: "Alice", Password: "Here's bob"}).String(), BasicAuth{Username: "Alice", Password: "Here's bob"}, }, // empty pass { "Basic QWxhZGRpbjo=", BasicAuth{Username: "Aladdin", Password: ""}, }, } for _, h := range headers { request, err := ParseAuthHeader(h.Header) require.NoError(t, err) assert.Equal(t, h.Expected.Username, request.Username) assert.Equal(t, h.Expected.Password, request.Password) } } oxy-1.3.0/utils/dumpreq.go000066400000000000000000000027001404246664300155060ustar00rootroot00000000000000package utils import ( "crypto/tls" "encoding/json" "fmt" "mime/multipart" "net/http" "net/url" ) // SerializableHttpRequest serializable HTTP request type SerializableHttpRequest struct { Method string URL *url.URL Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 Header http.Header ContentLength int64 TransferEncoding []string Host string Form url.Values PostForm url.Values MultipartForm *multipart.Form Trailer http.Header RemoteAddr string RequestURI string TLS *tls.ConnectionState } // Clone clone a request func Clone(r *http.Request) *SerializableHttpRequest { if r == nil { return nil } rc := new(SerializableHttpRequest) rc.Method = r.Method rc.URL = r.URL rc.Proto = r.Proto rc.ProtoMajor = r.ProtoMajor rc.ProtoMinor = r.ProtoMinor rc.Header = r.Header rc.ContentLength = r.ContentLength rc.Host = r.Host rc.RemoteAddr = r.RemoteAddr rc.RequestURI = r.RequestURI return rc } // ToJson serializes to JSON func (s *SerializableHttpRequest) ToJson() string { jsonVal, err := json.Marshal(s) if err != nil || jsonVal == nil { return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err) } return string(jsonVal) } // DumpHttpRequest dump a HTTP request to JSON func DumpHttpRequest(req *http.Request) string { return Clone(req).ToJson() } oxy-1.3.0/utils/dumpreq_test.go000066400000000000000000000012121404246664300165420ustar00rootroot00000000000000package utils import ( "net/http" "net/url" "testing" "github.com/stretchr/testify/assert" ) type readCloserTestImpl struct{} func (r *readCloserTestImpl) Read(p []byte) (n int, err error) { return 0, nil } func (r *readCloserTestImpl) Close() error { return nil } // Just to make sure we don't panic, return err and not // username and pass and cover the function func TestHttpReqToString(t *testing.T) { req := &http.Request{ URL: &url.URL{Host: "localhost:2374", Path: "/unittest"}, Method: http.MethodDelete, Cancel: make(chan struct{}), Body: &readCloserTestImpl{}, } assert.True(t, len(DumpHttpRequest(req)) > 0) } oxy-1.3.0/utils/handler.go000066400000000000000000000031531404246664300154510ustar00rootroot00000000000000package utils import ( "context" "io" "net" "net/http" log "github.com/sirupsen/logrus" ) // StatusClientClosedRequest non-standard HTTP status code for client disconnection const StatusClientClosedRequest = 499 // StatusClientClosedRequestText non-standard HTTP status for client disconnection const StatusClientClosedRequestText = "Client Closed Request" // ErrorHandler error handler type ErrorHandler interface { ServeHTTP(w http.ResponseWriter, req *http.Request, err error) } // DefaultHandler default error handler var DefaultHandler ErrorHandler = &StdHandler{} // StdHandler Standard error handler type StdHandler struct{} func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { statusCode := http.StatusInternalServerError if e, ok := err.(net.Error); ok { if e.Timeout() { statusCode = http.StatusGatewayTimeout } else { statusCode = http.StatusBadGateway } } else if err == io.EOF { statusCode = http.StatusBadGateway } else if err == context.Canceled { statusCode = StatusClientClosedRequest } w.WriteHeader(statusCode) w.Write([]byte(statusText(statusCode))) log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err) } func statusText(statusCode int) string { if statusCode == StatusClientClosedRequest { return StatusClientClosedRequestText } return http.StatusText(statusCode) } // ErrorHandlerFunc error handler function type type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, error) // ServeHTTP calls f(w, r). func (f ErrorHandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request, err error) { f(w, r, err) } oxy-1.3.0/utils/handler_test.go000066400000000000000000000012771404246664300165150ustar00rootroot00000000000000package utils import ( "bytes" "net/http" "net/http/httptest" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDefaultHandlerErrors(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := w.(http.Hijacker) conn, _, _ := h.Hijack() conn.Close() })) defer srv.Close() request, err := http.NewRequest(http.MethodGet, srv.URL, strings.NewReader("")) require.NoError(t, err) _, err = http.DefaultTransport.RoundTrip(request) w := NewBufferWriter(NopWriteCloser(&bytes.Buffer{})) DefaultHandler.ServeHTTP(w, nil, err) assert.Equal(t, http.StatusBadGateway, w.Code) } oxy-1.3.0/utils/netutils.go000066400000000000000000000115751404246664300157120ustar00rootroot00000000000000package utils import ( "bufio" "fmt" "io" "net" "net/http" "net/url" "reflect" log "github.com/sirupsen/logrus" ) // ProxyWriter calls recorder, used to debug logs type ProxyWriter struct { w http.ResponseWriter code int length int64 log *log.Logger } // NewProxyWriter creates a new ProxyWriter func NewProxyWriter(w http.ResponseWriter) *ProxyWriter { return NewProxyWriterWithLogger(w, log.StandardLogger()) } // NewProxyWriterWithLogger creates a new ProxyWriter func NewProxyWriterWithLogger(w http.ResponseWriter, l *log.Logger) *ProxyWriter { return &ProxyWriter{ w: w, log: l, } } // StatusCode gets status code func (p *ProxyWriter) StatusCode() int { if p.code == 0 { // per contract standard lib will set this to http.StatusOK if not set // by user, here we avoid the confusion by mirroring this logic return http.StatusOK } return p.code } // GetLength gets content length func (p *ProxyWriter) GetLength() int64 { return p.length } // Header gets response header func (p *ProxyWriter) Header() http.Header { return p.w.Header() } func (p *ProxyWriter) Write(buf []byte) (int, error) { p.length = p.length + int64(len(buf)) return p.w.Write(buf) } // WriteHeader writes status code func (p *ProxyWriter) WriteHeader(code int) { p.code = code p.w.WriteHeader(code) } // Flush flush the writer func (p *ProxyWriter) Flush() { if f, ok := p.w.(http.Flusher); ok { f.Flush() } } // CloseNotify returns a channel that receives at most a single value (true) // when the client connection has gone away. func (p *ProxyWriter) CloseNotify() <-chan bool { if cn, ok := p.w.(http.CloseNotifier); ok { return cn.CloseNotify() } p.log.Debugf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.w)) return make(<-chan bool) } // Hijack lets the caller take over the connection. func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hi, ok := p.w.(http.Hijacker); ok { return hi.Hijack() } p.log.Debugf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.w)) return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.w)) } // NewBufferWriter creates a new BufferWriter func NewBufferWriter(w io.WriteCloser) *BufferWriter { return &BufferWriter{ W: w, H: make(http.Header), } } // BufferWriter buffer writer type BufferWriter struct { H http.Header Code int W io.WriteCloser } // Close close the writer func (b *BufferWriter) Close() error { return b.W.Close() } // Header gets response header func (b *BufferWriter) Header() http.Header { return b.H } func (b *BufferWriter) Write(buf []byte) (int, error) { return b.W.Write(buf) } // WriteHeader writes status code func (b *BufferWriter) WriteHeader(code int) { b.Code = code } // CloseNotify returns a channel that receives at most a single value (true) // when the client connection has gone away. func (b *BufferWriter) CloseNotify() <-chan bool { if cn, ok := b.W.(http.CloseNotifier); ok { return cn.CloseNotify() } log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.W)) return make(<-chan bool) } // Hijack lets the caller take over the connection. func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hi, ok := b.W.(http.Hijacker); ok { return hi.Hijack() } log.Debugf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.W)) return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.W)) } type nopWriteCloser struct { io.Writer } func (*nopWriteCloser) Close() error { return nil } // NopWriteCloser returns a WriteCloser with a no-op Close method wrapping // the provided Writer w. func NopWriteCloser(w io.Writer) io.WriteCloser { return &nopWriteCloser{Writer: w} } // CopyURL provides update safe copy by avoiding shallow copying User field func CopyURL(i *url.URL) *url.URL { out := *i if i.User != nil { out.User = &(*i.User) } return &out } // CopyHeaders copies http headers from source to destination, it // does not overide, but adds multiple headers func CopyHeaders(dst http.Header, src http.Header) { for k, vv := range src { dst[k] = append(dst[k], vv...) } } // HasHeaders determines whether any of the header names is present in the http headers func HasHeaders(names []string, headers http.Header) bool { for _, h := range names { if headers.Get(h) != "" { return true } } return false } // RemoveHeaders removes the header with the given names from the headers map func RemoveHeaders(headers http.Header, names ...string) { for _, h := range names { headers.Del(h) } } oxy-1.3.0/utils/netutils_test.go000066400000000000000000000043551404246664300167470ustar00rootroot00000000000000package utils import ( "net/http" "net/url" "testing" "github.com/stretchr/testify/assert" ) // Make sure copy does it right, so the copied url // is safe to alter without modifying the other func TestCopyUrl(t *testing.T) { urlA := &url.URL{ Scheme: "http", Host: "localhost:5000", Path: "/upstream", Opaque: "opaque", RawQuery: "a=1&b=2", Fragment: "#hello", User: &url.Userinfo{}, } urlB := CopyURL(urlA) assert.Equal(t, urlA, urlB) urlB.Scheme = "https" assert.NotEqual(t, urlA, urlB) } // Make sure copy headers is not shallow and copies all headers func TestCopyHeaders(t *testing.T) { source, destination := make(http.Header), make(http.Header) source.Add("a", "b") source.Add("c", "d") CopyHeaders(destination, source) assert.Equal(t, "b", destination.Get("a")) assert.Equal(t, "d", destination.Get("c")) // make sure that altering source does not affect the destination source.Del("a") assert.Equal(t, "", source.Get("a")) assert.Equal(t, "b", destination.Get("a")) } func TestHasHeaders(t *testing.T) { source := make(http.Header) source.Add("a", "b") source.Add("c", "d") assert.True(t, HasHeaders([]string{"a", "f"}, source)) assert.False(t, HasHeaders([]string{"i", "j"}, source)) } func TestRemoveHeaders(t *testing.T) { source := make(http.Header) source.Add("a", "b") source.Add("a", "m") source.Add("c", "d") RemoveHeaders(source, "a") assert.Equal(t, "", source.Get("a")) assert.Equal(t, "d", source.Get("c")) } func BenchmarkCopyHeaders(b *testing.B) { dstHeaders := make([]http.Header, 0, b.N) sourceHeaders := make([]http.Header, 0, b.N) for n := 0; n < b.N; n++ { // example from a reverse proxy merging headers d := http.Header{} d.Add("Request-Id", "1bd36bcc-a0d1-4fc7-aedc-20bbdefa27c5") dstHeaders = append(dstHeaders, d) s := http.Header{} s.Add("Content-Length", "374") s.Add("Context-Type", "text/html; charset=utf-8") s.Add("Etag", `"op14g6ae"`) s.Add("Last-Modified", "Wed, 26 Apr 2017 18:24:06 GMT") s.Add("Server", "Caddy") s.Add("Date", "Fri, 28 Apr 2017 15:54:01 GMT") s.Add("Accept-Ranges", "bytes") sourceHeaders = append(sourceHeaders, s) } b.ResetTimer() for n := 0; n < b.N; n++ { CopyHeaders(dstHeaders[n], sourceHeaders[n]) } } oxy-1.3.0/utils/source.go000066400000000000000000000035311404246664300153340ustar00rootroot00000000000000package utils import ( "fmt" "net/http" "strings" ) // SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that // identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters // error should be returned when source can not be identified type SourceExtractor interface { Extract(req *http.Request) (token string, amount int64, err error) } // ExtractorFunc extractor function type type ExtractorFunc func(req *http.Request) (token string, amount int64, err error) // Extract extract from request func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) { return f(req) } // ExtractSource extract source function type type ExtractSource func(req *http.Request) // NewExtractor creates a new SourceExtractor func NewExtractor(variable string) (SourceExtractor, error) { if variable == "client.ip" { return ExtractorFunc(extractClientIP), nil } if variable == "request.host" { return ExtractorFunc(extractHost), nil } if strings.HasPrefix(variable, "request.header.") { header := strings.TrimPrefix(variable, "request.header.") if len(header) == 0 { return nil, fmt.Errorf("wrong header: %s", header) } return makeHeaderExtractor(header), nil } return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable) } func extractClientIP(req *http.Request) (string, int64, error) { vals := strings.SplitN(req.RemoteAddr, ":", 2) if len(vals[0]) == 0 { return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr) } return vals[0], 1, nil } func extractHost(req *http.Request) (string, int64, error) { return req.Host, 1, nil } func makeHeaderExtractor(header string) SourceExtractor { return ExtractorFunc(func(req *http.Request) (string, int64, error) { return req.Header.Get(header), 1, nil }) }